Focal loss函數(shù)及代碼

一、Focal loss損失函數(shù)

Focal Loss的引入主要是為了解決**難易樣本數(shù)量不平衡****(注意,有區(qū)別于正負樣本數(shù)量不平衡)的問題,實際可以使用的范圍非常廣泛。

本文的作者認為,易分樣本(即,置信度高的樣本)對模型的提升效果非常小,模型應(yīng)該主要關(guān)注與那些難分樣本。一個簡單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎!

focal loss函數(shù)公式:
FL(p) = -a(1-p)^\lambda log(p)
其中,(1)a為類別權(quán)重,用來權(quán)衡正負樣本不均衡問題,倘若負樣本越多,給負樣本的a權(quán)重就越小,這樣就可以降低負樣本的影響。加一個小于1的超參數(shù),相當(dāng)于把Loss曲線整體往下拉一些,使得當(dāng)樣本概率較大的時候影響減小。;
(2)\lambda 表示難分樣本權(quán)重,用來衡量難分樣本和易分樣本,對于正類樣本而言,預(yù)測結(jié)果為0.95肯定是簡單樣本,所以(1-0.95)的gamma次方就會很小,這時損失函數(shù)值就變得更小。而預(yù)測概率為0.3的樣本其損失相對很大。即正樣本:概率越小,表示hard example,損失越大; 負樣本:概率越大,表示hard example,損失越大。γ 起到了平滑的作用,作者的實驗中,論文采用α=0.25,γ=2效果最好。。針對hard example,Pt比較小,則權(quán)重比較大,讓網(wǎng)絡(luò)傾向于利用這樣的樣本來進行參數(shù)的更新

Focal loss缺點(騰訊面試):

(1) 對異常樣本敏感: 假如訓(xùn)練集中有個樣本label標(biāo)錯了,那么focal loss會一直放大這個樣本的loss(模型想矯正回來),導(dǎo)致網(wǎng)絡(luò)往錯誤方向?qū)W習(xí)。
(2)對分類邊界異常點處理不理想:由于邊界樣本表示相似性較高,對于不同異常值表示,每次損失更新時,都會有反復(fù)在分類決策面(0.5)上反復(fù)橫跳的點,導(dǎo)致模型收斂速度下降,退化成交叉熵損失。

二、Focal loss損失函數(shù)代碼

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=0.20, gamma=1.5, use_alpha=False, size_average=True):
        super(FocalLoss, self).__init__()
        self.class_num = class_num
        self.alpha = alpha
        self.gamma = gamma
        if use_alpha:
            self.alpha = torch.tensor(alpha).cuda()
            # self.alpha = torch.tensor(alpha)

        self.softmax = nn.Softmax(dim=1)
        self.use_alpha = use_alpha
        self.size_average = size_average

    def forward(self, pred, target):

        prob = self.softmax(pred.view(-1,self.class_num))
        prob = prob.clamp(min=0.0001,max=1.0)

        target_ = torch.zeros(target.size(0),self.class_num).cuda()
        # target_ = torch.zeros(target.size(0),self.class_num)
        target_.scatter_(1, target.view(-1, 1).long(), 1.)

        if self.use_alpha:
            batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
        else:
            batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()

        batch_loss = batch_loss.sum(dim=1)

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

三、Focal loss損失函數(shù)引用及使用

# 函數(shù)引用(focal_loss為模型文件名)
from focal_loss import FocalLoss

#...

# 損失函數(shù)初始化
criterion = FocalLoss(class_num=3)


#...

# 獲得損失函數(shù)
loss = criterion(outputs, targets)


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容