用rnn網路訓練mnist資料集

2021-10-07 08:37:44 字數 3197 閱讀 5577

與以往cnn卷積神經網路不同,rnn的思想是資料資訊有順序,所以rnn一般用來訓練文字資料資訊,就像小時候的填空題一樣,我們能根據上下文判斷這個空應該填什麼,這前面的文字和後面的文字都是有順序的。

用rnn訓練分類也有其優點,例如每個人都是頭朝上腳朝下(個別倒立的例外),在這裡我們就用rnn網路訓練mnist資料集。

mnist手寫數字是28*28的,假設每次訓練128張, 則訓練集資料的shape為[128, 28, 28]

這裡主要講一下rnn神經網路:

rnn是個有向迴圈網路,以這個demo為例,是28*28的,以28輪向rnn網路新增xt資訊(x1-x28),每次的資訊量是(batch_size, 28),與u進行矩陣相乘,(每一輪的u,w,v是通用的),運算公式:

g,f都是對應的啟用函式,設:

xt.shape=[b, m]

u.shape=[m, n], n為隱藏節點個數

w.shape=[n, n]

s(t-1) = [b, n]

所以st = tf.matmul(tf.concat([xt, s(t-1)], -1), tf.concat([u, w], 0))

在這裡我們只有n(隱藏節點個數)和s0是未知的 ,在**實現中,我們只需要3句**就能實現rnn網路過程:

# 初始化隱藏節點個數

rnn_cell = tf.nn.rnn_cell.basiclstmcell(hidden_num)

# 初始化s0

init_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# 構建完整的rnn網路

out, out_last = tf.nn.dynamic_rnn(rnn_cell, x, initial_state=init_state, time_major=

false

)

完整**如下:

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 取出資料集

mnist = input_data.read_data_sets(

"mnist_data_bak"

, one_hot=

true

)# 初始化引數+構建rnn神經網路

lr =

0.001

training_iters =

1000000

iter_num =

28height =

28hidden_num =

150batch_size =

128n_classes =

10x = tf.placeholder(dtype=tf.float32, shape=

[none

, iter_num, height]

)y = tf.placeholder(dtype=tf.float32, shape=

[none

, n_classes]

)weight = tf.variable(tf.random_normal(shape=

[hidden_num, n_classes]))

bais = tf.variable(tf.zeros(

[n_classes,])

)rnn_cell = tf.nn.rnn_cell.basiclstmcell(hidden_num)

init_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

out, out_last = tf.nn.dynamic_rnn(rnn_cell, x, initial_state=init_state, time_major=

false

)pred = tf.matmul(out_last[1]

, weight)

+ bais

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)

)train_op = tf.train.adamoptimizer(lr)

.minimize(cost)

acc_temp = tf.cast(tf.equal(tf.argmax(pred,1)

, tf.argmax(y,1)

), tf.float32)

accuracy = tf.reduce_mean(acc_temp)

# 構建訓練

init = tf.global_variables_initializer(

)with tf.session(

)as sess:

sess.run(init)

step =

0while step*batch_size < training_iters:

x_train, y_train = mnist.train.next_batch(batch_size)

x_train = x_train.reshape(

[batch_size, iter_num, height]

) _,

= sess.run(

[train_op]

, feed_dict=

) step +=

1if step %

20==0:

x_test, y_test = mnist.test.next_batch(batch_size)

x_test = x_test.reshape(

[batch_size, iter_num, height]

) acc = sess.run(accuracy, feed_dict=

)print

("acc:{}"

.format

(acc)

)

提取碼:b0c2

積極向上,每天向前一小步,總有一天能實現自己的價值。

用Mnist資料集訓練神經網路

這篇部落格是我在學習用tensorflow搭建神經網路時,所作的一些筆記。搭建的神經網路有兩層隱藏層,和輸入輸出層。採用全連線的方式進行傳輸,優化演算法採用自適用矩估計演算法。1.首先,匯入tensorflow官方提供的庫 import tensorflow as tf from tensorflo...

用rnn進行mnist資料集的處理

lstm rnn.basiclstmcell lstm size,forget bias 1.0,state is tuple true x split tf.split xr,time step size,0 t lstm cell會產生兩個內部狀態 ctct 和htht 關於rnn與lstm的介...

用Caffe 訓練和測試MNIST資料

c affe安裝包自帶mnist的例子。測試步驟如下 1.獲得mnist的資料報,在caffe的根目錄下執行.date mnist get mnist.sh指令碼 2.生成lmdb 執行.example mnist create mnist.sh。將mnist date 轉化成caffe可用的lmd...