使用torch進行簡單的手寫數字識別

2021-10-25 06:53:38 字數 4134 閱讀 5715

只有線性部分的計算是很難完成識別任務的,因此我們需要新增非線性部分。

sigmoid函式(啟用函式)是最常用的,我們這裡使用relu函式,當輸入的和大於0就是原來的和,當小於零就忽略掉。

我們只需要乘到前面就好了。

由於relu函式是線性的,因此模型的線性能力增強了,還具有很強的非線性能力。

再使用argmax函式得到索引值。

from torch import nn # 常用網路

from torch import optim # 優化工具包

import torchvision # 視覺資料集

from matplotlib import pyplot as plt

## 載入資料

batch_size=

512train_loader = torch.utils.data.dataloader(

torchvision.datasets.mnist(

'mnist_data'

,train=

true

,download=

true

, transform=torchvision.transforms.compose(

[ torchvision.transforms.totensor(),

torchvision.transforms.normalize(

(0.1307,)

,(0.3081,)

)# 做乙個標準化])

),batch_size=batch_size,shuffle=

true

)test_loader = torch.utils.data.dataloader(

torchvision.datasets.mnist(

'mnist_data/'

,train=

false

,download=

true

, transform=torchvision.transforms.compose(

[ torchvision.transforms.totensor(),

torchvision.transforms.normalize(

(0.1307,)

,(0.3081,)

)]))

, batch_size=batch_size,shuffle=

true

)x,y=

next

(iter

(train_loader)

)print

(x.shape,y.shape,x.

min(

),x.

max())

relu = nn.relu(

)# 如果使用torch.sigmoid作為啟用函式的話正確率只有60%

# 建立網路

class

net(nn.module)

:def

__init__

(self)

:super

(net,self)

.__init__(

)# xw+b 這裡的256,64使我們人根據自己的感覺指定的

self.fc1 = nn.linear(28*

28,256)

self.fc2 = nn.linear(

256,64)

self.fc3 = nn.linear(64,

10)defforward

(self,x)

:# 因為找不到relu函式,就換成了啟用函式

# x:[b,1,28,28]

# h1 = relu(xw1+b1)

x = relu(self.fc1(x)

)# h2 = relu(h1w2+b2)

x = relu(self.fc2(x)

)# h3 = h2*w3+b3

x = self.fc3(x)

return x

# 因為找不到自帶的one_hot函式,就手寫了乙個

defone_hot

(label, depth=10)

: out = torch.zeros(label.size(0)

, depth)

idx = torch.longtensor(label)

.view(-1

,1) out.scatter_(dim=

1, index=idx, value=1)

return out

## 訓練模型

net = net(

)# 返回[w1,b1,w2,b2,w3,b3] 物件,lr是學習過程

optimizer = optim.sgd(net.parameters(

), lr=

0.01

, momentum=

0.9)

train_loss =

mes_loss = nn.mseloss(

)for epoch in

range(3

):for batch_idx,

(x, y)

inenumerate

(train_loader)

:# x:[b,1,28,28],y:[512]

# [b,1,28,28] => [b,784]

x = x.view(x.size(0)

,28*28

)# =>[b,10]

out = net(x)

# [b,10]

y_onehot = one_hot(y)

# loss = mse(out,y_onehot)

loss = mes_loss(out, y_onehot)

# 清零梯度

optimizer.zero_grad(

)# 計算梯度

loss.backward(

)# w' = w -lr*grad

# 更新梯度,得到新的[w1,b1,w2,b2,w3,b3]

optimizer.step())

)if batch_idx %

10==0:

print

(epoch, batch_idx, loss.item())

# plot_curve(train_loss)

# 到現在得到了[w1,b1,w2,b2,w3,b3]

## 準確度測試

total_correct =

0for x,y in test_loader:

x = x.view(x.size(0)

,28*28

) out = net(x)

# out : [b,10] => pred: [b]

pred = out.argmax(dim =1)

correct = pred.eq(y)

.sum()

.float()

.item(

)# .float之後還是tensor型別,要拿到資料需要使用item()

total_correct += correct

total_num =

len(test_loader.dataset)

acc = total_correct/total_num

print

('準確率acc:'

,acc)

用Keras進行手寫字型識別(MNIST資料集)

首先載入資料 from keras.datasets import mnist train images,train labels test images,test labels mnist.load data 接下來,看看這個資料集的基本情況 train images.shape 60000,28...

手寫簡單的陣列

實現查詢,新增,刪除操作 array類 package com.company public class array 無引數的建構函式,預設陣列的容量capacity 10 public array 獲取陣列中元素的個數 public int getsize 獲取陣列的容量 public int g...

Python使用logging進行簡單的日誌處理

將日誌內容輸出到日誌檔案和控制台,先導入相關模組。import os import logging import time import sys設定log的資料夾路徑,並判斷log資料夾是否存在,若不存在則建立。project dir os.path.abspath os.path.join os....