决策树的剪枝算法
递归版本
def isTree(Node):
return 'child' in Node.keys()
def Clip(Node):
bestNt = 0
for value in Node['child']:
if Node['child'][value]['Nt'] > bestNt:
bestNt = Node['child'][value]['Nt']
bestLabel = Node['child'][value]['label']
Node['label'] = bestLabel
Node.pop('child')
def Merge(Node, alpha):
# 计算CaTb
CT_b = 0
for value in Node['child']:
CT_b = CT_b + Node['child'][value]['Nt'] * Node['child'][value]['entropy'] + alpha
# 计算CaTa
CT_a = Node['entropy'] + alpha
# 剪枝的条件
if CT_a <= CT_b:
Clip(Node)
def prune(Node, alpha):
# 判断子结点中是否存在树
for value in Node['child']:
# 如果存在树
if isTree(Node['child'][value]):
# 先对树子结点做prune
prune(Node['child'][value], alpha)
# 对所有子树都prune之后,判断是否所有子树都是叶子
isAllLeaf = True
for value in Node['child']:
if isTree(Node['child'][value]):
isAllLeaf = False
break
# 如果所有子树都是叶子
if isAllLeaf:
# 尝试对结点做剪枝
Merge(Node, alpha)DP版本
Last updated