Pytorch自定義Loss函數(shù)

方案一:只定義loss函數(shù)的前向計(jì)算公式

在pytorch中定義了前向計(jì)算的公式,在訓(xùn)練時(shí)它會(huì)自動(dòng)幫你計(jì)算反向傳播。

import torch.nn as nn

Class YourLoss(nn.Module):
    def __init__():
        pass

    def forward():
        pass

方案二:自定義loss函數(shù)的forward和backward

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return input.new(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return grad_output.new(result)

方案三:自己寫(xiě)一個(gè)pytorch的C擴(kuò)展

這個(gè)了解不多,所以也不太會(huì)

方案四:簡(jiǎn)單定義

看網(wǎng)上有說(shuō)直接定義一個(gè)簡(jiǎn)單函數(shù)就可以了,可以嘗試一下,與只定義forward類(lèi)似。

import torch

...... #模型操作

loss = torch.sum(x - y)

參考

  1. github:Pytorch自定義Loss函數(shù)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容