mnist手寫數字識別

2021-10-01 08:20:27 字數 2989 閱讀 5148

import tensorflow as tf

import numpy as np

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

mnist = read_data_sets(

"f:/python/mnist_data/"

, one_hot=

true

)x=tf.placeholder(tf.float32,

[none

,784])

y_ = tf.placeholder(

"float",[

none,10

])defweight_variable

(shape)

:

initial=tf.truncated_normal(shape=shape,stddev=

0.1)

return tf.variable(initial)

defbias_variable

(shape)

: initial = tf.constant(

0.1, shape=shape)

return tf.variable(initial)

defconv2d

(x, w)

:return tf.nn.conv2d(x, w, strides=[1

,1,1

,1], padding=

'same'

)def

max_pool_2x2

(x):

return tf.nn.max_pool(x, ksize=[1

,2,2

,1],strides=[1

,2,2

,1], padding=

'same'

)#第一層卷積

w_conv1 = weight_variable([5

,5,1

,32])

b_conv1 = bias_variable([32

])x_image = tf.reshape(x,[-

1,28,

28,1]

)h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1)

+ b_conv1)

h_pool1 = max_pool_2x2(h_conv1)

#第二層卷積

w_conv2 = weight_variable([5

,5,32

,64])

b_conv2 = bias_variable([64

])h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2)

+ b_conv2)

h_pool2 = max_pool_2x2(h_conv2)

#全連線層

w_fc1 = weight_variable([7

*7*64

,1024])

b_fc1 = bias_variable(

[1024])

h_pool2_flat = tf.reshape(h_pool2,[-

1,7*

7*64]

)h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1)

+ b_fc1)

#dropout層取消一些神經元,防止過擬合

keep_prob = tf.placeholder(

"float"

)h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

#輸出層

w_fc2 = weight_variable(

[1024,10

])b_fc2 = bias_variable([10

])y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2)

+ b_fc2)

#訓練cross_entropy =

-tf.reduce_sum(y_*tf.log(y_conv)

)#使用adam優化器

train_step = tf.train.adamoptimizer(1e-

4).minimize(cross_entropy)

#計算準確率,equal比較大小,argmax返回最大值下標

correct_prediction = tf.equal(tf.argmax(y_conv,1)

, tf.argmax(y_,1)

)#轉換成浮點數並計算平均值

accuracy = tf.reduce_mean(tf.cast(correct_prediction,

"float"))

sess = tf.session(

)sess.run(tf.global_variables_initializer())

for i in

range

(2000):

#其中的n代表返回多少個訓練資料集和對應的標籤

batch = mnist.train.next_batch(50)

if i%

100==0:

#batch[0]:訓練,batc[1]:訓練標籤,餵給x,y_

train_accuracy = sess.run(accuracy, feed_dict=

)#列印訓練輪數和準確率

print

("step %d, training accuracy %g"

%(i, train_accuracy)

) sess.run(train_step,feed_dict=

)#測試資料

print

("test accuracy %g"

%sess.run(accuracy,feed_dict=

))

MNIST手寫數字識別 tensorflow

神經網路一半包含三層,輸入層 隱含層 輸出層。如下圖所示 現以手寫數字識別為例 輸入為784個變數,輸出為10個節點,10個節點再通過softmax啟用函式轉化為 值。如下,準確率可達0.9226 import tensorflow as tf from tensorflow.examples.tu...

DNN識別mnist手寫數字

提取碼 sg3f 導庫import numpy as np import matplotlib.pyplot as plt import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers...

基於MNIST的手寫數字識別

1 mnist 資料資料集獲取 方式一 使用 tf.contrib,learn 模組載入 mnist 資料集 棄用 如下 使用 tf.contrib.learn 模組載入 mnist 資料集 deprecated 棄用 import tensorflow as tf from tensorflow....