shortcut和殘差連線

2021-10-24 22:38:23 字數 1459 閱讀 9075

最近搜尋了下這幾兩個概念,記錄一下個人理解。

若沒有加入identity分支,那麼就是用非線性變化函式來描述乙個網路的輸入輸出,即輸入為x,輸出為f(x),f通常包括了卷積,啟用等操作。

但是當我們強行將乙個輸入新增到函式的輸出的時候,雖然我們仍然可以用g(x)來描述輸入輸出的關係,但是這個g(x)卻可以明確的拆分為f(x)和x的線性疊加。將輸出表述為輸入和輸入的乙個非線性變換的線性疊加。它解決了深層網路無法訓練的問題。

首先我們定義殘差單元:

xl和xl+1表示的是第l個殘差單元的輸入和輸出,f是殘差結構,表示學習到的殘差,當h(xl)=xl時表示的就是恒等對映,f是relu啟用函式。

通過遞迴,可以得到任意深層單元l特徵的表達:

反向傳播過程為:

表示損失函式到達l的梯度,小括號裡的1表示短路機制(identity x)可以無損地傳播梯度,而另一項殘差梯度則需要經過帶有weights的層,殘差梯度不會那麼巧全為-1,就算其很小,由於1的存在不會導致梯度消失,所以殘差學習會更容易。

再舉個例子看看殘差網路是如何改善梯度消失現象的:

假設輸入只有乙個特徵,沒有偏置單元,每層只有乙個神經元:

我們先進行前向傳播,這裡將sigmoid激勵函式寫為s(x):

z1 = w1*x

a1 = s(z1)

z2 = w2*a1

a2 = s(z2)

zn = wn*an-1 (這裡n-1是下標)

an = s(zn)

根據鏈式求導和反向傳播,我們很容易得出,其中c是代價函式

那如果在a1和a2之間加入殘差連線,如下所示:

那麼z2=a1*w2+a1

所以z2對a1求導的結果就是(w2+1)

上邊的鏈式求導、反向傳輸的結果中的w2就變成了(w2+1)

所以殘差連線可以有效緩解梯度消失的現象。

最後乙個例子:

resnet網路就是用到了這種殘差連線。

Transformer的殘差連線

在學習transformer的過程中,編碼器和解碼器都用到了殘差連線,下面我總結一下殘差連線。假如我們的輸入為x,要得到的輸出為h x 那麼我們可以通過 h f x x,轉換為學習f。等得到f的輸出後,在此基礎上加上x即可得h的輸出。在transformer中,此時的f即是下圖中的multi hea...

殘差和損失函式

按自己理解的 殘差就是 y f x 就好像殘差網路裡面就是用這個公式。損失函式就是 根據需求定義的 對f x 與y的差異的度量方法,這跟度量空間有關,比如 下面的情況,要你比較兩個學生誰更好,我們可以設定學習成績作為比較方法,或者身高等等,這些就好像歐幾里得距離,或者余弦距離一樣,是設定出來的的度量...

理解誤差和殘差

誤差 所有不同樣本集的均值的均值,與真實總體均值的偏離.由於真實總體均值通常無法獲取或觀測到,因此通常是假設總體為某一分布型別,則有n個估算的均值 表徵的是觀測 測量的精確度 誤差大,由異常值引起.表明資料可能有嚴重的測量錯誤 或者所選模型不合適,殘差 某樣本的均值與所有樣本集均值的均值,的偏離 表...