深度學習(一) 簡單神經網路識別手寫數字

2021-08-20 22:28:43 字數 1531 閱讀 6743

mnist資料集相當於深度學習中的「hello world」,用於開始做測試用的簡單的視覺資料集,由幾萬張28*28的手寫數字組成,只包含灰度資訊,分為十類0~9。

1)選擇softmax regression模型

2)定義算是函式,這裡選擇交叉熵

3)選擇優化演算法梯度下降法

4)迭代進行資料訓練

5)進行驗證和準確率評測

#載入資料

from tensorflow.examples

.tutorials

.mnist import input_data

import tensorflow as tf

#訓練集5.5w,測試機1w,驗證集0.5w,每個樣本有乙個label

mnist=input_data.read_data_sets("mnist_data/",one_hot=true)

#label是乙個10維的向量,[1,0,0,0,0,0,0,0,0,0],其代表數字為0

#placeholder資料輸入的地方,第乙個引數資料型別,第二個引數是資料尺寸大小

sess=tf.interactivesession()

x=tf.placeholder(tf.float32,[none,784])

#定義權重和截距b

w=tf.variable(tf.zeros([784,10]))

b=tf.variable(tf.zeros([10]))

#softmax regression演算法公式

y=tf.nn

.softmax(tf.matmul(x,w)+b)

#定義乙個損失函式,多分類問題多用cross_entropy

y_=tf.placeholder(tf.float32,[none,10])

cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

#定義優化演算法和訓練速率,用梯度下降法和0.5的速率

train_step=tf.train

.gradientdescentoptimizer(0.5).minimize(cross_entropy)

#定義全域性引數初始化器

tf.global_variables_initializer().run()

#進行訓練

for i in range(1000):

batch_xs,batch_ys=mnist.train

.next_batch(100)

train_step.run()

correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

print(accuracy.eval())

準確率為:0.9156

參考:

《tensorflow實戰》

簡單的神經網路識別手寫數字

為了測試,我就先在紙上手寫了一些 工整 的數字,並用 特殊手段 將進行了處理,得到了幾個用於測試手寫數字識別率的 標準的大小為28 28,確保輸入層的節點數為784,好吧,我承認是懶所以只寫了五個 開始測試資料 負責一維陣列的引數 k 0 製作乙個一維陣列負責儲存二維陣列扁平化以後的資料 input...

使用神經網路識別手寫數字

神經網路和深度學習為影象識別 語音識別 自然語言處理等問題提供了目前最好的解決方案。本書主要會介紹神經網路和深度學習背後關鍵的概念。更多關於本書的細節,請參考這裡。或者您可以直接從第一章開始學習。本專案是neural networks and deep learning的中文翻譯,原文作者 mich...

四 神經網路(識別手寫字)

結合 python神經網路程式設計 這本書實現 個人認為最近幾年出的實戰系列書,給出的 和思路更加貼切現在的技術,吳恩達課程講解很棒但是很多資料和 或者思想比較老化,不便於吸收理解。乙個三層的簡單神經網路實現 import numpy as np import scipy.special impor...