由于實(shí)驗(yàn)需要,便用pytorch函數(shù)手動(dòng)實(shí)現(xiàn)了batchnorm函數(shù),但是最后發(fā)現(xiàn)結(jié)果不對(duì),最后在Pytorch論壇上找到了相關(guān)解決辦法!
基礎(chǔ)
前期實(shí)現(xiàn)
上述博客給出了python實(shí)現(xiàn)代碼,我將其中的numpy函數(shù)改成了pytorch的相關(guān)函數(shù):
def fowardbn(x, gam, beta, ):
'''
x:(N,D)維數(shù)據(jù)
'''
momentum = 0.1
eps = 1e-05
running_mean = 0
running_var = 1
running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
mean = x.mean(dim=0)
var = x.var(dim=0)
# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
x_hat = (x - mean) / torch.sqrt(var + eps)
out = gam * x_hat + beta
cache = (x, gam, beta, x_hat, mean, var, eps)
return out, cache
然后與nn.BatchNorm1d計(jì)算的結(jié)果比較:
model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)
out, cache = fowardbn(input2, model2.weight, model2.bias) # 使用相同的尺度變換量
發(fā)現(xiàn)結(jié)果x和out的值不一樣。
然后就不停的找問(wèn)題是不是實(shí)現(xiàn)方法有差別。
在官方論壇上找到了,有人遇到了相同的問(wèn)題,官方人員給了答復(fù),還提供了一個(gè)官方的實(shí)現(xiàn)版本。
Pytorch的論壇做的還是挺不錯(cuò)的。
問(wèn)題
我發(fā)現(xiàn)官方實(shí)現(xiàn)的代碼中
var = input.var([0, 2, 3], unbiased=False)
在求輸入的方差時(shí),多了一個(gè)參數(shù)設(shè)置unbiased=False,不懂。
我又查看了一下Pytorch的代碼文檔:
torch.var(input, unbiased=True) → Tensor
Returns the variance of all elements in the
inputtensor.
IfunbiasedisFalse, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.
意思是unbiased = False時(shí),通過(guò)無(wú)偏估計(jì)計(jì)算,反之則通過(guò)貝塞爾矯正方法計(jì)算??捎萌缦聢D片總結(jié):

這是統(tǒng)計(jì)方面的知識(shí)了,可以參考此博客。
最終實(shí)現(xiàn)代碼
將初始代碼中方差計(jì)算加上參數(shù)unbiased = False,結(jié)果正確,完整代碼如下
def fowardbn(x, gam, beta, ):
momentum = 0.1
eps = 1e-05
running_mean = 0
running_var = 1
running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
mean = x.mean(dim=0)
var = x.var(dim=0,unbiased=False)
# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
x_hat = (x - mean) / torch.sqrt(var + eps)
out = gam * x_hat + beta
cache = (x, gam, beta, x_hat, mean, var, eps)
return out, cache
model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)
out, cache = fowardbn(input2, model2.weight, model2.bias)
Reference
Batch Normalization 學(xué)習(xí)筆記
Batch Normalization梯度反向傳播推導(dǎo)
PyTorch論壇問(wèn)題
官方人員給的batchnorm2d的手動(dòng)實(shí)現(xiàn)代碼
方差的貝塞爾校正