pytorch學習筆記 回歸問題1

2021-10-24 06:24:47 字數 3643 閱讀 1087

搭建模型

定義計算步驟

輸出運算結果

本節主要針對mnist資料集的數字識別問題,寫出乙個解決回歸問題的方法。初步體會機器學習的工作流程

import  torch

from torch import nn

from torch.nn import functional as f

from torch import optim

import torchvision

from matplotlib import pyplot as plt

#畫圖專用的檔案

from utils import plot_image, plot_curve, one_hot

batch_size =

512# step1. load dataset載入資料集

train_loader = torch.utils.data.dataloader(

torchvision.datasets.mnist(

'mnist_data'

, train=

true

, download=

true

, transform=torchvision.transforms.compose(

[ torchvision.transforms.totensor(),

torchvision.transforms.normalize(

(0.1307,)

,(0.3081,)

)]))

, batch_size=batch_size, shuffle=

true

)test_loader = torch.utils.data.dataloader(

torchvision.datasets.mnist(

'mnist_data/'

, train=

false

, download=

true

, transform=torchvision.transforms.compose(

[ torchvision.transforms.totensor(),

torchvision.transforms.normalize(

(0.1307,)

,(0.3081,)

)]))

, batch_size=batch_size, shuffle=

false

)

x, y =

next

(iter

(train_loader)

)print

(x.shape, y.shape, x.

min(

), x.

max())

plot_image(x, y,

'image sample'

)

class

net(nn.module)

:def

__init__

(self)

:super

(net, self)

.__init__(

)# xw+b

self.fc1 = nn.linear(28*

28,256)

self.fc2 = nn.linear(

256,64)

self.fc3 = nn.linear(64,

10)defforward

(self, x)

:# x: [b, 1, 28, 28]

# h1 = relu(xw1+b1) 公式

x = f.relu(self.fc1(x)

)# h2 = relu(h1w2+b2) 公式

x = f.relu(self.fc2(x)

)# h3 = h2w3+b3 公式

x = self.fc3(x)

return x

net = net(

)# [w1, b1, w2, b2, w3, b3]

#優化器

optimizer = optim.sgd(net.parameters(

), lr=

0.01

, momentum=

0.9)

#記錄loss

train_loss =

for epoch in

range(3

):for batch_idx,

(x, y)

inenumerate

(train_loader)

:# x: [b, 1, 28, 28], y: [512]

# [b, 1, 28, 28] => [b, 784] 從四維變換成二維

x = x.view(x.size(0)

,28*28

)# => [b, 10]

out = net(x)

# [b, 10]

y_onehot = one_hot(y)

# loss = mse(out, y_onehot)

loss = f.mse_loss(out, y_onehot)

# 清零梯度

optimizer.zero_grad(

) loss.backward(

)# w' = w - lr*grad 梯度更新

optimizer.step())

)# 輸出

if batch_idx %

10==0:

print

(epoch+

1, batch_idx, loss.item())

plot_curve(train_loss)

# we get optimal [w1, b1, w2, b2, w3, b3]

plot_curve(train_loss)

# we get optimal [w1, b1, w2, b2, w3, b3]

total_correct =

0for x,y in test_loader:

x = x.view(x.size(0)

,28*28

) out = net(x)

# out: [b, 10] => pred: [b]

pred = out.argmax(dim=1)

correct = pred.eq(y)

.sum()

.float()

.item(

) total_correct += correct

total_num =

len(test_loader.dataset)

acc = total_correct / total_num

print

('test acc:'

, acc)

pytorch學習筆記3 線性回歸

線性回歸線性回歸 是分析乙個變數與另外乙個 多 個變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 求解w,b 求解步驟 1.確定模型 module y wx b 2.選擇損失函式 mse 均方差等 3.求解梯度並更新w,b w w lr w.grad b b lr w.grad...

Pytorch 線性回歸問題

y 4 3x 高斯雜訊 我們利用線性回歸原理,假設y wx b,利用梯度下降法,去求解w,b。驗證w,b是否比較接近w 3,b 4 計算loss function函式 loss sum y w x b 2 定義loss functiondef computre error loss function...

pytorch碎碎念 回歸問題

根據b站莫煩python邊學邊打,只有自己打一遍才能發現容易發生好多錯誤啊 昨晚配置pytorch很順利!一遍就好了 環境 cuda10.0 python3.7 pytorch1.2.0 gpu 1660ti 現在已經有了pytorch1.4.0 似乎tensor和variable的用法有了改變 但...