PPO(Proximal Policy Optimization)(一)算法原理

之前在學(xué)習(xí)大模型的時候?qū)W習(xí)過 PPO 算法, 但是當時也沒有來得及對該算法進行梳理, 因此原理也基本忘了差不多了。 現(xiàn)在在做具身算法, 強化學(xué)習(xí)還是要重新?lián)旎貋恚?這里又需要重新?lián)旎卦撍惴ǎ?不想在吃虧忘了, 幫該原理及時整理一下, 方便后面快速回憶。

這次 PPO 算法的學(xué)習(xí)就沒有直接看論文了, 因為網(wǎng)上的博客以及視頻是非常多的, 參考了幾篇文章以及視頻進行整理, 有幾個我覺得還是說的蠻好的視頻和博客, 我在這里分享下對應(yīng)鏈接也是對作者的尊重。

一、 背景介紹

1.1 強化學(xué)習(xí)基本概念

蒙特卡洛公式 vs TD

先看下下面的圖


超級馬里奧

上述圖表示的元素有幾種分別Agent這里是馬里奧,表達的是游戲的主角, State(Observation)環(huán)境表達的是是隨著游戲的變化整體環(huán)境的狀態(tài)變化,以及環(huán)境 Environment給 Agent 的Reward游戲獎勵, 以及最后主角需要根據(jù)策略做出對應(yīng)的 action 動作,最終的目的是希望 Agent 獲得整體最大的 Reward 值。下面我們看下強化學(xué)習(xí)主要的幾個基本要素:

  • Action Space: 可選擇的動作, 比如向左走、向右走向上跳
  • Policy: 策略函數(shù), 輸入State, 輸出Action概率分布, 一般用\pi表示。如\pi(left|s_t) = 0.1、 \pi(up|s_t) = 0.2、\pi(right|s_t)=0.7, 我們一般希望其具備更豐富的多樣性
  • Trajectory: 軌跡用\tau表示, 一連串狀態(tài)和動作的序列。 {s_0, a_0, s_1, a_1, ...} , 其中有的s_{t+1} = f(s_t, a_t)有的狀態(tài)轉(zhuǎn)移確定狀態(tài),比如大模型每次預(yù)測下一個 token。 有的狀態(tài)轉(zhuǎn)移s_{t+1}=P(\cdot|s_t, a_t)隨機的, 比如游戲中開寶箱。
  • Return: 表示回報, 從當前時間點到游戲結(jié)束的 Reward獎勵的累計和, 更多的希望可以從長遠的考慮

1.2 期望概念

舉一個簡單例子, 小明考試情況如下, 請計算出小明考試期望結(jié)果

  • 小明考試 20% 的概率考 80 分
  • 小明考試 80% 的概率考 90 分

計算過程: Exception = 0.2 * 80 + 0.8 * 90 = 88
當然期望也可以表示每個可能結(jié)果的概率與其結(jié)果的值的乘積之和\mathrm{E}(x)_{x \sim p(x)} = \sum_{x} x * p(x) \approx \frac{1}{n} \sum_{i=1}^{n} x \quad x \sim p(x), 這里的約等于成立條件在于采樣次數(shù)在無窮大的情況才成立。

  • 目標: 訓(xùn)練一個 Policy神經(jīng)網(wǎng)絡(luò)\pi, 在所有狀態(tài)z下, 給出相應(yīng)的 Action, 得到 Return的期望最大。
  • 目標:訓(xùn)練一個 Policy 神經(jīng)網(wǎng)絡(luò)\pi, 在所有的 Trajctory 中,得到的期望最大。

1.3 軌跡的期望

如果我們將軌跡期望帶入上述公式應(yīng)該是什么情況呢?
下面將您提供的兩部分推導(dǎo)過程整合為一個完整的序列,展示策略梯度的完整推導(dǎo)過程:

首先,我們的目標是最大化期望回報,需要計算期望回報關(guān)于參數(shù)θ的梯度:

\begin{align*} E(R(\tau))_{\tau \sim P_{\theta}(\tau)} &= \sum_{\tau} R(\tau) P_{\theta}(\tau) \quad \text{(我們需要對θ求梯度,進行梯度上升以最大化期望)} \\ \nabla E(R(\tau))_{\tau \sim P_{\theta}(\tau)} &= \nabla \sum_{\tau} R(\tau) P_{\theta}(\tau) \\ &= \sum_{\tau} R(\tau) \nabla P_{\theta}(\tau) \quad \text{(交換求和與梯度運算)} \\ &= \sum_{\tau} R(\tau) \nabla P_{\theta}(\tau) \frac{P_{\theta}(\tau)}{P_{\theta}(\tau)} \quad \text{(分子分母同乘 } P_{\theta}(\tau)\text{)} \\ &= \sum_{\tau} \color{blue}{P_{\theta}(\tau)} R(\tau) \frac{\nabla P_{\theta}(\tau)}{P_{\theta}(\tau)} \quad \text{(重組項)} \\ &\approx \frac{1}{N} \sum_{n=1}^{N} R(\tau^{n}) \frac{\nabla P_{\theta}(\tau^{n})}{P_{\theta}(\tau^{n})} \quad \text{(利用大數(shù)定律思想, 的用樣本均值近似期望,其中 } \tau^n \sim P_{\theta}(\tau)\text{)} \\ &= \frac{1}{N} \sum_{n=1}^{N} R(\tau^{n}) \nabla \log P_{\theta}(\tau^{n}) \quad \text{(利用對數(shù)導(dǎo)數(shù)性質(zhì):} \nabla \log f(x) = \frac{\nabla f(x)}{f(x)}\text{)} \\ &= \frac{1}{N} \sum_{n=1}^{N} R(\tau^n) \nabla \log \prod_{t=1}^{T_n} P_\theta(a_n^t | s_n^t) \quad \text{(軌跡概率分解為單步動作概率的乘積)} \\ &= \frac{1}{N} \sum_{n=1}^{N} R(\tau^n) \sum_{t=1}^{T_n} \nabla \log P_\theta(a_n^t | s_n^t) \quad \text{(對數(shù)乘積轉(zhuǎn)為求和)} \\ &= \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_n} R(\tau^n) \nabla \log P_\theta(a_n^t | s_n^t) \quad \text{(交換求和順序)} \end{align*}

這個最終結(jié)果表明,策略梯度可以通過對每個軌跡的每個時間步,將該軌跡的總回報與該時間步動作概率的對數(shù)梯度相乘,再取平均值來近似計算, 最終的公式結(jié)果如下圖所示, 該函數(shù)是單調(diào)遞增的, 所以 :

  • 當一個 trajectory 得到的 Return>0, 則增大這個 trajctory里面所有狀態(tài)下這個采取當前action的概率
  • 當一個 trajectory 得到的 Return<0, 則減少這個 trajctory里面所有狀態(tài)下這個采取當前action的概率
    Policy gradient

1.4 損失函數(shù)設(shè)置

\text{Loss} = -\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_n} \color{blue} {R(\tau^n)}\log \color{red} {P_\theta(a_n^t | s_n^t)}
首先因為是期望, 我們希望期望越大越好, 因此是一個梯度上升的過程, 所以在損失函數(shù)前面要加上負號。其中公式中 \color{red} {P_\theta(a_n^t | s_n^t)}使用的是神經(jīng)網(wǎng)絡(luò)進行求解, 根據(jù)當前狀態(tài)求解出對應(yīng)的 action 的值是多少, 如下圖所示, 圖中紅色的部分代表是不同狀態(tài)下動作的概率:

Policy network

同時為了查看回報函數(shù), 我們讓 Agent 基于該神經(jīng)網(wǎng)絡(luò)嘗試多次游戲得到 n 個 trajectory 以及對應(yīng)的 Reward 值, 如下圖藍色的部分,注意這里的每一個 action 是隨機采樣, 不是選取最大概率值, 這樣就到的 loss 函數(shù)所有的值, 就可以進行模型訓(xùn)練。不斷的采集在訓(xùn)練就是所謂的 On Policy

Reward

上述基于 On Policy 會有一個問題, 就是模型在訓(xùn)練的過程中,大部分時間都是在采集數(shù)據(jù),訓(xùn)練非常慢, 這也就是 PPO 算法需要解決的問題。

1.5 損失函數(shù)的優(yōu)化

\text{Loss} = -\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_n} \color{blue} {R(\tau^n)}\log \color{red} {P_\theta(a_n^t | s_n^t)}

這里有兩個改進想法

  1. 首先我們的獎勵值Reward 希望的是看當前的 action 對整體Reward 的影響, 因此需要考慮的是當前狀態(tài)做了這個動作之后到結(jié)束狀態(tài) Reward累計的 Reward, 而不應(yīng)該考慮整個 trajectory 整體的 Reward, 因為一個動作只能影響之后的Reward而不能影響之前的。

  2. 一個動作是可以對接下來產(chǎn)生的 Reward 產(chǎn)生影響, 但是只能影響接下來幾步, 而且影響會不斷衰減,后面的 Reward 主要還是當前動作的影響。

根據(jù)上述的想法我們看下如何來優(yōu)化我們的損失函數(shù)。
Loss=-\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} \color{blue}{R\left(\tau^{n}\right)}\nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right) \color{blue}{\quad R\left(\tau^{n}\right) \rightarrow \sum_{t^{\prime}=t}^{T_{n}} \gamma^{t^{\prime}-t} r_{t^{\prime}}^{n}=R_{t}^{n}}

可以看到該函數(shù)不是對整體的過程求解 Reward 而是從 t到 T_n(結(jié)束時刻)進行求和, 同時引入衰減因子\gamma, \gamma < 1, 距離當前動作越遠, 當前動作對 Reward影響越小, 呈指數(shù)衰減

這里我們用R_{t}^{n}替代R\left(\tau^{n}\right), 總的思想是希望用R\left(\tau^{n}\right)表示當前動作對trajectory Returned 的影響, 這樣我們就可得到下面更新后的損失函數(shù):
Loss =-\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} \color{blue}{R_{t}^{n}}\nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)

  1. 但是還有一個改進想法, 就是在好的局勢下和壞的局勢下是不一樣的。如下圖所示, 比如在好的局勢下, 不同動作都能得到正的 Reward, 相反在壞的局勢下, 不論做什么動作得到的都是負的 Reward。 以好的局勢為例, 所有的動作都能得到正的 Reward, 那么算法就會增加所有動作的概率得到 Reward 大的概率增加的概率大一點,但是模型訓(xùn)練的速度很慢, 我們希望的是相對好的動作概率增加, 相對差的動作概率降低,以此來提高訓(xùn)練效率, 就是將所有的動作 Reward 減去 Baseline, 這樣就能反映當前動作相對其他動作的價值,有點類似歸一化的思想,因此需要引入相對 Reward 的概念, 就相當于我的 Reward 值有正也有負的概念。
    不同局勢下的 Reward

    再舉一個例子, 應(yīng)該會加深大家的理解

    我們可以看下上面的圖, 比如上面s_1s_2我們希望其移動到金幣的地方, 我們可以看到s_2離金幣最近, 所以V(s_2) > V(s_1)(V代表狀態(tài)價值函數(shù)),這里就可以理解s2比s1 有更好的優(yōu)勢 而對于s_2棋子來說, 向上的回報肯定大于向下的回報所以Q(s_2, up) > Q(s_2, down)(Q代表動作價值函數(shù))

所以我們需要對我們的 Loss 函數(shù)做進一步修改, 如下所示:
Loss=-\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}}\left(\color{blue}{R_{t}^{n}}-\color{red}{B\left(s_{n}^{t}\right)}\right) \nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)
這里的B(我們所謂的 baseline)其實是用神經(jīng)網(wǎng)絡(luò)進行求解的, 來分析當前的優(yōu)勢, 這就是Actor-Critic 算法, 用來做動作的就是 Actor, 對動作進行打分就是Critc, 用來對 Actor 進行打分。

  1. 下面介紹幾個概念
    Loss=\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}}\left(\color{blue}{R_{t}^{n}}-\color{red}{B\left(s_{n}^{t}\right)}\right) \nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)
  • Action-Value Function(動作價值函數(shù))\color{blue}{Q_{\theta}(s,a)}
    \color{blue}{{R_{t}^{n}}} 每次都是一次隨機采樣, 方差很大, 訓(xùn)練不穩(wěn)定, 需要無限多次采樣, 因此我們通過下面\color{blue}{Q_{\theta}(s,a)} 函數(shù)來解決

    \color{blue}{Q_{\theta}(s,a)} 在 狀態(tài) s下,做出任何動作a期望的回報, 也就是所謂的動作價值函數(shù) Action-Value Function

  • State-Value Function(狀態(tài)價值函數(shù)) \color{red}{V_{\theta}(s)}
    \color{red}{V_{\theta}(s)}表示在狀態(tài)s下期望的回報, 即為狀態(tài)價值函數(shù)State-Value Function

  • Advantage Function(優(yōu)勢函數(shù))
    \color{green}{A_{\theta}(s,a)} = \color{blue}{Q_{\theta}(s,a)} - \color{red}{V_{\theta}(s)} 表示的是在狀態(tài)s下,做出的動作a比其他動作能帶來多少優(yōu)勢。

根據(jù)上述公式我們可以進一步對損失函數(shù)進行優(yōu)化即:
Loss = \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} \color{green}{A_{\theta}\left(s_{n}^{t}, a_{n}^{t}\right)} \nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)

為了能進一步看下我們的優(yōu)勢函數(shù)里面的值分別如何表達我們擁有以下公式來說明。
首先對于動作價值函數(shù) \color{blue}{Q_{\theta}(s,a)}
\color{blue}{Q_{\theta}(s,a)} = \color{purple}{r_t + \gamma * V_{\theta}(s_{t+1})} , 其實就是當前t 時刻獲得的 Reward 值r_t加上衰減系數(shù)\gamma乘以t+1時刻的狀態(tài)價值函數(shù), 把上述等式帶入到優(yōu)勢函數(shù)中,得到如下的動作價值函數(shù), 可以看到公式只要狀態(tài)價值函數(shù)了, 這樣模型也只需要預(yù)測狀態(tài)價值函數(shù)而不是一開始又要預(yù)測狀態(tài)價值函數(shù)以及動作價值函數(shù)了。
\color{green}{A_{\theta}(s,a)} = \color{purple}{r_t + \gamma * V_{\theta}(s_{t+1})} - \color{red}{V_{\theta}(s)}

總結(jié)

我們看下其他時間優(yōu)勢函數(shù), 需要注意的是采樣的步數(shù)越多, 方差越大但是偏差越小
\begin{aligned} A_{\theta}^{1}\left(s_{t}, a\right) &= r_{t}+\gamma * V_{\theta}\left(s_{t+1}\right)-V_{\theta}\left(s_{t}\right) \\ A_{\theta}^{2}\left(s_{t}, a\right) &= r_{t}+\gamma * r_{t+1}+\gamma^{2} * V_{\theta}\left(s_{t+2}\right)-V_{\theta}\left(s_{t}\right) \\ A_{\theta}^{3}\left(s_{t}, a\right) &= r_{t}+\gamma * r_{t+1}+\gamma^{2} * r_{t+2}+\gamma^{3} V_{\theta}\left(s_{t+3}\right)-V_{\theta}\left(s_{t}\right) \\ &\quad \vdots \\ A_{\theta}^{T}\left(s_{t}, a\right) &= r_{t}+\gamma * r_{t+1}+\gamma^{2} * r_{t+2}+\gamma^{3} * r_{t+3}+\cdots+\gamma^{T} * r_{T}-V_{\theta}\left(s_{t}\right) \end{aligned}
為了表示方便我們用\delta_{t}^V表示 A_{\theta}^T
\begin{aligned} \delta_{t}^{V} &= r_{t}+\gamma * V_{\theta}\left(s_{t+1}\right)-V_{\theta}\left(s_{t}\right) \\ \delta_{t+1}^{V} &= r_{t+1}+\gamma * V_{\theta}\left(s_{t+2}\right)-V_{\theta}\left(s_{t+1}\right) \\ \end{aligned}
因此最終表示如下:
\begin{aligned} A_{\theta}^{1}\left(s_{t}, a\right) &= \delta_{t}^{V} \\ A_{\theta}^{2}\left(s_{t}, a\right) &= \delta_{t}^{V}+\gamma \delta_{t+1}^{V} \\ A_{\theta}^{3}\left(s_{t}, a\right) &= \delta_{t}^{V}+\gamma \delta_{t+1}^{V}+\gamma^{2} \delta_{t+2}^{V} \\ &\quad \vdots \end{aligned}
最終基于 Generalized Adavanced Estimation(GAE)廣義優(yōu)勢函數(shù)采樣所有的步數(shù), 它通過優(yōu)勢函數(shù)一步采樣、兩步采樣, 三步采樣分配不同的權(quán)重, 然后將他們的加和來表示 GAE 優(yōu)勢函數(shù), 函數(shù)如下所示 :
\begin{align*} A_{\theta}^{GAE}(s_t, a) &= \color{blue}{(1 - \lambda)}(A_{\theta}^{1} + \color{blue}{\lambda} A_{\theta}^{2} + \color{blue}{\lambda^2} A_{\theta}^{3} + \cdots) \\ \lambda &= 0.9 \\ A_{\theta}^{GAE} &= 0.1 A_{\theta}^{1} + 0.09 A_{\theta}^{2} + 0.081 A_{\theta}^{3} + \cdots \\ &= (1 - \lambda)\left(\delta_{t}^{V}\left(1 + \lambda + \lambda^{2} + \cdots\right) + \gamma \delta_{t+1}^{V} * \left(\lambda + \lambda^{2} + \cdots\right) + \cdots\right) \\ \end{align*}
其中上述函數(shù)為等比數(shù)列, 并基于等比數(shù)列求和公式進行化簡如下:
\begin{align*} &= (1 - \lambda)\left(\delta_{t}^{V} \frac{1}{1 - \lambda} + \gamma \delta_{t+1}^{V} \frac{\lambda}{1 - \lambda} + \cdots\right) \\ &= \sum_{b=0}^{\infty} (\gamma \lambda)^ \delta_{t+b}^{V} \end{align*}
上述公式表示在狀態(tài)s_t做動作 a 的優(yōu)勢, 并且平衡了采樣不同步帶來的方差以及偏差的問題,廣義優(yōu)勢估計(Generalized Advantage Estimation,簡稱GAE), 其實也是蒙特卡羅估計(MC, Mente Carlo)是強化學(xué)習(xí)中用于估算策略梯度的一種方法,特別適用于異策性策略優(yōu)化算法,如PPO(Proximal Policy Optimization)。它通過結(jié)合多步回報來改進優(yōu)勢函數(shù)的估計,從而在減少方差的同時盡量保持偏差不變。

1.6 重要幾個公式總結(jié)

  • \delta_{t}^{V} = r_{t} + \gamma * V_{\theta}(s_{t+1}) - V_{\theta}(s_{t})稱為TD(Temporal Difference)誤差

  • A_{\theta}^{GAE}(s_{t}, a) = \sum_{b=0}^{\infty} (\gamma \lambda)^ \delta_{t+b}^{V}, GAE 廣義優(yōu)勢估計, 它通過結(jié)合多步回報來改進優(yōu)勢函數(shù)的估計,從而在減少方差的同時盡量保持偏差不變, GAE的核心思想在于平衡偏差和方差之間的關(guān)系。通常情況下,單步的優(yōu)勢估計具有較高的方差但較低的偏差`,而使用蒙特卡洛方法計算的多步回報雖然方差較大但是偏差較小。GAE通過引入一個調(diào)節(jié)參數(shù)\gamma\lambda來權(quán)衡這兩者,得到一個更穩(wěn)定的估計值。
    上述公式中:
    1?? r_t是時間步 t 時的獎勵
    2?? \gamma是折扣因子,決定了未來獎勵的重要性
    3?? V(s)是狀態(tài)價值函數(shù), 代表從狀態(tài) s 開始按照當前策略行動獲得的期望回報。
    4?? \lambda 是GAE的一個超參數(shù),用于控制估計中的偏差和方差之間的權(quán)衡。當\lambda=0 時,GAE退化為一步TD誤差;當\lambda=1 時,GAE相當于累積了所有的TD誤差,接近于蒙特卡洛方法。

我們可以用 “算獎金” 這個生活場景,把 GAE 的 λ 和偏差、方差的權(quán)衡講明白,完全不用復(fù)雜公式:
先明確核心問題:GAE 是干嘛的?
假設(shè)你是公司老板,要給員工 “預(yù)估獎金”—— 這個 “預(yù)估獎金”,對應(yīng)強化學(xué)習(xí)里的 “狀態(tài)價值”(就是判斷 “現(xiàn)在這個局面好不好、未來能拿到多少收益”)。
但未來是不確定的:你沒法精準算到員工年底能賺多少,只能根據(jù) “當前線索” 估算。GAE 的作用,就是用 “逐步修正” 的方式算這個預(yù)估獎金,而 λ 就是控制 “修正到多遠” 的開關(guān)。

  1. 當 λ=0:只看 “眼前一步”(對應(yīng)一步 TD 誤差)
    λ=0 的意思是:只信 “馬上能看到的反饋”,后面的都不管。
    比如你給員工算獎金:
    只看 “這個月員工簽了 1 個小單”,就直接按 “1 個小單的提成” 預(yù)估全年獎金(比如預(yù)估 1 萬塊)。
    至于 “下個月會不會簽大單、年底會不會有額外獎勵”,你完全不考慮。
    特點(偏差 vs 方差):
    方差小:因為只看眼前一步,數(shù)據(jù)很確定(這個月的單是實打?qū)嵉赜校?,估算結(jié)果不會忽高忽低。
    偏差大:只看一步太短視了 —— 萬一員工下個月簽了 10 個大單,實際獎金能有 10 萬,你預(yù)估的 1 萬就差太遠了。
  2. 當 λ=1:看 “直到最后”(接近蒙特卡洛方法)
    λ=1 的意思是:要等 “所有結(jié)果都出來”,再回頭算總賬,中間的步驟都不提前預(yù)估。
    還是算獎金:
    你不按月預(yù)估,而是等到 “年底最后一天”,把員工一整年的所有單子、獎勵、扣罰都加起來,算出實際獎金(比如 12 萬)。
    相當于 “不預(yù)估,直接等真實結(jié)果”。
    特點(偏差 vs 方差):
    偏差?。阂驗橛玫氖?“最終真實結(jié)果”,沒有預(yù)估誤差(算出來 12 萬就是實際拿 12 萬)。
    方差大:結(jié)果完全看 “運氣”—— 比如今年市場好,員工能拿 20 萬;明年市場差,可能只拿 5 萬,估算結(jié)果波動極大。
  3. 當 0<λ<1:“眼前” 和 “未來” 折中(GA 的核心價值)
    λ 在 0 到 1 之間時,就是 “既看眼前,也適當考慮未來,但未來看得越遠,權(quán)重越小”。
    比如 λ=0.5(可以理解為 “未來每多一步,信任度打 5 折”):
    算獎金時,你會看:這個月的單(100% 信) + 下個月可能的單(信 50%) + 下下個月可能的單(信 25%) + …… 后面的越來越不信,直到忽略。
    特點:
    通過調(diào) λ,能在 “偏差” 和 “方差” 之間找平衡 ——
    想讓估算更穩(wěn)定(方差?。桶?λ 調(diào)小一點,多信眼前;
    想讓估算更準確(偏差?。桶?λ 調(diào)大一點,多信未來。
    一句話總結(jié)
    λ 就像 “望遠鏡的焦距”:
    λ=0:只看腳下(準但短視,偏差大、方差?。?;
    λ=1:看無窮遠(遠但模糊,偏差小、方差大);
    0<λ<1:調(diào)焦距,既看腳下也看前方,找一個 “看得清又看得遠” 的平衡點。


image.png
  • \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta}^{GAE}(s_{n}^{t}, a_{n}^{t}) \nabla \log P_{\theta}(a_{n}^{t} \mid s_{n}^{t})
方差 vs 偏差

上述的狀態(tài)價值函數(shù)V_{\theta}(s_{t})用神經(jīng)網(wǎng)絡(luò)去擬合, 一般可以和策略函數(shù)(預(yù)測 action)共用一個網(wǎng)絡(luò)參數(shù), 只是最后一層不同只要預(yù)測單一的值來代表當前狀態(tài)的價值即可, 這里的label公式為
\color{blue}{\quad R\left(\tau^{n}\right) \rightarrow \sum_{t^{\prime}=t}^{T_{n}} \gamma^{t^{\prime}-t} r_{t^{\prime}}^{n}=R_{t}^{n}}
整體如下圖所示:

二、 PPO(Proximal Policy Optimization)鄰近策略優(yōu)化

基于上述背景的介紹, 我們可以正式進入PPO 算法原理理解。 之前已經(jīng)說過了 On Policy的模式,問題是采集的數(shù)據(jù)僅能用一次就需要拋棄, 需要重新采集一次數(shù)據(jù)才能進行訓(xùn)練, 因此采集的效率很慢。 PPO 算法希望我們進入到 Off Policy 的模式,即采集的模型和訓(xùn)練的模型不是同一個, 且采集的數(shù)據(jù)可以被用來多次訓(xùn)練,這樣可以提高模型訓(xùn)練效率如下所示:


On Policy vs Off Policy

這里關(guān)于 Off Policy可以舉一個例子

例子

比如老師針對小明的表現(xiàn)表揚或批評小明, 如果表揚小明, 則小明會加強老師表揚的行為, 如果批評小明, 小明就會減少老師批評小明的行為, 這里小明調(diào)整的都是小明自己的行為, 所以這里就是 On Policy. 但是如果其他學(xué)生基于老師對小明的評價去調(diào)整自己的行為就是 Off Policy 的過程

如果老師批評小明上課玩手機,但是你上課玩手機的頻率比小明還多, 那你應(yīng)該調(diào)整你上課玩手機的行為比小明還要多一些。 如果你上課玩手機的頻率比小明少, 那你應(yīng)該調(diào)整你上課玩手機的行為比小明還要少一些.

2.1 重要性采樣(Importance Sampling)

\begin{align*} \mathrm{E}(f(x))_{x \sim p(x)} &= \sum_{x} f(x) * p(x) \\ &= \sum_{x} f(x) * p(x) \frac{q(x)}{q(x)} \\ &= \sum_{x} f(x) \frac{p(x)}{q(x)} * q(x) \\ &= \mathrm{E}\left(f(x) \frac{p(x)}{q(x)}\right)_{x \sim q(x)} \\ &\approx \frac{1}{N} \sum_{n=1}^{N} f(x) \frac{p(x)}{q(x)}_{\substack{x \sim q(x)}} \end{align*}
上述的公式想表示在 q 的分布下如何的到 q 的期望。上述\frac{p(x)}{q(x)}可以理解為相對重要性, 用于衡量新舊策略的差異, 這個比率表示策略變化程度?;谏鲜龉轿覀兛梢詫⒛繕撕瘮?shù)中的期望進行調(diào)整, 將 On Policy 轉(zhuǎn)為 Off Policy 。 這里的\theta^{'}是小明的策略,\theta則是你的策略. \theta^{'}的優(yōu)勢函數(shù)A_{\theta^{'}}^{GAE}就是老師對小明的評價,你不能直接用老師對小明的評價來更新自己的準則, 小明上課玩手機多, 但你上課玩手機少, 所以你修改自己的行為就少點。

得到如下的公式:
\begin{align*} &\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right) \quad \color{blue}{\nabla \log f(x)=\frac{\nabla f(x)}{f(x)}} \\ &= \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \frac{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)} \nabla \log P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right) \\ &= \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \frac{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)}\frac{\nabla P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)} \\ &= \frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \frac{\nabla P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)} \\ \\ &\text{Loss} = -\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \frac{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)} \end{align*}
可以看到用參考策略\theta^{'}來進行數(shù)據(jù)采樣來計算優(yōu)勢函數(shù),然后用訓(xùn)練策略\theta做某個動作的概率除以參考策略\theta^{'}做某個動作概率來調(diào)整優(yōu)勢函數(shù),這樣我們就可以用參考策略做數(shù)據(jù)采樣, 同時采樣數(shù)據(jù)可以用來多次用來訓(xùn)練 Policy網(wǎng)絡(luò)這樣解決 On Policy 訓(xùn)練效率低的問題。 這樣需要注意的是這個參考策略和訓(xùn)練策略不能在同一情況下給出各種動作的差別太大。 舉一個例子老師對和你差不多學(xué)生的評價不能差距太大, 不然你很難學(xué)習(xí)到對你有用的經(jīng)驗和教訓(xùn), 這里通過 KL 散度來進行約束來保證分布盡可能一致(PPO又稱PPO-Penalty),當然可以用截斷函數(shù)表示(PPO2又稱 PPO-Clip) , 公式如下所示(PPO 以及 PPO2 的公式):
\begin{align*} Loss_{ppo} &= -\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \frac{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)} + \beta KL(P_{\theta}, P_{\theta'}) \\ Loss_{ppo2} &= -\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_{n}} \min\left(A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right) \frac{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)}, \text{clip}\left(\frac{P_{\theta}\left(a_{n}^{t} \mid s_{n}^{t}\right)}{P_{\theta'}\left(a_{n}^{t} \mid s_{n}^{t}\right)}, 1 - \varepsilon, 1 + \varepsilon\right) A_{\theta'}^{GAE}\left(s_{n}^{t}, a_{n}^{t}\right)\right) \end{align*}

image.png

情況 A > 0(鼓勵動作) A < 0(減少動作)
r > 1 + \epsilon \min選裁剪后的(1 + \epsilon) \cdot A \min選未裁剪的r \cdot A
r < 1 - \epsilon \min選未裁剪的r \cdot A \min選裁剪后的(1 - \epsilon) \cdot A
1 - \epsilon \leq r \leq 1 + \epsilon 兩者相等,隨便選 兩者相等,隨便選

2.2 總損失函數(shù)

1. 優(yōu)化目標損失函數(shù)
這一點前面已經(jīng)提到, 為了限制策略的更新幅度, PPO2 引入了剪輯目標函數(shù), PPO的目標是找到一個折中:在保持改進的同時防止策略變化過大。

L^{CLIP}(\theta) = \mathbb{E}_t\left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right]

2.值函數(shù)優(yōu)化
PPO不僅優(yōu)化策略,還同時更新值函數(shù)V(s_t),通過最小化均方誤差來更新,該損失函數(shù)使得 Critic 能夠更準確地估計狀態(tài)值
L^{VF}(\theta) = \mathbb{E}_t\left[ \left( V(s_i; \theta) - R_t \right)^2 \right]

  • V(s_t; \theta): 當前狀態(tài)的值函數(shù)預(yù)測, 用來評估當前游戲的難度
  • R_t = \sum_{k=0}^{n} \gamma^k r_{t+k}(label)

3. 策略正則化
為了鼓勵策略的探索,PPO 引入了熵正則化項:
L^{ENT}(\theta) = \mathbb{E}_t \left[ H\left( \pi_\theta(s_t) \right) \right]= -\sum \pi_\theta(a|s_t) \log \pi_\theta(a|s_t)

H\left( \pi_\theta(s_t) \right):策略的熵,表示策略分布的不確定性。增加熵可以防止策略過早收斂到局部最優(yōu), 系數(shù)控制探索強度。

總損失函數(shù)
L(\theta) = -\mathbb{E}_t \left[ L^{CLIP}(\theta) - c_1 L^{VF}(\theta) + c_2 L^{ENT}(\theta) \right]
c_1和c_2:權(quán)重系數(shù),用于平衡策略優(yōu)化、值函數(shù)更新和熵正則化。該函數(shù)中的重要想法:

  • 核心目標: 優(yōu)化策略, 使r_t(\theta)A_t的改進在限制范圍內(nèi)
  • 限制更新幅度: 通過剪輯函數(shù)clip(), 避免函數(shù)更新過大導(dǎo)致不穩(wěn)定
  • 同時優(yōu)化值函數(shù): 通過L^{VF}(\theta), 提高 Critic 的預(yù)測精度
  • 探索與穩(wěn)定性平衡: 通過L^{ENT}(\theta), 鼓勵策略探索,通常希望 越大越好(更鼓勵探索、更不容易過早變得確定)。

2.3 PPO算法整體流程

  1. 采樣: 使用當前策略\pi_{\theta_{old}}與環(huán)境交互, 收集狀態(tài)s_t、動作a_t、獎勵r_t
  2. 計算優(yōu)勢函數(shù): 評估某個動作a_t在狀態(tài)s_t下相對平均表現(xiàn)的優(yōu)劣(優(yōu)勢函數(shù)A_t), 利用A_t引導(dǎo)策略改進
  3. 計算概率比率r_t(\theta): 比較新策略和舊策略對動作a_t的選擇概率, 這里的概率比例, 用于衡量新舊策略的差異, 這個比率表示策略變化程度
    r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}
  • \pi_\theta(a_t|s_t): 新策略對動作a_t的概率
  • \pi_{\theta_{\text{old}}}(a_t|s_t): 舊策略對動作a_t的概率
  1. 策略更新:如果更新過大(超出剪輯范圍1-\epsilon1+\epsilon, 會被懲罰), 保證更新幅度適中,既不太保守,也不太激進。
  2. 值函數(shù)更新: 用該損失函數(shù)優(yōu)化值函數(shù)L^{VF}(\theta) = \mathbb{E}_t\left[ \left( V(s_i; \theta) - R_t \right)^2 \right]
  3. 重復(fù)以上步驟: 通過多輪迭代, 使得策略逐步優(yōu)化, 直到收斂。
    image.png

參考:

  1. 【強化學(xué)習(xí)】近端策略優(yōu)化算法(PPO)萬字詳解(附代碼)
  2. PPO算法(附pytorch代碼)
    ˇˇ
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容