k 近鄰法

k 近鄰法

  • k 近鄰算法
  • k 近鄰模型
  • k 近鄰法的實(shí)現(xiàn):kd 樹(shù)
  • 搜索 kd 樹(shù)

k 近鄰模型實(shí)現(xiàn)

  • k 近鄰模型實(shí)現(xiàn)
  • kd 樹(shù)搜索單個(gè)近鄰點(diǎn)
  • kd 樹(shù)搜索 k 個(gè)近鄰點(diǎn)

k 近鄰法(k-nearest neighbor,k-NN)是一種基本分類與回歸方法。k 近鄰法假設(shè)給定一個(gè)訓(xùn)練數(shù)據(jù)集,其中的實(shí)例類別已定。分類時(shí),對(duì)新的實(shí)例,根據(jù)其 k 個(gè)最近鄰的訓(xùn)練實(shí)例的類別,通過(guò)多數(shù)表決等方式進(jìn)行預(yù)測(cè)。因此,k近鄰法不具有顯式的學(xué)習(xí)過(guò)程。k 近鄰法實(shí)際上利用訓(xùn)練數(shù)據(jù)集對(duì)特征向量空間進(jìn)行劃分,并作為其分類的“模型”。k 值的選擇、距離度量及分類決策規(guī)則是 k 近鄰法的三個(gè)基本要素。

k 近鄰算法

  1. 給定數(shù)據(jù)集 T=\{(x_1,y_1), (x_2,y_2),...,(x_N,y_N)\},其中, x_i \in X \subseteq R^n 為實(shí)例的特征向量, y_i \in Y = \{c_1, c_2, ..., c_K\} 為實(shí)例的類別, i=1,2,...,N;實(shí)例的特征向量 x
    1>> 根據(jù)給定的距離度量,在訓(xùn)練集 T 中找出與 x 最鄰近的 k 個(gè)點(diǎn),涵蓋這 k 個(gè)點(diǎn)的 x 的鄰域記作 N_k(x);
    2>> 在 N_k(x) 中根據(jù)分類決策規(guī)則(如多數(shù)表決)決定 x 的類別 y
    y = arg\ max_{c_j} \sum_{x_i \in N_i(x)}I(y_i=c_j),\ \ \ \ \ i=1,2,..,N; j=1,2,...,K
    式中I 為指示函數(shù)。

  2. k 近鄰法的特殊情況是 k=1 的情形,稱為最近鄰算法。對(duì)于輸入的實(shí)例點(diǎn)(特征向量)x,最近鄰法將訓(xùn)練數(shù)據(jù)集中與 x 最鄰近點(diǎn)的類作為 x 的類。

k 近鄰模型

  1. 模型由三個(gè)基本要素——距離度量、k 值的選擇和分類決策規(guī)則決定。

  2. 特征空間中,對(duì)每個(gè)訓(xùn)練實(shí)例點(diǎn) x_i,距離該點(diǎn)比其他點(diǎn)更近的所有點(diǎn)組成一個(gè)區(qū)域,叫作單元(cell)。


  1. 特征空間中兩個(gè)實(shí)例點(diǎn)的距離是兩個(gè)實(shí)例點(diǎn)相似程度的反映。(歐氏距離、L_p距離、Minkowski 距離)

  2. 設(shè)特征空間 Xn 維實(shí)數(shù)向量空間 R^n,x_i,x_j \in X,x_i=(x_i^{(1)}, x_i^{(2)},...,x_i^{{n}})^T,x_j=(x_j^{(1)}, x_j^{(2)},...,x_j^{{n}})^Tx_i,x_jL_p 距離定義為
    L_p(x_i,x_j) = (\sum_{I=1}^n \mid x_i^{(I)} - x_j^{(I)} \mid ^p)^{\frac{1}{p}}
    這里 p \ge 1。當(dāng) p=2 時(shí), 稱為歐氏距離,即
    L_2(x_i,x_j) = (\sum_{I=1}^n \mid x_i^{(I)} - x_j^{(I)} \mid ^2)^{\frac{1}{2}}
    當(dāng) p=1 時(shí), 稱為哈曼頓距離,即
    L_1(x_i,x_j) = \sum_{I=1}^n \mid x_i^{(I)} - x_j^{(I)} \mid
    當(dāng) p=\infty,它是各個(gè)坐標(biāo)距離的最大值,即
    L_\infty(x_i,x_j) = max \mid x_i^{(I)} - x_j^{(I)} \mid

  3. k 值的選擇會(huì)對(duì) k 近鄰法的結(jié)果產(chǎn)生重大影響。如果選擇較小的 k 值,就相當(dāng)于用較小的鄰域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè),“學(xué)習(xí)”的近似誤差(approximation error)會(huì)減小,只有與輸入實(shí)例較近的(相似的)訓(xùn)練實(shí)例才會(huì)對(duì)預(yù)測(cè)結(jié)果起作用。但缺點(diǎn)是“學(xué)習(xí)”的估計(jì)誤差(estimation error)會(huì)增大,預(yù)測(cè)結(jié)果會(huì)對(duì)近鄰的實(shí)例點(diǎn)非常敏感。如果鄰近的實(shí)例點(diǎn)恰巧是噪聲,預(yù)測(cè)就會(huì)出錯(cuò)。換句話說(shuō),k 值的減小就意味著整體模型變得復(fù)雜,容易發(fā)生過(guò)擬合。
    如果選擇較大的 k 值,就相當(dāng)于用較大鄰域中的訓(xùn)練實(shí)例進(jìn)行預(yù)測(cè)。其優(yōu)點(diǎn)是可以減少學(xué)習(xí)的估計(jì)誤差。但缺點(diǎn)是學(xué)習(xí)的近似誤差會(huì)增大。這時(shí)與輸入實(shí)例較遠(yuǎn)的(不相似的)訓(xùn)練實(shí)例也會(huì)對(duì)預(yù)測(cè)起作用,使預(yù)測(cè)發(fā)生錯(cuò)誤。k 值的增大就意味著整體的模型變得簡(jiǎn)單。
    如果 k=N,那么無(wú)論輸入實(shí)例是什么,都將簡(jiǎn)單地預(yù)測(cè)它屬于在訓(xùn)練實(shí)例中最多的類。這時(shí),模型過(guò)于簡(jiǎn)單,完全忽略訓(xùn)練實(shí)例中的大量有用信息,是不可取的。
    在應(yīng)用中,k 值一般取一個(gè)比較小的數(shù)值。通常采用交叉驗(yàn)證法來(lái)選取最優(yōu)的 k 值。

  4. k 近鄰法中的分類決策規(guī)則往往是多數(shù)表決,即由輸入實(shí)例的 k 個(gè)鄰近的訓(xùn)練實(shí)例中的多數(shù)類決定輸入實(shí)例的類。

k 近鄰法的實(shí)現(xiàn):kd 樹(shù)

  1. 實(shí)現(xiàn) k 近鄰法時(shí),主要考慮的問(wèn)題是如何對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行快速 k 近鄰搜索。k 近鄰法最簡(jiǎn)單的實(shí)現(xiàn)方法是線性掃描(linear scan)。這時(shí)要計(jì)算輸入實(shí)例與每一個(gè)訓(xùn)練實(shí)例的距離。當(dāng)訓(xùn)練集很大時(shí),計(jì)算非常耗時(shí),這種方法是不可行的。

  2. kd 樹(shù)是一種對(duì) k 維空間中的實(shí)例點(diǎn)進(jìn)行存儲(chǔ)以便對(duì)其進(jìn)行快速檢索的樹(shù)形數(shù)據(jù)結(jié)構(gòu)。kd 樹(shù)是二叉樹(shù),表示對(duì) k 維空間的一個(gè)劃分(partition)。構(gòu)造 kd 樹(shù)相當(dāng)于不斷地用垂直于坐標(biāo)軸的超平面將 k 維空間切分,構(gòu)成一系列的 k 維超矩形區(qū)域。kd 樹(shù)的每個(gè)結(jié)點(diǎn)對(duì)應(yīng)于一個(gè) k 維超矩形區(qū)域。

  3. 構(gòu)造 kd 樹(shù)的方法如下:構(gòu)造根結(jié)點(diǎn),使根結(jié)點(diǎn)對(duì)應(yīng)于 k 維空間中包含所有實(shí)例點(diǎn)的超矩形區(qū)域;通過(guò)下面的遞歸方法,不斷地對(duì) k 維空間進(jìn)行切分,生成子結(jié)點(diǎn)。在超矩形區(qū)域(結(jié)點(diǎn))上選擇一個(gè)坐標(biāo)軸和在此坐標(biāo)軸上的一個(gè)切分點(diǎn),確定一個(gè)超平面,這個(gè)超平面通過(guò)選定的切分點(diǎn)并垂直于選定的坐標(biāo)軸,將當(dāng)前超矩形區(qū)域切分為左右兩個(gè)子區(qū)域(子結(jié)點(diǎn));這時(shí),實(shí)例被分到兩個(gè)子區(qū)域。這個(gè)過(guò)程直到子區(qū)域內(nèi)沒(méi)有實(shí)例時(shí)終止(終止時(shí)的結(jié)點(diǎn)為葉結(jié)點(diǎn))。在此過(guò)程中,將實(shí)例保存在相應(yīng)的結(jié)點(diǎn)上。

  4. 通常,依次選擇坐標(biāo)軸對(duì)空間切分,選擇訓(xùn)練實(shí)例點(diǎn)在選定坐標(biāo)軸上的中位數(shù)(median)為切分點(diǎn),這樣得到的 kd 樹(shù)是平衡的。注意,平衡的 kd 樹(shù)搜索時(shí)的效率未必是最優(yōu)的。

  5. 構(gòu)造平衡 kd 樹(shù):
    給定 k 維空間數(shù)據(jù)集 T=\{x_1,x_2,...,x_N\},其中 x_i=(x_i^{(1)}, x_i^{(2)},...,x_i^{(k)})^T,i=1,2,...,N
    1>> 開(kāi)始:構(gòu)造根結(jié)點(diǎn),根結(jié)點(diǎn)對(duì)應(yīng)于包含 T 的 k 維空間的超矩形區(qū)域。選擇 x^{(1)} 為坐標(biāo)軸,以 T 中所有實(shí)例的 x^{(1)} 坐標(biāo)的中位數(shù)為切分點(diǎn),將根結(jié)點(diǎn)對(duì)應(yīng)的超矩形區(qū)域切分為兩個(gè)子區(qū)域。切分由通過(guò)切分點(diǎn)并與坐標(biāo)軸 x^{(1)} 垂直的超平面實(shí)現(xiàn)。
    由根結(jié)點(diǎn)生成深度為 1 的左、右子結(jié)點(diǎn):左子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo) x^{(1)} 小于切分點(diǎn)的子區(qū)域,右子結(jié)點(diǎn)對(duì)應(yīng)于坐標(biāo) x^{(1)} 大于切分點(diǎn)的子區(qū)域。
    將落在切分超平面上的實(shí)例點(diǎn)保存在根結(jié)點(diǎn)。
    2>> 重復(fù):對(duì)深度為 j 的結(jié)點(diǎn),選擇 x^{(p)} 為切分的坐標(biāo)軸,p=j(luò)(mod\ k)+1,以該結(jié)點(diǎn)的區(qū)域中所有實(shí)例的 x^{(p)} 坐標(biāo)的中位數(shù)為切分點(diǎn),將該結(jié)點(diǎn)對(duì)應(yīng)的超矩形區(qū)域切分為兩個(gè)子區(qū)域。切分由通過(guò)切分點(diǎn)并與坐標(biāo)軸 x^{(p)} 垂直的超平面實(shí)現(xiàn)。
    由該結(jié)點(diǎn)生成深度為 j+1 的左、右子結(jié)點(diǎn):左子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo) x^{(p)} 小于切分點(diǎn)的子區(qū)域,右子結(jié)點(diǎn)對(duì)應(yīng)坐標(biāo) x^{(p)} 大于切分點(diǎn)的子區(qū)域。
    將落在切分超平面上的實(shí)例點(diǎn)保存在該結(jié)點(diǎn)。
    3>> 直到兩個(gè)子區(qū)域沒(méi)有實(shí)例存在時(shí)停止。從而形成kd樹(shù)的區(qū)域劃分。

  6. kd 樹(shù)示例:
    給定一個(gè)二維空間的數(shù)據(jù)集:T=\{(2,3)^T,(5,4)^T,(9,6)^T,(4,7)^T,(8,1)^T,(7,2)^T\}。
    根結(jié)點(diǎn)對(duì)應(yīng)包含數(shù)據(jù)集 T 的矩形,選擇 x^{(1)} 軸,6 個(gè)數(shù)據(jù)點(diǎn)的 x^{(1)} 坐標(biāo)的中位數(shù)是 7,以平面 x^{(1)}=7 將空間分為左、右兩個(gè)子矩形(子結(jié)點(diǎn));接著,左矩形以 x^{(2)}=4 分為兩個(gè)子矩形,右矩形以 x^{(2)}=6 分為兩個(gè)子矩形,如此遞歸,得到如下特征空間劃分圖及 kd 樹(shù)圖。

搜索 kd 樹(shù)

  1. 在 kd 樹(shù)中找出包含目標(biāo)點(diǎn) x 的葉結(jié)點(diǎn):從根結(jié)點(diǎn)出發(fā),遞歸地向下訪問(wèn) kd 樹(shù)。若目標(biāo)點(diǎn) x 當(dāng)前維的坐標(biāo)小于切分點(diǎn)的坐標(biāo),則移動(dòng)到左子結(jié)點(diǎn),否則移動(dòng)到右子結(jié)點(diǎn)。直到子結(jié)點(diǎn)為葉結(jié)點(diǎn)為止。

  2. 以此葉結(jié)點(diǎn)為“當(dāng)前最近點(diǎn)”。

  3. 遞歸地向上回退,在每個(gè)結(jié)點(diǎn)進(jìn)行以下操作:
    1>> 如果該結(jié)點(diǎn)保存的實(shí)例點(diǎn)比當(dāng)前最近點(diǎn)距離目標(biāo)點(diǎn)更近,則以該實(shí)例點(diǎn)為“當(dāng)前最近點(diǎn)”。
    2>> 當(dāng)前最近點(diǎn)一定存在于該結(jié)點(diǎn)一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域。檢查該子結(jié)點(diǎn)的父結(jié)點(diǎn)的另一子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)。具體地,檢查另一子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否與以目標(biāo)點(diǎn)為球心、以目標(biāo)點(diǎn)與“當(dāng)前最近點(diǎn)”間的距離為半徑的超球體相交。如果相交,可能在另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域內(nèi)存在距目標(biāo)點(diǎn)更近的點(diǎn),移動(dòng)到另一個(gè)子結(jié)點(diǎn)。接著,遞歸地進(jìn)行最近鄰搜索;如果不相交,向上回退。

  4. 當(dāng)回退到根結(jié)點(diǎn)時(shí),搜索結(jié)束。最后的“當(dāng)前最近點(diǎn)”即為 x 的最近鄰點(diǎn)。

  5. kd 樹(shù)更適用于訓(xùn)練實(shí)例數(shù)遠(yuǎn)大于空間維數(shù)時(shí)的 k 近鄰搜索。當(dāng)空間維數(shù)接近訓(xùn)練實(shí)例數(shù)時(shí),它的效率會(huì)迅速下降,幾乎接近線性掃描。

  6. kd 樹(shù)搜索示例:
    給定一個(gè)如下圖所示的 kd 樹(shù),根結(jié)點(diǎn)為 A,其子結(jié)點(diǎn)為 B,C 等。樹(shù)上共存儲(chǔ) 7 個(gè)實(shí)例點(diǎn);另有一個(gè)輸入目標(biāo)實(shí)例點(diǎn) S,求 S 的最近鄰。
    解:首先在 kd 樹(shù)中找到包含點(diǎn) S 的葉結(jié)點(diǎn) D(圖中的右下區(qū)域),以點(diǎn) D 作為近似最近鄰。真正最近鄰一定在以點(diǎn) S 為中心通過(guò)點(diǎn) D 的圓的內(nèi)部。然后返回結(jié)點(diǎn) D 的父結(jié)點(diǎn) B,在結(jié)點(diǎn) B 的另一子結(jié)點(diǎn) F 的區(qū)域內(nèi)搜索最近鄰。結(jié)點(diǎn) F 的區(qū)域與圓不相交,不可能有最近鄰點(diǎn)。繼續(xù)返回上一級(jí)父結(jié)點(diǎn) A,在結(jié)點(diǎn) A 的另一子結(jié)點(diǎn) C 的區(qū)域內(nèi)搜索最近鄰。結(jié)點(diǎn) C 的區(qū)域與圓相交;該區(qū)域在圓內(nèi)的實(shí)例點(diǎn)有點(diǎn) E,點(diǎn) E 比點(diǎn) D 更近,成為新的最近鄰近似。最后得到點(diǎn) E 是點(diǎn) S 的最近鄰。

k 近鄰模型實(shí)現(xiàn)

k 近鄰模型實(shí)現(xiàn)

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter


class Knn(object):
    def __init__(self, x_train, y_train, k=3, p=2):
        self.k = k
        self.p = p
        self.x_train = x_train
        self.y_train = y_train
    
    # 計(jì)算 Lp 距離
    def lp(self, x, y, p=2):
        if len(x) == len(y) and len(x) > 1:
            _sum = sum([math.pow(abs(x[i] - y[i]), p) for i in range(len(x))])
            return math.pow(_sum, 1/p)
        else:
            return 0
    
    # 預(yù)測(cè)
    def predict(self, x):
        # 計(jì)算 x 到其他節(jié)點(diǎn)的距離
        dists = [(np.linalg.norm(x - self.x_train[i], ord=self.p), self.y_train[i])
                 for i in range(len(x_train))]
        # 排序篩選出最近的 k 個(gè)節(jié)點(diǎn)
        nodes = sorted(dists, key=lambda x: x[0])[:self.k]
        # 對(duì)最近的 k 個(gè)節(jié)點(diǎn)的所屬類別進(jìn)行計(jì)數(shù)
        counter = Counter([node[-1] for node in nodes])
        # 返回計(jì)數(shù)最多的一個(gè)類別
        return counter.most_common(1)[0][0]
    
    # 計(jì)算模型的準(zhǔn)確率
    def score(self, x_test, y_test):
        rights = 0
        for x, y in zip(x_test, y_test):
            label = self.predict(x)
            if label == y:
                rights += 1
        return rights / len(x_test)


if __name__ == '__main__':
    # 獲取鳶尾花數(shù)據(jù)集
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['label'] = iris.target
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
    # 生成訓(xùn)練樣本,只取 sepal length,sepal width 作為樣本特征
    train = np.array(df.iloc[:100, [0, 1, -1]])
    x, y = train[:, :-1], train[:, -1]
    # 劃分訓(xùn)練集與測(cè)試集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    # 生成 knn 模型,并使用測(cè)試集進(jìn)行驗(yàn)證
    knn = Knn(x_train, y_train)
    print('測(cè)試集驗(yàn)證模型準(zhǔn)確率為:%s' % knn.score(x_test, y_test))
    # 對(duì)點(diǎn) [6.0, 3.0] 進(jìn)行分類
    print('點(diǎn) [6.0, 3.0] 分類結(jié)果為: %s' % knn.predict([6.0, 3.0]))
    # 繪圖
    plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
    plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
    plt.plot(6.0, 3.0, 'bo', label='[6.0, 3.0]')
    plt.xlabel('sepal length')
    plt.ylabel('sepal width')
    plt.legend()

運(yùn)行結(jié)果

也可以使用 sklearn 實(shí)現(xiàn) knn 模型

from sklearn.neighbors import KNeighborsClassifier
# KNeighborsClassifier 參數(shù)
# n_neighbors:臨近點(diǎn)個(gè)數(shù)
# p:距離度量
# algorithm:近鄰算法,可選{'auto', 'ball_tree', 'kd_tree', 'brute'}
# weights:近鄰的權(quán)重
knn_sk = KNeighborsClassifier()
knn_sk.fit(x_train, y_train)
knn_sk.score(x_test, y_test)

kd 樹(shù)搜索單個(gè)近鄰點(diǎn)

from math import sqrt
from collections import namedtuple


class KdNode(object):
    # 節(jié)點(diǎn)數(shù)據(jù)結(jié)構(gòu)
    def __init__(self, node, split, left, right):
        self.node = node # k 維向量節(jié)點(diǎn)
        self.split = split # 分割維度序號(hào)
        self.left = left # 左樹(shù)
        self.right = right # 右樹(shù)


class KdTree(object):
    def __init__(self, data):
        self.k = len(data[0]) # 數(shù)據(jù)維度
        self.root = self.create_kdnode(0, data) # 生成 kd 樹(shù)
        
    def create_kdnode(self, split, data):
        if not data: return None
        data.sort(key=lambda x: x[split])
        split_index = len(data) // 2
        median = data[split_index] # 中位數(shù)分割點(diǎn) 
        split_next = (split + 1) % self.k
        return KdNode(
            median,
            split,
            self.create_kdnode(split_next, data[:split_index]),
            self.create_kdnode(split_next, data[split_index + 1:])
        )
    
    # 對(duì)于構(gòu)建好的 kd 樹(shù) tree,尋找離 node 最近的節(jié)點(diǎn)
    def nearest(self, node):
        k = len(node)
        nearest_tuple = namedtuple("nearest_tuple", "nearest_node nearest_dist visited")
        
        def travel(kd_node, target, nearest_dist):
            # python中用float("inf")和float("-inf")表示正負(fù)無(wú)窮
            if kd_node is None: return nearest_tuple([0] * k, float("inf"), 0)
            
            visited = 1
            # 分割維度
            split = kd_node.split
            # 分割節(jié)點(diǎn)
            node = kd_node.node

            # 如果目標(biāo)點(diǎn)第 split 維小于分割節(jié)點(diǎn)的對(duì)應(yīng)值,則目標(biāo)離左子樹(shù)更近,否則離右子樹(shù)更近
            if target[split] <= node[split]:
                nearer_node = kd_node.left
                further_node = kd_node.right
            else:
                nearer_node = kd_node.right
                further_node = kd_node.left

            # 進(jìn)行遍歷找到包含目標(biāo)點(diǎn)的區(qū)域
            nearest_tuple_1 = travel(nearer_node, target, nearest_dist)
            nearest = nearest_tuple_1.nearest_node
            dist = nearest_tuple_1.nearest_dist
            visited += nearest_tuple_1.visited
            # 更新最近距離
            if dist < nearest_dist:
                nearest_dist = dist
            # 第 split 維上目標(biāo)點(diǎn)與分割超平面的距離
            temp_dist = abs(node[split] - target[split])
            # 判斷超球體是否與超平面相交
            if  temp_dist > nearest_dist:
                # 相交則可以直接返回,不用繼續(xù)判斷
                return nearest_tuple(nearest, dist, visited)
            
            # 計(jì)算目標(biāo)點(diǎn)與分割點(diǎn)的歐氏距離
            temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(node, target)))
            # 如果更近,更新最近距離
            if temp_dist < nearest_dist: 
                nearest = node
                nearest_dist = temp_dist

            # 檢查另一個(gè)子結(jié)點(diǎn)對(duì)應(yīng)的區(qū)域是否有更近的點(diǎn)
            nearest_tuple_2 = travel(further_node, target, nearest_dist) 
            visited += nearest_tuple_2.visited
            if nearest_tuple_2.nearest_dist < nearest_dist:  
                nearest = nearest_tuple_2.nearest_node
                nearest_dist = nearest_tuple_2.nearest_dist
            return nearest_tuple(nearest, nearest_dist, visited)
        # 從根節(jié)點(diǎn)開(kāi)始遞歸
        return travel(self.root, node, float("inf"))


if __name__ == '__main__':
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    kd = KdTree(data)
    print(kd.nearest([3, 5]))

運(yùn)行結(jié)果

nearest_tuple(nearest_node=[4, 7], nearest_dist=2.23606797749979, visited=4)

kd 樹(shù)搜索 k 個(gè)近鄰點(diǎn)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import sqrt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter


class KdNode(object):
    def __init__(self, data, depth=0, lchild=None, rchild=None):
        self.data = data
        self.depth = depth
        self.lchild = lchild
        self.rchild = rchild


class KdTree(object):
    def __init__(self):
        self.n = 0
        self.tree = None
        self.nearest = None

    def create(self, data_set, depth=0):
        if len(data_set) > 0:
            m, n = np.shape(data_set)
            self.n, axis, mid = n - 1, depth % (n - 1), int(m / 2)
            data_set_sorted = sorted(data_set, key=lambda x: x[axis])
            node = KdNode(data_set_sorted[mid], depth)
            if depth == 0: self.tree = node
            node.lchild = self.create(data_set_sorted[:mid], depth+1)
            node.rchild = self.create(data_set_sorted[mid+1:], depth+1)
            return node
        return None

    # 搜索 kdtree 的前 k 個(gè)最近點(diǎn)
    def search(self, x, k=1):
        nearest = []
        for i in range(k):
            nearest.append([-1, None])
        # 初始化 n 個(gè)點(diǎn),nearest 是按照距離遞減的方式
        self.nearest = np.array(nearest)

        def travel(node):
            if node is not None:
                # 當(dāng)前點(diǎn)的維度 axis
                axis = node.depth % self.n
                # x 點(diǎn)和當(dāng)前點(diǎn)在 axis 維度上的差
                daxis = x[axis] - node.data[axis]
                # 如果小于進(jìn)左子樹(shù),大于進(jìn)右子樹(shù)
                if daxis < 0:
                    travel(node.lchild)
                else:
                    travel(node.rchild)

                # 計(jì)算 x 點(diǎn)到當(dāng)前點(diǎn)的距離 dist
                dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
                for i, d in enumerate(self.nearest):
                    # 如果有比現(xiàn)在最近的 n 個(gè)點(diǎn)更近的點(diǎn),更新最近的點(diǎn)
                    if d[0] < 0 or dist < d[0]:
                        # 插入第 i 個(gè)位置的點(diǎn)
                        self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
                        # 刪除最后一個(gè)多出來(lái)的點(diǎn)
                        self.nearest = self.nearest[:-1]
                        break
                
                # 統(tǒng)計(jì)距離為 -1 的個(gè)數(shù) n
                n = list(self.nearest[:, 0]).count(-1)
                # self.nearest[-n-1, 0] 是當(dāng)前 nearest 中已經(jīng)有的最近點(diǎn)中,距離最大的點(diǎn)
                # self.nearest[-n-1, 0] > abs(daxis) 代表以 x 點(diǎn)為圓心,self.nearest[-n-1, 0]為半徑的圓
                # 與 axis 相交,說(shuō)明在左右子樹(shù)里面可能有比 self.nearest[-n-1, 0] 更近的點(diǎn)
                if self.nearest[-n-1, 0] > abs(daxis):
                    if daxis < 0:
                        travel(node.rchild)
                    else:
                        travel(node.lchild)

        travel(self.tree)
        # nodes 就是最近 k 個(gè)點(diǎn)
        nodes = self.nearest[:, 1]
        counter = Counter([node.data[-1] for node in nodes])
        return self.nearest, counter.most_common(1)[0][0]


if __name__ == '__main__':
    # 獲取鳶尾花數(shù)據(jù)集

    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['label'] = iris.target
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
    # 生成訓(xùn)練樣本,只取 sepal length,sepal width 作為樣本特征
    data = np.array(df.iloc[:100, [0, 1, -1]])
    # 劃分訓(xùn)練集與測(cè)試集
    train, test = train_test_split(data, test_size=0.1)
    x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0])
    x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])
        
    # 生成 knn 模型,并使用測(cè)試集進(jìn)行驗(yàn)證
    kdt = KdTree()
    kdt.create(train)

    score = 0
    for x in test:
        # 繪制訓(xùn)練點(diǎn)
        plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]')
        plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]')
        plt.xlabel('sepal length')
        plt.ylabel('sepal width')
         # 繪制測(cè)試點(diǎn)
        plt.scatter(x[0], x[1], c='red', marker='x')
        # 設(shè)置臨近點(diǎn)的個(gè)數(shù)
        nearest, belong = kdt.search(x[:-1], 5)
        if belong == x[-1]: score += 1
        print("test:", x, "predict:", belong)
        print("nearest:", nearest)
        for near in nearest:
            # k 個(gè)近鄰點(diǎn)
            plt.scatter(near[1].data[0], near[1].data[1], c='green', marker='+')
        plt.legend()
        plt.show()

    score /= len(test)
    print("score:", score)

運(yùn)行結(jié)果(部分截圖)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容