閱讀本文前,建議查閱相關資料,了解 KNN 算法與 KD 樹。
基礎知識
如圖所示,假設一個點 a 目前的最近鄰點為 b,如果存在相對于 b 離 a 更近的點,那么這個點一定在以 a 為圓心,ab 為半徑的圓內。
現右側的區(qū)域是未知的,如果 a 到分界線的距離 l 大于目前的最近距離 L(圓半徑),則沒有必要在右側的未知區(qū)域繼續(xù)尋找最近鄰點(如圖一),反之,則要繼續(xù)尋找(如圖二)。
相應的,投射到多維空間,假如切分邊界為第 i 維,切分點的值為 v(標量),當前最近鄰點為 y(向量),如果目標點 x(向量) 到切分邊界的距離 |x[i] - v| 滿足以下關系

時,需要在另一側繼續(xù)搜索。


通常地,一個機器學習算法分為
fit 和 predict 兩個階段,基于線性搜索的 KNN 是一種惰性算法,它將全部的計算任務放到了 predict 階段,predict 的時間復雜度為 O(n),KD 樹之所以比線性搜索快,就是因為它將一部分任務放到了 fit(建立 KD 樹) 階段,從而在搜索時可以略去大量不必搜索的結點(最優(yōu)情況下時間復雜度為 O(1))。上面說的比較簡單,關于 KNN 算法和 KD 樹的詳細內容,請參考李航博士的《統計學習方法》。
代碼
我們給出部分關鍵性的代碼。
基本數據結構
- 訓練集用一個一維數組
double *data表示,它的長度為n_samples * n_features,標簽集也用一個一維數組double *labels表示,它的長度為n_samples。 - 樹的結點用以下數據結構表示
struct tree_node { size_t id; // 表示訓練集中的第 i 個數據 size_t split; // 切分的維度 tree_node *left, *right; // 左、右子樹 }; - 一個 KD 樹的模型可用以下結構表示
struct tree_model { tree_node *root; // 根結點 const double *datas; // X const double *labels; // y size_t n_samples; // 樣例數 size_t n_features; // 每個樣例的特征數 double p; // 距離度量 }; - 求 K-近鄰時需要用到大頂堆,我們直接用 C++ 的優(yōu)先隊列來表示,堆內現有的
n(n <= k)個近鄰點中,距離測試點最遠的在堆頂struct neighbor_heap_cmp { bool operator()(const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) { return std::get<1>(i) < std::get<1>(j); } }; typedef std::tuple<size_t, double> neighbor; typedef std::priority_queue<neighbor, std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_; neighbor_heap k_neighbor_heap_;
KD-Tree 類
我們用類 KDTree 表示一個 KD 樹類,它應該具有的功能有建樹和搜索。
//(簡化的代碼,完整的代碼詳見最后)
class KDTree {
public:
// 建樹
KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
// 返回樹
tree_node *GetRoot() { return root; }
// 求一個測試點的 k 鄰
std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k);
private:
tree_node *root_;
}
尋找切分維和切分點
在建樹之前,我們還要考慮如何選擇切分維度和切分點。切分維度的選擇有許多,一般的,可以取 dim = floor % n_features,即當前樹的層數對特征數取余,我們在這里使用 dim = argmax(nmax - nmin),即選取當前結點集合中極差最大的維度。
(這里是不完整的代碼,有些工具函數的定義請詳見完整源代碼)
size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
if (points.size() == 1)
return 0;
size_t cur_best_dim = 0;
double cur_largest_spread = -1;
double cur_min_val;
double cur_max_val;
for (size_t dim = 0; dim < n_features; ++dim) {
cur_min_val = GetDimVal(points[0], dim);
cur_max_val = GetDimVal(points[0], dim);
for (const auto &id : points) {
if (GetDimVal(id, dim) > cur_max_val)
cur_max_val = GetDimVal(id, dim);
else if (GetDimVal(id, dim) < cur_min_val)
cur_min_val = GetDimVal(id, dim);
}
if (cur_max_val - cur_min_val > cur_largest_spread) {
cur_largest_spread = cur_max_val - cur_min_val;
cur_best_dim = dim;
}
}
return cur_best_dim;
}
選擇完切分維 k 之后,我們需選取當前結點集合中的結點在第 k 維的值的中位數 x 作為切分點的值,除去該點之外的點,第 k 維的值小于等于 x 的,放入左子樹,反之放入右子樹。
在求中位數時,不要全排序,然后取中間的點,可以采用類似快排的方法,找到中位數時就停止排序,這里我們就不寫算法了,直接用 C++ 的函數。
std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
size_t len = points.size();
for (size_t i = 0; i < points.size(); ++i)
get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
std::nth_element(get_mid_buf_,
get_mid_buf_ + len / 2,
get_mid_buf_ + len,
[](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) {
return std::get<1>(i) < std::get<1>(j);
});
return get_mid_buf_[len / 2];
}
建樹
建樹直接按照建立二叉樹的方法即可
tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
size_t dim = FindSplitDim(points);
std::tuple<size_t, double> t = MidElement(points, dim);
size_t arg_mid_val = std::get<0>(t);
double mid_val = std::get<1>(t);
tree_node *node = Malloc(tree_node, 1);
node->left = nullptr;
node->right = nullptr;
node->id = arg_mid_val;
node->split = dim;
std::vector<size_t> left, right;
for (auto &i : points) {
if (i == arg_mid_val)
continue;
if (GetDimVal(i, dim) <= mid_val)
left.emplace_back(i);
else
right.emplace_back(i);
}
if (!left.empty())
node->left = BuildTree(left);
if (!right.empty())
node->right = BuildTree(right);
return node;
}
搜索 K-近鄰的規(guī)則
一般書上所講的都是搜索最近鄰,但是我們這里是搜索 K-近鄰,需要對書上的算法做少許的擴充。
搜索最近鄰時,我們一般設置兩個變量 cur_min_id 和 cur_min_dist,如果當前搜索到的點到測試點的距離 l < cur_min_dist 時,我們將上述兩個變量更新為新點的 id 和 dist。
相應的,在搜索 K-近鄰時,我們可以設置一個最多有 k 個元素的大頂堆,這樣,在搜索時,當堆滿時,只需比較當前搜索點的 dist 是否小于堆頂點的 dist,如果小于,堆頂出堆,并將當前搜索點壓入,反之,則不變;當堆未滿時,直接將該搜索點壓入。
搜索 K-近鄰的算法
我們直接使用二叉樹深度優(yōu)先遍歷的非遞歸算法(具體的描述詳見《統計學習方法》第 43 頁算法 3.3)。
std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) {
std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
std::stack<tree_node *> paths;
tree_node *p = root;
while (p) {
HeapStackPush(paths, p, coor, k);
p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
}
while (!paths.empty()) {
p = paths.top();
paths.pop();
if (!p->left && !p->right)
continue;
if (k_neighbor_heap_.size() < k) {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if (p->right)
HeapStackPush(paths, p->right, coor, k);
} else {
double node_split_val = GetDimVal(p->id, p->split);
double coor_split_val = coor[p->split];
double heap_top_val = std::get<1>(k_neighbor_heap_.top());
if (coor_split_val > node_split_val) {
if (p->right)
HeapStackPush(paths, p->right, coor, k);
if ((coor_split_val - node_split_val) < heap_top_val && p->left)
HeapStackPush(paths, p->left, coor, k);
} else {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if ((node_split_val - coor_split_val) < heap_top_val && p->right)
HeapStackPush(paths, p->right, coor, k);
}
}
}
std::vector<std::tuple<size_t, double>> res;
while (!k_neighbor_heap_.empty()) {
res.emplace_back(k_neighbor_heap_.top());
k_neighbor_heap_.pop();
}
return res;
}
完整代碼
詳見 https://github.com/WiseDoge/libkdtree
完整代碼中除了 KD-Tree 的代碼外,還給出了測試代碼和 Python 接口代碼,以及一些調用第三方庫來加速的手段。