pytorch訓練MNIST資料集1

2021-10-24 09:49:40 字數 2735 閱讀 3069

本文採用全連線網路對mnist資料集進行訓練,訓練模型主要由五個線性單元和relu啟用函式組成

import torch

from torchvision import transforms

from torchvision import datasets

from torch.utils.data import dataloader

import torch.nn.functional as f

import torch.optim as optim

import os

import sys

batch_size = 64

transform = transforms.compose(

[transforms.totensor(), #將0-255變成0-1

transforms.normalize((0.1307,),(0.3081,)) #正則化

])train_dataset =datasets.mnist(root='../dataset/mnist',

train = true,

download = true,

transform = transform)

train_loader = dataloader(train_dataset,

shuffle=true,

batch_size=batch_size)

test_dataset =datasets.mnist(root='../dataset/mnist',

train = false,

download = true,

transform = transform)

test_loader = dataloader(test_dataset,

shuffle = false,

batch_size=batch_size)

class net(torch.nn.module):

def __init__(self):

super(net,self).__init__()

self.f1 = torch.nn.linear(784,512)

self.f2 = torch.nn.linear(512,256)

self.f3 = torch.nn.linear(256,128)

self.f4 = torch.nn.linear(128,64)

self.f5 = torch.nn.linear(64,10)

def forward(self,x):

#這裡將

x = x.view(-1,784) #展成1*784

x = f.relu(self.f1(x))

x = f.relu(self.f2(x))

x = f.relu(self.f3(x))

x = f.relu(self.f4(x))

return self.f5(x)

model = net()

#loss--交叉熵

criterion = torch.nn.crossentropyloss()

#帶衝量

optimzer = optim.sgd(model.parameters(),lr=0.01,momentum = 0.5)

#訓練def train(epoch):

running_loss =0.0

for batch_idx,data in enumerate(train_loader):

inputs,target = data

optimzer.zero_grad()

outputs = model(inputs)

loss = criterion(outputs,target)

loss.backward()

optimzer.step()

running_loss += loss.item()

if batch_idx%300==299:

print('[%d,%5d] loss:%.3f' % (epoch+1,batch_idx+1,running_loss/300))

running_loss = 0.0

def test():

correct = 0

total = 0

with torch.no_grad():

for batch_idx,data in enumerate(test_loader):

images,labels = data

outputs = model(images)

_,predicted = torch.max(outputs.data,dim=1)

total += labels.size(0)

correct += (predicted==labels).sum().item()

print('accuracy on test set:%d %%' % (100*correct/total))

if __name__== '__main__':

for epoch in range(7):

train(epoch)

test()

#儲存網路引數

結果: 經過7論訓練測試集可以達到97%

pytorch使用GPU訓練MNIST資料集

參考莫凡部落格進行mnist資料集的訓練,臨時記錄所使用的 import torch import torch.nn as nn import torch.utils.data as data import torchvision import matplotlib.pyplot as plt to...

使用matlab訓練mnist模型

前面的博文是通過命令進行mnist模型訓練與測試的,由於實驗需要,想要通過matlab語句來實現mnist模型的訓練,從而把這種方式用於其他問題模型的訓練與測試。1 準備資料與引數 因為matlab程式檔案是在matlab demo下,為了方便,直接把需要的檔案拷貝到demo下 mnist data...

TensorFlow 訓練 MNIST 資料(二)

輸入層 卷積層 卷積層 密集連線層 輸出層。其中每乙個卷積層中還有max pooling,用來進行降維,輸出層中是乙個softmax層。首先這次構建的神經網路相較上篇的神經網路來說,上次的權重矩陣和偏置矩陣直接設定為0,但是存在乙個問題就是容易導致神經元輸出恒為零的情況出現,由於是對稱的容易導致0梯...