神經網路更新引數的幾種方法

2021-09-02 14:00:04 字數 2444 閱讀 1882

梯度下降中,計算完各個引數的導數之後就需要更新引數值了,最常用的更新引數方法就是:

【sgd】:

x += - learning_rate * dx
但是這種方法收斂速度非常慢,其實除了這個更新引數的方法,還有很多的方法可以進行引數更新。

【momentum update】:

這個方法對於深度學習的網路引數更新往往有不錯的效果。本質意思是,在更新新的引數的時候需要考慮前乙個時刻的「慣性」,其更新引數如下:

# momentum update

v = mu * v - learning_rate * dx # integrate velocity

x += v # integrate position

上面計算方法和下面的等價(其中的ρ等價於上面的mu

其中一般的,v初始為0,mu是優化引數,一般初始化引數為0.9,當使用交叉驗證的時候,引數mu一般設定成[0.5,0.9,0.95,0.99],在開始訓練的時候,梯度下降較快,可以設定mu為0.5,在一段時間後逐漸變慢了,mu可以設定為0.9、0.99。也正是因為有了「慣性」,這個比sgd會穩定一些。

【nesterov momentum】

這是乙個和上面的momentum update有點不一樣的方法,這種方法最近得到了較為廣泛的運用,對於凸函式,它有更為快的收斂速度。

計算公式:

x_ahead = x + mu * v

# evaluate dx_ahead (the gradient at x_ahead instead of at x)

v = mu * v - learning_rate * dx_ahead

x += v

其基本思路如下:(參考自各種優化方法的比較)

首先,按照原來的更新方向更新一步(x_ahead,也就是棕色線),然後在該位置計算梯度值(也就是

dx_ahead,

紅色線),然後用這個梯度值修正最終的更新方向(綠色線)。上圖中描述了兩步的更新示意圖,其中藍色線是標準momentum更新路徑

【adagrad】

adagrad是一種自適應學習率的更新方法,計算方法如下:

# assume the gradient dx and parameter vector x

cache += dx**2

x += - learning_rate * dx / (np.sqrt(cache) + eps)

這個方法其實是動態更新學習率的方法,其中cache將每個梯度的平方和相加,而更新學習率的本質是,如果求得梯度距離越大,那麼學習率就變慢,而eps是乙個平滑的過程,取值通常在(10^-4~10^-8 之間)

【rmsprop】

rmspro是還沒有發布的方法,但是已經使用的額相當廣泛,其和adagrad的方法差不多,計算方法如下:

cache = decay_rate * cache + (1 - decay_rate) * dx**2

x += - learning_rate * dx / (np.sqrt(cache) + eps)

其中,decay_rate取值通常在[0.9,0.99,0.999]

【adam】

adam現在已經被廣泛運用了,adam的更新引數方法如下:

m = beta1*m + (1-beta1)*dx

v = beta2*v + (1-beta2)*(dx**2)

x += - learning_rate * m / (np.sqrt(v) + eps)

m,v一般初始化為0,而這篇**中,eps取值為1e-8 beta1=0.9 beta2=0.9999【幾種常見引數更新方法的比較】:

神經網路更新引數的幾種方法

梯度下降中,計算完各個引數的導數之後就需要更新引數值了,最常用的更新引數方法就是 sgd x learning rate dx 但是這種方法收斂速度非常慢,其實除了這個更新引數的方法,還有很多的方法可以進行引數更新。momentum update 這個方法對於深度學習的網路引數更新往往有不錯的效果。...

神經網路的幾種引數更新方法

method 3 adagrad 學習率衰減 method 4 adam 融合前兩個的方法 對比 實現 待更 本文預設讀者有深度學習基礎,所以不再解釋公式的具體含義 adagrad 會為引數的每個元素適當地調整學習率,與此同時進行學習 adagrad 的 ada 來自英文單詞 adaptive,即 ...

神經網路相關引數

關於建立 神經網路 段的引數說明 net feedforwardnet n,trainlm n為隱藏層大小,預設為10 trainlm 是被指定的訓練函式 levenberg marquardt 演算法t p net.trainparam.goal 0.001 goal是最小均方誤差的訓練目標 ne...