LSTM對MNIST資料集做分類

2021-09-04 04:17:02 字數 3447 閱讀 9353

rnn 從每張的第一行畫素讀到最後一行, 然後再進行分類判斷.

# 匯入資料

mnist = input_data.read_data_sets('mnist_data', one_hot=true)

# hyperparameters

lr = 0.001 # learning rate

training_iters = 100000 # train step 上限

batch_size = 128

n_inputs = 28 # mnist data input (img shape: 28*28)

n_steps = 28 # time steps

n_hidden_units = 128 # neurons in hidden layer

n_classes = 10 # mnist classes (0-9 digits)

x = tf.placeholder(tf.float32, [none, n_steps, n_inputs])

y = tf.placeholder(tf.float32, [none, n_classes])

weights =

biases =

rnn 總共有 3 個組成部分 ( input_layer, cell, output_layer)

# (1)input_layer

def rnn(x, weights, biases):

# 原始的 x 是 3 維資料, 我們需要把它變成 2 維資料才能使用 weights 的矩陣乘法

# x ==> (128 batches * 28 steps, 28 inputs)

x = tf.reshape(x, [-1, n_inputs])

# x_in = w*x + b

x_in = tf.matmul(x, weights['in']) + biases['in']

# x_in ==> (128 batches, 28 steps, 128 hidden) 換回3維

x_in = tf.reshape(x_in, [-1, n_steps, n_hidden_units])

# (2)cell

# 使用 basic lstm cell.

lstm_cell = tf.nn.rnn_cell.basiclstmcell(n_hidden_units, forget_bias=1.0, state_is_tuple=true)

init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) # 初始化全零 state

# (3)output_layer

outputs, states = tf.nn.dynamic_rnn(lstm_cell, x_in, initial_state=init_state, time_major=false)

results = tf.matmul(states[1], weights['out']) + biases['out']

return results

pred = rnn(x, weights, biases)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

train_op = tf.train.adamoptimizer(lr).minimize(cost)

correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.initialize_all_variables()

with tf.session() as sess:

sess.run(init)

step = 0

while step * batch_size < training_iters:

batch_xs, batch_ys = mnist.train.next_batch(batch_size)

batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])

sess.run([train_op], feed_dict=)

if step % 20 == 0:

print(sess.run(accuracy, feed_dict=))

step += 1

0.265625

0.7265625

0.828125

0.8828125

0.84375

0.859375

0.8984375

0.890625

0.84375

0.90625

0.921875

0.90625

0.9140625

0.9140625

0.9375

0.9609375

0.953125

0.921875

0.9453125

0.96875

0.9375

0.9609375

0.890625

0.984375

0.953125

0.953125

0.9453125

0.9453125

0.96875

0.9375

0.953125

0.96875

0.9375

0.9921875

0.9609375

0.9609375

0.953125

0.9609375

0.96875

0.96875

process finished with exit code 0

使用sklearn進行mnist資料集分類

深度之眼 西瓜書課後 import time import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import fetch openml from sklearn.linear model import l...

MNIST資料集介紹

mnist資料集包含了6w張作為訓練資料,1w作為測試資料。在mnist資料集中,每一張都代表了0 9中的乙個數字,的大小都是28 28,且數字都會出現在的正中間。資料集包含了四個檔案 t10k images idx3 ubyte.gz 測試資料 t10k labels idx1 ubyte.gz ...

Mnist資料集簡介

1,基本概念 mnist是乙個非常有名的手寫體數字識別資料集,在很多資料中,這個資料集都會被用作深度學習的入門樣例。而tensorflow的封裝讓使用mnist資料集變得更加方便。mnist資料集是nist資料集的乙個子集,mnist 資料集可在 獲取,它包含了四個部分 1 training set...