機(jī)器學(xué)習(xí)實(shí)戰(zhàn)篇(決策樹)

決策樹算法

優(yōu)點(diǎn)

計(jì)算復(fù)雜度不高,輸出結(jié)果易于理解,對中間值得缺失不敏感,可以處理不相關(guān)特征數(shù)據(jù)。

缺點(diǎn)

可能會產(chǎn)生過度匹配問題

適用數(shù)據(jù)類型

數(shù)值型和標(biāo)稱型數(shù)據(jù)


信息增益

劃分?jǐn)?shù)據(jù)的大原則是:將無序的數(shù)據(jù)變得更加有序。組織雜亂無章數(shù)據(jù)的一種方法就是使用信息論度量信息。

在劃分?jǐn)?shù)據(jù)集之前之后信息發(fā)生的變化稱為信息增益。通過計(jì)算每個特征劃分?jǐn)?shù)據(jù)集獲得的信息增益來度量劃分后的數(shù)據(jù)是否更有序,獲得信息增益最大的特征就是此次劃分?jǐn)?shù)據(jù)的最好特征。即此次根據(jù)特征劃分?jǐn)?shù)據(jù)后,數(shù)據(jù)已經(jīng)被盡可能的正確分類。

對數(shù)據(jù)集合信息的度量方式稱為香農(nóng)熵。信息增益即為劃分?jǐn)?shù)據(jù)前后香農(nóng)熵的差值。

熵定義為信息的期望值。如果待分類的事務(wù)可能劃分在多個分類之中,則符號x的信息定義:

p(xi)是選擇該分類的概率。

熵的計(jì)算方式:

具體代碼實(shí)現(xiàn):

from math import log
from collections import defaultdict , Counter

###測試數(shù)據(jù)集
def createDataset():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    featLabels = ['no surfacing','flippers']
    return dataSet, featLabels

def calcShannonEnt(dataset):
    ###計(jì)算數(shù)據(jù)集的香農(nóng)熵,用于衡量該數(shù)據(jù)集的復(fù)雜度,如果該數(shù)據(jù)集越復(fù)雜,類別越多,香農(nóng)熵值越大,反之越小
    numEntries = len(dataset)
    labels = defaultdict(int) 
    for featVec in dataset:
        label = featVec[-1]
        labels[label] += 1
    shannonEnt= 0.0
    for v in labels.values():
        prob = float(v) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

在計(jì)算得到香農(nóng)熵之后就可以通過計(jì)算信息增益來尋找最好的分類特征

def splitDataset(dataset, axis, value):
    ###根據(jù)特征值劃分?jǐn)?shù)據(jù)集 axis:特征下標(biāo) value:特征值
    retDataset=[]
    for featVec in dataset:
        if featVec[axis] == value:
            reducedFeatvec = featVec[:axis]
            reducedFeatvec.extend(featVec[axis+1:])
            retDataset.append(reducedFeatvec)
    return retDataset

def chooseBestFeatureToSplit(dataset):
    ###在當(dāng)前數(shù)據(jù)集尋找最適合劃分?jǐn)?shù)據(jù)集的特征,通過計(jì)算根據(jù)每種特征劃分?jǐn)?shù)據(jù)集的信息熵之和,尋找熵增最大的特征
    numFeats = len(dataset[0]) - 1
    baseEntropy = calcShannonEnt(dataset)
    bestInfoGain = 0.0
    bestFeat = -1
    for i in range(numFeats):
        featValues = set([data[i] for data in dataset])
        newEntropy = 0.0
        for value in featValues:
            resDataset = splitDataset(dataset, i, value)
            prop = len(resDataset) / float(len(dataset))
            newEntropy += prop * calcShannonEnt(resDataset)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeat = i
    return bestFeat ##返回特征下標(biāo)

遞歸構(gòu)建決策樹

在每次分類動作進(jìn)行時根據(jù)最佳特征盡可能將相同分類分配在同一組數(shù)據(jù)中

Example

'no surfacing','flippers','is_fish'
       1           1          yes
       1           1          yes
       1           0          no
       0           1          no
       0           1          no

其決策樹:


決策樹構(gòu)建代碼:

def majorityCnt(classList):
    ###尋找當(dāng)前類標(biāo)簽出現(xiàn)最多的標(biāo)簽 classList:標(biāo)簽值集合
    return Counter(classList).most_common()[0][0]

def createTree(dataset, inputLabels):
    ####inputLabels: 特征標(biāo)簽
    labels = inputLabels[:] ##防止輸入標(biāo)簽被更改
    classList = [data[-1] for data in dataset]
    ###當(dāng)前數(shù)據(jù)集的所有標(biāo)簽值相同,分類結(jié)束,返回標(biāo)簽
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    ###當(dāng)前已經(jīng)根據(jù)所有特征劃分?jǐn)?shù)據(jù)集,返回最多的標(biāo)簽值
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    bestFeatIndex = chooseBestFeatureToSplit(dataset)
    bestFeatValues = set([data[bestFeatIndex] for data in dataset])
    bestFeatLabel = labels[bestFeatIndex]
    del labels[bestFeatIndex] ###劃分?jǐn)?shù)據(jù)集會移除該特征,對應(yīng)特征標(biāo)簽也要移除
    trees = {bestFeatLabel:{}}
    for value in bestFeatValues:
        new_labels = labels[:] ###這里必須創(chuàng)建新的list對象,傳參后引用會影響當(dāng)前l(fā)abels的值
        trees[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeatIndex, value), new_labels)
    return trees

繪制決策樹

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeText, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords='axes fraction', \
        xytext=centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)

def getNumLeafs(trees):
    numLeafs = 0
    firstStr = trees.keys()[0]
    secondDict = trees[firstStr]
    for key in secondDict.keys():
        if isinstance(secondDict[key],dict):
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(trees):
    maxDepth = 0
    firstStr = trees.keys()[0]
    secondDict = trees[firstStr]
    for key in secondDict.keys():
        if isinstance(secondDict[key], dict):
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
    plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD

def createPlot(tree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(tree))
    plotTree.totalD = float(getTreeDepth(tree))
    plotTree.x0ff = -0.5 / plotTree.totalW
    plotTree.y0ff = 1.0
    plotTree(tree, (0.5, 1.0), '')
    plt.show()

調(diào)用createPlot()函數(shù)即可繪制決策樹

構(gòu)建分類器

根據(jù)現(xiàn)有的決策樹構(gòu)建分類器

def classify(tree, featLabels, testVec):
### {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    firstStr = tree.keys()[0]
    secondDict = tree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if key == testVec[featIndex]:
            if isinstance(secondDict[key], dict):
                classLabels = classify(secondDict[key], featLabels, testVec)
            else:
                classLabels = secondDict[key]
    return classLabels
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容