機(jī)器學(xué)習(xí)一些代碼記錄

計算多分類時的每個類別的F1

  • 接口
sklearn.metrics.classification_report(y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2, output_dict=False)

示例:

from sklearn.metrics import classification_report
y_true = [0,0, 1, 2, 2, 2, 0]
y_pred = [0, 1, 0, 2, 2, 1, 0]
target_names = ['dog', 'pig', 'cat']
result = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
print(result)
image.png

pytorch 使用K-折交叉驗證

pytorch 使用K-折交叉驗證

核心代碼

  # Define the K-fold Cross Validator
  kfold = KFold(n_splits=k_folds, shuffle=True)

  # K-fold Cross Validation model evaluation
  for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset))
    
    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    
    # Define data loaders for training and testing data in this fold
    trainloader = torch.utils.data.DataLoader(
                      dataset, 
                      batch_size=10, sampler=train_subsampler)
    testloader = torch.utils.data.DataLoader(
                      dataset,
                      batch_size=10, sampler=test_subsampler)

Pytorch的nn.CrossEntropyLoss()的weight使用

Pytorch的nn.CrossEntropyLoss()的weight使用

  • 大多使用:1/類別出現(xiàn)的次數(shù), 有人建議使用:出現(xiàn)類別最多的數(shù)目/自身類別出現(xiàn)的次數(shù)
    核心代碼
weights = [1/1016, 1/12852, 1/12888, 1/3380, 1/296] #[ 1 / number of instances for each class]
class_weights = torch.FloatTensor(weights).cuda()

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

BERT模型中車cased 是需要區(qū)分大小寫的,也就是字符不要lower() . uncased 是不區(qū)分大小寫的,也就是此表只有小寫,字符需要lower()

馬氏距離的計算

import numpy as np
from scipy.spatial.distance import mahalanobis

def mahalanobis_distance(p, distr):

    # p: a point
    # distr : a distribution

    # covariance matrix
    cov = np.cov(distr, rowvar=False)

    # average of the points in distr
    avg_distri = np.average(distr, axis=0)

    dis = mahalanobis(p, avg_distri, cov)

    return dis
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容