Pytorch交叉熵損失的理解

2021-10-09 09:23:18 字數 2930 閱讀 5340

# 讓我們康康這個小栗子

>>

>

input

= torch.randn(3,

5, requires_grad=

true

)# 定義輸入(可看成是神經網路最後一層的輸出),需要梯度資訊以被更新

>>

> target = torch.randint(5,

(3,)

, dtype=torch.int64)

# 定義交叉熵類別

>>

> optimizer = torch.optim.adam(

[input

],lr=1)

# 為了效果明顯,學習率設定成1了

>>

>

input

# 還沒被更新前的輸入

tensor([[

0.9230

,1.4724,-

0.3339,-

0.5156,-

0.6469],

[-1.8437

,0.6457,-

1.3924,-

0.2757,-

0.0595],

[-0.6466,-

0.1059,-

0.9250

,0.4411,-

1.6658]]

, requires_grad=

true

)>>

> target # 即第一行的輸入的正確標籤為4,第二行的輸入的正確標籤為4,第三行的輸入的正確標籤為0

tensor([4

,4,0

])>>

> optimizer.zero_grad(

)# 計算損失前,先把優化器的歷史梯度資訊清零

>>

> loss=torch.nn.functional.cross_entropy(

input

,target)

# 使用交叉熵計算損失

>>

> loss.backward(

)# 進行梯度計算

>>

> optimizer.step(

)# 使用優化器對需要被更新的引數進行更新

>>

>

input

# 可見input已經被更新了。輸出層整層被更新體現在一行資料全部都發生了變化

tensor([[

-0.0770

,0.4724,-

1.3339,-

1.5156

,0.3531],

[-2.8437,-

0.3543,-

2.3924,-

1.2757

,0.9405],

[0.3534,-

1.1059,-

1.9250,-

0.5589,-

2.6658]]

, requires_grad=

true

)>>

>

input

.grad # 從input的梯度中可以看到對應一行4列,二行4列,三行0列的指示著input的更新

tensor([[

0.0963

,0.1668

,0.0274

,0.0228,-

0.3133],

[0.0131

,0.1583

,0.0206

,0.0630,-

0.2551],

[-0.2843

,0.0841

,0.0371

,0.1454

,0.0177]]

)>>

> loss # nllloss即negative log-likelihood loss

tensor(

2.0596

, grad_fn=

)

引用官方文件:

input為(n, c),其中的n可看做為一次計算得出的乙個結果,c是類別。即我們的例子是一次計算得到了5個神經元給出的值。而我們總共計算了3次。所以input是(3, 5)。

對於第一行來說,因為4為正確的標籤,而第一行的4列(最後一列)可解釋成神經網路正確標籤對應類別的神經元輸出的概率,即一開始為-0.6469,此時神經元不認為4類時正確的類,而更新後變成0.3531,變大了,讓神經元更多地認為4類時正確的類了,榆次同時,其他類別的值也同時被減小了。

再引用官方文件:

可以看到交叉熵損失可以用-log(正確類所佔的比率)來刻畫,也即為negative log-likelihood(負對數似然)即負的log(裡面算softmax得到概率)。總體使得正確類的對應的指示神經元,在正確類輸入資料輸入進來時,輸出較大的值,使得它佔所有類別的指示神經元的值的比重最大,相應地可解釋成神經網路認為這個正確類的輸入資料是正確類的概率最大。

我們可以粗略看看pytorch中計算的交叉熵損失函式loss(x, class)即對某一條資料的屬於正確那一類計算損失,但我們不要以為模型每次都只更新指示正確的那一類的那個神經元,其他神將元就都不更新了。這種想法是錯的,因為雖說loss(x, class)是對指示著某個class的神經元的那個輸出值計算loss,而在計算的過程中也涉及到了其他非指示正確類的神經元的值,所以在更新正確類神經元的同時也更新非正確類的神經元,即神經網路的最後一層(輸出層)整層被同時更新【可參考上面**的input】。

原因在於loss(x, class)的計算中同時涉及到了正確類的神經元的輸出值與其它非正確類的神經元的輸出值。

交叉熵損失函式理解

交叉熵損失函式的數學原理 我們知道,在二分類問題模型 例如邏輯回歸 logistic regression 神經網路 neural network 等,真實樣本的標籤為 0,1 分別表示負類和正類。模型的最後通常會經過乙個 sigmoid 函式,輸出乙個概率值,這個概率值反映了 為正類的可能性 概率...

交叉熵損失函式 交叉熵損失函式和均方差損失函式引出

交叉熵 均方差損失函式,加正則項的損失函式,線性回歸 嶺回歸 lasso回歸等回歸問題,邏輯回歸,感知機等分類問題 經驗風險 結構風險,極大似然估計 拉普拉斯平滑估計 最大後驗概率估計 貝葉斯估計,貝葉斯公式,頻率學派 貝葉斯學派,概率 統計 記錄被這些各種概念困擾的我,今天終於理出了一些頭緒。概率...

交叉熵損失函式

公式 分類問題中,我們通常使用 交叉熵來做損失函式,在網路的後面 接上一層softmax 將數值 score 轉換成概率。如果是二分類問題,我們通常使用sigmod函式 2.為什麼使用交叉熵損失函式?如果分類問題使用 mse 均方誤差 的方式,在輸出概率接近0 或者 接近1的時候,偏導數非常的小,學...