k 近鄰法
- k 近鄰算法
- k 近鄰模型
- k 近鄰法的實(shí)現(xiàn):kd 樹(shù)
- 搜索 kd 樹(shù)
k 近鄰模型實(shí)現(xiàn)
- k 近鄰模型實(shí)現(xiàn)
- kd 樹(shù)搜索單個(gè)近鄰點(diǎn)
- kd 樹(shù)搜索 k 個(gè)近鄰點(diǎn)
k 近鄰法(k-nearest neighbor,k-NN)是一種基本分類與回歸方法。k 近鄰法假設(shè)給定一個(gè)訓(xùn)練數(shù)據(jù)集,其中的實(shí)例類別已定。
分類時(shí),對(duì)新的實(shí)例,根據(jù)其 k 個(gè)最近鄰的訓(xùn)練實(shí)例的類別,通過(guò)多數(shù)表決等方式進(jìn)行預(yù)測(cè)。因此,k近鄰法不具有顯式的學(xué)習(xí)過(guò)程。k 近鄰法實(shí)際上利用訓(xùn)練數(shù)據(jù)集對(duì)特征向量空間進(jìn)行劃分,并作為其分類的“模型”。k 值的選擇、距離度量及分類決策規(guī)則是 k 近鄰法的三個(gè)基本要素。
k 近鄰算法
- 給定數(shù)據(jù)集
,其中,
為實(shí)例的特征向量,
為實(shí)例的類別,
;實(shí)例的特征向量
1>> 根據(jù)給定的距離度量,在訓(xùn)練集中找出與
最鄰近的
個(gè)點(diǎn),涵蓋這
個(gè)點(diǎn)的
的鄰域記作
;
2>> 在中根據(jù)分類決策規(guī)則(如多數(shù)表決)決定
的類別
:
式中為指示函數(shù)。
- k 近鄰法的特殊情況是
的情形,稱為最近鄰算法。對(duì)于輸入的實(shí)例點(diǎn)(特征向量)
,最近鄰法將訓(xùn)練數(shù)據(jù)集中與
最鄰近點(diǎn)的類作為
的類。
k 近鄰模型
- 模型由三個(gè)基本要素——距離度量、k 值的選擇和分類決策規(guī)則決定。
- 特征空間中,對(duì)每個(gè)訓(xùn)練實(shí)例點(diǎn)
,距離該點(diǎn)比其他點(diǎn)更近的所有點(diǎn)組成一個(gè)區(qū)域,叫作單元(cell)。
- 特征空間中兩個(gè)實(shí)例點(diǎn)的距離是兩個(gè)實(shí)例點(diǎn)相似程度的反映。(歐氏距離、
距離、Minkowski 距離)
- 設(shè)特征空間
是
維實(shí)數(shù)向量空間
,
,
,
,
的
距離定義為
這里。當(dāng)
時(shí), 稱為
歐氏距離,即
當(dāng)時(shí), 稱為
哈曼頓距離,即
當(dāng),它是各個(gè)坐標(biāo)距離的最大值,即
- k 值的選擇會(huì)對(duì) k 近鄰法的結(jié)果產(chǎn)生重大影響。如果選擇較小的 k 值,就相當(dāng)于用較小的鄰域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè),“學(xué)習(xí)”的近似誤差(approximation error)會(huì)減小,只有與輸入實(shí)例較近的(相似的)訓(xùn)練實(shí)例才會(huì)對(duì)預(yù)測(cè)結(jié)果起作用。但缺點(diǎn)是“學(xué)習(xí)”的估計(jì)誤差(estimation error)會(huì)增大,預(yù)測(cè)結(jié)果會(huì)對(duì)近鄰的實(shí)例點(diǎn)非常敏感。如果鄰近的實(shí)例點(diǎn)恰巧是噪聲,預(yù)測(cè)就會(huì)出錯(cuò)。
換句話說(shuō),k 值的減小就意味著整體模型變得復(fù)雜,容易發(fā)生過(guò)擬合。
如果選擇較大的 k 值,就相當(dāng)于用較大鄰域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè)。其優(yōu)點(diǎn)是可以減少學(xué)習(xí)的估計(jì)誤差。但缺點(diǎn)是學(xué)習(xí)的近似誤差會(huì)增大。這時(shí)與輸入實(shí)例較遠(yuǎn)的(不相似的)訓(xùn)練實(shí)例也會(huì)對(duì)預(yù)測(cè)起作用,使預(yù)測(cè)發(fā)生錯(cuò)誤。k 值的增大就意味著整體的模型變得簡(jiǎn)單。
如果 k=N,那么無(wú)論輸入實(shí)例是什么,都將簡(jiǎn)單地預(yù)測(cè)它屬于在訓(xùn)練實(shí)例中最多的類。這時(shí),模型過(guò)于簡(jiǎn)單,完全忽略訓(xùn)練實(shí)例中的大量有用信息,是不可取的。
在應(yīng)用中,k 值一般取一個(gè)比較小的數(shù)值。通常采用交叉驗(yàn)證法來(lái)選取最優(yōu)的 k 值。
- k 近鄰法中的分類決策規(guī)則往往是
多數(shù)表決,即由輸入實(shí)例的 k 個(gè)鄰近的訓(xùn)練實(shí)例中的多數(shù)類決定輸入實(shí)例的類。
k 近鄰法的實(shí)現(xiàn):kd 樹(shù)
- 實(shí)現(xiàn) k 近鄰法時(shí),主要考慮的問(wèn)題是如何對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行快速 k 近鄰搜索。k 近鄰法最簡(jiǎn)單的實(shí)現(xiàn)方法是線性掃描(linear scan)。這時(shí)要計(jì)算輸入實(shí)例與每一個(gè)訓(xùn)練實(shí)例的距離。當(dāng)訓(xùn)練集很大時(shí),計(jì)算非常耗時(shí),這種方法是不可行的。
- kd 樹(shù)是一種對(duì) k 維空間中的實(shí)例點(diǎn)進(jìn)行存儲(chǔ)以便對(duì)其進(jìn)行快速檢索的樹(shù)形數(shù)據(jù)結(jié)構(gòu)。kd 樹(shù)是二叉樹(shù),表示對(duì) k 維空間的一個(gè)劃分(partition)。構(gòu)造 kd 樹(shù)相當(dāng)于不斷地用垂直于坐標(biāo)軸的超平面將 k 維空間切分,構(gòu)成一系列的 k 維超矩形區(qū)域。kd 樹(shù)的每個(gè)結(jié)點(diǎn)對(duì)應(yīng)于一個(gè) k 維超矩形區(qū)域。
- 構(gòu)造 kd 樹(shù)的方法如下:構(gòu)造根結(jié)點(diǎn),使根結(jié)點(diǎn)對(duì)應(yīng)于 k 維空間中包含所有實(shí)例點(diǎn)的超矩形區(qū)域;通過(guò)下面的遞歸方法,不斷地對(duì) k 維空間進(jìn)行切分,生成子結(jié)點(diǎn)。在超矩形區(qū)域(結(jié)點(diǎn))上選擇一個(gè)坐標(biāo)軸和在此坐標(biāo)軸上的一個(gè)切分點(diǎn),確定一個(gè)超平面,這個(gè)超平面通過(guò)選定的切分點(diǎn)并垂直于選定的坐標(biāo)軸,將當(dāng)前超矩形區(qū)域切分為左右兩個(gè)子區(qū)域(子結(jié)點(diǎn));這時(shí),實(shí)例被分到兩個(gè)子區(qū)域。這個(gè)過(guò)程直到子區(qū)域內(nèi)沒(méi)有實(shí)例時(shí)終止(終止時(shí)的結(jié)點(diǎn)為葉結(jié)點(diǎn))。在此過(guò)程中,將實(shí)例保存在相應(yīng)的結(jié)點(diǎn)上。
- 通常,依次選擇坐標(biāo)軸對(duì)空間切分,選擇訓(xùn)練實(shí)例點(diǎn)在選定坐標(biāo)軸上的中位數(shù)(median)為切分點(diǎn),這樣得到的 kd 樹(shù)是平衡的。注意,
平衡的 kd 樹(shù)搜索時(shí)的效率未必是最優(yōu)的。
- 構(gòu)造平衡 kd 樹(shù):
給定 k 維空間數(shù)據(jù)集,其中
,
1>> 開(kāi)始:構(gòu)造根結(jié)點(diǎn),根結(jié)點(diǎn)對(duì)應(yīng)于包含的 k 維空間的超矩形區(qū)域。選擇
為坐標(biāo)軸,以
中所有實(shí)例的
坐標(biāo)的中位數(shù)為切分點(diǎn),將根結(jié)點(diǎn)對(duì)應(yīng)的超矩形區(qū)域切分為兩個(gè)子區(qū)域。切分由通過(guò)切分點(diǎn)并與坐標(biāo)軸
垂直的超平面實(shí)現(xiàn)。
由根結(jié)點(diǎn)生成深度為的左、右子結(jié)點(diǎn):左子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo)
小于切分點(diǎn)的子區(qū)域,右子結(jié)點(diǎn)對(duì)應(yīng)于坐標(biāo)
大于切分點(diǎn)的子區(qū)域。
將落在切分超平面上的實(shí)例點(diǎn)保存在根結(jié)點(diǎn)。
2>> 重復(fù):對(duì)深度為的結(jié)點(diǎn),選擇
為切分的坐標(biāo)軸,
,以該結(jié)點(diǎn)的區(qū)域中所有實(shí)例的
坐標(biāo)的中位數(shù)為切分點(diǎn),將該結(jié)點(diǎn)對(duì)應(yīng)的超矩形區(qū)域切分為兩個(gè)子區(qū)域。切分由通過(guò)切分點(diǎn)并與坐標(biāo)軸
垂直的超平面實(shí)現(xiàn)。
由該結(jié)點(diǎn)生成深度為的左、右子結(jié)點(diǎn):左子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo)
小于切分點(diǎn)的子區(qū)域,右子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo)
大于切分點(diǎn)的子區(qū)域。
將落在切分超平面上的實(shí)例點(diǎn)保存在該結(jié)點(diǎn)。
3>> 直到兩個(gè)子區(qū)域沒(méi)有實(shí)例存在時(shí)停止。從而形成kd樹(shù)的區(qū)域劃分。
-
kd 樹(shù)示例:
給定一個(gè)二維空間的數(shù)據(jù)集:。
根結(jié)點(diǎn)對(duì)應(yīng)包含數(shù)據(jù)集的矩形,選擇
軸,6 個(gè)數(shù)據(jù)點(diǎn)的
坐標(biāo)的中位數(shù)是 7,以平面
將空間分為左、右兩個(gè)子矩形(子結(jié)點(diǎn));接著,左矩形以
分為兩個(gè)子矩形,右矩形以
分為兩個(gè)子矩形,如此遞歸,得到如下特征空間劃分圖及 kd 樹(shù)圖。
搜索 kd 樹(shù)
- 在 kd 樹(shù)中找出包含目標(biāo)點(diǎn)
的葉結(jié)點(diǎn):從根結(jié)點(diǎn)出發(fā),遞歸地向下訪問(wèn) kd 樹(shù)。若目標(biāo)點(diǎn)
當(dāng)前維的坐標(biāo)小于切分點(diǎn)的坐標(biāo),則移動(dòng)到左子結(jié)點(diǎn),否則移動(dòng)到右子結(jié)點(diǎn)。
直到子結(jié)點(diǎn)為葉結(jié)點(diǎn)為止。
- 以此葉結(jié)點(diǎn)為“當(dāng)前最近點(diǎn)”。
- 遞歸地向上回退,在每個(gè)結(jié)點(diǎn)進(jìn)行以下操作:
1>>如果該結(jié)點(diǎn)保存的實(shí)例點(diǎn)比當(dāng)前最近點(diǎn)距離目標(biāo)點(diǎn)更近,則以該實(shí)例點(diǎn)為“當(dāng)前最近點(diǎn)”。
2>>當(dāng)前最近點(diǎn)一定存在于該結(jié)點(diǎn)一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域。檢查該子結(jié)點(diǎn)的父結(jié)點(diǎn)的另一子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)。具體地,檢查另一子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否與以目標(biāo)點(diǎn)為球心、以目標(biāo)點(diǎn)與“當(dāng)前最近點(diǎn)”間的距離為半徑的超球體相交。如果相交,可能在另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域內(nèi)存在距目標(biāo)點(diǎn)更近的點(diǎn),移動(dòng)到另一個(gè)子結(jié)點(diǎn)。接著,遞歸地進(jìn)行最近鄰搜索;如果不相交,向上回退。
- 當(dāng)回退到根結(jié)點(diǎn)時(shí),搜索結(jié)束。最后的“當(dāng)前最近點(diǎn)”即為
的最近鄰點(diǎn)。
- kd 樹(shù)更適用于訓(xùn)練實(shí)例數(shù)遠(yuǎn)大于空間維數(shù)時(shí)的 k 近鄰搜索。當(dāng)空間維數(shù)接近訓(xùn)練實(shí)例數(shù)時(shí),它的效率會(huì)迅速下降,幾乎接近線性掃描。
-
kd 樹(shù)搜索示例:
給定一個(gè)如下圖所示的 kd 樹(shù),根結(jié)點(diǎn)為 A,其子結(jié)點(diǎn)為 B,C 等。樹(shù)上共存儲(chǔ) 7 個(gè)實(shí)例點(diǎn);另有一個(gè)輸入目標(biāo)實(shí)例點(diǎn) S,求 S 的最近鄰。
解:首先在 kd 樹(shù)中找到包含點(diǎn) S 的葉結(jié)點(diǎn) D(圖中的右下區(qū)域),以點(diǎn) D 作為近似最近鄰。真正最近鄰一定在以點(diǎn) S 為中心通過(guò)點(diǎn) D 的圓的內(nèi)部。然后返回結(jié)點(diǎn) D 的父結(jié)點(diǎn) B,在結(jié)點(diǎn) B 的另一子結(jié)點(diǎn) F 的區(qū)域內(nèi)搜索最近鄰。結(jié)點(diǎn) F 的區(qū)域與圓不相交,不可能有最近鄰點(diǎn)。繼續(xù)返回上一級(jí)父結(jié)點(diǎn) A,在結(jié)點(diǎn) A 的另一子結(jié)點(diǎn) C 的區(qū)域內(nèi)搜索最近鄰。結(jié)點(diǎn) C 的區(qū)域與圓相交;該區(qū)域在圓內(nèi)的實(shí)例點(diǎn)有點(diǎn) E,點(diǎn) E 比點(diǎn) D 更近,成為新的最近鄰近似。最后得到點(diǎn) E 是點(diǎn) S 的最近鄰。
k 近鄰模型實(shí)現(xiàn)
k 近鄰模型實(shí)現(xiàn)
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
class Knn(object):
def __init__(self, x_train, y_train, k=3, p=2):
self.k = k
self.p = p
self.x_train = x_train
self.y_train = y_train
# 計(jì)算 Lp 距離
def lp(self, x, y, p=2):
if len(x) == len(y) and len(x) > 1:
_sum = sum([math.pow(abs(x[i] - y[i]), p) for i in range(len(x))])
return math.pow(_sum, 1/p)
else:
return 0
# 預(yù)測(cè)
def predict(self, x):
# 計(jì)算 x 到其他節(jié)點(diǎn)的距離
dists = [(np.linalg.norm(x - self.x_train[i], ord=self.p), self.y_train[i])
for i in range(len(x_train))]
# 排序篩選出最近的 k 個(gè)節(jié)點(diǎn)
nodes = sorted(dists, key=lambda x: x[0])[:self.k]
# 對(duì)最近的 k 個(gè)節(jié)點(diǎn)的所屬類別進(jìn)行計(jì)數(shù)
counter = Counter([node[-1] for node in nodes])
# 返回計(jì)數(shù)最多的一個(gè)類別
return counter.most_common(1)[0][0]
# 計(jì)算模型的準(zhǔn)確率
def score(self, x_test, y_test):
rights = 0
for x, y in zip(x_test, y_test):
label = self.predict(x)
if label == y:
rights += 1
return rights / len(x_test)
if __name__ == '__main__':
# 獲取鳶尾花數(shù)據(jù)集
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# 生成訓(xùn)練樣本,只取 sepal length,sepal width 作為樣本特征
train = np.array(df.iloc[:100, [0, 1, -1]])
x, y = train[:, :-1], train[:, -1]
# 劃分訓(xùn)練集與測(cè)試集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
# 生成 knn 模型,并使用測(cè)試集進(jìn)行驗(yàn)證
knn = Knn(x_train, y_train)
print('測(cè)試集驗(yàn)證模型準(zhǔn)確率為:%s' % knn.score(x_test, y_test))
# 對(duì)點(diǎn) [6.0, 3.0] 進(jìn)行分類
print('點(diǎn) [6.0, 3.0] 分類結(jié)果為: %s' % knn.predict([6.0, 3.0]))
# 繪圖
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.plot(6.0, 3.0, 'bo', label='[6.0, 3.0]')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
運(yùn)行結(jié)果
也可以使用 sklearn 實(shí)現(xiàn) knn 模型
from sklearn.neighbors import KNeighborsClassifier
# KNeighborsClassifier 參數(shù)
# n_neighbors:臨近點(diǎn)個(gè)數(shù)
# p:距離度量
# algorithm:近鄰算法,可選{'auto', 'ball_tree', 'kd_tree', 'brute'}
# weights:近鄰的權(quán)重
knn_sk = KNeighborsClassifier()
knn_sk.fit(x_train, y_train)
knn_sk.score(x_test, y_test)
kd 樹(shù)搜索單個(gè)近鄰點(diǎn)
from math import sqrt
from collections import namedtuple
class KdNode(object):
# 節(jié)點(diǎn)數(shù)據(jù)結(jié)構(gòu)
def __init__(self, node, split, left, right):
self.node = node # k 維向量節(jié)點(diǎn)
self.split = split # 分割維度序號(hào)
self.left = left # 左樹(shù)
self.right = right # 右樹(shù)
class KdTree(object):
def __init__(self, data):
self.k = len(data[0]) # 數(shù)據(jù)維度
self.root = self.create_kdnode(0, data) # 生成 kd 樹(shù)
def create_kdnode(self, split, data):
if not data: return None
data.sort(key=lambda x: x[split])
split_index = len(data) // 2
median = data[split_index] # 中位數(shù)分割點(diǎn)
split_next = (split + 1) % self.k
return KdNode(
median,
split,
self.create_kdnode(split_next, data[:split_index]),
self.create_kdnode(split_next, data[split_index + 1:])
)
# 對(duì)于構(gòu)建好的 kd 樹(shù) tree,尋找離 node 最近的節(jié)點(diǎn)
def nearest(self, node):
k = len(node)
nearest_tuple = namedtuple("nearest_tuple", "nearest_node nearest_dist visited")
def travel(kd_node, target, nearest_dist):
# python中用float("inf")和float("-inf")表示正負(fù)無(wú)窮
if kd_node is None: return nearest_tuple([0] * k, float("inf"), 0)
visited = 1
# 分割維度
split = kd_node.split
# 分割節(jié)點(diǎn)
node = kd_node.node
# 如果目標(biāo)點(diǎn)第 split 維小于分割節(jié)點(diǎn)的對(duì)應(yīng)值,則目標(biāo)離左子樹(shù)更近,否則離右子樹(shù)更近
if target[split] <= node[split]:
nearer_node = kd_node.left
further_node = kd_node.right
else:
nearer_node = kd_node.right
further_node = kd_node.left
# 進(jìn)行遍歷找到包含目標(biāo)點(diǎn)的區(qū)域
nearest_tuple_1 = travel(nearer_node, target, nearest_dist)
nearest = nearest_tuple_1.nearest_node
dist = nearest_tuple_1.nearest_dist
visited += nearest_tuple_1.visited
# 更新最近距離
if dist < nearest_dist:
nearest_dist = dist
# 第 split 維上目標(biāo)點(diǎn)與分割超平面的距離
temp_dist = abs(node[split] - target[split])
# 判斷超球體是否與超平面相交
if temp_dist > nearest_dist:
# 相交則可以直接返回,不用繼續(xù)判斷
return nearest_tuple(nearest, dist, visited)
# 計(jì)算目標(biāo)點(diǎn)與分割點(diǎn)的歐氏距離
temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(node, target)))
# 如果更近,更新最近距離
if temp_dist < nearest_dist:
nearest = node
nearest_dist = temp_dist
# 檢查另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)
nearest_tuple_2 = travel(further_node, target, nearest_dist)
visited += nearest_tuple_2.visited
if nearest_tuple_2.nearest_dist < nearest_dist:
nearest = nearest_tuple_2.nearest_node
nearest_dist = nearest_tuple_2.nearest_dist
return nearest_tuple(nearest, nearest_dist, visited)
# 從根節(jié)點(diǎn)開(kāi)始遞歸
return travel(self.root, node, float("inf"))
if __name__ == '__main__':
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KdTree(data)
print(kd.nearest([3, 5]))
運(yùn)行結(jié)果
nearest_tuple(nearest_node=[4, 7], nearest_dist=2.23606797749979, visited=4)
kd 樹(shù)搜索 k 個(gè)近鄰點(diǎn)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import sqrt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
class KdNode(object):
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = data
self.depth = depth
self.lchild = lchild
self.rchild = rchild
class KdTree(object):
def __init__(self):
self.n = 0
self.tree = None
self.nearest = None
def create(self, data_set, depth=0):
if len(data_set) > 0:
m, n = np.shape(data_set)
self.n, axis, mid = n - 1, depth % (n - 1), int(m / 2)
data_set_sorted = sorted(data_set, key=lambda x: x[axis])
node = KdNode(data_set_sorted[mid], depth)
if depth == 0: self.tree = node
node.lchild = self.create(data_set_sorted[:mid], depth+1)
node.rchild = self.create(data_set_sorted[mid+1:], depth+1)
return node
return None
# 搜索 kdtree 的前 k 個(gè)最近點(diǎn)
def search(self, x, k=1):
nearest = []
for i in range(k):
nearest.append([-1, None])
# 初始化 n 個(gè)點(diǎn),nearest 是按照距離遞減的方式
self.nearest = np.array(nearest)
def travel(node):
if node is not None:
# 當(dāng)前點(diǎn)的維度 axis
axis = node.depth % self.n
# x 點(diǎn)和當(dāng)前點(diǎn)在 axis 維度上的差
daxis = x[axis] - node.data[axis]
# 如果小于進(jìn)左子樹(shù),大于進(jìn)右子樹(shù)
if daxis < 0:
travel(node.lchild)
else:
travel(node.rchild)
# 計(jì)算 x 點(diǎn)到當(dāng)前點(diǎn)的距離 dist
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
for i, d in enumerate(self.nearest):
# 如果有比現(xiàn)在最近的 n 個(gè)點(diǎn)更近的點(diǎn),更新最近的點(diǎn)
if d[0] < 0 or dist < d[0]:
# 插入第 i 個(gè)位置的點(diǎn)
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
# 刪除最后一個(gè)多出來(lái)的點(diǎn)
self.nearest = self.nearest[:-1]
break
# 統(tǒng)計(jì)距離為 -1 的個(gè)數(shù) n
n = list(self.nearest[:, 0]).count(-1)
# self.nearest[-n-1, 0] 是當(dāng)前 nearest 中已經(jīng)有的最近點(diǎn)中,距離最大的點(diǎn)
# self.nearest[-n-1, 0] > abs(daxis) 代表以 x 點(diǎn)為圓心,self.nearest[-n-1, 0]為半徑的圓
# 與 axis 相交,說(shuō)明在左右子樹(shù)里面可能有比 self.nearest[-n-1, 0] 更近的點(diǎn)
if self.nearest[-n-1, 0] > abs(daxis):
if daxis < 0:
travel(node.rchild)
else:
travel(node.lchild)
travel(self.tree)
# nodes 就是最近 k 個(gè)點(diǎn)
nodes = self.nearest[:, 1]
counter = Counter([node.data[-1] for node in nodes])
return self.nearest, counter.most_common(1)[0][0]
if __name__ == '__main__':
# 獲取鳶尾花數(shù)據(jù)集
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# 生成訓(xùn)練樣本,只取 sepal length,sepal width 作為樣本特征
data = np.array(df.iloc[:100, [0, 1, -1]])
# 劃分訓(xùn)練集與測(cè)試集
train, test = train_test_split(data, test_size=0.1)
x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0])
x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])
# 生成 knn 模型,并使用測(cè)試集進(jìn)行驗(yàn)證
kdt = KdTree()
kdt.create(train)
score = 0
for x in test:
# 繪制訓(xùn)練點(diǎn)
plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]')
plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
# 繪制測(cè)試點(diǎn)
plt.scatter(x[0], x[1], c='red', marker='x')
# 設(shè)置臨近點(diǎn)的個(gè)數(shù)
nearest, belong = kdt.search(x[:-1], 5)
if belong == x[-1]: score += 1
print("test:", x, "predict:", belong)
print("nearest:", nearest)
for near in nearest:
# k 個(gè)近鄰點(diǎn)
plt.scatter(near[1].data[0], near[1].data[1], c='green', marker='+')
plt.legend()
plt.show()
score /= len(test)
print("score:", score)
運(yùn)行結(jié)果(部分截圖)



