tensorflow 基礎學習十 RNN

2022-07-05 10:00:14 字數 2246 閱讀 5387

rnn網路的結構:

上圖展示了乙個簡單的迴圈神經網路結構,在這個迴圈體中僅使用了乙個類似全連線的神經網路結構。迴圈神經網路中的狀態是通過乙個向量來表示的,這個向量的維度稱為迴圈神經網路隱藏層的大小,設其為h。從上圖可以看出,迴圈體中的神經網路的輸入包含兩部分,一分部為上一時刻的狀態,另一部分為當前時刻的輸入樣本。假設輸入向量的維度為x,則上圖中迴圈體的全連線層神經網路的輸入大小為h+x,即將上一時刻的狀態和當前時刻的輸入拼接成乙個大的向量作為迴圈體中神經網路的輸入。因為該神經網路的輸出為當前時刻的狀態,於是輸出層的節點個數也為h,迴圈體中的引數個數為(h+x)×h+h個。從上圖可以看到,迴圈體中的神經網路輸出不僅提供給下一時刻作為狀態,同時也會提供給當前時刻的輸出。為了將當前時刻的狀態轉換為最終的輸出,迴圈神經網路還需要另外乙個全連線神經網路來完成這個過程。不同時刻用於輸出的全連線神經網路中的引數也是一致的。

下面展示乙個迴圈神經網路前向傳播的具體計算過程:

上圖中,假設狀態的維度為2,輸入、輸出的維度都為1,迴圈體中的全連線層中權重為:$w_=\begin 0.1 & 0.2\\ 0.3 & 0.4\\ 0.5 & 0.6 \end$

偏置項的大小為brnn=[0.1,-0.1],用於輸出的全連線層權重為:$w_=\begin 1.0 \\ 2.0 \end$ ,偏置項大小為$b_=0.1$,那麼在$t_$時刻,因為沒有上一時刻,所以將狀態初始化為[0,0],而當前的輸入為1,所以拼接得到向量[0,0,1],通過迴圈體中的全連線層神經網路得到結果為:

$tanh\left ( [0,0,1] \times \begin 0.1 & 0.2 \\ 0.3 & 0.4 \\ 0.5 & 0.6\end+[0.1,-0.1]  \right )=tanh\left ( [0.6,0.5]\right )=[0.537,0.462]$

這個結果將作為下一時刻的輸入狀態,同時迴圈神經網路也會使用該狀態生成輸出,最終得到$t_$的輸出為:$[0.537,0.462] \times \begin 1.0 \\ 2.0 \end+0.1=1.56$

使用$t_$時刻的狀態可以類似地推導得出$t_$時刻的狀態為[0.860,0.884],而$t_$時刻的輸出為2.73。在得到迴圈神經網路的前向傳播結果後,可以和其他神經網路類似的定義損失函式。迴圈神經網路唯一的區別在於它每個時刻都有乙個輸出,所以迴圈神經網路的總損失為所有時刻上的損失函式的總和,以下**實現了這個迴圈神經網路前向傳播的過程。

import

numpy as np

x=[1,2]

state=[0.0,0.0]

#分開定義不同輸入部分的權重以方便計算

w_cell_state=np.array([[0.1,0.2],[0.3,0.4]])

w_cell_input=np.array([0.5,0.6])

b_cell=np.array([0.1,-0.1])

#定義用於輸出的全連線層引數。

w_output=np.array([[1.0],[2.0]])

b_output=0.1

#按照時間順序執行迴圈神經網路的前向傳播過程。

for i in

range(len(x)):

#計算迴圈體中的全連線層神經網路

before_activation=np.dot(state,w_cell_state)+x[i]*w_cell_input+b_cell

state=np.tanh(before_activation)

#根據當前時刻狀態計算最終輸出

final_output=np.dot(state,w_output)+b_output

#輸出每個時刻的資訊

print('

before activation:

',before_activation)

print('

state:

',state)

print('

output:

',final_output)

在實際應用中,如果序列過長會導致優化時出現梯度消散的問題,所以實際中一般會規定乙個最大長度,當序列長度超過規定長度之後會對序列進行截斷。

tensorflow基礎學習 會話

會話 tensorflow執行模型 一 tensorflow系統結構的概述 從圖中可以看出tensorflow的整個系統在結構上大體可以分為兩個子系統 前端系統和後端系統。其中前端系統提供程式設計模型,負責構造計算圖 後端系統提供執行時環境,負責執行計算圖。我們重點關注系統中client distr...

TensorFlow教程(十) 反向傳播

tensorflow通過宣告優化函式 optimization function 來實現,一旦宣告好優化函式,tensorflow將通過它在計算圖中解決反向傳播的項。當傳入資料,最小化損失函式,tensorflow會在計算圖中根據狀態相應的調節變數。coding utf 8 import tenso...

Tensorflow 基礎概念

g v,e v operation 圖的節點 e tensor 圖的邊 g graph 圖 tensorflow tensor 多維陣列 flow graph 圖 op session回話上下文管理 variable tensor 多維資料變數 placeholder 外部傳入的引數變數 seesi...