sklearn多分類模型評測(LR, linearSVC, lightgbm)

多分類

背景:多分類是指具有兩類以上的分類任務(wù); 例如,分類一組可能是橘子,蘋果或梨的水果圖像。本文旨在為大家提供一段即寫即用的代碼,跳過對原理的解說,直接上手跑一版baseline。當(dāng)然,后續(xù)的優(yōu)化任務(wù)還是需要一定的算法基礎(chǔ),比如模型參數(shù)以及性能參數(shù)優(yōu)化。

初步結(jié)論

本數(shù)據(jù)集上, 在迭代次數(shù)量級基本一致的情況下,lightgbm表現(xiàn)更優(yōu):樹的固有多分類特性使得不需要OVR或者OVO式的開銷,而且lightgbm本身就對決策樹進(jìn)行了優(yōu)化,因此性能和分類能力都較好。

模型 AUC 精確率 耗時(shí)(s)
linearSVC 0.9169 0.6708 883
LR 0.9226 0.6571 944
lightgbm 0.9332 0.6947 600

數(shù)據(jù)定義

一個樣本僅對應(yīng)一個標(biāo)簽
數(shù)據(jù)量: 800M(32w樣本量 * 929 特征)

數(shù)據(jù)格式

特征1|特征2|...|特征N|label

評測算法

  • LR
  • linearSVC
  • lightgbm
    notice: 樹模型是天生的多分類模型,LR、linearSVC則是基于“One-Vs-The-Rest”,即為N類訓(xùn)練N個模型,為樣本選擇一個最佳類別。
    參考:多分類和多標(biāo)簽算法

版本

系統(tǒng) 64bit centOS
sklearn 0.19.1

代碼走讀

import pandas as pd
import numpy as np
import time
import logging
import os, sys
import psutil
import lightgbm as lgb
from datetime import datetime

from itertools import cycle
from sklearn import svm
from sklearn.metrics import *
from sklearn.cross_validation import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.externals import joblib
from scipy import interp


# 循環(huán)讀取多個文件
path = "./data/d20190416/"
os.chdir(path)
files = os.listdir(path)
files_csv = list(filter(lambda x: x[:4]=='part' , files))[:200]

data_list = []
for file in  files_csv:
    tmp = pd.read_csv(path + file, sep = '|', header=None)
    data_list.append(tmp)
data_set = pd.concat(data_list, axis = 0)
del data_list

#配置列名
sample_cnt, col_cnt = data_set.shape
cols = ["x_%d"%(i) for i in range(col_cnt - 1)]
cols.append("y")
data_set.columns = cols

# 數(shù)據(jù)預(yù)覽,事先準(zhǔn)備好one-hot特征,最后一列為label={0,1,2,3}
# >>  0|1|1|1|0|0....|2 
  1. 模型 OneVsRestClassifier
    元分類器 svm.LinearSVC
    說明:OneVsRestClassifier模塊, 是通過將分類問題分解為二進(jìn)制分類問題來解決,因此構(gòu)建樣本時(shí)需要將label列轉(zhuǎn)為二進(jìn)制格式 e.g. 2 -> [0, 0, 1, 0] 0 ->[1, 0, 0, 0]
    性能:單核883s, 迭代1000次
    AUC : 0.9169
    精確率:0.6708
#########################################################################################
# 模型 OneVsRestClassifier 
# 元分類器 svm.LinearSVC
#########################################################################################
y = label_binarize(data_set["y"], classes=[0,1,2,3])
X = data_set.iloc[:, :-1]
# 隨機(jī)化數(shù)據(jù),并劃分訓(xùn)練數(shù)據(jù)和測試數(shù)據(jù)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,random_state=0)

#訓(xùn)練
model = OneVsRestClassifier(svm.LinearSVC(random_state = 0, verbose = 1))
btime = datetime.now() 
model.fit(X_train, y_train)
print 'all tasks done. total time used:%s s.\n\n'%((datetime.now() - btime).total_seconds())

# 評價(jià)
y_score = model.decision_function(X_test) # 計(jì)算屬于各個類別的概率,返回值的shape = [n_samples, n_classes]
# 1、調(diào)用函數(shù)計(jì)算micro類型的AUC
print '調(diào)用函數(shù)auc:', roc_auc_score(y_test, y_score, average='micro')

# 2、混淆矩陣
y_pred = model.predict(X_test) # 預(yù)測屬于哪個類別
confusion_matrix(y_test.argmax(axis=1), y_pred1.argmax(axis=1)) # 需要0、1、2、3而不是OH編碼格式

# 3、經(jīng)典-精確率、召回率、F1分?jǐn)?shù)
precision_score(y_test, y_pred,average='micro')
recall_score(y_test, y_pred,average='micro')
f1_score(y_test, y_pred,average='micro')

# 4、模型報(bào)告
classification_report(y_test, y_pred, digits=4)
'''             precision    recall  f1-score   support

          0       0.78      0.85      0.81     42276
          1       0.83      0.66      0.74     18960
          2       0.59      0.34      0.44     13591
          3       0.59      0.35      0.44     13170
          4       0.00      0.00      0.00      8151

avg / total       0.67      0.60      0.62     96148
'''

# 保存模型
joblib.dump(model, './model/LinearSVC.pkl')
  1. 模型(分類器) LR
    說明:LogisticRegression模塊,設(shè)置multi_class='ovr',會訓(xùn)練出“類別數(shù)”個分類器,構(gòu)建樣本時(shí)需要原始label即可
    性能:單核944s , 迭代1000次
    AUC : 0.9226
    精確率:0.6571
#########################################################################################
# 模型 LogisticRegression(random_state=0, solver='sag',multi_class='ovr', verbose = 1)
#########################################################################################
from sklearn.linear_model import LogisticRegression

# 準(zhǔn)備數(shù)據(jù)
X = data_set.iloc[:, :-1]
X_train, X_test, y_train, y_test = train_test_split(X, data_set["y"], test_size=0.3,random_state=0)

# 訓(xùn)練
btime = datetime.now() 
lr_clf = LogisticRegression(random_state=0, solver='sag',multi_class='ovr', verbose = 1)
lr_clf.fit(X_train, y_train)
print 'all tasks done. total time used:%s s.\n\n'%((datetime.now() - btime).total_seconds())

# 1、AUC
y_pred_pa = lr_clf.predict_proba(X_test)
y_test_oh = label_binarize(y_test, classes=[0,1,2,3])
print '調(diào)用函數(shù)auc:', roc_auc_score(y_test_oh, y_pred_pa, average='micro')

#  2、混淆矩陣
y_pred = lr_clf.predict(X_test)
confusion_matrix(y_test, y_pred_1)

#  3、經(jīng)典-精確率、召回率、F1分?jǐn)?shù)
precision_score(y_test, y_pred_1,average='micro')
recall_score(y_test, y_pred_1,average='micro')
f1_score(y_test, y_pred_1,average='micro')

# 4、模型報(bào)告
print(classification_report(y_test, y_pred , digits=4))

# 保存模型
joblib.dump(lr_clf, './model/lr_clf.pkl')
  1. 模型(分類器) lightgbm
    說明:樹的輸出本身就可以是多分類,應(yīng)該是操作最簡單的,構(gòu)建樣本時(shí)需要原始label即可
    性能:單核600s, 迭代200次
    AUC : 0.9332
    精確率:0.6947
#########################################################################################
# 模型 lightgbm
#########################################################################################
import lightgbm as lgb

# 準(zhǔn)備數(shù)據(jù)
X = data_set.iloc[:, :-1]
X_train, X_test, y_train, y_test = train_test_split(X, data_set["y"], test_size=0.3,random_state=0)

# 訓(xùn)練
btime = datetime.now() 
train_data=lgb.Dataset(X_train,label=y_train)
validation_data=lgb.Dataset(X_test,label=y_test)
params={
    'learning_rate':0.1,
    'lambda_l1':0.1,
    'lambda_l2':0.2,
    'max_depth':6,
    'objective':'multiclass',
    'num_class':4,  
}
clf=lgb.train(params,train_data,valid_sets=[validation_data])
print 'all tasks done. total time used:%s s.\n\n'%((datetime.now() - btime).total_seconds())

# 1、AUC
y_pred_pa = clf.predict(X_test)  # !!!注意lgm預(yù)測的是分?jǐn)?shù),類似 sklearn的predict_proba
y_test_oh = label_binarize(y_test, classes= [0,1,2,3])
print '調(diào)用函數(shù)auc:', roc_auc_score(y_test_oh, y_pred_pa, average='micro')

#  2、混淆矩陣
y_pred = y_pred_pa .argmax(axis=1)
confusion_matrix(y_test, y_pred )

#  3、經(jīng)典-精確率、召回率、F1分?jǐn)?shù)
precision_score(y_test, y_pred,average='micro')
recall_score(y_test, y_pred,average='micro')
f1_score(y_test, y_pred,average='micro')

# 4、模型報(bào)告
print(classification_report(y_test, y_pred))

# 保存模型
joblib.dump(clf, './model/lgb.pkl')
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容