9-8 OvR与OvO
Last updated
Last updated
逻辑回归只能解决二分类问题 解决方法:OvR(One vs Rest), OvO(One vs One)
N个类型就进行N次分类,选择得分最高的 对于逻辑回归,这里的分类是指分类的概率
N个类别就进行C(N,2)次分类,选择赢数最多的分类
OvO算法耗时更多,但分类更准确。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression(multi_class='ovr')
log_reg.fit(X_train, y_train)
log_reg.score(X_test, y_test) # 0.6578947368421053
plot_decision_boundary(log_reg, axis=[4,8,1.5,4.5]) # 见9-5
plt.scatter(iris.data[iris.target==0,0],iris.data[iris.target==0,1], color='red')
plt.scatter(iris.data[iris.target==1,0],iris.data[iris.target==1,1], color='blue')
plt.scatter(iris.data[iris.target==2,0],iris.data[iris.target==2,1], color='green')
plt.show()
from sklearn.linear_model import LogisticRegression
log_reg2 = LogisticRegression(multi_class='multinomial', solver='newton-cg')
log_reg2.fit(X_train, y_train)
log_reg2.score(X_test, y_test) # 0.7894736842105263
plot_decision_boundary(log_reg2, axis=[4,8,1.5,4.5])
plt.scatter(iris.data[iris.target==0,0],iris.data[iris.target==0,1], color='red')
plt.scatter(iris.data[iris.target==1,0],iris.data[iris.target==1,1], color='blue')
plt.scatter(iris.data[iris.target==2,0],iris.data[iris.target==2,1], color='green')
plt.show()
X = iris.data
y = iris.target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression(multi_class='ovr')
log_reg.fit(X_train, y_train)
log_reg.score(X_test, y_test) # 0.9473684210526315
log_reg2 = LogisticRegression(multi_class='multinomial', solver='newton-cg')
log_reg2.fit(X_train, y_train)
log_reg2.score(X_test, y_test) # 1.0
log_reg = LogisticRegression()
from sklearn.multiclass import OneVsRestClassifier
ovr = OneVsRestClassifier(log_reg)
ovr.fit(X_train, y_train)
ovr.score(X_test, y_test) # 0.9473684210526315
from sklearn.multiclass import OneVsOneClassifier
log_reg = LogisticRegression()
ovo = OneVsOneClassifier(log_reg)
ovo.fit(X_train, y_train)
ovo.score(X_test, y_test) # 1.0