【譯文】TensorFlow實(shí)現(xiàn)Batch Normalization

原文:Implementing Batch Normalization in Tensorflow
來源:R2RT

黑猿大叔注:本文基于一個(gè)最基礎(chǔ)的全連接網(wǎng)絡(luò),演示如何構(gòu)建Batch Norm層、如何訓(xùn)練以及如何正確進(jìn)行測試,玩轉(zhuǎn)這份示例代碼是理解Batch Norm的最好方式。

文中代碼可在jupyter notebook環(huán)境下運(yùn)行:

批標(biāo)準(zhǔn)化,是Sergey Ioffe和Christian Szegedy在2015年3月的論文BN2015中提出的一種簡單、高效的改善神經(jīng)網(wǎng)絡(luò)性能的方法。論文BN2015中,Ioffe和Szegedy指出批標(biāo)準(zhǔn)化不僅能應(yīng)用更高的學(xué)習(xí)率、具有正則化器的效用,還能將訓(xùn)練速度提升14倍之多。本文將基于TensorFlow來實(shí)現(xiàn)批標(biāo)準(zhǔn)化。

問題的提出

批標(biāo)準(zhǔn)化所要解決的問題是:模型參數(shù)在學(xué)習(xí)階段的變化,會使每個(gè)隱藏層輸出的分布也發(fā)生改變。這意味著靠后的層要在訓(xùn)練過程中去適應(yīng)這些變化。

批標(biāo)準(zhǔn)化的概念

為了解決這個(gè)問題,論文BN2015提出了批標(biāo)準(zhǔn)化,即在訓(xùn)練時(shí)作用于每個(gè)神經(jīng)元激活函數(shù)(比如sigmoid或者ReLU函數(shù))的輸入,使得基于每個(gè)批次的訓(xùn)練樣本,激活函數(shù)的輸入都能滿足均值為0,方差為1的分布。對于激活函數(shù)σ(Wx+b),應(yīng)用批標(biāo)準(zhǔn)化后變?yōu)棣?BN(Wx+b)),其中BN代表批標(biāo)準(zhǔn)化。

批標(biāo)準(zhǔn)化公式

對一批數(shù)據(jù)中的某個(gè)數(shù)值進(jìn)行標(biāo)準(zhǔn)化,做法是先減去整批數(shù)據(jù)的均值,然后除以整批數(shù)據(jù)的標(biāo)準(zhǔn)差√(σ2+ε)。注意小的常量ε加到方差中是為了防止除零。給定一個(gè)數(shù)值xi,一個(gè)初始的批標(biāo)準(zhǔn)化公式如下:

上面的公式中,批標(biāo)準(zhǔn)化對激活函數(shù)的輸入約束為正態(tài)分布,但是這樣一來限制了網(wǎng)絡(luò)層的表達(dá)能力。為此,可以通過乘以一個(gè)新的比例參數(shù)γ,并加上一個(gè)新的位移參數(shù)β,來讓網(wǎng)絡(luò)撤銷批標(biāo)準(zhǔn)化變換。γ和β都是可學(xué)習(xí)參數(shù)。

加入γ和β后得到下面最終的批標(biāo)準(zhǔn)化公式:

基于TensorFlow實(shí)現(xiàn)批標(biāo)準(zhǔn)化

我們將把批標(biāo)準(zhǔn)化加進(jìn)一個(gè)有兩個(gè)隱藏層、每層包含100個(gè)神經(jīng)元的全連接神經(jīng)網(wǎng)絡(luò),并展示與論文BN2015中圖1(b)和(c)類似的實(shí)驗(yàn)結(jié)果。

需要注意,此時(shí)該網(wǎng)絡(luò)還不適合在測試期使用。后面的“模型預(yù)測”一節(jié)中將會闡釋其中的原因,并給出修復(fù)版本。

Imports, config

import numpy as np, tensorflow as tf, tqdm
from tensorflow.examples.tutorials.mnist                       
import input_data
import matplotlib.pyplot as plt
%matplotlib inline
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Generate predetermined random weights so the networks are similarly initialized
w1_initial = np.random.normal(size=(784,100)).astype(np.float32)
w2_initial = np.random.normal(size=(100,100)).astype(np.float32)
w3_initial = np.random.normal(size=(100,10)).astype(np.float32)

# Small epsilon value for the BN transform
epsilon = 1e-3

Building the graph

# Placeholders
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])

# Layer 1 without BN
w1 = tf.Variable(w1_initial)
b1 = tf.Variable(tf.zeros([100]))
z1 = tf.matmul(x,w1)+b1
l1 = tf.nn.sigmoid(z1)

下面是經(jīng)過批標(biāo)準(zhǔn)化的第一層:

# Layer 1 with BN
w1_BN = tf.Variable(w1_initial)

# Note that pre-batch normalization bias is ommitted. The effect of this bias would be
# eliminated when subtracting the batch mean. Instead, the role of the bias is performed
# by the new beta variable. See Section 3.2 of the BN2015 paper.
z1_BN = tf.matmul(x,w1_BN)

# Calculate batch mean and variance
batch_mean1, batch_var1 = tf.nn.moments(z1_BN,[0])

# Apply the initial batch normalizing transform
z1_hat = (z1_BN - batch_mean1) / tf.sqrt(batch_var1 + epsilon)

# Create two new parameters, scale and beta (shift)
scale1 = tf.Variable(tf.ones([100]))
beta1 = tf.Variable(tf.zeros([100]))

# Scale and shift to obtain the final output of the batch normalization
# this value is fed into the activation function (here a sigmoid)
BN1 = scale1 * z1_hat + beta1
l1_BN = tf.nn.sigmoid(BN1)

# Layer 2 without BN
w2 = tf.Variable(w2_initial)
b2 = tf.Variable(tf.zeros([100]))
z2 = tf.matmul(l1,w2)+b2
l2 = tf.nn.sigmoid(z2)

TensorFlow提供了tf.nn.batch_normalization,我用它定義了下面的第二層。這與上面第一層的代碼行為是一樣的。查閱官方文檔在這里,查閱開源代碼在這里。

# Layer 2 with BN, using Tensorflows built-in BN function
w2_BN = tf.Variable(w2_initial)
z2_BN = tf.matmul(l1_BN,w2_BN)
batch_mean2, batch_var2 = tf.nn.moments(z2_BN,[0])
scale2 = tf.Variable(tf.ones([100]))
beta2 = tf.Variable(tf.zeros([100]))
BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon)
l2_BN = tf.nn.sigmoid(BN2)

# Softmax
w3 = tf.Variable(w3_initial)
b3 = tf.Variable(tf.zeros([10]))
y  = tf.nn.softmax(tf.matmul(l2,w3)+b3)

w3_BN = tf.Variable(w3_initial)
b3_BN = tf.Variable(tf.zeros([10]))
y_BN  = tf.nn.softmax(tf.matmul(l2_BN,w3_BN)+b3_BN)

# Loss, optimizer and predictions
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
cross_entropy_BN = -tf.reduce_sum(y_*tf.log(y_BN))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
train_step_BN = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy_BN)

correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
correct_prediction_BN = tf.equal(tf.arg_max(y_BN,1),tf.arg_max(y_,1))
accuracy_BN = tf.reduce_mean(tf.cast(correct_prediction_BN,tf.float32))

Training the network

zs, BNs, acc, acc_BN = [], [], [], []

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
for i in tqdm.tqdm(range(40000)):
    batch = mnist.train.next_batch(60)
    train_step.run(feed_dict={x: batch[0], y_: batch[1]})
    train_step_BN.run(feed_dict={x: batch[0], y_: batch[1]})
    if i % 50 is 0:
        res = sess.run([accuracy,accuracy_BN,z2,BN2],feed_dict={x: mnist.test.images, y_: mnist.test.labels})
        acc.append(res[0])
        acc_BN.append(res[1])
        zs.append(np.mean(res[2],axis=0)) # record the mean value of z2 over the entire test set
        BNs.append(np.mean(res[3],axis=0)) # record the mean value of BN2 over the entire test set

zs, BNs, acc, acc_BN = np.array(zs), np.array(BNs), np.array(acc), np.array(acc_BN)

速度和精度的提升

如下所示,應(yīng)用批標(biāo)準(zhǔn)化后,精度和訓(xùn)練速度均有可觀的改善。論文BN2015中的圖2顯示,批標(biāo)準(zhǔn)化對于其他網(wǎng)絡(luò)架構(gòu)也同樣具有重要作用。

fig, ax = plt.subplots()

ax.plot(range(0,len(acc)*50,50),acc, label='Without BN')
ax.plot(range(0,len(acc)*50,50),acc_BN, label='With BN')
ax.set_xlabel('Training steps')
ax.set_ylabel('Accuracy')
ax.set_ylim([0.8,1])
ax.set_title('Batch Normalization Accuracy')
ax.legend(loc=4)
plt.show()

激活函數(shù)輸入的時(shí)間序列圖示

下面是網(wǎng)絡(luò)第2層的前5個(gè)神經(jīng)元的sigmoid激活函數(shù)輸入隨時(shí)間的分布情況。批標(biāo)準(zhǔn)化在消除輸入的方差/噪聲上具有顯著的效果。

fig, axes = plt.subplots(5, 2, figsize=(6,12))
fig.tight_layout()

for i, ax in enumerate(axes):
    ax[0].set_title("Without BN")
    ax[1].set_title("With BN")
    ax[0].plot(zs[:,i])
    ax[1].plot(BNs[:,i])

模型預(yù)測

使用批標(biāo)準(zhǔn)化模型進(jìn)行預(yù)測時(shí),使用批量樣本自身的均值和方差會適得其反。想象一下單個(gè)樣本進(jìn)入我們訓(xùn)練的模型會發(fā)生什么?激活函數(shù)的輸入將永遠(yuǎn)為零(因?yàn)槲覀冏龅氖蔷禐?的標(biāo)準(zhǔn)化),而且無論輸入是什么,我們總得到相同的結(jié)果。

驗(yàn)證如下:

predictions = []
correct = 0
for i in range(100):
    pred, corr = sess.run([tf.arg_max(y_BN,1), accuracy_BN],
                         feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]})
    correct += corr
    predictions.append(pred[0])
print("PREDICTIONS:", predictions)
print("ACCURACY:", correct/100)

PREDICTIONS: [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
ACCURACY: 0.02

我們的模型總是輸出8,在MNIST的前100個(gè)樣本中8實(shí)際上只有2個(gè),所以精度只有2%。

修改模型的測試期行為

為了修復(fù)這個(gè)問題,我們需要將批均值和批方差替換成全局均值和全局方差。詳見論文BN2015的3.1節(jié)。但是這會造成,上面的模型想正確的工作,就只能一次性的將測試集所有樣本進(jìn)行預(yù)測,因?yàn)檫@樣才能算出理想的全局均值和全局方差。

為了使批標(biāo)準(zhǔn)化模型適用于測試,我們需要在測試前的每一步批標(biāo)準(zhǔn)化操作時(shí),都對全局均值和全局方差進(jìn)行估算,然后才能在做預(yù)測時(shí)使用這些值。和我們需要批標(biāo)準(zhǔn)化的原因一樣(激活輸入的均值和方差在訓(xùn)練時(shí)會發(fā)生變化),估算全局均值和方差最好在其依賴的權(quán)重更新完成后,但是同時(shí)進(jìn)行也不算特別糟,因?yàn)闄?quán)重在訓(xùn)練快結(jié)束時(shí)就收斂了。

現(xiàn)在,為了基于TensorFlow來實(shí)現(xiàn)修復(fù),我們要寫一個(gè)batch_norm_wrapper函數(shù),來封裝激活輸入。這個(gè)函數(shù)會將全局均值和方差作為tf.Variables來存儲,并在做標(biāo)準(zhǔn)化時(shí)決定采用批統(tǒng)計(jì)還是全局統(tǒng)計(jì)。為此,需要一個(gè)is_training標(biāo)記。當(dāng)is_training == True,我們就要在訓(xùn)練期學(xué)習(xí)全局均值和方差。代碼骨架如下:

def batch_norm_wrapper(inputs, is_training):
    ...
    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)

    if is_training:
        mean, var = tf.nn.moments(inputs,[0])
        ...
        # learn pop_mean and pop_var here
        ...
        return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, scale, epsilon)
    else:
        return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)

注意變量節(jié)點(diǎn)聲明了 trainable = False,因?yàn)槲覀儗⒁孕懈滤鼈?,而不是讓最?yōu)化器來更新。

在訓(xùn)練期間,一個(gè)計(jì)算全局均值和方差的方法是指數(shù)平滑法,它很簡單,且避免了額外的工作,我們應(yīng)用如下:

decay = 0.999 # use numbers closer to 1 if you have more data
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))

最后,我們需要解決如何調(diào)用這些訓(xùn)練期操作。為了完全可控,你可以把它們加入到一個(gè)graph collection(可以看看下面鏈接的TensorFlow源碼),但是簡單起見,我們將會在每次計(jì)算批均值和批方差時(shí)都調(diào)用它們。為此,當(dāng)is_training為True時(shí),我們把它們作為依賴加入了batch_norm_wrapper的返回值中。最終的batch_norm_wrapper函數(shù)如下:

# this is a simpler version of Tensorflow's 'official' version. See:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py#L102
def batch_norm_wrapper(inputs, is_training, decay = 0.999):

    scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
    beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)

    if is_training:
        batch_mean, batch_var = tf.nn.moments(inputs,[0])
        train_mean = tf.assign(pop_mean,
                               pop_mean * decay + batch_mean * (1 - decay))
        train_var = tf.assign(pop_var,
                              pop_var * decay + batch_var * (1 - decay))
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(inputs,
                batch_mean, batch_var, beta, scale, epsilon)
    else:
        return tf.nn.batch_normalization(inputs,
            pop_mean, pop_var, beta, scale, epsilon)

實(shí)現(xiàn)正常測試

現(xiàn)在為了證明修復(fù)后的代碼可以正常測試,我們使用batch_norm_wrapper重新構(gòu)建模型。注意,我們不僅要在訓(xùn)練時(shí)做一次構(gòu)建,在測試時(shí)還要重新做一次構(gòu)建,所以我們寫了一個(gè)build_graph函數(shù)(實(shí)際的模型對象往往也是這么封裝的):

def build_graph(is_training):
    # Placeholders
    x = tf.placeholder(tf.float32, shape=[None, 784])
    y_ = tf.placeholder(tf.float32, shape=[None, 10])

    # Layer 1
    w1 = tf.Variable(w1_initial)
    z1 = tf.matmul(x,w1)
    bn1 = batch_norm_wrapper(z1, is_training)
    l1 = tf.nn.sigmoid(bn1)

    #Layer 2
    w2 = tf.Variable(w2_initial)
    z2 = tf.matmul(l1,w2)
    bn2 = batch_norm_wrapper(z2, is_training)
    l2 = tf.nn.sigmoid(bn2)

    # Softmax
    w3 = tf.Variable(w3_initial)
    b3 = tf.Variable(tf.zeros([10]))
    y  = tf.nn.softmax(tf.matmul(l2, w3))

    # Loss, Optimizer and Predictions
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))

    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

    correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    return (x, y_), train_step, accuracy, y, tf.train.Saver()

#Build training graph, train and save the trained model

sess.close()
tf.reset_default_graph()
(x, y_), train_step, accuracy, _, saver = build_graph(is_training=True)

acc = []
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in tqdm.tqdm(range(10000)):
        batch = mnist.train.next_batch(60)
        train_step.run(feed_dict={x: batch[0], y_: batch[1]})
        if i % 50 is 0:
            res = sess.run([accuracy],feed_dict={x: mnist.test.images, y_: mnist.test.labels})
            acc.append(res[0])
    saved_model = saver.save(sess, './temp-bn-save')

print("Final accuracy:", acc[-1])

Final accuracy: 0.9721

現(xiàn)在應(yīng)該一切正常了,我們重復(fù)上面的實(shí)驗(yàn):

tf.reset_default_graph()
(x, y_), _, accuracy, y, saver = build_graph(is_training=False)

predictions = []
correct = 0
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, './temp-bn-save')
    for i in range(100):
        pred, corr = sess.run([tf.arg_max(y,1), accuracy],
                             feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]})
        correct += corr
        predictions.append(pred[0])
print("PREDICTIONS:", predictions)
print("ACCURACY:", correct/100)

PREDICTIONS: [7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5, 4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2, 4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0, 2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4, 1, 7, 6, 9]
ACCURACY: 0.99

我的博客已同步至騰訊云+社區(qū),https://cloud.tencent.com/developer/article/1092219

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容