keras構建用於MNIST分類的多層感知機

2021-10-04 03:47:15 字數 2486 閱讀 7059

和tensorflow一樣,用keras庫建立mlp神經網路。構建網路載入資料集,進行建模,然後訓練模型,進行模型測試以及對模型的評估。

從keras中匯入模組

import tensorflow as tf

import keras

import os

from keras.models import sequential

from keras.layers import dense,activation

from keras.optimizers import sgd

from tensorflow.examples.tutorials.mnist import input_data

載入資料集,可以用注釋的部分,由於我的還是報錯,所以我換了這一種方法,效果也還是差不多。

# mnist_home=os.path.join(datasetslib.datasets_root,'mnist')

# mnist=input_data.read_data_sets(mnist_home,one_hot=true)

mnist = input_data.read_data_sets(

'mnist_data'

,one_hot=

true

)x_train=mnist.train.images

x_test=mnist.test.images

y_train=mnist.train.labels

y_test=mnist.test.labels

#print(x_train.shape,y_train.shape)

#print(x_test.shape,x_test.shape)

定義超引數

num_inputs=

784num_outputs=

10num_layers=

2num_neurons=

for i in

range

(num_layers)

:256

)learning_rate=

0.01

n_epochs=

50batch_size=

100

建立模型

#建立乙個順序模型

model=sequential(

)#新增第乙個隱藏層,在第乙個隱藏層中必須新增張量的形狀

model.add(dense(units=num_neurons[0]

,activation=

'relu'

, input_shape=

(num_inputs,))

)#新增第二個隱藏層

model.add(dense(units=num_neurons[1]

,activation=

'relu'))

#新增具有啟用函式softmax的輸出層

model.add(dense(units=num_outputs,activation=

'softmax'))

#輸出模型的詳細資訊

model.summary(

)

模型建立好以後得到結果:

使用sgd優化器編譯模型和訓練模型

#使用sgd優化器編譯模型

model.

compile

(loss=

'categorical_crossentropy'

, optimizer=sgd(lr=learning_rate)

, metrics=

['accuracy'])

#訓練模型

model.fit(x_train,y_train,

batch_size=batch_size,

epochs=n_epochs)

在模型訓練的過程中可以看到每次訓練迭代的損失函式值和分類精度:

10-40省略

評估模型並輸出損失函式值和分類精度

score=model.evaluate(x_test,y_test)

print

('\n test loss:'

,score[0]

)print

('test accuracy:'

,score[1]

)

Keras入門實戰(1) MNIST手寫數字分類

目錄 1 首先我們載入keras中的資料集 2 網路架構 3 選擇編譯 compile引數 4 準備影象資料 5 訓練模型 6 測試資料 前面的部落格中已經介紹了如何在ubuntu下安裝keras深度學習框架。現在我們使用 keras 庫來學習手寫數字分類。我們這裡要解決的問題是 將手寫數字的灰度影...

Keras 淺嚐之MNIST手寫數字識別

最近關注了一陣keras,感覺這個東西挺方便的,今天嘗試了一下發現確實還挺方便。不但提供了常用的layers normalization regularation activation等演算法,甚至還包括了幾個常用的資料庫例如cifar 10和mnist等等。下面的 算是keras的hellowor...

keras 實現mnist手寫數字集識別

coding utf 8 classifier mnist import numpy as np np.random.seed 1337 from keras.datasets import mnist from keras.utils import np utils from keras.mode...