CART树的剪枝
CART树的剪枝算法
输入:剪枝前的CART树 输出:剪枝后的CART树
原理
令: 损失函数为:
对树上的所有结点计算: 假如该结点不split,该结点的损失为:
假如该结点split,则该结点的损失为:
当a=0或a足够小时,有
即split得到的损失更小。这是肯定的,因为CART树的生成算法就是这样的定义的。因为split得到的损失更小,才会split生成子结点。 但当a慢慢增大,对一个特定的结点t来说,当a取到某个值时,会有
将公式(2)、(3)代码公式(5),得出:
若对于某个结点来说,此时的a使得$C_a(T_t)< C_a(t) $,那么就应该剪枝了。
步骤
- 令当前树为$T_0$,表示未做过剪枝的CART树 
- 自下而上【?】地对所有结点计算g(t),并同时记录最小的g(t)作为当前的alpha 
- 对Tree中$g(t)=a$的结点做剪枝,得到的新的树$T_a$ 
- 如果当前的树不只一个根结点,就循环2-3步 
- 通过交叉验证选出最优的$T_a$ 
代码
def isLeaf(Node):
    # return type(Node).__name__ != 'dict'
    return not ('left' in Node.keys() or 'right' in Node.keys())
def calcAlphaList(Node):
    if isLeaf(Node):
        return
    costNotSplit = Node['gini']
    costSplit = Node['left']['gini'] + Node['right']['gini']
    alpha = (costNotSplit-costSplit)/(Node['Tt']-1)
    Node['alpha'] = alpha
    if alpha < calcAlphaList.bestAlpha:
        calcAlphaList.bestAlpha = alpha
    calcAlphaList(Node['left'])
    calcAlphaList(Node['right'])
def calcTt(Node):
    if isLeaf(Node):
        return 1
    Node['Tt'] = calcTt(Node['left']) + calcTt(Node['right'])
    return Node['Tt']
def cut(Node, alpha):
    if Node['alpha'] == alpha:
        Node.pop('left')
        Node.pop('right')
    else:
        cut(Node['left'], alpha)
        cut(Node['right'], alpha)
def prune(Tree):
    i = 0
    print ('i=',i,Tree)
    while not isLeaf(Tree):
        calcTt(Tree)
        calcAlphaList.bestAlpha = np.inf
        calcAlphaList(Tree)
        i += 1
        print ('i=',i,'alpha',calcAlphaList.bestAlpha)
        cut(Tree, calcAlphaList.bestAlpha)
        print (Tree)这个算法没想通

Last updated
Was this helpful?