原理
- 通過選擇最好的特征來劃分?jǐn)?shù)據(jù)集,對(duì)數(shù)據(jù)子集繼續(xù)劃分,直到數(shù)據(jù)子集中是相同的類別;劃分?jǐn)?shù)據(jù)集的特征可以通過計(jì)算信息增益的方法來選擇。
優(yōu)點(diǎn)
- 計(jì)算復(fù)雜度不高,輸出結(jié)果容易理解,可以處理不相關(guān)特征數(shù)據(jù)。
缺點(diǎn)
- 可能會(huì)產(chǎn)生過度匹配問題。
適用數(shù)據(jù)類型:
- 標(biāo)稱型和數(shù)值型數(shù)據(jù)。
import math
#計(jì)算熵
def calEntropy(data_2d_array):
classDict = {}
for item in data_2d_array:
classDiff = item[-1]
classDict[classDiff] = classDict.get(classDiff,0) + 1
#計(jì)算熵的公式
h = 0.0
for key in classDict:
prob = classDict[key] / len(data_2d_array) #計(jì)算概率
h -= prob * math.log(prob,2) #計(jì)算熵
return h
calEntropy([[1,1,'yes'],[1,0,'no'],[1,1,'yes'],[0,1,'no'],[0,1,'no']])
0.9709505944546686
#劃分?jǐn)?shù)據(jù)集
def splitData(data_2d_array, feature, value):
restData = []
for diff in data_2d_array:
if diff[feature] == value:
restList = diff[:feature]
restList.extend(diff[feature+1:])
restData.append(restList)
return restData
data_2d_array=[[1,1,'yes'],[1,0,'no'],[1,1,'yes'],[0,1,'no'],[0,1,'no']]
splitData(data_2d_array,0,1)
[[1, 'yes'], [0, 'no'], [1, 'yes']]
def chooseBestFeatureToSplit(data_2d_array):
bestInfoGain = 0.0
bestFeatureIndex = -1
numFeature = len(data_2d_array[0])-1
beforeEntropy = calEntropy(data_2d_array)
for i in range(numFeature):
uniqValue = set([x[i] for x in data_2d_array])
afterEntropy = 0
for val in uniqValue:
restData = splitData(data_2d_array,i,val)
subEntropy = calEntropy(restData)
power = len(restData) / len(data_2d_array)
afterEntropy += power*subEntropy #公式參考西瓜書P35
InfoGain = beforeEntropy - afterEntropy #afterEntropy越小,數(shù)據(jù)越有序,InfoGain越大
if InfoGain >= bestInfoGain:
bestInfoGain = InfoGain
bestFeatureIndex = i
return bestFeatureIndex
chooseBestFeatureToSplit(data_2d_array)
0
'''
構(gòu)建決策樹偽代碼:
if 樣本類別相同:
return 類別
if:遍歷完所有特征:
return 投票結(jié)果
else:
選擇最優(yōu)特征
劃分?jǐn)?shù)據(jù)集
創(chuàng)建分支節(jié)點(diǎn)
for 每個(gè)數(shù)據(jù)集:
分支節(jié)點(diǎn) += 遞歸構(gòu)建決策樹
return 分支節(jié)點(diǎn)
'''
#少數(shù)服從多數(shù)投票
def vote(classList):
classCount = {}
for i in classList:
classCount[i] += classCount.get(i,0)
sortClass = sorted(classCount.items(), key= lambda x: x[1])
return sortClass[0][-1]
#構(gòu)建決策樹
def createTree(data_2d_array,feature):
classList = [i[-1] for i in data_2d_array]
if classList.count(classList[0]) == len(data_2d_array):
return classList[0]
if len(data_2d_array[0]) == 1:
return vote(classList)
#選擇最優(yōu)特征
bestFeatureIndex = chooseBestFeatureToSplit(data_2d_array)
bestFeature = feature[bestFeatureIndex]
#創(chuàng)建分支節(jié)點(diǎn)
myTree = {feature[bestFeatureIndex]:{}}
del(feature[bestFeatureIndex])
#數(shù)據(jù)子集繼續(xù)劃分
uniFeatureValueSet = set([i[bestFeatureIndex] for i in data_2d_array])
for val in uniFeatureValueSet:
subFeature = feature[:]
subData = splitData(data_2d_array, bestFeatureIndex, val)
myTree[bestFeature][val] = createTree(subData,subFeature)
return myTree
data_2d_array=[[1,1,'yes'],[1,0,'no'],[1,1,'yes'],[0,1,'no'],[0,1,'no']]
createTree(data_2d_array,feature=['no surfacing','flippers'])
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#使用文本注解繪制樹節(jié)點(diǎn)
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle='sawtooth',fc='0.8')
leafNode = dict(boxstyle='round4',fc='.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va = "center", \
ha = "center", bbox = nodeType, arrowprops = arrow_args)
def createPlot():
fig = plt.figure(facecolor='white')
# fig.clf()
createPlot.ax1 = plt.subplot(111, frameon = False) #frameon 是否繪制矩形貼圖
plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
createPlot()

output_5_0.png
#獲取葉子節(jié)點(diǎn)個(gè)數(shù)和樹的深度
def getNumLeafs(myTree):
firstNode = list(myTree.keys())[0]
secondDict = myTree[firstNode]
numLeafs = 0
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstNode = list(myTree.keys())[0]
secondDict = myTree[firstNode]
#遍歷所有節(jié)點(diǎn)來計(jì)算這個(gè)節(jié)點(diǎn)的深度,找到最深那條分支,裝進(jìn)袋子里(裝袋法)。
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
#本節(jié)點(diǎn)深度 = 1(本節(jié)點(diǎn))+ 子樹深度
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
'''自己做的錯(cuò)誤解法:
def getTreeDepth(myTree):
numDepth = 0
firstNode = list(myTree.keys())[0]
secondDict = myTree[firstNode]
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
numDepth += 1
else:
numDepth = 1
return numDepth
'''
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1:{'flippers': {0: 'no', 1: 'yes'}}}}}}
print('number of leafs is: %d'%getNumLeafs(myTree))
print('depth of tree is: %d' %getTreeDepth(myTree))
number of leafs is: 4
depth of tree is: 3
#使用文本注解繪制樹節(jié)點(diǎn)
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle='sawtooth',fc='0.8')
leafNode = dict(boxstyle='round4',fc='.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlotTree.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va = "center", \
ha = "center", bbox = nodeType, arrowprops = arrow_args)
#獲取葉子節(jié)點(diǎn)個(gè)數(shù)和樹的深度
def getNumLeafs(myTree):
firstNode = list(myTree.keys())[0]
secondDict = myTree[firstNode]
numLeafs = 0
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstNode = list(myTree.keys())[0]
secondDict = myTree[firstNode]
#遍歷所有節(jié)點(diǎn)來計(jì)算這個(gè)節(jié)點(diǎn)的深度,找到最深那條分支,裝進(jìn)袋子里(裝袋法)。
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
#本節(jié)點(diǎn)深度 = 1(本節(jié)點(diǎn))+ 子樹深度
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotMidText(centerPt, parentPt, txtString):
xMid = centerPt[0] + (parentPt[0] - centerPt[0])/2
yMid = centerPt[1] + (parentPt[1] - centerPt[1])/2
createPlotTree.ax1.text(xMid,yMid,txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstNode = list(myTree.keys())[0]
centerPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(centerPt, parentPt,nodeTxt)
plotNode(firstNode, centerPt, parentPt,decisionNode)
secondDict = myTree[firstNode]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centerPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff, plotTree.yOff),centerPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),centerPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlotTree(inTree):
fig = plt.figure(facecolor='white')
axprops = dict(xticks=[],yticks=[])
createPlotTree.ax1 = plt.subplot(frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(.5,1.0),'')
plt.show()
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1:{'flippers': {0: 'no', 1: 'yes'}}}}}}
createPlotTree(myTree)

output_7_0.png
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1:'yes'}}}}
featureName = ['no surfacing','flippers']
testVector = [1,1]
#分類
def classify(myTree, featureName, testVector):
firstNode = list(myTree.keys())[0]
#找到第一個(gè)特征所對(duì)應(yīng)在testVector的值
indexOfFirstNode = featureName.index(firstNode)
valOfFirstNodeInTestVec = testVector[indexOfFirstNode]
secondDict = myTree[firstNode]
for val in secondDict: #遍歷子樹的key,其實(shí)就是該節(jié)點(diǎn)的value
if valOfFirstNodeInTestVec == val:
if type(secondDict[val]).__name__ == 'dict':
res = classify(secondDict[val],featureName, testVector)
else:
res = secondDict[val]
return res
classify(myTree, featureName, testVector)
'yes'
#存儲(chǔ)決策樹
def dumpTree(myTree,filename):
import pickle
fw = open(filename,'wb') #以二進(jìn)制的方式打開
pickle.dump(myTree,fw) #pickle存儲(chǔ)方式默認(rèn)是二進(jìn)制方式
fw.close()
def loadTree(filename):
import pickle
fr = open(filename,'rb')
tree = pickle.load(fr)
fr.close()
return tree
dumpTree(myTree,'classifierStorage.txt')
loadTree('classifierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
使用決策樹預(yù)測(cè)隱形眼鏡類型
def classifyLenses():
fr = open('../../Reference Code/Ch03/lenses.txt')
lenses = [line.strip().split('\t') for line in fr.readlines()]
feature = ['age','prescript','astigmatic','tearTate']
lenseTree = createTree(lenses,feature)
showTree = createPlotTree(lenseTree)
return showTree
classifyLenses()

output_11_0.png