RNN中網路結構的理解

2021-09-29 11:48:24 字數 3323 閱讀 8356

在使用tensorflow對構建rnn模型的時候,有幾個引數一直不能很好的理解它本身的結構,這對後續網路的修改產生了很大的問題,在網上查閱資料後對其中一些引數結構進行總結。

例子**如下:

#!/usr/bin/env python3

# -*- coding:utf-8 -*-

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

defweight_variable

(shape)

: initial = tf.truncated_normal(shape=shape, stddev=

0.1)

return tf.variable(initial)

defbias_variable

(shape)

: initial = tf.constant(

0.1, shape=shape)

return tf.variable(initial)

defrnn

(x, weights, bias, n_times, n_inputs, n_hidden_units)

:# inputs shape:(100,28,28)

inputs = tf.reshape(x,[-

1, n_times, n_inputs]

) lstm_cell = tf.contrib.rnn.basiclstmcell(n_hidden_units)

# todo tensorflow 刪除了core_rnn_cell

# todo lstm_cell = tf.contrib.rnn.core_rnn_cell.basiclstmcell

# output shape:(100,28,100) finall_state為乙個包含兩個元素的tuple,其中每個元素的shape都為(100,100)

output, finall_state = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)

# predicton shape:(100,10)

prediction = tf.matmul(finall_state[1]

, weights)

+ bias

return prediction

defmain()

: mnist = input_data.read_data_sets(

'mnist_data'

, one_hot=

true

) batch_size =

100 n_batch = mnist.train.num_examples // batch_size

n_inputs =

28 n_times =

28 n_hidden_units =

100 n_classes =

10 x = tf.placeholder(tf.float32,

[none

, n_inputs * n_times]

) y = tf.placeholder(tf.float32,

[none

, n_classes]

) weights = weight_variable(

[n_hidden_units, n_classes]

) bias = bias_variable(

[n_classes]

) prediction = rnn(x, weights, bias, n_times, n_inputs, n_hidden_units)

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction)

) optimizer = tf.train.adamoptimizer(1e-

4)train = optimizer.minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(prediction,1)

, tf.argmax(y,1)

) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)

)with tf.session(

)as sess:

sess.run(tf.global_variables_initializer())

for epoch in

range(11

):for batch in

range

(n_batch)

: batch_xs, batch_ys = mnist.train.next_batch(batch_size)

# todo batch_xs.shape=(batch_size,n_input*n_times)

sess.run(train, feed_dict=

) acc = sess.run(accuracy, feed_dict=

)print

('iter:'

+str

(epoch +1)

+'testing accuracy='

+str

(acc)

)if __name__ ==

'__main__'

: main(

)

以mnist手寫數字識別資料集為例,原資料集中每個大小均為[28,28],我們設定其中的batch_size為100,原始輸入資料就是[100,28*28],進行reshape操作為模型可以處理的資料[100,28,28]。

最關鍵的是tf.contrib.rnn.basiclstmcell(n_hidden_units)中n_hidden_units的含義,查閱資料後得知為網路輸出的向量維數。rnn中我們每個step輸入28維的資料,每個step輸出100維的資料,output輸出每個step的結果,所以最後output的shape為[100,28,100],finall_state為包含兩個狀態(c_state,h_state)的元組,其中每個狀態僅含有最後乙個step的資料,所以兩者的shape都為[100,100],其中m_state的內容和output中每個batch最後一行(也就是最後乙個step)一樣。

參考鏈結原理詳解以及tensorflow中的rnn實現/

caffe中網路結構的視覺化

實驗工具是 ubuntu系統下的caffe 視覺化方法很多,一種是用caffe自帶的draw net.py 來實現網路結構的視覺化。具體實現如下 python caffe python draw net.py train.prototxt net.png用python命令執行draw net.py ...

深度學習之神經網路結構 RNN 理解LSTM

本篇部落格移動到中。rnn 我們不是在大腦一片空白的情況下開始思考。當你讀這篇文章的時候,你是基於對前面單詞的理解來理解當前的單詞。你不會把所有的東西丟開,讓大腦每次都一片空白地思考。我們的思想是持久的。傳統的神經網路做不到這一點,這看起來是它的主要缺點。舉個例子,假設你正在看電影,你想對每個時間點...

RNN 迴圈神經的網路結構 特點及應用例項

迴圈神經網路 recurrent neural network,rnn 是用來建模序列化資料的一種主流深度學習模型。rnn將神經元序列起來,每個神經元能用它的內部變數儲存之間輸入的序列資訊來把整個序列濃縮成抽象表示,並據此進行分類或生成新的序列,解決了傳統前饋神經網路無法處理變長序列和難以捕捉序列中...