DNN識別mnist手寫數字

2022-06-27 04:39:10 字數 2619 閱讀 4983

提取碼:sg3f

導庫

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers

讀取mnist資料

import numpy as np

path='./mnist.npz'

f = np.load(path)

train_x, train_y = f['x_train'], f['y_train'] # 訓練集

test_x, test_y = f['x_test'], f['y_test'] # 測試集

f.close()

檢視資料格式

將資料以形式輸出

將資料格式改為dnn可接收的一維格式

train_x = train_x.reshape((60000,28*28),order='c')    # 將二維的展開為一維的資料(訓練集)  

test_x = test_x.reshape((10000,28*28),order='c') # 將二維的展開為一維的資料(測試集)

搭建dnn並訓練

model = keras.sequential()

model.add(layers.dense(100,activation='relu',input_dim=28*28))

model.add(layers.dense(10,activation='softmax'))

adam = keras.optimizers.adam(lr=0.01)

model.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['acc'])

model.fit(train_x,train_y,epochs=50,batch_size=512)

經過50輪訓練後,dnn在訓練集上的loss和準確率如下

dnn在測試集上的loss和準確率如下

model.evaluate(test_x,test_y)
完整的**如下

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers

path='./mnist.npz'

f = np.load(path)

train_x, train_y = f['x_train'], f['y_train'] # 訓練集

test_x, test_y = f['x_test'], f['y_test'] # 測試集

f.close()

print(train_x.shape)

print(train_y.shape)

print(test_x.shape)

print(test_y.shape)

plt.imshow(train_x[10000])

train_x = train_x.reshape((60000,28*28),order='c') # 將二維的展開為一維的資料(訓練集)

test_x = test_x.reshape((10000,28*28),order='c') # 將二維的展開為一維的資料(測試集)

model = keras.sequential()

model.add(layers.dense(100,activation='relu',input_dim=28*28))

model.add(layers.dense(10,activation='softmax'))

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

model.fit(train_x,train_y,epochs=50,batch_size=512)

model.evaluate(test_x,test_y)

mnist手寫數字識別

import tensorflow as tf import numpy as np from tensorflow.contrib.learn.python.learn.datasets.mnist import read data sets mnist read data sets f pyth...

MNIST手寫數字識別 tensorflow

神經網路一半包含三層,輸入層 隱含層 輸出層。如下圖所示 現以手寫數字識別為例 輸入為784個變數,輸出為10個節點,10個節點再通過softmax啟用函式轉化為 值。如下,準確率可達0.9226 import tensorflow as tf from tensorflow.examples.tu...

基於MNIST的手寫數字識別

1 mnist 資料資料集獲取 方式一 使用 tf.contrib,learn 模組載入 mnist 資料集 棄用 如下 使用 tf.contrib.learn 模組載入 mnist 資料集 deprecated 棄用 import tensorflow as tf from tensorflow....