Matlab深度學習實踐之手寫體識別(含詳細注釋)

2021-10-14 05:36:42 字數 3455 閱讀 1238

matlab這幾年在人工智慧這塊兒也越做越好了,最近為了熟悉matlab如何搭建神經網路,自己做了乙個手寫體識別實驗,記錄一下。

實驗任務非常簡單,網路搭的也非常隨意,不合理的地方也懶得改,旨在走通matlab搭建神經網路的流程。

首先,資料集為mnist資料集

我已經把資料按類別分好,分為train和test,底下又都有十個子資料夾存放手寫體影象。

網路訓練**如下:

clear;close all;clc;

%% 資料讀取、增強

%讀取訓練集

path_train =

'd:\work\過期檔案\手寫體識別\mnist\train'

;%訓練集路徑

folders_train =

fullfile

(path_train,);

%讀取子目錄

%讀取所有影象路徑

[imdstrain,imdsvalidation]

=spliteachlabel

(imds_train,

0.9,

0.1)

;%拆分出驗證集

%讀取測試集

path_test =

'd:\work\過期檔案\手寫體識別\mnist\test'

%影象增強

pixelrange =[-

22];

%平移範圍

scalerange =

[0.9

1.1]

;%縮放範圍

imageaugmenter =

imagedataaugmenter(.

..'randxtranslation'

,pixelrange,..

.'randytranslation'

,pixelrange,..

.'randxscale'

,scalerange,..

.'randyscale'

,scalerange)

;%定義影象增強器

augimdstrain =

augmentedimagedatastore([

28,28]

,imds_train,..

.'dataaugmentation'

,imageaugmenter)

;%影象增強

%% 設計(或者讀取)網路

layers =

[imageinputlayer([

28281]

,"name"

,"imageinput"

)convolution2dlayer([

55],

32,"name"

,"conv_1"

,"padding"

,"same"

,"stride",[

22])

relulayer

("name"

,"relu_1"

)batchnormalizationlayer

("name"

,"batchnorm_1"

)convolution2dlayer([

33],

32,"name"

,"conv_2"

,"padding"

,"same"

)relulayer

("name"

,"relu_2"

)fullyconnectedlayer

(512

,"name"

,"fc_1"

)batchnormalizationlayer

("name"

,"batchnorm_2"

)relulayer

("name"

,"relu_3"

)fullyconnectedlayer(10

,"name"

,"fc_2"

)softmaxlayer

("name"

,"softmax"

)classificationlayer

("name"

,"classoutput")]

;%analyzenetwork

(layers)

%分析網路

%% 訓練網路

options =

trainingoptions

('sgdm',.

..'minibatchsize'

,512,.

..'maxepochs',1

,...

'initiallearnrate'

,1e-2,.

..'shuffle'

,'every-epoch',.

..'validationdata'

,imdsvalidation,..

.'validationfrequency',3

,...

'verbose',1

,...

'plots'

,'training-progress');

%設定訓練策略

trainednet =

trainnetwork

(augimdstrain,layers,options)

;%訓練

%% 測試模型

[ypred,probs]

=classify

(trainednet,imds_test)

; accuracy =

mean

(ypred == imds_test.labels)

這裡面,用到了一些函式,一些重要的用法我都寫在其他部落格裡了,這兒只大致說一下有什麼用

訓練結果:

如果需要處理好的資料集,可以留下郵箱~

最後在說明一下,網路是隨便搭的,不要用!!只是學習matlab用的

以上這些希望會對你有所幫助

深度學習之手寫數字識別

mnist是乙個入門級的計算機視覺資料集,它包含各種手寫數字 它也包含每一張對應的標籤,告訴我們這個是數字幾。比如,上面這四張的標籤分別是5,0,4,1。mnist資料集的官網是 yann lecun s website 這份 然後用下面的 匯入到你的專案裡面,也可以直接複製貼上到你的 檔案裡面。i...

Python基礎學習之手寫識別演算法

k 近鄰演算法 from numpy import python裡的計算包numpy import operator 運算子模組 import os 資料準備所需的函式 def createdataset group array 1.0,1.1 1.0,1.0 0,0 0,0.1 labels a ...

深度學習實踐

選擇合適的損失函式 mini batch 選擇不同的啟用函式 改變學習速度 momentum early stopping 正則化 dropout 改變網路架構 選擇合適的損失函式 mini batch當資料集很大時,訓練演算法是非常慢的,和 batch 梯度下降相比,使用 mini batch 梯...