8-5 学习曲线
学习曲线
欠拟合、拟合、过拟合和学习曲线图对比
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
train_score = []
test_score = []
for i in range(1, len(X_train)+1):
algo.fit(X_train[:i], y_train[:i])
y_train_predict = algo.predict(X_train[:i])
train_score.append(mean_squared_error(y_train[:i], y_train_predict))
y_test_predict = algo.predict(X_test)
test_score.append(mean_squared_error(y_test, y_test_predict))
plt.plot([i for i in range(1, len(X_train)+1)], np.sqrt(train_score), label="train")
plt.plot([i for i in range(1, len(X_train)+1)], np.sqrt(test_score), label="test")
plt.legend()
plt.axis([0, len(X_train)+1, 0, 4])
plt.show()线性回归,欠拟合

2阶多项式回归,拟合

20阶多项式回归,过拟合

总结
Last updated


