STEP-8:Pytorch-從0實(shí)現(xiàn)DCGAN

感謝伯禹學(xué)習(xí)平臺(tái),本次學(xué)習(xí)將記錄記錄如何使用Pytorch高效實(shí)現(xiàn)網(wǎng)絡(luò),熟練掌握Pytorch的基礎(chǔ)知識(shí)。記錄不包含理論知識(shí)的細(xì)節(jié)展開。

DCGAN

DCGAN,實(shí)在GAN的基礎(chǔ)上,使用卷積網(wǎng)絡(luò)替換了原有的G和D中的全連接層,在圖像生成上具有較好的表現(xiàn)。

GAN的結(jié)構(gòu)示意圖
初始準(zhǔn)備

使用的輸入圖像的大小為(64,64,3)彩色的RGB圖像。在DCGAN中比較重要的一點(diǎn)是,G生成的圖像大小需要和真實(shí)圖片大小保持一致性,確定G的輸入大小后,可以設(shè)計(jì)G的基本結(jié)構(gòu)。

Generator

G的作用是將一個(gè)噪聲z使用反卷積的操作,將其拉伸到真實(shí)數(shù)據(jù)大小,反卷積的參數(shù)是可以學(xué)習(xí),所以與D的學(xué)習(xí)參數(shù)迭代,有了對(duì)抗。

import torch
import torchvision
import torch.nn as nn
# G_block,G與之前的網(wǎng)絡(luò)設(shè)計(jì)思想一致,將復(fù)用的模塊封裝為塊
class G_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4,strides=2, padding=1):
        super(G_block,self).__init__()
        self.conv2d_trans=nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
                                             stride=strides, padding=padding, bias=False)
        # 前幾層都是 trans_conv+bn+relu
        self.batch_norm=nn.BatchNorm2d(out_channels,0.8)
        self.activation=nn.ReLU()
    def forward(self,x):
        return self.activation(self.batch_norm(self.conv2d_trans(x)))

如上block中的反卷積輸入大小可以參考如下計(jì)算公式,在輸入n_k = n_h=16,k_h=k_w=4,p_h=p_w = 1,s_w=s_h=2的情況下輸出大小為(32,32),也就w,h增大一倍
\begin{aligned} n_h^{'} \times n_w^{'} &= [(n_h k_h - (n_h-1)(k_h-s_h)- 2p_h] \times [(n_w k_w - (n_w-1)(k_w-s_w)- 2p_w]\\ &= [(k_h + s_h (n_h-1)- 2p_h] \times [(k_w + s_w (n_w-1)- 2p_w]\\ &= [(4 + 2 \times (16-1)- 2 \times 1] \times [(4 + 2 \times (16-1)- 2 \times 1]\\ &= 32 \times 32 .\\ \end{aligned}
在G的網(wǎng)絡(luò)結(jié)構(gòu)中,它每層的輸出通道數(shù)時(shí)遞減的,這點(diǎn)與D剛好有點(diǎn)對(duì)稱的意思。

class net_G(nn.Module):
    def __init__(self,in_channels):
        super(net_G,self).__init__()

        n_G=64
        self.model=nn.Sequential(
            G_block(in_channels,n_G*8,strides=1,padding=0), 
            G_block(n_G*8,n_G*4),
            G_block(n_G*4,n_G*2),
            G_block(n_G*2,n_G),
            # 最后的輸出卷積層使用的激活函數(shù)為Tanh,具有較好的泛化能力
            nn.ConvTranspose2d(
                n_G,3,kernel_size=4,stride=2,padding=1,bias=False
            ),
            nn.Tanh()
        )
    def forward(self,x):
        x=self.model(x)
        return x

在給定輸入為(1,1)的噪聲情況下,不難驗(yàn)證輸出為(64,64)

Discriminator

D的作用是區(qū)分真假,生成對(duì)抗網(wǎng)絡(luò)是一個(gè)無(wú)監(jiān)督學(xué)習(xí),其標(biāo)簽區(qū)分只有0,1,分別對(duì)應(yīng)生成數(shù)據(jù)和真實(shí)數(shù)據(jù)。D的結(jié)構(gòu),類似于G的反過(guò)來(lái)的意思。

# 定義D的復(fù)用塊,使用正常的卷積網(wǎng)絡(luò)
class D_block(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=4,strides=2,
                 padding=1,alpha=0.2):
        super(D_block,self).__init__()
        self.conv2d=nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding,bias=False)
        # 這里用的時(shí)conv+bn+leakyrelu,論文中給出這樣有利于收斂。
        self.batch_norm=nn.BatchNorm2d(out_channels,0.8)
        self.activation=nn.LeakyReLU(alpha)
    def forward(self,X):
        return self.activation(self.batch_norm(self.conv2d(X)))

對(duì)應(yīng)的D

class net_D(nn.Module):
    def __init__(self,in_channels):
        super(net_D,self).__init__()
        n_D=64
        self.model=nn.Sequential(
            D_block(in_channels,n_D),
            D_block(n_D,n_D*2),
            D_block(n_D*2,n_D*4),
            D_block(n_D*4,n_D*8)
        )
        self.conv=nn.Conv2d(n_D*8,1,kernel_size=4,bias=False)
        # 最后使用的sigmoid激活,常的分類激活函數(shù)
        self.activation=nn.Sigmoid()
    def forward(self,x):
        x=self.model(x)
        x=self.conv(x)
        x=self.activation(x)
        return x

這里由于D輸入尺寸為為(64,64)所以其最后的輸出為(1,1)

如何訓(xùn)練這樣的網(wǎng)絡(luò)

在DCGAN中,網(wǎng)絡(luò)通常是迭代訓(xùn)練的,固定G訓(xùn)練D,固定D訓(xùn)練G。。。
訓(xùn)練過(guò)程中的損失計(jì)算參考如下代碼

def update_D(X,Z,net_D,net_G,loss,trainer_D):
    batch_size=X.shape[0]
    Tensor=torch.cuda.FloatTensor
    ones=Variable(Tensor(np.ones(batch_size,)),requires_grad=False).view(batch_size,1)
    zeros = Variable(Tensor(np.zeros(batch_size,)),requires_grad=False).view(batch_size,1)
    #訓(xùn)練D的時(shí)候,給原始圖1標(biāo)簽,生成圖0標(biāo)簽
    real_Y=net_D(X).view(batch_size,-1)
    fake_X=net_G(Z)
    fake_Y=net_D(fake_X).view(batch_size,-1)
    loss_D=(loss(real_Y,ones)+loss(fake_Y,zeros))/2
    loss_D.backward()
    trainer_D.step()
    return float(loss_D.sum())

def update_G(Z,net_D,net_G,loss,trainer_G):
    batch_size=Z.shape[0]
    Tensor=torch.cuda.FloatTensor
    ones=Variable(Tensor(np.ones((batch_size,))),requires_grad=False).view(batch_size,1)
    # 在訓(xùn)練G的時(shí)候我們需要給定生成圖1標(biāo)簽
    fake_X=net_G(Z)
    fake_Y=net_D(fake_X).view(batch_size,-1)
    loss_G=loss(fake_Y,ones)
    loss_G.backward()
    trainer_G.step()
    return float(loss_G.sum())

讀取數(shù)據(jù)訓(xùn)練參考如下

    for epo in range(epochs):
        for data in dataiter:
            z = ..
            d.zero_grad()
            update_D(...)
            g.zero_grad()
            update_G

總結(jié)

DCGAN 可以說(shuō)是入門GAN的開始,后面有很多基于GAN思路的改進(jìn),本質(zhì)上來(lái)說(shuō),GAN的思路使得神經(jīng)網(wǎng)絡(luò)具有了可控的創(chuàng)造性,但距離人的差距還是很大。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容