tensorflow 分類損失函式問題(有點坑)

2021-09-02 14:14:41 字數 1348 閱讀 7764

tf.nn.softmax_cross_entropy_with_logits(記為f1) 和

tf.nn.sparse_softmax_cross_entropy_with_logits(記為f3),以及

tf.nn.softmax_cross_entropy_with_logits_v2(記為f2)

之間的區別。

f1和f3對於引數logits的要求都是一樣的,即未經處理的,直接由神經網路輸出的數值, 比如 [3.5,2.1,7.89,4.4]。兩個函式不一樣的地方在於labels格式的要求,f1的要求labels的格式和logits類似,比如[0,0,1,0]。而f3的要求labels是乙個數值,這個數值記錄著ground truth所在的索引。以[0,0,1,0]為例,這裡真值1的索引為2。所以f3要求labels的輸入為數字2(tensor)。一般可以用tf.argmax()來從[0,0,1,0]中取得真值的索引。

f1和f2之間很像,實際上官方文件已經標記出f1已經是deprecated 狀態,推薦使用f2。兩者唯一的區別在於f1在進行反向傳播的時候,只對logits進行反向傳播,labels保持不變。而f2在進行反向傳播的時候,同時對logits和labels都進行反向傳播,如果將labels傳入的tensor設定為stop_gradients,就和f1一樣了。

那麼問題來了,一般我們在進行監督學習的時候,labels都是標記好的真值,什麼時候會需要改變label?f2存在的意義是什麼?實際上在應用中labels並不一定都是人工手動標註的,有的時候還可能是神經網路生成的,乙個實際的例子就是對抗生成網路(gan)。

測試用**:

import tensorflow as tf

import numpy as np

truth = np.array([0,0,1,0])

pred_logits = np.array([3.5,2.1,7.89,4.4])

loss = tf.nn.softmax_cross_entropy_with_logits(labels=truth,logits=pred_logits)

loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=truth,logits=pred_logits)

loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(truth),logits=pred_logits)

with tf.session() as sess:

print(sess.run(loss))

print(sess.run(loss2))

print(sess.run(loss3))

參考:

TensorFlow 損失函式

import numpy as np import tensorflow as tf sess tf.interactivesession 1.多分類中的softmax函式在多分類的神經網路中,通常在最後一層接乙個softmax層。對於n分類問題,softmax層就有n個結點,每個結點輸出的就是該類...

TensorFlow損失函式

tensorflow損失函式 正如前面所討論的,在回歸中定義了損失函式或目標函式,其目的是找到使損失最小化的係數。本文將介紹如何在 tensorflow 中定義損失函式,並根據問題選擇合適的損失函式。宣告乙個損失函式需要將係數定義為變數,將資料集定義為佔位符。可以有乙個常學習率或變化的學習率和正則化...

tf 損失函式 TensorFlow裡面損失函式

2 交叉熵 交叉熵 crossentropy 也是loss演算法的一種,一般用在分類問題上,表達的意識為 輸入樣本屬於某一類的概率 其表示式如下,其中y代表真實值分類 0或1 a代表 值。在tensorflow中常見的交叉熵函式有 sigmoid交叉熵 softmax交叉熵 sparse交叉熵 加權...