pytorch loss function 總結(jié)

以下是從PyTorch 的損失函數(shù)文檔整理出來(lái)的損失函數(shù):
值得注意的是,很多的 loss 函數(shù)都有 size_averagereduce 兩個(gè)布爾類型的參數(shù),需要解釋一下。因?yàn)橐话銚p失函數(shù)都是直接計(jì)算 batch 的數(shù)據(jù),因此返回的 loss 結(jié)果都是維度為 (batch_size, ) 的向量。

  • 如果 reduce = False,那么 size_average 參數(shù)失效,直接返回向量形式的 loss;
  • 如果 reduce = True,那么 loss 返回的是標(biāo)量
    • 如果 size_average = True,返回 loss.mean();
    • 如果 size_average = True,返回 loss.sum();

所以下面講解的時(shí)候,一般都把這兩個(gè)參數(shù)設(shè)置成 False,這樣子比較好理解原始的損失函數(shù)定義。

下面是常見(jiàn)的損失函數(shù)。

nn.L1Loss

這里表述的還是不太清楚,其實(shí)要求 x和 y的維度要一樣(可以是向量或者矩陣),得到的 loss 維度也是對(duì)應(yīng)一樣的。這里用下標(biāo) i表示第 i 個(gè)元素。

    loss_fn = torch.nn.L1Loss(reduce=False, size_average=False)
    input = torch.autograd.Variable(torch.randn(3,4))
    target = torch.autograd.Variable(torch.randn(3,4))
    loss = loss_fn(input, target)
    print(input); print(target); print(loss)
    print(input.size(), target.size(), loss.size())

nn.SmoothL1Loss

也叫作 Huber Loss,誤差在 (-1,1) 上是平方損失,其他情況是 L1 損失。

這里很上面的 L1Loss 類似,都是 element-wise 的操作,下標(biāo) i 是 x的第 iii 個(gè)元素。

loss_fn = torch.nn.SmoothL1Loss(reduce=False, size_average=False)
input = torch.autograd.Variable(torch.randn(3,4))
target = torch.autograd.Variable(torch.randn(3,4))
loss = loss_fn(input, target)
print(input); print(target); print(loss)
print(input.size(), target.size(), loss.size())

nn.MSELoss

均方損失函數(shù),用法和上面類似,這里 loss, x, y 的維度是一樣的,可以是向量或者矩陣,iii 是下標(biāo)。

    loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
    input = torch.autograd.Variable(torch.randn(3,4))
    target = torch.autograd.Variable(torch.randn(3,4))
    loss = loss_fn(input, target)
    print(input); print(target); print(loss)
    print(input.size(), target.size(), loss.size())

nn.BCELoss

二分類用的交叉熵,用的時(shí)候需要在該層前面加上 Sigmoid 函數(shù)。交叉熵的定義參考 wikipedia 頁(yè)面: Cross Entropy

因?yàn)殡x散版的交叉熵定義是

,其中 p,q都是向量,且都是概率分布。如果是二分類的話,因?yàn)橹挥姓头蠢?,且兩者的概率和?1,那么只需要預(yù)測(cè)一個(gè)概率就好了,因此可以簡(jiǎn)化成


注意這里 x,y可以是向量或者矩陣,i 只是下標(biāo);x_i表示第 i 個(gè)樣本預(yù)測(cè)為 正例 的概率,y_i 表示第 i 個(gè)樣本的標(biāo)簽,w_i 表示該項(xiàng)的權(quán)重大小。可以看出,loss, x, y, w 的維度都是一樣的。

    import torch.nn.functional as F
    loss_fn = torch.nn.BCELoss(reduce=False, size_average=False)
    input = Variable(torch.randn(3, 4))
    target = Variable(torch.FloatTensor(3, 4).random_(2))
    loss = loss_fn(F.sigmoid(input), target)
    print(input); print(target); print(loss)

這里比較奇怪的是,權(quán)重的維度不是 2,而是和 x, y 一樣,有時(shí)候遇到正負(fù)例樣本不均衡的時(shí)候,可能要多寫一句話

    class_weight = Variable(torch.FloatTensor([1, 10])) # 這里正例比較少,因此權(quán)重要大一些
    target = Variable(torch.FloatTensor(3, 4).random_(2))
    weight = class_weight[target.long()] # (3, 4)
    loss_fn = torch.nn.BCELoss(weight=weight, reduce=False, size_average=False)
    # balabala...

其實(shí)這樣子做的話,如果每次 batch_size 長(zhǎng)度不一樣,只能每次都定義 loss_fn 了,不知道有沒(méi)有更好的解決方案。

nn.BCEWithLogitsLoss

上面的 nn.BCELoss 需要手動(dòng)加上一個(gè) Sigmoid 層,這里是結(jié)合了兩者,這樣做能夠利用 log_sum_exp trick,使得數(shù)值結(jié)果更加穩(wěn)定(numerical stability)。建議使用這個(gè)損失函數(shù)。

值得注意的是,文檔里的參數(shù)只有 weight, size_average 兩個(gè),但是實(shí)際測(cè)試 reduce 參數(shù)也是可以用的。此外兩個(gè)損失函數(shù)的 target 要求是 FloatTensor,而且不一樣是只能取 0, 1 兩種值,任意值應(yīng)該都是可以的。

nn.CrossEntropyLoss

多分類用的交叉熵?fù)p失函數(shù),用這個(gè) loss 前面不需要加 Softmax 層。

這里損害函數(shù)的計(jì)算,按理說(shuō)應(yīng)該也是原始交叉熵公式的形式,但是這里限制了 target 類型為 torch.LongTensr,而且不是多標(biāo)簽意味著標(biāo)簽是 one-hot 編碼的形式,即只有一個(gè)位置是 1,其他位置都是 0,那么帶入交叉熵公式中化簡(jiǎn)后就成了下面的簡(jiǎn)化形式。參考 cs231n 作業(yè)里對(duì) Softmax Loss 的推導(dǎo)。

這里的 x∈?^N,是沒(méi)有經(jīng)過(guò) Softmax 的激活值,N是 x的維度大?。ɑ蛘呓刑卣骶S度); \text{label} ∈[0,C?1] 是標(biāo)量,是對(duì)應(yīng)的標(biāo)簽,可以看到兩者維度是不一樣的。C 是要分類的個(gè)數(shù)。w∈?^C 是維度為 C 的向量,表示標(biāo)簽的權(quán)重,樣本少的類別,可以考慮把權(quán)重設(shè)置大一點(diǎn)。

    weight = torch.Tensor([1,2,1,1,10])
    loss_fn = torch.nn.CrossEntropyLoss(reduce=False, size_average=False, weight=weight)
    input = Variable(torch.randn(3, 5)) # (batch_size, C)
    target = Variable(torch.FloatTensor(3).random_(5))
    loss = loss_fn(input, target)
    print(input); print(target); print(loss)

nn.NLLLoss

用于多分類的負(fù)對(duì)數(shù)似然損失函數(shù)(Negative Log Likelihood)

在前面接上一個(gè) nn.LogSoftMax 層就等價(jià)于交叉熵?fù)p失了。事實(shí)上,nn.CrossEntropyLoss 也是調(diào)用這個(gè)函數(shù)。注意這里的x_{\text{label}}和上個(gè)交叉熵?fù)p失里的不一樣(雖然符號(hào)我給寫一樣了),這里是經(jīng)過(guò) logSoftMax運(yùn)算后的數(shù)值,

nn.NLLLoss2d

和上面類似,但是多了幾個(gè)維度,一般用在圖片上。現(xiàn)在的 pytorch 版本已經(jīng)和上面的函數(shù)合并了。

  • input, (N, C, H, W)
  • target, (N, H, W)

比如用全卷積網(wǎng)絡(luò)做 Semantic Segmentation 時(shí),最后圖片的每個(gè)點(diǎn)都會(huì)預(yù)測(cè)一個(gè)類別標(biāo)簽。

nn.KLDivLoss

KL 散度,又叫做相對(duì)熵,算的是兩個(gè)分布之間的距離,越相似則越接近零。

注意這里的 x_i是 log概率,剛開始還以為 API 弄錯(cuò)了。

nn.MarginRankingLoss

評(píng)價(jià)相似度的損失


這里的三個(gè)都是標(biāo)量,y 只能取 1 或者 -1,取 1 時(shí)表示 x1 比 x2 要大;反之 x2 要大。參數(shù) margin 表示兩個(gè)向量至少要相聚 margin 的大小,否則 loss 非負(fù)。默認(rèn) margin 取零。

nn.MultiMarginLoss

多分類(multi-class)的 Hinge 損失,


其中 1≤y≤N 表示標(biāo)簽,p 默認(rèn)取 1,margin默認(rèn)取 1,也可以取別的值。參考 cs231n 作業(yè)里對(duì) SVM Loss 的推導(dǎo)。

nn.MultiLabelMarginLoss

多類別(multi-class)多分類(multi-classification)的 Hinge 損失,是上面 MultiMarginLoss 在多類別上的拓展。同時(shí)限定 p = 1,margin = 1.

這個(gè)接口有點(diǎn)坑,是直接從 Torch 那里抄過(guò)來(lái)的,見(jiàn) MultiLabelMarginCriterion 的描述。而 Lua 的下標(biāo)和 Python 不一樣,前者的數(shù)組下標(biāo)是從 1 開始的,所以用 0 表示占位符。有幾個(gè)坑需要注意,

  1. 這里的 x,y都是大小為 N 的向量,如果 y不是向量而是標(biāo)量,后面的 \sum_j就沒(méi)有了,因此就退化成上面的 MultiMarginLoss.
  2. 限制 y的大小為 N,是為了處理多標(biāo)簽中標(biāo)簽個(gè)數(shù)不同的情況,用 0 表示占位,該位置和后面的數(shù)字都會(huì)被認(rèn)為不是正確的類。如 y=[5,3,0,0,4]那么就會(huì)被認(rèn)為是屬于類別 5 和 3,而 4 因?yàn)樵诹愫竺妫虼藭?huì)被忽略。
  3. 上面的公式和說(shuō)明只是為了和文檔保持一致,其實(shí)在調(diào)用接口的時(shí)候,用的是 -1 做占位符,而 0 是第一個(gè)類別。

舉個(gè)梨子,

import torch
loss = torch.nn.MultiLabelMarginLoss()
x = torch.autograd.Variable(torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]]))
y = torch.autograd.Variable(torch.LongTensor([[3, 0, -1, 1]]))
print loss(x, y) # will give 0.8500

按照上面的理解,第 3, 0 個(gè)是正確的類,1, 2 不是,那么,

*注意這里推導(dǎo)的第二行,我為了簡(jiǎn)短,都省略了 max(0, x) 符號(hào)。

nn.SoftMarginLoss

多標(biāo)簽二分類問(wèn)題,這 NNN 項(xiàng)都是二分類問(wèn)題,其實(shí)就是把 NNN 個(gè)二分類的 loss 加起來(lái),化簡(jiǎn)一下。其中 yy\mathbf{y} 只能取 1,?11,?11, -1 兩種,代表正類和負(fù)類。和下面的其實(shí)是等價(jià)的,只是 yy\mathbf{y} 的形式不同。

nn.MultiLabelSoftMarginLoss

上面的多分類版本,根據(jù)最大熵的多標(biāo)簽 one-versue-all 損失,其中 y只能取 1,01,01, 0 兩種,代表正類和負(fù)類。

nn.CosineEmbeddingLoss

余弦相似度的損失,目的是讓兩個(gè)向量盡量相近。注意這兩個(gè)向量都是有梯度的。

margin 可以取 [?1,1][?1,1][-1, 1],但是比較建議取 0-0.5 較好。

nn.HingeEmbeddingLoss

不知道做啥用的。另外文檔里寫錯(cuò)了,x,y的維度應(yīng)該是一樣的。


nn.TripleMarginLoss

其中 d(x_i,y_i)=\|x_i-y_i\|^2

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容