用驗(yàn)證曲線(xiàn) validation curve 選擇超參數(shù)

本文結(jié)構(gòu):

  • 驗(yàn)證曲線(xiàn)的作用?
  • 驗(yàn)證曲線(xiàn)是什么?
  • 怎么解讀?
  • 怎么畫(huà)?

驗(yàn)證曲線(xiàn)的作用?

我們知道誤差由偏差(bias)、方差(variance)和噪聲(noise)組成。

偏差:模型對(duì)于不同的訓(xùn)練樣本集,預(yù)測(cè)結(jié)果的平均誤差。
方差:模型對(duì)于不同訓(xùn)練樣本集的敏感程度。
噪聲:數(shù)據(jù)集本身的一項(xiàng)屬性。

同樣的數(shù)據(jù)(cos函數(shù)上的點(diǎn)加上噪聲),我們用同樣的模型(polynomial),但是超參數(shù)卻不同(degree = 1, 4 ,15),會(huì)得到不同的擬合效果:

第一個(gè)模型太簡(jiǎn)單,模型本身就擬合不了這些數(shù)據(jù)(高偏差);
第二個(gè)模型可以看成幾乎完美地?cái)M合了數(shù)據(jù);
第三個(gè)模型完美擬合了所有訓(xùn)練數(shù)據(jù),但卻不能很好地?cái)M合真實(shí)的函數(shù),也就是對(duì)于不同的訓(xùn)練數(shù)據(jù)很敏感(高方差)。

對(duì)于這兩個(gè)問(wèn)題,我們可以選擇模型和超參數(shù)來(lái)得到效果更好的配置,也就是可以通過(guò)驗(yàn)證曲線(xiàn)調(diào)節(jié)。


驗(yàn)證曲線(xiàn)是什么?

驗(yàn)證曲線(xiàn)和學(xué)習(xí)曲線(xiàn)的區(qū)別是,橫軸為某個(gè)超參數(shù)的一系列值,由此來(lái)看不同參數(shù)設(shè)置下模型的準(zhǔn)確率,而不是不同訓(xùn)練集大小下的準(zhǔn)確率。

從驗(yàn)證曲線(xiàn)上可以看到隨著超參數(shù)設(shè)置的改變,模型可能從欠擬合到合適再到過(guò)擬合的過(guò)程,進(jìn)而選擇一個(gè)合適的設(shè)置,來(lái)提高模型的性能。

需要注意的是如果我們使用驗(yàn)證分?jǐn)?shù)來(lái)優(yōu)化超參數(shù),那么該驗(yàn)證分?jǐn)?shù)是有偏差的,它無(wú)法再代表模型的泛化能力,我們就需要使用其他測(cè)試集來(lái)重新評(píng)估模型的泛化能力。

不過(guò)有時(shí)畫(huà)出單個(gè)超參數(shù)與訓(xùn)練分?jǐn)?shù)和驗(yàn)證分?jǐn)?shù)的關(guān)系圖,有助于觀察該模型在相應(yīng)的超參數(shù)取值時(shí),是否有過(guò)擬合或欠擬合的情況發(fā)生。


怎么解讀?

如圖是 SVM 在不同的 gamma 時(shí),它在訓(xùn)練集和交叉驗(yàn)證上的分?jǐn)?shù):

gamma 很小時(shí),訓(xùn)練分?jǐn)?shù)和驗(yàn)證分?jǐn)?shù)都很低,為欠擬合。
gamma 逐漸增加,兩個(gè)分?jǐn)?shù)都較高,此時(shí)模型相對(duì)不錯(cuò)。
gamma 太高時(shí),訓(xùn)練分?jǐn)?shù)高,驗(yàn)證分?jǐn)?shù)低,學(xué)習(xí)器會(huì)過(guò)擬合。

本例中,可以選驗(yàn)證集準(zhǔn)確率開(kāi)始下降,而測(cè)試集越來(lái)越高那個(gè)轉(zhuǎn)折點(diǎn)作為 gamma 的最優(yōu)選擇。


怎么畫(huà)?

下面用 SVC 為例,
調(diào)用 validation_curve

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits

from sklearn.svm import SVC
from sklearn.learning_curve import validation_curve

validation_curve 要看的是 SVC() 的超參數(shù) gamma,
gamma 的范圍是取 10^-6 到 10^-1 5 個(gè)值,
評(píng)分用的是 metrics.accuracy_score 的 accuracy:

param_range = np.logspace(-6, -1, 5)
train_scores, test_scores = validation_curve(
    SVC(), X, y, param_name="gamma", param_range=param_range,
    cv=10, scoring="accuracy", n_jobs=1)

畫(huà)圖時(shí),橫軸為 param_range,縱軸為 train_scores_mean,test_scores_mean

plt.semilogx(param_range, train_scores_mean, label="Training score", color="r")
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
             color="g")

資料:
http://sklearn.lzjqsdd.com/auto_examples/model_selection/plot_validation_curve.html#example-model-selection-plot-validation-curve-py


推薦閱讀 歷史技術(shù)博文鏈接匯總
http://www.itdecent.cn/p/28f02bb59fe5
也許可以找到你想要的

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容