這里我們將建立 一個(gè)對(duì)抗生成網(wǎng)絡(luò) (GAN)訓(xùn)練MNIST,并在最后生成新的手寫數(shù)字。
這里先介紹幾個(gè)Demo:
Pix2pix 基本上就是你畫一個(gè)東西它就能生成類似的圖片

CycleGAN 這里視頻可以讓馬看起來(lái)像斑馬。

GAN背后的思想是你有一個(gè)生成器和辨別器,它們都處在這樣的一個(gè)博弈中,生成器產(chǎn)生假圖像,比如假數(shù)據(jù),讓它看起來(lái)更像真數(shù)據(jù),然后辨別器努力辨識(shí)該數(shù)據(jù)是真或是假。所以生成器將假數(shù)據(jù)傳遞給辨別器,而你將真數(shù)據(jù)傳遞給辨別器,然后由辨別器判定它是真是假。當(dāng)你在訓(xùn)練時(shí),生成器會(huì)學(xué)習(xí)生成圖像和數(shù)據(jù),讓它們看起來(lái)盡量與真實(shí)數(shù)據(jù)一樣,在這個(gè)過(guò)程中它會(huì)模仿實(shí)際真實(shí)數(shù)據(jù)的概率分布,通過(guò)這種方式,你可以生成與真實(shí)世界中看起來(lái)一樣的新圖像、新數(shù)據(jù)。
這里導(dǎo)入包和數(shù)據(jù)集
%matplotlib inline
import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
模型輸入
這里創(chuàng)建兩個(gè)輸入,辨別器的輸入為inputs_real,生成器的輸入為inputs_z。
def model_inputs(real_dim, z_dim):
inputs_real = tf.placeholder(tf.float32,(None ,real_dim),name ='input_real')
inputs_z = tf.placeholder(tf.float32,(None,z_dim),name = 'input_z')
return inputs_real, inputs_z

上圖顯示了整個(gè)網(wǎng)絡(luò)的樣子,這里生成器輸入是我們的z,它只是一個(gè)隨機(jī)向量,一種隨機(jī)白噪聲,我們會(huì)將其傳入生成器,然后生成器學(xué)習(xí)如何將這個(gè)隨機(jī)向量Z轉(zhuǎn)變?yōu)閠anh層中的圖像,tanh的輸出范圍為-1到1,這意味我們需要做轉(zhuǎn)化工作,需要轉(zhuǎn)換MNIST,使其取值-1到1之間。然后再將其傳入到辨別器網(wǎng)絡(luò)。
生成器
def generator(z, out_dim, n_units=128, reuse=False, alpha=0.01):
''' Build the generator network.
Arguments
---------
z : Input tensor for the generator
out_dim : Shape of the generator output
n_units : Number of units in hidden layer
reuse : Reuse the variables with tf.variable_scope
alpha : leak parameter for leaky ReLU
Returns
-------
out, logits:
'''
with tf.variable_scope('generator',reuse = reuse) :
# Hidden layer
h1 = tf.layers.dense(z,n_units,activation = None)
# Leaky ReLU
h1 = tf.maximum(alpha * h1,h1)
# Logits and tanh output
logits = tf.layers.dense(h1,out_dim)
out = tf.tanh(logits)
return out
使用tf.variable_scope,需要聲明with tf.variable_scope('scope_name', reuse=False):這里我們使用generator作為域的名稱,所以基本上所有的變量都將以generator開(kāi)頭。
這里我們選擇重用,所以它將告訴作用域重用本網(wǎng)絡(luò)中的變量。那么,我們從函數(shù)參數(shù)中獲得reuse,默認(rèn)情況下它是False。tf.layers.dense是一個(gè)全連接層,你可以直接使用層模塊,因?yàn)樗歉呒?jí)的,它會(huì)為你執(zhí)行所有權(quán)重初始化。
辨別器
辨別器和生成器構(gòu)造方法差不多。
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
''' Build the discriminator network.
Arguments
---------
x : Input tensor for the discriminator
n_units: Number of units in hidden layer
reuse : Reuse the variables with tf.variable_scope
alpha : leak parameter for leaky ReLU
Returns
-------
out, logits:
'''
with tf.variable_scope('discriminator',reuse = reuse):
# Hidden layer
h1 =tf.layers.dense(x,n_units,activation = None)
# Leaky ReLU
h1 =tf.maximum(alpha * h1,h1)
logits = tf.layers.dense(h1,1,activation = None)
out =tf.sigmod(logits)
return out, logits
超參數(shù)
# Size of input image to discriminator
input_size = 784 # 28x28 MNIST images flattened
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Label smoothing
smooth = 0.1
構(gòu)建網(wǎng)絡(luò)
tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)
# Build the model
g_model = generator(input_z, input_size)
# g_model is the generator output
d_model_real, d_logits_real = discriminator(input_real)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)
這里辨別器用相同的權(quán)重,所以reuse這里為true.
計(jì)算辨別器及生成器的損失
同時(shí)訓(xùn)練辨別器和生成器網(wǎng)絡(luò),我們需要這兩個(gè)不同網(wǎng)絡(luò)的損失。對(duì)辨別器總損失:是真實(shí)圖像和假圖像損失之和。
關(guān)于標(biāo)簽,對(duì)于真實(shí)圖像,我們想讓辨別器知道它們是真的,我們希望標(biāo)簽全部是1。為了幫助辨別器更好的泛化,我們要執(zhí)行一個(gè)叫做標(biāo)簽平滑的操作,創(chuàng)建一個(gè)smooth的參數(shù),略小于1。假數(shù)據(jù)辨別器損失也類似,設(shè)定這些標(biāo)簽全部為0。最后對(duì)于生成器,再次使用d_logits_fake,但這次我們的標(biāo)簽全部為1,我們想讓生成器欺騙辨別器,我們想讓辨別器認(rèn)為假圖像是真的
# Calculate losses
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_real)))
d_loss = d_loss_real + d_loss_fake
g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)))
優(yōu)化器
我們要分別更新生成器和辨別器變量,首先獲取所有可訓(xùn)練的變量
# Optimizers
learning_rate = 0.002
# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
訓(xùn)練
batch_size = 100
epochs = 100
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for ii in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
# Get images, reshape and rescale to pass to D
batch_images = batch[0].reshape((batch_size, 784))
batch_images = batch_images*2 - 1
# Sample random noise for G
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
# Run optimizers
_ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
_ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
# At the end of each epoch, get the losses and print them out
train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
train_loss_g = g_loss.eval({input_z: batch_z})
print("Epoch {}/{}...".format(e+1, epochs),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g))
# Save losses to view after training
losses.append((train_loss_d, train_loss_g))
# Sample from generator as we're training for viewing afterwards
sample_z = np.random.uniform(-1, 1, size=(16, z_size))
gen_samples = sess.run(
generator(input_z, input_size, reuse=True),
feed_dict={input_z: sample_z})
samples.append(gen_samples)
saver.save(sess, './checkpoints/generator.ckpt')
# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
pkl.dump(samples, f)
結(jié)果

改進(jìn)GAN
我向你展示的 GAN,在生成器和辨別器中只使用了一個(gè)隱藏層。這個(gè) GAN 的結(jié)果已經(jīng)非常不錯(cuò)了,但仍然有很多噪聲圖像,以及有些圖像看起來(lái)不太像數(shù)字。但是,要讓生成器生成的圖像與 MNIST 數(shù)據(jù)集幾乎一樣,是完全可能的。

這來(lái)自一篇題為 Improved Techniques for Training GANs 的文章。那么,它們?nèi)绾紊扇绱嗣烙^的圖像呢?
批歸一化
提醒一下,在三層情況下你可能無(wú)法使它很好地工作。網(wǎng)絡(luò)會(huì)變得對(duì)權(quán)重的初始值非常敏感,導(dǎo)致無(wú)法訓(xùn)練。我們可以使用 批歸一化(Batch Normalization) 來(lái)解決這個(gè)問(wèn)題。原理很簡(jiǎn)單。就像我們對(duì)網(wǎng)絡(luò)輸入的做法一樣,我們可以對(duì)每個(gè)層的輸入進(jìn)行歸一化。也就是說(shuō),縮放層輸入,使它具有零均值和標(biāo)準(zhǔn)差 1。經(jīng)發(fā)現(xiàn),批歸一化對(duì)于構(gòu)建深度 GAN 非常有必要。
歡迎大家看我以前寫的Batch Normalization