pytorch nn.BatchNorm1d 與手動(dòng)python實(shí)現(xiàn)不一樣--解決辦法

由于實(shí)驗(yàn)需要,便用pytorch函數(shù)手動(dòng)實(shí)現(xiàn)了batchnorm函數(shù),但是最后發(fā)現(xiàn)結(jié)果不對(duì),最后在Pytorch論壇上找到了相關(guān)解決辦法!

基礎(chǔ)

前期實(shí)現(xiàn)

上述博客給出了python實(shí)現(xiàn)代碼,我將其中的numpy函數(shù)改成了pytorch的相關(guān)函數(shù):

def fowardbn(x, gam, beta, ):
'''
x:(N,D)維數(shù)據(jù)
'''
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=0)
    var = x.var(dim=0)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache

然后與nn.BatchNorm1d計(jì)算的結(jié)果比較:

model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)

out, cache = fowardbn(input2, model2.weight, model2.bias) # 使用相同的尺度變換量

發(fā)現(xiàn)結(jié)果x和out的值不一樣。
然后就不停的找問(wèn)題是不是實(shí)現(xiàn)方法有差別。
\color{red}{最后}在官方論壇上找到了,有人遇到了相同的問(wèn)題,官方人員給了答復(fù),還提供了一個(gè)官方的實(shí)現(xiàn)版本。
Pytorch的論壇做的還是挺不錯(cuò)的。

問(wèn)題

我發(fā)現(xiàn)官方實(shí)現(xiàn)的代碼中

var = input.var([0, 2, 3], unbiased=False)

在求輸入的方差時(shí),多了一個(gè)參數(shù)設(shè)置unbiased=False,不懂。
我又查看了一下Pytorch的代碼文檔:

torch.var(input, unbiased=True) → Tensor

Returns the variance of all elements in the input tensor.
If unbiased is False, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.

意思是unbiased = False時(shí),通過(guò)無(wú)偏估計(jì)計(jì)算,反之則通過(guò)貝塞爾矯正方法計(jì)算??捎萌缦聢D片總結(jié):

image.png

這是統(tǒng)計(jì)方面的知識(shí)了,可以參考此博客。

最終實(shí)現(xiàn)代碼

將初始代碼中方差計(jì)算加上參數(shù)unbiased = False,結(jié)果正確,完整代碼如下

def fowardbn(x, gam, beta, ):
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=0)
    var = x.var(dim=0,unbiased=False)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache

model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)
out, cache = fowardbn(input2, model2.weight, model2.bias)

Reference

Batch Normalization 學(xué)習(xí)筆記
Batch Normalization梯度反向傳播推導(dǎo)
PyTorch論壇問(wèn)題
官方人員給的batchnorm2d的手動(dòng)實(shí)現(xiàn)代碼
方差的貝塞爾校正

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容