GAN生成對抗網(wǎng)絡(luò)

簡介

生成對抗網(wǎng)絡(luò)(以下簡稱GAN)是通過讓兩個神經(jīng)網(wǎng)絡(luò)相互博弈的方式進行學(xué)習(xí),可以根據(jù)原有的數(shù)據(jù)集生成以假亂真的新的數(shù)據(jù),舉個不是很恰當?shù)睦?,類似于造假鞋,莆田藝術(shù)家通過觀察真鞋,模仿真鞋的特點造出假鞋并賣給消費者,消費者收到鞋子后將它與網(wǎng)上的真鞋信息進行對比找瑕疵,并給出反饋,比如標不正,氣墊彈性不好,莆田藝術(shù)家根據(jù)消費者給出的反饋積極地改進工藝,經(jīng)過不懈努力后最終造出了可以忽悠消費者的假鞋。
在上述情景中,莆田藝術(shù)家相當于生成器,消費者相當于辨別器,在造假的過程中,生成器和判別器一直處于對抗狀態(tài)。
我們把上述情景抽象為神經(jīng)網(wǎng)絡(luò)。首先,通過對生成器輸入一個分布的數(shù)據(jù),生成器通過神經(jīng)網(wǎng)絡(luò)模仿生成出一個輸出(假鞋),將假鞋與真鞋的信息共同輸入到判別器中。然后,判別器通過神經(jīng)網(wǎng)絡(luò)學(xué)著分辨兩者的差異,做一個分類判斷出這雙鞋是真鞋還是假鞋。

這樣,生成器不斷訓(xùn)練為了以假亂真,判別器不斷訓(xùn)練為了區(qū)分二者。最終,生成器真能完全模擬出與真實的數(shù)據(jù)一模一樣的輸出,判別器已經(jīng)無力判斷。基于伊恩·古德費洛最早對 GAN 的定義,GAN 實際上是在完成這樣一個優(yōu)化任務(wù):
\min_{G}\max_{D}V(D,G)=E_{p_{data}} \left ( x \right ) [\log D(x)]+E_{p_{z}}\left ( z \right ) [\log (1-D(G(z))] \tag{1}式中,G 表示生成器;D 表示判別器;V 是定義的價值函數(shù),代表判別器的判別性能,該數(shù)值越大性能越好;p_{data}(x) 表示真實的數(shù)據(jù)分布;p_{z}(z) 表示生成器的輸入數(shù)據(jù)分布;E 表示期望。
第一項 E_{p_{data}}\left ( x \right )[\log D(x)] 是依據(jù)真實數(shù)據(jù)的對數(shù)函數(shù)損失而構(gòu)建的。具體可以理解為,最理想的情況是,判別器 D 能夠?qū)谡鎸崝?shù)據(jù)的分布數(shù)據(jù)給出 1 的判斷。所以,通過優(yōu)化 D 最大化這一項可以使 D(x)=1。其中,x 服從 p_{data}(x) 分布。
第二項,E_{p_{z}}\left ( z \right ) [\log (1-D(G(z))],是相對生成器的生成數(shù)據(jù)而言的。我們希望,當喂給判別器的數(shù)據(jù)是生成器的生成數(shù)據(jù)時,判別器能輸出 0。由于 D 的輸出是,輸入數(shù)據(jù)是真實數(shù)據(jù)的概率,那么 1-D(輸入) 是,輸入數(shù)據(jù)是生成器生成數(shù)據(jù)的概率,通過優(yōu)化 D 最大化這一項,則可以使 D(G(z))=0。其中,z 服從 p_{z} ,也就是生成器的生成數(shù)據(jù)分布。
生成器與判別器是對抗的關(guān)系,價值函數(shù)代表了判別器的判別性能。那么,通過優(yōu)化 G 能夠在第二項 E_{p_{z}}\left ( z \right ) [\log (1-D(G(z))] 上迷惑判別器,讓判別器對于 G(z) 這個輸入,盡可能地得到 D(G(z))=1。本質(zhì)上,生成器就是在最小化這一項,也就是在最小化價值函數(shù)。

KL 散度

為了界定兩個數(shù)據(jù)分布,也就是真實數(shù)據(jù)和生成器生成數(shù)據(jù)之間的差異,需要引入 KL 散度。
D_{KL}(P||Q)=E_{p(x)}[\log\frac{p(x)}{q(x)}]=\int_{x}p(x)\log\frac{p(x)}{q(x)} \tag{2} KL 散度具有非負性。
當且僅當 P,Q 在離散型變量下是相同的分布時,即 p(x)=q(x),D_{KL}(P||Q)=0
KL 散度衡量了兩個分布差異的程度,經(jīng)常被視為兩種分布間的距離。
要注意的是,D_{KL}(P||Q)\neq D_{KL}(Q||P),即 KL 散度沒有對稱性。

最優(yōu)判別器

將價值函數(shù)里的生成器固定不動,將期望寫成積分的形式有:
V(D)=\int_{x}p_{data}(x)\log(D(x))+p_{g}(x)\log(1-D(x))dx \tag{3}整個式子中,只有一個變量 D。次數(shù),對被積函數(shù),令 y=D(x),a=p_{data}(x),b=p_{g}(x),a,b 均為常數(shù)。那么,被積函數(shù)變?yōu)椋?br> f(y)=a\log y + b\log(1-y) \tag{4}為了找到最優(yōu)值 y,需要對上式求一階導(dǎo)數(shù)。而且,在 a+b\neq 0 的情況下有:
f'(y)=0 \rightarrow \frac{a}{y}+\frac{1-y}=0 \rightarrow y = \frac{a}{a+b} \tag{5}驗證 f(y) 的二階導(dǎo)數(shù) f''(y)<0,則 \frac{a}{a+b} 這個點為極大值,這個事實給出了最優(yōu)判別器的存在可能性。

盡管在實踐中我們并不知道 a=p_{data}(x),也就是真實的數(shù)據(jù)的分布。但我們在利用深度學(xué)習(xí)訓(xùn)練判別器時,可以讓 D 向這個目標逐漸逼近。

最優(yōu)生成器

若最優(yōu)的判別器為:
D=\frac{p_{data}(x)}{p_{data}(x) + p_{g}(x)} \tag{6}我們將其代入 V(G,D),此時價值函數(shù)里只有 G 這一個變量:
V(G)=\int_{x}p_{data}(x)\log\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}+p_{g}(x)\log(1-\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)})dx \tag{7}此時,通過變換,我們可以得到下面的式子:
V(G)=-\log2\int_{x}p_{g}(x)+p_{data}(x)dx+\int_{x}p_{data}(x)(\log2+\log\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)})+p_{g}(x)(\log2+\log\frac{p_{g}(x)}{p_{data}(x)+p_{g}(x)})dx \tag{8}這個變換比較復(fù)雜,大家可以檢驗步與步之間的恒等性判斷。根據(jù)對數(shù)的一些基本變換,可以得到:
\log2+\log\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}=\log\frac{2p_{data}(x)}{p_{data}(x)+p_{g}(x)}=\log\frac{p_{data}(x)}{(p_{data}(x)+p_{g}(x))/2} \tag{9}最終得到:
V(G)=-\log4+\int_{x}p_{data}(x)\log\frac{p_{data}(x)}{(p_{data}(x)+p_{g}(x))/2}dx+\int_{x}p_{g}(x)\log\frac{p_{g}(x)}{(p_{data}(x)+p_{g}(x))/2}dx \tag{10} V(G)=-\log4+D_{KL}(p_{data}||\frac{p_{data}+p_{g}}{2})+(p_{g}||\frac{p_{data}+p_{g}}{2}) \tag{11}因為 KL 散度的非負性,那么就可以知道 -\log4 就是 V(G) 的最小值,而且最小值是在當且僅當 p_{data}(x)=p_{g}(x) 時取得。這其實就是真實數(shù)據(jù)分布等于生成器的生成數(shù)據(jù)分布,可以從數(shù)學(xué)理論上證明了它的存在性和唯一性。

GAN的實現(xiàn)過程

生成器的輸入:即上面的 p_{z}(z),我們當然不能讓這個分布任意化,一般會設(shè)為常見的分布類型,如高斯分布、均勻分布等等,然后生成器基于這個分布產(chǎn)生的數(shù)據(jù)生成自己的偽造數(shù)據(jù)來迷惑判別器。
期望如何模擬:實踐中,我們是沒有辦法利用積分求數(shù)學(xué)期望的,所以一般只能從無窮的真實數(shù)據(jù)和無窮的生成器中采樣以逼近真實的數(shù)學(xué)期望。
近似價值函數(shù):若給定生成器 G,并希望計算 maxV(G,D) 以求得判別器 D。那么,首先需要從真實的數(shù)據(jù)分布 p_{data}(x) 中采樣 m 個樣本 {??^{1}, ??^{2}, \dots, ??^{??}}。并從生成器的輸入,即 p_{z}(z) 中采樣 m 個樣本 {\tilde{x}^{1}, \tilde{x}^{2}, \dots, \tilde{x}^{m}}。因此,最大化價值函數(shù) V(G,D) 就可以使用以下表達式近似替代:
\tilde{V}=\frac{1}{m}\sum_{i=1}^{m}\log D(x^{i})+\frac{1}{m}\sum_{i=1}^{m}\log(1-D(G(\tilde{x}^{i}))) \tag{12}可以把 GAN 的訓(xùn)練過程總結(jié)為:

  1. 從真實數(shù)據(jù) p_{data}(x) 采樣 m 個樣本 {??^{1},??^{2}...,??^{??}};
  2. 從生成器的輸入,即噪聲數(shù)據(jù) p_{z}(z) 采樣 m 個樣本 {\tilde{x}^{1},\tilde{x}^{2},...,\tilde{x}^{m}};
  3. 將噪聲樣本 {\tilde{x}^{1}, \tilde{x}^{2}, ..., \tilde{x}^{m}} 投入到生成器中生成{G(\tilde{x}^{1}),G(\tilde{x}^{2}),...,G(\tilde{x}^{m})};
  4. 通過梯度上升的方法,極大化價值函數(shù),更新判別器的參數(shù);
  5. 從生成器的輸入,即噪聲數(shù)據(jù) p_{z}(z) 另外采樣 m 個樣本{z^{1},z^{2},...,z^{m}};
  6. 將噪聲樣本 {z^{1},z^{2},...,z^{m}} 投入到生成器中生成 {G(z^{1}),G(z^{2}),...,G(z^{m})};
  7. 通過梯度下降的方法,極小化價值函數(shù),更新生成器的參數(shù)。

利用PyTorch搭建GAN生成手寫識別數(shù)據(jù)

安裝GPU版本PyTorch

  • 打開終端,在conda 配置中添加清華源
  • 編輯~/.condarc,將- defaults整行刪除
  • 安裝PyTouch GPU版本
    使用conda安裝,不用自己額外配置依賴包和版本兼容問題,conda會自動配置好,而且可以直接在jupyter中調(diào)用,非常方便。
    一般需要等待很長時間,而且會經(jīng)常中斷,中斷直接再重復(fù)運行安裝命令即可,會繼續(xù)安裝之前沒裝上的

    得益于國內(nèi)無與倫比的網(wǎng)絡(luò)環(huán)境,100Mb的寬帶完全失靈,下載了大概一個小時,中途中斷了三四次,終于裝好了!!我感覺天快亮了... ...

訓(xùn)練GAN

為了方便可視化,直接用jupyter notebook

  • 首先,導(dǎo)入需要用的模塊
  • 下載并解壓mnist數(shù)據(jù)集


    transform 函數(shù)允許我們把導(dǎo)入的數(shù)據(jù)集按照一定規(guī)則改變結(jié)構(gòu),我們在這里引入了 Normalize 將會把 Tensor 正則化。即:Normalized_image=(image-mean)/std。這樣做的目的是便于后續(xù)的訓(xùn)練。

  • 接下來,搭建深度學(xué)習(xí)模型,用于構(gòu)建判別器和生成器。這里通過引入 nn.Module 基類的方法來搭建
    判別器構(gòu)建過程,遵照 PyTorch 的 Sequential 網(wǎng)絡(luò)搭建法。我們用 4 層網(wǎng)絡(luò)結(jié)構(gòu),并把每層都使用全連接配上 LeakyReLU 激活再帶上 dropout 防止過擬合。最后一層,用 sigmoid 保證輸出值是一個 01 之間的概率值。設(shè)計前饋過程函數(shù)時,注意把每個樣本大小 28\times28 的輸入矩陣先轉(zhuǎn)換為 784 的向量用于全連接。

  • 接下來構(gòu)建生成器。本模型中的設(shè)定生成器的每個輸入樣本是大小為 100 的向量,通過全連接層配上 LeakyReLU 激活搭建,最后一層用 tanh 激活,且保證每個樣本輸出是一個 784 的向量。

  • 接下來實例化生成器與判別器,設(shè)定學(xué)習(xí)率和損失函數(shù)。價值函數(shù)按照定義是:
    \tilde{V}=\frac{1}{m}\sum_{i=1}^{m} \log D(x^{i})+\frac{1}{m}\sum_{i=1}^{m} \log(1-D(G(\tilde{x}^{i}))) \tag{13}PyTorch 中,BCELoss 表示二項 Cross Entropy,它的展開形式是:
    -[y\log x + (1-y)\log(1-x)] \tag{14}其中 ylabel,x 是輸出。那么,對于 01 這兩種 label 而言,當 y=0,上式第一項不存在,就剩下 \tilde{V} 的第二項。當 y=1,上式第二項不存在,就剩下 \tilde{V} 的第一項。那么 BCELoss 的結(jié)構(gòu)就與損失函數(shù) \tilde{V} 相同,只不過我們定義的損失函數(shù)有對真實數(shù)據(jù)與對生成器生成的數(shù)據(jù)兩種情況的輸出。

  • 接下來,就可以定義如何訓(xùn)練判別器了。值得注意的是,這里需要設(shè)置 zero_grad() 來消除之前的梯度,以免造成梯度疊加。此外,我們通過將真實數(shù)據(jù)的損失和偽造數(shù)據(jù)的損失兩部分相加,作為最終的損失函數(shù)。然后,通過后向傳播,用之前的判定器優(yōu)化器優(yōu)化,通過降低 BCELoss 來增大價值函數(shù)的值。

  • 同樣,接下來需要定義生成器的訓(xùn)練方法。注意,這里的 real_labels 在之后將設(shè)為 1。因為對于所有的生成器輸出,我們希望它向真實的數(shù)據(jù)分布學(xué)習(xí),那么 BCELoss 此時為 -\log x。最終,我們希望判別器的輸出 (x) 接近于 1,即判別器判斷該數(shù)據(jù)為真實數(shù)據(jù)的概率越大。所以,這里依舊是在減少 BCELoss,則直接調(diào)用 criterion 就可以設(shè)定好生成器的損失函數(shù)。

    之前已經(jīng)設(shè)定好生成器的每個樣本輸入為一個 100 大小的向量,這里就將生成器的輸入產(chǎn)生一個 100 大小,且服從標準正態(tài)分布的向量。

  • 一切準備就緒,開始 GAN 的訓(xùn)練。

    以下是剛開始產(chǎn)生的圖片


    以下是最終生成的圖片,可以看到,通過生成器與判定器的不斷博弈,產(chǎn)生的圖片也越來越逼真

GAN的改進

相比起卷積神經(jīng)網(wǎng)絡(luò)之于計算機視覺,循環(huán)神經(jīng)網(wǎng)絡(luò)之于自然語言處理,GAN 尚且沒有一個特別適合的應(yīng)用場景。主要原因是 GAN 目前還存在諸多問題。例如:

  1. 不收斂問題:GAN 是兩個神經(jīng)網(wǎng)絡(luò)之間的博弈。試想,如果判別器提前學(xué)到了非常強的,那么生成器很容易出現(xiàn)梯度消失而無法繼續(xù)學(xué)習(xí)。所有 GAN 的收斂性一直是個問題,這樣也導(dǎo)致 GAN 在實際搭建過程中對各種超參數(shù)都非常敏感,需要精心設(shè)計才能完成一次訓(xùn)練任務(wù);
  2. 崩潰問題:GAN 模型被定義為一個極小極大問題,可以說,GAN 沒有一個清晰的目標函數(shù)。這樣會非常容易導(dǎo)致,生成器在學(xué)習(xí)的過程中開始退化,總是生成相同的樣本點,而這也進一步導(dǎo)致判別器總是被喂給相同的樣本點而無法繼續(xù)學(xué)習(xí),整個模型崩潰;
  3. 模型過于自由: 理論上,我們希望 GAN 能夠模擬出任意的真實數(shù)據(jù)分布,但事實上,由于我們沒有對模型進行事先建模,再加上「真實分布與生成分布的樣本空間并不完全重合」是一個極大概率事件。那么,對于較大的圖片,如果像素一旦過多,GAN 就會變得越來越不可控,訓(xùn)練難度非常大。
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容