> For the complete documentation index, see [llms.txt](https://windmising.gitbook.io/lihang-tongjixuexifangfa/llms.txt). Markdown versions of documentation pages are available by appending `.md` to page URLs; this page is available as [Markdown](https://windmising.gitbook.io/lihang-tongjixuexifangfa/cart/7.md).

# CART树的剪枝

## CART树的剪枝算法

输入：剪枝前的CART树\
输出：剪枝后的CART树

### 原理

令：\
损失函数为：

$$
C\_a(T) = C(T) + a|T|
$$

对树上的所有结点计算：\
假如该结点不split，该结点的损失为：

$$
C\_a(t) = C(t) + a|T|
$$

假如该结点split，则该结点的损失为：

$$
C\_a(T\_t) = C(T\_t) + a|T\_t|
$$

当a=0或a足够小时，有

$$
C\_a(T\_t) < C\_a(t)
$$

即split得到的损失更小。这是肯定的，因为CART树的生成算法就是这样的定义的。因为split得到的损失更小，才会split生成子结点。\
但当a慢慢增大，对一个特定的结点t来说，当a取到某个值时，会有

$$
C\_a(T\_t) = C\_a(t)
$$

将公式（2）、（3）代码公式（5），得出:

$$
a = \frac {C(t) - C(T\_t)}{|T\_t| - 1}
$$

若对于某个结点来说，此时的a使得$C\_a(T\_t)< C\_a(t) $，那么就应该剪枝了。

### 步骤

1. 令当前树为$T\_0$，表示未做过剪枝的CART树 &#x20;
2. **自下而上【？】**&#x5730;对所有结点计算g(t)，并同时记录最小的g(t)作为当前的alpha &#x20;

   $$
   g(t) = \frac {C(t) - C(T\_t)}{|T\_t| - 1}
   $$
3. 对Tree中$g(t)=a$的结点做剪枝，得到的新的树$T\_a$ &#x20;
4. 如果当前的树不只一个根结点，就循环2-3步&#x20;
5. 通过交叉验证选出最优的$T\_a$  &#x20;

## 代码

```python
def isLeaf(Node):
    # return type(Node).__name__ != 'dict'
    return not ('left' in Node.keys() or 'right' in Node.keys())

def calcAlphaList(Node):
    if isLeaf(Node):
        return
    costNotSplit = Node['gini']
    costSplit = Node['left']['gini'] + Node['right']['gini']
    alpha = (costNotSplit-costSplit)/(Node['Tt']-1)
    Node['alpha'] = alpha
    if alpha < calcAlphaList.bestAlpha:
        calcAlphaList.bestAlpha = alpha
    calcAlphaList(Node['left'])
    calcAlphaList(Node['right'])

def calcTt(Node):
    if isLeaf(Node):
        return 1
    Node['Tt'] = calcTt(Node['left']) + calcTt(Node['right'])
    return Node['Tt']

def cut(Node, alpha):
    if Node['alpha'] == alpha:
        Node.pop('left')
        Node.pop('right')
    else:
        cut(Node['left'], alpha)
        cut(Node['right'], alpha)

def prune(Tree):
    i = 0
    print ('i=',i,Tree)
    while not isLeaf(Tree):
        calcTt(Tree)
        calcAlphaList.bestAlpha = np.inf
        calcAlphaList(Tree)
        i += 1
        print ('i=',i,'alpha',calcAlphaList.bestAlpha)
        cut(Tree, calcAlphaList.bestAlpha)
        print (Tree)
```

## 这个算法没想通

![](http://windmissing.github.io/images_for_gitbook/LiHang-TongJiXueXiFangFa/6.png)
