機(jī)器學(xué)習(xí)(七):決策樹

一、基本原理

1.1、思想與流程

決策樹(decision tree)是機(jī)器學(xué)習(xí)中常見的分類與回歸方法,是一種呈樹形結(jié)構(gòu)的判別模型。決策樹可以看做一個(gè)互斥且完備的if-then規(guī)則集合。同時(shí)決策樹還表示定義在特征條件下類的條件概率分布,條件概率分布將特征空間劃分為互不相交的單元(cell)或區(qū)域(region),每個(gè)單元定義一個(gè)類的概率分布就構(gòu)成了一個(gè)條件概率分布。該條件概率分布可表示為P(Y | X),其中X表示特征的隨機(jī)變量,Y表示類的隨機(jī)變量。

決策樹與概率分布對(duì)應(yīng)關(guān)系

決策樹學(xué)習(xí)本質(zhì)上是從訓(xùn)練集中歸納出一組分類規(guī)則,是訓(xùn)練數(shù)據(jù)矛盾較小,同時(shí)具有很好的泛化性能。決策樹的損失函數(shù)通常是正則化的極大似然函數(shù),學(xué)習(xí)的目標(biāo)是以損失函數(shù)為目標(biāo)函數(shù)的最小化。決策樹學(xué)習(xí)的算法通常是一個(gè)遞歸地選擇最優(yōu)特征,并根據(jù)該特征對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行分割,使得對(duì)各個(gè)子數(shù)據(jù)集有一個(gè)最好的分類的過(guò)程。其基本過(guò)程如下:

決策樹基本過(guò)程

1.2、特征選擇

特征選擇在于選取對(duì)訓(xùn)練數(shù)據(jù)具有分類能力的特征,這是決策樹學(xué)習(xí)的關(guān)鍵。常見的特征選擇的準(zhǔn)則是信息增益和信息增益率。

信息增益(information gain)的表達(dá)式為g(D, A)=H(D)-H(D | A) 其中A為特征,D為訓(xùn)練數(shù)據(jù)集,H(D)為信息熵,H(D | A)為條件熵。
隨機(jī)變量X的信息熵為H(X)=-\sum_{i=1}^{n} p_{i} \log p_{i},條件熵為X給定下Y的條件概率分布的熵對(duì)X的數(shù)學(xué)期望H(Y | X)=\sum_{i=1}^{n} p_{i} H\left(Y | X=x_{i}\right)。
信息增益率(information gain rate)即其信息增益與信息熵之比,即g_{R}(D, A)=\frac{g(D, A)}{H(D)}

1.3、剪枝

剪枝(pruning)是決策樹處理過(guò)擬合的主要手段。具體地,剪枝從已生成的樹上裁掉一些子樹或葉節(jié)點(diǎn),并將其根節(jié)點(diǎn)或父節(jié)點(diǎn)作為新的葉節(jié)點(diǎn),從而簡(jiǎn)化樹模型。剪枝可分為“預(yù)剪枝”和“后剪枝”,預(yù)剪枝是指在決策樹生成過(guò)程中,對(duì)每個(gè)節(jié)點(diǎn)在劃分前進(jìn)行估計(jì),若當(dāng)前節(jié)點(diǎn)的劃分不能帶來(lái)決策樹泛化性能的提升,則停止劃分并將當(dāng)前節(jié)點(diǎn)標(biāo)記為葉節(jié)點(diǎn);后剪枝則是先從訓(xùn)練集生成一棵完整的決策樹,然后自底向上對(duì)非葉節(jié)點(diǎn)進(jìn)行考察,若將該節(jié)點(diǎn)對(duì)應(yīng)子樹替換為葉節(jié)點(diǎn)能提高決策樹泛化性能,則將該子樹替換為葉節(jié)點(diǎn)。

二、算法實(shí)現(xiàn)

2.1、手動(dòng)實(shí)現(xiàn)

1、模塊導(dǎo)入與數(shù)據(jù)生成

import pickle
import operator
import matplotlib.pyplot as plt
from math import log

def createDataSet():
    dataSet = [[0, 0, 0, 0, 'no'],                      #數(shù)據(jù)集
            [0, 0, 0, 1, 'no'],
            [0, 1, 0, 1, 'yes'],
            [0, 1, 1, 0, 'yes'],
            [0, 0, 0, 0, 'no'],
            [1, 0, 0, 0, 'no'],
            [1, 0, 0, 1, 'no'],
            [1, 1, 1, 1, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [1, 0, 1, 2, 'yes'],
            [2, 0, 1, 2, 'yes'],
            [2, 0, 1, 1, 'yes'],
            [2, 1, 0, 1, 'yes'],
            [2, 1, 0, 2, 'yes'],
            [2, 0, 0, 0, 'no']]
    labels = ['age', 'work', 'house', 'credit']     #特征標(biāo)簽
    return dataSet, labels  

def splitData(data,axis,value):
    newData = []
    for vec in data:
        if vec[axis] == value:
            newData.append((vec[:axis]+vec[axis+1:]))
    return newData

2、計(jì)算交叉熵

def calcEntropy(data):
    row = len(data)
    label = {}
    for vec in data:
        current_label = vec[-1]
        if current_label not in label.keys():
            label[current_label] = 0
        label[current_label] += 1
    entropy = 0
    for key in label:
        prob = float(label[key])/row
        entropy -= prob*log(prob,2)
    return entropy

3、選擇最優(yōu)特征

def chooseFeature(data):
    features = len(data[0])-1
    entropy = calcEntropy(data)
    best_info_gain = 0.0
    best_feature = -1
    for i in range(features):
        feature_list = set([example[i] for example in data])
        temp = 0.0
        for value in feature_list:
            subdata = splitData(data,i,value)
            prob = len(subdata)/len(data)
            temp += prob*calcEntropy(data)
        info_gain = entropy - temp
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature

4、統(tǒng)計(jì)類標(biāo)簽中最多的元素

def majorityClass(class_list):
    class_count = {}
    for vote in class_list:
        if vote not in class_count.keys():
            class_count[vote] = 0
            class_count[vote] += 1
    sort_class = sorted(class_count.items(),key = operator.itemgetter(1),reverse = True)
    return sort_class[0][0]

5、創(chuàng)建決策樹

def createTree(data,labels,features):
    class_list = [example[-1] for example in data]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list
    if len(data[0]) == 1 or len(labels) == 0:
        return majorityClass(class_list)
    best_feature = chooseFeature(data)
    best_label = labels[best_feature]
    features.append(best_label)
    
    tree = {best_label:{}}
    del(labels[best_feature])
    
    feature_list = set([example[best_feature] for example in data])
    for value in feature_list:
        sublabels = labels[:]
        tree[best_label][value] = createTree(splitData(data,best_feature,value),sublabels,features)
    return tree

6、獲取決策樹葉子節(jié)點(diǎn)數(shù)目及決策樹層數(shù)

def numLeafs(tree):
    leafs = 0
    string = next(iter(tree))
    tree_dict = tree[string]
    for key in tree_dict.keys():
        if type(tree_dict[key]).__name__=='dict':
            leafs += numLeafs(tree_dict[key])
        else:
            leafs += 1
    return leafs
def treeDepth(tree):
    max_depth = 0
    string = next(iter(tree))
    tree_dict = tree[string]
    for key in tree_dict.keys():
        if type(tree_dict[key]).__name__=='dict':
            depth = 1 + treeDepth(tree_dict[key])
        else:
            depth = 1
        if depth>max_depth:
            max_depth = depth
    return max_depth

7、繪制圖像,包括樹、節(jié)點(diǎn)、邊屬性

def plotNode(node_txt,centerPt,parent,node_type):
    arrow = dict(arrowstyle = '<-')
    createPlot.ax1.annotate(node_txt,xy=parent,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=node_type,arrowprops=arrow)
         
def plotText(cntr,parent,txt):
    x_mid = (parent[0]-cntr[0])/2.0 + cntr[0]
    y_mid = (parent[1]-cntr[1])/2.0 + cntr[1]
    createPlot.ax1.text(x_mid,y_mid,txt,va='center',ha='center',rotation=30)
    
def createPlot(tree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(numLeafs(tree))
    plotTree.totalD = float(treeDepth(tree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(tree,(0.5,1.0),'')
    plt.show()
    
def plotTree(tree,parent,txt):
    decision_node = dict(boxstyle='sawtooth',fc='0.8')
    leaf_node = dict(boxstyle='round4',fc='0.8')
    leafs = numLeafs(tree)
    depth = treeDepth(tree)
    string = next(iter(tree))
    cntr = (plotTree.xOff + (1.0+float(leafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotText(cntr,parent,txt)
    plotNode(string,cntr,parent,decision_node)
    tree_dict = tree[string]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    
    for key in tree_dict.keys():
        if type(tree_dict[key]).__name__=='dict':
            plotTree(tree_dict[key],cntr,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(tree_dict[key],(plotTree.xOff,plotTree.yOff),cntr,leaf_node)
            plotText((plotTree.xOff,plotTree.yOff),cntr,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

8、使用決策樹進(jìn)行分類及存儲(chǔ)決策樹

def classify(tree,labels,vec):
    string = next(iter(tree))
    tree_dict = tree[string]
    index = labels.index(string)
    
    for key in tree_dict.keys():
        if vec[index] == key:
            if type(tree_dict[key]).__name__=='dict':
                class_label = classify(tree_dict[key],labels,vec)
            else:
                class_label = tree_dict[key]
    return class_label

def storeTree(tree,filename):
    with open(filename,'wb') as f:
        pickle.dump(tree,f)

9、主函數(shù)

if __name__ == '__main__':
    data,labels = createDataSet()
    feature_labels = []
    tree = createTree(data,labels,feature_labels)
    createPlot(tree)
    
    test_vec = [0,1]
    result = classify(tree,feature_labels,test_vec)
    if result == 'yes':
        print('lending')
    if result == 'no':
        print('no lending')

下圖為上述決策樹產(chǎn)生的分類示意圖,展現(xiàn)了分類樹的作用過(guò)程。


分類示意圖

2.2、使用sklearn庫(kù)

可以調(diào)用mglearn庫(kù),展現(xiàn)動(dòng)物分類的過(guò)程,如下:

import mglearn
mglearn.plots.plot_animal_tree()
動(dòng)物分類舉例

使用sklearn庫(kù)中DecisionTreeClassifier對(duì)癌癥數(shù)據(jù)集進(jìn)行處理,結(jié)果如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import export_graphviz
import graphviz

cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42)

tree = DecisionTreeClassifier(max_depth=4,random_state=0)
tree.fit(X_train,y_train)

print('accuracy on training set:{:.3f}'.format(tree.score(X_train,y_train)))
print('accuracy on test set:{:.3f}'.format(tree.score(X_test,y_test)))

accuracy on training set:0.988
accuracy on test set:0.951

繪制其分類過(guò)程如下:

export_graphviz(tree,out_file='tree.dot',class_names=['malignant','benign'],feature_names=cancer.feature_names,impurity=False,filled=True)
with open('tree.dot') as f:
    dot_graph = f.read()
    
graphviz.Source(dot_graph)
癌癥數(shù)據(jù)集分類過(guò)程

同樣,可找出其特征重要性,如下:

n_features = cancer.data.shape[1]
plt.barh(range(n_features),tree.feature_importances_,align='center')
plt.yticks(np.arange(n_features),cancer.feature_names)
plt.xlabel('Feature importance')
plt.ylabel('Feature')
特征重要性分布

三、問(wèn)題探討

(信息)熵、聯(lián)合熵、條件熵、相對(duì)熵、互信息

熵(Entropy)是隨機(jī)變量的不確定性的度量。設(shè)X是離散隨機(jī)變量,其概率密度函數(shù)p(x)=P(X=x),則其信息熵為:H(X)=-\sum_{x} p(x) \log p(x) 當(dāng)對(duì)數(shù)底數(shù)是2時(shí),單位是bit,當(dāng)對(duì)數(shù)底數(shù)是e時(shí),單位是nat(奈特)。
如果隨機(jī)變量(X, Y) \sim p(x, y),則其聯(lián)合熵(Joint entropy)為:H(X, Y)=-\sum_{x } \sum_{y} p(x, y) \log p(x, y)=-E \log p(X, Y)
類似的,其條件熵(Conditional entropy)為:\begin{aligned} H(Y | X)=\sum_{x} p(x) H(Y | X=x)=-& \sum_{x} p(x) \sum_{y} p(y | x) \log p(y | x)=-\sum_{x } \sum_{y} p(x, y) \log p(y | x) =-E \log p(Y | X) \end{aligned}
另外,有如下關(guān)系成立:H(X, Y)=H(X)+H(Y | X),證明如下:H(X, Y) =-\sum_{x} \sum_{y} p(x, y) \log p(x, y)=-\sum_{x } \sum_{y } p(x, y) \log p(x) p(y | x) \\ =-\sum_{x } \sum_{y } p(x, y) \log p(x)-\sum_{x} \sum_{y} p(x, y) \log p(y | x) \\ =-\sum_{x} p(x) \log p(x)+H(Y | X)=H(X)+H(Y | X)
交叉熵(Cross entropy)(又稱相對(duì)熵,KL散度等),是兩個(gè)隨機(jī)分布之間距離的度量。當(dāng)真實(shí)分布為p(x),而假定分布為q(x),其交叉熵為:D(p \| q)=\sum_{x} p(x) \log \frac{p(x)}{q(x)}
互信息(Mutual information)是一個(gè)隨機(jī)變量包含另一個(gè)隨機(jī)變量信息量的度量,也可以說(shuō)是在給定一個(gè)隨機(jī)變量的條件下,原隨機(jī)變量的不確定性的減少量,即:I(X ; Y)=\sum_{x } \sum_{y} p(x, y) \log \frac{p(x, y)}{p(x) p(y)}=D(p(x, y) \| p(x) p(y))
熵與互信息的關(guān)系:I(X ; Y)=H(Y)-H(Y | X)=H(X)-H(X | Y) 證明如下:
\begin{array}{c}{I(X ; Y)=\sum_{x , y } p(x, y) \log \frac{p(x, y)}{p(x) p(y)}=\sum_{x , y } p(x, y) \log \frac{p(x | y)}{p(x)}} \\ {=-\sum_{x , y} p(x, y) \log p(x)-\left(-\sum_{x, y} p(x, y) \log p(x | y)\right) \\=H(X)-H(X | Y)}\end{array}

參考資料

[1] https://github.com/lawlite19/MachineLearning_Python
[2] 周志華 著. 機(jī)器學(xué)習(xí). 北京:清華大學(xué)出版社,2016
[3] 李航 著. 統(tǒng)計(jì)學(xué)習(xí)方法. 北京:清華大學(xué)出版社,2012
[4] 史春奇等 著. 機(jī)器學(xué)習(xí)算法背后的理論與優(yōu)化. 北京:清華大學(xué)出版社,2019
[5] Peter Harrington 著. 李銳等 譯. 機(jī)器學(xué)習(xí)實(shí)戰(zhàn). 北京:人民郵電出版社,2013

勸君更盡一杯酒,西出陽(yáng)關(guān)無(wú)故人。 ——王維《送元二使安西》

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容