Batch Normalization的原理和效果

論文:Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift

摘要

深度網(wǎng)絡(luò)由于前面的網(wǎng)絡(luò)層參數(shù)變化,導(dǎo)致下一層的輸入分布發(fā)生變化,從而導(dǎo)致網(wǎng)絡(luò)難以訓(xùn)練。這降低了訓(xùn)練速度(因為需要使用曉得學(xué)習(xí)率和精細(xì)的初始化),并且使飽和非線性網(wǎng)絡(luò)的訓(xùn)練變得困難。我們把這個現(xiàn)象叫做internal covariate shift,并通過歸一化輸入來解決問題。我們的方法的優(yōu)勢在于使標(biāo)準(zhǔn)化成為模型體系結(jié)構(gòu)的一部分,并為每個小訓(xùn)練 batch 執(zhí)行標(biāo)準(zhǔn)化。Batch Normalization使我們能夠使用更大的學(xué)習(xí)率,并且不需要太關(guān)心初始化。它的行為就像正則化,在一些情況下可以不必使用Dropout。在 state-of-the-art的分類模型上使用Batch Normalization,可以少用14倍的training steps就可以達(dá)到同樣的準(zhǔn)確率,并且和原始模型拉開了個顯著差距。使用集成了batch normalization的網(wǎng)路,我們達(dá)到了Imagenet分類的最好的公開結(jié)果:達(dá)到了 4.9% top-5 validation error (and 4.8% test error),超過了人類識別率。


相關(guān)鏈接:《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》閱讀筆記與實現(xiàn) - CSDN博客

深度學(xué)習(xí)中 Batch Normalization為什么效果好? - 知乎

深度學(xué)習(xí)中 Batch Normalization為什么效果好? - 知乎

詳解深度學(xué)習(xí)中的Normalization,不只是BN

Batch Normalization的作用:

1、加速收斂,并且有更好的收斂結(jié)果

2、可以使用更大的學(xué)習(xí)率,并且不必做精細(xì)的參數(shù)初始化

3、有正則化的效果


使用:

在Tensorflow中使用batch norm需要在更新op的使用加上


示例代碼:

```

def train_model(base_lr, loss, data_num):

? ? """

? ? train model

? ? :param base_lr: base learning rate

? ? :param loss: loss

? ? :param data_num:

? ? :return:

? ? train_op, lr_op

? ? """

? ? update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

? ? with tf.control_dependencies(update_ops):

? ? ? ? lr_factor = 0.1

? ? ? ? global_step = tf.Variable(0, trainable=False)

? ? ? ? #LR_EPOCH [8,14]

? ? ? ? #boundaried [num_batch,num_batch]

? ? ? ? boundaries = [int(epoch * data_num / config.BATCH_SIZE) for epoch in config.LR_EPOCH]

? ? ? ? #lr_values[0.01,0.001,0.0001,0.00001]

? ? ? ? lr_values = [base_lr * (lr_factor ** x) for x in range(0, len(config.LR_EPOCH) + 1)]

? ? ? ? #control learning rate

? ? ? ? lr_op = tf.train.piecewise_constant(global_step, boundaries, lr_values)

? ? ? ? optimizer = tf.train.MomentumOptimizer(lr_op, 0.9)

? ? ? ? train_op = optimizer.minimize(loss, global_step)

? ? return train_op, lr_op

```

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容