涉及WGAN的論文總共三篇:
WGAN前作:Towards Principled Methods for Training Generative Adversarial Networks
WGAN:Wasserstein GAN
改進的WGAN:Improved Training of Wasserstein GANs
代碼:各種GAN的實現(xiàn)
這三篇論文理論性都比較強,尤其是第一篇,涉及到比較多的理論公式推導。知乎鄭華濱的兩個論述令人拍案叫絕的Wasserstein GAN,Wasserstein GAN最新進展:從weight clipping到gradient penalty,更加先進的Lipschitz限制手法在理論方面已經(jīng)做了一個很好的介紹。不過對于很多數(shù)學不太好的同學(包括我自己),看著還是不太好理解,所以這里盡量站在做工程的角度,理一下這三篇文章的思路,這樣可以對作者的思路有一個比較清晰的理解。
GAN的Loss存在的問題
判別器的Loss:
<div align=center>
原始生成器Loss:
<div align=center>
Ian Goodfellow提出的改進的判別器Loss:
<div align=center>
在WGAN前作中指出,原始判別器的Loss在判別器達到最優(yōu)的時候,等價于最小化生成分布與真實分布之間的JS散度,由于隨機生成分布很難與真實分布有不可忽略的重疊以及JS散度的突變特性,使得生成器面臨梯度消失的問題;而對于Ian Goodfellow提出的改進的判別器Loss,在最優(yōu)判別器下,等價于既要最小化生成分布與真實分布直接的KL散度,又要最大化其JS散度,相互矛盾,導致梯度不穩(wěn)定,而且KL散度的不對稱性使得生成器寧可喪失多樣性也不愿喪失準確性,導致collapse mode現(xiàn)象。 ------令人拍案叫絕的Wasserstein GAN
總結(jié)起來,就是,不管判別器的Loss是第一種設(shè)計還是第二種設(shè)計,訓練到后面,判別器肯定是越來越好的,越來越趨近最優(yōu)判別器的??墒菃栴}就在于這里,為了得到最優(yōu)判別器,這會導致梯度消失,collapse mode的現(xiàn)象。于是,作者提出了一個解決方案:
WGAN前作針對分布重疊問題提出了一個過渡解決方案,通過對生成樣本和真實樣本加噪聲使得兩個分布產(chǎn)生重疊,理論上可以解決訓練不穩(wěn)定的問題,可以放心訓練判別器到接近最優(yōu),但是未能提供一個指示訓練進程的可靠指標,也未做實驗驗證。 ------令人拍案叫絕的Wasserstein GAN
上面的一大段介紹不太可能看懂,總之就是作者通過一大堆數(shù)學推導,發(fā)現(xiàn)原始GAN的判別器Loss有問題,作者提了一個湊合的方案,但是也沒實驗,不知道行不行。我們接下來重點關(guān)注WGAN,也就是作者給出的解決方案是什么?至于作者給出的方案為什么能解決前面分析的問題,就需要去仔細琢磨公式了。
WGAN原理
前面提到了原始GAN使用的loss本質(zhì)上來說是最小化KS散度,或者KL散度,這樣是不合理的,于是作者就提出用Wasserstein距離來作為衡量兩個分布之間的距離。作者根據(jù)Wasserstein距離又推導出了相應的Loss:
WGAN生成器Loss:
<div align=center>
WGAN判別器Loss:
<div align=center>
具體到代碼實現(xiàn)層面:
- 判別器最后一層去掉sigmoid
- 生成器和判別器的loss不取log
- 每次更新判別器的參數(shù)之后把它們的絕對值截斷到不超過一個固定常數(shù)c
-
不要用基于動量的優(yōu)化算法(包括momentum和Adam),推薦RMSProp,SGD也行
第四點是作者的經(jīng)驗得出的,前面三點都是有理論推導的。具體代碼實現(xiàn)如下,注意判別器最后一層不要sigmoid,loss設(shè)計的時候不要log:
# 生成器loss
errG = -netD(inputNegative)
# 判別器loss
errD_real = netD(inputPostive)
errD_fake = netD(inputNegative)
errD = -errD_real + errD_fake
其實從代碼實現(xiàn)上來看很直觀,就是按照作者說的上面幾個要點,把原始GAN的Loss改一下就好了。不過這樣改的理由,作者卻花了兩篇論文來論述。
WGAN存在的問題
實際實驗過程發(fā)現(xiàn),WGAN沒有那么好用,主要原因在于WAGN進行梯度截斷。梯度截斷將導致判別網(wǎng)絡趨向于一個二值網(wǎng)絡,造成模型容量的下降。
于是作者提出使用梯度懲罰來替代梯度裁剪。公式如下:
由于上式是對每一個梯度進行懲罰,所以不適合使用BN,因為它會引入同個batch中不同樣本的相互依賴關(guān)系。如果需要的話,可以選擇Layer Normalization。具體代碼實現(xiàn)如下:
gradients = tf.gradients(pred, x)[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),reduction_indices=range(1, x.shape.ndims)))
gp = tf.reduce_mean((slopes - 1.)**2)
d_loss = errD + gp * lambda
總結(jié)
總的來說,WGAN的三篇論文,前兩篇討論loss設(shè)計導致的問題,提出了新的loss設(shè)計方式,公式推導比較復雜。不過代碼實現(xiàn)起來很簡答。本博客也重點關(guān)注其實現(xiàn),以及簡要說了一下loss公式的形式,忽略了許多中間的理論細節(jié)。有興趣深挖的同學可以去翻一下論文原文。