
最近時不時就有網(wǎng)友發(fā)私信讓我?guī)兔ebug程序,對于這些來信,我的回復(fù)通常都是一句話:“先跑通論文作者的開源代碼,在此基礎(chǔ)上再逐步修改數(shù)據(jù)集和模型?!?/p>
你或許會覺得我在擺譜,凈說些高大上的政治正確的空話,但這還真不是。
和其他的軟件程序不同,神經(jīng)網(wǎng)絡(luò)是個系統(tǒng)工程,數(shù)據(jù)、參數(shù)、模型內(nèi)部結(jié)構(gòu)、訓(xùn)練策略、學(xué)習(xí)率等等,這些因素不管哪一部分出錯,它都不會報錯,只是會輸出一些不是你想要的結(jié)果而已。
參數(shù)初始化就是這么一個容易被忽視的重要因素,因為不僅使用者對其重要性缺乏概念,而且這些操作都被TF、pytorch這些框架封裝了,你可能不知道的是,糟糕的參數(shù)初始化是會阻礙復(fù)雜非線性系統(tǒng)的訓(xùn)練的。
本文以MNIST手寫體數(shù)字識別模型為例來演示參數(shù)初始化對模型訓(xùn)練的影響。點擊這里查看源碼。
Xavier Initialization
早期的參數(shù)初始化方法普遍是將數(shù)據(jù)和參數(shù)normalize為高斯分布(均值0方差1),但隨著神經(jīng)網(wǎng)絡(luò)深度的增加,這方法并不能解決梯度消失問題。

Xavier初始化的作者,Xavier Glorot,在Understanding the difficulty of training deep feedforward neural networks論文中提出一個洞見:激活值的方差是逐層遞減的,這導(dǎo)致反向傳播中的梯度也逐層遞減。要解決梯度消失,就要避免激活值方差的衰減,最理想的情況是,每層的輸出值(激活值)保持高斯分布。

因此,他提出了Xavier初始化:bias初始化為0,為Normalize后的參數(shù)乘以一個rescale系數(shù):1/,n是輸入?yún)?shù)的個數(shù)。
公式的推導(dǎo)過程大致如下:
- 因為E(期望)等于均值,而輸入數(shù)據(jù)(x)和參數(shù)(W)的均值都是0,因此,
- 又因為x和W恒等分布(方差都是1),因此,
- 我們的目標是
,因此,
如果上述這段公式你看暈了,也沒關(guān)系,只要記住結(jié)果就好。
接下來,我們要做實驗來驗證Xavier的洞見。
def linear(x, w, b): return x @ w + b
def relu(x): return x.clamp_min(0.)
nh = 50
W1 = torch.randn(784, nh)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1)
b2 = torch.zeros(1)
z1 = linear(x_train, W1, b1)
print(z1.mean(), z1.std())
tensor(-0.8809) tensor(26.9281)
這是個簡單的線性回歸模型:,(W1, b1)和(W2, b2)分別是隱層和輸出層的參數(shù),W1/W2初始化為高斯分布,b1/b2初始為0。果然,第一個linear層的輸出值(z1)的均值和標準差就已經(jīng)發(fā)生了很大的變化。如果后續(xù)使用sigmoid作為激活函數(shù),那梯度消失就會很明顯。
現(xiàn)在我們按照Xavier的方法來初始化參數(shù):
W1 = torch.randn(784, nh) * math.sqrt(1 / 784)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1) * math.sqrt(1 / nh)
b2 = torch.zeros(1)
z1 = linear(x_train, W1, b1)
print(z1.mean(), z1.std())
tensor(0.1031) tensor(0.9458)
a1 = relu(z1)
a1.mean(), a1.std()
(tensor(0.4272), tensor(0.5915))
參數(shù)經(jīng)過Xavier初始化后,linear層的輸出值的分布沒有大的變化(),依舊接近高斯分布,但是好景不長,relu的激活值分布就開始跑偏了(
)。
Kaiming Initialization
Xavier初始化的問題在于,它只適用于線性激活函數(shù),但實際上,對于深層神經(jīng)網(wǎng)絡(luò)來說,線性激活函數(shù)是沒有價值,神經(jīng)網(wǎng)絡(luò)需要非線性激活函數(shù)來構(gòu)建復(fù)雜的非線性系統(tǒng)。今天的神經(jīng)網(wǎng)絡(luò)普遍使用relu激活函數(shù)。
Kaiming初始化的發(fā)明人kaiming he,在Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification論文中提出了針對relu的kaiming初始化。
因為relu會拋棄掉小于0的值,對于一個均值為0的data來說,這就相當于砍掉了一半的值,這樣一來,均值就會變大,前面Xavier初始化公式中E(x)=mean=0的情況就不成立了。根據(jù)新公式的推導(dǎo),最終得到新的rescale系數(shù):。更多細節(jié)請看論文的section 2.2。
W1 = torch.randn(784, nh) * math.sqrt(2 / 784)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1) * math.sqrt(2 / nh)
b2 = torch.zeros(1)
z1 = linear(x_train, W1, b1)
a1 = relu(z1)
a1.mean(), a1.std()
(tensor(0.4553), tensor(0.7339))
可以看到,Kaiming初始化的表現(xiàn)要優(yōu)于Xavier初始化,relu之后的輸出值標準差還有0.7339(浮動可以達到0.8+)。
實際上,Kaiming初始化已經(jīng)被Pytorch用作默認的參數(shù)初始化函數(shù)。
import torch.nn.init as init
W1 = torch.zeros(784, nh)
b1 = torch.zeros(nh)
W2 = torch.zeros(nh, 1)
b2 = torch.zeros(1)
init.kaiming_normal_(W1, mode='fan_out', nonlinearity='relu')
init.kaiming_normal_(W2, mode='fan_out')
z1 = linear(x_train, W1, b1)
a1 = relu(z1)
print("layer1: ", a1.mean(), a1.std())
z2 = linear(a1, W2, b2)
layer1: tensor(0.5583) tensor(0.8157)
tensor(1.1784) tensor(1.3209)
現(xiàn)在,方差的問題已經(jīng)解決了,接下來就是均值不為0的問題。因為在x軸上平移data并不會影響data的方差,因此,如果把relu的激活值左移5,結(jié)果會如何?
def linear(x, w, b):
return x @ w + b
def relu(x):
return x.clamp_min(0.) - 0.5
def model(x):
x = relu(linear(x, W1, b1))
print("layer1: ", x.mean(), x.std())
x = relu(linear(x, W2, b2))
print("layer2: ", x.mean(), x.std())
x = linear(x, W3, b3)
print("layer3: ", x.mean(), x.std())
return x
nh = [100, 50]
W1 = torch.zeros(784, nh[0])
b1 = torch.zeros(nh[0])
W2 = torch.zeros(nh[0], nh[1])
b2 = torch.zeros(nh[1])
W3 = torch.zeros(nh[1], 1)
b3 = torch.zeros(1)
init.kaiming_normal_(W1, mode='fan_out')
init.kaiming_normal_(W2, mode='fan_out')
init.kaiming_normal_(W3, mode='fan_out')
_ = model(x_train)
layer1: tensor(0.0383) tensor(0.7993)
layer2: tensor(0.0075) tensor(0.7048)
layer3: tensor(-0.2149) tensor(0.4493)
結(jié)果出乎意料的好,這個三層的模型在沒有添加batchnorm的情況下,每層的輸入值和輸出值都接近高斯分布,雖然數(shù)據(jù)方差是會逐層遞減,但相比normalize初始化和Xavier初始化要好很多。
最后,因為Kaiming初始化是pytorch的默認初始化函數(shù),因此我又用pytorch提供的nn.Linear()和nn.Relu()來構(gòu)建相同的模型對比測試,結(jié)果是大跌眼鏡。
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(784, nh[0])
self.lin2 = nn.Linear(nh[0], nh[1])
self.lin3 = nn.Linear(nh[1], 1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.lin1(x))
print("layer 1: ", x.mean().item(), x.std().item())
x = self.relu(self.lin2(x))
print("layer 2: ", x.mean().item(), x.std().item())
x = self.relu(self.lin3(x))
print("layer 3: ", x.mean().item(), x.std().item())
return x
m = Model()
_ = m(x_train)
layer 1: 0.2270725518465042 0.32707411050796
layer 2: 0.033514849841594696 0.23475737869739532
layer 3: 0.013271240517497063 0.09185370802879333
可以看到,第三層的輸出已經(jīng)均值為0、方差為0。去看nn.Linear()類的代碼時會看到,它在做初始化時會傳入?yún)?shù)a=math.sqrt(5)。我們知道,當輸入為負數(shù)時,leaky relu的梯度為,
,參數(shù)a就是這個
。雖然kaiming_uniform_()的默認網(wǎng)絡(luò)要使用的激活函數(shù)是leaky relu,但a默認值為0,此時leaky relu就等于relu。但現(xiàn)在數(shù)據(jù)存在負數(shù),因此,mean相比relu模型更接近于0,甚至E(x) > 0的假設(shè)都不成立了,因此,rescale系數(shù)就不準確了,nn.Linear()才會有這樣的表現(xiàn)。
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
END
本文通過Xavier和Kaiming初始化來展現(xiàn)了參數(shù)初始化的重要性,因為糟糕的初始化容易讓神經(jīng)網(wǎng)絡(luò)陷入梯度消失的陷阱中。
References
- Understanding the difficulty of training deep feedforward neural networks
- Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
- understanding-xavier-initialization-in-deep-neural-networks
- https://github.com/fastai/course-v3/blob/master/nbs/dl2/02_fully_connected.ipynb
歡迎關(guān)注和點贊,你的鼓勵將是我創(chuàng)作的動力
歡迎轉(zhuǎn)發(fā)至朋友圈,公眾號轉(zhuǎn)載請后臺留言申請授權(quán)~