pytorch學(xué)習(xí)經(jīng)驗(yàn)(五)手動(dòng)實(shí)現(xiàn)交叉熵?fù)p失及Focal Loss

我發(fā)現(xiàn),手寫損失函數(shù)一般都會(huì)運(yùn)用到很多稍微復(fù)雜一些的張量操作,很適合用來(lái)學(xué)習(xí)pytorch張量操作,所以這里分析幾個(gè)常用損失函數(shù)練習(xí)一下。

1. Binary Cross Entropy Loss

BCELoss的計(jì)算公式很簡(jiǎn)單:


BCE公式

這里我們按照公式簡(jiǎn)單實(shí)現(xiàn)一下就可以:

class BCELosswithLogits(nn.Module):
    def __init__(self, pos_weight=1, reduction='mean'):
        super(BCELosswithLogits, self).__init__()
        self.pos_weight = pos_weight
        self.reduction = reduction

    def forward(self, logits, target):
        # logits: [N, *], target: [N, *]
        logits = F.sigmoid(logits)
        loss = - self.pos_weight * target * torch.log(logits) - \
               (1 - target) * torch.log(1 - logits)
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

輸入target為ground truth,logits為未經(jīng)過(guò)sigmoid激活的網(wǎng)絡(luò)輸出,運(yùn)用公式計(jì)算出的loss形狀和logits相同,運(yùn)用mean或sum方法將其變?yōu)橐粋€(gè)數(shù)值。self.pos_weight調(diào)整正樣本的計(jì)算比例。
pytorch官方提供了BCEWithLogitsLoss類,除了二分類之外,還可以用于多標(biāo)簽分類,此時(shí)target形狀為N * C,logits形狀也是N * C。這種用法通常見(jiàn)于multi-label任務(wù)中,類間沒(méi)有競(jìng)爭(zhēng)關(guān)系。

2. Cross Entropy Loss

Cross Entropy Loss一般用于多分類任務(wù),其計(jì)算公式如下圖所示,其中yi等于1(第i個(gè)樣本是否屬于每一類,不屬于的都等于0了,不會(huì)算到loss里),log括號(hào)里一大堆(不想寫了)表示第i個(gè)樣本對(duì)應(yīng)logits中,其gt所屬那一類的分類置信度,比如第i個(gè)樣本是第5類,一共有C個(gè)類別,那么括號(hào)里的一堆就是一個(gè)C維向量里的第5個(gè)元素。
因此,交叉熵?fù)p失計(jì)算的其實(shí)就是每個(gè)樣本所屬實(shí)際類別對(duì)應(yīng)分類置信度的負(fù)對(duì)數(shù),也就是模型分對(duì)的可能性有多高。交叉熵?fù)p失只能用于標(biāo)簽唯一的分類任務(wù),因?yàn)轭愰g是要做softmax歸一化的,那么如果其中一類的置信度很高,對(duì)應(yīng)的其他類別的置信度就變低了,類間存在競(jìng)爭(zhēng)關(guān)系。


下面是我實(shí)現(xiàn)的交叉熵?fù)p失函數(shù),這里用到的一個(gè)平時(shí)不常用的張量操作就是gather操作,利用target將logits中對(duì)應(yīng)類別的分類置信度取出來(lái)。

class CrossEntropyLoss(torch.nn.Module):
    def __init__(self, reduction='mean'):
        super(CrossEntropyLoss, self).__init__()
        self.reduction = reduction
    def forward(self, logits, target):
        # logits: [N, C, H, W], target: [N, H, W]
        # loss = sum(-y_i * log(c_i))
        if logits.dim() > 2:
            logits = logits.view(logits.size(0), logits.size(1), -1)  # [N, C, HW]
            logits = logits.transpose(1, 2)   # [N, HW, C]
            logits = logits.contiguous().view(-1, logits.size(2))    # [NHW, C]
        target = target.view(-1, 1)    # [NHW,1]

        logits = F.log_softmax(logits, 1)
        logits = logits.gather(1, target)   # [NHW, 1]
        loss = -1 * logits

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

3. Focal BCE Loss

二分類的focal loss計(jì)算公式如下圖所示,與BCE loss的區(qū)別在于,每一項(xiàng)前面乘了(1-pt)^gamma,也就是該樣本的分類難度,值越大,說(shuō)明模型分的越不準(zhǔn),需要增大其loss權(quán)重;并且為了進(jìn)一步平衡正負(fù)樣本,還乘了alpha來(lái)調(diào)節(jié)。



二分類的focal loss代碼實(shí)現(xiàn)跟bceloss差不多。

class BCEFocalLosswithLogits(nn.Module):
    def __init__(self, gamma=0.2, alpha=0.6, reduction='mean'):
        super(BCEFocalLosswithLogits, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, logits, target):
        # logits: [N, H, W], target: [N, H, W]
        logits = F.sigmoid(logits)
        alpha = self.alpha
        gamma = self.gamma
        loss = - alpha * (1 - logits) ** gamma * target * torch.log(logits) - \
               (1 - alpha) * logits ** gamma * (1 - target) * torch.log(1 - logits)
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

4. Focal CE Loss

代碼參考:https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py

class CrossEntropyFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=0.2, reduction='mean'):
        super(CrossEntropyFocalLoss, self).__init__()
        self.reduction = reduction
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, target):
        # logits: [N, C, H, W], target: [N, H, W]
        # loss = sum(-y_i * log(c_i))
        if logits.dim() > 2:
            logits = logits.view(logits.size(0), logits.size(1), -1)  # [N, C, HW]
            logits = logits.transpose(1, 2)   # [N, HW, C]
            logits = logits.contiguous().view(-1, logits.size(2))    # [NHW, C]
        target = target.view(-1, 1)    # [NHW,1]

        pt = F.softmax(logits, 1)
        pt = pt.gather(1, target).view(-1)   # [NHW]
        log_gt = torch.log(pt)

        if self.alpha is not None:
            # alpha: [C]
            alpha = self.alpha.gather(0, target.view(-1))   # [NHW]
            log_gt = log_gt * alpha
            
        loss = -1 * (1 - pt) ** self.gamma * log_gt

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

一些比較不常用的張量操作

  • torch.gather
    torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
    輸入input,利用index選擇input中的元素,并組成out輸出。這里假設(shè)input是一個(gè)d1 * d2 * d3的張量:
    dim=0,即在第0維進(jìn)行選擇,則index的尺寸應(yīng)該為1 * d2 * d3,每次在d1個(gè)元素中選擇一個(gè),輸出out尺寸也為1 * d2 * d3;
    dim=1或2也類似,下面是pytorch官方文檔的描述,out和index的形狀是一樣的。

  • torch.cumsum
    torch.cumsum(input, dim, *, dtype=None, out=None) → Tensor
    輸入input,對(duì)指定維度進(jìn)行累加。比如:


    通常

  • torch.full
    torch.full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
    創(chuàng)建size大小的張量,張量的每個(gè)元素都為fill_value。

  • torch.empty(size).random_(N):生成size大小的張量,每個(gè)張量值為不超過(guò)N的隨機(jī)int。

  • torch.diag():求矩陣的對(duì)角元素

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容