Generative Adversarial Network

這里我們將建立 一個(gè)對(duì)抗生成網(wǎng)絡(luò) (GAN)訓(xùn)練MNIST,并在最后生成新的手寫數(shù)字。

這里先介紹幾個(gè)Demo:

Pix2pix 基本上就是你畫一個(gè)東西它就能生成類似的圖片


Pix2pix生成的貓

CycleGAN 這里視頻可以讓馬看起來(lái)像斑馬。

gan_diagram

GAN背后的思想是你有一個(gè)生成器和辨別器,它們都處在這樣的一個(gè)博弈中,生成器產(chǎn)生假圖像,比如假數(shù)據(jù),讓它看起來(lái)更像真數(shù)據(jù),然后辨別器努力辨識(shí)該數(shù)據(jù)是真或是假。所以生成器將假數(shù)據(jù)傳遞給辨別器,而你將真數(shù)據(jù)傳遞給辨別器,然后由辨別器判定它是真是假。當(dāng)你在訓(xùn)練時(shí),生成器會(huì)學(xué)習(xí)生成圖像和數(shù)據(jù),讓它們看起來(lái)盡量與真實(shí)數(shù)據(jù)一樣,在這個(gè)過(guò)程中它會(huì)模仿實(shí)際真實(shí)數(shù)據(jù)的概率分布,通過(guò)這種方式,你可以生成與真實(shí)世界中看起來(lái)一樣的新圖像、新數(shù)據(jù)。
這里導(dǎo)入包和數(shù)據(jù)集

%matplotlib inline

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')

模型輸入

這里創(chuàng)建兩個(gè)輸入,辨別器的輸入為inputs_real,生成器的輸入為inputs_z。

def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32,(None ,real_dim),name ='input_real')
    inputs_z = tf.placeholder(tf.float32,(None,z_dim),name = 'input_z') 
    return inputs_real, inputs_z
gan_network

上圖顯示了整個(gè)網(wǎng)絡(luò)的樣子,這里生成器輸入是我們的z,它只是一個(gè)隨機(jī)向量,一種隨機(jī)白噪聲,我們會(huì)將其傳入生成器,然后生成器學(xué)習(xí)如何將這個(gè)隨機(jī)向量Z轉(zhuǎn)變?yōu)閠anh層中的圖像,tanh的輸出范圍為-1到1,這意味我們需要做轉(zhuǎn)化工作,需要轉(zhuǎn)換MNIST,使其取值-1到1之間。然后再將其傳入到辨別器網(wǎng)絡(luò)。

生成器

def generator(z, out_dim, n_units=128, reuse=False,  alpha=0.01):
    ''' Build the generator network.
    
        Arguments
        ---------
        z : Input tensor for the generator
        out_dim : Shape of the generator output
        n_units : Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('generator',reuse = reuse) :
        # Hidden layer
        h1 = tf.layers.dense(z,n_units,activation = None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1,h1)
        
        # Logits and tanh output
        logits = tf.layers.dense(h1,out_dim)
        out = tf.tanh(logits)
        
        return out

使用tf.variable_scope,需要聲明with tf.variable_scope('scope_name', reuse=False):這里我們使用generator作為域的名稱,所以基本上所有的變量都將以generator開(kāi)頭。
這里我們選擇重用,所以它將告訴作用域重用本網(wǎng)絡(luò)中的變量。那么,我們從函數(shù)參數(shù)中獲得reuse,默認(rèn)情況下它是False。tf.layers.dense是一個(gè)全連接層,你可以直接使用層模塊,因?yàn)樗歉呒?jí)的,它會(huì)為你執(zhí)行所有權(quán)重初始化。

辨別器

辨別器和生成器構(gòu)造方法差不多。

def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    ''' Build the discriminator network.
    
        Arguments
        ---------
        x : Input tensor for the discriminator
        n_units: Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('discriminator',reuse = reuse):
        # Hidden layer
        h1 =tf.layers.dense(x,n_units,activation = None)
        # Leaky ReLU
        h1 =tf.maximum(alpha * h1,h1)
        
        logits = tf.layers.dense(h1,1,activation = None)
        out =tf.sigmod(logits)
        
        return out, logits

超參數(shù)

# Size of input image to discriminator
input_size = 784 # 28x28 MNIST images flattened
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Label smoothing 
smooth = 0.1

構(gòu)建網(wǎng)絡(luò)

tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)

# Build the model
g_model = generator(input_z, input_size)
# g_model is the generator output

d_model_real, d_logits_real = discriminator(input_real)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)

這里辨別器用相同的權(quán)重,所以reuse這里為true.

計(jì)算辨別器及生成器的損失

同時(shí)訓(xùn)練辨別器和生成器網(wǎng)絡(luò),我們需要這兩個(gè)不同網(wǎng)絡(luò)的損失。對(duì)辨別器總損失:是真實(shí)圖像和假圖像損失之和。
關(guān)于標(biāo)簽,對(duì)于真實(shí)圖像,我們想讓辨別器知道它們是真的,我們希望標(biāo)簽全部是1。為了幫助辨別器更好的泛化,我們要執(zhí)行一個(gè)叫做標(biāo)簽平滑的操作,創(chuàng)建一個(gè)smooth的參數(shù),略小于1。假數(shù)據(jù)辨別器損失也類似,設(shè)定這些標(biāo)簽全部為0。最后對(duì)于生成器,再次使用d_logits_fake,但這次我們的標(biāo)簽全部為1,我們想讓生成器欺騙辨別器,我們想讓辨別器認(rèn)為假圖像是真的

# Calculate losses
d_loss_real = tf.reduce_mean(
                  tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
                                                          labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(
                  tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                          labels=tf.zeros_like(d_logits_real)))
d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(
             tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                     labels=tf.ones_like(d_logits_fake)))

優(yōu)化器

我們要分別更新生成器和辨別器變量,首先獲取所有可訓(xùn)練的變量

# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

訓(xùn)練

batch_size = 100
epochs = 100
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
        train_loss_g = g_loss.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))    
        # Save losses to view after training
        losses.append((train_loss_d, train_loss_g))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(input_z, input_size, reuse=True),
                       feed_dict={input_z: sample_z})
        samples.append(gen_samples)
        saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

結(jié)果

改進(jìn)GAN

我向你展示的 GAN,在生成器和辨別器中只使用了一個(gè)隱藏層。這個(gè) GAN 的結(jié)果已經(jīng)非常不錯(cuò)了,但仍然有很多噪聲圖像,以及有些圖像看起來(lái)不太像數(shù)字。但是,要讓生成器生成的圖像與 MNIST 數(shù)據(jù)集幾乎一樣,是完全可能的。


這來(lái)自一篇題為 Improved Techniques for Training GANs 的文章。那么,它們?nèi)绾紊扇绱嗣烙^的圖像呢?

批歸一化

提醒一下,在三層情況下你可能無(wú)法使它很好地工作。網(wǎng)絡(luò)會(huì)變得對(duì)權(quán)重的初始值非常敏感,導(dǎo)致無(wú)法訓(xùn)練。我們可以使用 批歸一化(Batch Normalization) 來(lái)解決這個(gè)問(wèn)題。原理很簡(jiǎn)單。就像我們對(duì)網(wǎng)絡(luò)輸入的做法一樣,我們可以對(duì)每個(gè)層的輸入進(jìn)行歸一化。也就是說(shuō),縮放層輸入,使它具有零均值和標(biāo)準(zhǔn)差 1。經(jīng)發(fā)現(xiàn),批歸一化對(duì)于構(gòu)建深度 GAN 非常有必要。
歡迎大家看我以前寫的Batch Normalization

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容