用rnn進行mnist資料集的處理

2021-08-19 15:36:28 字數 4558 閱讀 8601

lstm = rnn.basiclstmcell(lstm_size, forget_bias=1.0, state_is_tuple=true)

x_split =tf.split(xr, time_step_size, 0)

t ,lstm cell會產生兩個內部狀態 ctct

和htht (關於rnn與lstm的介紹可參考:

迴圈神經網路與lstm

)。當state_is_tuple=true時,上面講到的狀態ctct

和htht 就是分開記錄,放在乙個二元tuple中返回,如果這個引數沒有設定或設定成false,兩個狀態就按列連線起來返回。官方說這種形式馬上就要被deprecated了,所有我們在使用lstm的時候要加上state_is_tuple=true。

lstm = tf.nn
lstm = tf.nn

.rnn_cell.multirnncell([lstm] * num_layers, state_is_tuple=true)

# -*- coding: utf-8 -*-

import tensorflow as tf

from tensorflow.contrib import rnn

import numpy as np

import input_data

# configuration

# o * w + b -> 10 labels for each image, o[? 28], w[28 10], b[10]

# ^ (o: output 28 vec from 28 vec input)

# |

# +-+ +-+ +--+

# |1|->|2|-> ... |28| time_step_size = 28

# +-+ +-+ +--+

# ^ ^ ... ^

# | | |

# img1:[28] [28] ... [28]

# img2:[28] [28] ... [28]

# img3:[28] [28] ... [28]

# ...

# img128 or img256 (batch_size or test_size 256)

# each input size = input_vec_size=lstm_size=28

# configuration variables

input_vec_size = lstm_size = 28

# 輸入向量的維度

time_step_size = 28

# 迴圈層長度

batch_size = 128

test_size = 256

def init_weights(shape):

return tf.variable(tf.random_normal(shape, stddev=0.01))

def model(x, w, b, lstm_size):

# x, input shape: (batch_size, time_step_size, input_vec_size)

# xt shape: (time_step_size, batch_size, input_vec_size)

xt = tf.transpose(x, [1, 0, 2]) # permute time_step_size and batch_size,[28, 128, 28]

# xr shape: (time_step_size * batch_size, input_vec_size)

xr = tf.reshape(xt, [-1, lstm_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)

# each array shape: (batch_size, input_vec_size)

x_split = tf.split(xr, time_step_size, 0) # split them to time_step_size (28 arrays),shape = [(128, 28),(128, 28)...]

# make lstm with lstm_size (each input vector size). num_units=lstm_size; forget_bias=1.0

lstm = rnn.basiclstmcell(lstm_size, forget_bias=1.0, state_is_tuple=true)

# get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)

# rnn..static_rnn()的輸出對應於每乙個timestep,如果只關心最後一步的輸出,取outputs[-1]即可

outputs, _states = rnn.static_rnn(lstm, x_split, dtype=tf.float32) # 時間序列上每個cell的輸出:[... shape=(128, 28)..]

# linear activation

# get the last output

return tf.matmul(outputs[-1], w) + b, lstm.state_size # state size to initialize the stat

mnist = input_data.read_data_sets("mnist_data/", one_hot=true) # 讀取資料

# mnist.train.images是乙個55000 * 784維的矩陣, mnist.train.labels是乙個55000 * 10維的矩陣

trx, try, tex, tey = mnist.train

.images, mnist.train

.labels, mnist.test

.images, mnist.test

.labels

# 將每張圖用乙個28x28的矩陣表示,(55000,28,28,1)

trx = trx.reshape(-1, 28, 28)

tex = tex.reshape(-1, 28, 28)

x = tf.placeholder("float", [none, 28, 28])

y = tf.placeholder("float", [none, 10])

# get lstm_size and output 10 labels

w = init_weights([lstm_size, 10]) # 輸出層權重矩陣28×10

b = init_weights([10]) # 輸出層bais

py_x, state_size = model(x, w, b, lstm_size)

cost = tf.reduce_mean(tf.nn

.softmax_cross_entropy_with_logits(logits=py_x, labels=y))

train_op = tf.train

.rmspropoptimizer(0.001, 0.9).minimize(cost)

predict_op = tf.argmax(py_x, 1)

session_conf = tf.configproto()

session_conf.gpu_options.allow_growth = true

# launch the graph in a session

with tf.session(config=session_conf) as sess:

# you need to initialize all variables

tf.global_variables_initializer().run()

for i in range(100):

for start, end in zip(range(0, len(trx), batch_size), range(batch_size, len(trx)+1, batch_size)):

sess.run(train_op, feed_dict=)

test_indices = np.arange(len(tex)) # get a test batch

np.random

.shuffle(test_indices)

test_indices = test_indices[0:test_size]

print(i, np.mean(np.argmax(tey[test_indices], axis=1) ==

sess.run(predict_op, feed_dict=)))

用rnn網路訓練mnist資料集

與以往cnn卷積神經網路不同,rnn的思想是資料資訊有順序,所以rnn一般用來訓練文字資料資訊,就像小時候的填空題一樣,我們能根據上下文判斷這個空應該填什麼,這前面的文字和後面的文字都是有順序的。用rnn訓練分類也有其優點,例如每個人都是頭朝上腳朝下 個別倒立的例外 在這裡我們就用rnn網路訓練mn...

mnist資料集進行自編碼

自動編碼的核心就是各種全連線的組合,它是一種無監督的形式,因為他的標籤是自己。import torch import torch.nn as nn from torch.autograd import variable import torch.utils.data as data import t...

MNIST資料集手寫體識別 RNN實現

github部落格傳送門 csdn部落格傳送門 tensorflow python基礎 深度學習基礎網路模型 mnist手寫體識別資料集 import tensorflow as tf mnist input data.read data sets mnist data one hot true c...