10分鐘快速入門PyTorch (10)

前面我們已經(jīng)講完了一般的深層網(wǎng)絡(luò),適用于圖像的卷積神經(jīng)網(wǎng)絡(luò),適用于序列的循環(huán)神經(jīng)網(wǎng)絡(luò)。但是要知道Lecun提出第一代卷積網(wǎng)絡(luò)Lenet的時間是1998年,而循環(huán)神經(jīng)網(wǎng)絡(luò)提出的時間更早,是在1986年。這些網(wǎng)絡(luò)在當(dāng)時并沒有火起來,如今隨著計算能力的加強,數(shù)據(jù)集的增多,深度學(xué)習(xí)逐漸火了起來,隨著越來越多的人的研究,各種各樣的神經(jīng)網(wǎng)絡(luò)都在不斷進(jìn)步,CNN里面出現(xiàn)了inception net,resnet等等,RNN演變了LSTM和GRU,雖然神經(jīng)網(wǎng)絡(luò)不斷在發(fā)展,但是本質(zhì)上仍然是在CNN和RNN的基礎(chǔ)上。

直到2014年,深度學(xué)習(xí)三巨頭之一 Ian Goodfellow 提出了生成對抗網(wǎng)絡(luò)(Generative Adversarial Networks, GANs),剛開始的時候并沒有引起轟動,直到16年,學(xué)界、業(yè)界對其的興趣出現(xiàn)了“井噴”,多篇重磅文章陸續(xù)發(fā)表,Lecun也形容GANs“adversarial training is the coolest thing since sliced bread.” 16年12月NIPS大會上,Goodfellow做了GANs的專題報告,使得GANs成為了當(dāng)今最炙手可熱的研究領(lǐng)域,等你看完了這篇文章你就會知道為什么GANs能夠成為當(dāng)今人工智能領(lǐng)域的主要課題之一。

GANs

GANs的全稱叫做生成對抗網(wǎng)絡(luò),根據(jù)這個名字,你就可以猜測這個網(wǎng)絡(luò)是由兩部分組成的,第一部分是生成,第二部分是對抗。那么你已經(jīng)基本猜對了,這個網(wǎng)絡(luò)第一部分是生成網(wǎng)絡(luò),第二部分對抗模型嚴(yán)格來講是一個判別器,簡單來說呢,就是讓兩個網(wǎng)絡(luò)相互競爭,生成網(wǎng)絡(luò)來生成假的數(shù)據(jù),對抗網(wǎng)絡(luò)通過判別器去判別真?zhèn)?,最后希望生成器生成的?shù)據(jù)能夠以假亂真。

可以用這個圖來簡單的看一看這兩個過程。

1.png

下面我們就來依次介紹。

Discriminator Network

首先我們來講一下對抗過程,因為這個過程更加簡單。

對抗過程簡單來說就是一個判斷真假的判別器,相當(dāng)于一個二分類問題,我們輸入一張真的圖片希望判別器輸出的結(jié)果是1,輸入一張假的圖片希望判別器輸出的結(jié)果是0。這其實已經(jīng)和原圖片的label沒有關(guān)系了,不管原圖片到底是一個多少類別的圖片,他們都統(tǒng)一稱為真的圖片,label是1表示真實的;而生成的假的圖片的label是0表示假的。

我們訓(xùn)練的過程就是希望這個判別器能夠正確的判出真的圖片和假的圖片,這其實就是一個簡單的二分類問題,對于這個問題可以用我們前面講過的很多方法去處理,比如logistic回歸,深層網(wǎng)絡(luò),卷積神經(jīng)網(wǎng)絡(luò),循環(huán)神經(jīng)網(wǎng)絡(luò)都可以。

Generative Network

接著我們要看看如何生成一張假的圖片。首先給出一個簡單的高維的正態(tài)分布的噪聲向量,如上圖所示的D-dimensional noise vector,這個時候我們可以通過仿射變換,也就是xw+b將其映射到一個更高的維度,然后將他重新排列成一個矩形,這樣看著更像一張圖片,接著進(jìn)行一些卷積、池化、激活函數(shù)處理,最后得到了一個與我們輸入圖片大小一模一樣的噪音矩陣,這就是我們所說的假的圖片,這個時候我們?nèi)绾稳ビ?xùn)練這個生成器呢?就是通過判別器來得到結(jié)果,然后希望增大判別器判別這個結(jié)果為真的概率,在這一步我們不會更新判別器的參數(shù),只會更新生成器的參數(shù)。

如下圖所示

以上的過程已經(jīng)簡單的闡述了生成對抗網(wǎng)絡(luò)的學(xué)習(xí)過程,如果仍然不太清楚這個過程,下面我們會通過代碼來更清晰地展示整個過程。

Code

我們會使用mnist手寫數(shù)字來做數(shù)據(jù)集,通過生成對抗網(wǎng)絡(luò)我們希望生成一些“以假亂真”的手寫字體。為了加快訓(xùn)練過程,我們不使用卷積網(wǎng)絡(luò)來做判別器,我們使用簡單的多層網(wǎng)絡(luò)來進(jìn)行判別。

Discriminator Network

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.dis(x)
        return x

以上這個網(wǎng)絡(luò)是一個簡單的多層神經(jīng)網(wǎng)絡(luò),將圖片28x28展開成784,然后通過多層感知器,中間經(jīng)過斜率設(shè)置為0.2的LeakyReLU激活函數(shù),最后接sigmoid激活函數(shù)得到一個0到1之間的概率進(jìn)行二分類。之所以使用LeakyRelu而不是用ReLU激活函數(shù)是因為經(jīng)過實驗LeakyReLU的表現(xiàn)更好。

Generative Network

class generator(nn.Module):
    def __init__(self, input_size):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.gen(x)
        return x

輸入一個100維的0~1之間的高斯分布,然后通過第一層線性變換將其映射到256維,然后通過LeakyReLU激活函數(shù),接著進(jìn)行一個線性變換,再經(jīng)過一個LeakyReLU激活函數(shù),然后經(jīng)過線性變換將其變成784維,最后經(jīng)過Tanh激活函數(shù)是希望生成的假的圖片數(shù)據(jù)分布能夠在-1~1之間。

Discriminator Train

判別器的訓(xùn)練由兩部分組成,第一部分是真的圖像判別為真,第二部分是假的圖片判別為假,在這兩個過程中,生成器的參數(shù)不參與更新。

首先我們需要定義loss的度量方式和優(yōu)化函數(shù),loss度量使用二分類的交叉熵,油畫函數(shù)注意使用的學(xué)習(xí)率是0.0003

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

接著進(jìn)入訓(xùn)練

img = img.view(num_img, -1)  # 將圖片展開乘28x28=784
real_img = Variable(img).cuda()  # 將tensor變成Variable放入計算圖中
real_label = Variable(torch.ones(num_img)).cuda()  # 定義真實label為1
fake_label = Variable(torch.zeros(num_img)).cuda()  # 定義假的label為0

# compute loss of real_img
real_out = D(real_img)  # 將真實的圖片放入判別器中
d_loss_real = criterion(real_out, real_label)  # 得到真實圖片的loss  
real_scores = real_out  # 真實圖片放入判別器輸出越接近1越好

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 隨機生成一些噪聲
fake_img = G(z)  # 放入生成網(wǎng)絡(luò)生成一張假的圖片
fake_out = D(fake_img)  # 判別器判斷假的圖片
d_loss_fake = criterion(fake_out, fake_label)  # 得到假的圖片的loss
fake_scores = fake_out  # 假的圖片放入判別器越接近0越好

# bp and optimize
d_loss = d_loss_real + d_loss_fake  # 將真假圖片的loss加起來
d_optimizer.zero_grad()  # 歸0梯度
d_loss.backward()  # 反向傳播
d_optimizer.step()  # 更新參數(shù)

我已經(jīng)把每一步都注釋在了代碼上,這樣更加便于大家閱讀,這是一個判別器的訓(xùn)練過程,我們希望判別器能夠正確辨別出真假圖片。

Generative Train

在生成網(wǎng)絡(luò)的訓(xùn)練中,我們希望生成一張假的圖片,然后經(jīng)過判別器之后希望他能夠判斷為真的圖片,在這個過程中,我們將判別器固定,將假的圖片傳入判別器的結(jié)果與真實label對應(yīng),反向傳播更新的參數(shù)是生成網(wǎng)絡(luò)里面的參數(shù),這樣我們就可以通過跟新生成網(wǎng)絡(luò)里面的參數(shù)來使得判別器判斷生成的假的圖片為真,這樣就達(dá)到了生成對抗的作用。

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到隨機噪聲
fake_img = G(z)  # 生成假的圖片
output = D(fake_img)  # 經(jīng)過判別器得到結(jié)果
g_loss = criterion(output, real_label)  # 得到假的圖片與真實圖片label的loss

# bp and optimize
g_optimizer.zero_grad()  # 歸0梯度
g_loss.backward()  # 反向傳播
g_optimizer.step()  # 更新生成網(wǎng)絡(luò)的參數(shù)

這樣我們就寫好了一個簡單的生成網(wǎng)絡(luò),通過不斷地訓(xùn)練我們希望能夠生成很真的圖片。

Result

通過不斷訓(xùn)練,我們可以得到下面的圖片

這是真實圖片

real_images.png

第1幅為第一次生成的噪聲圖片,之后分別是跑完15次生成的圖片,跑完30次,跑完50次,跑完70次,最后一個是跑完100次生成的圖片

fake_images-1.png

fake_images-15.png

fake_images-30.png

fake_images-50.png

fake_images-70.png

fake_images-100.png

怎么樣,是不是特別神奇,我們居然可以生成一副看著很真的圖片,這里我們只是用了簡單的多層感知器來生成和判別模型,我們可以用更復(fù)雜的卷積神經(jīng)網(wǎng)絡(luò)來做同樣的事情,代碼將和本文的代碼放在一起,有興趣的同學(xué)可以自己去看看,然后放幾張卷積網(wǎng)絡(luò)生成的圖片

fake_images-1.png

fake_images-7.png

fake_images-12.png

可以發(fā)現(xiàn)產(chǎn)生的噪聲更少了,訓(xùn)練也更加穩(wěn)定,主要是里面引入了Batchnormalization,另外gan的訓(xùn)練過程是特別困難的,兩個對偶網(wǎng)絡(luò)相互學(xué)習(xí),這個時候有一些訓(xùn)練技巧可以使得訓(xùn)練生成更加穩(wěn)定,詳細(xì)見一下github

最后我們來說一下為何Gans能夠成為最近20年來機器學(xué)習(xí)以及深度學(xué)習(xí)界革命性的發(fā)現(xiàn)。這是因為不管是深度學(xué)習(xí)還是機器學(xué)習(xí)仍然很大一部分是監(jiān)督學(xué)習(xí),但是創(chuàng)建這么多有l(wèi)abel的數(shù)據(jù)集所需要的人力物力是極大的,同時遇到的新的任務(wù)時我們很容易得到原始的沒有l(wèi)abel的數(shù)據(jù)集,這是我們需要花大量的時間去給其標(biāo)定label,所以很多人都認(rèn)為無監(jiān)督學(xué)習(xí)才是機器學(xué)習(xí)的未來,這個時候Gans的出現(xiàn)為無監(jiān)督學(xué)習(xí)提供了有力的支持,這當(dāng)然引起了學(xué)界的大量關(guān)注,同時基于Gans的應(yīng)用也越來越多,業(yè)界對其也非??駸?。

最后引用Yan Lecun的話:"它(Gans)為創(chuàng)建無監(jiān)督學(xué)習(xí)模型提供了強有力的算法框架,有望幫助我們?yōu)?AI 加入常識(common sense)。我們認(rèn)為,沿著這條路走下去,有不小的成功機會能開發(fā)出更智慧的 AI 。"

以上我們簡單的介紹了Gans,通過網(wǎng)絡(luò)實現(xiàn)了手寫字體的生成,當(dāng)然還有更多的變形和應(yīng)用,有興趣的同學(xué)可以自己閱讀相關(guān)論文深入了解。

下一章我們將進(jìn)入pytorch教程的最后一個部分,也是和AI聯(lián)系最為緊密的一個部分,reinforcement learning,增強學(xué)習(xí)。


本文代碼已經(jīng)上傳到了github

歡迎查看我的知乎專欄,深度煉丹

歡迎訪問我的博客

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容