DDPG算法解析

時間會讓你忘記我嗎

直接看名字就能看出DDPG(Deep Deterministic Policy Gradient )其實就是DPG(Deterministic Policy Gradient )的深度神經(jīng)網(wǎng)絡(luò)版本,它采用Actor-Critic架構(gòu),用來解決連續(xù)控制問題。

其實當(dāng)初在我剛學(xué)了解決離散控制問題的方法的時候,就思考過如果換成連續(xù)控制問題該怎么辦,然后再看DPG所使用的方法之后,發(fā)現(xiàn)跟我想的是一模一樣....
所以,DDPG解決強化學(xué)習(xí)問題的思路跟那些解決離散問題的AC架構(gòu)方法并沒多大區(qū)別,就是做了點微小的改動以適應(yīng)連續(xù)問題罷了。因此,有離散控制問題經(jīng)驗的你,并且比較懂深度學(xué)習(xí)的話,在理解DDPG的時候可以說是非常簡單,甚至覺得不值一提...

來看一個連續(xù)控制問題

機械手臂

如圖,假如我們想用強化學(xué)習(xí)訓(xùn)練一個策略來控制機械手臂,上面的軸可以在[0, 2\pi]之間轉(zhuǎn)動, 下面的軸可以在[0, \pi] 之間轉(zhuǎn)動,那么它的動作空間將會是一個多維的連續(xù)空間:
A \in [0, 2\pi] *[0, \pi]

在有無窮多個action的時候,我們要怎么來實現(xiàn)策略網(wǎng)絡(luò)呢?

回想一下在離散AC框架下的策略網(wǎng)絡(luò),它是輸入狀態(tài)s, 輸出a的概率分布\pi (a|s) :

離散控制的策略網(wǎng)絡(luò)

因為連續(xù)控制問題有無數(shù)個action, 顯然像離散問題那樣通過輸出層softmax后的n個有限action的概率的方式是行不通的。

因為我太了解深度學(xué)習(xí)這一套東西了,所以面對這個問題的時候,直接就想到了兩個解決方案:

  • 確定策略(就是本文要講的DDPG的方法): 既然沒法輸出動作的概率分布,那我用整個策略網(wǎng)絡(luò)代表概率分布,將分類問題改為回歸問題,直接輸出確定動作不就可以了嘛...
  • 隨機策略:不是要輸出分布嘛,不能一個個給,我輸出一個高斯分布的均值和方差不就行了嘛...

就這樣兩個我想當(dāng)然就想到的方法,然后發(fā)現(xiàn)業(yè)界就是這么玩的....
既然本文是講DDPG,自然,我們就沿著第一個想法來實現(xiàn)。
于是我們可以把上面的策略網(wǎng)絡(luò)改造成這樣:

確定策略網(wǎng)絡(luò)

讓神經(jīng)網(wǎng)絡(luò)直接輸出每個機械臂需要轉(zhuǎn)動多少的動作,幾根機械臂就輸出幾維。
這樣,我們就可以利用這個網(wǎng)絡(luò)的輸出動作來操作機械臂,得到相應(yīng)的transition (s_t, a_t, r_t, s_{t+1}), 接下來就可以按照AC架構(gòu)的老路來訓(xùn)練模型了~

再來看看模型的更新過程

DDPG

Critic更新 (更新價值網(wǎng)絡(luò)參數(shù)w

價值網(wǎng)絡(luò)擬合的目標(biāo)一般跟DQN網(wǎng)絡(luò)一樣是最大動作價值函數(shù)Q^*, 期望顯然沒法求,于是通過蒙特卡洛方法,使用觀測值q(s,a;w)來近似,再通過TD算法來改進w:
q_t = q(s_t,a_t;w) \\ q_{t+1} = q(s_{t+1}, a'_{t+1}; w) \\ q_{target} = r_t + \gamma q_{t+1}
于是 TD error為:
\delta_t = q_t - q_{target}
然后通過TD error 梯度下降來更新網(wǎng)絡(luò)參數(shù)w :
w \leftarrow w - \alpha \frac{\partial \delta}{ \partial w}

Actor更新 (更新策略網(wǎng)絡(luò)參數(shù)\theta

Critic 輸出的價值代表了Actor預(yù)測動作的好壞,因此策略網(wǎng)絡(luò)的目標(biāo)是最大化價值Value ,自然就想到了用梯度上升法來最大化q(s,a;w) ,于是,我們可以對 q(s,a;w)\theta 的梯度,讓我們將策略網(wǎng)絡(luò)記作\pi(s;\theta)
pg = \frac{\partial q(s, \pi(s;\theta);w )}{\partial \theta} = \frac{\partial q(s,a;w)}{\partial a} * \frac{\partial a}{\partial \theta}
然后用梯度上升更新\theta :
\theta \leftarrow \theta + \alpha'* pg

優(yōu)化高估或低估問題

觀察上面的推導(dǎo)過程,我們?nèi)菀装l(fā)現(xiàn),這玩意跟DQN類似,因為bootstraping的通病,一開始低估了就會不斷低估,一開始高估了就會不斷高估,將會使得估計誤差一邊倒,導(dǎo)致學(xué)習(xí)的效果不好。為了處理這個問題,有很多種解決方案,大概就是跟DQN 差不多,DDPG就是這么做的。

引入target network

其實就是加入一個延遲更新策略,分別用兩個網(wǎng)絡(luò)來分別估計t+1時刻和t 時刻的值,即:
q_t = q(s_t,a_t;w) \\ q_{t+1} = q(s_{t+1}, a'_{t+1}; w')

a_t = \pi(s; \theta) \\ a_{t+1} = \pi(s_t+1; \theta')
這樣一來就隔斷了用自己的估計來估計自己,避免了不斷被強化的傾向。但是,實際更新target network參數(shù)的過程采用的是這樣一種方式:
w' = \tau ~w + (1-\tau) w' , \tau \in (0,1) \\ \theta' = \tau ~\theta + (1-\tau) \theta' , \tau \in (0,1)
因為target net的參數(shù)還是依賴于原來的網(wǎng)絡(luò)參數(shù),這種傳遞無法完全避免。

經(jīng)驗回放

通常經(jīng)驗回放可以使算法更加穩(wěn)定,因為僅僅使用新數(shù)據(jù)容易導(dǎo)致網(wǎng)絡(luò)過擬合使得訓(xùn)練終止,這給了樣本有了更多的學(xué)習(xí)機會,當(dāng)然如果使用過多的經(jīng)驗也會降低學(xué)習(xí)速度,這需要一定程度上進行權(quán)衡。

當(dāng)然,還有很多常見的方法都可以... 根據(jù)需要來。

總結(jié)

  • DDPG是一種off-policy的算法
  • DDPG只能用于連續(xù)動作空間的環(huán)境
  • DDPG可以被看作是連續(xù)動作空間環(huán)境下的DQN

相關(guān)論文

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容