PyTorch常用的初始化和正則

原文: https://www.pytorchtutorial.com/pytorch-goodies/

模型統(tǒng)計(jì)數(shù)據(jù)(Model Statistics)

統(tǒng)計(jì)參數(shù)總數(shù)量
num_params = sum(param.numel() for param in model.parameters())

參數(shù)初始化(Weight Initialization)

PyTorch 中參數(shù)的默認(rèn)初始化在各個(gè)層的 reset_parameters() 方法中。例如:nn.Linear 和 nn.Conv2D,都是在 [-limit, limit] 之間的均勻分布(Uniform distribution),其中 limit 是 1. / sqrt(fan_in) ,fan_in 是指參數(shù)張量(tensor)的輸入單元的數(shù)量

下面是幾種常見(jiàn)的初始化方式。

Xavier Initialization

Xavier初始化的基本思想是保持輸入和輸出的方差一致,這樣就避免了所有輸出值都趨向于0。這是通用的方法,適用于任何激活函數(shù)。

# 默認(rèn)方法
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform(m.weight)

也可以使用 gain 參數(shù)來(lái)自定義初始化的標(biāo)準(zhǔn)差來(lái)匹配特定的激活函數(shù):

for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform(m.weight(), gain=nn.init.calculate_gain(\\'relu\\'))

參考資料:

He et. al Initialization

He initialization的思想是:在ReLU網(wǎng)絡(luò)中,假定每一層有一半的神經(jīng)元被激活,另一半為0。推薦在ReLU網(wǎng)絡(luò)中使用。

# he initialization
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal(m.weight, mode=\\'fan_in\\')

正交初始化(Orthogonal Initialization)

主要用以解決深度網(wǎng)絡(luò)下的梯度消失、梯度爆炸問(wèn)題,在RNN中經(jīng)常使用的參數(shù)初始化方法

for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.orthogonal(m.weight)

Batchnorm Initialization

在非線性激活函數(shù)之前,我們想讓輸出值有比較好的分布(例如高斯分布),以便于計(jì)算梯度和更新參數(shù)。Batch Normalization 將輸出值強(qiáng)行做一次 Gaussian Normalization 和線性變換:

image.png
for m in model:
    if isinstance(m, nn.BatchNorm2d):
        nn.init.constant(m.weight, 1)
        nn.init.constant(m.bias, 0)

參數(shù)正則化(Weight Regularization)

L2/L1 Regularization

機(jī)器學(xué)習(xí)中幾乎都可以看到損失函數(shù)后面會(huì)添加一個(gè)額外項(xiàng),常用的額外項(xiàng)一般有兩種,稱作L1正則化和L2正則化,或者L1范數(shù)和L2范數(shù)。

L1 正則化和 L2 正則化可以看做是損失函數(shù)的懲罰項(xiàng)。所謂 “懲罰” 是指對(duì)損失函數(shù)中的某些參數(shù)做一些限制。

L1 正則化是指權(quán)值向量 w 中各個(gè)元素的絕對(duì)值之和,通常表示為 ||w||1
L2 正則化是指權(quán)值向量 w 中各個(gè)元素的平方和然后再求平方根,通常表示為 ||w||2
下面是L1正則化和L2正則化的作用,這些表述可以在很多文章中找到。

L1 正則化可以產(chǎn)生稀疏權(quán)值矩陣,即產(chǎn)生一個(gè)稀疏模型,可以用于特征選擇
L2 正則化可以防止模型過(guò)擬合(overfitting);一定程度上,L1也可以防止過(guò)擬合

  • L2 正則化的實(shí)現(xiàn)方法:
reg = 1e-6
l2_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
    if \'bias\' not in name:
        l2_loss = l2_loss   (0.5 * reg * torch.sum(torch.pow(W, 2)))
  • L1 正則化的實(shí)現(xiàn)方法:
reg = 1e-6
l1_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
    if \'bias\' not in name:
        l1_loss = l1_loss   (reg * torch.sum(torch.abs(W)))
  • Orthogonal Regularization
reg = 1e-6
orth_loss = Variable(torch.FloatTensor(1), requires_grad=True)
for name, param in model.named_parameters():
    if \'bias\' not in name:
        param_flat = param.view(param.shape[0], -1)
        sym = torch.mm(param_flat, torch.t(param_flat))
        sym -= Variable(torch.eye(param_flat.shape[0]))
        orth_loss = orth_loss   (reg * sym.sum())
  • Max Norm Constraint
    簡(jiǎn)單來(lái)講就是對(duì) w 的指直接進(jìn)行限制。
ef max_norm(model, max_val=3, eps=1e-8):
    for name, param in model.named_parameters():
        if \'bias\' not in name:
            norm = param.norm(2, dim=0, keepdim=True)
            desired = torch.clamp(norm, 0, max_val)
            param = param * (desired / (eps   norm))
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容