TensorFlow入門MNIST手寫識別

2021-09-23 08:04:42 字數 3069 閱讀 3830

#匯入mnist資料集:

#訓練集有55000個樣本;測試集有10000個樣本;同時驗證集有5000個樣本(每個樣本都有它應標註資訊,即lable)

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist1_data/", one_hot=true)

#檢視mnist資料集的情況

print(mnist.train.images.shape, mnist.train.labels.shape)

print(mnist.test.images.shape, mnist.test.labels.shape)

print(mnist.validation.images.shape, mnist.validation.labels.shape)

#載入tensorflow庫,建立interactivesession,之後的運算預設跑在這個session裡

#接下來建立placeholder,即輸入資料的地方,第乙個引數是資料型別,第二個引數代表tensor的shape,就是資料的尺寸

#none代表不限條數的輸入,784=28×28輸入的784維向量

import tensorflow as tf

sess = tf.interactivesession()

x = tf.placeholder(tf.float32, [none, 784])

#variable用來儲存模型引數的,它可以長期存在且在每輪迭代中更新

w = tf.variable(tf.zeros([784, 10])) #把weights和biases全部初始化為0

b = tf.variable(tf.zeros([10]))

# 實現softmax regression演算法 y=softmax(wx+b)

# softmax是tf.nn下面的乙個函式,tf.nn包含了大量神經網路元件,tf.matmul是tensorflow中的矩陣乘法

y = tf.nn.softmax(tf.matmul(x, w) + b)

#為了訓練模型定義乙個loss function描述模型對問題的分類精度:通常使用cross-entropy作為loss function

# hy'(y)=-∑y'ilog(yi) :用來判斷模型對真實概率分布的估計準確度,y是**的概率分布,y'是真實的概率分布

#先定義placeholder,輸入真實的label,用來計算cross-entropy,這裡y_* tf.log(y)公式中的y'ilog(yi),reduce_sum求和

#reduce_mean用來對每個batch資料結果求均值

y_ = tf.placeholder(tf.float32, [none, 10])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) #損失函式

#下面定義乙個優化演算法即可開始訓練,採用常見的隨機梯度下降sgd,tensorflow自動構成計算圖,並根據反向傳播演算法訓練

#在每一輪迭代時更新引數來減小loss,tensorflow提供我們封裝好的優化器,只需每輪迭代時feed資料給它就可以

#直接呼叫tf.train.gradientdescentoptimizer,設定學習速率為0.5,優化目標設定為cross-entropy

train_step = tf.train.gradientdescentoptimizer(0.5).minimize(cross_entropy)

#下一步使用tensorflow的全域性引數初始化器tf.global_variables_initializer,並直接使用它的run方法

tf.global_variables_initializer().run()

#最後一步開始迭代執行訓練操作 train_step,這裡每次都隨機從訓練集中抽取100條樣本構成乙個mini-batch,並feed給

#placeholder,然後呼叫 train_step對這些樣本進行訓練。使用一小部分樣本進行訓練稱為隨機梯度下降

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

train_step.run()

#現在已完成了訓練,接下來對模型的準確率進行驗證,tf.argmax是從乙個tensor中尋找最大值的序號

#tf.argmax(y, 1)就是求各個**的數字中概率最大的乙個

#tf.equal方法用來判斷**的數字類別是否是正確的類別

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

#統計全部樣本**的accuracy,這裡需要選用tf.cast將之前correct_prediction輸出的bool值轉換為float32,再求平均

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

#計算模型在測試集上的準確率,再將結果列印出來

print(accuracy.eval())

#本演算法平均準確率可達92%左右

使用tensorflow實現了乙個簡單的機器學習演算法 softmax regression,這算作乙個沒有隱含層的最淺的神經網路。

四個流程:

1.定義演算法公式,也就是神經網路forward時的計算

2.定義loss,選定優化器,並指定優化器優化loss

3.迭代地對資料進行訓練

4.在測試集火驗證集上對準確率進行評測

在定義的各個公式其實只是computation graph,在執行這行**時,計算還沒有實際發生,只有等呼叫run方法,並feed資料時計算才真正執行。比如cross_entropy,train_step,accuracy等都是計算圖中的節點,並不是資料結果,可以通過呼叫run 方法執行這些節點或者運算操作來獲取結果。

參考:《tensorflow實戰》黃文堅  唐源

tensorflow教程學習三深入MNIST

載入資料 from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data one hot true 我們使用interactivesession類可以更加靈活地...

Tensorflow 入門記錄

tensor為張量,flow為流圖。tensorflow內含有很多寫好的工具,如梯度下降演算法,卷積操作等。在使用tensorflow時,先導入包import tensorflow as tf,在進行定義tensorflow變數時,使用tf.variable 引數 有趣的是乙個叫做佔位符的工具,tf...

tensorflow入門例子

import tensorflow as tf import numpy as np 使用 numpy 生成假資料 phony data 總共 100 個點.100,2 x data np.float32 np.random.rand 100,2 隨機輸入 y data np.dot x data,...