使用交叉驗證對鳶尾花分類模型進行調參 超引數

2021-08-29 02:07:46 字數 4084 閱讀 8804

如圖,

大訓練集分塊,使用不同的分塊方法分成n對小訓練集驗證集

使用小訓練集進行訓練,使用驗證集進行驗證,得到準確率,求n個驗證集上的平均正確率

使用平均正確率最高的超引數,對整個大訓練集進行訓練,訓練出引數。

訓練集上訓練。

十折交叉驗證

諸如你有多個可調節的超引數,那麼選擇超引數的方法通常是網格搜尋,即固定乙個參、變化其他參,像網格一樣去搜尋。

"""任務:鳶尾花識別

"""import

pandas as pd

from sklearn.model_selection import

train_test_split, gridsearchcv

from sklearn.neighbors import

kneighborsclassifier

from sklearn.linear_model import

logisticregression

from sklearn.svm import

svcdata_file = '

./data_ai/iris.csv

'species_label_dict =

#使用的特徵列

feat_cols = ['

sepallengthcm

', '

sepalwidthcm

', '

petallengthcm

', '

petalwidthcm']

defmain():

"""主函式

"""#

讀取資料集

iris_data = pd.read_csv(data_file, index_col='id'

) iris_data[

'label

'] = iris_data['

species

'].map(species_label_dict)

#獲取資料集特徵

x =iris_data[feat_cols].values

#獲取資料標籤

y = iris_data['

label

'].values

#劃分資料集

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=1/3, random_state=10)

model_dict =

),'logistic regression':

(logisticregression(),

),'svm':

(svc(),

)}

#名稱+元組

for model_name, (model, model_params) in

model_dict.items():

#訓練模型

clf = gridsearchcv(estimator=model, param_grid=model_params, cv=5) #

模型、引數、折數

clf.fit(x_train, y_train) #

訓練 best_model = clf.best_estimator_ #

最佳模型的物件#驗證

acc =best_model.score(x_test, y_test)

print('

{}模型的**準確率:%

'.format(model_name, acc * 100))

print('

{}模型的最優引數:{}

'.format(model_name, clf.best_params_)) #

最好的模型名稱和引數

if__name__ == '

__main__':

main()

執行結果:

knn模型的**準確率:96.00%

knn模型的最優引數:

logistic regression模型的**準確率:96.00%

logistic regression模型的最優引數:

svm模型的**準確率:98.00%

svm模型的最優引數:

練習:使用交叉驗證對水果分類模型進行調參

可能的**

import

pandas as pd

from sklearn.model_selection import

gridsearchcv, train_test_split

from sklearn.neighbors import

kneighborsclassifier

from sklearn.linear_model import

logisticregression

from sklearn.svm import

svc#

讀取資料

data = pd.read_csv('

./data_ai/fruit_data.csv')

#資料處理

fruit_dict =

data[

'label

'] = data['

fruit_name

'].map(fruit_dict)

feat_cols = ['

mass

','width

','height

','color_score']

#資料提取

x =data[feat_cols].values

y = data['

label

'].values

x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=1/5, random_state= 3)

model_dict = ),

'logestic regression

': (logisticregression(), ),

'svm

': (svc(), )

}for model_name, (model, model_para) in

model_dict.items():

#訓練clf = gridsearchcv(estimator=model, param_grid=model_para, cv=5) #

模型、引數、折數

clf.fit(x_train,y_train)

best_model =clf.best_estimator_

#驗證acc =best_model.score(x_test, y_test)

print(f'

中選擇為引數的**準確率最好,準確率可達%

')

執行結果:

knn中選擇為引數的**準確率最好,準確率可達66.66666666666666%

logestic regression中選擇為引數的**準確率最好,準確率可達91.66666666666666%

svm中選擇為引數的**準確率最好,準確率可達50.0%

利用KNN對鳶尾花資料進行分類

knn k nearest neighbor 工作原理 存在乙個樣本資料集合,也稱為訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類對應的關係。輸入沒有標籤的資料後,將新資料中的每個特徵與樣本集中資料對應的特徵進行比較,提取出樣本集中特徵最相似資料 最近鄰 的分類標籤...

交叉驗證與網格搜尋(以KNN分類鳶尾花為例)

總結 import pandas as pd pd.set option display.max rows 6 1獲取資料 from sklearn.datasets import load iris iris load iris 新建乙個dataframe,把iris中data方進來,並且data...

使用KNN對鳶尾花資料集進行分類處理

author tao contact 1281538933 qq.com file knn.py time 2020 12 21 software vscode from sklearn.datasets import load iris 匯入資料集iris import matplotlib.py...