基於TensorFlow的MNIST資料集的實驗

2022-03-03 05:34:25 字數 2519 閱讀 3142

一、mnist實驗內容

**如下所示:

import

tensorflow as tf

from tensorflow.examples.tutorials.mnist import

input_data

import

matplotlib.pyplot as plt

import

numpy as np

%matplotlib inline

mnist = input_data.read_data_sets('

/home/ubuntu-mm/tensorflow/learning/mnist_data

', one_hot=true) #

#表示輸入任意數量的mnist影象,每一張圖展平成784維的向量

#placeholder是佔位符,在訓練時指定

x = tf.placeholder(tf.float32, [none, 784])

#初始化w,b矩陣

w = tf.variable(tf.zeros([784,10]))

b = tf.variable(tf.zeros([10]))

#tf.matmul(x,w)表示x乘以w

y = tf.nn.softmax(tf.matmul(x, w) +b)

#為了計算交叉熵,我們首先需要新增乙個新的佔位符用於輸入正確值

y_ = tf.placeholder("

float

", [none,10])

#交叉熵損失函式

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

#模型的訓練,不斷的降低成本函式

#要求tensorflow用梯度下降演算法(gradient descent algorithm)以0.01的學習速率最小化交叉熵

train_step = tf.train.gradientdescentoptimizer(0.01).minimize(cross_entropy)

#在執行計算之前,需要新增乙個操作來初始化我們建立的變數

init =tf.global_variables_initializer()

#在session裡面啟動我模型,並且初始化變數

with tf.session() as sess:

#sess = tf.session()

#sess.run(init)

sess.run(init)

#開始訓練模型,迴圈訓練1000次

for i in range(50):

#隨機抓取訓練資料中的100個批處理資料點

batch_xs, batch_ys = mnist.train.next_batch(100)

#然後我們用這些資料點作為引數替換之前的佔位符來執行train_step

sess.run(train_step, feed_dict=)

#檢驗真實標籤與**標籤是否一致

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

#計算精確度,將true和false轉化成相應的浮點數,求和取平均

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "

float"))

#計算所學習到的模型在測試資料集上面的正確率

print(sess.run(accuracy, feed_dict=))

print

'w is:

',w.eval()[10]

print

'b is:

',b.eval()

batch_xs, batch_ys = mnist.train.next_batch(100)

#print 'batch_xs[1]=',batch_xs[1]

print

'batch_ys[1]=

',batch_ys[1]

x_in = tf.reshape(batch_xs[1],[1,784])

y_predict = tf.nn.softmax(tf.matmul(x_in, w) +b)

print

'y_predict is :

',y_predict.eval()

ori_pic = np.zeros([28,28])

for m in range(784):

i = m%28j = (m-i)/28ori_pic[j][i] = batch_xs[1][m]

plt.figure(1)

plt.imshow(ori_pic)

實驗執行的結果如下所示:

由結果顯示的可知:對應為6的概率是99.56%

二、交叉熵損失函式的基本原理:

深度入門學習 tensorflow讀取mnist

import tensorflow as tf import matplotlib.pyplot as plt 讀取mnist資料方法一 from tensorflow.examples.tutorials.mnist import input data mnist input data.read ...

基於Tensorflow的Keras安裝

平台 ubuntu14.04 版本 python3.5,anaconda3 4.1.1,tensorflow1.4.0 bash anaconda sh 檔名 步驟2 回到剛開啟命令框時的目錄,依此輸入 mkdir pip cd pip vim pipconfig 然後會彈出一些配置 拉到最下面,點...

tensorflow實現基於LSTM的文字分類方法

學習一段時間的tensor flow之後,想找個專案試試手,然後想起了之前在看theano教程中的乙個文字分類的例項,這個星期就用tensorflow實現了一下,感覺和之前使用的theano還是有很大的區別,有必要總結mark一下 這個分類的模型其實也是很簡單,主要就是乙個單層的lstm模型,當然也...