【莫煩強化學習】關于Sarsa算法

Sarsa

Sarsa原理

Sarsa的決策過程和Q-Learning類似,都是在Q表中挑選值較大的動作值施加在環(huán)境中來換取獎懲。不同之處在于更新方式。

如下圖所示,在狀態(tài)s采取行動a,到達下一個狀態(tài)s'時,Q-Learning算法會根據(jù)Q(s')的最大值,假設自己走能使maxQ(s')這條路,來更新剛剛走過的Q(s,a)。此時在s‘的agent還沒做出任何決策。

與Q-Learning不同,Sarsa會在狀態(tài)s’做出實際的行為(該行為并不一定能使Q(s')最大化),并根據(jù)實際做出行為的Q值來更新剛剛走過的Q(s,a)。

Sarsa是一種on-policy在線學習算法,Q-learning是一種off-policy離線學習算法。
Q-Learning更新狀態(tài)s時只看到maxQ(s'),忽視掉該行為可能帶來的懲罰,因此它是一個大膽、貪婪的策略。Sarsa算法在接近收斂時,允許對探索性的行動進行可能的懲罰(Q-Learning會直接忽略)這使得Sarsa算法更加保守。

Sarsa算法更新

還是agent走迷宮的例子

與Q-Learning的不同在于

  • 兩個choose_action,第一個在循環(huán)外面(真的做出了行為并刷新了環(huán)境
  • RL.learn(str(observation), action, reward, str(observation_), action_)這里多了一個action_
def update():
    for episode in range(100):
        # 初始化環(huán)境
        observation = env.reset()

        # Sarsa 根據(jù) state 觀測選擇行為
        action = RL.choose_action(str(observation))

        while True:
            # 刷新環(huán)境
            env.render()

            # 在環(huán)境中采取行為, 獲得下一個 state_ (obervation_), reward, 和是否終止
            observation_, reward, done = env.step(action)

            # 根據(jù)下一個 state (obervation_) 選取下一個 action_
            action_ = RL.choose_action(str(observation_))

            # 從 (s, a, r, s, a) 中學習, 更新 Q_tabel 的參數(shù) ==> Sarsa
            RL.learn(str(observation), action, reward, str(observation_), action_)

            # 將下一個當成下一步的 state (observation) and action
            observation = observation_
            action = action_

            # 終止時跳出循環(huán)
            if done:
                break

    # 大循環(huán)完畢
    print('game over')
    env.destroy()

if __name__ == "__main__":
    env = Maze()
    RL = SarsaTable(actions=list(range(env.n_actions)))

    env.after(100, update)
    env.mainloop()

Sarsa思維決策

定義一個RL父類

import numpy as np
import pandas as pd

class RL(object):
    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        ... # 和 QLearningTable 中的代碼一樣

    def check_state_exist(self, state):
        ... # 和 QLearningTable 中的代碼一樣

    def choose_action(self, observation):
        ... # 和 QLearningTable 中的代碼一樣

    def learn(self, *args):
        pass # 每種的都有點不同, 所以用 pass

定義Q-Learning子類

class QLearningTable(RL):   # 繼承了父類 RL
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    # 表示繼承關系

    def learn(self, s, a, r, s_):   # learn 的方法在每種類型中有不一樣, 需重新定義
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)

定義Sarsa子類

class SarsaTable(RL):   # 繼承 RL class

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    # 表示繼承關系

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # q_target 基于選好的 a_ 而不是 Q(s_) 的最大值
        else:
            q_target = r  # 如果 s_ 是終止符
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新 q_table

Sarsa(lambda)

Sarsa(lambda)是一種Sarsa的提速算法

上述Sarsa算法是單步更新法,即每次獲取到reward,只更新獲取到reward的前一步

Sarsa-lambda就是更新獲取到reward的前l(fā)ambda步

lambda就是一個衰變值,取值在0-1之間。當lambda取0,就變成了Sarsa的單步更新,當lambda取 1,就變成了回合更新。lambda取值越大,離寶藏越近的步更新力度越大。

Sarsa(lambda)例子

還是那個走迷宮的例子

SarsaLambdaTable在算法更新迭代的部分,和SarsaTable 是一樣的,思維決策部分有所不同,如下圖所示:

從上圖可以看出,和Sarsa相比,Sarsa(lambda)算法中多了一個矩陣E (eligibility trace),它是用來保存在路徑中所經(jīng)歷的每一步,因此在每次更新時也會對之前經(jīng)歷的步進行更新

"""
This part of code is the Q learning brain, which is a brain of the agent.
All decisions are made in here.

View more on my tutorial page: https://morvanzhou.github.io/tutorials/
"""

import numpy as np
import pandas as pd

# 預設值里增加了trace_decay=0.9,也就是lambda的值 
class SarsaLambdaTable(RL): # 繼承RL父類
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

        # backward view, eligibility trace.
        self.lambda_ = trace_decay
        self.eligibility_trace = self.q_table.copy()

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            to_be_append = pd.Series(
                    [0] * len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            self.q_table = self.q_table.append(to_be_append)

            # also update eligibility trace
            self.eligibility_trace = self.eligibility_trace.append(to_be_append)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        error = q_target - q_predict

        # increase trace amount for visited state-action pair

        # Method 1:
        # self.eligibility_trace.loc[s, a] += 1

        # Method 2:
        self.eligibility_trace.loc[s, :] *= 0
        self.eligibility_trace.loc[s, a] = 1

        # Q update
        self.q_table += self.lr * error * self.eligibility_trace

        # decay eligibility trace after update
        self.eligibility_trace *= self.gamma*self.lambda_

除了圖中和上面代碼這種更新方式, 還有一種會更加有效率. 我們可以將上面的這一步替換成下面這樣:

# 上面代碼中的方式:
self.eligibility_trace.ix[s, a] += 1

# 更有效的方式:
self.eligibility_trace.ix[s, :] *= 0
self.eligibility_trace.ix[s, a] = 1

他們的不同之處可以用這張圖來概括:

這是針對于一個 state-action 值按經(jīng)歷次數(shù)的變化,最上面是經(jīng)歷 state-action 的時間點,第二張圖是使用這種方式所帶來的 “不可或缺性值”:

self.eligibility_trace.ix[s, a] += 1

下面圖是使用這種方法帶來的 “不可或缺性值”:

self.eligibility_trace.ix[s, :] *= 0; self.eligibility_trace.ix[s, a] = 1

最后不要忘了,eligibility trace只是記錄每個回合的每一步,新回合(episode)開始的時候需要將 Trace 清零

for episode in range(100):
    ...

    # 新回合, 清零
    RL.eligibility_trace *= 0

    while True: # 開始回
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容