梵高眼中的世界(一)實(shí)時(shí)圖像風(fēng)格轉(zhuǎn)換簡介

本文目錄:

  • Introduction
  • Related work
  • Methods
  • Gram 矩陣
  • Batch Normalization

Introduction

不久前,一個(gè)名叫Prisma的APP在微博和朋友圈火了起來。Prisma是個(gè)能夠?qū)D像風(fēng)格轉(zhuǎn)換為藝術(shù)風(fēng)格的APP,它能夠?qū)崿F(xiàn)如下轉(zhuǎn)換:

Prisma

除了引起大眾的好奇心外,業(yè)內(nèi)人士也紛紛猜測Prisma是如何做到實(shí)現(xiàn)快速的圖像風(fēng)格轉(zhuǎn)換。此前,在Gatys的論文<Image Style Transfer Using Convolutional Neural Networks>中,實(shí)現(xiàn)一張圖片的圖像風(fēng)格轉(zhuǎn)換需要較長時(shí)間。

在文中我將講解Prisma是如何實(shí)現(xiàn)實(shí)時(shí)風(fēng)格轉(zhuǎn)換的。本文內(nèi)容基于Fei Fei Li團(tuán)隊(duì)的<Perceptual Losses for Real-Time Style Transfer
and Super-Resolution>一文。

系列文章目錄如下:

  • 梵高眼中的世界(一)實(shí)時(shí)圖像風(fēng)格轉(zhuǎn)換簡介
  • 梵高眼中的世界(二)基于perceptual損失的網(wǎng)絡(luò)
  • 梵高眼中的世界(三)實(shí)現(xiàn)與改進(jìn)

Related work

在進(jìn)行圖像風(fēng)格轉(zhuǎn)換時(shí),我們需要一張風(fēng)格圖像style image和一張內(nèi)容圖像content image。我們構(gòu)造一個(gè)網(wǎng)絡(luò)衡量生成圖像與style image以及content image的loss,再通過訓(xùn)練減小loss得到最終圖像。

在Gatys的方法中,他使用了如下圖所示的方法:

Gatys' method

上圖最左邊是風(fēng)格圖像,梵高的《星夜》;最右邊是內(nèi)容圖像。
算法步驟如下:

  1. 生成了一張白噪聲圖像作為初始圖像。

  2. 將風(fēng)格圖像,內(nèi)容圖像,初始圖像分別通過一個(gè)預(yù)訓(xùn)練的VGG-19網(wǎng)絡(luò),得到某些層的輸出。這里的“某些層”是經(jīng)過實(shí)驗(yàn)得出的,是使得輸出圖像最佳的層數(shù)。

  3. 計(jì)算內(nèi)容損失函數(shù):

    內(nèi)容損失函數(shù)

    其中Pl_ij是原始圖像在第l層位置j與第i個(gè)filter卷積后的輸出,F(xiàn)l_ij是相應(yīng)的生成圖像的輸出。

計(jì)算風(fēng)格損失函數(shù):

單層

多層累加

風(fēng)格損失函數(shù)與圖像有些不同,在這里我們不直接使用某些層卷積后的輸出,而是計(jì)算輸出的Gram矩陣,再用于上式風(fēng)格損失的計(jì)算:

Gram matrix

5.計(jì)算總損失

Total loss

此時(shí)我們可以通過梯度下降算法對(duì)初始化的白噪聲圖像進(jìn)行訓(xùn)練,得到最終的風(fēng)格轉(zhuǎn)換圖像。
Gatys的算法缺點(diǎn)是一次只能訓(xùn)練出一張圖。我們希望得到一個(gè)前饋的神經(jīng)網(wǎng)絡(luò),對(duì)于每一張內(nèi)容圖像,只需要通過這個(gè)前饋神經(jīng)網(wǎng)絡(luò),就能快速得到風(fēng)格轉(zhuǎn)換圖像。

Methods

在這里只對(duì)Gram matrix以及Batch Normalization進(jìn)行講解,具體實(shí)現(xiàn)細(xì)節(jié)請(qǐng)閱讀原文。

Gram matrix

Gram matrix 計(jì)算如下:

Gram matrix

上式的意思為,G^l_i,j意味著第l層特征圖i和j的內(nèi)積。同理可表示為:

Gram matrix

在論文中,作者用高維的特征圖相關(guān)性來表示圖像風(fēng)格。上式矩陣的對(duì)角線表示每一個(gè)特征圖自身的信息,其余元素表示了不同特征圖之間的信息。

Gram matrix的tensorflow實(shí)現(xiàn)如下:

def gram_matrix(x):
    '''
    Args:
        x: Tensor with shape [batch size, length, width, channels]
    Return:
        Tensor with shape [channels, channels]
    '''
    bs, l, w, c = x.get_shape()
    size = l*w*c
    x = tf.reshape(x, (bs, l*w, c))
    x_t = tf.transpose(x, perm=[0,2,1])
    return tf.matmul(x_t, x)/size

Batch Normalization

Batch Normalization 最早由Google在ICML2015的論文<Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift>提出。
其算法如下:

Batch Normalization Algorithm

這個(gè)算法看上去有點(diǎn)復(fù)雜,但直觀上很好理解:
對(duì)于一個(gè)mini-batch里面的值x_i,我們計(jì)算平均值 μ和方差σ。對(duì)于每一個(gè)x_i,我們對(duì)其進(jìn)行z-score歸一化,得到平均值為0,標(biāo)準(zhǔn)差為1的數(shù)據(jù)。式子中的ε是一個(gè)很小的偏差值,防止出現(xiàn)除以0的情況。實(shí)現(xiàn)中可以取ε=1e-3。在對(duì)數(shù)據(jù)進(jìn)行歸一化后,BN算法再進(jìn)行“scale and shift”,將數(shù)據(jù)還原成原來的輸入。
Batch Normalization是為了解決Internal Covariate Shift問題而提出。

Explanation

Batch Normalization在Tensorflow下的實(shí)現(xiàn):

from tensorflow.contrib.layers import batch_norm
def batch_norm_layer(x, is_training, scope):
    bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=True,
    reuse=None,
    trainable=True,
    scope=scope)

    bn_test = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=False,
    reuse=True, 
    trainable=True,
    scope=scope)

    bn = tf.cond(is_training, lambda: bn_train, lambda: bn_test)
    return bn

注意其中is_training是一個(gè)placeholder。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容