CycleGAN-兩個(gè)領(lǐng)域非匹配圖像的相互轉(zhuǎn)換

1. CycleGAN的簡介

? ? ? ? pix2pix可以很好地處理匹配數(shù)據(jù)集圖像轉(zhuǎn)換,但是在很多情況下匹配數(shù)據(jù)集是沒有的或者是很難收集到的,但是我們可以很容易的得到兩個(gè)領(lǐng)域大量的非匹配數(shù)據(jù)。2017年有兩篇非常相似的論文CycleGAN和DiscoGAN,提出了一種解決非匹配數(shù)據(jù)集的圖像轉(zhuǎn)換方案。而且CycleGAN在轉(zhuǎn)換的過程中,只是將A領(lǐng)域圖像的某些特性轉(zhuǎn)換成B領(lǐng)域的一些特性,圖像的其余大部分內(nèi)容都沒有改變。CycleGAN 能實(shí)現(xiàn)兩個(gè)相近數(shù)據(jù)集之間的轉(zhuǎn)換。


2. CycleGAN的網(wǎng)絡(luò)結(jié)構(gòu)

CycleGAN網(wǎng)絡(luò)結(jié)構(gòu)的拆分

? ? ? ? 該結(jié)構(gòu)中,生成器相當(dāng)于一個(gè)自編碼網(wǎng)絡(luò),前半部分進(jìn)行編碼,后半部分進(jìn)行解碼,而且生成器G和生成器F的結(jié)構(gòu)完全相同,其中生成器G負(fù)責(zé)實(shí)現(xiàn)由X到Y(jié)的轉(zhuǎn)換,生成器F負(fù)責(zé)實(shí)現(xiàn)由Y到X的轉(zhuǎn)換,它們的輸入、輸出的大小均為(batch_size, n_channel, cols, rows),判別器的輸入為(batch_size, n_channel, cols, rows), 判別器的輸出為(batch_size, 1, s1, s2)。


3. CycleGAN的損失函數(shù)

(1)對(duì)抗損失

L_{GAN}(G, D_{Y} , X, Y) = E_{y\in p_{data} } (y)[log D_{Y}(y) ] +  E_{x\in p_{data} } (x)[log (1-D_{Y}(G(x)) )]

L_{GAN}(F, D_{Y} , Y, X) = E_{x\in p_{data} } (x)[log D_{X}(x) ] +  E_{y\in p_{data} } (y)[log (1-D_{X}(G(y)) )]

對(duì)抗損失的作用是,使生成的目標(biāo)領(lǐng)域的圖像和目標(biāo)領(lǐng)域的真實(shí)圖像盡可能地接近。

(2)循環(huán)損失

L_{cyc}(G, F) = E_{x\in p_{data} } (x)[||F(G(x))-x|| _{1} ]  + E_{y\in p_{data} } (y)[||G(F(y))-y|| _{1} ]

循環(huán)損失的作用是,使生成的圖像盡可能多的保留原始圖像的內(nèi)容。

在網(wǎng)絡(luò)訓(xùn)練的過程中是將G和F聯(lián)合起來一起訓(xùn)練的,D_{X} D_{Y} 是單獨(dú)進(jìn)行訓(xùn)練的。

G-F聯(lián)合網(wǎng)絡(luò)的損失函數(shù)為:L_{G-F}=L_{GAN} (G, D_{Y} , X, Y) + L_{GAN} (F, D_{X} , Y, X) + \lambda L_{cycle}(G, F)

fake_B = G_AB(real_A)

loss_GAN_AB = torch.nn.MSELoss(D_B(fake_B), valid)

fake_A = G_BA(real_B)

loss_GAN_BA = torch.nn.MSELoss(D_A(fake_A), valid)

loss_G_GAN = (loss_GAN_AB + loss_GAN_BA) / 2? ? ? #? 生成器的對(duì)抗損失

recov_A = G_BA(fake_B)

loss_cycle_A = torch.nn.L1Loss(recov_A, real_A)

recov_cycle_B = G_AB(fake_A)

loss_cycle_B = torch.nn.L1Loss(recov_B, real_B)

loss_cycle = (loss_cycle_A + loss_cycle_B) / 2? ? ? ? ? #? 生成器的循環(huán)損失

Loss_G = loss_G_GAN + lambda_cycle * loss_cycle

D_{X} 的損失函數(shù)為:?L_{D_{X} } = L_{GAN}(F, D_{X} , Y, X)

loss_real = torch.nn.MSELoss(D_A(real_A), valid)

fake_A = fake_A_buffer.push_and_pop(fake_A)

loss_fake = torch.nn.MSELoss(D_A(fake_A.detach()), fake)

loss_D_A = (loss_real + loss_fake) / 2

D_{Y} 的損失函數(shù)為:?L_{D_{Y} } = L_{GAN}(G, D_{Y} , X, Y)

loss_real = torch.nn.MSELoss(D_B(real_B), valid)

fake_B = fake_B_buffer.push_and_pop(fake_B)

loss_fake = torch.nn.MSELoss(D_B(fake_B.detach(), fake)

loss_D_B = (loss_real + loss_fake) / 2

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容