一、基本原理
1.1、思想與流程
決策樹(decision tree)是機(jī)器學(xué)習(xí)中常見的分類與回歸方法,是一種呈樹形結(jié)構(gòu)的判別模型。決策樹可以看做一個(gè)互斥且完備的if-then規(guī)則集合。同時(shí)決策樹還表示定義在特征條件下類的條件概率分布,條件概率分布將特征空間劃分為互不相交的單元(cell)或區(qū)域(region),每個(gè)單元定義一個(gè)類的概率分布就構(gòu)成了一個(gè)條件概率分布。該條件概率分布可表示為,其中
表示特征的隨機(jī)變量,
表示類的隨機(jī)變量。

決策樹學(xué)習(xí)本質(zhì)上是從訓(xùn)練集中歸納出一組分類規(guī)則,是訓(xùn)練數(shù)據(jù)矛盾較小,同時(shí)具有很好的泛化性能。決策樹的損失函數(shù)通常是正則化的極大似然函數(shù),學(xué)習(xí)的目標(biāo)是以損失函數(shù)為目標(biāo)函數(shù)的最小化。決策樹學(xué)習(xí)的算法通常是一個(gè)遞歸地選擇最優(yōu)特征,并根據(jù)該特征對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行分割,使得對(duì)各個(gè)子數(shù)據(jù)集有一個(gè)最好的分類的過(guò)程。其基本過(guò)程如下:

1.2、特征選擇
特征選擇在于選取對(duì)訓(xùn)練數(shù)據(jù)具有分類能力的特征,這是決策樹學(xué)習(xí)的關(guān)鍵。常見的特征選擇的準(zhǔn)則是信息增益和信息增益率。
信息增益(information gain)的表達(dá)式為 其中
為特征,
為訓(xùn)練數(shù)據(jù)集,
為信息熵,
為條件熵。
隨機(jī)變量的信息熵為
,條件熵為
給定下
的條件概率分布的熵對(duì)
的數(shù)學(xué)期望
。
信息增益率(information gain rate)即其信息增益與信息熵之比,即
1.3、剪枝
剪枝(pruning)是決策樹處理過(guò)擬合的主要手段。具體地,剪枝從已生成的樹上裁掉一些子樹或葉節(jié)點(diǎn),并將其根節(jié)點(diǎn)或父節(jié)點(diǎn)作為新的葉節(jié)點(diǎn),從而簡(jiǎn)化樹模型。剪枝可分為“預(yù)剪枝”和“后剪枝”,預(yù)剪枝是指在決策樹生成過(guò)程中,對(duì)每個(gè)節(jié)點(diǎn)在劃分前進(jìn)行估計(jì),若當(dāng)前節(jié)點(diǎn)的劃分不能帶來(lái)決策樹泛化性能的提升,則停止劃分并將當(dāng)前節(jié)點(diǎn)標(biāo)記為葉節(jié)點(diǎn);后剪枝則是先從訓(xùn)練集生成一棵完整的決策樹,然后自底向上對(duì)非葉節(jié)點(diǎn)進(jìn)行考察,若將該節(jié)點(diǎn)對(duì)應(yīng)子樹替換為葉節(jié)點(diǎn)能提高決策樹泛化性能,則將該子樹替換為葉節(jié)點(diǎn)。
二、算法實(shí)現(xiàn)
2.1、手動(dòng)實(shí)現(xiàn)
1、模塊導(dǎo)入與數(shù)據(jù)生成
import pickle
import operator
import matplotlib.pyplot as plt
from math import log
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], #數(shù)據(jù)集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['age', 'work', 'house', 'credit'] #特征標(biāo)簽
return dataSet, labels
def splitData(data,axis,value):
newData = []
for vec in data:
if vec[axis] == value:
newData.append((vec[:axis]+vec[axis+1:]))
return newData
2、計(jì)算交叉熵
def calcEntropy(data):
row = len(data)
label = {}
for vec in data:
current_label = vec[-1]
if current_label not in label.keys():
label[current_label] = 0
label[current_label] += 1
entropy = 0
for key in label:
prob = float(label[key])/row
entropy -= prob*log(prob,2)
return entropy
3、選擇最優(yōu)特征
def chooseFeature(data):
features = len(data[0])-1
entropy = calcEntropy(data)
best_info_gain = 0.0
best_feature = -1
for i in range(features):
feature_list = set([example[i] for example in data])
temp = 0.0
for value in feature_list:
subdata = splitData(data,i,value)
prob = len(subdata)/len(data)
temp += prob*calcEntropy(data)
info_gain = entropy - temp
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
4、統(tǒng)計(jì)類標(biāo)簽中最多的元素
def majorityClass(class_list):
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sort_class = sorted(class_count.items(),key = operator.itemgetter(1),reverse = True)
return sort_class[0][0]
5、創(chuàng)建決策樹
def createTree(data,labels,features):
class_list = [example[-1] for example in data]
if class_list.count(class_list[0]) == len(class_list):
return class_list
if len(data[0]) == 1 or len(labels) == 0:
return majorityClass(class_list)
best_feature = chooseFeature(data)
best_label = labels[best_feature]
features.append(best_label)
tree = {best_label:{}}
del(labels[best_feature])
feature_list = set([example[best_feature] for example in data])
for value in feature_list:
sublabels = labels[:]
tree[best_label][value] = createTree(splitData(data,best_feature,value),sublabels,features)
return tree
6、獲取決策樹葉子節(jié)點(diǎn)數(shù)目及決策樹層數(shù)
def numLeafs(tree):
leafs = 0
string = next(iter(tree))
tree_dict = tree[string]
for key in tree_dict.keys():
if type(tree_dict[key]).__name__=='dict':
leafs += numLeafs(tree_dict[key])
else:
leafs += 1
return leafs
def treeDepth(tree):
max_depth = 0
string = next(iter(tree))
tree_dict = tree[string]
for key in tree_dict.keys():
if type(tree_dict[key]).__name__=='dict':
depth = 1 + treeDepth(tree_dict[key])
else:
depth = 1
if depth>max_depth:
max_depth = depth
return max_depth
7、繪制圖像,包括樹、節(jié)點(diǎn)、邊屬性
def plotNode(node_txt,centerPt,parent,node_type):
arrow = dict(arrowstyle = '<-')
createPlot.ax1.annotate(node_txt,xy=parent,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',ha='center',bbox=node_type,arrowprops=arrow)
def plotText(cntr,parent,txt):
x_mid = (parent[0]-cntr[0])/2.0 + cntr[0]
y_mid = (parent[1]-cntr[1])/2.0 + cntr[1]
createPlot.ax1.text(x_mid,y_mid,txt,va='center',ha='center',rotation=30)
def createPlot(tree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(numLeafs(tree))
plotTree.totalD = float(treeDepth(tree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(tree,(0.5,1.0),'')
plt.show()
def plotTree(tree,parent,txt):
decision_node = dict(boxstyle='sawtooth',fc='0.8')
leaf_node = dict(boxstyle='round4',fc='0.8')
leafs = numLeafs(tree)
depth = treeDepth(tree)
string = next(iter(tree))
cntr = (plotTree.xOff + (1.0+float(leafs))/2.0/plotTree.totalW, plotTree.yOff)
plotText(cntr,parent,txt)
plotNode(string,cntr,parent,decision_node)
tree_dict = tree[string]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in tree_dict.keys():
if type(tree_dict[key]).__name__=='dict':
plotTree(tree_dict[key],cntr,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(tree_dict[key],(plotTree.xOff,plotTree.yOff),cntr,leaf_node)
plotText((plotTree.xOff,plotTree.yOff),cntr,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
8、使用決策樹進(jìn)行分類及存儲(chǔ)決策樹
def classify(tree,labels,vec):
string = next(iter(tree))
tree_dict = tree[string]
index = labels.index(string)
for key in tree_dict.keys():
if vec[index] == key:
if type(tree_dict[key]).__name__=='dict':
class_label = classify(tree_dict[key],labels,vec)
else:
class_label = tree_dict[key]
return class_label
def storeTree(tree,filename):
with open(filename,'wb') as f:
pickle.dump(tree,f)
9、主函數(shù)
if __name__ == '__main__':
data,labels = createDataSet()
feature_labels = []
tree = createTree(data,labels,feature_labels)
createPlot(tree)
test_vec = [0,1]
result = classify(tree,feature_labels,test_vec)
if result == 'yes':
print('lending')
if result == 'no':
print('no lending')
下圖為上述決策樹產(chǎn)生的分類示意圖,展現(xiàn)了分類樹的作用過(guò)程。

2.2、使用sklearn庫(kù)
可以調(diào)用mglearn庫(kù),展現(xiàn)動(dòng)物分類的過(guò)程,如下:
import mglearn
mglearn.plots.plot_animal_tree()

使用sklearn庫(kù)中DecisionTreeClassifier對(duì)癌癥數(shù)據(jù)集進(jìn)行處理,結(jié)果如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.tree import export_graphviz
import graphviz
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42)
tree = DecisionTreeClassifier(max_depth=4,random_state=0)
tree.fit(X_train,y_train)
print('accuracy on training set:{:.3f}'.format(tree.score(X_train,y_train)))
print('accuracy on test set:{:.3f}'.format(tree.score(X_test,y_test)))
accuracy on training set:0.988
accuracy on test set:0.951
繪制其分類過(guò)程如下:
export_graphviz(tree,out_file='tree.dot',class_names=['malignant','benign'],feature_names=cancer.feature_names,impurity=False,filled=True)
with open('tree.dot') as f:
dot_graph = f.read()
graphviz.Source(dot_graph)

同樣,可找出其特征重要性,如下:
n_features = cancer.data.shape[1]
plt.barh(range(n_features),tree.feature_importances_,align='center')
plt.yticks(np.arange(n_features),cancer.feature_names)
plt.xlabel('Feature importance')
plt.ylabel('Feature')

三、問(wèn)題探討
(信息)熵、聯(lián)合熵、條件熵、相對(duì)熵、互信息
熵(Entropy)是隨機(jī)變量的不確定性的度量。設(shè)是離散隨機(jī)變量,其概率密度函數(shù)
,則其信息熵為:
當(dāng)對(duì)數(shù)底數(shù)是
時(shí),單位是bit,當(dāng)對(duì)數(shù)底數(shù)是
時(shí),單位是nat(奈特)。
如果隨機(jī)變量,則其聯(lián)合熵(Joint entropy)為:
類似的,其條件熵(Conditional entropy)為:
另外,有如下關(guān)系成立:,證明如下:
交叉熵(Cross entropy)(又稱相對(duì)熵,KL散度等),是兩個(gè)隨機(jī)分布之間距離的度量。當(dāng)真實(shí)分布為,而假定分布為
,其交叉熵為:
互信息(Mutual information)是一個(gè)隨機(jī)變量包含另一個(gè)隨機(jī)變量信息量的度量,也可以說(shuō)是在給定一個(gè)隨機(jī)變量的條件下,原隨機(jī)變量的不確定性的減少量,即:
熵與互信息的關(guān)系: 證明如下:
參考資料
[1] https://github.com/lawlite19/MachineLearning_Python
[2] 周志華 著. 機(jī)器學(xué)習(xí). 北京:清華大學(xué)出版社,2016
[3] 李航 著. 統(tǒng)計(jì)學(xué)習(xí)方法. 北京:清華大學(xué)出版社,2012
[4] 史春奇等 著. 機(jī)器學(xué)習(xí)算法背后的理論與優(yōu)化. 北京:清華大學(xué)出版社,2019
[5] Peter Harrington 著. 李銳等 譯. 機(jī)器學(xué)習(xí)實(shí)戰(zhàn). 北京:人民郵電出版社,2013
勸君更盡一杯酒,西出陽(yáng)關(guān)無(wú)故人。 ——王維《送元二使安西》