親手實(shí)現(xiàn)決策樹(一)

決策樹的建立

1.整體思路

準(zhǔn)備函數(shù)

  • 依據(jù)某個(gè)feature對數(shù)據(jù)進(jìn)行分割為set_1, set_2 --> divide_set
  • 分別對set_1, set_2的分類結(jié)果進(jìn)行統(tǒng)計(jì) --> unique_count
  • 根據(jù)統(tǒng)計(jì)的結(jié)果計(jì)算交叉熵 --> entropy

計(jì)算思路

  • 對數(shù)據(jù)的列進(jìn)行for循環(huán),選擇出gain最大的feature
  • 根據(jù)此feature進(jìn)行數(shù)據(jù)集的分割,然后再對set_1, set_2進(jìn)行遞歸
  • 直至gain為0或要進(jìn)一步判斷的子數(shù)據(jù)集為空

2.python實(shí)現(xiàn)

主干代碼

def build_tree(rows, scoref=entropy):
    # 基準(zhǔn)情況
    if len(rows) == 0:
        return DecisionNode()

    current_score = scoref(rows)      # 分類前的得分
    best_gain = 0.0
    best_criteria = None
    best_sets = None

    column_count = len(rows[0]) - 1    # 特征數(shù)量
    for col in range(column_count):
        # 在當(dāng)前列中生成一個(gè)由不同值構(gòu)成的序列
        column_values = {}
        for row in rows:
            column_values[row[col]] = 1
        # 分類
        for value in column_values.keys():
            set_1, set_2 = divide_set(rows, col, value)

            p = float(len(set_1)) / len(rows)
            gain = current_score - p * scoref(set_1) - (1 - p) * scoref(set_2)
            if gain > best_gain and len(set_1) > 0 and len(set_2) > 0:
                best_gain = gain
                best_criteria = (col, value)
                best_sets = (set_1, set_2)
        # 創(chuàng)建子分支
    if best_gain > 0:
        # 不是葉子結(jié)點(diǎn),繼續(xù)遞歸分類,分類結(jié)果res=None, 判斷條件(特征)為col,臨界值為value
        true_branch = build_tree(best_sets[0])
        false_branch = build_tree(best_sets[1])
        return DecisionNode(col=best_criteria[0], value=best_criteria[1], tb=true_branch, fb=false_branch)
    else:
        # 不能再分類,返回分類的計(jì)數(shù)結(jié)果
        return DecisionNode(results=unique_counts(rows))

DecisionNode類

class DecisionNode:
    def __init__(
            self,
            col=-1,
            value=None,
            results=None,
            tb=None,
            fb=None
    ):
        self.col = col              # the criteria to be tested
        self.value = value          # true value
        self.results = results      # 分類結(jié)果,非葉子結(jié)點(diǎn)均為None
        self.tb = tb                # true
        self.fb = fb                # false

divide_set分割數(shù)據(jù)

def divide_set(rows, column, value):
    # 根據(jù)value對數(shù)據(jù)進(jìn)行2分類,set_1中為true, set_2中為false
    split_function = None
    if isinstance(value, int) or isinstance(value, float):
        split_function = lambda row: row[column] >= value
    else:
        split_function = lambda row: row[column] == value

    set_1 = [row for row in rows if split_function(row)]
    set_2 = [row for row in rows if not split_function(row)]

    return set_1, set_2

unique_counts對分類結(jié)果計(jì)數(shù)

def unique_counts(rows):
    results = {}
    for row in rows:
        r = row[len(row) - 1]    # 分類結(jié)果:None, Basic, Premium
        if r not in results:
            results[r] = 0
        results[r] += 1
    return results

entropy計(jì)算交叉熵

def entropy(rows):
    results = unique_counts(rows)
    ent = 0.0
    for r in results.keys():
        p = float(results[r]) / len(rows)
        ent -= p * log2(p)
    return ent

3.運(yùn)行測試

測試數(shù)據(jù)

my_data = [['slashdot', 'USA', 'yes', 18, 'None'],
           ['google', 'France', 'yes', 23, 'Premium'],
           ['digg', 'USA', 'yes', 24, 'Basic'],
           ['kiwibotes', 'France', 'yes', 23, 'Basic'],
           ['google', 'UK', 'no', 21, 'Premium'],
           ['(direct)', 'New Zealand', 'no', 12, 'None'],
           ['(direct)', 'UK', 'no', 21, 'Basic'],
           ['google', 'USA', 'no', 24, 'Premium'],
           ['slashdot', 'France', 'yes', 19, 'None'],
           ['digg', 'USA', 'no', 18, 'None'],
           ['google', 'UK', 'no', 18, 'None'],
           ['kiwitobes', 'UK', 'no', 19, 'None'],
           ['digg', 'New Zealand', 'yes', 12, 'Basic'],
           ['google', 'UK', 'yes', 18, 'Basic'],
           ['kiwitobes', 'France', 'yes', 19, 'Basic']]

展示結(jié)果

def print_tree(tree, indent=''):
    # 葉子結(jié)點(diǎn),其results為分類結(jié)果;否則,其results為None
    if tree.results is not None:
        print(str(tree.results))
    else:
        # 打印判斷條件
        print(str(tree.col) + ':' + str(tree.value) + "?")
        # 打印分支
        print(indent + "T->", end='')
        print_tree(tree.tb, indent+'  ')
        print(indent + "F->", end='')
        print_tree(tree.fb, indent+'  ')

運(yùn)行結(jié)果

3:21?
T->0:google?
  T->{'Premium': 3}
  F->{'Basic': 3}
F->2:yes?
  T->0:slashdot?
    T->{'None': 2}
    F->{'Basic': 3}
  F->{'None': 4}

可以驗(yàn)證,決策樹在訓(xùn)練集上準(zhǔn)確率為100%


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容