from collections import Counter
from math import log
def entropy(y):
counter = Counter(y) # counter是键值数据对,键是y的取值,值是y取这个键个数据
res = 0.0
for num in counter.values():
p = num / len(y)
res += -p * log(p)
return res
寻找熵最小的d和value
def try_split(X, y):
best_entropy = float('inf')
best_d, best_v = -1, -1
for d in range(X.shape[1]): # 穷搜每一个维度
sorted_index = np.argsort(X[:,d])
for i in range(1, len(X)): # 对每个样本遍历,可选的域值为两个点之间的值
if X[sorted_index[i-1], d] != X[sorted_index[i], d]:
v = (X[sorted_index[i-1], d] + X[sorted_index[i], d]) / 2
x_l, x_r, y_l, y_r = split(X, y, d, v)
e = entropy(y_l) + entropy(y_r)
if e < best_entropy:
best_entropy, best_d, best_v = e, d, v
return best_entropy, best_d, best_v