卷積神經網路分類mnist手寫體數字

2021-09-05 08:05:27 字數 2584 閱讀 7437

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist_data",one_hot=true)

import matplotlib.pyplot as plt

class net:

def __init__(self):

self.x = tf.placeholder(tf.float32,[none,28,28,1])

self.y = tf.placeholder(tf.float32,[none,10])

self.conv1_w = tf.variable(tf.random_normal([3,3,1,16],dtype=tf.float32,stddev=0.1))

self.conv1_b = tf.variable(tf.zeros([16]))

self.conv2_w = tf.variable(tf.random_normal([3,3,16,32],dtype=tf.float32,stddev=0.1))

self.conv2_b = tf.variable(tf.zeros([32]))

self.w1 = tf.variable(tf.random_normal([7*7*32,128],stddev=0.1))

self.b1 = tf.variable(tf.zeros([128]))

self.w2 = tf.variable(tf.random_normal([128,10],stddev=0.1))

self.b2 = tf.variable(tf.zeros([10]))

def forward(self):

self.conv1 = tf.nn.relu(tf.nn.conv2d(self.x,self.conv1_w,strides=[1,1,1,1],padding='same')+self.conv1_b)

self.pool1 = tf.nn.max_pool(self.conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='same')

self.conv2 = tf.nn.relu(tf.nn.conv2d(self.pool1,self.conv2_w,strides=[1,1,1,1],padding='same')+self.conv2_b)

self.pool2 = tf.nn.max_pool(self.conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='same')

self.flat = tf.reshape(self.pool2,[-1,7*7*32])

self.y1 = tf.nn.relu(tf.matmul(self.flat,self.w1)+self.b1)

self.y2 = tf.nn.softmax(tf.matmul(self.y1,self.w2)+self.b2)

def backward(self):

self.loss = tf.reduce_mean((self.y2-self.y)**2)

self.opt = tf.train.adamoptimizer().minimize(self.loss)

self.prediction_corect = tf.equal(tf.argmax(self.y2,1),tf.argmax(self.y,1))#比較**值和真實值是否相等

self.rst = tf.cast(self.prediction_corect,'float')#將布林值轉化為float型別

self.accuracy = tf.reduce_mean(self.rst)#求出平均值表示精度(百分數)

if __name__ == '__main__':

net = net()

net.forward()

net.backward()

init = tf.global_variables_initializer()

with tf.session() as sess:

sess.run(init)

a =

b =

c =

for i in range(1000):

x,y = mnist.train.next_batch(100)

x = x.reshape([100,28,28,1])

loss,acc,_ = sess.run([net.loss,net.accuracy,net.opt],feed_dict=)

if i%10 == 0:

plt.subplot(1,2,1)#生成1行兩列的子圖顯示在第乙個子圖

plt.plot(a,b)

plt.title('accuracy rate')

plt.subplot(1,2,2)#生成1行兩列的子圖顯示在第二個子圖

plt.plot(a,c)

plt.title('loss')

plt.pause(0.0001)

print(loss,acc)

mnist手寫體識別 卷積神經網路

coding utf 8 通過卷積神經網路進行 author elijah 引入資料集 from tensorflow.examples.tutorials.mnist import input data import tensorflow as tf mnist input data.read d...

卷積神經網路應用於MNIST資料集分類

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data one hot true 每個批次的大小 bat...

tensorflow實現卷積神經網路手寫數字識別

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data one hot true batch size ...