KD-Tree 算法的 C++ 實現

閱讀本文前,建議查閱相關資料,了解 KNN 算法與 KD 樹。

基礎知識

如圖所示,假設一個點 a 目前的最近鄰點為 b,如果存在相對于 ba 更近的點,那么這個點一定在以 a 為圓心,ab 為半徑的圓內。
現右側的區(qū)域是未知的,如果 a 到分界線的距離 l 大于目前的最近距離 L(圓半徑),則沒有必要在右側的未知區(qū)域繼續(xù)尋找最近鄰點(如圖一),反之,則要繼續(xù)尋找(如圖二)。
相應的,投射到多維空間,假如切分邊界為第 i 維,切分點的值為 v(標量),當前最近鄰點為 y(向量),如果目標點 x(向量) 到切分邊界的距離 |x[i] - v| 滿足以下關系


時,需要在另一側繼續(xù)搜索。

圖1:不需要在右側未知區(qū)域繼續(xù)搜索的情況

圖2:需要在右側未知區(qū)域繼續(xù)搜索的情況

通常地,一個機器學習算法分為 fitpredict 兩個階段,基于線性搜索的 KNN 是一種惰性算法,它將全部的計算任務放到了 predict 階段,predict 的時間復雜度為 O(n),KD 樹之所以比線性搜索快,就是因為它將一部分任務放到了 fit(建立 KD 樹) 階段,從而在搜索時可以略去大量不必搜索的結點(最優(yōu)情況下時間復雜度為 O(1))。
上面說的比較簡單,關于 KNN 算法和 KD 樹的詳細內容,請參考李航博士的《統計學習方法》。

代碼

我們給出部分關鍵性的代碼。

基本數據結構

  • 訓練集用一個一維數組 double *data 表示,它的長度為 n_samples * n_features,標簽集也用一個一維數組 double *labels 表示,它的長度為 n_samples
  • 樹的結點用以下數據結構表示
     struct tree_node
     {
         size_t id;               // 表示訓練集中的第 i 個數據
         size_t split;            // 切分的維度
         tree_node *left, *right; // 左、右子樹
     };
    
  • 一個 KD 樹的模型可用以下結構表示
     struct tree_model
     {
         tree_node *root;        // 根結點
         const double *datas;    // X
         const double *labels;   // y
         size_t n_samples;       // 樣例數
         size_t n_features;      // 每個樣例的特征數
         double p;               // 距離度量
     };
    
  • 求 K-近鄰時需要用到大頂堆,我們直接用 C++ 的優(yōu)先隊列來表示,堆內現有的 n(n <= k) 個近鄰點中,距離測試點最遠的在堆頂
    struct neighbor_heap_cmp {
        bool operator()(const std::tuple<size_t, double> &i, 
                        const std::tuple<size_t, double> &j) {
              return std::get<1>(i) < std::get<1>(j);
          }
      };
    
    typedef std::tuple<size_t, double> neighbor;
    typedef std::priority_queue<neighbor,
            std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_;
    
    neighbor_heap k_neighbor_heap_;
    

KD-Tree 類

我們用類 KDTree 表示一個 KD 樹類,它應該具有的功能有建樹搜索。

//(簡化的代碼,完整的代碼詳見最后)
class KDTree {
public:
    // 建樹
    KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
    // 返回樹
    tree_node *GetRoot() { return root; }
    // 求一個測試點的 k 鄰
    std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k);
private:
    tree_node *root_;
}

尋找切分維和切分點

在建樹之前,我們還要考慮如何選擇切分維度和切分點。切分維度的選擇有許多,一般的,可以取 dim = floor % n_features,即當前樹的層數對特征數取余,我們在這里使用 dim = argmax(nmax - nmin),即選取當前結點集合中極差最大的維度。

(這里是不完整的代碼,有些工具函數的定義請詳見完整源代碼)
size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
    if (points.size() == 1)
        return 0;
    size_t cur_best_dim = 0;
    double cur_largest_spread = -1;
    double cur_min_val;
    double cur_max_val;
    for (size_t dim = 0; dim < n_features; ++dim) {
        cur_min_val = GetDimVal(points[0], dim);
        cur_max_val = GetDimVal(points[0], dim);
        for (const auto &id : points) {
            if (GetDimVal(id, dim) > cur_max_val)
                cur_max_val = GetDimVal(id, dim);
            else if (GetDimVal(id, dim) < cur_min_val)
                cur_min_val = GetDimVal(id, dim);
        }

        if (cur_max_val - cur_min_val > cur_largest_spread) {
            cur_largest_spread = cur_max_val - cur_min_val;
            cur_best_dim = dim;
        }
    }
    return cur_best_dim;
}

選擇完切分維 k 之后,我們需選取當前結點集合中的結點在第 k 維的值的中位數 x 作為切分點的值,除去該點之外的點,第 k 維的值小于等于 x 的,放入左子樹,反之放入右子樹。
在求中位數時,不要全排序,然后取中間的點,可以采用類似快排的方法,找到中位數時就停止排序,這里我們就不寫算法了,直接用 C++ 的函數。

std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
    size_t len = points.size();
    for (size_t i = 0; i < points.size(); ++i)
        get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
    std::nth_element(get_mid_buf_,
                     get_mid_buf_ + len / 2,
                     get_mid_buf_ + len,
                     [](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) {
                         return std::get<1>(i) < std::get<1>(j);
                     });
    return get_mid_buf_[len / 2];
}

建樹

建樹直接按照建立二叉樹的方法即可

tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
    size_t dim = FindSplitDim(points);
    std::tuple<size_t, double> t = MidElement(points, dim);
    size_t arg_mid_val = std::get<0>(t);
    double mid_val = std::get<1>(t);

    tree_node *node = Malloc(tree_node, 1);
    node->left = nullptr;
    node->right = nullptr;
    node->id = arg_mid_val;
    node->split = dim;
    std::vector<size_t> left, right;
    for (auto &i : points) {
        if (i == arg_mid_val)
            continue;
        if (GetDimVal(i, dim) <= mid_val)
            left.emplace_back(i);
        else
            right.emplace_back(i);
    }
    if (!left.empty())
        node->left = BuildTree(left);
    if (!right.empty())
        node->right = BuildTree(right);
    return node;
}

搜索 K-近鄰的規(guī)則

一般書上所講的都是搜索最近鄰,但是我們這里是搜索 K-近鄰,需要對書上的算法做少許的擴充。
搜索最近鄰時,我們一般設置兩個變量 cur_min_idcur_min_dist,如果當前搜索到的點到測試點的距離 l < cur_min_dist 時,我們將上述兩個變量更新為新點的 iddist
相應的,在搜索 K-近鄰時,我們可以設置一個最多有 k 個元素的大頂堆,這樣,在搜索時,當堆滿時,只需比較當前搜索點的 dist 是否小于堆頂點的 dist,如果小于,堆頂出堆,并將當前搜索點壓入,反之,則不變;當堆未滿時,直接將該搜索點壓入。

搜索 K-近鄰的算法

我們直接使用二叉樹深度優(yōu)先遍歷的非遞歸算法(具體的描述詳見《統計學習方法》第 43 頁算法 3.3)。

std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) {
    std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
    std::stack<tree_node *> paths;
    tree_node *p = root;

    while (p) {
        HeapStackPush(paths, p, coor, k);
        p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
    }
    while (!paths.empty()) {
        p = paths.top();
        paths.pop();

        if (!p->left && !p->right)
            continue;

        if (k_neighbor_heap_.size() < k) {
            if (p->left)
                HeapStackPush(paths, p->left, coor, k);
            if (p->right)
                HeapStackPush(paths, p->right, coor, k);
        } else {
            double node_split_val = GetDimVal(p->id, p->split);
            double coor_split_val = coor[p->split];
            double heap_top_val = std::get<1>(k_neighbor_heap_.top());
            if (coor_split_val > node_split_val) {
                if (p->right)
                    HeapStackPush(paths, p->right, coor, k);
                if ((coor_split_val - node_split_val) < heap_top_val && p->left)
                    HeapStackPush(paths, p->left, coor, k);
            } else {
                if (p->left)
                    HeapStackPush(paths, p->left, coor, k);
                if ((node_split_val - coor_split_val) < heap_top_val && p->right)
                    HeapStackPush(paths, p->right, coor, k);
            }
        }
    }
    std::vector<std::tuple<size_t, double>> res;

    while (!k_neighbor_heap_.empty()) {
        res.emplace_back(k_neighbor_heap_.top());
        k_neighbor_heap_.pop();
    }
    return res;
}

完整代碼

詳見 https://github.com/WiseDoge/libkdtree
完整代碼中除了 KD-Tree 的代碼外,還給出了測試代碼和 Python 接口代碼,以及一些調用第三方庫來加速的手段。

?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容