生成對(duì)抗網(wǎng)絡(luò)GAN-PyTorch庫(kù)

GAN 自從被提出后,便迅速受到廣泛關(guān)注。大體上GAN 可分為兩類,一類是無條件下的生成;另一類是基于條件信息的生成。近日,來自韓國(guó)浦項(xiàng)科技大學(xué)的碩士生在 GitHub 上開源了一個(gè)項(xiàng)目,提供了條件 / 無條件圖像生成的代表性生成對(duì)抗網(wǎng)絡(luò)(GAN)的實(shí)現(xiàn)。

項(xiàng)目地址:https://github.com/POSTECH-CVLab/PyTorch-StudioGAN

具體而言,該項(xiàng)目具有以下幾個(gè)顯著特征:

(1)提供了大量 PyTorch 框架的 GAN 實(shí)現(xiàn);

(2)基于 CIFAR 10、Tiny ImageNet 和 ImageNet 數(shù)據(jù)集的 GAN 基準(zhǔn);

(3)相較原始實(shí)現(xiàn)的更好的性能和更低的內(nèi)存消耗;

(4)提供完全最新 PyTorch 環(huán)境的預(yù)訓(xùn)練模型;

(5)支持多 GPU(DP、DDP 和多節(jié)點(diǎn) DDP)、混合精度、同步批歸一化、LARS、Tensorboard 可視化和其他分析方法。

如下圖所示,項(xiàng)目作者提供了 18 + 個(gè) SOTA GAN 的實(shí)現(xiàn),包括 DCGAN、LSGAN、GGAN、WGAN-WC、WGAN-GP、WGAN-DRA、ACGAN、ProjGAN、SNGAN、SAGAN、BigGAN、BigGAN-Deep、CRGAN、ICRGAN、LOGAN、DiffAugGAN、ADAGAN、ContraGAN 和 FreezeD。

cBN:條件批歸一化;AC:輔助分類器;PD:Projection 判別器;CL:對(duì)比學(xué)習(xí)

其中,需要注意以下幾點(diǎn):

(1)G/D_type 表示將標(biāo)簽信息注入生成器或判別式的方式;

(2)EMA 表示生成器中應(yīng)用更新后的指數(shù)移動(dòng)平均線;

(3)Tiny ImageNet 數(shù)據(jù)集上的實(shí)驗(yàn)使用的是 ResNet 架構(gòu)而不是 CNN。

環(huán)境要求

用戶可以采用以下方法安裝推薦的環(huán)境:

conda env create -f environment.yml -n studiogan

在 docker 中還可以采用以下方式:

docker pull mgkang/studiogan:latest

以下是創(chuàng)建名字為「studioGAN」容器的命令,同樣也可以使用端口號(hào)為 6006 來連接 tensoreboard:

docker run -it --gpus all --shm-size 128g -p 6006:6006 --name studioGAN -v /home/USER:/root/code --workdir /root/code mgkang/studiogan:latest /bin/bash

使用方法

使用 GPU 0 的情況下,在 CONFIG_PATH 中對(duì)于模型的訓(xùn)練「-t」和評(píng)估「-e」進(jìn)行了定義:

CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -e -c CONFIG_PATH

在使用 GPU (0, 1, 2, 3) 和 DataParallel 情況下,在 CONFIG_PATH 中對(duì)于模型的訓(xùn)練「-t」和評(píng)估「-e」進(jìn)行了定義:

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -e -c CONFIG_PATH

在 python3 src/main.py 程序中查看可用選項(xiàng),通過 Tensorboard 可以監(jiān)控 IS、FID、F_beta、Authenticity Accuracies 以及最大奇異值:

~ PyTorch-StudioGAN/logs/RUN_NAME>>> tensorboard --logdir=./ --port PORT

可視化以及分析生成圖像

StudioGAN 支持圖像可視化、k 最近鄰分析、線性差值以及頻率分析。所有的結(jié)果保存在「./figures/RUN_NAME/*.png」中。

圖像可視化的代碼和示例如下:

CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -iv -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

k 最近鄰分析,這里固定 K=7,第一列中是生成的圖像

CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -knn -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

線性插值(僅適用于有條件的 Big ResNet 模型 )的代碼和示例如下:

CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -itp -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
禁止轉(zhuǎn)載,如需轉(zhuǎn)載請(qǐng)通過簡(jiǎn)信或評(píng)論聯(lián)系作者。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容