1. StarGAN 的簡介
? ? ? ? Pix2pix解決了兩個領域之間匹配數(shù)據(jù)集之間的轉換,然而在很多情況下匹配數(shù)據(jù)集很難獲得,于是出現(xiàn)了CycleGAN。它可以實現(xiàn)兩個領域之間非匹配數(shù)據(jù)集之間的轉換, 然而這些轉換每次只能在兩個領域之間進行,當需要進行多個領域間的轉換時,就需要訓練多個網絡,非常麻煩。2018年的CVPR上發(fā)表了一篇文章提出了StarGAN,它僅使用一個網絡就實現(xiàn)多個領域之間的圖像轉換,而且圖像轉換的效果也比較好。
2. StarGAN 的網絡結構:

注:上圖中的兩個生成器 G 是同一個生成器,整個StarGAN中只使用了一個生成器和一個判別器。
? ? ? ? StarGAN實現(xiàn)了多個領域圖像之間的轉換,但是網絡結構比CycleGAN 更簡單,整個網絡只包含一個生成器和一個判別器。從結構上StarGAN與ACGAN比較相似,生成器是輸入除了圖像之外還有目標領域的標簽;判別器是輸入除了圖像之外還有相應的類別標簽,而且判別器的輸出除了判別圖像真假之外還要對圖像進行分類。
? ? ? ? 生成器的輸入包含兩個部分,一部分是輸入圖像imgs,大小為(batch_size, n_channel, cols, rows);一部分是目標領域的標簽domain,大小為(batch_size, n_dim)。為了將這兩部拼接,需要通過repeat操作來對domain進行擴展,將其擴展為(batch_size, n_dim, cols, rows),因此,生成器輸入的大小為(batch_size, n_channel + n_dim, cols, rows),生成器的輸出為(batch_size, n_channel, cols, rows)。判別器的輸入為圖像imgs,大小為(batch_size, n_channel, cols, rows),判別器的輸出分為兩部分,一部分是圖像的真假判斷,大小為(batch_size, 1, s1, s2),另一部分為圖像的類別劃分,大小為(batch_size, n_dim)。
3. StarGAN的損失函數(shù)
(1)對抗損失:即常規(guī)的生成對抗網絡的損失的損失函數(shù),判別器在努力地判別輸入圖像的真假,生成器在努力地生成假圖像來欺騙判別器。
? ? ? ? ? ? ? ??
(2)分類損失:即將輸入圖像進行分類的損失。對于判別器D而言,需要將真實圖像分到正確的類別中;對于生成器G而言,需要使得生成圖像分到目標類別中。
對于判別器D:?? ? ? ? ? ? 其中,
代表判別器將真實樣本歸為相應標簽類別
?的概率分布,判別器D的目標是最小化損失函數(shù)
。
對于生成器G:?? ? ? ?生成器希望生成數(shù)據(jù)能夠被判別器判斷為目標分類c, 因此生成器的目標是最小化損失函數(shù)
。
(3)重建損失:為了確保生成數(shù)據(jù)能夠很好地還原到原來的領域分類中,此處將原始圖像和經過兩次生成的圖像的L1范數(shù)作為重建損失。
? ? ? ? ? ? ? ??
因此,StarGAN的生成器和判別器總的損失函數(shù)分別為:
生成器G損失函數(shù):?
gen_imgs = generator(imgs, sampled_c)? ? ? ?#? 生成圖像,sampled_c 為隨機生成的目標類標簽
recov_imgs = generator(gen_imgs, labels)? ? ? ? ? #? ?圖像重建
fake_validity, pred_cls = discriminator(gen_imgs)? ? ? ? ?#? 生成圖像的判別
loss_G_adv = -torch.mean(fake_validity)? ? ? ? ? # 對抗損失
loss_G_cls = torch.nn.functional.binary_cross_entropy_with_logits(sampled_c, pred_cls, size_average=False) / sampled.size(0)? ? ? ? #? 分類損失
loss_G_rec = torch.nn.L1Loss(recov_imgs, imgs)? ? ?# 重建損失
Loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec? ? ?#? 生成器總的損失
判別器D損失函數(shù):?
fake_imgs = generator(imgs, sampled_c)? ? ? #? 生成圖像
real_validity, pred_cls = discriminator(imgs)? ? ? # 真實圖像的判別
fake_validity, _ = discriminator( fake_imgs.detach())? ? ?#? 生成圖像的判別
gradient_penalty = compute_gradient_penalty(discriminator, imgs.data, fake_imgs.data)? ? ? ?# 梯度懲罰
loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty? ? ? ?# 對抗損失
loss_D_cls =?torch.nn.functional.binary_cross_entropy_with_logits(labels, pred_cls, size_average=False) / sampled.size(0)? ? ? ? ? ?# 分類損失
Loss_D = loss_D_adv + lambda_cls * loss_D_cls? ? ?# 判別器總的損失