【pytorch】StatScores的原理與使用

【pytorch】StatScores的原理與使用

StatScores的原理與使用

Confusion matrix (混淆矩陣)

在介紹StatScores之前,我們先復習以下Confusion matrix。

我們有兩組數(shù)據(jù),分別為真實分布預測分布

預測為真定義為Possitive,預測為假定義為Negetive

四分類定義

  1. 如果預測Possitive與真實一致,則為True Possitive,簡寫為TP

  2. 如果預測Negetive與真實一致,則為True Negetive,簡寫為TN

  3. 如果預測Possitive與真實不一致,則為False Possitive,簡寫為FP

  4. 如果預測Negetive與真實不一致,則為False Negetive,簡寫為FN

關系圖

StatScores類實際上就是統(tǒng)計一組預測數(shù)據(jù)的這四個分類。

額外提一下Precision與Recall

Precision(準確率) 與 Recall (召回率)

P r e c i s i o n = T P T P + F P Precision = \cfrac {TP} {TP+FP} Precision=TP+FPTP

R e c a l l = T P T P + F N Recall = \cfrac {TP} {TP+FN} Recall=TP+FNTP


StatScores類

繼承關系

直接繼承與Metrics

class StatScores(Metric)  
  

四類任務

它將處理的case分為了四類

  1. Binary 二分類
  2. MultiClass 多分類
  3. MultiLabel 多標簽
  4. MultiClass&MultiLabel

沒有入?yún)⒅付ㄋ鶎俚娜蝿誧ase,代碼中是根據(jù)pred張量來判斷的。邏輯如下,

因為筆者暫時只使用第1和2中,所以其他暫不介紹了。

Update與Compute方法

所有繼承Metrics的子類都需要實現(xiàn)Update和Compute方法。

1. update

update方法中調(diào)用內(nèi)部方法 _stat_scores_update

在該方法內(nèi)部,首先將根據(jù)輸入的數(shù)據(jù)做分類 _input_format_classification

該方法主要作用是將preds和target做one hot化,所屬分類任務的case也在該方法中識別的。

_input_format_classification的四個參數(shù)

這里有三個參數(shù)注意以下:

  • threshold
    它僅僅作用與Binary的任務,作用是preds張量中,如果元素大于threshold,則規(guī)整為1,否則規(guī)整為0

  • num_classes
    指明分類種類,如果不指明的話,代碼中根據(jù)元素值的最大值來判斷。這個值同時也會影響one_hot后的數(shù)據(jù)長度。

  • multiclass
    如果multiclass=False,則強制認為所屬任務為Binary。True或者不設置(None)則根據(jù)入?yún)⒆孕信袛?/p>

  • topk
    在多分類任務中,在做one_hot轉換時,需要返回的最大前k個位置。
    比如[0.1,0.5,0.4], 在topk=1(默認時),返回的是 [0,1,0],
    如果topk=2,則返回的是[0,1,1]

_stat_scores

_stat_scores是真實計算tp, fp, tn, fn四個值的地方。

舉個例子

假設我們有如下

preds  = torch.tensor(\[0, 1, 0\])  
target = torch.tensor(\[1, 1, 0\])  

首先,在 _input_format_classification方法處理后,這兩個張量會轉換為one_hot形式如下,

preds = \[\[1,0\], \[0,1\], \[1,0\]\]  
target= \[\[0,1\], \[0,1\], \[1,0\]\]  
  

然后, 進入**_stat_scores**

第64,65行的計算結果如下:

\# 預測true是正確的預測值和預測是false是正確的預測值  
true\_pred, false\_pred = \[\[False,False\], \[True, True\], \[True, True\]\] ,   
                           \[ \[True True\], \[False, False\] \[False, False\]  
\# 預測是Ture的預測值與預測是False的預測值  
pos\_pred, neg\_pred = \[\[False, True\] \[False, True\] \[True, False\]\] ,   
                           \[\[True False\] \[True False\] \[True False\]\]  

這兩者再兩兩相乘,得到tp fp tn fn

tp = (true\_pred \* pos\_pred).sum(dim=dim)  
    fp = (false\_pred \* pos\_pred).sum(dim=dim)  
  
    tn = (true\_pred \* neg\_pred).sum(dim=dim)  
    fn = (false\_pred \* neg\_pred).sum(dim=dim)  

2. compute

compute調(diào)用內(nèi)部方法 _stat_scores_compute

_stat_scores_compute

該方法返回一個數(shù)組, [tp, fp, tn, fn, tp_fn]

這個就是StatScores的返回結果。

?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容