方案一:只定義loss函數(shù)的前向計(jì)算公式
在pytorch中定義了前向計(jì)算的公式,在訓(xùn)練時(shí)它會(huì)自動(dòng)幫你計(jì)算反向傳播。
import torch.nn as nn
Class YourLoss(nn.Module):
def __init__():
pass
def forward():
pass
方案二:自定義loss函數(shù)的forward和backward
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return input.new(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return grad_output.new(result)
方案三:自己寫(xiě)一個(gè)pytorch的C擴(kuò)展
這個(gè)了解不多,所以也不太會(huì)
方案四:簡(jiǎn)單定義
看網(wǎng)上有說(shuō)直接定義一個(gè)簡(jiǎn)單函數(shù)就可以了,可以嘗試一下,與只定義forward類(lèi)似。
import torch
...... #模型操作
loss = torch.sum(x - y)