計算流程
輸入shape都是N的preds和targets
一、獲取分?jǐn)?shù)的降序索引desc_score_indices:
desc_score_indices = torch.argsort(preds, descending=True)
二、使用分?jǐn)?shù)的降序索引 desc_score_indices 獲取降序排序后的 preds 和 targets :
preds = preds[desc_score_indices]
targets = targets[desc_score_indices]
三、使用降序排序的 preds ,獲取分?jǐn)?shù)下降位置的索引distinct_value_indices:
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
四、使用在 distinct_value_indices 末尾添加 N-1,構(gòu)建閾值索引 threshold_idxs :
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)
五、獲取各個閾值對應(yīng)的 tp:
累加 targets,使用 threshold_idxs 獲取各個閾值對應(yīng)的 tp。
tps = torch.cumsum(targets, dim=0)[threshold_idxs]
六、獲取各個閾值對應(yīng)的fp:
方法1: 累加 1 - targets,使用 threshold_idxs 獲取各個閾值對應(yīng)的 fp。
fps = torch.cumsum((1 - targets), dim=0)[threshold_idxs]
方法2: 通過 threshold_idxs 與 tps 計算,因為 threshold_idxs + 1 可以表示有多少樣本大于等于對應(yīng)閾值,即有多少個樣本被判斷為正樣本。
fps = 1 + threshold_idxs - tps
七、獲取閾值列表 thresholds:
thresholds = preds[threshold_idxs]
八、增加額外的閾值位,保證ROC曲線從(0,0)開始:
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
九、判斷 tps 與 fps 是否有效,并計算 tpr 與 fpr:
if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]
if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]
整體實現(xiàn)
使用 pytorch 實現(xiàn),參考 torchmetrics,源代碼中包含 pos_label 和 sample_weights 參數(shù),這里沒有使用只是簡單實現(xiàn)。
import torch
from torch import Tensor
from torch.nn import functional as F
def roc_compute_single_class(preds, targets):
#獲取分?jǐn)?shù)的降序索引desc_score_indices
desc_score_indices = torch.argsort(preds, descending=True)
#使用分?jǐn)?shù)的降序索引 desc_score_indices 獲取降序排序后的 preds 和 targets
preds = preds[desc_score_indices]
targets = targest[desc_score_indices]
#使用降序排序的 preds ,獲取分?jǐn)?shù)下降位置的索引distinct_value_indices
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
#使用在 distinct_value_indices 末尾添加 N-1,構(gòu)建閾值索引 threshold_idxs
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=targets.size(0) - 1)
#獲取各個閾值對應(yīng)的 tp
tps = torch.cumsum(targets, dim=0)[threshold_idxs]
#獲取各個閾值對應(yīng)的fp
fps = 1 + threshold_idxs - tps
#獲取閾值列表 thresholds
thresholds = preds[threshold_idxs]
#增加額外的閾值位,保證ROC曲線從(0,0)開始
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
#判斷 tps 與 fps 是否有效,并計算 tpr 與 fpr
if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]
if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]
return fpr, tpr, thresholds