即split得到的损失更小。这是肯定的,因为CART树的生成算法就是这样的定义的。因为split得到的损失更小,才会split生成子结点。
但当a慢慢增大,对一个特定的结点t来说,当a取到某个值时,会有
def isLeaf(Node):
# return type(Node).__name__ != 'dict'
return not ('left' in Node.keys() or 'right' in Node.keys())
def calcAlphaList(Node):
if isLeaf(Node):
return
costNotSplit = Node['gini']
costSplit = Node['left']['gini'] + Node['right']['gini']
alpha = (costNotSplit-costSplit)/(Node['Tt']-1)
Node['alpha'] = alpha
if alpha < calcAlphaList.bestAlpha:
calcAlphaList.bestAlpha = alpha
calcAlphaList(Node['left'])
calcAlphaList(Node['right'])
def calcTt(Node):
if isLeaf(Node):
return 1
Node['Tt'] = calcTt(Node['left']) + calcTt(Node['right'])
return Node['Tt']
def cut(Node, alpha):
if Node['alpha'] == alpha:
Node.pop('left')
Node.pop('right')
else:
cut(Node['left'], alpha)
cut(Node['right'], alpha)
def prune(Tree):
i = 0
print ('i=',i,Tree)
while not isLeaf(Tree):
calcTt(Tree)
calcAlphaList.bestAlpha = np.inf
calcAlphaList(Tree)
i += 1
print ('i=',i,'alpha',calcAlphaList.bestAlpha)
cut(Tree, calcAlphaList.bestAlpha)
print (Tree)