Tensorflow 權重衰減的使用

2021-07-31 11:56:08 字數 1790 閱讀 7245

在 tf.get_variable 這個函式中有乙個命名引數為 regularizer,顧名思義,這個引數可用於正則化。在 tensorflow 官網中,regularizer 的描述如下:

get_variable(

name,

shape=none,

dtype=none,

initializer=none,

regularizer=none,

trainable=true,

collections=none,

caching_device=none,

partitioner=none,

validate_shape=true,

use_resource=none,

custom_getter=none

)

在使用的時候,我們可以自定義乙個函式,只要這個函式的輸入引數和返回值都是 tensor,這個函式就能作為 regularizer 這個命名引數傳入 get_variable 函式。例如我們可以定義如下的損失函式:

def

regularizer

(tensor):

with tf.name_scope(scope, default_name='l2_regularizer', values=[tensor]):

l2_weight = tf.convert_to_tensor(weight_decay, \

dtype=tensor.dtype.base_dtype, \

name='weight_decay')

return tf.multiply(l2_weight, tf.nn.l2_loss(tensor), name='value')

我還沒找到說明 tensor 這個引數的文件,猜測這個 tensor 傳過來的變數就是 get_variable 這個函式的返回值,當我們用 get_variable 獲取權重時,這個 tensor 就是權重,所以 regularizer 這個函式返回的就是乙個 l2 正則項。

目前為止,我們僅僅是知道了這個引數的含義,但並不知道 tensorflow 底層會怎樣去使用這個引數。需要注意,tensorflow 沒有自動把這個正則項自動加到優化當中去,這意味著在構建 loss 階段我們必須把正則項加入進去。

由官方文件可以知道,我們可以通過 tf.get_collection(tf.graphkeys.regularization_losses) 獲取正則項,這個函式得到的是由若干個型別為 tensor 的元素構成的 list。具體來說,如果我們為網路中的 n 個權重變數都加上 regularizer 這個引數,那麼在我們應該得到 n 個對應的正則項,從而通過函式 tf.get_collection(tf.graphkeys.regularization_losses) 能獲取乙個 n 個元素構成的 list,這裡每乙個元素分別對應不同的權重的正則項。

由ufldl 的 backpropagation algorithm 可以知道在優化問題中,損失函式即為均方差以及所有權重的正則項之和,所以在構建 loss 時,我們只要把均方差與上述 tf.get_collection(tf.graphkeys.regularization_losses) 的所有元素之和相加即可得到最終的 loss。

regularization_losses = tf.get_collection(tf.graphkeys.regularization_losses)

loss = tf.add_n(regularization_losses) + loss

正則化與權重衰減

1.權重衰減 weight decay l2正則化的目的就是為了讓權重衰減到更小的值,在一定程度上減少模型過擬合的問題,所以權重衰減也叫l2正則化。其中c0代表原始的代價函式,後面那一項就是l2正則化項,它是這樣來的 所有引數w的平方的和,除以訓練集的樣本大小n。就是正則項係數,權衡正則項與c0項的...

權重衰減 L2正則化

正則化方法 防止過擬合,提高泛化能力 避免過擬合的方法有很多 early stopping 資料集擴增 data augmentation 正則化 regularization 包括l1 l2 l2 regularization也叫weight decay dropout。權重衰減 weight d...

Tensorflow中檢視權重

剛開始學習tensorflow,還不太會用,開個博記錄,今天遇到乙個問題是用tf.layers.dense建立的全連線層,如何檢視權重?知道kernel表示了權重,但是如何提示成變數?我分成兩步 1 檢視tensor tf.trainable variables 命令列裡中執行即可,如下圖 可以看到...