ML 四 RNN實現MNIST的識別

2021-08-28 07:54:10 字數 2935 閱讀 7755

本文使用的**來自網上的開源專案

rnn的理解相較與cnn有一定的難度,本文不做rnn原理的講解,我貼乙個寫的非常好的rnn講解的部落格,有興趣的朋友可以去這裡學習了解下:

有關lstm的了解,可以看這裡:

首先,慣例讀入mnist

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('mnist_data',one_hot=true)

然後定義一些超參

hidden_layer = 128

batch_size = 128

inputs_num = 28

steps_num = 28

class_num = 10

lr = 0.001

第乙個:隱藏層神經元個數;

第五個:類別數,0~9所以10個;

最後:學習率

weights = 

biases =

由於,rnn中權值與偏置值均有兩種,方便呼叫所以用字典的形式定義

x = tf.reshape(x,[batch_size*steps_num,inputs_num])#128*28  28

x_in = tf.matmul(x,weights['in'])+biases['in']#128*28 128

x_in = tf.reshape(x_in,[-1,steps_num,hidden_layer])#_ 28 128

cell = tf.nn.rnn_cell.basiclstmcell(hidden_layer,forget_bias=1.0,state_is_tuple=true)

#遺忘偏置值是1表示全不遺忘,0表示全遺忘

_init_state = cell.zero_state(batch_size,tf.float32)

#通過zero_state得到乙個全0的初始狀態,形狀為(batch_size, state_size),這裡的state_size等於隱層神經元的數目

output,states = tf.nn.dynamic_rnn(cell,x_in,initial_state=_init_state,time_major=false)

result = tf.matmul(states[1],weights['out'])+biases['out']

這就是我們定義的rnn,其中使用的是lstm;

在使用rnn時,比較重要的一點就是我們時刻要記住我們的輸入輸出形狀,每一步的形狀我基本已在**後標註出來

在傳入rnn時,我們需要將轉成128*28,28的形狀,在傳入cell時,要將它轉成三維的形狀

這裡就比較難理解了,rnn難理解的點也就在這——輸入與輸出

首先,要先了解什麼是「rnncell」,它可以說是rnn的基本單元、rnn的必備單元、乙個神經網路是不是rnn的重要識別標誌(這麼說可能比較片面),每個rnncell都有乙個call方法,使用方式是:output, states = call(input, state)。

舉個栗子,假設我們有乙個初始狀態h0,還有輸入x1,呼叫call(x1, h0)後就可以得到(output1, h1): 再呼叫一次call(x2, h1)就可以得到(output2, h2): 也就是說,每呼叫一次rnncell的call方法,就相當於在時間上「推進了一步」。

state_size是隱層的大小,output_size是輸出的大小,比如我們將乙個batch送入模型計算,設輸入資料的形狀為(batch_size, input_size),那麼計算時得到的隱層狀態就是(batch_size, state_size),輸出就是(batch_size, output_size)。

最後我們呼叫tf.nn.dynamic_rnn一次執行多步,output的形狀為(batch_size, time_steps, cell.output_size)。states是最後一步的隱狀態,它的形狀為(batch_size, cell.state_size)

最後,我將rnn中每一步的形狀,寫出來,如果你還是理解不了,請原諒我的笨,我實在不知道還能怎麼解釋了,

x:  128*28 28

x_in:  128*28 128

x_in:   128 28 128

output: 128 28 128

states :  128 128

result : 128 10

pred = rnn(x,weights,biases)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels= y))

train = tf.train.adamoptimizer(lr).minimize(cost)

correct = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

accuracy = tf.reduce_mean(tf.cast(correct,dtype=tf.float32))

之後就是訓練與計算準確率

batch_xs,batch_ys = mnist.train.next_batch(batch_size)

batch_xs = batch_xs.reshape([batch_size,steps_num,inputs_num])

sess.run(train,feed_dict=)

if step%100 == 0:

print(sess.run(accuracy,feed_dict=))

流程與**在這了,希望對你有幫助

rnn是真的難理解,寫著寫著我自己都開始暈了  =。=!!!

MNIST資料集手寫體識別 RNN實現

github部落格傳送門 csdn部落格傳送門 tensorflow python基礎 深度學習基礎網路模型 mnist手寫體識別資料集 import tensorflow as tf mnist input data.read data sets mnist data one hot true c...

用rnn進行mnist資料集的處理

lstm rnn.basiclstmcell lstm size,forget bias 1.0,state is tuple true x split tf.split xr,time step size,0 t lstm cell會產生兩個內部狀態 ctct 和htht 關於rnn與lstm的介...

用簡單卷積神經網路實現MNIST資料集識別

第一次寫部落格,記錄一下最近的學習經歷吧,最近在學卷積神經網路,自己就寫了乙個比較簡單的卷積神經網路實現了mnist資料集的識別,本來是想用lenet5來實現的,感覺lenet5太老了,所以就寫了乙個差不多的卷積神經網路來實現mnist資料集的識別。希望可以幫助一些剛學習卷積神經網路的朋友,也可以根...