一文搞懂深度網(wǎng)絡(luò)初始化(Xavier and Kaiming initialization)

最近時不時就有網(wǎng)友發(fā)私信讓我?guī)兔ebug程序,對于這些來信,我的回復(fù)通常都是一句話:“先跑通論文作者的開源代碼,在此基礎(chǔ)上再逐步修改數(shù)據(jù)集和模型?!?/p>

你或許會覺得我在擺譜,凈說些高大上的政治正確的空話,但這還真不是。

和其他的軟件程序不同,神經(jīng)網(wǎng)絡(luò)是個系統(tǒng)工程,數(shù)據(jù)、參數(shù)、模型內(nèi)部結(jié)構(gòu)、訓(xùn)練策略、學(xué)習(xí)率等等,這些因素不管哪一部分出錯,它都不會報錯,只是會輸出一些不是你想要的結(jié)果而已。

參數(shù)初始化就是這么一個容易被忽視的重要因素,因為不僅使用者對其重要性缺乏概念,而且這些操作都被TF、pytorch這些框架封裝了,你可能不知道的是,糟糕的參數(shù)初始化是會阻礙復(fù)雜非線性系統(tǒng)的訓(xùn)練的。

本文以MNIST手寫體數(shù)字識別模型為例來演示參數(shù)初始化對模型訓(xùn)練的影響。點擊這里查看源碼。

Xavier Initialization

早期的參數(shù)初始化方法普遍是將數(shù)據(jù)和參數(shù)normalize為高斯分布(均值0方差1),但隨著神經(jīng)網(wǎng)絡(luò)深度的增加,這方法并不能解決梯度消失問題。

Figure 1: XavierInitialisation.pdf

Xavier初始化的作者,Xavier Glorot,在Understanding the difficulty of training deep feedforward neural networks論文中提出一個洞見:激活值的方差是逐層遞減的,這導(dǎo)致反向傳播中的梯度也逐層遞減。要解決梯度消失,就要避免激活值方差的衰減,最理想的情況是,每層的輸出值(激活值)保持高斯分布。

Figure 2: xavier initialization

因此,他提出了Xavier初始化:bias初始化為0,為Normalize后的參數(shù)乘以一個rescale系數(shù):1/\sqrt n,n是輸入?yún)?shù)的個數(shù)。

公式的推導(dǎo)過程大致如下:

  • y = ax + b = W\vec x + \vec b = \vec w_1x_1 + \vec w_2x_2 + ... + \vec w_nx_n + \vec b
  • var(y) = var(\vec w_1x_1 + ... + \vec w_nx_n + \vec b) = var(\vec w_1x_1) + ... + var(\vec w_nx_n)
  • var(\vec w_ix_i) = E(x_i)^2var(\vec w_i) + E(\vec w_i)^2var(x_i) + var(\vec w_i)var(x_i)
  • 因為E(期望)等于均值,而輸入數(shù)據(jù)(x)和參數(shù)(W)的均值都是0,因此,var(\vec w_ix_i) = var(\vec w_i)var(x_i)
  • var(y) = var(\vec w_1)var(x_1) + var(\vec w_2)var(x_2) + ... + var(\vec w_n)var(x_n)
  • 又因為x和W恒等分布(方差都是1),因此,var(y) = N * var(\vec w_i)var(x_i)
  • 我們的目標是var(y) = var(x),因此,N * var(\vec w_i) = 1, var(\vec w_i) = 1/N
  • std = \sqrt {var}, std(\vec w_i) = 1/\sqrt N

如果上述這段公式你看暈了,也沒關(guān)系,只要記住結(jié)果就好。

接下來,我們要做實驗來驗證Xavier的洞見。

def linear(x, w, b): return x @ w + b

def relu(x): return x.clamp_min(0.)

nh = 50
W1 = torch.randn(784, nh)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

z1 = linear(x_train, W1, b1)
print(z1.mean(), z1.std())

tensor(-0.8809) tensor(26.9281)

這是個簡單的線性回歸模型:y = ax + b,(W1, b1)和(W2, b2)分別是隱層和輸出層的參數(shù),W1/W2初始化為高斯分布,b1/b2初始為0。果然,第一個linear層的輸出值(z1)的均值和標準差就已經(jīng)發(fā)生了很大的變化。如果后續(xù)使用sigmoid作為激活函數(shù),那梯度消失就會很明顯。

現(xiàn)在我們按照Xavier的方法來初始化參數(shù):

W1 = torch.randn(784, nh) * math.sqrt(1 / 784)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1) * math.sqrt(1 / nh)
b2 = torch.zeros(1)

z1 = linear(x_train, W1, b1)
print(z1.mean(), z1.std())

tensor(0.1031) tensor(0.9458)

a1 = relu(z1)
a1.mean(), a1.std()

(tensor(0.4272), tensor(0.5915))

參數(shù)經(jīng)過Xavier初始化后,linear層的輸出值的分布沒有大的變化(U[0.1031, 0.9458]),依舊接近高斯分布,但是好景不長,relu的激活值分布就開始跑偏了(U[0.4272, 0.5915])。

Kaiming Initialization

Xavier初始化的問題在于,它只適用于線性激活函數(shù),但實際上,對于深層神經(jīng)網(wǎng)絡(luò)來說,線性激活函數(shù)是沒有價值,神經(jīng)網(wǎng)絡(luò)需要非線性激活函數(shù)來構(gòu)建復(fù)雜的非線性系統(tǒng)。今天的神經(jīng)網(wǎng)絡(luò)普遍使用relu激活函數(shù)。

Kaiming初始化的發(fā)明人kaiming he,在Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification論文中提出了針對relu的kaiming初始化。

因為relu會拋棄掉小于0的值,對于一個均值為0的data來說,這就相當于砍掉了一半的值,這樣一來,均值就會變大,前面Xavier初始化公式中E(x)=mean=0的情況就不成立了。根據(jù)新公式的推導(dǎo),最終得到新的rescale系數(shù):\sqrt {2/n}。更多細節(jié)請看論文的section 2.2。

W1 = torch.randn(784, nh) * math.sqrt(2 / 784)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1) * math.sqrt(2 / nh)
b2 = torch.zeros(1)

z1 = linear(x_train, W1, b1)
a1 = relu(z1)
a1.mean(), a1.std()

(tensor(0.4553), tensor(0.7339))

可以看到,Kaiming初始化的表現(xiàn)要優(yōu)于Xavier初始化,relu之后的輸出值標準差還有0.7339(浮動可以達到0.8+)。

實際上,Kaiming初始化已經(jīng)被Pytorch用作默認的參數(shù)初始化函數(shù)。

import torch.nn.init as init

W1 = torch.zeros(784, nh)
b1 = torch.zeros(nh)
W2 = torch.zeros(nh, 1)
b2 = torch.zeros(1)

init.kaiming_normal_(W1, mode='fan_out', nonlinearity='relu')
init.kaiming_normal_(W2, mode='fan_out')
z1 = linear(x_train, W1, b1)
a1 = relu(z1)
print("layer1: ", a1.mean(), a1.std())
z2 = linear(a1, W2, b2)

layer1:  tensor(0.5583) tensor(0.8157)
tensor(1.1784) tensor(1.3209)

現(xiàn)在,方差的問題已經(jīng)解決了,接下來就是均值不為0的問題。因為在x軸上平移data并不會影響data的方差,因此,如果把relu的激活值左移5,結(jié)果會如何?

def linear(x, w, b):
  return x @ w + b

def relu(x):
  return x.clamp_min(0.) - 0.5

def model(x):
  x = relu(linear(x, W1, b1))
  print("layer1: ", x.mean(), x.std())
  x = relu(linear(x, W2, b2))
  print("layer2: ", x.mean(), x.std())
  x = linear(x, W3, b3)
  print("layer3: ", x.mean(), x.std())
  return x

nh = [100, 50]
W1 = torch.zeros(784, nh[0])
b1 = torch.zeros(nh[0])
W2 = torch.zeros(nh[0], nh[1])
b2 = torch.zeros(nh[1])
W3 = torch.zeros(nh[1], 1)
b3 = torch.zeros(1)

init.kaiming_normal_(W1, mode='fan_out')
init.kaiming_normal_(W2, mode='fan_out')
init.kaiming_normal_(W3, mode='fan_out')
_ = model(x_train)

layer1:  tensor(0.0383) tensor(0.7993)
layer2:  tensor(0.0075) tensor(0.7048)
layer3:  tensor(-0.2149) tensor(0.4493)

結(jié)果出乎意料的好,這個三層的模型在沒有添加batchnorm的情況下,每層的輸入值和輸出值都接近高斯分布,雖然數(shù)據(jù)方差是會逐層遞減,但相比normalize初始化和Xavier初始化要好很多。

最后,因為Kaiming初始化是pytorch的默認初始化函數(shù),因此我又用pytorch提供的nn.Linear()和nn.Relu()來構(gòu)建相同的模型對比測試,結(jié)果是大跌眼鏡。

class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.lin1 = nn.Linear(784, nh[0])
    self.lin2 = nn.Linear(nh[0], nh[1])
    self.lin3 = nn.Linear(nh[1], 1)
    self.relu = nn.ReLU()
  
  def forward(self, x):
    x = self.relu(self.lin1(x))
    print("layer 1: ", x.mean().item(), x.std().item())
    x = self.relu(self.lin2(x))
    print("layer 2: ", x.mean().item(), x.std().item())
    x = self.relu(self.lin3(x))
    print("layer 3: ", x.mean().item(), x.std().item())
    return x

m = Model()
_ = m(x_train)

layer 1:  0.2270725518465042 0.32707411050796
layer 2:  0.033514849841594696 0.23475737869739532
layer 3:  0.013271240517497063 0.09185370802879333

可以看到,第三層的輸出已經(jīng)均值為0、方差為0。去看nn.Linear()類的代碼時會看到,它在做初始化時會傳入?yún)?shù)a=math.sqrt(5)。我們知道,當輸入為負數(shù)時,leaky relu的梯度為[0, \infty],x = \lambda x,參數(shù)a就是這個\lambda。雖然kaiming_uniform_()的默認網(wǎng)絡(luò)要使用的激活函數(shù)是leaky relu,但a默認值為0,此時leaky relu就等于relu。但現(xiàn)在數(shù)據(jù)存在負數(shù),因此,mean相比relu模型更接近于0,甚至E(x) > 0的假設(shè)都不成立了,因此,rescale系數(shù)就不準確了,nn.Linear()才會有這樣的表現(xiàn)。

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

END

本文通過Xavier和Kaiming初始化來展現(xiàn)了參數(shù)初始化的重要性,因為糟糕的初始化容易讓神經(jīng)網(wǎng)絡(luò)陷入梯度消失的陷阱中。

References


歡迎關(guān)注和點贊,你的鼓勵將是我創(chuàng)作的動力

歡迎轉(zhuǎn)發(fā)至朋友圈,公眾號轉(zhuǎn)載請后臺留言申請授權(quán)~

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容