我發(fā)現(xiàn),手寫損失函數(shù)一般都會(huì)運(yùn)用到很多稍微復(fù)雜一些的張量操作,很適合用來(lái)學(xué)習(xí)pytorch張量操作,所以這里分析幾個(gè)常用損失函數(shù)練習(xí)一下。
1. Binary Cross Entropy Loss
BCELoss的計(jì)算公式很簡(jiǎn)單:

這里我們按照公式簡(jiǎn)單實(shí)現(xiàn)一下就可以:
class BCELosswithLogits(nn.Module):
def __init__(self, pos_weight=1, reduction='mean'):
super(BCELosswithLogits, self).__init__()
self.pos_weight = pos_weight
self.reduction = reduction
def forward(self, logits, target):
# logits: [N, *], target: [N, *]
logits = F.sigmoid(logits)
loss = - self.pos_weight * target * torch.log(logits) - \
(1 - target) * torch.log(1 - logits)
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
輸入target為ground truth,logits為未經(jīng)過(guò)sigmoid激活的網(wǎng)絡(luò)輸出,運(yùn)用公式計(jì)算出的loss形狀和logits相同,運(yùn)用mean或sum方法將其變?yōu)橐粋€(gè)數(shù)值。self.pos_weight調(diào)整正樣本的計(jì)算比例。
pytorch官方提供了BCEWithLogitsLoss類,除了二分類之外,還可以用于多標(biāo)簽分類,此時(shí)target形狀為N * C,logits形狀也是N * C。這種用法通常見(jiàn)于multi-label任務(wù)中,類間沒(méi)有競(jìng)爭(zhēng)關(guān)系。
2. Cross Entropy Loss
Cross Entropy Loss一般用于多分類任務(wù),其計(jì)算公式如下圖所示,其中yi等于1(第i個(gè)樣本是否屬于每一類,不屬于的都等于0了,不會(huì)算到loss里),log括號(hào)里一大堆(不想寫了)表示第i個(gè)樣本對(duì)應(yīng)logits中,其gt所屬那一類的分類置信度,比如第i個(gè)樣本是第5類,一共有C個(gè)類別,那么括號(hào)里的一堆就是一個(gè)C維向量里的第5個(gè)元素。
因此,交叉熵?fù)p失計(jì)算的其實(shí)就是每個(gè)樣本所屬實(shí)際類別對(duì)應(yīng)分類置信度的負(fù)對(duì)數(shù),也就是模型分對(duì)的可能性有多高。交叉熵?fù)p失只能用于標(biāo)簽唯一的分類任務(wù),因?yàn)轭愰g是要做softmax歸一化的,那么如果其中一類的置信度很高,對(duì)應(yīng)的其他類別的置信度就變低了,類間存在競(jìng)爭(zhēng)關(guān)系。

下面是我實(shí)現(xiàn)的交叉熵?fù)p失函數(shù),這里用到的一個(gè)平時(shí)不常用的張量操作就是gather操作,利用target將logits中對(duì)應(yīng)類別的分類置信度取出來(lái)。
class CrossEntropyLoss(torch.nn.Module):
def __init__(self, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self.reduction = reduction
def forward(self, logits, target):
# logits: [N, C, H, W], target: [N, H, W]
# loss = sum(-y_i * log(c_i))
if logits.dim() > 2:
logits = logits.view(logits.size(0), logits.size(1), -1) # [N, C, HW]
logits = logits.transpose(1, 2) # [N, HW, C]
logits = logits.contiguous().view(-1, logits.size(2)) # [NHW, C]
target = target.view(-1, 1) # [NHW,1]
logits = F.log_softmax(logits, 1)
logits = logits.gather(1, target) # [NHW, 1]
loss = -1 * logits
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
3. Focal BCE Loss
二分類的focal loss計(jì)算公式如下圖所示,與BCE loss的區(qū)別在于,每一項(xiàng)前面乘了(1-pt)^gamma,也就是該樣本的分類難度,值越大,說(shuō)明模型分的越不準(zhǔn),需要增大其loss權(quán)重;并且為了進(jìn)一步平衡正負(fù)樣本,還乘了alpha來(lái)調(diào)節(jié)。

二分類的focal loss代碼實(shí)現(xiàn)跟bceloss差不多。
class BCEFocalLosswithLogits(nn.Module):
def __init__(self, gamma=0.2, alpha=0.6, reduction='mean'):
super(BCEFocalLosswithLogits, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, logits, target):
# logits: [N, H, W], target: [N, H, W]
logits = F.sigmoid(logits)
alpha = self.alpha
gamma = self.gamma
loss = - alpha * (1 - logits) ** gamma * target * torch.log(logits) - \
(1 - alpha) * logits ** gamma * (1 - target) * torch.log(1 - logits)
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
4. Focal CE Loss
代碼參考:https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
class CrossEntropyFocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=0.2, reduction='mean'):
super(CrossEntropyFocalLoss, self).__init__()
self.reduction = reduction
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, target):
# logits: [N, C, H, W], target: [N, H, W]
# loss = sum(-y_i * log(c_i))
if logits.dim() > 2:
logits = logits.view(logits.size(0), logits.size(1), -1) # [N, C, HW]
logits = logits.transpose(1, 2) # [N, HW, C]
logits = logits.contiguous().view(-1, logits.size(2)) # [NHW, C]
target = target.view(-1, 1) # [NHW,1]
pt = F.softmax(logits, 1)
pt = pt.gather(1, target).view(-1) # [NHW]
log_gt = torch.log(pt)
if self.alpha is not None:
# alpha: [C]
alpha = self.alpha.gather(0, target.view(-1)) # [NHW]
log_gt = log_gt * alpha
loss = -1 * (1 - pt) ** self.gamma * log_gt
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
一些比較不常用的張量操作
-
torch.gather
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
輸入input,利用index選擇input中的元素,并組成out輸出。這里假設(shè)input是一個(gè)d1 * d2 * d3的張量:
dim=0,即在第0維進(jìn)行選擇,則index的尺寸應(yīng)該為1 * d2 * d3,每次在d1個(gè)元素中選擇一個(gè),輸出out尺寸也為1 * d2 * d3;
dim=1或2也類似,下面是pytorch官方文檔的描述,out和index的形狀是一樣的。
-
torch.cumsum
torch.cumsum(input, dim, *, dtype=None, out=None) → Tensor
輸入input,對(duì)指定維度進(jìn)行累加。比如:
通常 torch.full
torch.full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
創(chuàng)建size大小的張量,張量的每個(gè)元素都為fill_value。torch.empty(size).random_(N):生成size大小的張量,每個(gè)張量值為不超過(guò)N的隨機(jī)int。
torch.diag():求矩陣的對(duì)角元素

