Sarsa
Sarsa原理
Sarsa的決策過程和Q-Learning類似,都是在Q表中挑選值較大的動作值施加在環(huán)境中來換取獎懲。不同之處在于更新方式。
如下圖所示,在狀態(tài)s采取行動a,到達下一個狀態(tài)s'時,Q-Learning算法會根據(jù)Q(s')的最大值,假設自己走能使maxQ(s')這條路,來更新剛剛走過的Q(s,a)。此時在s‘的agent還沒做出任何決策。
與Q-Learning不同,Sarsa會在狀態(tài)s’做出實際的行為(該行為并不一定能使Q(s')最大化),并根據(jù)實際做出行為的Q值來更新剛剛走過的Q(s,a)。

Sarsa是一種on-policy在線學習算法,Q-learning是一種off-policy離線學習算法。
Q-Learning更新狀態(tài)s時只看到maxQ(s'),忽視掉該行為可能帶來的懲罰,因此它是一個大膽、貪婪的策略。Sarsa算法在接近收斂時,允許對探索性的行動進行可能的懲罰(Q-Learning會直接忽略)這使得Sarsa算法更加保守。
Sarsa算法更新
還是agent走迷宮的例子
與Q-Learning的不同在于
- 兩個choose_action,第一個在循環(huán)外面(真的做出了行為并刷新了環(huán)境
- RL.learn(str(observation), action, reward, str(observation_), action_)這里多了一個action_
def update():
for episode in range(100):
# 初始化環(huán)境
observation = env.reset()
# Sarsa 根據(jù) state 觀測選擇行為
action = RL.choose_action(str(observation))
while True:
# 刷新環(huán)境
env.render()
# 在環(huán)境中采取行為, 獲得下一個 state_ (obervation_), reward, 和是否終止
observation_, reward, done = env.step(action)
# 根據(jù)下一個 state (obervation_) 選取下一個 action_
action_ = RL.choose_action(str(observation_))
# 從 (s, a, r, s, a) 中學習, 更新 Q_tabel 的參數(shù) ==> Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# 將下一個當成下一步的 state (observation) and action
observation = observation_
action = action_
# 終止時跳出循環(huán)
if done:
break
# 大循環(huán)完畢
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
Sarsa思維決策
定義一個RL父類
import numpy as np
import pandas as pd
class RL(object):
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
... # 和 QLearningTable 中的代碼一樣
def check_state_exist(self, state):
... # 和 QLearningTable 中的代碼一樣
def choose_action(self, observation):
... # 和 QLearningTable 中的代碼一樣
def learn(self, *args):
pass # 每種的都有點不同, 所以用 pass
定義Q-Learning子類
class QLearningTable(RL): # 繼承了父類 RL
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示繼承關系
def learn(self, s, a, r, s_): # learn 的方法在每種類型中有不一樣, 需重新定義
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
定義Sarsa子類
class SarsaTable(RL): # 繼承 RL class
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示繼承關系
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target 基于選好的 a_ 而不是 Q(s_) 的最大值
else:
q_target = r # 如果 s_ 是終止符
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新 q_table
Sarsa(lambda)
Sarsa(lambda)是一種Sarsa的提速算法
上述Sarsa算法是單步更新法,即每次獲取到reward,只更新獲取到reward的前一步
Sarsa-lambda就是更新獲取到reward的前l(fā)ambda步
lambda就是一個衰變值,取值在0-1之間。當lambda取0,就變成了Sarsa的單步更新,當lambda取 1,就變成了回合更新。lambda取值越大,離寶藏越近的步更新力度越大。
Sarsa(lambda)例子
還是那個走迷宮的例子
SarsaLambdaTable在算法更新迭代的部分,和SarsaTable 是一樣的,思維決策部分有所不同,如下圖所示:

從上圖可以看出,和Sarsa相比,Sarsa(lambda)算法中多了一個矩陣E (eligibility trace),它是用來保存在路徑中所經(jīng)歷的每一步,因此在每次更新時也會對之前經(jīng)歷的步進行更新
"""
This part of code is the Q learning brain, which is a brain of the agent.
All decisions are made in here.
View more on my tutorial page: https://morvanzhou.github.io/tutorials/
"""
import numpy as np
import pandas as pd
# 預設值里增加了trace_decay=0.9,也就是lambda的值
class SarsaLambdaTable(RL): # 繼承RL父類
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
# backward view, eligibility trace.
self.lambda_ = trace_decay
self.eligibility_trace = self.q_table.copy()
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
self.q_table = self.q_table.append(to_be_append)
# also update eligibility trace
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
error = q_target - q_predict
# increase trace amount for visited state-action pair
# Method 1:
# self.eligibility_trace.loc[s, a] += 1
# Method 2:
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
# Q update
self.q_table += self.lr * error * self.eligibility_trace
# decay eligibility trace after update
self.eligibility_trace *= self.gamma*self.lambda_
除了圖中和上面代碼這種更新方式, 還有一種會更加有效率. 我們可以將上面的這一步替換成下面這樣:
# 上面代碼中的方式:
self.eligibility_trace.ix[s, a] += 1
# 更有效的方式:
self.eligibility_trace.ix[s, :] *= 0
self.eligibility_trace.ix[s, a] = 1
他們的不同之處可以用這張圖來概括:

這是針對于一個 state-action 值按經(jīng)歷次數(shù)的變化,最上面是經(jīng)歷 state-action 的時間點,第二張圖是使用這種方式所帶來的 “不可或缺性值”:
self.eligibility_trace.ix[s, a] += 1
下面圖是使用這種方法帶來的 “不可或缺性值”:
self.eligibility_trace.ix[s, :] *= 0; self.eligibility_trace.ix[s, a] = 1
最后不要忘了,eligibility trace只是記錄每個回合的每一步,新回合(episode)開始的時候需要將 Trace 清零
for episode in range(100):
...
# 新回合, 清零
RL.eligibility_trace *= 0
while True: # 開始回