決策樹的建立
1.整體思路
準(zhǔn)備函數(shù)
- 依據(jù)某個(gè)feature對數(shù)據(jù)進(jìn)行分割為set_1, set_2 --> divide_set
- 分別對set_1, set_2的分類結(jié)果進(jìn)行統(tǒng)計(jì) --> unique_count
- 根據(jù)統(tǒng)計(jì)的結(jié)果計(jì)算交叉熵 --> entropy
計(jì)算思路
- 對數(shù)據(jù)的列進(jìn)行for循環(huán),選擇出gain最大的feature
- 根據(jù)此feature進(jìn)行數(shù)據(jù)集的分割,然后再對set_1, set_2進(jìn)行遞歸
- 直至gain為0或要進(jìn)一步判斷的子數(shù)據(jù)集為空
2.python實(shí)現(xiàn)
主干代碼
def build_tree(rows, scoref=entropy):
# 基準(zhǔn)情況
if len(rows) == 0:
return DecisionNode()
current_score = scoref(rows) # 分類前的得分
best_gain = 0.0
best_criteria = None
best_sets = None
column_count = len(rows[0]) - 1 # 特征數(shù)量
for col in range(column_count):
# 在當(dāng)前列中生成一個(gè)由不同值構(gòu)成的序列
column_values = {}
for row in rows:
column_values[row[col]] = 1
# 分類
for value in column_values.keys():
set_1, set_2 = divide_set(rows, col, value)
p = float(len(set_1)) / len(rows)
gain = current_score - p * scoref(set_1) - (1 - p) * scoref(set_2)
if gain > best_gain and len(set_1) > 0 and len(set_2) > 0:
best_gain = gain
best_criteria = (col, value)
best_sets = (set_1, set_2)
# 創(chuàng)建子分支
if best_gain > 0:
# 不是葉子結(jié)點(diǎn),繼續(xù)遞歸分類,分類結(jié)果res=None, 判斷條件(特征)為col,臨界值為value
true_branch = build_tree(best_sets[0])
false_branch = build_tree(best_sets[1])
return DecisionNode(col=best_criteria[0], value=best_criteria[1], tb=true_branch, fb=false_branch)
else:
# 不能再分類,返回分類的計(jì)數(shù)結(jié)果
return DecisionNode(results=unique_counts(rows))
DecisionNode類
class DecisionNode:
def __init__(
self,
col=-1,
value=None,
results=None,
tb=None,
fb=None
):
self.col = col # the criteria to be tested
self.value = value # true value
self.results = results # 分類結(jié)果,非葉子結(jié)點(diǎn)均為None
self.tb = tb # true
self.fb = fb # false
divide_set分割數(shù)據(jù)
def divide_set(rows, column, value):
# 根據(jù)value對數(shù)據(jù)進(jìn)行2分類,set_1中為true, set_2中為false
split_function = None
if isinstance(value, int) or isinstance(value, float):
split_function = lambda row: row[column] >= value
else:
split_function = lambda row: row[column] == value
set_1 = [row for row in rows if split_function(row)]
set_2 = [row for row in rows if not split_function(row)]
return set_1, set_2
unique_counts對分類結(jié)果計(jì)數(shù)
def unique_counts(rows):
results = {}
for row in rows:
r = row[len(row) - 1] # 分類結(jié)果:None, Basic, Premium
if r not in results:
results[r] = 0
results[r] += 1
return results
entropy計(jì)算交叉熵
def entropy(rows):
results = unique_counts(rows)
ent = 0.0
for r in results.keys():
p = float(results[r]) / len(rows)
ent -= p * log2(p)
return ent
3.運(yùn)行測試
測試數(shù)據(jù)
my_data = [['slashdot', 'USA', 'yes', 18, 'None'],
['google', 'France', 'yes', 23, 'Premium'],
['digg', 'USA', 'yes', 24, 'Basic'],
['kiwibotes', 'France', 'yes', 23, 'Basic'],
['google', 'UK', 'no', 21, 'Premium'],
['(direct)', 'New Zealand', 'no', 12, 'None'],
['(direct)', 'UK', 'no', 21, 'Basic'],
['google', 'USA', 'no', 24, 'Premium'],
['slashdot', 'France', 'yes', 19, 'None'],
['digg', 'USA', 'no', 18, 'None'],
['google', 'UK', 'no', 18, 'None'],
['kiwitobes', 'UK', 'no', 19, 'None'],
['digg', 'New Zealand', 'yes', 12, 'Basic'],
['google', 'UK', 'yes', 18, 'Basic'],
['kiwitobes', 'France', 'yes', 19, 'Basic']]
展示結(jié)果
def print_tree(tree, indent=''):
# 葉子結(jié)點(diǎn),其results為分類結(jié)果;否則,其results為None
if tree.results is not None:
print(str(tree.results))
else:
# 打印判斷條件
print(str(tree.col) + ':' + str(tree.value) + "?")
# 打印分支
print(indent + "T->", end='')
print_tree(tree.tb, indent+' ')
print(indent + "F->", end='')
print_tree(tree.fb, indent+' ')
運(yùn)行結(jié)果
3:21?
T->0:google?
T->{'Premium': 3}
F->{'Basic': 3}
F->2:yes?
T->0:slashdot?
T->{'None': 2}
F->{'Basic': 3}
F->{'None': 4}
可以驗(yàn)證,決策樹在訓(xùn)練集上準(zhǔn)確率為100%