MNIST資料集手寫體識別 RNN實現

2022-08-24 13:03:09 字數 3222 閱讀 1190

github部落格傳送門

csdn部落格傳送門

tensorflow

python基礎

深度學習基礎網路模型(mnist手寫體識別資料集)

import tensorflow as tf

mnist = input_data.read_data_sets('../mnist_data/', one_hot=true)

class rnnnet: # 建立乙個rnnnet類

def __init__(self):

self.x = tf.placeholder(dtype=tf.float32, shape=[none, 28, 28], name='input_x') # 建立資料佔位符

self.y = tf.placeholder(dtype=tf.float32, shape=[none, 10], name='input_y') # 建立標籤佔位符

self.fc_w1 = tf.variable(tf.truncated_normal(shape=[128, 10], dtype=tf.float32, stddev=tf.sqrt(1 / 10), name='fc_w1')) # 定義 輸出層/全鏈結層 w

self.fc_b1 = tf.variable(tf.zeros(shape=[10]), dtype=tf.float32, name='fc_b1') # 定義 輸出層/全鏈結層 偏值b

# 前向計算

def forward(self):

cell = tf.nn.rnn_cell.basiclstmcell(128) # 建立128個lstm的rnn結構(細胞結構)

state1 = cell.zero_state(100, dtype=tf.float32) # 初始化細胞的狀態為 0, 傳入初始化批次 和資料型別

self.rnn_ouput, self.state = tf.nn.dynamic_rnn(cell, self.x, initial_state=state1, time_major=false) # 將細胞cell 和資料 self.x 初始化狀態傳入rnn細胞結構 獲得兩個返回值 output 和 狀態state

self.fc1 = tf.matmul(self.rnn_ouput[:, -1, :], self.fc_w1) + self.fc_b1 # 取rnn_output的輸出狀態的 每個輸出的最後一行 進行全鏈結計算

self.output = tf.nn.softmax(self.fc1) # 將全鏈結計算後的結果進行 softmax分類

# 後向計算

def backward(self):

# 求出網路的 cost值(損失)

self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.fc1, name='cost'))

# 使用adamoptimizer優化器優化 self.cost

self.opt = tf.train.adamoptimizer().minimize(self.cost)

# # 計算測試集識別精度

def acc(self):

# 將**值 output 和 標籤值 self.y 進行比較

self.acc1 = tf.equal(tf.argmax(self.output, 1), tf.argmax(self.y, 1))

# 最後對比較出來的bool值 轉換為float32型別後 求均值就可以看到滿值為 1的精度顯示

self.accaracy = tf.reduce_mean(tf.cast(self.acc1, tf.float32))

if __name__ == '__main__':

net = rnnnet() # 啟動tensorflow繪圖的rnnnet

net.forward() # 啟動前向計算

net.backward() # 啟動後向計算

net.acc() # 啟動精度計算

init = tf.global_variables_initializer() # 定義初始化tensorflow所有變數操作

with tf.session() as sess: # 建立乙個session會話

sess.run(init) # 執行init變數內的初始化所有變數的操作

for i in range(10000): # 訓練10000次

ax, ay = mnist.train.next_batch(100) # 從mnist資料集中取資料出來 ax接收 ay接收標籤

ax_batch = ax.reshape([-1, 28, 28]) # 將取出的 資料 reshape成 nsv 結構

loss, output, accaracy, _ = sess.run(fetches=[net.cost, net.output, net.accaracy, net.opt], feed_dict=) # 將資料喂進rnn網路

# print(loss) # 列印損失

# print(accaracy) # 列印訓練精度

if i % 10 == 0: # 每訓練10次

test_ax, test_ay = mnist.test.next_batch(100) # 則使用測試集對當前網路進行測試

test_ax_batch = sess.run(tf.reshape(test_ax, [-1, 28, 28])) # 將取出的 資料 reshape成 nsv 結構

test_output = sess.run(fetches=net.output, feed_dict=) # 注意fetches=[net.output]加了中括號返回值會變為list # 將測試資料喂進網路 接收乙個output值

test_acc = tf.equal(tf.argmax(test_output, 1), tf.argmax(test_ay, 1)) # 對output值和標籤y值進行求比較運算

test_accaracy = sess.run(tf.reduce_mean(tf.cast(test_acc, tf.float32))) # 求出精度的準確率進行列印

print(test_accaracy) # 列印當前測試集的精度

mnist手寫體識別 卷積神經網路

coding utf 8 通過卷積神經網路進行 author elijah 引入資料集 from tensorflow.examples.tutorials.mnist import input data import tensorflow as tf mnist input data.read d...

Keras入門級實戰 MNIST手寫體識別

手寫體識別 這裡要解決的問題是,將手寫數字的灰度影象 28 畫素 28 畫素 劃分到 10 個類別 中 0 9 這個資料集包含 60 000 張訓練影象和 10 000 張測試圖 像,由美國國家標準與技術研究院 national institute of standards and technolo...

kaggle練習 手寫體識別

coding utf 8 created on sun apr 22 10 25 14 2018 author zhangsh import csv import numpy as np from sklearn.neighbors import kneighborsclassifier list ...