最近在学强化学习,看了不少的教程,还是觉得莫烦大神的强化学习教程写的不错。所以,特意仔细研究莫烦的RL代码。在这贴上自己的理解。 莫烦RL教程:https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/ 代码:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents
下图的是Sarsa算法的伪代码: 下图对比了Sarsa算法和Q-Learning算法
Qlearing和Sarsa更新Q表的不同之处在于,QLearning使用的Q现实是用的Q(S_)中的最大值(下一步不一定使用该(S_,A_),只是想象的), 而Sarsa使用的是下一步将要采用的Q(S_,A_)
# 编写了一个RL父类,方便Q-learning 和 Sarsa 子类继承。 # RL类中前一节都已解读过 class RL(object): def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): self.actions = action_space # a list self.lr = learning_rate self.gamma = reward_decay self.epsilon = e_greedy self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) def check_state_exist(self, state): if state not in self.q_table.index: # append new state to q table self.q_table = self.q_table.append( pd.Series( [0]*len(self.actions), index=self.q_table.columns, name=state, ) ) def choose_action(self, observation): self.check_state_exist(observation) # action selection if np.random.rand() < self.epsilon: # choose best action state_action = self.q_table.loc[observation, :] # some actions may have the same value, randomly choose on in these actions action = np.random.choice(state_action[state_action == np.max(state_action)].index) else: # choose random action action = np.random.choice(self.actions) return action # Q-learning 和 Sarsa的learn函数都不一样,所以需要各自编写 def learn(self, *args): pass # off-policy class QLearningTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): # super()继承QLearning的父类RL super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.loc[s_, :].max() # next state is not terminal else: q_target = r # next state is terminal self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update # on-policy class SarsaTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): # super()继承Sarsa的父类RL super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_, a_): # 检查S_是否在表中 self.check_state_exist(s_) # Q现实 q_predict = self.q_table.loc[s, a] # Q估计 if s_ != 'terminal': q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal else: q_target = r # next state is terminal # 更新Q表 self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # updateQ-learning是获取S,根据S选择A,使用A得到R和S_。之后使用max(Q(S_))来更新Q(S,A)。 更新使用的Q(S_,A_),下一步时不一定使用,这里只是想象的。
Sarsa是通过S、A,使用A得到R和S_,再根据S_选择A_。这个A_下一步肯定会使用。 哈哈,一个有趣的事,Sarsa使用的(S,A,R,S_,A_),连起来刚好就是Sarsa的拼写。
from maze_env import Maze from RL_brain import SarsaTable def update(): # 训练次数 for episode in range(100): # 初始化环境 observation = env.reset() # 根据state选择action action = RL.choose_action(str(observation)) # 时间步,回合 while True: # 刷新环境 env.render() # 采取action,返回 下一步的状态S_, 奖励R, 游戏结束flag observation_, reward, done = env.step(action) # 根据S_选择动作 action_ = RL.choose_action(str(observation_)) # Sarsa根据(S,A,R,S_,A_)来学习 RL.learn(str(observation), action, reward, str(observation_), action_) # 交换 S and A observation = observation_ action = action_ # 如果游戏结束,结束本次训练 if done: break # end of game print('game over') env.destroy() if __name__ == "__main__": env = Maze() RL = SarsaTable(actions=list(range(env.n_actions))) env.after(100, update) env.mainloop()