RNN原理及其解決MNIST手寫數字識別

2021-10-02 05:44:45 字數 3815 閱讀 8550

雙向迴圈神經網路

從這張圖可以看出,在經歷多次傳播之後,前面遠的神經元對現在的影響會變小,即出現梯度消失的情況,我們希望網路有個好的心態——選擇性記憶和遺忘,對他有影響的記下來,對他沒有沒有影響的忘記。為了解決rnn存在的梯度消失問題,科學家們花費7年時間,分析出了lstm(長短時記憶網路),解決了rnn的問題,此前,說人們通過rnn取得了顯著的成果,這些成果基本上都是使用lstm實現的。這足以表明lstm的強大。

假設某輪訓練中,各時刻的梯度以及最終的梯度之和如下圖:

我們就可以看到,從上圖的t-3時刻開始,梯度已經幾乎減少到0了。那麼,從這個時刻開始再往之前走,得到的梯度(幾乎為零)就不會對最終的梯度值有任何貢獻,這就相當於無論t-3時刻之前的網路狀態h是什麼,在訓練中都不會對權重陣列w的更新產生影響,也就是網路事實上已經忽略了t-3時刻之前的狀態。這就是原始rnn無法處理長距離依賴的原因。

lstm核心結構圖

前面描述的開關是怎樣在演算法中實現的呢?這就用到了門(gate)的概念。門實際上就是一層全連線層,它的輸入是乙個向量,輸出是乙個0到1之間的實數向量。假設w是門的權重向量,b是偏置項。

門的使用,就是用門的輸出向量按元素乘以我們需要控制的那個向量。因為門的輸出是0到1之間的實數向量,那麼,當門輸出為0時,任何向量與之相乘都會得到0向量,這就相當於啥都不能通過;輸出為1時,任何向量與之相乘都不會有任何改變,這就相當於啥都可以通過。因為(也就是sigmoid函式)的值域是(0,1),所以門的狀態都是半開半閉的。

lstm結構圖展開

所有門的輸入為3個:本層的cell,上一層的輸出,本層的輸入,門的開閉相當於權值。是訓練出來的。

lstm工作過程(一看就明白)

rnn+lstm的**實現,解決mnist手寫數字識別的問題。核心在於函式rnn,其他的部分和之前用全連線網路實現是一樣的。

import tensorflow as tf

import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

#載入mnist資料集

mnist=input_data.read_data_sets(

"mnist_data"

,one_hot=

true

)#輸入是28*28

max_time=

28#一共有28行

n_inputs=

28#一行有28個資料

lstm_size=

100#隱藏層單元,這裡可不是神經元,而是乙個bloack

n_classes=

10#10個分類

batch_size=

50n_batch=mnist.train.num_examples//batch_size #//是整除的意思。計算一共有多少個批次

x=tf.placeholder(tf.float32,

[none

,784])

#none表示可以取任意值,方便後面傳入。

y=tf.placeholder(tf.float32,

[none,10

])#y就是標籤,正確答案

#建立乙個簡單的神經網路(前向傳播)

weights=tf.variable(tf.truncated_normal(

[lstm_size,n_classes]

,stddev=

0.1)

)biases=tf.variable(tf.constant(

0.1,shape=

[n_classes]))

#定義rnn網路

defrnn

(x,weights,biases)

:#inputs=[batch_size,max_time,n_inputs]

inputs=tf.reshape(x,[-

1,max_time,n_inputs]

)#改變x的shape,使它可以參與運算

#定義lstm基本cell

lstm_cell=tf.contrib.run.core_run_celll.basiclstmcell(lstm_size)

#final_state[0]是cell_state

#final_state[1]是hidden_state

outputs,final_state=tf.nn.dynamic_run(lstm_cell,inputs,dtype=tf.float32)

results=tf.nn.softmax(tf.matmul(final_state[1]

,weights)

+biases)

return results

#計算rnn的返回結果

prediction=rnn(x,weights,biases)

#二次代價函式(反向傳播)

cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y)

)train_step=tf.train.adamoptimizer(

0.0001

).minimize(cross_entropy)

#重點理解這兩句,有新東西。

correct_prediction=tf.equal(tf.argmax(y,1)

,tf.argmax(prediction,1)

)#tf.argmax(input,axis)根據axis取值的不同返回每行或者每列最大值的索引。axis為1表示取行最大值得索引。

#如果兩個值相等,返回true,結果儲存的是布林型的列表

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)

)#tf.cast()類似強制型別轉換,把布林型變為32位float型,然後求平均。[1,1,1,0,0,0,1,1,1,1],準確率為0.7

init=tf.global_variables_initializer(

)with tf.session(

)as sess:

sess.run(init)

for epoch in

range(6

):for batch in

range

(n_batch)

: batch_xs,batch_ys=mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict=

) acc=sess.run(accuracy,feed_dict=

)print

("iter"

+str

(epoch)

+",testing accuracy"

+str

(acc)

)

tensorflow實踐 手寫MNIST數字識別

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集,讀取的是壓縮包 mnist input data.read data sets mnist one hot true 每個...

tensorflow實現MNIST手寫數字識別

mnist資料集是由0 9,10個手寫數字組成。訓練影象有60000張,測試影象有10000張。from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data ...

MNIST資料集手寫體識別 RNN實現

github部落格傳送門 csdn部落格傳送門 tensorflow python基礎 深度學習基礎網路模型 mnist手寫體識別資料集 import tensorflow as tf mnist input data.read data sets mnist data one hot true c...