pytorch中的損失函數(shù)

1. 多標(biāo)簽分類損失函數(shù)

pytorch中能計算多標(biāo)簽分類任務(wù)loss的方法有好幾個。
binary_cross_entropy和binary_cross_entropy_with_logits都是來自torch.nn.functional的函數(shù),BCELoss和BCEWithLogitsLoss都來自torch.nn,它們的區(qū)別:

函數(shù)名 解釋
binary_cross_entropy Function that measures the Binary Cross Entropy between the target and the output
binary_cross_entropy_with_logits Function that measures Binary Cross Entropy between target and output logits
BCELoss Function that measures the Binary Cross Entropy between the target and the output
BCEWithLogitsLoss Function that measures Binary Cross Entropy between target and output logits

區(qū)別只在于這個logits,損失函數(shù)(類)名字中帶了with_logits,這里的logits指的是該損失函數(shù)已經(jīng)內(nèi)部自帶了計算logit的操作,無需在傳入給這個loss函數(shù)之前手動使用sigmoid/softmax將之前網(wǎng)絡(luò)的輸入映射到[0,1]之間。
nn.functional.xxx是函數(shù)接口,而nn.Xxx是nn.functional.xxx的類封裝,并且nn.Xxx都繼承于一個共同祖先nn.Module。

In [257]: import torch
In [258]: import torch.nn as nn
In [259]: import torch.nn.functional as F

In [260]: true = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
In [261]: pred = torch.rand((2,3))

In [262]: true
Out[262]:
tensor([[1., 0., 1.],
        [1., 0., 0.]])

In [263]: pred
Out[263]:
tensor([[0.0391, 0.7691, 0.1190],
        [0.8846, 0.1628, 0.2641]])

In [264]: F.binary_cross_entropy(torch.sigmoid(pred), true)
Out[264]: tensor(0.7361)

In [265]: F.binary_cross_entropy_with_logits(pred, true)
Out[265]: tensor(0.7361)


In [267]: lf2 = nn.BCELoss()
In [268]: lf2(torch.sigmoid(pred), true)
Out[268]: tensor(0.7361)

In [269]: lf = nn.BCEWithLogitsLoss()
In [270]: lf(pred, true)
Out[270]: tensor(0.7361)

# -(ylog(p)+(1-y)log(1-p))
In [268]: torch.sum(-(true*torch.log(torch.sigmoid(pred))+(1-true)*torch.log(1-torch.sigmoid(pred))))/6  
Out[268]: tensor(0.7361)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容