網(wǎng)格搜索,搜索的是參數(shù),即在指定的參數(shù)范圍內(nèi),按步長依次調整參數(shù),利用調整的參數(shù)訓練學習器,從所有的參數(shù)中找到在驗證集上精度最高的參數(shù),這其實是一個訓練和比較的過程。本節(jié)介紹三種網(wǎng)格搜索方法:簡單網(wǎng)格搜索、與交叉驗證結合的網(wǎng)格搜索、使用GridSearchCV的網(wǎng)格搜索。

網(wǎng)格搜索方法
簡單網(wǎng)格搜索
for循環(huán)遍歷全部的參數(shù)設置,并找出最高分和對應的參數(shù)。
X_train, X_test, y_train, y_test=train_test_split(wine.data,
wine.target,
random_state=38)
best_score = 0
for alpha in [0.01,0.1,1.0,10.0]:
for max_iter in [100,1000,5000,10000]:
lasso = Lasso(alpha=alpha,max_iter=max_iter)
lasso.fit(X_train, y_train)
score = lasso.score(X_test, y_test)
if score > best_score:
best_score = score
best_parameters={'alpha':alpha,'最大迭代次數(shù)':max_iter}
print("模型最高分為:{:.3f}".format(best_score))
print('最佳參數(shù)設置:{}'.format(best_parameters))
模型最高分為:0.889
最佳參數(shù)設置:{'alpha': 0.01, '最大迭代次數(shù)': 100}
與交叉驗證結合的網(wǎng)格搜索
交叉驗證法和網(wǎng)格搜索法結合起來找到模型的最優(yōu)參數(shù)。只用先前拆分好的X_train來進行交叉驗證,以便我們找到最佳參數(shù)之后,再用來擬合X_test來看一下模型的得分。
for alpha in [0.01,0.1,1.0,10.0]:
for max_iter in [100,1000,5000,10000]:
lasso = Lasso(alpha=alpha,max_iter=max_iter)
scores = cross_val_score(lasso, X_train, y_train, cv=6)
score = np.mean(scores)
if score > best_score:
best_score = score
best_parameters={'alpha':alpha, '最大迭代數(shù)':max_iter}
print("模型最高分為:{:.3f}".format(best_score))
print('最佳參數(shù)設置:{}'.format(best_parameters))
模型最高分為:0.865
最佳參數(shù)設置:{'alpha': 0.01, '最大迭代數(shù)': 100}
lasso = Lasso(alpha=0.01, max_iter=100).fit(X_train, y_train)
print('測試數(shù)據(jù)集得分:{:.3f}'.format(lasso.score(X_test,y_test)))
測試數(shù)據(jù)集得分:0.819
使用GridSearchCV的網(wǎng)格搜索
GridSearchCV本身就是將交叉驗證和網(wǎng)格搜索封裝在一起。GridSearchCV需要反復建模,所需要的計算時間往往更長。
from sklearn.model_selection import GridSearchCV
params = {'alpha':[0.01,0.1,1.0,10.0],
'max_iter':[100,1000,5000,10000]}
grid_search = GridSearchCV(lasso,params,cv=6)
grid_search.fit(X_train, y_train)
print('模型最高分:{:.3f}'.format(grid_search.score(X_test, y_test)))
print('最優(yōu)參數(shù):{}'.format(grid_search.best_params_))
模型最高分:0.819
最優(yōu)參數(shù):{'alpha': 0.01, 'max_iter': 100}
print('交叉驗證最高得分:{:.3f}'.format(grid_search.best_score_))
交叉驗證最高得分:0.865
分析:GridSearchCV有一個屬性best_score_,這個屬性會存儲模型在交叉驗證中所得的最高分,而不是在測試數(shù)據(jù)集上的得分。