BN(Batch Normalization)和TF2的BN層

1、Batch Normalization

在討論Batch Normalization之前,先討論一下feature scaling可能會對后續(xù)的討論有很大的幫助。feature scaling,即特征歸一化,是機器學習領域中一種通用的數據預處理方法,其目的是將模式向量中尺度不一致的不同維度特征歸一到同一尺度,以保證訓練速度與精度。

假設有一個大小為n的數據集X^{1,...,n},其中每個模式向量有m個維度的特征X^{i} = {x_{1,...,m}}。如果在這個數據集中,第i維的特征x_{i}服從均值為0、方差為1的高斯分布,而第j維的特征x_{j}服從均值為200,方差為1的高斯分布,那么這個數據集將難以用于模型訓練。其原因在于,x_{i}x_{j}的分布相差甚遠,模型中與x_{i}相關的參數只進行很小的改變往往難以對結果造成顯著性的改變,而與x_{j}相關的參數則恰恰相反,這讓訓練過程的learning rate很難統(tǒng)一,過小收斂過慢,過大則可能不收斂。

為了解決以上問題,feature scaling對每個維度的特征都進行如下變換,變換的結果則是所有維度的特征都歸一化到均值為0、方差為1這個尺度:

\hat{x}_{i}^{r} = \frac{x_{i}^{r} - m_{i}}{\sigma_{i}}

以上方法對于模型的訓練是十分有效的,而在深度神經網絡的研究中,研究人員延續(xù)這種思路提出了Batch Normalization。相對于傳統(tǒng)的模型,深度神經網絡遇到的問題是,隨著網絡深度增加,網絡中一個小小的改變可能在經過若干層的傳播之后令整個網絡出現極大的波動,如bp過程中的梯度消失與爆炸(事實上,ReLU、有效的初始化、設置更小的learning rate等方法都能用于解決該問題)。

Batch Normalization可以用于解決深度神經網絡的Internal Covariate Shift問題,其實質是:使用一定的規(guī)范化方法,把每個隱層神經元的輸入控制為均值為0、方差為1的標準正態(tài)分布,使得非線性變換函數的輸入值落入對輸入比較敏感的區(qū)域(如Sigmoid函數只在0附近具有較好的梯度),以此避免梯度消失問題。

在Batch Normalization中,Batch是指每次訓練時網絡的輸入都是一批訓練數據,這一批數據會同時經過網絡的一層,然后在經過WX^{i}+b=Z^{i}之后,網絡再一起對這一批數據的Z^{i}做規(guī)范化處理。當然,Batch Normalization的論文中還使用了兩個參數處理規(guī)范化之后的數據,即\hat{Z}^{i} = \gamma\odot\tilde{z}^{i}+\beta。事實上,如果\gamma=\sigma,\beta=\mu,這就等價于Normalization的一個逆運算,那么normalization的意義似乎就不存在了,但是,事實并非如此,因為\mu,\sigmaZ^{i}相關,而\gamma,\beta則完全獨立,二者并不等價。合理的解釋是,后續(xù)操作是為了防止normalization矯枉過正增加的人為擾動。Batch Normalization的具體結構如下所示:

bn

2、TF2的BN層

在tensorflow2中使用BN層的方法如下,需要注意的是BN層在訓練和推理兩種模式下存在不同。

BN層有4*num_channels個參數,每4個參數對應一個通道,分別是\mu, \sigma, \beta, \gamma。其中\beta, \gamma和其他層的參數的邏輯是一致的,訓練時不斷調整,推理時不再改變(即只有優(yōu)化器更新參數時才會改變)。而\mu, \sigma不同,在推理時,即使沒有優(yōu)化器更新參數,也可能不斷變化。這兩個參數受BatchNormalization層的參數training控制,當training=False時,二者為移動均值和方差(固定);當training=True時,二者與每次輸入的batch相關,\mu, \sigma是當前batch的均值、方差。

綜上,在使用TF2的BN層時,推理時需要指定當前模式為推理模式,方法如下(還存在其他方法,如顯示地聲明training參數為False)。此外,BN層也有trainable參數,和其他層一樣,該參數意在凍結\beta, \gamma兩個參數,但是當trainable=True時,該BN層會以推理模式運行,\mu, \sigma兩個參數也就隨之固定。

import tensorflow as tf

# BN層的使用
tf.keras.layers.BatchNormalization()(x)

# 訓練、推理模式的選擇,0-推理、1-訓練
tf.keras.backend.set_learning_phase(0)
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

友情鏈接更多精彩內容