TF多層 LSTM 以及 State 之間的融合

2021-09-11 05:11:37 字數 4081 閱讀 5985

第一是實現多層的lstm的網路;

第二是實現兩個lstm的state的concat操作, 分析 state 的結構.

對於第乙個問題,之前一直沒有注意過, 看下面兩個例子:

在這裡插入**片

import tensorflow as tf

num_units = [20, 20]

#unit1, ok

# x = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)

# x = tf.reshape(x, [-1, 5, 6])

# multi_rnn = [tf.nn.rnn_cell.basiclstmcell(num_units=units) for units in num_units]

# lstm_cells = tf.contrib.rnn.multirnncell(multi_rnn)

# output, state = tf.nn.dynamic_rnn(lstm_cells, x, time_major=true, dtype=tf.float32)

#unit2, ok

# x = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)

# x = tf.reshape(x, [-1, 5, 6])

# multi_rnn =

# for i in range(2):

# lstm_cells = tf.contrib.rnn.multirnncell(multi_rnn)

# output, state = tf.nn.dynamic_rnn(lstm_cells, x, time_major=true, dtype=tf.float32)

# unit3 *********error***********

x = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)

x = tf.reshape(x, [-1, 5, 6])

# single_cell = tf.nn.rnn_cell.basiclstmcell(num_units=20) # same as below

lstm_cells = tf.contrib.rnn.multirnncell([tf.nn.rnn_cell.basiclstmcell(num_units=20)] * 2)

output, state = tf.nn.dynamic_rnn(lstm_cells, x, time_major=true, dtype=tf.float32)

print(output)

print(state)

with tf.session() as sess:

sess.run(tf.global_variables_initializer())

for var in tf.global_variables():

print(var.op.name)

output_run, state_run = sess.run([output, state])

之前還真沒注意到這個問題, 雖然一般都是多層的維度一致,但是都是寫成 unit2 這種形式.

第二個問題兩個 encoder 的 state 的融合, 並保持 state 型別 (lstm/gru)

import tensorflow as tf

def concate_rnn_states(num_layers, encoder_state_local, encoder_state_global):

''':param num_layers:

:param encoder_fw_state:

for lstm:

(lstmstatetuple(c=,

h=),

lstmstatetuple(c=,

h=))

for gru:(,)

:param encoder_bw_state: same as fw

:return: tuple

'''encoder_states =

for i in range(num_layers):

if isinstance(encoder_state_local[i], tf.nn.rnn_cell.lstmstatetuple):

# for lstm

encoder_state_c = tf.concat(values=(encoder_state_local[i].c, encoder_state_global[i].c), axis=1,

name="concat_layer{}_state_c".format(i))

encoder_state_h = tf.concat(values=(encoder_state_local[i].h, encoder_state_global[i].h), axis=1,

name="concat_layer{}_state_h".format(i))

encoder_state = tf.contrib.rnn.lstmstatetuple(c=encoder_state_c, h=encoder_state_h)

elif isinstance(encoder_state_local[i], tf.tensor):

# for gru and rnn

encoder_state = tf.concat(values=(encoder_state_local[i], encoder_state_global[i]), axis=1,

name='gruorrnn_concat')

return tuple(encoder_states)

num_units = [20, 20]

#unit1

x = tf.random_normal(shape=[3, 5, 6], dtype=tf.float32)

x = tf.reshape(x, [-1, 5, 6])

with tf.variable_scope("encoder1") as scope:

local_multi_rnn = [tf.nn.rnn_cell.grucell(num_units=units) for units in num_units]

local_lstm_cells = tf.contrib.rnn.multirnncell(local_multi_rnn)

encoder_output_local, encoder_state_local = tf.nn.dynamic_rnn(local_lstm_cells, x, time_major=false, dtype=tf.float32)

with tf.variable_scope("encoder2") as scope:

global_multi_rnn = [tf.nn.rnn_cell.grucell(num_units=units) for units in num_units]

global_lstm_cells = tf.contrib.rnn.multirnncell(global_multi_rnn)

encoder_output_global, encoder_state_global = tf.nn.dynamic_rnn(global_lstm_cells, x, time_major=false, dtype=tf.float32)

print("concat output")

encoder_outputs = tf.concat([encoder_output_local, encoder_output_global], axis=-1)

print(encoder_output_local)

print(encoder_outputs)

print("concat state")

print(encoder_state_local)

print(encoder_state_global)

encoder_states = concate_rnn_states(2, encoder_state_local, encoder_state_global)

print(encoder_states)

單層LSTM和多層LSTM的輸入與輸出

rnn結構 對應的 為 中沒寫偏置 上圖是單層lstm的輸入輸出結構圖。其實它是由乙個lstm單元的乙個展開,如下圖所示 所以從左到右的每個lstm block只是對應乙個時序中的不同的步。在第乙個圖中,輸入的時序特徵有s個,長度記作 seq len,每個特徵是乙個c維的向量,長度記作 input ...

多層LSTM結構的深入解讀

讀這篇文章的時候,預設你已經對lstm神經網路有了乙個初步的認識,當你深入理解時,可能會對多層lstm內部的隱藏節點數,有關cell的定義或者每一層的輸入輸出是什麼樣子的特別好奇,雖然神經網路就像是乙個黑箱子一樣,但是我們仍然試圖去理解他們。我們所說的lstm的cell就是這樣子的乙個結構 圖中標識...

學習筆記TF011 多層神經網路

線性回歸 對數機率回歸模型,本質上是單個神經元。計算輸入特徵加權和。偏置視為每個樣本輸入特徵為1權重,計算特徵線性組合。啟用 傳遞 函式 計算輸出。線性回歸,恒等式 值不變 對數機率回歸,sigmoid。輸入 權重 求和 傳遞 輸出。softmax分類含c個神經元,每個神經元對應乙個輸出類別。xor...