機器學(xué)習(xí)入門筆記02:kNN算法(中)

目標(biāo)

  • 使用sk-learn實現(xiàn)kNN算法
  • 對數(shù)據(jù)分割訓(xùn)練集和測試集
  • 訓(xùn)練并驗證,評價準(zhǔn)確率
  • 使用網(wǎng)格搜索優(yōu)化超參數(shù)

實現(xiàn)代碼

#!/usr/bin/env python3
import numpy as np
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score

# 導(dǎo)入訓(xùn)練數(shù)據(jù)
iris = datasets.load_iris()
X = iris.data
y = iris.target
print(X.shape)  # > (150, 4)
print(y.shape)  # > (150,)

# 劃分訓(xùn)練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y)
print(X_train.shape)  # > (120, 4)
print(X_test.shape)  # > (30, 4)
print(y_train.shape)  # > (120,)
print(y_test.shape)  # > (30,)

# 訓(xùn)練
X_train, X_test, y_train, y_test = train_test_split(X, y)
knn_clf = KNeighborsClassifier(n_neighbors=3)
knn_clf.fit(X_train, y_train)
y_predict = knn_clf.predict(X_test)  # 比對y_predict和y_test結(jié)果是否一致
accuracy = accuracy_score(y_test, y_predict)
print(accuracy)

# 超參數(shù)優(yōu)化(網(wǎng)格搜索)
param_search = [
{"weights": ["uniform"], "n_neighbors": [i for i in range(1, 11)]},
{"weights": ["distance"], "n_neighbors": [i for i in range(1, 11)], "p": [i for i in range(1, 6)]}
]  # 定義
knn_clf = KNeighborsClassifier()  # 調(diào)用網(wǎng)格搜索方法
# 定義網(wǎng)格搜索的對象grid_search,其構(gòu)造函數(shù)的第一個參數(shù)表示對哪一個分類器進(jìn)行算法搜索,第二個參數(shù)表示網(wǎng)格搜索相應(yīng)的參數(shù)
grid_search = GridSearchCV(knn_clf, param_search)
GridSearchCV = grid_search.fit(X_train, y_train)
print(GridSearchCV)
# > GridSearchCV(cv='warn', error_score='raise-deprecating',
# >  estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30,
# > metric='minkowski',
# > metric_params=None, n_jobs=None,
# > n_neighbors=5, p=2,
# > weights='uniform'),
# >  iid='warn', n_jobs=None,
# >  param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
# >   'weights': ['uniform']},
# >  {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
# >   'p': [1, 2, 3, 4, 5], 'weights': ['distance']}],
# >  pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
# >  scoring=None, verbose=0)
print(grid_search.best_estimator_)
# > KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
# >  metric_params=None, n_jobs=None, n_neighbors=3, p=2,
# >  weights='uniform')
print(grid_search.best_score_)  # > 0.9910714285714286
knn_clf = grid_search.best_estimator_
accuracy = knn_clf.score(X_test, y_test)
print(accuracy) # > 0.9473684210526315
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容