4-6 网格搜索
Grid Search
准备数据
from sklearn import datasets
digits = datasets.load_digits()
X = digits.data
y = digits.target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)使用网格搜索
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
param_grid = [
{
'weights':['uniform'],
'n_neighbors': [i for i in range(1, 11)]
},
{
'weights':['distance'],
'n_neighbors': [i for i in range(1, 11)],
'p': [i for i in range(1, 6)]
}
]
knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid)
%%time
grid_search.fit(X_train, y_train)
察看运行结果
使用运行结果建立新的模型
其它GridSearchCV参数
更多的距离定义

Last updated