學(xué)習(xí)曲線:sklearn.model_selection.learning_curve

第一:學(xué)習(xí)曲線

    學(xué)習(xí)曲線是一種用來判斷訓(xùn)練模型的一種方法,它會自動 把訓(xùn)練樣本的數(shù)量按照預(yù)定的規(guī)則逐漸增加,然后畫出不同訓(xùn)練樣本數(shù)量時(shí)的模型準(zhǔn)確度。

    我們可以把Jtrain(theta) and Jtest(theta)作為縱坐標(biāo),畫出與訓(xùn)練集數(shù)據(jù)集m的大小關(guān)系,這就是學(xué)習(xí)曲線。通過學(xué)習(xí)曲線,可以直觀地觀察到模型的準(zhǔn)確性和訓(xùn)練數(shù)據(jù)大小的關(guān)系。 我們可以比較直觀的了解到我們的模型處于一個什么樣的狀態(tài),如:過擬合(overfitting)或欠擬合(underfitting)

    如果數(shù)據(jù)集的大小為m,則通過下面的流程即可畫出學(xué)習(xí)曲線:

1.把數(shù)據(jù)集分成訓(xùn)練數(shù)據(jù)集和交叉驗(yàn)證數(shù)據(jù)集(可以看作測試機(jī))

2.取訓(xùn)練數(shù)據(jù)及的20%作為訓(xùn)練樣本,訓(xùn)練出模型參數(shù)。

3.使用交叉驗(yàn)證數(shù)據(jù)集來計(jì)算訓(xùn)練出來的模型的準(zhǔn)確性。

4.以續(xù)聯(lián)數(shù)據(jù)及的準(zhǔn)確性和交叉驗(yàn)證的準(zhǔn)確性為縱坐標(biāo),訓(xùn)練數(shù)據(jù)集個數(shù)作為橫坐標(biāo),在坐標(biāo)軸上畫出上述步驟計(jì)算出來的模型準(zhǔn)確性。

5.訓(xùn)練數(shù)據(jù)集增加10%,調(diào)到步驟2,繼續(xù)執(zhí)行,知道訓(xùn)練數(shù)據(jù)集大小為100%。

第二:比較

參考鏈接:https://blog.csdn.net/u012328159/article/details/79255433

  1. learning_curve():這個函數(shù)主要是用來判斷(可視化)模型是否過擬合的,關(guān)于過擬合,就不多說了,具體可以看以前的博客:模型選擇和改進(jìn)
(X,y) = datasets.load_digits(return_X_y=True)

train_sizes,train_score,test_score = learning_curve(RandomForestClassifier(),X,y,train_sizes=[0.1,0.2,0.4,0.6,0.8,1],cv=10,scoring='accuracy')

train_error =  1- np.mean(train_score,axis=1)

test_error = 1- np.mean(test_score,axis=1)

plt.plot(train_sizes,train_error,'o-',color = 'r',label = 'training')

plt.plot(train_sizes,test_error,'o-',color = 'g',label = 'testing')

plt.legend(loc='best')

plt.xlabel('traing examples')

plt.ylabel('error')

plt.show()
  1. validation_curve():這個函數(shù)主要是用來查看在參數(shù)不同的取值下模型的性能
(X,y) = datasets.load_digits(return_X_y=True)

# print(X[:2,:])

param_range = [10,20,40,80,160,250]

train_score,test_score = validation_curve(RandomForestClassifier(),X,y,param_name='n_estimators',param_range=param_range,cv=10,scoring='accuracy')

train_score =  np.mean(train_score,axis=1)

test_score = np.mean(test_score,axis=1)

plt.plot(param_range,train_score,'o-',color = 'r',label = 'training')

plt.plot(param_range,test_score,'o-',color = 'g',label = 'testing')

plt.legend(loc='best')

plt.xlabel('number of tree')

plt.ylabel('accuracy')

plt.show()

第三:參數(shù)解釋

from sklearn.model_selection import learning_curve

參數(shù)解釋:參考:https://blog.csdn.net/gracejpw/article/details/102370364

image

X : array-like, shape (n_samples, n_features) Training vector, where n_samples is the number of samples and n_features is the number of features.

是一個m*n的矩陣,m:樣品數(shù)量,n:特征數(shù)量

y : array-like, shape (n_samples) or (n_samples, n_features), optional Target relative to X for classification or regression; None for unsupervised learning.

是一個m*1的矩陣,m:樣品數(shù)量,相對于X的目標(biāo)進(jìn)行分類或回歸

groups : array-like, with shape (n_samples,), optional Group labels for the samples used while splitting the dataset into train/test set.

將數(shù)據(jù)集拆分為訓(xùn)練/測試集時(shí)使用的樣本的標(biāo)簽分組。**[可選]**

**train_sizes **: array-like, shape (n_ticks,), dtype float or int Relative or absolute numbers of training examples that will be used to generate the learning curve. If the dtype is float, it is regarded as a fraction of the maximum size of the training set (that is determined by the selected validation method), i.e. it has to be within (0, 1]. Otherwise it is interpreted as absolute sizes of the training sets. Note that for classification the number of samples usually have to be big enough to contain at least one sample from each class. (default: np.linspace(0.1, 1.0, 5))

指定訓(xùn)練樣品數(shù)量的變化規(guī)則。比如:np.linspace(0.1, 1.0, 5)表示把訓(xùn)練樣品數(shù)量從0.1-1分成5等分,生成[0.1, 0.325,0.55,0.75,1]的序列,從序列中取出訓(xùn)練樣品數(shù)量百分比,逐個計(jì)算在當(dāng)前訓(xùn)練樣本數(shù)量情況下訓(xùn)練出來的模型準(zhǔn)確性。

**cv **: int, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy.

交叉驗(yàn)證拆分策略,可以使用sklearn.model_selection.ShuffleSplit

    None,要使用默認(rèn)的三折交叉驗(yàn)證(v0.22版本中將改為五折)

    整數(shù),用于指定(分層)KFold中的折疊數(shù),

    CV splitter

    可迭代的集(訓(xùn)練,測試)拆分為索引數(shù)組。

    對于整數(shù)/無輸入,如果估計(jì)器是分類器,y是二進(jìn)制或多類,則使用StratifiedKFold。在所有其他情況下,都使用KFold。

scoring:字符串,可調(diào)用或無,可選,默認(rèn):None,模型性能的評價(jià)指標(biāo),如(‘accuracy’、‘f1’、”mean_squared_error”等)

exploit_incremental_learning:布爾值,可選,默認(rèn)值:False

如果估算器支持增量學(xué)習(xí),此參數(shù)將用于加快擬合不同訓(xùn)練集大小的速度。

n_jobs:int或None,可選(默認(rèn)=None)

要并行運(yùn)行的作業(yè)數(shù)。None表示1。 -1表示使用所有處理器。

pre_dispatch:整數(shù)或字符串,可選

并行執(zhí)行的預(yù)調(diào)度作業(yè)數(shù)(默認(rèn)為全部)。該選項(xiàng)可以減少分配的內(nèi)存。該字符串可以是“ 2 * n_jobs”之類的表達(dá)式。

shuffle:布爾值,可選

是否在基于``train_sizes’'為前綴之前對訓(xùn)練數(shù)據(jù)進(jìn)行洗牌。

random_state:int,RandomState實(shí)例或無,可選(默認(rèn)=None)

如果為int,則random_state是隨機(jī)數(shù)生成器使用的種子;否則為false。如果是RandomState實(shí)例,則random_state是隨機(jī)數(shù)生成器;如果為None,則隨機(jī)數(shù)生成器是np.random使用的RandomState實(shí)例。在shuffle為True時(shí)使用。

error_score:‘raise’ | ‘raise-deprecating’ 或數(shù)字

如果估算器擬合中出現(xiàn)錯誤,則分配給分?jǐn)?shù)的值。如果設(shè)置為“ raise”,則會引發(fā)錯誤。如果設(shè)置為“raise-deprecating”,則會在出現(xiàn)錯誤之前打印FutureWarning。如果給出數(shù)值,則引發(fā)FitFailedWarning。此參數(shù)不會影響重新安裝步驟,這將始終引發(fā)錯誤。默認(rèn)值為“不贊成使用”,但從0.22版開始,它將更改為np.nan。

返回值:

image

第四:使用


from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import ShuffleSplit

from sklearn.model_selection import train_test_split

from sklearn.model_selection import learning_curve

from sklearn.preprocessing import PolynomialFeatures

from sklearn.pipeline import Pipeline

from sklearn.datasets import load_breast_cancer

import matplotlib.pyplot as plt

import numpy as np

import time

cancer = load_breast_cancer()

X      = cancer.data

y      = cancer.target

def polynomial_model(degree = 1, **kargs):

    polynomial_features = PolynomialFeatures(degree = degree, include_bias = False)

    logistic_regression = LogisticRegression(**kargs)

    pipeline            = Pipeline([("pf", polynomial_features),

                                    ("lr", logistic_regression)])

    return pipeline

def plot_learning_curve(plt, estimator, title, X, y, ylim = None, cv = None, n_jobs = 1, train_size = np.linspace(0.1,1,5)):

    plt.title(title)

    if ylim is not None:

        plt.ylim(*ylim)

    plt.xlabel("Training examples")

    plt.ylabel("Score")

    train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv = cv, n_jobs = n_jobs, train_sizes = train_size)

    print("train_sizes:\n",train_sizes, "\ntrain_scores:\n",train_scores, "\ntest_scores:\n",test_scores)

    train_scores_mean = np.mean(train_scores, axis = 1)

    test_scores_mean  = np.mean(test_scores, axis = 1)

    train_scores_std  = np.std(train_scores, axis = 1)

    test_scores_std  = np.std(test_scores,  axis = 1)

    plt.grid()

    plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std, alpha = 0.1,color = "r")

    plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, alpha = 0.1,color = "g")

    plt.plot(train_sizes, train_scores_mean, "o-", color = "r", label = "Training score")

    plt.plot(train_sizes, test_scores_mean,"o-", color = "g", label = "Cross-validation score")

    plt.legend(loc = "best")

    return plt

cv = ShuffleSplit(n_splits = 10, test_size = 0.2, random_state = 0)

title = "Learning Curves (degreee={0}, penalty={1})"

degrees = [1,2]

penalty = ["l1", "l2"]

start = time.clock()

plt.figure(figsize = (12,4), dpi = 144)

j = 0

for p in penalty:

    for i in range(len(degrees)):

        plt.subplot(len(penalty), len(degrees), j + 1)

        plot_learning_curve(plt, polynomial_model(degree = degrees[i], penalty = p), title.format(degrees[i], p), X, y, ylim = (0.8,1.01), cv = cv)

        j += 1

plt.tight_layout()

plt.savefig("1.png")

learning_curve的返回值結(jié)果如下:


learning_curve返回結(jié)果展示

共選擇了5組數(shù)據(jù)且選擇了10折交叉驗(yàn)證,所以,train_sizes 為5個元素的narray,train_scores 和 test_scores為5*10的矩陣,每一行,為一次數(shù)據(jù)的每一折的結(jié)果,對其求平均值,作為最終的準(zhǔn)確性。
第五:性能評估

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容