深度學習入門筆記7 手寫數字識別

2021-10-05 11:25:10 字數 4474 閱讀 9484

mnist資料集(修改的國家標準與技術研究所——modified national institute of standards and technology),是乙個大型的包含手寫數字的資料集。該資料集由0-9手寫數字組成,共10個類別。每張的大小為28 * 28。

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

import matplotlib.pyplot as plt

# 通過指定的路徑(第1個引數)獲取(載入)手寫數字資料集。如果指定的路徑中檔案

# 如果檔案已經存在,則直接使用。

mnist = input_data.read_data_sets(

"data/"

, one_hot=

true

)

mnist資料集共有70000張影象,其中訓練集60000張,測試集10000張。訓練集分為55000張訓練影象與5000張驗證影象。

mnist影象為單通道。

display(mnist.train.images.shape)

display(mnist.train.labels.shape)

mnist.train.labels[

0]

可以通過matplotlib庫顯示指定的影象。

plt.imshow(mnist.train.images[1]

.reshape((28

,28))

, cmap=

"gray"

)

我們使用單層神經網路來實現該任務。使用softmax啟用函式。

影象的每個畫素可以看做乙個特徵,而每個畫素點對應著乙個權重,來衡量該畫素點對目標的影響大小。

# 定義輸入。

x = tf.placeholder(dtype=tf.float32, shape=

[none

,784])

y = tf.placeholder(dtype=tf.float32, shape=

[none,10

])# w = tf.variable(tf.random_normal(shape=[784, 10], stddev=0.1))

# 對於單層的神經網路,權重初始化不那麼重要,但是對於多層神經網路,權重的初始化就比較重要了。

w = tf.variable(tf.zeros(shape=

[784,10

]))b = tf.variable(tf.zeros(shape=[1

,10])

)# 計算淨輸入。(logits值)

z = tf.matmul(x, w)

+ b# 多分類,使用softmax。傳遞logits值,返回屬於每個類別的概率。

a = tf.nn.softmax(z)

# 定義交叉熵損失函式。

loss =

-tf.reduce_sum(y * tf.log(a)

)train_step = tf.train.gradientdescentoptimizer(

0.01

).minimize(loss)

# tf.argmax(y, axis=1) 求真實的類別的索引。

# tf.argmax(a, axis=1) 求**的類別的索引。

# correct是乙個布林型別的張量。

correct = tf.equal(tf.argmax(y, axis=1)

, tf.argmax(a, axis=1)

)# 計算準確率。

rate = tf.reduce_mean(tf.cast(correct, tf.float32)

)with tf.session(

)as sess:

sess.run(tf.global_variables_initializer())

for i in

range(1

,3001):

batch_x, batch_y = mnist.train.next_batch(

100)

sess.run(train_step, feed_dict=

)if i %

500==0:

# 傳入測試資料,檢視測試集上的準確率。

print

(sess.run(rate, feed_dict=

))

0.894

0.92

0.9183

0.9031

0.9207

0.9152

採用中間加入一隱藏層(多層神經網路)來實現,檢視準確率是否改善。

x = tf.placeholder(dtype=tf.float32, shape=

[none

,784])

y = tf.placeholder(dtype=tf.float32, shape=

[none,10

])# 如果將權重初始化為0,則準確率非常低。10%左右。

# w = tf.variable(tf.zeros(shape=[784, 256]))

# 如果標準差設定不當,準確率也非常低。

# w = tf.variable(tf.random_normal(shape=[784, 256], stddev=0.05))

# 使用標準正態分佈,標準差0.05,準確率為97%左右。

# w = tf.variable(tf.random_normal(shape=[784, 256], stddev=0.05))

# 也可以使用截斷正態分佈,準確率與標準正態分佈差不多。使用截斷正態分佈時,標準差設定為0.1不會出現問題。

w = tf.variable(tf.truncated_normal(shape=

[784

,256

], stddev=

0.1)

)b = tf.variable(tf.zeros(shape=[1

,256])

)z = tf.matmul(x, w)

+ b# 使用relu啟用函式。a是當前層神經元的輸出值,會作為下一層神經元的輸入值。

a = tf.nn.relu(z)

# w2 = tf.variable(tf.random_normal(shape=[256, 10], stddev=0.05))

w2 = tf.variable(tf.truncated_normal(shape=

[256,10

], stddev=

0.1)

)b2 = tf.variable(tf.zeros(shape=[1

,10])

)z2 = tf.matmul(a, w2)

+ b2

a2 = tf.nn.softmax(z2)

loss =

-tf.reduce_sum(y * tf.log(a2)

)# 這裡不再計算softmax,再計算交叉熵,而是直接用tf.nn.softmax_cross_entropy_with_logits直接計算。

# 但是,使用該方法後,準確率有所下降。

# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=z2, labels=y))

train_step = tf.train.gradientdescentoptimizer(

0.01

).minimize(loss)

correct = tf.equal(tf.argmax(y, axis=1)

, tf.argmax(a2, axis=1)

)# correct = tf.equal(tf.argmax(y, axis=1), tf.argmax(z2, axis=1))

rate = tf.reduce_mean(tf.cast(correct, tf.float32)

)with tf.session(

)as sess:

sess.run(tf.global_variables_initializer())

for i in

range(1

,3001):

batch_x, batch_y = mnist.train.next_batch(

100)

sess.run(train_step, feed_dict=

)if i %

500==0:

print

(sess.run(rate, feed_dict=

))

0.9619

0.9716

0.9749

0.974

0.9773

0.9746

Pytorch學習筆記(三) 手寫數字識別

1.首先匯入所需要的包,其中torchvision包主要實現資料的處理 匯入和預覽 import torch from torchvision import datasets,transforms from torch.autograd import variable download the da...

scikit learn ID3手寫數字識別

判定樹是乙個類似於流程圖的樹結構 其中,每個內部結點表示在一屬性上的測試,每個分支代表乙個屬性輸出,而每個樹葉結點代表類或類分布。樹的頂層是根結點。id3演算法根據的就是資訊獲取量 information gain gain a info d infor a d coding utf 8 pytho...

tensorflow3 手寫數字識別

28 28個輸入單元,200個中間單元,10個輸出單元 coding utf 8 created on fri may 17 19 39 39 2019 author 666 import tensorflow as tf from tensorflow.examples.tutorials.mni...