【pytorch】StatScores的原理與使用
StatScores的原理與使用
Confusion matrix (混淆矩陣)
在介紹StatScores之前,我們先復習以下Confusion matrix。
我們有兩組數(shù)據(jù),分別為真實分布,預測分布
預測為真定義為Possitive,預測為假定義為Negetive
四分類定義
如果預測Possitive與真實一致,則為True Possitive,簡寫為TP
如果預測Negetive與真實一致,則為True Negetive,簡寫為TN
如果預測Possitive與真實不一致,則為False Possitive,簡寫為FP
-
如果預測Negetive與真實不一致,則為False Negetive,簡寫為FN
關系圖

StatScores類實際上就是統(tǒng)計一組預測數(shù)據(jù)的這四個分類。
額外提一下Precision與Recall
Precision(準確率) 與 Recall (召回率)
P r e c i s i o n = T P T P + F P Precision = \cfrac {TP} {TP+FP} Precision=TP+FPTP
R e c a l l = T P T P + F N Recall = \cfrac {TP} {TP+FN} Recall=TP+FNTP
StatScores類
繼承關系
直接繼承與Metrics
class StatScores(Metric)
四類任務
它將處理的case分為了四類
- Binary 二分類
- MultiClass 多分類
- MultiLabel 多標簽
- MultiClass&MultiLabel
沒有入?yún)⒅付ㄋ鶎俚娜蝿誧ase,代碼中是根據(jù)pred張量來判斷的。邏輯如下,

因為筆者暫時只使用第1和2中,所以其他暫不介紹了。
Update與Compute方法
所有繼承Metrics的子類都需要實現(xiàn)Update和Compute方法。
1. update

update方法中調(diào)用內(nèi)部方法 _stat_scores_update

在該方法內(nèi)部,首先將根據(jù)輸入的數(shù)據(jù)做分類 _input_format_classification
該方法主要作用是將preds和target做one hot化,所屬分類任務的case也在該方法中識別的。
_input_format_classification的四個參數(shù)
這里有三個參數(shù)注意以下:
threshold
它僅僅作用與Binary的任務,作用是preds張量中,如果元素大于threshold,則規(guī)整為1,否則規(guī)整為0num_classes
指明分類種類,如果不指明的話,代碼中根據(jù)元素值的最大值來判斷。這個值同時也會影響one_hot后的數(shù)據(jù)長度。multiclass
如果multiclass=False,則強制認為所屬任務為Binary。True或者不設置(None)則根據(jù)入?yún)⒆孕信袛?/p>-
topk
在多分類任務中,在做one_hot轉換時,需要返回的最大前k個位置。
比如[0.1,0.5,0.4], 在topk=1(默認時),返回的是 [0,1,0],
如果topk=2,則返回的是[0,1,1]
_stat_scores
_stat_scores是真實計算tp, fp, tn, fn四個值的地方。

舉個例子
假設我們有如下
preds = torch.tensor(\[0, 1, 0\])
target = torch.tensor(\[1, 1, 0\])
首先,在 _input_format_classification方法處理后,這兩個張量會轉換為one_hot形式如下,
preds = \[\[1,0\], \[0,1\], \[1,0\]\]
target= \[\[0,1\], \[0,1\], \[1,0\]\]
然后, 進入**_stat_scores**
第64,65行的計算結果如下:
\# 預測true是正確的預測值和預測是false是正確的預測值
true\_pred, false\_pred = \[\[False,False\], \[True, True\], \[True, True\]\] ,
\[ \[True True\], \[False, False\] \[False, False\]
\# 預測是Ture的預測值與預測是False的預測值
pos\_pred, neg\_pred = \[\[False, True\] \[False, True\] \[True, False\]\] ,
\[\[True False\] \[True False\] \[True False\]\]
這兩者再兩兩相乘,得到tp fp tn fn
tp = (true\_pred \* pos\_pred).sum(dim=dim)
fp = (false\_pred \* pos\_pred).sum(dim=dim)
tn = (true\_pred \* neg\_pred).sum(dim=dim)
fn = (false\_pred \* neg\_pred).sum(dim=dim)
2. compute
compute調(diào)用內(nèi)部方法 _stat_scores_compute

_stat_scores_compute
該方法返回一個數(shù)組, [tp, fp, tn, fn, tp_fn]

這個就是StatScores的返回結果。