一、Focal loss損失函數(shù)
Focal Loss的引入主要是為了解決**難易樣本數(shù)量不平衡****(注意,有區(qū)別于正負樣本數(shù)量不平衡)的問題,實際可以使用的范圍非常廣泛。
本文的作者認為,易分樣本(即,置信度高的樣本)對模型的提升效果非常小,模型應(yīng)該主要關(guān)注與那些難分樣本。一個簡單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎!
focal loss函數(shù)公式:
其中,(1)為類別權(quán)重,用來權(quán)衡正負樣本不均衡問題,倘若負樣本越多,給負樣本的
權(quán)重就越小,這樣就可以降低負樣本的影響。加一個小于1的超參數(shù),相當(dāng)于把Loss曲線整體往下拉一些,使得當(dāng)樣本概率較大的時候影響減小。;
(2) 表示難分樣本權(quán)重,用來衡量難分樣本和易分樣本,對于正類樣本而言,預(yù)測結(jié)果為0.95肯定是簡單樣本,所以(1-0.95)的gamma次方就會很小,這時損失函數(shù)值就變得更小。而預(yù)測概率為0.3的樣本其損失相對很大。即正樣本:概率越小,表示hard example,損失越大; 負樣本:概率越大,表示hard example,損失越大。γ 起到了平滑的作用,作者的實驗中,論文采用α=0.25,γ=2效果最好。。針對hard example,Pt比較小,則權(quán)重比較大,讓網(wǎng)絡(luò)傾向于利用這樣的樣本來進行參數(shù)的更新
Focal loss缺點(騰訊面試):
(1) 對異常樣本敏感: 假如訓(xùn)練集中有個樣本label標(biāo)錯了,那么focal loss會一直放大這個樣本的loss(模型想矯正回來),導(dǎo)致網(wǎng)絡(luò)往錯誤方向?qū)W習(xí)。
(2)對分類邊界異常點處理不理想:由于邊界樣本表示相似性較高,對于不同異常值表示,每次損失更新時,都會有反復(fù)在分類決策面(0.5)上反復(fù)橫跳的點,導(dǎo)致模型收斂速度下降,退化成交叉熵損失。
二、Focal loss損失函數(shù)代碼
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=0.20, gamma=1.5, use_alpha=False, size_average=True):
super(FocalLoss, self).__init__()
self.class_num = class_num
self.alpha = alpha
self.gamma = gamma
if use_alpha:
self.alpha = torch.tensor(alpha).cuda()
# self.alpha = torch.tensor(alpha)
self.softmax = nn.Softmax(dim=1)
self.use_alpha = use_alpha
self.size_average = size_average
def forward(self, pred, target):
prob = self.softmax(pred.view(-1,self.class_num))
prob = prob.clamp(min=0.0001,max=1.0)
target_ = torch.zeros(target.size(0),self.class_num).cuda()
# target_ = torch.zeros(target.size(0),self.class_num)
target_.scatter_(1, target.view(-1, 1).long(), 1.)
if self.use_alpha:
batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
else:
batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
batch_loss = batch_loss.sum(dim=1)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
三、Focal loss損失函數(shù)引用及使用
# 函數(shù)引用(focal_loss為模型文件名)
from focal_loss import FocalLoss
#...
# 損失函數(shù)初始化
criterion = FocalLoss(class_num=3)
#...
# 獲得損失函數(shù)
loss = criterion(outputs, targets)