https://github.com/leehomyc/cyclegan-1
https://junyanz.github.io/CycleGAN/
cyclegan_datasets.py
DATASET_TO_SIZES = {
'horse2zebra_train': 1334,
'horse2zebra_test': 140
}
"""The image types of each dataset. Currently only supports .jpg or .png"""
DATASET_TO_IMAGETYPE = {
'horse2zebra_train': '.jpg',
'horse2zebra_test': '.jpg',
}
"""The path to the output csv file."""
PATH_TO_CSV = {
'horse2zebra_train': './input/horse2zebra/horse2zebra_train.csv',
'horse2zebra_test': './input/horse2zebra/horse2zebra_test.csv',
}
數(shù)據(jù)保存在./input/horse2zebra,有四個(gè)目錄:trainA, trainB, testA, testB
create_cyclegan_dataset.py
"""Create datasets for training and testing."""
import csv
import os
import random
import click
import cyclegan_datasets
def create_list(foldername, fulldir=True, suffix=".jpg"):
"""
:param foldername: The full path of the folder.
:param fulldir: Whether to return the full path or not.
:param suffix: Filter by suffix.
:return: The list of filenames in the folder with given suffix.
"""
file_list_tmp = os.listdir(foldername)
file_list = []
if fulldir:
for item in file_list_tmp:
if item.endswith(suffix):
file_list.append(os.path.join(foldername, item))
else:
for item in file_list_tmp:
if item.endswith(suffix):
file_list.append(item)
return file_list
@click.command()
@click.option('--image_path_a',
type=click.STRING,
default='./input/horse2zebra/trainA',
help='The path to the images from domain_a.')
@click.option('--image_path_b',
type=click.STRING,
default='./input/horse2zebra/trainB',
help='The path to the images from domain_b.')
@click.option('--dataset_name',
type=click.STRING,
default='horse2zebra_train',
help='The name of the dataset in cyclegan_dataset.')
@click.option('--do_shuffle',
type=click.BOOL,
default=False,
help='Whether to shuffle images when creating the dataset.')
def create_dataset(image_path_a, image_path_b,
dataset_name, do_shuffle):
list_a = create_list(image_path_a, True,
cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
list_b = create_list(image_path_b, True,
cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
output_path = cyclegan_datasets.PATH_TO_CSV[dataset_name]
num_rows = cyclegan_datasets.DATASET_TO_SIZES[dataset_name]
all_data_tuples = []
for i in range(num_rows):
all_data_tuples.append((
list_a[i % len(list_a)],
list_b[i % len(list_b)]
))
if do_shuffle is True:
random.shuffle(all_data_tuples)
with open(output_path, 'w') as csv_file:
csv_writer = csv.writer(csv_file)
for data_tuple in enumerate(all_data_tuples):
csv_writer.writerow(list(data_tuple[1]))
@click.command(), @click.option與argparse.ArgumentParser()作用相同。
layers.py
- lrelu
def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
with tf.variable_scope(name):
if alt_relu_impl:
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
else:
return tf.maximum(x, leak * x)
兩種是等價(jià)的,但是第一種占用內(nèi)存更少。
- instance normalization
def instance_norm(x):
with tf.variable_scope("instance_norm"):
epsilon = 1e-5
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
scale = tf.get_variable('scale', [x.get_shape()[-1]],
initializer=tf.truncated_normal_initializer(
mean=1.0, stddev=0.02
))
offset = tf.get_variable(
'offset', [x.get_shape()[-1]],
initializer=tf.constant_initializer(0.0)
)
out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
return out
instance normalization使用單一圖片作為輸入,在GAN,style transfer這類任務(wù)上IN的實(shí)驗(yàn)結(jié)論要優(yōu)于BN,給出的普遍的闡述性解釋是:這類生成式方法,自己的風(fēng)格比較獨(dú)立不應(yīng)該與batch中其他的樣本產(chǎn)生太大聯(lián)系。
axis、scale、offset可以參考前一篇Batch Normalization的部分。
- 卷積層
def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
padding="VALID", name="conv2d", do_norm=True, do_relu=True,
relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d(
inputconv, o_d, f_w, s_w, padding,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=stddev
),
biases_initializer=tf.constant_initializer(0.0)
)
if do_norm:
conv = instance_norm(conv)
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
tf.truncated_normal_initializer: 如果生成的值大于平均值2個(gè)標(biāo)準(zhǔn)偏差的值則丟棄重新選擇
- 反卷積層
def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
stddev=0.02, padding="VALID", name="deconv2d",
do_norm=True, do_relu=True, relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d_transpose(
inputconv, o_d, [f_h, f_w],
[s_h, s_w], padding,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=tf.constant_initializer(0.0)
)
if do_norm:
conv = instance_norm(conv)
# conv = tf.contrib.layers.batch_norm(conv, decay=0.9,
# updates_collections=None, epsilon=1e-5, scale=True,
# scope="batch_norm")
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
losses.py
"""Contains losses used for performing image-to-image domain adaptation."""
import tensorflow as tf
# L(G, F)
def cycle_consistency_loss(real_images, generated_images):
"""Compute the cycle consistency loss.
The cycle consistency loss is defined as the sum of the L1 distances
between the real images from each domain and their generated (fake)
counterparts.
This definition is derived from Equation 2 in:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks.
Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros.
Args:
real_images: A batch of images from domain X, a `Tensor` of shape
[batch_size, height, width, channels].
generated_images: A batch of generated images made to look like they
came from domain X, a `Tensor` of shape
[batch_size, height, width, channels].
Returns:
The cycle consistency loss.
"""
return tf.reduce_mean(tf.abs(real_images - generated_images))
def lsgan_loss_generator(prob_fake_is_real):
"""Computes the LS-GAN loss as minimized by the generator.
Rather than compute the negative loglikelihood, a least-squares loss is
used to optimize the discriminators as per Equation 2 in:
Least Squares Generative Adversarial Networks
Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
Stephen Paul Smolley.
https://arxiv.org/pdf/1611.04076.pdf
Args:
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total LS-GAN loss.
"""
return tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 1))
def lsgan_loss_discriminator(prob_real_is_real, prob_fake_is_real):
"""Computes the LS-GAN loss as minimized by the discriminator.
Rather than compute the negative loglikelihood, a least-squares loss is
used to optimize the discriminators as per Equation 2 in:
Least Squares Generative Adversarial Networks
Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
Stephen Paul Smolley.
https://arxiv.org/pdf/1611.04076.pdf
Args:
prob_real_is_real: The discriminator's estimate that images actually
drawn from the real domain are in fact real.
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total LS-GAN loss.
"""
return (tf.reduce_mean(tf.squared_difference(prob_real_is_real, 1)) +
tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 0))) * 0.5



model.py
- ResNet block
def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
"""build a single block of resnet.
:param inputres: inputres
:param dim: dim
:param name: name
:param padding: for tensorflow version use REFLECT; for pytorch version use
CONSTANT
:return: a single block of resnet.
"""
with tf.variable_scope(name):
out_res = tf.pad(inputres, [[0, 0], [1, 1], [
1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(
out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(
out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
return tf.nn.relu(out_res + inputres)
ResNet Block:

Reflection Padding:
t = tf.constant([[1, 2, 3], [4, 5, 6]])
paddings = tf.constant([[1, 1,], [2, 2]])
tf.pad(t, paddings, "REFLECT") # [[6, 5, 4, 5, 6, 5, 4],
# [3, 2, 1, 2, 3, 2, 1],
# [6, 5, 4, 5, 6, 5, 4],
# [3, 2, 1, 2, 3, 2, 1]]
mode="REFLECT"是映射填充,上下(1維)填充順序和paddings是相反的,左右(零維)順序補(bǔ)齊。
- Generator
def build_generator_resnet_9blocks_tf(inputgen, name="generator", skip=False):
with tf.variable_scope(name):
f = 7
ks = 3
padding = "REFLECT"
pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [
ks, ks], [0, 0]], padding)
o_c1 = layers.general_conv2d(
pad_input, ngf, f, f, 1, 1, 0.02, name="c1")
o_c2 = layers.general_conv2d(
o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
o_c3 = layers.general_conv2d(
o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3")
o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)
o_c4 = layers.general_deconv2d(
o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02,
"SAME", "c4")
o_c5 = layers.general_deconv2d(
o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,
"SAME", "c5")
o_c6 = layers.general_conv2d(o_c5, IMG_CHANNELS, f, f, 1, 1,
0.02, "SAME", "c6",
do_norm=False, do_relu=False)
if skip is True:
out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
else:
out_gen = tf.nn.tanh(o_c6, "t1")
return out_gen
"We use 6 blocks for 128 × 128 training images, and 9 blocks for 256 × 256 or higher-resolution training images.
Let c7s1-k denote a 7 × 7 Convolution-InstanceNorm-ReLU layer with k filters and stride 1.
dk denotes a 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters, and stride 2.
Reflection padding was used to reduce artifacts.
Rk denotes a residual block that contains two 3 × 3 convolutional layers with the same number of filters on both layer.
uk denotes a 3 × 3 fractional-strided-Convolution-InstanceNorm-ReLU layer with k filters, and stride 12 .The network with 6 blocks consists of:
c7s1-32,d64,d128,R128,R128,R128,R128,R128,R128,u64,u32,c7s1-3
The network with 9 blocks consists of:
c7s1-32,d64,d128,R128,R128,R128,R128,R128,R128,R128,R128,R128,u64,u32,c7s1-3"
- Discriminator
def discriminator_tf(inputdisc, name="discriminator"):
with tf.variable_scope(name):
f = 4
o_c1 = layers.general_conv2d(inputdisc, ndf, f, f, 2, 2,
0.02, "SAME", "c1", do_norm=False,
relufactor=0.2)
o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2,
0.02, "SAME", "c2", relufactor=0.2)
o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2,
0.02, "SAME", "c3", relufactor=0.2)
o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 1, 1,
0.02, "SAME", "c4", relufactor=0.2)
o_c5 = layers.general_conv2d(
o_c4, 1, f, f, 1, 1, 0.02,
"SAME", "c5", do_norm=False, do_relu=False
)
return o_c5
"For discriminator networks, we use 70 × 70 PatchGAN [21]. Let Ck denote a 4 × 4 Convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2. After the last layer, we apply a convolution to produce a 1 dimensional output. We do not use InstanceNorm for the first C64 layer. We use leaky ReLUs with slope 0:2. The discriminator architecture is:
C64-C128-C256-C512"
- PatchGAN
def patch_discriminator(inputdisc, name="discriminator"):
with tf.variable_scope(name):
f = 4
patch_input = tf.random_crop(inputdisc, [1, 70, 70, 3])
o_c1 = layers.general_conv2d(patch_input, ndf, f, f, 2, 2,
0.02, "SAME", "c1", do_norm="False",
relufactor=0.2)
o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2,
0.02, "SAME", "c2", relufactor=0.2)
o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2,
0.02, "SAME", "c3", relufactor=0.2)
o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 2, 2,
0.02, "SAME", "c4", relufactor=0.2)
o_c5 = layers.general_conv2d(
o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False,
do_relu=False)
return o_c5
- 輸出
def get_outputs(inputs, network="tensorflow", skip=False):
images_a = inputs['images_a']
images_b = inputs['images_b']
fake_pool_a = inputs['fake_pool_a']
fake_pool_b = inputs['fake_pool_b']
with tf.variable_scope("Model") as scope:
if network == "pytorch":
current_discriminator = discriminator
current_generator = build_generator_resnet_9blocks
elif network == "tensorflow":
current_discriminator = discriminator_tf
current_generator = build_generator_resnet_9blocks_tf
else:
raise ValueError(
'network must be either pytorch or tensorflow'
)
prob_real_a_is_real = current_discriminator(images_a, "d_A")
prob_real_b_is_real = current_discriminator(images_b, "d_B")
fake_images_b = current_generator(images_a, name="g_A", skip=skip)
fake_images_a = current_generator(images_b, name="g_B", skip=skip)
scope.reuse_variables()
prob_fake_a_is_real = current_discriminator(fake_images_a, "d_A")
prob_fake_b_is_real = current_discriminator(fake_images_b, "d_B")
cycle_images_a = current_generator(fake_images_b, "g_B", skip=skip)
cycle_images_b = current_generator(fake_images_a, "g_A", skip=skip)
scope.reuse_variables()
prob_fake_pool_a_is_real = current_discriminator(fake_pool_a, "d_A")
prob_fake_pool_b_is_real = current_discriminator(fake_pool_b, "d_B")
return {
'prob_real_a_is_real': prob_real_a_is_real,
'prob_real_b_is_real': prob_real_b_is_real,
'prob_fake_a_is_real': prob_fake_a_is_real,
'prob_fake_b_is_real': prob_fake_b_is_real,
'prob_fake_pool_a_is_real': prob_fake_pool_a_is_real,
'prob_fake_pool_b_is_real': prob_fake_pool_b_is_real,
'cycle_images_a': cycle_images_a,
'cycle_images_b': cycle_images_b,
'fake_images_a': fake_images_a,
'fake_images_b': fake_images_b,
}
A: 真馬集 images_a(A) -> fake_images_b(fB) -> cycle_images_a
B: 真斑馬集 images_b(B) -> fake_images_a(fA) -> cycle_images_b
fA: 假馬集 fake_pool_a
fB: 假斑馬集 fake_pool_b
data_loader.py
- load sample
import tensorflow as tf
import cyclegan_datasets
import model
def _load_samples(csv_name, image_type):
filename_queue = tf.train.string_input_producer(
[csv_name])
reader = tf.TextLineReader()
_, csv_filename = reader.read(filename_queue)
record_defaults = [tf.constant([], dtype=tf.string),
tf.constant([], dtype=tf.string)]
filename_i, filename_j = tf.decode_csv(
csv_filename, record_defaults=record_defaults)
file_contents_i = tf.read_file(filename_i)
file_contents_j = tf.read_file(filename_j)
if image_type == '.jpg':
image_decoded_A = tf.image.decode_jpeg(
file_contents_i, channels=model.IMG_CHANNELS)
image_decoded_B = tf.image.decode_jpeg(
file_contents_j, channels=model.IMG_CHANNELS)
elif image_type == '.png':
image_decoded_A = tf.image.decode_png(
file_contents_i, channels=model.IMG_CHANNELS, dtype=tf.uint8)
image_decoded_B = tf.image.decode_png(
file_contents_j, channels=model.IMG_CHANNELS, dtype=tf.uint8)
return image_decoded_A, image_decoded_B
和之前pixel2pixel的load_sample過(guò)程類似,只不過(guò)這里reader是TextLineReader()(因?yàn)橐恍惺且唤M文件)。在這里對(duì)csv的讀取也是標(biāo)準(zhǔn)流程,參考:
http://www.itdecent.cn/p/d063804fb272
http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
在調(diào)用run或者eval去執(zhí)行read之前, 必須調(diào)用tf.train.start_queue_runners來(lái)將文件名填充到隊(duì)列。否則read操作會(huì)被阻塞到文件名隊(duì)列中有值為止。
# 流程
q = tf.train.string_input_producer([csv1, csv2, ...])
# string_input_producer同時(shí)打開多個(gè)文件,顯式創(chuàng)建Queue,同時(shí)隱含了QueueRunner的創(chuàng)建
reader = tf.TextLineReader()
content = reader.read(q)
record_defaults = [[], [], ...]
tf.decode_csv(content, record_defaults=record_defaults)
...
coord = tf.train.Coordinator()
# 創(chuàng)建coordinator
threads = tf.train.start_queue_runners(coord=coord)
# 啟動(dòng)計(jì)算圖中所有的隊(duì)列線程
- load data
def load_data(dataset_name, image_size_before_crop,
do_shuffle=True, do_flipping=False):
"""
:param dataset_name: The name of the dataset.
:param image_size_before_crop: Resize to this size before random cropping.
:param do_shuffle: Shuffle switch.
:param do_flipping: Flip switch.
:return:
"""
if dataset_name not in cyclegan_datasets.DATASET_TO_SIZES:
raise ValueError('split name %s was not recognized.'
% dataset_name)
csv_name = cyclegan_datasets.PATH_TO_CSV[dataset_name]
image_i, image_j = _load_samples(
csv_name, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
inputs = {
'image_i': image_i,
'image_j': image_j
}
# Preprocessing:
inputs['image_i'] = tf.image.resize_images(
inputs['image_i'], [image_size_before_crop, image_size_before_crop])
inputs['image_j'] = tf.image.resize_images(
inputs['image_j'], [image_size_before_crop, image_size_before_crop])
if do_flipping is True:
inputs['image_i'] = tf.image.random_flip_left_right(inputs['image_i'])
inputs['image_j'] = tf.image.random_flip_left_right(inputs['image_j'])
inputs['image_i'] = tf.random_crop(
inputs['image_i'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3])
inputs['image_j'] = tf.random_crop(
inputs['image_j'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3])
inputs['image_i'] = tf.subtract(tf.div(inputs['image_i'], 127.5), 1)
inputs['image_j'] = tf.subtract(tf.div(inputs['image_j'], 127.5), 1)
# Batch
if do_shuffle is True:
inputs['images_i'], inputs['images_j'] = tf.train.shuffle_batch(
[inputs['image_i'], inputs['image_j']], 1, 5000, 100)
else:
inputs['images_i'], inputs['images_j'] = tf.train.batch(
[inputs['image_i'], inputs['image_j']], 1)
return inputs
tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue),capacity是隊(duì)列中的容量,min_after_capacity是出隊(duì)后,隊(duì)列至少剩下min_after_dequeue個(gè)數(shù)據(jù)。
main.py
代碼很易懂,注釋也寫的很清楚。
from datetime import datetime
import json
import numpy as np
import os
import random
from scipy.misc import imsave
import click
import tensorflow as tf
import cyclegan_datasets
import data_loader, losses, model
slim = tf.contrib.slim
class CycleGAN:
"""The CycleGAN module."""
...
- 初始化
def __init__(self, pool_size, lambda_a,
lambda_b, output_root_dir, to_restore,
base_lr, max_step, network_version,
dataset_name, checkpoint_dir, do_flipping, skip):
current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
self._pool_size = pool_size
self._size_before_crop = 286
self._lambda_a = lambda_a # cycleloss前的系數(shù)lambda
self._lambda_b = lambda_b
self._output_dir = os.path.join(output_root_dir, current_time)
self._images_dir = os.path.join(self._output_dir, 'imgs')
self._num_imgs_to_save = 20
self._to_restore = to_restore
self._base_lr = base_lr
self._max_step = max_step
self._network_version = network_version
self._dataset_name = dataset_name
self._checkpoint_dir = checkpoint_dir
self._do_flipping = do_flipping
self._skip = skip
self.fake_images_A = np.zeros(
(self._pool_size, 1, model.IMG_HEIGHT, model.IMG_WIDTH,
model.IMG_CHANNELS)
)
self.fake_images_B = np.zeros(
(self._pool_size, 1, model.IMG_HEIGHT, model.IMG_WIDTH,
model.IMG_CHANNELS)
)
def model_setup(self):
"""
This function sets up the model to train.
self.input_A/self.input_B -> Set of training images.
self.fake_A/self.fake_B -> Generated images by corresponding generator
of input_A and input_B
self.lr -> Learning rate variable
self.cyc_A/ self.cyc_B -> Images generated after feeding
self.fake_A/self.fake_B to corresponding generator.
This is use to calculate cyclic loss
"""
self.input_a = tf.placeholder(
tf.float32, [
1,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="input_A")
self.input_b = tf.placeholder(
tf.float32, [
1,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="input_B")
self.fake_pool_A = tf.placeholder(
tf.float32, [
None,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="fake_pool_A")
self.fake_pool_B = tf.placeholder(
tf.float32, [
None,
model.IMG_WIDTH,
model.IMG_HEIGHT,
model.IMG_CHANNELS
], name="fake_pool_B")
self.global_step = slim.get_or_create_global_step()
self.num_fake_inputs = 0
self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")
inputs = {
'images_a': self.input_a,
'images_b': self.input_b,
'fake_pool_a': self.fake_pool_A,
'fake_pool_b': self.fake_pool_B,
}
outputs = model.get_outputs(
inputs, network=self._network_version, skip=self._skip)
self.prob_real_a_is_real = outputs['prob_real_a_is_real']
self.prob_real_b_is_real = outputs['prob_real_b_is_real']
self.fake_images_a = outputs['fake_images_a']
self.fake_images_b = outputs['fake_images_b']
self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']
self.cycle_images_a = outputs['cycle_images_a']
self.cycle_images_b = outputs['cycle_images_b']
self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']
-
計(jì)算代價(jià)
def compute_losses(self):
"""
In this function we are defining the variables for loss calculations
and training model.
d_loss_A/d_loss_B -> loss for discriminator A/B
g_loss_A/g_loss_B -> loss for generator A/B
*_trainer -> Various trainer for above loss functions
*_summ -> Summary variables for above loss functions
"""
cycle_consistency_loss_a = \
self._lambda_a * losses.cycle_consistency_loss(
real_images=self.input_a, generated_images=self.cycle_images_a,
)
cycle_consistency_loss_b = \
self._lambda_b * losses.cycle_consistency_loss(
real_images=self.input_b, generated_images=self.cycle_images_b,
)
lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
g_loss_A = \
cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
g_loss_B = \
cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a
d_loss_A = losses.lsgan_loss_discriminator(
prob_real_is_real=self.prob_real_a_is_real,
prob_fake_is_real=self.prob_fake_pool_a_is_real,
)
d_loss_B = losses.lsgan_loss_discriminator(
prob_real_is_real=self.prob_real_b_is_real,
prob_fake_is_real=self.prob_fake_pool_b_is_real,
)
optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
self.model_vars = tf.trainable_variables()
d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]
self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)
for var in self.model_vars:
print(var.name)
# Summary variables for tensorboard
self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
- 緩存生成圖片
def fake_image_pool(self, num_fakes, fake, fake_pool):
"""
This function saves the generated image to corresponding
pool of images.
It keeps on feeling the pool till it is full and then randomly
selects an already stored image and replace it with new one.
"""
if num_fakes < self._pool_size:
fake_pool[num_fakes] = fake
return fake
else:
p = random.random()
if p > 0.5:
random_id = random.randint(0, self._pool_size - 1)
temp = fake_pool[random_id]
fake_pool[random_id] = fake
return temp
else:
return fake
- 訓(xùn)練
def train(self):
"""Training Function."""
# Load Dataset from the dataset folder
self.inputs = data_loader.load_data(
self._dataset_name, self._size_before_crop,
True, self._do_flipping)
# Build the network
self.model_setup()
# Loss function calculations
self.compute_losses()
# Initializing the global variables
init = (tf.global_variables_initializer(),
tf.local_variables_initializer())
saver = tf.train.Saver()
max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name]
with tf.Session() as sess:
sess.run(init)
# Restore the model to run the model from last checkpoint
if self._to_restore:
chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
saver.restore(sess, chkpt_fname)
writer = tf.summary.FileWriter(self._output_dir)
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Training Loop
for epoch in range(sess.run(self.global_step), self._max_step):
print("In the epoch ", epoch)
saver.save(sess, os.path.join(
self._output_dir, "cyclegan"), global_step=epoch)
# Dealing with the learning rate as per the epoch number
if epoch < 100:
curr_lr = self._base_lr
else:
curr_lr = self._base_lr - \
self._base_lr * (epoch - 100) / 100
self.save_images(sess, epoch)
for i in range(0, max_images):
print("Processing batch {}/{}".format(i, max_images))
inputs = sess.run(self.inputs)
# Optimizing the G_A network
_, fake_B_temp, summary_str = sess.run(
[self.g_A_trainer,
self.fake_images_b,
self.g_A_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr
}
)
writer.add_summary(summary_str, epoch * max_images + i)
fake_B_temp1 = self.fake_image_pool(
self.num_fake_inputs, fake_B_temp, self.fake_images_B)
# Optimizing the D_B network
_, summary_str = sess.run(
[self.d_B_trainer, self.d_B_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr,
self.fake_pool_B: fake_B_temp1
}
)
writer.add_summary(summary_str, epoch * max_images + i)
# Optimizing the G_B network
_, fake_A_temp, summary_str = sess.run(
[self.g_B_trainer,
self.fake_images_a,
self.g_B_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr
}
)
writer.add_summary(summary_str, epoch * max_images + i)
fake_A_temp1 = self.fake_image_pool(
self.num_fake_inputs, fake_A_temp, self.fake_images_A)
# Optimizing the D_A network
_, summary_str = sess.run(
[self.d_A_trainer, self.d_A_loss_summ],
feed_dict={
self.input_a:
inputs['images_i'],
self.input_b:
inputs['images_j'],
self.learning_rate: curr_lr,
self.fake_pool_A: fake_A_temp1
}
)
writer.add_summary(summary_str, epoch * max_images + i)
writer.flush()
self.num_fake_inputs += 1
sess.run(tf.assign(self.global_step, epoch + 1))
coord.request_stop()
coord.join(threads)
writer.add_graph(sess.graph)
- 測(cè)試
def test(self):
"""Test Function."""
print("Testing the results")
self.inputs = data_loader.load_data(
self._dataset_name, self._size_before_crop,
False, self._do_flipping)
self.model_setup()
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
saver.restore(sess, chkpt_fname)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
self._num_imgs_to_save = cyclegan_datasets.DATASET_TO_SIZES[
self._dataset_name]
self.save_images(sess, 0)
coord.request_stop()
coord.join(threads)
- 主程序
@click.command()
@click.option('--to_train',
type=click.INT,
default=True,
help='Whether it is train or false.')
@click.option('--log_dir',
type=click.STRING,
default=None,
help='Where the data is logged to.')
@click.option('--config_filename',
type=click.STRING,
default='train',
help='The name of the configuration file.')
@click.option('--checkpoint_dir',
type=click.STRING,
default='',
help='The name of the train/test split.')
@click.option('--skip',
type=click.BOOL,
default=False,
help='Whether to add skip connection between input and output.')
def main(to_train, log_dir, config_filename, checkpoint_dir, skip):
"""
:param to_train: Specify whether it is training or testing. 1: training; 2:
resuming from latest checkpoint; 0: testing.
:param log_dir: The root dir to save checkpoints and imgs. The actual dir
is the root dir appended by the folder with the name timestamp.
:param config_filename: The configuration file.
:param checkpoint_dir: The directory that saves the latest checkpoint. It
only takes effect when to_train == 2.
:param skip: A boolean indicating whether to add skip connection between
input and output.
"""
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
with open(config_filename) as config_file:
config = json.load(config_file)
lambda_a = float(config['_LAMBDA_A']) if '_LAMBDA_A' in config else 10.0
lambda_b = float(config['_LAMBDA_B']) if '_LAMBDA_B' in config else 10.0
pool_size = int(config['pool_size']) if 'pool_size' in config else 50
to_restore = (to_train == 2)
base_lr = float(config['base_lr']) if 'base_lr' in config else 0.0002
max_step = int(config['max_step']) if 'max_step' in config else 200
network_version = str(config['network_version'])
dataset_name = str(config['dataset_name'])
do_flipping = bool(config['do_flipping'])
cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir,
to_restore, base_lr, max_step, network_version,
dataset_name, checkpoint_dir, do_flipping, skip)
if to_train > 0:
cyclegan_model.train()
else:
cyclegan_model.test()
- 保存圖片
def save_images(self, sess, epoch):
"""
Saves input and output images.
:param sess: The session.
:param epoch: Currnt epoch.
"""
if not os.path.exists(self._images_dir):
os.makedirs(self._images_dir)
names = ['inputA_', 'inputB_', 'fakeA_',
'fakeB_', 'cycA_', 'cycB_']
with open(os.path.join(
self._output_dir, 'epoch_' + str(epoch) + '.html'
), 'w') as v_html:
for i in range(0, self._num_imgs_to_save):
print("Saving image {}/{}".format(i, self._num_imgs_to_save))
inputs = sess.run(self.inputs)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([
self.fake_images_a,
self.fake_images_b,
self.cycle_images_a,
self.cycle_images_b
], feed_dict={
self.input_a: inputs['images_i'],
self.input_b: inputs['images_j']
})
tensors = [inputs['images_i'], inputs['images_j'],
fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp]
for name, tensor in zip(names, tensors):
image_name = name + str(epoch) + "_" + str(i) + ".jpg"
imsave(os.path.join(self._images_dir, image_name),
((tensor[0] + 1) * 127.5).astype(np.uint8)
)
v_html.write(
"<img src=\"" +
os.path.join('imgs', image_name) + "\">"
)
v_html.write("<br>")
結(jié)果
我自己在服務(wù)器上跑了100個(gè)epoch后的結(jié)果:
馬 -> 斑馬 -> 馬



斑馬 -> 馬 -> 斑馬



這還算是比較好的結(jié)果,有的慘不忍睹:



總體來(lái)說(shuō),馬->斑馬遠(yuǎn)遠(yuǎn)好于斑馬->馬,而且生成的馬身上仍然條紋很多,不知道訓(xùn)練更長(zhǎng)時(shí)間會(huì)不會(huì)好一些。
