StarGAN-多領域間的轉換

1. StarGAN 的簡介

? ? ? ? Pix2pix解決了兩個領域之間匹配數(shù)據(jù)集之間的轉換,然而在很多情況下匹配數(shù)據(jù)集很難獲得,于是出現(xiàn)了CycleGAN。它可以實現(xiàn)兩個領域之間非匹配數(shù)據(jù)集之間的轉換, 然而這些轉換每次只能在兩個領域之間進行,當需要進行多個領域間的轉換時,就需要訓練多個網絡,非常麻煩。2018年的CVPR上發(fā)表了一篇文章提出了StarGAN,它僅使用一個網絡就實現(xiàn)多個領域之間的圖像轉換,而且圖像轉換的效果也比較好。


2. StarGAN 的網絡結構:

StarGAN的結構拆分

注:上圖中的兩個生成器 G 是同一個生成器,整個StarGAN中只使用了一個生成器和一個判別器。

? ? ? ? StarGAN實現(xiàn)了多個領域圖像之間的轉換,但是網絡結構比CycleGAN 更簡單,整個網絡只包含一個生成器和一個判別器。從結構上StarGAN與ACGAN比較相似,生成器是輸入除了圖像之外還有目標領域的標簽;判別器是輸入除了圖像之外還有相應的類別標簽,而且判別器的輸出除了判別圖像真假之外還要對圖像進行分類。

? ? ? ? 生成器的輸入包含兩個部分,一部分是輸入圖像imgs,大小為(batch_size, n_channel, cols, rows);一部分是目標領域的標簽domain,大小為(batch_size, n_dim)。為了將這兩部拼接,需要通過repeat操作來對domain進行擴展,將其擴展為(batch_size, n_dim, cols, rows),因此,生成器輸入的大小為(batch_size, n_channel + n_dim, cols, rows),生成器的輸出為(batch_size, n_channel, cols, rows)。判別器的輸入為圖像imgs,大小為(batch_size, n_channel, cols, rows),判別器的輸出分為兩部分,一部分是圖像的真假判斷,大小為(batch_size, 1, s1, s2),另一部分為圖像的類別劃分,大小為(batch_size, n_dim)。


3. StarGAN的損失函數(shù)

(1)對抗損失:即常規(guī)的生成對抗網絡的損失的損失函數(shù),判別器在努力地判別輸入圖像的真假,生成器在努力地生成假圖像來欺騙判別器。

? ? ? ? ? ? ? ??L_{adv} = E[log D_{src}(x) ] + E_{c, x} [log(1-D_{src}( G(x, c)) )]

(2)分類損失:即將輸入圖像進行分類的損失。對于判別器D而言,需要將真實圖像分到正確的類別中;對于生成器G而言,需要使得生成圖像分到目標類別中。

對于判別器D:?L_{cls}^r  = E_{x, \tilde{c} } [-log D_{cls} (\tilde{c}|x )]? ? ? ? ? ? 其中,D_{cls} (\tilde{c}|x )代表判別器將真實樣本歸為相應標簽類別\tilde{c}?的概率分布,判別器D的目標是最小化損失函數(shù)L_{cls}^r 。

對于生成器G:?L_{cls}^f  = E_{x, c } [-log D_{cls} (c| G(x, c) )]? ? ? ?生成器希望生成數(shù)據(jù)能夠被判別器判斷為目標分類c, 因此生成器的目標是最小化損失函數(shù)L_{cls}^f 。

(3)重建損失:為了確保生成數(shù)據(jù)能夠很好地還原到原來的領域分類中,此處將原始圖像和經過兩次生成的圖像的L1范數(shù)作為重建損失。

? ? ? ? ? ? ? ??L_{rec} = E_{x, c, \tilde{c} }  [||x-G(G(x, c), \tilde{c} )||_{1} ]

因此,StarGAN的生成器和判別器總的損失函數(shù)分別為:

生成器G損失函數(shù):?L_{G} = -L_{adv} + \lambda _{cls} L_{cls}^f  + \lambda _{rec} L_{rec}

gen_imgs = generator(imgs, sampled_c)? ? ? ?#? 生成圖像,sampled_c 為隨機生成的目標類標簽

recov_imgs = generator(gen_imgs, labels)? ? ? ? ? #? ?圖像重建

fake_validity, pred_cls = discriminator(gen_imgs)? ? ? ? ?#? 生成圖像的判別

loss_G_adv = -torch.mean(fake_validity)? ? ? ? ? # 對抗損失

loss_G_cls = torch.nn.functional.binary_cross_entropy_with_logits(sampled_c, pred_cls, size_average=False) / sampled.size(0)? ? ? ? #? 分類損失

loss_G_rec = torch.nn.L1Loss(recov_imgs, imgs)? ? ?# 重建損失

Loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec? ? ?#? 生成器總的損失

判別器D損失函數(shù):?L_{D} = -L_{adv} + \lambda _{cls} L_{cls}^r

fake_imgs = generator(imgs, sampled_c)? ? ? #? 生成圖像

real_validity, pred_cls = discriminator(imgs)? ? ? # 真實圖像的判別

fake_validity, _ = discriminator( fake_imgs.detach())? ? ?#? 生成圖像的判別

gradient_penalty = compute_gradient_penalty(discriminator, imgs.data, fake_imgs.data)? ? ? ?# 梯度懲罰

loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty? ? ? ?# 對抗損失

loss_D_cls =?torch.nn.functional.binary_cross_entropy_with_logits(labels, pred_cls, size_average=False) / sampled.size(0)? ? ? ? ? ?# 分類損失

Loss_D = loss_D_adv + lambda_cls * loss_D_cls? ? ?# 判別器總的損失

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容