多標(biāo)簽分類與BCEloss

什么是多標(biāo)簽分類

學(xué)習(xí)過(guò)機(jī)器學(xué)習(xí)的你,也許對(duì)分類問(wèn)題很熟悉。比如下圖:


image.png

圖片中是否包含房子?你的回答就是有或者沒有,這就是一個(gè)典型的二分類問(wèn)題。


image.png

同樣,是這幅照片,問(wèn)題變成了,這幅照片是誰(shuí)拍攝的?備選答案你,你的父親,你的母親?這就變成了一個(gè)多分類問(wèn)題。

image.png

但今天談?wù)摰亩鄻?biāo)簽是什么呢?
如果我問(wèn)你上面圖包含一座房子嗎?選項(xiàng)會(huì)是YES或NO。


image.png

你會(huì)發(fā)現(xiàn)圖中所示的答案有多個(gè)yes,而不同于之前的多分類只有一個(gè)yes。

這里思考一下:
一首歌:如果你有四個(gè)類別的音樂(lè),分別為:古典音樂(lè)、鄉(xiāng)村音樂(lè)、搖滾樂(lè)和爵士樂(lè),那么這些類別之間是互斥的。這首歌屬于哪個(gè)類別?這是一個(gè)什么問(wèn)題?

一首歌:如果你有四個(gè)類別的音樂(lè),分別為:人聲音樂(lè)、舞曲、影視原聲、流行歌曲,那么這些類別之間并不是互斥的。這首歌屬于哪個(gè)類別?這是一個(gè)什么問(wèn)題?

多標(biāo)簽的問(wèn)題的損失函數(shù)是什么

這里需要先了解一下softmax 與 sigmoid函數(shù)


image.png

這兩個(gè)函數(shù)最重要的區(qū)別,我們觀察一下:


image.png

區(qū)別還是很明顯的。

綜上,我們可以得出以下結(jié)論:


image.png

pytorch中的實(shí)現(xiàn)

PyTorch提供了兩個(gè)類來(lái)計(jì)算二分類交叉熵(Binary Cross Entropy),分別是BCELoss() 和BCEWithLogitsLoss()

看一下源碼,參考幫助,我們來(lái)玩一下


image.png
from torch import autograd
input = autograd.Variable(torch.randn(3,3), requires_grad=True)
print(input)

輸出

tensor([[ 1.9072,  1.1079,  1.4906],
        [-0.6584, -0.0512,  0.7608],
        [-0.0614,  0.6583,  0.1095]], requires_grad=True)

因?yàn)镹ote that the targets t[i] should be numbers between 0 and 1.
所以需要先sigmoid

from torch import nn
m = nn.Sigmoid()
print(m(input))

輸出

tensor([[0.8707, 0.7517, 0.8162],
        [0.3411, 0.4872, 0.6815],
        [0.4847, 0.6589, 0.5273]], grad_fn=<SigmoidBackward>)

假設(shè)你的target如下:

target = torch.FloatTensor([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
print(target)

輸出

tensor([[0., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.]])

我們先根據(jù)源碼中公式,
image.png

自己計(jì)算一下:

import math

r11 = 0 * math.log(0.8707) + (1-0) * math.log((1 - 0.8707))
r12 = 1 * math.log(0.7517) + (1-1) * math.log((1 - 0.7517))
r13 = 1 * math.log(0.8162) + (1-1) * math.log((1 - 0.8162))

r21 = 1 * math.log(0.3411) + (1-1) * math.log((1 - 0.3411))
r22 = 1 * math.log(0.4872) + (1-1) * math.log((1 - 0.4872))
r23 = 1 * math.log(0.6815) + (1-1) * math.log((1 - 0.6815))

r31 = 0 * math.log(0.4847) + (1-0) * math.log((1 - 0.4847))
r32 = 0 * math.log(0.6589) + (1-0) * math.log((1 - 0.6589))
r33 = 0 * math.log(0.5273) + (1-0) * math.log((1 - 0.5273))

r1 = -(r11 + r12 + r13) / 3
#0.8447112733378236
r2 = -(r21 + r22 + r23) / 3
#0.7260397266631787
r3 = -(r31 + r32 + r33) / 3
#0.8292933181294807
bceloss = (r1 + r2 + r3) / 3 
print(bceloss)

輸出

0.8000147727101611

核心解讀
就是把每一個(gè)標(biāo)簽的預(yù)測(cè)值(sigmoid計(jì)算之后)交給cross_entropy函數(shù)來(lái)進(jìn)行分類計(jì)算。比如樣本1是一張圖片(r1),r11代表某一個(gè)標(biāo)簽(有房子),r12代表某一個(gè)標(biāo)簽(有樹), r13代表某一個(gè)標(biāo)簽(有小狗),最后r1就是樣本1的預(yù)測(cè)值與真實(shí)值之間的loss。

我們?cè)賹?duì)比一下使用torch內(nèi)置的loss函數(shù)

loss = nn.BCELoss()
print(loss(m(input), target))

輸出

tensor(0.8000, grad_fn=<BinaryCrossEntropyBackward>)

和我們自己算的誤差非常小,可以忽略。

我們可以把sigmoid和bce的過(guò)程放到一起,使用內(nèi)建的BCEWithLogitsLoss函數(shù)

loss = nn.BCEWithLogitsLoss()
print(loss(input, target))

輸出

tensor(0.8000, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容