5.3.3. Q-learning 算法#

我们可以通过下面这段代码实现Q-Learning算法,首先,让我们定义以下参数:

  • actions:可能的动作列表

  • learning_rate:学习率

  • reward_decay:奖励折减系数

  • e_greedy:贪心系数

之后,让我们创建一个DataFrame作为Q-Table,其中包含所有可能的状态和动作,并在类中定义了以下函数:

  • choose_action:根据当前状态和贪婪度选择动作

  • learn:根据当前状态、动作、奖励和下一个状态更新 Q-Table

  • check_state_exist:检查当前状态是否在Q-Table中存在

import numpy as np
import pandas as pd

import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)


class QLearningTable:
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.7, e_greedy=0.99
                 ):
        # 定义智能体动作
        self.actions = actions

        # 定义学习率
        self.lr = learning_rate

        # 定义奖励折减系数
        self.gamma = reward_decay

        # 定义贪心系数
        self.epsilon = e_greedy
        # 初始Q表格
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)


def choose_action(self, observation, episode, n):
    self.epsilon = 0.8 - episode / (0.8 * n)

    # 检测状态是否在Q表格里
    self.check_state_exist(observation)
    # 动作选择
    if np.random.uniform() > self.epsilon:
        # 以概率(1-epsilon)选择最大值动作
        state_action = self.q_table.loc[observation, :]
        state_action = state_action.reindex(np.random.permutation(state_action.index))
        # 一些状态有相同的Q值,取索引最大
        action = state_action.idxmax()
    else:
        # 以epsilon概率随机选择动作
        action = np.random.choice(self.actions)
    return action


def learn(self, s, a, r, s_):
    # 检测状态是否在Q表格里
    self.check_state_exist(s_)
    # 根据Q表格获得预测的Q值
    q_predict = self.q_table.loc[s, a]

    # 如果智能体没有发生碰撞,即下一个状态不是终止状态
    if s_ != 'terminal':
        # 计算时序差分 TD error=〖r_(t+1)+(γ max)┬(a_(t+1) )〗⁡〖Q_t  (s_(t+1),a_(t+1))〗-Q_t  (s_t,a_t )]
        q_target = r + self.gamma * self.q_table.loc[s_, :].max()
    else:
        # 如果智能体发生碰撞,下一个状态为终止状态
        q_target = r

        # 更新Q值 Q_t  (s_t,a_t )+a* TD error
    self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
    # 保存Q表格
    self.q_table.to_csv("q_table.csv")


def check_state_exist(self, state):
    # 如果当前状态不在Q表格的索引当中
    if state not in self.q_table.index:
        # 将当前状态添加到Q表格的索引中
        self.q_table = self.q_table.append(
            pd.Series(
                [0] * len(self.actions),
                index=self.q_table.columns,
                name=state,
            )
        )