tensorflow實現(xiàn)深度卷積生成對抗網(wǎng)絡(DCGAN)生成手寫數(shù)字圖片

生成對抗網(wǎng)絡(GANs)是當今計算機科學領域最有趣的想法之一。兩個模型通過對抗過程同時訓練。一個生成器(“藝術家”)學習創(chuàng)造看起來真實的圖像,而判別器(“藝術評論家”)學習區(qū)分真假圖像。

訓練過程中,生成器在生成逼真圖像方面逐漸變強,而判別器在辨別這些圖像的能力上逐漸變強。當判別器不再能夠區(qū)分真實圖片和偽造圖片時,訓練過程達到平衡。

本筆記在 MNIST 數(shù)據(jù)集上演示了該過程。下方動畫展示了當訓練了 50 個epoch (全部數(shù)據(jù)集迭代50次) 時生成器所生成的一系列圖片。圖片從隨機噪聲開始,隨著時間的推移越來越像手寫數(shù)字。


dcgan.gif

本教程演示了如何使用深度卷積生成對抗網(wǎng)絡(DCGAN)生成手寫數(shù)字圖片。該代碼是使用 Keras Sequential API 與 tf.GradientTape 訓練循環(huán)編寫的。

生成器和判別器均使用 Keras Sequential API 定義。

生成器 生成器使用 tf.keras.layers.Conv2DTranspose (上采樣)層來從種子(隨機噪聲)中產(chǎn)生圖片。以一個使用該種子作為輸入的 Dense 層開始,然后多次上采樣直到達到所期望的 28x28x1 的圖片尺寸。注意除了輸出層使用 tanh 之外,其他每層均使用 tf.keras.layers.LeakyReLU 作為激活函數(shù)。

import tensorflow as tf 
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 將圖片標準化到 [-1, 1] 區(qū)間內(nèi)

BUFFER_SIZE = 60000
BATCH_SIZE = 256
# 批量化和打亂數(shù)據(jù)
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 沒有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)

# 該方法返回計算交叉熵損失的輔助函數(shù)
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                discriminator_optimizer=discriminator_optimizer,
                                generator=generator,
                                discriminator=discriminator)

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16


# 我們將重復使用該種子(因此在動畫 GIF 中更容易可視化進度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

# 注意 `tf.function` 的使用
# 該注解使函數(shù)被“編譯”
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch)

    # 繼續(xù)進行時為 GIF 生成圖像
        display.clear_output(wait=True)
        generate_and_save_images(generator,
                            epoch + 1,
                            seed)

    # 每 15 個 epoch 保存一次模型
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# 最后一個 epoch 結(jié)束后生成圖片
    display.clear_output(wait=True)
    generate_and_save_images(generator,epochs,seed)

def generate_and_save_images(model, epoch, test_input):
    # 注意 training` 設定為 False
    # 因此,所有層都在推理模式下運行(batchnorm)。
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4,4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()



train(train_dataset, EPOCHS)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# 使用 epoch 數(shù)生成單張圖片
def display_image(epoch_no):
    return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

display_image(EPOCHS)

輸出結(jié)果:


image.png

Out[57]:

image.png

實際例子代碼下載:https://github.com/wennaz/Deep_Learning

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容