BN和LN 在pytorch

BN對于nchw的形狀來說,是對channel維度以外的形狀做歸一化,會有一點損失,但是能防止過擬合


batch_norm公式
# coding=utf8
import torch
from torch import nn

# track_running_stats=False,求當(dāng)前 batch 真實平均值和標(biāo)準(zhǔn)差,
# 而不是更新全局平均值和標(biāo)準(zhǔn)差
# affine=False, 只做歸一化,不乘以 gamma 加 beta(通過訓(xùn)練才能確定)
# num_features 為 feature map 的 channel 數(shù)目
# eps 設(shè)為 0,讓官方代碼和我們自己的代碼結(jié)果盡量接近
bn = nn.BatchNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)

# 乘 10000 為了擴大數(shù)值,如果出現(xiàn)不一致,差別更明顯
x = torch.rand(10, 3, 5, 5)*10000 
official_bn = bn(x)
# 把 channel 維度單獨提出來,而把其它需要求均值和標(biāo)準(zhǔn)差的維度融合到一起
x1 = x.permute(1,0,2,3).view(3, -1)
 
mu = x1.mean(dim=1).view(1,3,1,1)
# unbiased=False, 求方差時不做無偏估計(除以 N-1 而不是 N),和原始論文一致
# 個人感覺無偏估計僅僅是數(shù)學(xué)上好看,實際應(yīng)用中差別不大
std = x1.std(dim=1, unbiased=False).view(1,3,1,1)

my_bn = (x-mu)/std

diff=(official_bn-my_bn).sum()
print('diff={}'.format(diff)) # 差別是 10-5 級的,證明和官方版本基本一致

LN 對每個樣本的 C、H、W 維度上的數(shù)據(jù)求均值和標(biāo)準(zhǔn)差,保留 N 維度

import torch
from torch import nn

x = torch.rand(10, 3, 5, 5)*10000

# normalization_shape 相當(dāng)于告訴程序這本書有多少頁,每頁多少行多少列
# eps=0 排除干擾
# elementwise_affine=False 不作映射
# 這里的映射和 BN 以及下文的 IN 有區(qū)別,它是 elementwise 的 affine,
# 即 gamma 和 beta 不是 channel 維的向量,而是維度等于 normalized_shape 的矩陣
ln = nn.LayerNorm(normalized_shape=[3, 5, 5], eps=0, elementwise_affine=False)

official_ln = ln(x)

x1 = x.view(10, -1)
mu = x1.mean(dim=1).view(10, 1, 1, 1)
std = x1.std(dim=1,unbiased=False).view(10, 1, 1, 1)

my_ln = (x-mu)/std

diff = (my_ln-official_ln).sum()

print('diff={}'.format(diff)) # 差別和官方版本數(shù)量級在 1e-5

還有Group Normalization和Instance Normalization

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容