Theano入門 Logistic回歸

2021-07-05 16:57:47 字數 1984 閱讀 1101

參考鏈結(1)。

(1)輸入權重w

輸入權重矩陣的維度需要設定,行影象的大小,列為標籤one-hot編碼後的大小,cifar-10共有10類,所以為10;mnist標籤為0~9,也是10。

(2)學習率learning_rate

這裡為固定學習率。學習率越大收斂越快,但接近結果時可能會跳過最優解,逼近曲線鋸齒化明顯;學習率越小則相反。

(3)塊大小batch_size

batch_size越小收斂越慢,引入的隨機性越大,效果會更好。batch_size越大則相反。

注釋部分為mnist資料集的內容和設定,有效部分為cifar-10資料集的內容和設定。

(1)載入資料集

mnist函式載入mnist資料集,one-hot設定編碼格式;cifar10設定cifar-10資料集,dtype設定輸出資料的資料型別。

(2)模型初始化

先產生32*32行10列的權重,隨機數產生的例子如下:

(3)logistic回歸模型

假設有x和y矩陣,x為模型輸入,y為模型輸出。model函式先讓x和w相乘,x維度為(樣本數*32*32),w維度為(32*32,10),相乘且經過softmax回歸後的py_x的維度為(樣本數*10)。y_pred為py_x按列尋找出最大值,即每個樣本的最大的概率按行順序組成y_pred。

學習率設定為0.01。

categorical_crossentropy函式計算的近似概率密度分布py_x和實際概率密度分布y的交叉熵。交叉熵為py_x和y每一位的交叉熵,共10個數。這裡用求平均表示的整體交叉熵作為損失cost。熵越大,結果的不確定性越大。

grad函式計算損失的梯度,以固定的學習率learning_rate來更新權重。update為更新值的格式。

(4)交叉熵

同一事件集合上的兩個不同的概率分布間的交叉熵計算事件產生的位的平均值(這裡是指整個資料集標籤經過one-hot編碼後的標籤每位的平均值)。假設計算得到的概率密度分布為q,實際概率密度分布為p,則交叉熵為:h(p,q)=h(p)+dkl(p||q),其中h(p)為p的熵,dkl(p||q)為q相對於p的kl散度。離散的交叉熵為p(x)logq(x)的和的負值。

(5)theano的function函式說明

function函式包含輸入inputs,輸出outputs,更新updates和allow_input_downcast等。輸入和輸出不難理解,allow_input_downcast是允許資料精度降低。以**中的w為例,update表示函式每更新一次,權重w就會被w - gradient * learning_rate替換一次。

function左側的變數train和predict理解為函式名。

(6)模型訓練

模型訓練以塊為單位進行。batch_size決定了塊大小。start為塊的起始索引,end為塊的結束索引。所以train執行的次數小於等於塊的個數。外層執行100次,每次執行一次epoch時權重更新塊的個數次。

(7)模型**

predict函式輸入為x,輸出為y_pred。此時比較tex的輸出的y_pred的**索引和tey中的實際索引,計算相同索引的均值表示準確率accuracy。

logistic回歸對於mnist資料集的準確率大約有91%,而對於cifar-10資料集的準確率大約只有28%(如下圖),cifar-100資料集準確率更小。

(2)交叉熵:

(3)batch_size討論:

theano 入門教程1 7

1.7.1計算導數 使用theano.tensor.grad 函式計算梯度 import theano import theano.tensor as t from theano import pp x t.dscalar x y x 2 gy t.grad y,x pp gy fill x ten...

7 theano 安裝 學習theano的筆記

之前寫過有關theano在macos上的安裝部署,昨天又在windows上安裝了一版,發現還有theano.test 這種玩法,不知道有多少人的機器是全部通過的,6000多個測試程式,不發生問題的概率估計是不大。我自己的機器上現在還有7個failure,不打算找問題了,先用著再說。天天搞深度學習的,...

theano學習筆記

定義函式import theano.tensor as t from theano import function,pp 標量 x t.dscalar x 向量 x t.vector a 矩陣 x t.dmatrix x y t.dscalar y z x y f function x,y z 函式...