SeqGAN解讀

SeqGAN的概念來自AAAI 2017的SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient一文。

Motivation

如題所示,這篇文章的核心思想是將GAN與強(qiáng)化學(xué)習(xí)的Policy Gradient算法結(jié)合到一起——這也正是D2IA-GAN在處理Generator的優(yōu)化時(shí)使用的技巧。
而該論文的出發(fā)點(diǎn)也是意識(shí)到了標(biāo)準(zhǔn)的GAN在處理像序列這種離散數(shù)據(jù)時(shí)會(huì)遇到的困難,主要體現(xiàn)在兩個(gè)方面:Generator難以傳遞梯度更新,Discriminator難以評(píng)估非完整序列。
對(duì)于前者,作者給出的解決方案對(duì)我來說比較熟悉,即把整個(gè)GAN看作一個(gè)強(qiáng)化學(xué)習(xí)系統(tǒng),用Policy Gradient算法更新Generator的參數(shù);對(duì)于后者,作者則借鑒了蒙特卡洛樹搜索(Monte Carlo tree search,MCTS)的思想,對(duì)任意時(shí)刻的非完整序列都可以進(jìn)行評(píng)估。

問題定義

根據(jù)強(qiáng)化學(xué)習(xí)的設(shè)定,在時(shí)刻t,當(dāng)前的狀態(tài)s被定義為“已生成的序列”


,記作

,而動(dòng)作a是接下來要選出的元素

,所以policy模型就是

值得一提的是,這里的policy模型是stochastic,輸出的是動(dòng)作的概率分布;而狀態(tài)的轉(zhuǎn)移則顯然是deterministic,一旦動(dòng)作確定了,接下來的狀態(tài)也就確定了。

根據(jù)Policy Gradient算法,Generator的優(yōu)化目標(biāo)是令從初始狀態(tài)開始的value(累積的reward期望值)最大化:


其中,

是完整序列的reward,

action-value函數(shù),是指“在狀態(tài)s下選擇動(dòng)作a,此后一直遵循著policy做決策,最終得到的value”。所以對(duì)于最右邊的式子我們可以這樣來理解:在初始狀態(tài)下,對(duì)于policy可能選出的每個(gè)y,都計(jì)算對(duì)應(yīng)的value,把這些value根據(jù)policy的概率分布加權(quán)求和,就得到了初始狀態(tài)的value。

action-value函數(shù)

接下來的關(guān)鍵是如何定義

因?yàn)镈iscriminator充當(dāng)了這個(gè)強(qiáng)化學(xué)習(xí)系統(tǒng)的environment,所以Discriminator的輸出應(yīng)當(dāng)作為reward。但是Discriminator只能對(duì)生成的完整序列進(jìn)行評(píng)估,因此目前只能對(duì)完整序列狀態(tài)的value進(jìn)行定義:



這是遠(yuǎn)遠(yuǎn)不夠的,必須要對(duì)任意狀態(tài)的value都有定義。

蒙特卡洛樹搜索(MCTS)

在評(píng)估任意時(shí)刻的序列時(shí),我們考慮的其實(shí)都是它能帶來的long-term reward,就像下圍棋或象棋一樣,每下一步棋都要以全局為考量。在圍棋和象棋的求解算法中,MCTS是一個(gè)很重要的組成部分,所以作者想到了將它應(yīng)用到當(dāng)前的問題。
從名字得知,這種算法屬于一種蒙特卡洛方法(Monte Carlo method)——根據(jù)維基百科,也稱統(tǒng)計(jì)模擬方法,是指使用隨機(jī)數(shù)(或更常見的偽隨機(jī)數(shù))來解決很多計(jì)算問題的方法。MCTS正是這樣一種基于統(tǒng)計(jì)模擬的啟發(fā)式搜索算法,常用于游戲的決策過程。
MCTS可以無限循環(huán),而每一次循環(huán)都由以下4個(gè)步驟構(gòu)成:

  • Selection:從根節(jié)點(diǎn)開始,連續(xù)選擇子節(jié)點(diǎn)向下搜索,直至抵達(dá)一個(gè)葉節(jié)點(diǎn)。子節(jié)點(diǎn)的選擇方法一般采用UCT(Upper Confidence Bound applied to trees)算法,根據(jù)節(jié)點(diǎn)的“勝利次數(shù)”和“游戲次數(shù)”來計(jì)算被選中的概率,保持了Exploitation和Exploration的平衡,是保證搜索向最優(yōu)發(fā)展的關(guān)鍵。
  • Expansion:在葉節(jié)點(diǎn)創(chuàng)建多個(gè)子節(jié)點(diǎn)。
  • Simulation:在創(chuàng)建的子節(jié)點(diǎn)中根據(jù)roll-out policy選擇一個(gè)節(jié)點(diǎn)進(jìn)行模擬,又稱為playout或者rollout。它和Selection的區(qū)別在于:Selection指的是對(duì)于搜索樹中已有節(jié)點(diǎn)的選擇,從根節(jié)點(diǎn)開始,有歷史統(tǒng)計(jì)數(shù)據(jù)作為參考,使用UCT算法選擇每次的子節(jié)點(diǎn);Simulation是簡(jiǎn)單的模擬,從葉節(jié)點(diǎn)開始,用自定義的roll-out policy(可以只是簡(jiǎn)單的隨機(jī)概率)來選擇子節(jié)點(diǎn),且模擬經(jīng)過的節(jié)點(diǎn)并不加入樹中。
  • Backpropagation:根據(jù)Simulation的結(jié)果,沿著搜索樹的路徑向上更新節(jié)點(diǎn)的統(tǒng)計(jì)信息,包括“勝利次數(shù)”和“游戲次數(shù)”,用于Selection做決策。

在SeqGAN中,實(shí)際上只應(yīng)用了上述的Simulation過程:對(duì)于非完整的序列


,以

(等同于Generator)作為roll-out policy,將剩余的T-t個(gè)元素模擬出來,這樣就可以利用Discriminator進(jìn)行評(píng)估了。為了減小對(duì)value估計(jì)的誤差,會(huì)進(jìn)行N次模擬,對(duì)這N個(gè)結(jié)果取平均值。
最終得到了完整的action-value函數(shù):

policy gradient計(jì)算

Generator目標(biāo)函數(shù)的梯度可以初步推導(dǎo)為:



在此基礎(chǔ)上,可以去掉期望項(xiàng),構(gòu)造一個(gè)無偏估計(jì)再繼續(xù)推導(dǎo):



源碼對(duì)loss的實(shí)現(xiàn)為:
  • 111行:x是一個(gè)batch生成的所有序列,原來是一個(gè)三維數(shù)組,這里進(jìn)行了reshape并轉(zhuǎn)化為one-hot vector,最終得到一個(gè)二維數(shù)組,每一行以one-hot的形式代表這些生成序列的每一個(gè)元素,行數(shù)是batch size*sequence length。

  • 113行:最終也是得到一個(gè)二維數(shù)組,行數(shù)與上面相同,每一行代表這些生成序列每個(gè)時(shí)刻t關(guān)于所有候選元素的log概率分布,形如

  • 114行:這里的括號(hào)對(duì)應(yīng)110行,運(yùn)算得到這些序列每個(gè)元素被選中的log likelihood,即
  • 116行:這些生成序列每個(gè)時(shí)刻的reward。

  • 117行:括號(hào)對(duì)應(yīng)于109行的結(jié)尾,括號(hào)內(nèi)的運(yùn)算得到了每個(gè)時(shí)刻的

    ,reduce_sum的意義是,對(duì)一個(gè)batch中所有序列的所有
    進(jìn)行總的求和,負(fù)號(hào)的作用則是把梯度上升問題轉(zhuǎn)化為梯度下降。雖然沒有顯式地計(jì)算期望值,但歸因于大量的取樣和學(xué)習(xí)率的存在,最終自動(dòng)推導(dǎo)出來的梯度是與上述公式相符的。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容