機器學(xué)習(xí)——決策樹

決策樹(Decision Tree)ID3算法

概述

決策樹是一個預(yù)測模型;他代表的是對象屬性與對象值之間的一種映射關(guān)系。樹中每個節(jié)點表示某個對象,而每個分叉路徑則代表的某個可能的屬性值,而每個葉結(jié)點則對應(yīng)從根節(jié)點到該葉節(jié)點所經(jīng)歷的路徑所表示的對象的值。決策樹僅有單一輸出,若欲有復(fù)數(shù)輸出,可以建立獨立的決策樹以處理不同輸出。數(shù)據(jù)挖掘中決策樹是一種經(jīng)常要用到的技術(shù),可以用于分析數(shù)據(jù),同樣也可以用來作預(yù)測。

優(yōu)點:計算復(fù)雜度不高,輸出結(jié)果易于理解,對中間值的缺失不敏感,可以處理不想管特征數(shù)據(jù)
缺點:可能會產(chǎn)生過擬合問題
使用數(shù)據(jù)類型:數(shù)值型和標(biāo)稱型

相較于KNN,決策樹的主要優(yōu)勢在于數(shù)據(jù)形式非常容易理解

算法流程

  1. 收集數(shù)據(jù):可以使用任何方法
  2. 準(zhǔn)備數(shù)據(jù):樹構(gòu)造算法值適用于標(biāo)稱型數(shù)據(jù),因此數(shù)值型數(shù)據(jù)必須離散化
  3. 分析數(shù)據(jù):可以使用任何方法,構(gòu)造樹完成之后,我們應(yīng)該檢查圖形是否符合預(yù)期
  4. 訓(xùn)練算法:構(gòu)造樹的數(shù)據(jù)結(jié)構(gòu)
  5. 測試算法:使用經(jīng)驗樹計算錯誤率
  6. 使用算法:此步驟可以適用于任何監(jiān)督學(xué)習(xí)算法,而使用決策樹可以更好的理解數(shù)據(jù)的內(nèi)在含義

信息增益(information gain)

劃分?jǐn)?shù)據(jù)集的大原則是,將無序的數(shù)據(jù)變得更加有序。我們可以有多種方法劃分?jǐn)?shù)據(jù)集,但是每種方法都有各自的優(yōu)缺點。
在劃分?jǐn)?shù)據(jù)集之前之后信息發(fā)生的變化稱為信息增益,獲得信息增益最高的特征就是最好的選擇。

熵(entropy)

集合信息的度量方式稱為香農(nóng)熵或簡稱熵,這個名字來源于信息論之父克勞德·香農(nóng)。
熵定義為信息的期望值。如果待分類的事物可能劃分在多個分類之中,則符號xi的信息定義為


p(xi)是選擇該分類的概率。
為了計算熵,需要計算所有類別所有可能值包含的信息期望值,公式如下

n是分類的數(shù)目

from math import log

def calcShannonEnt(dataSet):
    '''
    計算給定數(shù)據(jù)集的熵
    '''
    # 獲取數(shù)據(jù)集示例數(shù)量
    numEntries = len(dataSet)
    # 構(gòu)造分類標(biāo)簽字典
    labelCounts = {}
    
    # 遍歷數(shù)據(jù)集,獲取分類標(biāo)簽數(shù)量
    for featVec in dataSet:
        curLable = featVec[-1]
        if curLable not in labelCounts.keys():
            labelCounts[curLable] = 0
        labelCounts[curLable] += 1
    # 遍歷分類標(biāo)簽,計算熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    
    return shannonEnt

海洋生物數(shù)據(jù),如下

不浮出水面是否可以生存 是否有腳蹼 屬于魚類
1
2
3
4
5

根據(jù)海洋生物數(shù)據(jù)構(gòu)造數(shù)據(jù)集

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

計算數(shù)據(jù)集的熵

dataSet, labels = createDataSet()
calcShannonEnt(dataSet)
[out] 0.9709505944546686

熵越大,則代表混合的數(shù)據(jù)越多。

dataSet[0][-1]='maybe' # 增加一個新的分類maybe
calcShannonEnt(dataSet)
[out] 1.3709505944546687

分類增加,導(dǎo)致熵變大。

劃分?jǐn)?shù)據(jù)集

分類算法除了需要測量數(shù)據(jù)集的無序程度,還需要劃分?jǐn)?shù)據(jù)集,度量劃分?jǐn)?shù)據(jù)集的無序程度,以便判斷當(dāng)前劃分是否正確。

def splitDataSet(dataSet, axis, value):
    '''
    按照給定特征劃分?jǐn)?shù)據(jù)集
    param dataSet: 待劃分的數(shù)據(jù)集
    param axis: 劃分?jǐn)?shù)據(jù)集的特征
    param value: 需要返回的特征的值
    '''
    retDataSet = []
    # 遍歷數(shù)據(jù)集,返回給定特征等于特定值的示例集
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

例如我們要以特征“不浮出水面是否可以生存”進行劃分,然后返回可以生存的示例

splitDataSet(dataSet, 0, 1)
[out] [[1, 'maybe'], [1, 'yes'], [0, 'no']]

現(xiàn)在需要通過計算熵,找到最好的劃分?jǐn)?shù)據(jù)的方式

def chooseBestFeatureToSplit(dataSet):
    '''
    選擇最好的數(shù)據(jù)劃分方式
    '''
    # 特征值數(shù)量
    numFeatures = len(dataSet[0]) - 1
    # 數(shù)據(jù)集劃分前的熵
    baseEntropy = calcShannonEnt(dataSet)
    # 最優(yōu)的信息增益
    bestInfoGain = 0.0
    # 最優(yōu)的數(shù)據(jù)劃分特征
    bestFeature = -1
    
    # 遍歷特征,對每個特征進行數(shù)據(jù)劃分,找到最優(yōu)信息增益的特征
    for i in range(numFeatures):
        # 創(chuàng)建唯一的特征值列表
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        # 數(shù)據(jù)劃分后的熵
        newEntropy = 0.0
        
        # 按照指定特征進行數(shù)據(jù)劃分,并計算數(shù)據(jù)劃分后的熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        
        # 找到最優(yōu)信息增益所對應(yīng)的特征值
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    
    return bestFeature

用海洋生物數(shù)據(jù)進行測試,發(fā)現(xiàn)第一次最優(yōu)的數(shù)據(jù)劃分特征是“不浮出水面是否可以生存”

chooseBestFeatureToSplit(dataSet)
[out] 0

遞歸構(gòu)造決策樹

由于特征值可能不止一個,因此存在大于兩個分支的數(shù)據(jù)集劃分。劃分一次后,數(shù)據(jù)將被向下傳遞到樹分支節(jié)點,進行再次劃分。因此可以采用遞歸的原則處理數(shù)據(jù)。
偽代碼如下:
if 類別相同
??return 該類別
elif 遍歷完所有特征
??return 返回數(shù)量最多的類別
elif
??尋找劃分?jǐn)?shù)據(jù)的最好特征
??劃分?jǐn)?shù)據(jù)集
??創(chuàng)建分支節(jié)點
??for 每個劃分的子集
????調(diào)用函數(shù)createTree并增加返回結(jié)果到分支節(jié)點中
??return 分支節(jié)點

import operator

def majorityCnt(classList):
    '''
    獲取次數(shù)最多的分類名稱
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    
    sortedClassCount = sorted(classCount.iteritems(),
                             key=operator.itemgetter(1),
                             reversed=True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    # 數(shù)據(jù)集的所有分類
    classList = [example[-1] for example in dataSet]
    # 類別完全相同則停止劃分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 遍歷完所有特征時返回次數(shù)最多的類別
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    
    # 選擇數(shù)據(jù)劃分最優(yōu)特征并構(gòu)建樹
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    
    # 劃分?jǐn)?shù)據(jù)集,創(chuàng)建分支節(jié)點,并遞歸分支節(jié)點
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), subLabels)
        
    return myTree

用海洋生物數(shù)據(jù)集進行測試??梢园l(fā)現(xiàn)返回值是一個嵌套的字典類型。如果字典的值是數(shù)據(jù)字典,代表這是一個分支節(jié)點;如果字典的值是一個特定值,那么代表這是一個葉節(jié)點。

dataSet, labels = createDataSet()
createTree(dataSet, labels)
[out] {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

使用文本注解繪制樹節(jié)點

# -.- coding:utf-8 -.-
import matplotlib.pyplot as plt
import matplotlib as mpl

# 繪圖中文顯示
mpl.rcParams['font.sans-serif'] = ['KaiTi']
mpl.rcParams['font.serif'] = ['KaiTi']

# 定義文本框和箭頭格式
desisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    '''
    繪制帶箭頭的注解
    param nodeTxt: 注解
    param centerPt: 箭頭所在坐標(biāo)
    param parentPt: 箭尾所在坐標(biāo)
    param nodeType: 節(jié)點樣式
    '''
    createPlot.ax1.annotate(nodeTxt, 
                            xy=parentPt, 
                            xycoords='axes fraction', 
                            xytext=centerPt, 
                            textcoords='axes fraction', 
                            va='center', 
                            ha='center', 
                            bbox=nodeType, 
                            arrowprops=arrow_args)
    
def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode(u'決策節(jié)點', (0.5, 0.1), (0.1, 0.5), desisionNode)
    plotNode(u'葉節(jié)點', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

測試注解方法

createPlot()

構(gòu)造注解樹

首先需要知道有多少個葉節(jié)點,以便確定x軸的長度;還需要知道樹有多少層,以便確定y軸的高度。

def getNumLeafs(myTree):
    '''
    獲取葉節(jié)點數(shù)量
    '''
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    # 如果子節(jié)點為字典,繼續(xù)遞歸,否則葉節(jié)點數(shù)量加1
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
            
    return numLeafs

def getTreeDepth(myTree):
    '''
    獲取樹的深度
    '''
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    # 遞歸計算子節(jié)點最大深度
    for key in secondDict.keys():
        if type(secondDict[key]).__name__== 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    
    return maxDepth

def retrieveTree(i):
    '''
    測試用樹信息,樹列表中包含了兩顆樹
    '''
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

測試方法

print getNumLeafs(retrieveTree(0))
print getTreeDepth(retrieveTree(0))
[out] 3
[out] 2

構(gòu)造樹形注解

def plotMidText(cntrPt, parentPt, txtString):
    '''
    父子節(jié)點間填充文本
    '''
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)
    
def plotTree(myTree, parentPt, nodeTxt):
    # 計算樹的高和寬
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, desisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    # 遞歸畫出分支節(jié)點
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 計算樹的高度與寬度,并保存與全局變量中
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
myTree = retrieveTree(1)
createPlot(myTree)

打印結(jié)果如下


使用決策樹進行分類

def classify(inputTree, featLabels, testVec):
    '''
    決策樹的分類函數(shù)
    param inputTree: 訓(xùn)練用樹決策樹
    param featLables: 訓(xùn)練用分類標(biāo)簽
    param testVec: 用于分類的輸入向量
    '''
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    # 根節(jié)點對應(yīng)特征標(biāo)簽列表的索引值
    featIndex = featLabels.index(firstStr)
    # 遞歸遍歷樹,比較testVec變量中的值與樹節(jié)點的值,如果達到葉子節(jié)點,則返回當(dāng)前節(jié)點的分類標(biāo)簽
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    
    return classLabel

輸入測試數(shù)據(jù)[1, 0],即不浮出水面可以生存, 沒有腳蹼的海洋生物,根據(jù)決策樹分類結(jié)果為不屬于魚類

dataSet, labels = createDataSet()
myTree = retrieveTree(0)
print labels
print myTree
print classify(myTree, labels, [1, 0])
[out] ['no surfacing', 'flippers']
[out] {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
[out] no

決策樹的存儲

構(gòu)造決策樹是非常耗時的任務(wù),然后用構(gòu)建耗的決策樹解決分類問題,則可以很快完成。因此,在每次執(zhí)行分類時最好調(diào)用已經(jīng)構(gòu)造好的決策樹。為了保存決策樹,可以使用pickle序列化對象,將其保存在磁盤中,并在需要的時候讀取出來。

def storeTree(inputTree, filename):
    '''
    序列化對象并保存至指定路徑
    '''
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()
    
def grabTree(filename):
    '''
    讀取文件并反序列化
    '''
    import pickle
    fr = open(filename)
    return pickle.load(fr)

序列化并存儲決策樹。執(zhí)行代碼后,可以看到文件夾中多了一個classifierStorage.txt的文件,打開后會看到序列化對象后的字符串

storeTree(myTree, 'classifierStorage.txt')

讀取文件并反序列化,可以看到?jīng)Q策樹被正確讀取出來

grabTree('classifierStorage.txt')
[out] {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

通過對決策樹的存儲,不用在每次對數(shù)據(jù)分類時重新學(xué)習(xí)一遍,這也是決策樹的優(yōu)點之一。而k近鄰算法就無法持久化分類器。

示例:使用決策樹預(yù)測隱形眼鏡類型

  1. 收集數(shù)據(jù):提供的文本文件
  2. 準(zhǔn)備數(shù)據(jù):解析tab分割的數(shù)據(jù)行
  3. 分析數(shù)據(jù):快速檢查數(shù)據(jù),確保正確的解析數(shù)據(jù)內(nèi)容,使用createPlot()函數(shù)繪制最終的屬性圖
  4. 訓(xùn)練算法:使用createTree()函數(shù)
  5. 測試算法:編寫測試函數(shù)驗證決策樹可以正確分類給定的數(shù)據(jù)實例
  6. 使用算法:存儲樹的數(shù)據(jù)結(jié)構(gòu),以便下次使用時無需重新構(gòu)造樹

隱形眼鏡數(shù)據(jù)文件下載

# 讀取數(shù)據(jù)文件
fr = open('lenses.txt')
# 解析tab分割符
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
# 構(gòu)造特征標(biāo)簽
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
# 訓(xùn)練算法
lensesTree = createTree(lenses, lensesLabels)

打印決策樹看一下,發(fā)現(xiàn)已經(jīng)成功構(gòu)造

lensesTree
[out] {'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft',
      'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}},
      'young': 'soft'}},
    'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses',
        'presbyopic': 'no lenses',
        'young': 'hard'}},
      'myope': 'hard'}}}},
  'reduced': 'no lenses'}}

圖形化顯示決策樹,更能直觀的看出劃分情況。沿著決策樹的不同分支,就能得到不同患者需要佩戴的隱形眼鏡類型。

createPlot(lensesTree)

可以看出該決策樹非常好的匹配了實驗數(shù)據(jù),但是匹配項可能太多了,會造成過擬合。為了減少過度匹配的問題,可以裁剪決策樹,去掉一些不必要的葉子節(jié)點。

總結(jié)

ID3算法無法直接處理數(shù)值型數(shù)據(jù),可以用戶劃分標(biāo)稱型數(shù)據(jù)集。構(gòu)造決策樹時,通常使用遞歸的方法將數(shù)據(jù)集轉(zhuǎn)化為決策樹。
除了ID3算法以外,還有其他決策樹的構(gòu)造算法,最流行的是C4.5和CART

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容