MNIST資料集分類簡單版本

2022-06-13 07:33:11 字數 3111 閱讀 3735

from tensorflow.examples.tutorials.mnist import input_data

#載入資料集
mnist = input_data.read_data_sets("/data/stu05/mnist_data",one_hot=true)

extracting /data/stu05/mnist_data/train-images-idx3-ubyte.gz

extracting /data/stu05/mnist_data/train-labels-idx1-ubyte.gz

extracting /data/stu05/mnist_data/t10k-images-idx3-ubyte.gz

extracting /data/stu05/mnist_data/t10k-labels-idx1-ubyte.gz

#每個批次的大小
batch_size = 100

#計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder,none=100,28*28=784,即100行,784列
x = tf.placeholder(tf.float32,[none,784])

#0-9個輸出標籤
y = tf.placeholder(tf.float32,[none,10])

#建立乙個簡單的神經網路,只有輸入層和輸出層
w = tf.variable(tf.zeros([784,10]))

b = tf.variable(tf.zeros([1,10]))

#softmax函式轉化為概率值
prediction = tf.nn.softmax(tf.matmul(x,w)+b)

#二次代價函式
loss = tf.reduce_mean(tf.square(y-prediction))

#使用梯度下降法
train_step = tf.train.gradientdescentoptimizer(0.2).minimize(loss)

#初始化變數
init = tf.global_variables_initializer()

#tf.equal()比較函式大小是否相同,相同為true,不同為false;tf.argmax():求y=1在哪個位置,求概率最大在哪個位置
#argmax返回一維張量中最大的值所在的位置,結果存放在乙個布林型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

#求準確率
#cast轉化型別,將布林型轉化為32位浮點型,true=1.0,false=0.0;再求平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.session() as sess:

sess.run(init)

#將所有訓練21次
for epoch in range(21):

#訓練一次所有的
for batch in range(n_batch):

batch_xs,batch_ys = mnist.train.next_batch(batch_size)

#feed_dict傳入訓練集的和標籤
sess.run(train_step,feed_dict=)

#傳入測試集的和標籤
acc = sess.run(accuracy,feed_dict=)

print("iter"+str(epoch)+",testing accuracy:"+str(acc))

iter0,testing accuracy:0.8303

iter1,testing accuracy:0.8708

iter2,testing accuracy:0.8821

iter3,testing accuracy:0.8885

iter4,testing accuracy:0.8941

iter5,testing accuracy:0.8973

iter6,testing accuracy:0.9001

iter7,testing accuracy:0.9013

iter8,testing accuracy:0.9038

iter9,testing accuracy:0.9048

iter10,testing accuracy:0.9068

iter11,testing accuracy:0.9068

iter12,testing accuracy:0.9084

iter13,testing accuracy:0.9094

iter14,testing accuracy:0.9097

iter15,testing accuracy:0.9107

iter16,testing accuracy:0.9118

iter17,testing accuracy:0.9116

iter18,testing accuracy:0.9127

iter19,testing accuracy:0.9136

iter20,testing accuracy:0.9146

Mnist資料集分類簡單版本

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one hot true 每個批次的...

MNIST資料集分類簡單版本(詳細)

import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input data 載入mnist資料集 mnist input data.read data sets mnist...

3 3實現MNIST資料集分類

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one hot true 定義每個批...