pytorch五 用Variable實現線性回歸

2021-09-19 04:27:53 字數 1010 閱讀 3426

#匯入相關包

import torch as t

import matplotlib.pyplot as plt

#構造資料

def get_fake_data(batch_size = 8):

#設定隨機種子數,這樣每次生成的隨機數都是一樣的

t.manual_seed(10)

#產生隨機資料:y = 2*x+3,加上了一些雜訊

x = t.rand(batch_size,1) * 20

#randn生成期望為0方差為1的正態分佈隨機數

y = x * 2 + (1 + t.randn(batch_size,1)) * 3

return x,y

#檢視生成資料的分布

x,y = get_fake_data()

plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())

#線性回歸

#隨機初始化引數

w = t.rand(1,1)

b = t.zeros(1,1)

#學習率

lr = 0.001

for i in range(10000):

x,y = get_fake_data()

#forward:計算loss

y_pred = x.mm(w) + b.expand_as(y)

#均方誤差作為損失函式

loss = 0.5 * (y_pred - y)**2

loss = loss.sum()

#backward:手動計算梯度

dloss = 1

dy_pred = dloss * (y_pred - y)

dw = x.t().mm(dy_pred)

db = dy_pred.sum()

#更新引數

w.sub_(lr * dw)

b.sub_(lr * db)

if

Pytorch學習筆記(五)

9 在pytorch中使用lstm 學習pytorch的rnn使用時最好去官方文件看一下api是如何使用的 乙個需要注意的地方是在pytorch中rnn的輸入input的shape的三維分別是 seq len,batch,input size 隱藏層h 0的shape三維分別是 num layers...

pytorch(五) 批訓練

import torch import torch.utils.data as data 虛構要訓練的資料 x torch.linspace 11,20,10 在 11,20 裡取出10個間隔相等的數 torch tensor y torch.linspace 20,11,10 batch size...

用PyTorch實現多層網路

task4 2天 用pytorch實現多層網路 1.引入模組,讀取資料 2.構建計算圖 構建網路模型 3.損失函式與優化器 4.開始訓練模型 5.對訓練的模型 結果進行評估 參考 import torch import torchvision import torchvision.transform...