Batch Normalization
相反,Batch normalization被立即認(rèn)為是有巨大影響的。當(dāng)它剛出現(xiàn)時(shí),就馬上得到大家的認(rèn)同。我清楚地記得在2015,每個(gè)人都在談?wù)撍?,這是因?yàn)樗男Ч鴮?shí)顯著,如這張圖所示。

這是當(dāng)時(shí)最先進(jìn)的Inception結(jié)構(gòu)的ImageNet模型。這是(黑色虛線)它得到一個(gè)很好的結(jié)果花的時(shí)間,然后他們用這個(gè)叫btach norm的新東西做了同樣的事情,它們(藍(lán)色實(shí)線)做得非常非??臁?斓阶屓藗儗?duì)它的神速嘆為觀止。
他們管這個(gè)叫批量歸一化(Batch Normalization)。它通過(guò)減少內(nèi)部協(xié)變量轉(zhuǎn)移(internal covariate shift)來(lái)加速訓(xùn)練。什么是內(nèi)部協(xié)變量轉(zhuǎn)移?其實(shí)它不重要。因?yàn)閎atch norm是一些研究者靠直覺(jué)提出來(lái)的、要試試看的東西。他們?cè)囼?yàn)了,它的效果很好,他們事后補(bǔ)充了一些數(shù)學(xué)分析來(lái)試著說(shuō)明為什么需要它。事實(shí)證明,他們完全錯(cuò)了。

最近兩個(gè)月,有兩篇論文證明了這一點(diǎn),而人們花了三年才真正明白,在過(guò)去兩個(gè)月里,有兩篇論文指出,batch normalization根本沒(méi)有減少內(nèi)部協(xié)變量轉(zhuǎn)移。并且,就算真的減少了,也也跟它為什么能加速訓(xùn)練無(wú)關(guān)。我認(rèn)為這是一個(gè)有意義的領(lǐng)悟,也說(shuō)明為什么我們要專(zhuān)注于做實(shí)踐者和實(shí)驗(yàn)主義者,并提升我們的直覺(jué)。
batch norm的作用是由這篇論文里的圖(the positive impact of BatchNorm on training might be somewhat serendipitous)可以說(shuō)明。橫軸是步數(shù)(steps)或者batch(x-axis),縱軸是損失(y-axis)。紅線是沒(méi)用batch norm訓(xùn)練時(shí)的情況,波動(dòng)很大。藍(lán)線是用了batch norm后,訓(xùn)練的情況,波動(dòng)小了很多。這意味著,使用batch norm你可以提高你的學(xué)習(xí)率。因?yàn)檫@些波動(dòng)劇烈的部分代表這你的權(quán)重設(shè)置帶來(lái)的不確定性的時(shí)段,程序跳進(jìn)某個(gè)權(quán)重空間中的糟糕區(qū)域,再也挑不出來(lái)了。如果波動(dòng)沒(méi)那么劇烈,那么就可以用更好的學(xué)習(xí)率進(jìn)行訓(xùn)練,這就是實(shí)際的情況。
這就是這種算法了。它很簡(jiǎn)單。算法會(huì)取一個(gè)mini batch。我們有了一個(gè)mini batch,記住,這是一個(gè)網(wǎng)絡(luò)層,進(jìn)入這個(gè)層的是激活值。網(wǎng)絡(luò)層輸入的是激活值。激活值用等來(lái)表示。

- 首先,我們要計(jì)算這些激活值的平均值(mean),均值就是激活值的和除以其數(shù)量就是平均值。
- 第二件事,我們找到激活值的方差(variance),差值就是激活值與均值之差的的平方和,再除以數(shù)量是就方差。
- 然后做標(biāo)準(zhǔn)化(normalize),激活值減去均值除以標(biāo)準(zhǔn)差,就是標(biāo)準(zhǔn)化。這實(shí)際上這幾步不是很重要。我們?cè)?jīng)以為很重要,但后來(lái)發(fā)現(xiàn)不是。真正重要的部分是下面的東西。
- 我們?nèi)∵@些值,加上一個(gè)偏置向量(這里把它叫做
)。
我們之前已經(jīng)看過(guò)了。我們之前用過(guò)偏差項(xiàng)。所以我們會(huì)像之前一樣,加上一個(gè)偏差項(xiàng)。然后,我們要用另外一個(gè)和bias很像的東西,但不是加上它,我們會(huì)乘以它。這些參數(shù)
和
是要學(xué)習(xí)的參數(shù)。
記住,在神經(jīng)網(wǎng)絡(luò)里,只有兩種數(shù)字:激活值和參數(shù)。這些是參數(shù)。是用梯度下降學(xué)習(xí)到的東西。只是一個(gè)普通的bias層,
是一個(gè)做乘法的bias層。沒(méi)有人這樣叫它,但它就是這樣的。
就是使用乘法而非加法的偏差項(xiàng)。這就是batch norm。這就是這一層做的事。
為什么batch norm可以實(shí)現(xiàn)了不起的結(jié)果?我不清楚有沒(méi)有人之前準(zhǔn)確地把這寫(xiě)過(guò)關(guān)于這個(gè)問(wèn)題的文章。如果有,抱歉這里沒(méi)有引用它,因?yàn)槲覜](méi)有看過(guò)。讓我解釋下。究竟發(fā)生了什么。我們的預(yù)測(cè)值是各權(quán)重的函數(shù),可以達(dá)到上百萬(wàn)個(gè)權(quán)重值,
也是一個(gè)層輸入的函數(shù)。
這個(gè)函數(shù)是神經(jīng)網(wǎng)絡(luò)函數(shù),不管在神經(jīng)網(wǎng)絡(luò)里發(fā)生了什么。我們的損失值,假定是均方誤差,就是實(shí)際值減去預(yù)測(cè)值的平方。
假設(shè)我們要預(yù)測(cè)電影評(píng)分的結(jié)果,預(yù)測(cè)值在1到5之間。我們一直在訓(xùn)練模型,最后的激活值在-1到1之間。這些激活值與實(shí)際需要的值相差太遠(yuǎn)。縮放和均值都沒(méi)用,我們應(yīng)該怎么做?一個(gè)方法是用一組新的權(quán)重,讓傳播值增加,讓平均值增長(zhǎng)。但是這很難,因?yàn)樗羞@些參數(shù)有很密切而復(fù)雜的相互作用。我們得到的所有非線性因素匯聚在一起,因此想要提高數(shù)值,需要在這種復(fù)雜的情形下找到一條出路。我們使用像動(dòng)量、Adam之類(lèi)的東西來(lái)幫助我們,但還要做大量的旋轉(zhuǎn)(twidding)才能得到我們想要的。這會(huì)花很長(zhǎng)時(shí)間,學(xué)習(xí)曲線會(huì)跌宕起伏。
我們這樣做怎么樣?在的公示后面乘以
,再加上
會(huì)怎么樣?

我們多加了兩個(gè)參數(shù)向量?,F(xiàn)在它就簡(jiǎn)單了。為了提高縮放比例,這個(gè)數(shù) 得到直接梯度來(lái)增加縮放比例。為了改變均值(向量b)可以直接得到梯度來(lái)改變均值?!総o change the mean that number has a direct gradient to change the mean there's no interactiions or complexities】?jī)烧邲](méi)有相互作用和復(fù)雜性,都是直接的升降或縮放,這就是batch norm做的事。所以batch norm讓使輸出變大變小這個(gè)重要工作更容易做到,使輸出提升或下降,放大或縮小。這就是為什么我們能得到這樣的結(jié)果。
這些細(xì)節(jié),在某種意義上,不是特別重要。真正重要的是你肯定想要用它。如果不用它,也會(huì)用類(lèi)似的東西?,F(xiàn)在,有很多其它類(lèi)型的歸一化方法,但batch norm效果很好。我們?cè)趂astai里用的另一種的標(biāo)準(zhǔn)化方法主要是weight norm,這是最近幾個(gè)月新開(kāi)發(fā)的。

這就是batch norm,我們?yōu)槊恳粋€(gè)連續(xù)變量創(chuàng)建了一個(gè)btach norm層。n_cont是連續(xù)變量的數(shù)量。在fastai里,n_something通常代表這個(gè)東西的數(shù)量。cont通常代表連續(xù)(continuous)。這里是我們使用它的地方。我們?nèi)〉竭B續(xù)變量,然后將其送入batch norm層。


你可以在模型里的這個(gè)地方看到它。

一個(gè)值得注意的東西是這里的動(dòng)量(momentum)。這不是優(yōu)化里的動(dòng)量法,是exponentially weighted moving average里的動(dòng)量。具體來(lái)說(shuō),這個(gè)(batch norm算法里的)平均值和標(biāo)準(zhǔn)差,我們沒(méi)有為每一個(gè)mini batch用不同的平均值和標(biāo)準(zhǔn)差。如果這樣做了,這些值的變化會(huì)很大,從而導(dǎo)致難以訓(xùn)練。因此,我們采用平均值和標(biāo)準(zhǔn)差的指數(shù)加權(quán)移動(dòng)平均值。如果你不記得這是什么意思,就回去看下上周的課程,復(fù)習(xí)下指數(shù)加權(quán)移動(dòng)平均,上節(jié)課我們?cè)趀xcel里實(shí)現(xiàn)動(dòng)量法和Adam算法的梯度平方項(xiàng)。

你可以通過(guò)往pytorch構(gòu)造器傳入一個(gè)不同的值來(lái)改變一個(gè)batch norm層里動(dòng)量的數(shù)量。如果你用了一個(gè)較小的數(shù),這意味著從一個(gè)mini batch間的平均值和標(biāo)準(zhǔn)差會(huì)變小,它們受正則化的效果欠佳。輸入一個(gè)較大的數(shù),意味著均值和標(biāo)準(zhǔn)差在mini-batch間的差異也較大,這樣的正則化的效果也會(huì)較好。這樣的訓(xùn)練會(huì)更好,因?yàn)槠鋮?shù)化的程度較好,包含均值和標(biāo)準(zhǔn)差的動(dòng)量項(xiàng),產(chǎn)生了這種較好的正則化效果。
當(dāng)你用了batch norm,你需要用更大的學(xué)習(xí)率。這是我們的模型。你可以運(yùn)行lr_find,你可以看下結(jié)果。然后可以用fit()函數(shù)訓(xùn)練,在保存結(jié)果。畫(huà)損失曲線,接下來(lái)再次用fit()訓(xùn)練:
learn.lr_find()
learn.recorder.plot()

learn.fit_one_cycle(5, 1e-3, wd=0.2)

learn.load('1');
learn.fit_one_cycle(5, 3e-4)

最終得到0.103。競(jìng)賽的第十名是0.108,這看起來(lái)很好。但不要太當(dāng)回事,因?yàn)槟闳绻呀Y(jié)果提交到Kaggle比賽, 還要用真實(shí)的訓(xùn)練集, 但可以看到,我們的模型至少是2015年的先進(jìn)水平了。而且我說(shuō)過(guò),這些模型至今基本上沒(méi)有做過(guò)架構(gòu)上的改進(jìn)。當(dāng)時(shí)沒(méi)有batch norm,我們添加了batch norm應(yīng)該能得到更好的結(jié)果,并且運(yùn)行得更快。他們的模型用一個(gè)比較低的學(xué)習(xí)率,訓(xùn)練了很久??梢钥吹?,這個(gè)用了不到45分鐘。這很好很快。
提問(wèn): 你使用dropout和其他正則化方法比如權(quán)重衰減(weight decay),L2正則化等等的比例是怎樣的?
記得嗎,L2正則化和權(quán)重衰減是做同一件事的兩種方式,我們應(yīng)該總是用權(quán)重衰減,不用L2正則化。我們現(xiàn)在的方法有權(quán)重衰減,可產(chǎn)生正則化效果的batch norm。batch norm有正則化的效果,還有我們馬上要學(xué)習(xí)的數(shù)據(jù)增強(qiáng)(data augmentation )以及dropout。我們應(yīng)該總是用batch norm。它很簡(jiǎn)單。我們等下會(huì)學(xué)數(shù)據(jù)增強(qiáng)。剩下就是比較dropout和權(quán)重衰減。我不知道。我沒(méi)有看到過(guò)有人做過(guò)令人信服的關(guān)于如何結(jié)合這兩個(gè)東西的研究??梢钥偸怯闷渲幸粋€(gè),而不用另一個(gè)嗎?為什么能?為什么不能?我想沒(méi)人可以解答。在實(shí)踐中,看起來(lái),通常你需要同時(shí)用這兩個(gè)。你通常需要用權(quán)重衰減,但也經(jīng)常需要用dropout。說(shuō)實(shí)話,我不知道為什么。我沒(méi)有見(jiàn)過(guò)有人解釋為什么,如何做選擇。這個(gè)需要你去嘗試,來(lái)獲得這樣一種感覺(jué),知道對(duì)你的問(wèn)題,哪個(gè)方法是有效的。我認(rèn)為我們?cè)诿總€(gè)learner里提供的默認(rèn)方法可以在大多數(shù)場(chǎng)景工作得很好。但是,肯定可以去嘗試你自己定義的值。