5.3.3. Q-learning 算法#
我们可以通过下面这段代码实现Q-Learning算法,首先,让我们定义以下参数:
actions:可能的动作列表
learning_rate:学习率
reward_decay:奖励折减系数
e_greedy:贪心系数
之后,让我们创建一个DataFrame作为Q-Table,其中包含所有可能的状态和动作,并在类中定义了以下函数:
choose_action:根据当前状态和贪婪度选择动作
learn:根据当前状态、动作、奖励和下一个状态更新 Q-Table
check_state_exist:检查当前状态是否在Q-Table中存在
import numpy as np
import pandas as pd
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
class QLearningTable:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.7, e_greedy=0.99
):
# 定义智能体动作
self.actions = actions
# 定义学习率
self.lr = learning_rate
# 定义奖励折减系数
self.gamma = reward_decay
# 定义贪心系数
self.epsilon = e_greedy
# 初始Q表格
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def choose_action(self, observation, episode, n):
self.epsilon = 0.8 - episode / (0.8 * n)
# 检测状态是否在Q表格里
self.check_state_exist(observation)
# 动作选择
if np.random.uniform() > self.epsilon:
# 以概率(1-epsilon)选择最大值动作
state_action = self.q_table.loc[observation, :]
state_action = state_action.reindex(np.random.permutation(state_action.index))
# 一些状态有相同的Q值,取索引最大
action = state_action.idxmax()
else:
# 以epsilon概率随机选择动作
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_):
# 检测状态是否在Q表格里
self.check_state_exist(s_)
# 根据Q表格获得预测的Q值
q_predict = self.q_table.loc[s, a]
# 如果智能体没有发生碰撞,即下一个状态不是终止状态
if s_ != 'terminal':
# 计算时序差分 TD error=〖r_(t+1)+(γ max)┬(a_(t+1) )〗〖Q_t (s_(t+1),a_(t+1))〗-Q_t (s_t,a_t )]
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
# 如果智能体发生碰撞,下一个状态为终止状态
q_target = r
# 更新Q值 Q_t (s_t,a_t )+a* TD error
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
# 保存Q表格
self.q_table.to_csv("q_table.csv")
def check_state_exist(self, state):
# 如果当前状态不在Q表格的索引当中
if state not in self.q_table.index:
# 将当前状态添加到Q表格的索引中
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
)