KNN的實現(xiàn):kd樹(python)

在尋找輸入樣本的k個近鄰的時候,若進行線性掃描,對于大數(shù)據(jù)集來說耗時太久,為了加快搜索速度,提出了用kd樹實現(xiàn)k個近鄰的搜索,此時復雜度為O(logN)。

首先是建樹

這里假設輸入數(shù)據(jù)一個N×K的矩陣,N代表實例點的個數(shù),K代表樣本空間的維度。每一行代表一個實例點。

每個節(jié)點包含六個屬性:

  • SamplePoints:實例點的行號,表示該節(jié)點對應區(qū)域包含的所有實例點
  • SplitDim:切割對應的區(qū)域時選擇的特征(維度)
  • MidPoint:是一個元組,(切分點的行號,切分特征的中位數(shù))
  • left:指向左子節(jié)點
  • right:指向右子節(jié)點
  • father:指向父節(jié)點
  • visited:該節(jié)點是否已被訪問的標志

包含兩個方法:

  • get_median():獲取切割特征的中位數(shù)
  • get_dim():獲取方差最大的特征作為切割特征

過程如下:

  1. 構造根節(jié)點,使根節(jié)點對應于k維空間中包含所有實例點的超矩形區(qū)域;
  2. 在超矩形區(qū)域上選擇一個坐標軸和在一個切分點,確定一個超平面,這個超平面通過選定的切分點并垂直于選定的坐標軸,將當前超矩形區(qū)域切分為左右兩個子區(qū)域(子節(jié)點);這時,實例被分到兩個子區(qū)域。
  3. 將切分點保存在根節(jié)點上。
  4. 重復步驟2、3,直到子區(qū)域內(nèi)只含有不含實例時終止。
import numpy as np
from collections import Counter

class KdTreeNode(object):
    def __init__(self, SamplePoints):
        self.SamplePoints = SamplePoints  
        self.SplitDim = self.get_dim() 
        self.MidPoint = self.get_MidPoint() 
        self.left = None 
        self.right = None 
        self.father = None 
        self.visited = False 

    def get_dim(self):
        variance = np.var(X[self.SamplePoints, :], axis = 0) #計算該節(jié)點包含的實例點每個特征的方差
        #print(variance)
        return np.argmax(variance) #選擇方差最大的特征

    def get_MidPoint(self):
        tmp = X[self.SamplePoints, self.SplitDim]
        length = len(tmp)
        index = np.argsort(tmp) #該函數(shù)返回的是數(shù)組值從小到大的索引值
        return (self.SamplePoints[index[int(length/2)]], tmp[index[int(length/2)]])  #(中位數(shù)所在的行號,中位數(shù)的值)
        

def build_tree(SamplePoints, father = None):  #構建kd樹
    if len(SamplePoints) == 0: #子區(qū)域不含實例點時停止
        return None
    root = KdTreeNode(SamplePoints)
    LeftPoints = []  #分割區(qū)域依據(jù)的特征小于或等于median的實例點
    RightPoints = [] #分割區(qū)域依據(jù)的特征大于median的實例點
    for x in SamplePoints:
        if x == root.MidPoint[0]:
            continue
        if X[x, root.SplitDim] <= root.MidPoint[1]:
            LeftPoints.append(x)
        else:
            RightPoints.append(x)
    root.father = father
    if len(SamplePoints) > 1: #子區(qū)域只含一個點時停止
        root.left = build_tree(LeftPoints, root) #構建左子樹
        root.right = build_tree(RightPoints, root) #構建右子樹
    return root

最近鄰搜索

  1. 從根節(jié)點出發(fā),遞歸地向下訪問kd樹。若目標點x當前維(即切割根節(jié)點對應區(qū)域時選擇的維度)的坐標小于或等于切分點的坐標,則移動到左子節(jié)點,否則移動到右子節(jié)點。直到子節(jié)點為葉節(jié)點為止,記此葉節(jié)點為L。
  2. 以此葉節(jié)點L上的切分點為“當前最近點Ncur”,記錄Ncur與目標點的距離為Dcur。
  3. 判斷L的父節(jié)點是否已被訪問。
    3.1. 若未被訪問,檢查L的父節(jié)點的另一子節(jié)點(即L的兄弟節(jié)點)對應的區(qū)域是否與以目標點為球心以Dcur為半徑的超球體相交。具體做法是在分割L的父節(jié)點區(qū)域時選擇的維度上計算目標點與切分點的坐標差值的絕對值,然后將其與Dcur比較。
    a) 若大于Dur,說明不相交。則標記L的父節(jié)點已被訪問,回到此步驟的開頭。
    b) 若小于或等于Dcur,說明相交。先計算L的父節(jié)點上的切分點與目標點的距離,檢查是否要更新Pcur與Dcur,完成后標記L的父節(jié)點已被訪問。從L的兄弟節(jié)點出發(fā),按照步驟1找到一個新的葉節(jié)點L。計算L上的切分點與目標點的距離,檢查是否要更新Pcur與Dcur,完成后回到此步驟的開頭。
    3.2 若已被訪問,判斷L的父節(jié)點是否為根節(jié)點。
    a) 若是,則停止整個程序。Pcur即為目標點的最近鄰。
    b) 若不是,則回退到L的父節(jié)點,具做法為令L=L的父節(jié)點,然后回到此步驟的開頭。
def approx_nearest_neighbor(root, TargetPoint): #尋找樹中與目標點的近似最近鄰點,該最似最近鄰僅僅是與目標點在同一分區(qū)中,不一定是最近鄰
    if root.left == None and root.right == None:
        return root
    if TargetPoint[root.SplitDim] <= root.MidPoint[1]:
        if root.left == None: #若應往左子樹走時發(fā)現(xiàn)左子樹為空,轉向右子樹搜尋,保證最后返回的是一個葉節(jié)點
            return approx_nearest_neighbor(root.right, TargetPoint) 
        return approx_nearest_neighbor(root.left, TargetPoint)
    else:
        if root.right == None: #若應往右子樹走時發(fā)現(xiàn)左子樹為空,轉向左子樹搜尋
            return approx_nearest_neighbor(root.left, TargetPoint)
        return approx_nearest_neighbor(root.right, TargetPoint)


def nearest_neighbor_search(root, TargetPoint): #搜索與目標點的歐氏距離最小的樣本點
    Vis = approx_nearest_neighbor(root, TargetPoint) #表示以該節(jié)點為根節(jié)點的子樹已被搜索完成
    Ncur =  X[Vis.MidPoint[0], :]#開始時直接用近似最近鄰點作為當前最近鄰點
    Dcur = np.sqrt(np.sum(np.square(Ncur - TargetPoint))) #目標點與當前最近鄰的歐式距離
    if Vis == root: #當樣本空間中只有一個點則直接輸出該點,注意Vis是一個節(jié)點,Ncur是一個點向量
        return (Ncur, Dcur)
    while True:
        if not Vis.father.visited: #若Vis的父節(jié)點未被訪問
            VerticalDis = abs(TargetPoint[Vis.father.SplitDim] - Vis.father.MidPoint[1]) #目標點到以Vis父節(jié)點為切分點的分割超平面的垂直距離
            #若Vis的兄弟節(jié)點代表的區(qū)域與以目標點為圓心Dcur為半徑的圓相交
            if VerticalDis <= Dcur:
                EuclideanDis = np.sqrt(np.sum(np.square(X[Vis.father.MidPoint[0], :] - TargetPoint))) #Vis的父節(jié)點與目標點的距離
                if EuclideanDis < Dcur: #若比Dcur小,則將其作為當前最近鄰
                    Dcur = EuclideanDis
                    Ncur = X[Vis.father.MidPoint[0], :]
                Vis.father.visited = True #此節(jié)點已被訪問
                #尋找Vis的兄弟節(jié)點
                if Vis.father.left == Vis:
                    brother = Vis.father.right 
                else:
                    brother = Vis.father.left
                #若無兄弟節(jié)點,直接爬升到Vis的父節(jié)點
                if brother == None:
                    continue
                #若有兄弟節(jié)點
                Vis = approx_nearest_neighbor(brother, TargetPoint)
                EuclideanDis = np.sqrt(np.sum(np.square(X[Vis.MidPoint[0], :] - TargetPoint)))
                if EuclideanDis < Dcur:
                    Dcur = EuclideanDis
                    Ncur = X[Vis.MidPoint[0], :]
                continue
            #若不相交
            else:
                Vis.father.visited = True
        else: #若Vis的父節(jié)點已被訪問
            if Vis.father == root: #若根節(jié)點已被訪問,則結束搜索
                break
            else:
                Vis = Vis.father #向上爬升到Vis的父節(jié)點
    return (Ncur, Dcur)

K近鄰搜索

k近鄰的搜索與最近鄰搜索類似,不過程序中的“當前最近鄰Ncur”要改為“當前K近鄰Kcur”,它是一個二維列表,里面的每一行代表了K個近鄰點中的一個。在每次比較一個新的節(jié)點時,都需判斷是否要對它進行更新,用離目標點更近的點代替更遠的點。

def compare_dis(CurrentPoint, TargetPoint, Ncur, K): #計算樣本點與目標點的距離,若有必要的話對Ncur進行更新
    EuclideanDis = np.sqrt(np.sum(np.square(CurrentPoint - TargetPoint))) #計算歐式距離
    Ncur = sorted(Ncur, key = lambda x : -x[1]) #對Ncur中的K個點按照到目標點的距離從遠到近排序
    if EuclideanDis < Ncur[0][1]:  #如果當前目標點到目標點的距離比Ncur中最遠的點要近,則對Ncur進行更新
        Ncur = Ncur[1:K]
        Ncur.append((CurrentPoint, EuclideanDis))
    return Ncur

def k_neighbor_search(root, TargetPoint, K): #搜索與目標點的歐氏距離最小的K個樣本點
    Vis = approx_nearest_neighbor(root, TargetPoint) #Vis表示以該節(jié)點為根節(jié)點的子樹已被搜索完成
    Ncur = [] #存儲當前K個近鄰點
    for i in range(K): #用K個離目標點無窮遠的點作為Ncur的初始值
        Ncur.append((X[i,:], float('inf')))
    Ncur = compare_dis(X[Vis.MidPoint[0], :], TargetPoint, Ncur, K)
    if Vis == root: #當K=1時, 若樣本空間中只有一個點,則直接輸出該點
        return Ncur
    while True:
        if not Vis.father.visited: #若Vis的父節(jié)點未被訪問
            VerticalDis = abs(TargetPoint[Vis.father.SplitDim] - Vis.father.MidPoint[1]) #目標點到以Vis父節(jié)點為切分點的分割超平面的垂直距離
            #若Vis的兄弟節(jié)點代表的區(qū)域與以目標點為圓心Dcur為半徑的圓相交
            if VerticalDis <= sorted(Ncur, key = lambda x : -x[1])[0][1]:
                Ncur = compare_dis(X[Vis.father.MidPoint[0], :], TargetPoint, Ncur, K) #判斷Vis的父節(jié)點是否要加入到Ncur中
                Vis.father.visited = True #此節(jié)點已被訪問
                brother = Vis.father.right if Vis.father.left == Vis else Vis.father.left #尋找Vis的兄弟節(jié)點
                #若無兄弟節(jié)點,直接爬升到Vis的父節(jié)點
                if brother == None:
                    continue
                #若有兄弟節(jié)點
                Vis = approx_nearest_neighbor(brother, TargetPoint)
                Ncur = compare_dis(X[Vis.MidPoint[0], :], TargetPoint, Ncur, K)
                continue
            #若不相交
            else:
                Vis.father.visited = True
        else: #若Vis的父節(jié)點已被訪問
            if Vis.father == root: #若根節(jié)點已被訪問,則結束搜索
                break
            else:
                Vis = Vis.father #向上爬升到Vis的父節(jié)點
    return Ncur

測試程序

下圖中的紅色叉叉代表目標點。

#主程序
X = np.array([[2,3], 
    [5,4], 
    [9,6],
    [4,7],
    [8,1],
    [7,2]]) #存儲樣本向量
TargetPoint = np.array([8, 0]) #輸入目標點
root = build_tree(range(len(X))) #建樹
while True:
    K = int(input('Input K:').strip())  #若樣本點的個數(shù)沒有K個,需重新設定K
    if len(X) < K:
        print('Retry')
        continue
    break
Ncur = k_neighbor_search(root, TargetPoint, K)
for point in Ncur:
    print(point[0]) #輸出K個近鄰點的坐標
特征空間劃分
最后編輯于
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

  • 一.樸素貝葉斯 1.分類理論 樸素貝葉斯是一種基于貝葉斯定理和特征條件獨立性假設的多分類的機器學習方法,所...
    wlj1107閱讀 3,402評論 0 5
  • 1 KD-Tree 實現(xiàn)kNN算法時,最簡單的實現(xiàn)方法就是線性掃描,正如我們上一章節(jié)內(nèi)容介紹的一樣->K近鄰算法,...
    壯少Bryant閱讀 3,655評論 0 1
  • 保留初心,砥礪前行 k-nearest neighbor, k-NN是一種可以用于多分類和回歸的方法。knn是一...
    加勒比海鮮王閱讀 1,549評論 3 7
  • k 近鄰是什么 k 近鄰法是機器學習中最基本的分類和回歸方法,也稱為kNN算法。通常k近鄰法用于分類問題。k近鄰法...
    程序員Morgan閱讀 1,083評論 0 1
  • 雖然,題目在邏輯關系中,并沒有因果聯(lián)系。但我相信站在巨人的肩膀上,你一定比別人看的遠!兩個孩子的媽媽,更多了一份義...
    王玲玲Casey閱讀 415評論 1 3

友情鏈接更多精彩內(nèi)容