目標(biāo)
- 使用sk-learn實現(xiàn)kNN算法
- 對數(shù)據(jù)分割訓(xùn)練集和測試集
- 訓(xùn)練并驗證,評價準(zhǔn)確率
- 使用網(wǎng)格搜索優(yōu)化超參數(shù)
實現(xiàn)代碼
#!/usr/bin/env python3
import numpy as np
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score
# 導(dǎo)入訓(xùn)練數(shù)據(jù)
iris = datasets.load_iris()
X = iris.data
y = iris.target
print(X.shape) # > (150, 4)
print(y.shape) # > (150,)
# 劃分訓(xùn)練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y)
print(X_train.shape) # > (120, 4)
print(X_test.shape) # > (30, 4)
print(y_train.shape) # > (120,)
print(y_test.shape) # > (30,)
# 訓(xùn)練
X_train, X_test, y_train, y_test = train_test_split(X, y)
knn_clf = KNeighborsClassifier(n_neighbors=3)
knn_clf.fit(X_train, y_train)
y_predict = knn_clf.predict(X_test) # 比對y_predict和y_test結(jié)果是否一致
accuracy = accuracy_score(y_test, y_predict)
print(accuracy)
# 超參數(shù)優(yōu)化(網(wǎng)格搜索)
param_search = [
{"weights": ["uniform"], "n_neighbors": [i for i in range(1, 11)]},
{"weights": ["distance"], "n_neighbors": [i for i in range(1, 11)], "p": [i for i in range(1, 6)]}
] # 定義
knn_clf = KNeighborsClassifier() # 調(diào)用網(wǎng)格搜索方法
# 定義網(wǎng)格搜索的對象grid_search,其構(gòu)造函數(shù)的第一個參數(shù)表示對哪一個分類器進(jìn)行算法搜索,第二個參數(shù)表示網(wǎng)格搜索相應(yīng)的參數(shù)
grid_search = GridSearchCV(knn_clf, param_search)
GridSearchCV = grid_search.fit(X_train, y_train)
print(GridSearchCV)
# > GridSearchCV(cv='warn', error_score='raise-deprecating',
# > estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30,
# > metric='minkowski',
# > metric_params=None, n_jobs=None,
# > n_neighbors=5, p=2,
# > weights='uniform'),
# > iid='warn', n_jobs=None,
# > param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
# > 'weights': ['uniform']},
# > {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
# > 'p': [1, 2, 3, 4, 5], 'weights': ['distance']}],
# > pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
# > scoring=None, verbose=0)
print(grid_search.best_estimator_)
# > KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
# > metric_params=None, n_jobs=None, n_neighbors=3, p=2,
# > weights='uniform')
print(grid_search.best_score_) # > 0.9910714285714286
knn_clf = grid_search.best_estimator_
accuracy = knn_clf.score(X_test, y_test)
print(accuracy) # > 0.9473684210526315