sklearn交叉驗證3 老魚學sklearn

2022-02-10 04:38:28 字數 2606 閱讀 3139

在上乙個博文中,我們用learning_curve函式來確定應該擁有多少的訓練集能夠達到效果,就像乙個人進行學習時需要做多少題目就能擁有較好的考試成績了。

本次我們來看下如何調整學習中的引數,類似乙個人是在早上7點鐘開始讀書好還是晚上8點鐘讀書好。

資料仍然利用手寫數字識別作為訓練資料:

from sklearn.datasets import load_digits

# 載入資料

digits = load_digits()

x = digits.data

y = digits.target

我們想要調整·svc(gamma=0.001)·svc中的gamma引數,看到底把gamma引數設定成哪個值是最優的。

因此需要定義測試的引數範圍,這裡設定了引數值的範圍為從10的-6次方到10的-2.3次方,總共5個值:

import numpy as np

# 定義gamma引數的可能取值範圍,從10**-6, 到10**-2.3,總共5個引數值

param_range = np.logspace(-6, -2.3, 5)

validation_curve不停嘗試在不同引數值下的損失函式值:

from sklearn.model_selection import validation_curve

from sklearn.svm import svc

# param_name中指定了修改svc中的哪個引數值,這裡修改的是gamma引數值;param_range引數指定了具體引數值的可選範圍

train_loss, test_loss = validation_curve(svc(), x, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')

train_loss_mean = -np.mean(train_loss, axis=1)

test_loss_mean = -np.mean(test_loss, axis=1)

視覺化圖形,橫座標為引數可選值的範圍,縱座標為在各引數下的損失函式值

# 視覺化圖形,橫座標為引數可選值的範圍,縱座標為在各引數下的損失函式值

import matplotlib.pyplot as plt

plt.plot(param_range, train_loss_mean, label="train")

plt.plot(param_range, test_loss_mean, label="test")

plt.legend()

plt.show()

圖形顯示為:

在這個圖形中,我們發現gamma值有乙個轉折點,當其在0.001之後,測試集的誤差值就開始擴大了,因此,從圖形上看,乙個比較好的學習引數值是gamma=0.001或者再往前一點點,大概在0.0007左右。

完整的**如下:

from sklearn.datasets import load_digits

# 載入資料

digits = load_digits()

x = digits.data

y = digits.target

import numpy as np

# 定義gamma引數的可能取值範圍,從10**-6, 到10**-2.3,總共5個引數值

param_range = np.logspace(-6, -2.3, 5)

from sklearn.model_selection import validation_curve

from sklearn.svm import svc

# param_name中指定了修改svc中的哪個引數值,這裡修改的是gamma引數值;param_range引數指定了具體引數值的可選範圍

train_loss, test_loss = validation_curve(svc(), x, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')

train_loss_mean = -np.mean(train_loss, axis=1)

test_loss_mean = -np.mean(test_loss, axis=1)

# 視覺化圖形,橫座標為引數可選值的範圍,縱座標為在各引數下的損失函式值

import matplotlib.pyplot as plt

plt.plot(param_range, train_loss_mean, label="train")

plt.plot(param_range, test_loss_mean, label="test")

plt.legend()

plt.show()

sklearn交叉驗證 老魚學sklearn

交叉驗證 cross validation 有時亦稱迴圈估計,是一種統計學上將資料樣本切割成較小子集的實用方法。於是可以先在乙個子集上做分析,而其它子集則用來做後續對此分析的確認及驗證。一開始的子集被稱為訓練集。而其它的子集則被稱為驗證集或測試集。交叉驗證是一種評估統計分析 機器學習演算法對獨立於訓...

sklearn資料庫 老魚學sklearn

在做機器學習時需要有資料進行訓練,幸好sklearn提供了很多已經標註好的資料集供我們進行訓練。本節就來看看sklearn提供了哪些可供訓練的資料集。這些資料位於datasets中,為 載入波士頓房價資料,可以用於線性回歸用 sklearn.datasets.load boston 載入方式為 fr...

sklearn儲存模型 老魚學sklearn

訓練好了乙個model 以後總需要儲存和再次 所以儲存和讀取我們的sklearn model也是同樣重要的一步。比如,我們根據 樣本資料訓練了一下房價模型,當使用者輸入自己的房子後,我們就需要根據訓練好的房價模型來 使用者房子的 這樣就需要在訓練模型後把模型儲存起來,在使用模型時把模型讀取出來對輸入...