ROC計算流程簡述與實現(xiàn)

計算流程

輸入shape都是N的preds和targets

一、獲取分?jǐn)?shù)的降序索引desc_score_indices:

desc_score_indices = torch.argsort(preds, descending=True)

二、使用分?jǐn)?shù)的降序索引 desc_score_indices 獲取降序排序后的 preds 和 targets :

preds = preds[desc_score_indices]
targets = targets[desc_score_indices]

三、使用降序排序的 preds ,獲取分?jǐn)?shù)下降位置的索引distinct_value_indices:

distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]

四、使用在 distinct_value_indices 末尾添加 N-1,構(gòu)建閾值索引 threshold_idxs :

threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)

五、獲取各個閾值對應(yīng)的 tp:
累加 targets,使用 threshold_idxs 獲取各個閾值對應(yīng)的 tp。

tps = torch.cumsum(targets, dim=0)[threshold_idxs]

六、獲取各個閾值對應(yīng)的fp:
方法1: 累加 1 - targets,使用 threshold_idxs 獲取各個閾值對應(yīng)的 fp。

fps = torch.cumsum((1 - targets), dim=0)[threshold_idxs]

方法2: 通過 threshold_idxs 與 tps 計算,因為 threshold_idxs + 1 可以表示有多少樣本大于等于對應(yīng)閾值,即有多少個樣本被判斷為正樣本。

fps = 1 + threshold_idxs - tps

七、獲取閾值列表 thresholds:

thresholds = preds[threshold_idxs]

八、增加額外的閾值位,保證ROC曲線從(0,0)開始:

tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

九、判斷 tps 與 fps 是否有效,并計算 tpr 與 fpr:

if tps[-1] <= 0:
    raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]

if fps[-1] <= 0:
    raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]

整體實現(xiàn)

使用 pytorch 實現(xiàn),參考 torchmetrics,源代碼中包含 pos_label 和 sample_weights 參數(shù),這里沒有使用只是簡單實現(xiàn)。

import torch
from torch import Tensor
from torch.nn import functional as F

def roc_compute_single_class(preds, targets):

    #獲取分?jǐn)?shù)的降序索引desc_score_indices
    desc_score_indices = torch.argsort(preds, descending=True)

    #使用分?jǐn)?shù)的降序索引 desc_score_indices 獲取降序排序后的 preds 和 targets 
    preds = preds[desc_score_indices]
    targets = targest[desc_score_indices]

    #使用降序排序的 preds ,獲取分?jǐn)?shù)下降位置的索引distinct_value_indices
    distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]

    #使用在 distinct_value_indices 末尾添加 N-1,構(gòu)建閾值索引 threshold_idxs 
    threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)

    #獲取各個閾值對應(yīng)的 tp
    tps = torch.cumsum(targets, dim=0)[threshold_idxs]

    #獲取各個閾值對應(yīng)的fp
    fps = 1 + threshold_idxs - tps

    #獲取閾值列表 thresholds
    thresholds = preds[threshold_idxs]

    #增加額外的閾值位,保證ROC曲線從(0,0)開始
    tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
    fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
    thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

    #判斷 tps 與 fps 是否有效,并計算 tpr 與 fpr
    if fps[-1] <= 0:
        raise ValueError("No negative samples in targets, false positive value should be meaningless")
    fpr = fps / fps[-1]

    if tps[-1] <= 0:
        raise ValueError("No positive samples in targets, true positive value should be meaningless")
    tpr = tps / tps[-1]

    return fpr, tpr, thresholds
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容