一、概述
??本文是論文Image-to-Image Translation with Conditional Adversarial Net-works的閱讀筆記。雖然年代已經(jīng)有些舊遠(yuǎn),但是像這種計(jì)算機(jī)視覺(jué)領(lǐng)域的經(jīng)典文獻(xiàn),尤其是pixel2pixel這樣面向一大類問(wèn)題的文章,其涉及的知識(shí)之廣度和深度,都是非常值得去學(xué)習(xí)和探究的。作者試圖用一個(gè)網(wǎng)絡(luò),來(lái)解決任意Image-to-Image Translation的問(wèn)題,并且在其中幾個(gè)有代表性的子問(wèn)題(如真實(shí)圖<->語(yǔ)義分割圖, 衛(wèi)星圖<->谷歌地圖,邊緣圖填充,圖像著色等)上達(dá)到了和領(lǐng)域內(nèi)一些專門定制的方法相媲美的效果,可謂以不變應(yīng)萬(wàn)變。
??以下是一些相關(guān)的論文插圖,包括一些近期的擴(kuò)展工作,比如2的成果(雖然在原始pix2pix基礎(chǔ)上專注于語(yǔ)義圖轉(zhuǎn)換,但是效果確實(shí)驚艷?。?/p>




??pix2pixHD論文中好多插圖對(duì)比實(shí)驗(yàn)甚至都不放原始pixel2pixel2的結(jié)果了...話不多說(shuō),只有打好基礎(chǔ),才能去學(xué)習(xí)最新的fancy方法!
??本文組織結(jié)構(gòu)如下:第二部分介紹原始pixel2pixel的論文筆記;第三部分是自己用Gluon在不同的數(shù)據(jù)集上訓(xùn)練pixel2pixel的結(jié)果。最后一部分總結(jié)這篇論文的收獲和問(wèn)題。
二、Image-to-Image Translation with Conditional Adversarial Networks
2.1 背景介紹
??許多圖像處理領(lǐng)域(本文主要以graph和vision兩個(gè)為例,解釋見(jiàn)下)都涉及從一個(gè)圖像到另一個(gè)圖像的轉(zhuǎn)換問(wèn)題。不同圖像之間轉(zhuǎn)換,如RGB圖,gradient field,edge map,semantic label map等等,此前都是用特定的算法去解決。而這些圖像其實(shí)本質(zhì)上具有相同的特征,即mapping from pixels to pixels。類比相同意思不同語(yǔ)言之間可以相互翻譯,這些相同圖像不同的表現(xiàn)形式之間也可以相互“translate”,故本文的任務(wù)設(shè)定為:
“Translating one possible representation of a scene into another, given sufficient training data.”
??如前面所述,在不同的graph(本文中主要指realistic photo synthesis)以及vision(如語(yǔ)義分割)任務(wù)中,單獨(dú)處理每個(gè)領(lǐng)域都需要通過(guò)設(shè)計(jì)不同的problem-dependent loss來(lái)指導(dǎo)網(wǎng)絡(luò)訓(xùn)練,我們能不能找到一個(gè)較通用的方法,來(lái)處理這一大類圖像轉(zhuǎn)換問(wèn)題呢?其實(shí)有一個(gè)現(xiàn)成的方法,簡(jiǎn)單粗暴——使用歐氏距離(或者說(shuō)squared-L2,MSE等)作為通用目標(biāo)函數(shù),通過(guò)訓(xùn)練逐漸減小預(yù)測(cè)輸出和groundtruth之間的差距。
2.2 MSE作為通用的目標(biāo)函數(shù)有什么不好?
?? 直觀上理解,如果生成的圖像和真實(shí)圖像之間的差距非常小,那么生成的結(jié)果應(yīng)該看起來(lái)非常逼真。然而事實(shí)上這種訓(xùn)練方法有兩個(gè)問(wèn)題,如下所述:

??上圖展示了一個(gè)預(yù)測(cè)視頻下一幀的例子。如果用實(shí)際的下一幀圖像和預(yù)測(cè)得到的下一幀圖像之間的均方誤差作為損失函數(shù)來(lái)訓(xùn)練網(wǎng)絡(luò),最后得到的結(jié)果如中間圖所示。產(chǎn)生這樣結(jié)果的原因在于,這一任務(wù)可能存在多個(gè)可能的正確結(jié)果(這里的多個(gè)可能結(jié)果能不能理解成,和groundtruth恰好具有相同MSE的輸出圖像組成的一個(gè)集合?),不同結(jié)果之間具有微小的差異,比如眼睛、耳朵的位置,人臉角度等。模型被要求只輸出一個(gè)結(jié)果,所以最終結(jié)果是多個(gè)可能結(jié)果的平均(average all plausible outputs)。我在另一篇博客中更詳細(xì)地介紹了L2 loss引入模糊的原因。
2.3 結(jié)構(gòu)化與非結(jié)構(gòu)化損失
??在許多圖像重建任務(wù)中,相比使用L2(MSE)作為通用損失函數(shù),還有另一個(gè)非常通用的概念,即論文中提到的"per-pixel classification or regression ",我自己的理解就是所謂的圖像回歸任務(wù)。當(dāng)然這個(gè)說(shuō)法非常寬泛,比如下圖的例子,輸入一個(gè)圖像,通過(guò)一個(gè)CNN,輸出一定空間維度的特征圖,然后使用一個(gè)loss函數(shù)約束輸出特征圖,使其逐步逼近我們想要得到的輸出(label圖),這就可以看做一個(gè)回歸任務(wù),有的教材將其視作分類任務(wù),如下圖,個(gè)人還是覺(jué)得看出回歸任務(wù)比較直觀,因?yàn)槊總€(gè)像素點(diǎn)的輸出值可以范圍非常大。這里也體現(xiàn)出分類和回歸其實(shí)并非涇渭分明。

??注意上面說(shuō)的是“一定空間維度的輸出”,上圖給的例子是輸入和輸出維度一致的。然而這種回歸的思想廣泛用于各種視覺(jué)領(lǐng)域,比如對(duì)于單張圖像樣本而言,輸出維度是輸入的倍數(shù)——超分辨率重建任務(wù);輸出維度是一固定長(zhǎng)度的向量,比如如 可以代表YOLO算法輸出的N個(gè)bbox的位置和概率信息(yolo將圖像回歸到一條長(zhǎng)向量上,雖然對(duì)于類別信息和位置信息回歸時(shí)使用不同的loss);再比如輸出維度是一定通道數(shù)目的高斯響應(yīng),如landmark回歸任務(wù)(要得到多少個(gè)key point 就輸出多少個(gè)通道,每個(gè)通道去學(xué)習(xí)一個(gè)以key point 位置為中心的高斯響應(yīng)圖)。
??扯得有點(diǎn)遠(yuǎn)。總之,計(jì)算機(jī)視覺(jué)領(lǐng)域很多看起來(lái)比較復(fù)雜的任務(wù)其實(shí)都可以歸結(jié)為回歸任務(wù),而不是非得擬合曲線那種線性回歸。使用簡(jiǎn)單的圖像回歸模式來(lái)formulate大部分的Image-to-image translation,雖然是一種較通用的方法,但是這里面有一個(gè)問(wèn):潛在地任務(wù)輸出空間的每一個(gè)像素,在給定的輸入下,是條件獨(dú)立的。這被稱為一種非結(jié)構(gòu)化(unstructured)的學(xué)習(xí)方式,簡(jiǎn)單來(lái)說(shuō)就是忽略了像素與像素之間的聯(lián)系。與之相對(duì)應(yīng)的是結(jié)構(gòu)化損失函數(shù)(structured loss),比如條件隨機(jī)場(chǎng)CRF,Perceptual loss, Feature matching 等一些暫時(shí)還沒(méi)研究過(guò)的算法。這些不同的結(jié)構(gòu)化損失函數(shù)都是為了懲罰輸入和輸出結(jié)構(gòu)上的差異,而不僅僅是像素級(jí)差異?!径疚奶岢龅幕赾GAN的方法和上面這些不同之處在于,其可以根據(jù)需要自動(dòng)學(xué)習(xí)這些結(jié)構(gòu)化的損失函數(shù)?!?/p>
問(wèn)題:如何理解“GAN可以自動(dòng)學(xué)習(xí)損失函數(shù)” 這句話?(如論文中說(shuō)的:
"GANs not only learn the mapping from input image to output image, but also learn a loss function to train this mapping.")
2.4 設(shè)計(jì)網(wǎng)絡(luò)架構(gòu)
??本文中在介紹網(wǎng)絡(luò)架構(gòu)的同時(shí)介紹了很多有用的信息點(diǎn)。一些我認(rèn)為很有用的點(diǎn)被列在下面,順帶介紹下網(wǎng)絡(luò)結(jié)構(gòu)以及訓(xùn)練中的事項(xiàng)。
??首先一個(gè)問(wèn)題是,為什么要使用cGAN,而不是直接用GAN生成。我們知道,單純的GAN屬于隱式的生成模型,不會(huì)顯式的進(jìn)行概率密度估計(jì),而是從訓(xùn)練集所屬的數(shù)據(jù)分布中隨機(jī)采樣。但是由于在高維空間中采樣難度較大,所以我們采取一種間接的方式:在低維空間中采樣,如一維的高斯噪聲。如果能建立一種映射關(guān)系,把在低維度采樣和在高維度采樣這兩件事聯(lián)系起來(lái),就能通過(guò)在低維采樣來(lái)代替高維采樣。這個(gè)映射關(guān)系自然由萬(wàn)能近似器——神經(jīng)網(wǎng)絡(luò)來(lái)學(xué)習(xí)。如上述所說(shuō),GAN這種setting的第一個(gè)問(wèn)題,在于無(wú)法點(diǎn)對(duì)點(diǎn)的進(jìn)行Image Translation任務(wù)。因?yàn)槲覀冏罱K得到的只是一些采樣的結(jié)果,而無(wú)法根據(jù)輸入生成想要的輸出。
??另一個(gè)問(wèn)題在于,GAN生成的圖像可能會(huì)無(wú)視輸入圖像的不同,而輸出較單一的模式。接下來(lái)我們具體來(lái)看下論文中提出的網(wǎng)絡(luò)架構(gòu)有哪些有價(jià)值的信息:
2.4.1 生成器
- 整體架構(gòu)使用Unet。原因在于很多圖像轉(zhuǎn)換任務(wù)(比如圖像上色)不只是需要語(yǔ)義信息,還需要網(wǎng)絡(luò)保留一些低級(jí)信息,如顏色、輪廓等。故采用Unet對(duì)稱的跳躍連接結(jié)構(gòu)將低層信息直接復(fù)制到高層特征圖上。Unet實(shí)際上就是加了跳躍連接的AutoEncoder結(jié)構(gòu)。
- 在Unet-Block的設(shè)計(jì)中參考了DCGAN的結(jié)構(gòu),如不使用pooling,使用LeakeyRelu作為激活函數(shù)等。詳見(jiàn)后面的代碼。
- 【很多條件生成對(duì)抗網(wǎng)絡(luò)不僅僅輸入先驗(yàn)
,往往還會(huì)引入噪聲
,一起作為輸入。這部分我沒(méi)看懂...說(shuō)是如果不引入一些噪聲,網(wǎng)絡(luò)的輸出會(huì)變成確定的。所以要引入一點(diǎn)噪聲。然后作者又說(shuō)他們嘗試了在輸入引入高斯噪聲,但是發(fā)現(xiàn)沒(méi)效果。所以最后網(wǎng)絡(luò)架構(gòu)中使用的是Dropout噪聲(可能是因?yàn)镈ropout引入了隨機(jī)性,所以可以看做某種噪聲?)】

2.4.2 判別器
- 輔助的L1 loss與patchGAN 判別器
論文中提到,雖然L2 loss這類損失函數(shù)會(huì)引入模糊,但是可以精確地捕捉到低頻信息。故本文使用L1 loss作為輔助來(lái)學(xué)習(xí)低頻信息。這樣還有另一個(gè)好處,就是GAN的判別器部分只需要學(xué)習(xí)高頻信息即可。這樣判別器只需要關(guān)注圖像局部的結(jié)構(gòu),而不需要看到全圖的信息才能做出判斷。具體來(lái)說(shuō),本文使用一種叫patchGAN的判別器結(jié)構(gòu),相比用來(lái)判斷整張生成的圖像是否為real,patchGAN一次只判斷一個(gè)N×N的圖像塊是否為real,就好像拿著這樣一個(gè)分類器在一張圖像上滑動(dòng),每個(gè)小塊對(duì)應(yīng)的結(jié)果求平均,作為這一張圖像的判別結(jié)果。這樣也減小了訓(xùn)練中的參數(shù)量。 - 論文中提到,這種patchGAN的設(shè)計(jì)實(shí)際上是把圖像建模成馬爾科夫隨機(jī)場(chǎng),即一個(gè)距離超過(guò)一個(gè)patch的像素是獨(dú)立的,距離在一個(gè)patch之內(nèi)的像素點(diǎn)之間才有關(guān)聯(lián)。
- 在優(yōu)化D網(wǎng)絡(luò)時(shí)乘以一個(gè)因子0.5,使得D網(wǎng)絡(luò)相對(duì)G網(wǎng)絡(luò)的學(xué)習(xí)速率要慢一些。(注意,實(shí)際操作時(shí)只需將D_loss乘以0.5即可實(shí)現(xiàn)。)
- 采用常見(jiàn)的全卷積輸出(Global Average Pooling 后接 Dense)。
2.4.3 訓(xùn)練設(shè)置
- BN和Dropout的不同設(shè)置:
Dropout:測(cè)試集也使用;
BatchNorm:測(cè)試階段不使用訓(xùn)練時(shí)的moving average,而直接用測(cè)試集的均值; - BatchSize較小,1-10(當(dāng)設(shè)為1時(shí),BN稱為instance normalization);
- 使用Adam train時(shí),第一動(dòng)量
的值設(shè)為0.5,而一般深度學(xué)習(xí)平臺(tái)的默認(rèn)值是0.9
- PatchGAN中Patch設(shè)置:70。
- GAN loss和L1 loss的比例設(shè)置:L1 loss乘以100
2.5 其他
??本文也對(duì)評(píng)價(jià)生成圖像質(zhì)量的指標(biāo)進(jìn)行了探索。其中一種評(píng)價(jià)方法是將生成的圖像用在對(duì)應(yīng)的vision任務(wù)上,如識(shí)別或者分割,然后用這些任務(wù)的評(píng)價(jià)指標(biāo)來(lái)間接評(píng)價(jià)生成的圖像質(zhì)量。具體來(lái)說(shuō)本文用到了一個(gè)FCN-scores。后續(xù)可能需要進(jìn)一步了解。
代碼實(shí)現(xiàn):
??代碼是學(xué)習(xí)Gluon GAN相關(guān)教程的一個(gè)pix2pix的練習(xí)。注意中文版Gluon教程后面好像沒(méi)有GAN,我當(dāng)時(shí)看的是英文版教程。
# coding: utf-8
import os
import pdb
import time
import logging
import mxnet as mx
import numpy as np
from os.path import join
from mxnet.gluon import nn
import matplotlib.pyplot as plt
from mxnet import gluon, nd, autograd
from mxnet.gluon.nn import Conv2D, LeakyReLU, BatchNorm, Dropout, Activation, Conv2DTranspose
from datetime import datetime
# ### Step 0: Data Preprocessing
# * 這里使用和上一節(jié)類似的`os.walk + mx.io.NDArrayIter`的套路
def preprocess_single_img(img):
assert isinstance(img, mx.ndarray.ndarray.NDArray), "input must be NDArray type."
img = mx.image.imresize(img, 2 * img_wd, img_ht)
img_in = img[:,:img_wd].transpose((2,0,1)).expand_dims(0)
img_out = img[:,img_wd:].transpose((2,0,1)).expand_dims(0)
assert img_in.shape==(1, 3, 256, 256), "image shape not correct."
return img_in, img_out
def load_data(data_path, batch_size, reverse=False):
img_in_list, img_out_list = [], []
for path, _, files in os.walk(data_path):
for file in files:
if not file[-4:] in ['.jpg']:
continue
img_arr = mx.image.imread(join(path, file)).astype(np.float32)/127.5 - 1
img_in, img_out = preprocess_single_img(img_arr)
if not reverse:
img_in_list.append(img_in)
img_out_list.append(img_out)
else:
img_in_list.append(img_out)
img_out_list.append(img_in)
return mx.io.NDArrayIter(data = [nd.concatenate(img_in_list), nd.concatenate(img_out_list)], batch_size=batch_size)
def visualize(img_arr):
plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
plt.axis('off')
def show_samples(data_iter, num_samples=4):
img_in_list, img_out_list = data_iter.next().data
for i in range(num_samples):
plt.subplot(2, num_samples, i+1)
visualize(img_in_list[i])
plt.subplot(2, num_samples, i+num_samples+1)
visualize(img_out_list[i])
plt.show()
# ### Step 2 Network Design
# #### 2.1 Unet網(wǎng)絡(luò)塊定義:
# 注意:
# ①同TensorFlow一樣,卷積層默認(rèn): use_bias=True
# ②Unet Block定義的基本思路:先定義好Encoder-Decoder結(jié)構(gòu),最后再hybrid_forward中將Encoder-Decoder輸入特征級(jí)聯(lián)到輸出特征即可
# ③**除了最內(nèi)層的Block,其他所有層,輸入Encoder的feature map的通道數(shù)都是輸入Decoder的feature map通道數(shù)的一半**
# ④BatchNorm層默認(rèn)參數(shù)設(shè)置momentum為0.9,而G和D中都設(shè)置為0.1?
class UnetSkipUnit(nn.HybridBlock):
def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False, use_dropout=False, use_bias=False):
super(UnetSkipUnit, self).__init__()
# 先定義一些基本的組件
self.outermost = outermost
en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1, in_channels=outer_channels, use_bias=use_bias)
en_relu = LeakyReLU(alpha=0.2)
en_bn = BatchNorm(momentum=0.1, in_channels=inner_channels)
deconv_innermost = Conv2DTranspose(
channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels, use_bias=use_bias)
deconv_output = Conv2DTranspose(
channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=2*inner_channels, use_bias=True)
deconv_common = de_conv_innermost = Conv2DTranspose(
channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=2*inner_channels, use_bias=use_bias)
de_relu = Activation('relu')
de_bn = BatchNorm(momentum=0.1, in_channels=outer_channels)
# Unet網(wǎng)絡(luò)塊可以分為三種:最里面的,中間的,最外面的。
if innermost:
encoder = [en_relu, en_conv]
decoder = [de_relu, deconv_innermost, de_bn]
model = encoder + decoder
elif outermost:
encoder = [en_conv]
decoder = [de_relu, deconv_output]
model = encoder + [inner_block] + decoder
model += [Activation('tanh')]
else:
encoder = [en_relu, en_conv, en_bn]
decoder = [de_relu, deconv_common, de_bn]
model = encoder + [inner_block] + decoder
if use_dropout:
model += [Dropout(0.5)]
self.model = nn.HybridSequential()
with self.model.name_scope():
for block in model:
self.model.add(block)
def hybrid_forward(self, F, x):
# 除了outermost之外的block都要加skip connection
if self.outermost:
return self.model(x)
else:
#pdb.set_trace()
return F.concat(self.model(x), x, dim=1)
class UnetGenerator(nn.HybridBlock):
def __init__(self, input_channels, num_downs, ngf=64, use_dropout=True):
super(UnetGenerator, self).__init__()
unet= UnetSkipUnit(ngf * 8, ngf * 8, innermost=True)
for _ in range(num_downs - 5):
unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout)
unet = UnetSkipUnit(ngf * 8, ngf * 4, unet)
unet = UnetSkipUnit(ngf * 4, ngf * 2, unet)
unet = UnetSkipUnit(ngf * 2, ngf * 1, unet)
unet = UnetSkipUnit(ngf, input_channels, unet, outermost=True)
self.model = unet
def hybrid_forward(self, F, x):
return self.model(x)
class Discriminator(nn.HybridBlock):
def __init__(self, in_channels, n_layers=3, ndf=64, use_sigmoid=False, use_bias=False):
super(Discriminator, self).__init__()
# 用下面一段代碼來(lái)配置標(biāo)準(zhǔn)的2x 下采樣卷積
kernel_size=4
padding = int(np.ceil((kernel_size-1)/2))
self.model = nn.HybridSequential()
# 先用一個(gè)卷積將輸入轉(zhuǎn)為第一層feature map
self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2, padding=padding, use_bias=use_bias, in_channels=in_channels))
self.model.add(LeakyReLU(alpha=0.2))
nf_mult = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
self.model.add(
Conv2D(channels=ndf*nf_mult, kernel_size=kernel_size, strides=2, padding=padding, use_bias=use_bias, in_channels=ndf*nf_mult_prev),
BatchNorm(momentum=0.1, in_channels=ndf*nf_mult),
LeakyReLU(alpha=0.2))
# 若layers較少,channel未達(dá)到512, 可以繼續(xù)升一點(diǎn)維度
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
self.model.add(
Conv2D(channels=ndf*nf_mult, kernel_size=kernel_size, strides=1, padding=padding, use_bias=use_bias, in_channels=ndf*nf_mult_prev),
BatchNorm(momentum=0.1, in_channels=ndf*nf_mult),
LeakyReLU(alpha=0.2))
# 輸出: output channel為什么設(shè)為1?
self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1, padding=padding, use_bias=True, in_channels=ndf*nf_mult))
if use_sigmoid:
self.model.add(Activation('sigmoid'))
def hybrid_forward(self, F, x):
return self.model(x)
# ### 2.4 Construct Network
# 注意:
# ①這里的loss使用binary_cross_entropy + L1 loss 作為最終的loss。L1 loss用來(lái)capture 圖像中的low frequencies
# ②使用自定義的初始化方式: (這里說(shuō)的初始化均為實(shí)值初始化,而不是僅僅定義初始化方式)
# - 卷積層:
# - $weight$: 標(biāo)準(zhǔn)差為0.02的高斯隨機(jī)初始化
# - $bias$: 全零初始化
# - BN層:
# - 除了$gamma$之外,所有的bn參數(shù)($beta, running__mean, running__var$)初始化為0; $gamma$: **均值為1**,標(biāo)準(zhǔn)差0.02的高斯隨機(jī)初始化
#
# ③這里設(shè)置的Trainer中的beta1參數(shù)是bn中的嗎?bn中的beta不應(yīng)該是參數(shù)而不是超參數(shù)嗎? 答:是Adam中的第一動(dòng)量 $β_1$
def init_param(param):
if param.name.find('conv') != -1: # conv層的參數(shù),包括w和b
if param.name.find('weight') != -1:
param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
else:
param.initialize(init=mx.init.Zero(), ctx=ctx)
elif param.name.find('batchnorm') != -1: #bn層的參數(shù)
param.initialize(init=mx.init.Zero(), ctx=ctx)
if param.name.find('gamma')!=-1:
param.set_data(nd.random_normal(1, 0.02, param.data().shape))
def network_init(net):
for param in net.collect_params().values():
init_param(param)
# 正式定義網(wǎng)絡(luò)架構(gòu)
def set_networks(num_downs=8, n_layers=3, ckpt=None):
netG = UnetGenerator(input_channels=3, num_downs=8)
netD = Discriminator(in_channels=6, n_layers=3)
if ckpt is not None:
print('[+]Loading Checkpoints {} ...'.format(ckpt))
netG.load_parameters(ckpt+'G.params', ctx=ctx)
netD.load_parameters(ckpt+'D.params', ctx=ctx)
print('[+]Checkpoint loaded successfully!')
else:
network_init(netG)
network_init(netD)
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate':lr, 'beta1':beta1})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate':lr, 'beta1':beta1})
return netG, netD, trainerG, trainerD
###################### Set loss function #######################
# ## Step 3: Training Loop
# ### 3.1 為判別模型專門定義一個(gè)ImagePool,使得判別模型不僅僅比較當(dāng)前的真實(shí)輸入和虛假輸出的損失,還要考慮歷史損失
# * 理解:
#
# 首先在pool滿之前,讀入的每張圖像都會(huì)被存儲(chǔ)在pool的images成員變量中。同時(shí)也會(huì)返回一份給ret,用于傳遞到函數(shù)外面。
# pool中只能存50張images,很快就會(huì)被占滿。當(dāng)pool滿了以后,再query一個(gè)樣本時(shí),pool可能以百分之五十的幾率選擇如下兩種操作中的一個(gè):
#
# ①使用讀入的image替換掉images列表中的隨機(jī)一張,替換得到的images中的old image被分給ret,隨后返回。
# ②新的image被加入到ret中,pool中的images列表不更新
# * 問(wèn)題:
#
# ①ImagePool的作用是什么?
# ②pool會(huì)對(duì)每張圖像進(jìn)行qurey操作,最起碼有一些nd運(yùn)算。這會(huì)對(duì)訓(xùn)練的迭代速度產(chǎn)生多大的影響?
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
ret_imgs = []
for i in range(images.shape[0]):
image = nd.expand_dims(images[i], axis=0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
ret_imgs.append(image)
else:
p = nd.random_uniform(0, 1, shape=(1,)).asscalar()
if p > 0.5:
random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar()
tmp = self.images[random_id].copy()
self.images[random_id] = image
ret_imgs.append(tmp)
else:
ret_imgs.append(image)
ret_imgs = nd.concat(*ret_imgs, dim=0)
return ret_imgs
def facc(label, pred):
return ((pred.ravel()>0.5) == (label.ravel())).mean()
def train(lamda=100, lr_decay=0.2, period=50, ckpt='.', viz=False):
image_pool = ImagePool(pool_size)
metric = mx.metric.CustomMetric(facc)
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)
#fig = plt.figure()
for epoch in range(num_epochs):
epoch_tic = time.time()
btic = time.time()
train_data.reset()
for iter, batch in enumerate(train_data):
real_in, real_out = batch.data[0].as_in_context(ctx), batch.data[1].as_in_context(ctx)
fake_out = netG(real_in)
fake_concat = image_pool.query(nd.Concat(real_in, fake_out, dim=1))
with autograd.record():
# Train with fake images
output = netD(fake_concat) #?????????????????? 這里把x和fake一同送入D,是Conditional GAN的體現(xiàn)?如何理解這里的條件概率?
fake_label = nd.zeros(output.shape, ctx=ctx)
errD_fake = GAN_loss(output, fake_label)
metric.update([fake_label,],[output,]) ## metric應(yīng)該何時(shí)update???
# Train with real images
real_concat = image_pool.query(nd.Concat(real_in, real_out, dim=1))
output = netD(real_concat)
real_label = nd.ones(output.shape, ctx=ctx)
errD_real = GAN_loss(output, real_label)
errD = (errD_fake + errD_real) * 0.5 ## 如論文所述,D loss乘以0.5以降低相對(duì)G的更新速率
errD.backward()
metric.update([real_label,],[output,])
trainerD.step(batch_size)
with autograd.record():
fake_out = netG(real_in) # 這里的G為什么沒(méi)有體現(xiàn)出Conditional GAN?? ####### 重要 #######
#fake_concat = image_pool.query(nd.Concat(real_in, fake_out, dim=1))
# 注意:image_pool只用于記錄判別器
fake_concat = nd.Concat(real_in, fake_out) # Conditional GAN的先驗(yàn):real_in,即 x
output = netD(fake_concat)
errG = GAN_loss(output, real_label) + lamda * L1_loss(real_out, fake_out)
errG.backward()
trainerG.step(batch_size)
if iter % 10 == 0:
name, acc = metric.get()
logging.info('Epoch {}, lr {:.6f}, D loss: {:3f}, G loss:{:3f}, binary training acc: {:2f}, at iter {}, Speed: {} samples/s'.format(
epoch, trainerD.learning_rate, errD.mean().asscalar(), errG.mean().asscalar(), acc, iter, 0.1*batch_size/ (time.time()-btic)))
btic = time.time()
if epoch % period == 0:
trainerD.set_learning_rate(trainerD.learning_rate * lr_decay)
trainerG.set_learning_rate(trainerG.learning_rate * lr_decay)
if epoch % 100 == 0:
print('[+]saving checkpoints to {}'.format(ckpt))
netG.save_parameters(join(ckpt, 'pixel_netG_epoch_{}.params'.format(epoch)))
netD.save_parameters(join(ckpt, 'pixel_netD_epoch_{}.params'.format(epoch)))
name, epoch_acc = metric.get()
metric.reset()
logging.info('\n[+]binary training accuracy at epoch %d %s=%f' % (epoch, name, epoch_acc))
logging.info('[+]time: {:3f}'.format(time.time() - epoch_tic))
if __name__=='__main__':
#### 超參數(shù)列表
ctx = mx.gpu(0)
lr = 0.001
batch_size = 10
beta1 = 0.5 ## beta_1(第一動(dòng)量)默認(rèn)設(shè)置是0.9,為什么這里差別也這么大???
pool_size = 50
num_epochs = 500
Dataset_Path= 'CMP_Dataset/facades/'
img_wd, img_ht = 256, 256
GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
L1_loss = gluon.loss.L1Loss()
netG, netD, trainerG, trainerD = set_networks(n_layers=2, ckpt='pixel_net')
train_data = load_data(join(Dataset_Path,'train'), batch_size, reverse=True)
val_data = load_data(join(Dataset_Path,'val'), batch_size, reverse=True)
train(lamda=100, lr_decay=0.8, period=50, ckpt='pix2pix/models')
print('[+]Training complete. Saving parameters...')
netG.save_parameters('pixel_netG.params')
netD.save_parameters('pixel_netD.params')
實(shí)驗(yàn)結(jié)果:
- 從下圖可見(jiàn)實(shí)驗(yàn)還是存在一些問(wèn)題的。雖然基本上重建出了真實(shí)圖像,但是有兩個(gè)明顯的缺點(diǎn):
①局部地區(qū)出現(xiàn)了偽影 ②重建的模式趨于單一模式。
以后有時(shí)間還會(huì)繼續(xù)調(diào)整參數(shù)進(jìn)行實(shí)驗(yàn)。如果有朋友做過(guò)類似實(shí)驗(yàn),還請(qǐng)不吝賜教~


三、參考
- https://arxiv.org/pdf/1611.07004.pdf
- https://www.researchgate.net/publication/321417929_High-Resolution_Image_Synthesis_and_Semantic_Manipulation_with_Conditional_GANs
- http://www.itdecent.cn/p/704b14752308
- 原教程鏈接:https://nbviewer.jupyter.org/github/zackchase/mxnet-the-straight-dope/blob/master/chapter14_generative-adversarial-networks/pixel2pixel.ipynb