Tensorflow訓練之得到Nan錯誤的分析

2021-08-20 07:44:56 字數 1625 閱讀 3095

**:

之前在tensorflow中實現不同的神經網路,作為新手,發現經常會出現計算的loss中,出現nan值的情況,總的來說,

tensorflow中出現nan值的情況有兩種,一種是在loss中計算後得到了nan值,另一種是在更新網路權重等等資料的

時候出現了nan值,本文接下來,首先解決計算loss中得到nan值的問題,隨後介紹更新網路時,出現nan值的情況。

loss計算中出現nan值

大致的解決辦法就是,在出現nan值的loss中一般是使用的tensorflow的log函式,然後計算得到的nan,一般是輸入的值中出現了負數值或者0值,在tensorflow的官網上的教程中,使用其偵錯程式除錯nan值的出現,也是查到了計算log的傳參為0;而解決的辦法也很簡單,假設傳參給log的引數為y,那麼在呼叫log前,進行一次數值剪下,修改呼叫如下:

loss = tf.log(tf.clip_by_value(y,1e-8,1.0))

這樣,y的最小值為0的情況就被替換成了乙個極小值,1e-8,這樣就不會出現nan值了,stackoverflow上也給出了相同的解決方案。於是,我就採用了上述的解決方案對於log的引數進行數值限制,但是我更加複雜化了這個限制。

tf.clip_by_value這個函式,是將第乙個引數,限制在第

二、三個引數指定的範圍之內,使用這個函式的原意是要避免0值,並沒有限制最大值,因而我將限制的呼叫修改如下:

loss = tf.log(tf.clip_by_value(y,1e-8,tf.reduce_max(y)))

這樣就確保了對於y值的剪下,不會影響到其數值的上限。但是在實際的神經網路中使用的時候,我發現這樣修改後,雖然loss的數值一直在變化,可是優化後的結果幾乎是保持不變的,這就存在問題了。

經過檢查,其實並不能這麼簡單的為了持續訓練,而修改計算損失函式時的輸入值。這樣修改後,loss的數值很可能(存在0的話確定就是)假的數值,會對優化器優化的過程造成一定的影響,導致優化器並不能正常的工作。

要解決這個假的loss的方法很簡單,就是人為的改造神經網路,來控制輸出的結果,不會存在0。這就需要設計好最後一層輸出層的啟用函式,每個啟用函式都是存在值域的,詳情請見部落格

比如要給乙個在(0,1)之間的輸出(不包含0),那麼顯然sigmoid是最好的選擇。不過需要注意的是,在tensorflow中,tf.nn.sigmoid函式,在輸出的引數非常大,或者非常小的情況下,會給出邊界值1或者0的輸出,這就意味著,改造神經網路的過程,並不只是最後一層輸出層的啟用函式,你必須確保自己大致知道每一層的輸出的乙個範圍,這樣才能徹底的解決nan值的出現。

舉例說明就是tensorflow的官網給的教程,其輸出層使用的是softmax啟用函式,其數值在[0,1],這在設計的時候,基本就確定了會出現nan值的情況,只是發生的時間罷了。

更新網路時出現nan值

更新網路中出現nan值很難發現,但是一般除錯程式的時候,會用summary去觀測權重等網路中的值的更新,因而,此時出現nan值的話,會報錯類似如下:

invalidargumenterror (see above for traceback): nan in summary histogram for: weight_1

這樣的情況,一般是由於優化器的學習率設定不當導致的,而且一般是學習率設定過高導致的,因而此時可以嘗試使用更小的學習率進行訓練來解決這樣的問題。

TensorFlow訓練Logistic回歸

如下圖,可以清晰看到線性回歸和邏輯回歸的關係,乙個線性方程被邏輯方程歸一化後就成了邏輯回歸。對於二分類,輸出假如線性回歸模型為,則要將z轉成y,即y g z 於是最直接的方式是用單位階躍函式來表示,即 如圖,但階躍函式不連續,於是用sigmoid函式替代之,為 如圖,則有,即logistics函式,...

Tensorflow訓練迴圈

def fit loop model,inputs,targets,sample weights none,class weight none,val inputs none,val targets none,val sample weights none,batch size none,epoch...

tensorflow 資料訓練

一 資料訓練遇到問題 excle資料,如何進行訓練?excle資料,如何resize 呢?目前思路 tfrecords 採用 numpy的方法進行處理 學習方法 從檔案中讀取資料 標準化格式tfrecords記錄 二 資料預處理 numpy 不能有中文,要採用decode等方法 不能夠有百分號?目前...