CART树的剪枝

CART树的剪枝算法

输入:剪枝前的CART树 输出:剪枝后的CART树

原理

令: 损失函数为:

Ca(T)=C(T)+aTC_a(T) = C(T) + a|T|

对树上的所有结点计算: 假如该结点不split,该结点的损失为:

Ca(t)=C(t)+aTC_a(t) = C(t) + a|T|

假如该结点split,则该结点的损失为:

Ca(Tt)=C(Tt)+aTtC_a(T_t) = C(T_t) + a|T_t|

当a=0或a足够小时,有

Ca(Tt)<Ca(t)C_a(T_t) < C_a(t)

即split得到的损失更小。这是肯定的,因为CART树的生成算法就是这样的定义的。因为split得到的损失更小,才会split生成子结点。 但当a慢慢增大,对一个特定的结点t来说,当a取到某个值时,会有

Ca(Tt)=Ca(t)C_a(T_t) = C_a(t)

将公式(2)、(3)代码公式(5),得出:

a=C(t)C(Tt)Tt1a = \frac {C(t) - C(T_t)}{|T_t| - 1}

若对于某个结点来说,此时的a使得$C_a(T_t)< C_a(t) $,那么就应该剪枝了。

步骤

  1. 令当前树为$T_0$,表示未做过剪枝的CART树

  2. 自下而上【?】地对所有结点计算g(t),并同时记录最小的g(t)作为当前的alpha

    g(t)=C(t)C(Tt)Tt1g(t) = \frac {C(t) - C(T_t)}{|T_t| - 1}
  3. 对Tree中$g(t)=a$的结点做剪枝,得到的新的树$T_a$

  4. 如果当前的树不只一个根结点,就循环2-3步

  5. 通过交叉验证选出最优的$T_a$

代码

这个算法没想通

Last updated

Was this helpful?