SeqGAN——對抗思想與增強(qiáng)學(xué)習(xí)的碰撞

保留初心,砥礪前行

SeqGAN這篇paper從大半年之前就開始看,斷斷續(xù)續(xù)看到現(xiàn)在,接下來的工作或許會與GAN + RL有關(guān),因此又把它翻出來,又一次仔細(xì)拜讀了一番。接下來就記錄下我的一點(diǎn)理解。

paper鏈接

1. 背景

GAN在之前發(fā)的文章里已經(jīng)說過了,不了解的同學(xué)點(diǎn)我,雖然現(xiàn)在GAN的變種越來越多,用途廣泛,但是它們的對抗思想都是沒有變化的。簡單來說,就是在生成的過程中加入一個可以鑒別真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的鑒別器,使生成器G和鑒別器D相互對抗,D的作用是努力地分辨真實(shí)數(shù)據(jù)和生成數(shù)據(jù),G的作用是努力改進(jìn)自己從而生成可以迷惑D的數(shù)據(jù)。當(dāng)D無法再分別出真假數(shù)據(jù),則認(rèn)為此時的G已經(jīng)達(dá)到了一個很優(yōu)的效果。
它的諸多優(yōu)點(diǎn)是它如今可以這么火爆的原因:

  • 可以生成更好的樣本
  • 模型只用到了反向傳播,而不需要馬爾科夫鏈
  • 訓(xùn)練時不需要對隱變量做推斷
  • G的參數(shù)更新不是直接來自數(shù)據(jù)樣本,而是使用來自D的反向傳播
  • 理論上,只要是可微分函數(shù)都可以用于構(gòu)建D和G,因?yàn)槟軌蚺c深度神經(jīng)網(wǎng)絡(luò)結(jié)合做深度生成式模型

它的最后一條優(yōu)點(diǎn)也恰恰就是它的局限,之前我發(fā)過的文章中也有涉及到,點(diǎn)點(diǎn)點(diǎn)點(diǎn)點(diǎn)我,在NLP中,數(shù)據(jù)不像圖片處理時是連續(xù)的,可以微分,我們在優(yōu)化生成器的過程中不能找到“中國 + 0.1”這樣的東西代表什么,因此對于離散的數(shù)據(jù),普通的GAN是無法work的。

2. 大體思路

這位還在讀本科的作者想到了使用RL來解決這個問題。

如上圖(左)所示,仍然是對抗的思想,真實(shí)數(shù)據(jù)加上G的生成數(shù)據(jù)來訓(xùn)練D。但是從前邊背景章節(jié)所述的內(nèi)容中,我們可以知道G的離散輸出,讓D很難回傳一個梯度用來更新G,因此需要做一些改變,看上圖(右),paper中將policy network當(dāng)做G,已經(jīng)存在的紅色圓點(diǎn)稱為現(xiàn)在的狀態(tài)(state),要生成的下一個紅色圓點(diǎn)稱作動作(action),因?yàn)镈需要對一個完整的序列評分,所以就是用MCTS(蒙特卡洛樹搜索)將每一個動作的各種可能性補(bǔ)全,D對這些完整的序列產(chǎn)生reward,回傳給G,通過增強(qiáng)學(xué)習(xí)更新G。這樣就是用Reinforcement learning的方式,訓(xùn)練出一個可以產(chǎn)生下一個最優(yōu)的action的生成網(wǎng)絡(luò)。

3. 主要內(nèi)容

不論怎么對抗,目的都是為了更好的生成,因此我們可以把生成作為切入點(diǎn)。生成器G的目標(biāo)是生成sequence來最大化reward的期望。

在這里把這個reward的期望叫做J(θ)。就是在s0和θ的條件下,產(chǎn)生某個完全的sequence的reward的期望。其中Gθ()部分可以輕易地看出就是Generator Model。而QDφGθ()(我在這里叫它Q值)在文中被叫做一個sequence的action-value function 。因此,我們可以這樣理解這個式子:G生成某一個y1的概率乘以這個y1的Q值,這樣求出所有y1的概率乘Q值,再求和,則得到了這個J(θ),也就是我們生成模型想要最大化的函數(shù)。

所以問題來了,這個Q值怎么求?
paper中使用的是REINFORCE algorithm 并且就把這個Q值看作是鑒別器D的返回值。

因?yàn)椴煌暾能壽E產(chǎn)生的reward沒有實(shí)際意義,因此在原有y_1到y(tǒng)_t-1的情況下,產(chǎn)生的y_t的Q值并不能在y_t產(chǎn)生后直接計(jì)算,除非y_t就是整個序列的最后一個。paper中想了一個辦法,使用蒙特卡洛搜索(就我所知“蒙特卡洛”這四個字可以等同于“隨意”)將y_t后的內(nèi)容進(jìn)行補(bǔ)全。既然是隨意補(bǔ)全就說明會產(chǎn)生多種情況,paper中將同一個y_t后使用蒙特卡洛搜索補(bǔ)全的所有可能的sequence全都計(jì)算reward,然后求平均。如下圖所示。

就這樣,我們生成了一些逼真的sequence。我們就要用如下方式訓(xùn)練D。

這個式子很容易理解,最大化D判斷真實(shí)數(shù)據(jù)為真加上D判斷生成數(shù)據(jù)為假,也就是最小化它們的相反數(shù)。

D訓(xùn)練了一輪或者多輪(因?yàn)镚AN的訓(xùn)練一直是個難題,找好G和D的訓(xùn)練輪數(shù)比例是關(guān)鍵)之后,就得到了一個更優(yōu)秀的D,此時要用D去更新G。G的更新可以看做是梯度下降。

其中,

αh代表學(xué)習(xí)率。

以上就是大概的seqGAN的原理。

4. 算法

首先隨機(jī)初始化G網(wǎng)絡(luò)和D網(wǎng)絡(luò)參數(shù)。

通過MLE預(yù)訓(xùn)練G網(wǎng)絡(luò),目的是提高G網(wǎng)絡(luò)的搜索效率。

使用預(yù)訓(xùn)練的G生成一些數(shù)據(jù),用來通過最小化交叉熵來預(yù)訓(xùn)練D。

  1. 開始生成sequence,并使用方程(4)計(jì)算reward(這個reward來自于G生成的sequence與D產(chǎn)生的Q值)。

  2. 使用方程(8)更新G的參數(shù)。

  3. 更優(yōu)的G生成更好的sequence,和真實(shí)數(shù)據(jù)一起通過方程(5)訓(xùn)練D。

以上1,2,3循環(huán)訓(xùn)練直到收斂。

5. 實(shí)驗(yàn)

論文的實(shí)驗(yàn)部分就不是本文的重點(diǎn)了,有興趣的話看一下paper就可以了。

后邊說的比較敷衍了,那...就這樣吧。


參考資料:SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient | 百度&google

如果你也喜歡機(jī)器學(xué)習(xí),并且也像我一樣在ML之路上努力,請關(guān)注我,這里會不定期進(jìn)行分享,希望可以與你一同進(jìn)步。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容