機器學習10 神經網路反向傳播演算法

2021-07-30 06:04:59 字數 2048 閱讀 8512

一,神經網路的代價函式

首先引入一些便於稍後討論的新標記方法:假設神經網路的訓練樣本有 m 個,每個包含一組輸入 x 和一組輸出訊號 y,l 表示神經網路層數,sl 表示每層的 neuron 個數(sl 表示輸出層神經元個數),sl  代表最後一層中處理單元的個數。將神經網路的分類定義為兩種情況:二類分類和多類分類,二類分類:sl=1, y=0 or 1 表示哪一類;k 類分類:sl=k, yi = 1 表示分到第 i 類;(k>2)

我們回顧邏輯回歸問題中我們的代價函式為:

在邏輯回歸中,我們只有乙個輸出變數,又稱標量(scalar),也只有乙個因變數 y,但是在神經網路中,我們可以有很多輸出變數,我們的 hθ(x)是乙個維度為 k 的向量,並且我們訓練集中的因變數也是同樣維度的乙個向量,因此我們的代價函式會比邏輯回歸更加複雜一些,為:

這個看起來複雜很多的代價函式背後的思想還是一樣的,我們希望通過代價函式來觀察演算法**的結果與真實情況的誤差有多大,唯一不同的是,對於每一行特徵,我們都會給出k 個**,基本上我們可以利用迴圈,對每一行特徵都** k 個不同結果,然後在利用迴圈在 k 個**中選擇可能性最高的乙個,將其與 y 中的實際資料進行比較。歸一化的那一項只是排除了每一層 θ0 後,每一層的 θ 矩陣的和。最裡層的迴圈 j 迴圈

所有的行(由 sl +1 層的啟用單元數決定),迴圈 i 則迴圈所有的列,由該層(sl 層)的啟用單元數所決定。即:hθ(x)與真實值之間的距離為每個樣本-每個類輸出的加和,對引數進行regularization 的 bias 項處理所有引數的平方和。

二,反向傳播演算法反向傳播演算法作用:求出神經網路代價函式的偏導數,便於進行梯度下降

之前我們在計算神經網路**結果的時候我們採用了一種正向傳播方法,我們從第一層開始正向一層一層進行計算,直到最後一層的 hθ(x)。現在,為了計算代價函式的偏導數,我們需要採用一種反向傳播演算法,也就是首先計算最後一層的誤差,然後再一層一層反向求出各層的誤差,直到倒數第二層。以乙個例子來說明反向傳播演算法。假設我們的訓練集只有乙個例項(x(1),y(1)),我們的神經網路是乙個四層的神經網路,其中 k=4,sl=4,l=4:

前向傳播演算法:

反向傳播演算法:

我們從最後一層的誤差開始計算,誤差是啟用單元的**(ak)與實際值(yk)之間的誤差,(k=1:k)。我們用 δ 來表示誤差,則:

重要的是清楚地知道上面式子中上下標的含義:

l 代表目前所計算的是第幾層

j 代表目前計算層中的啟用單元的下標,也將是下一層的第 j 個輸入變數的下標。

i 代表下一層中誤差單元的下標,是受到權重矩陣中第 i 行影響的下一層中的誤差單元的下標。

如果我們考慮歸一化處理,並且我們的訓練集是乙個特徵矩陣而非向量。在上面的特殊情況中,我們需要計算每一層的誤差單元來計算代價函式的偏導數。在更為一般的情況中,我們同樣需要計算每一層的誤差單元,但是我們需要為整個訓練集計算誤差單元,此時的誤差單元也是乙個矩陣,我們用 來表示這個誤差矩陣。第 l 層的第 i 個啟用單元受到第 j個引數影響而導致的誤差。

我們的演算法表示為:

即首先用正向傳播方法計算出每一層的啟用單元,利用訓練集的結果與神經網路**的結果求出最後一層的誤差,然後利用該誤差運用反向傳播法計算出直至第二層的所有誤差。在求出了之後,我們便可以計算代價函式的偏導數了,計算方法如下:

機器學習 反向傳播神經網路推導

簡單的反向傳播神經網路可以說就是將基本單元如感知器,組成層級結構,劃分出輸入層 隱含層 輸出層 不同層之間通過連線來形成耦合,從而組成乙個有功用的網狀演算法結構。感知器可以通過迭代計算來逼近想獲取的結果,迭代過程中感知器不斷和計算結果反饋,較為常用的迭代計算方法有梯度下降法。當感知器組成網路之後,每...

神經網路學習引入 反向傳播

反向傳播技術是神經網路的核心技術之一。如果乙個運算我們能夠使用計算圖表式出來,那麼我們就能夠利用反向傳播的方法,即從後往前遞迴地呼叫鏈式法則來計算梯度。乙個簡單的函式f x y,z x y zf x,y,z x y z f x,y,z x y z的計算圖如下所示。假設輸入為 x 2,y 5,z 4,...

機器學習 神經網路引數的反向傳播演算法

分類問題為多元分類和二元分類 delta是沒有1層的,因為第一層是我們的觀測值,沒有誤差 經過後面的推導,忽略了 正則項 就會得到左下角代價函式的偏導項 反向傳播演算法計算後,累計誤差的結果等於對代價函式的偏導。i 指的是第幾個樣本 j 指的是第 l 層的第 j 個節點 反向傳播具體的 error計...