pytorch學習筆記五 批訓練

2021-10-04 16:11:42 字數 1263 閱讀 2399

學自莫凡python

一批5個資料(batch_size=5),15個資料總共被分成3批訓練(step=3)。並將所有資料整體訓練了3遍。

# 1.匯入模組

import torch

import torch.utils.data as data #data是用來批訓練的模組

# 2.一批訓練5個資料

batch_size = 5

# 3.使用torch構建資料集

x = torch.linspace(1, 15, 15)

y = torch.linspace(15, 1, 15)

# x用來資料訓練,y用來誤差計算

torch_dataset = data.tensordataset(x, y)

# 4.使用loader將訓練變成一批一批的

loader = data.dataloader(

dataset = torch_dataset, # 匯入資料集

batch_size= batch_size, # 設定一批的樣本數

shuffle=false, # 打亂資料順序再分批進行下輪訓練(false不打亂)

num_workers=2 # 多執行緒,使用雙程序提取資料

)# 5.分批訓練

# 將所有資料整體訓練3次(epoch=3)

# 一批5個資料(batch_size=5),將15個資料分成3批訓練(step=3)

# 使用print檢視訓練過程

for epoch in range(3):

for step, (batch_x, batch_y) in enumerate(loader):

print('epoch:', epoch, '|step:', step,

'|batch_x', batch_x.numpy(), '|batch_y', batch_y.numpy())

注意:如果所有資料的個數不能被batch_size整除,即step不為整數時,那麼最後一批剩餘幾個就訓練幾個。比如在上面的實驗中一批處理7個資料(batch_size=7),15個資料的分批情況就是:第一批7個,第二批7個,第三批1個。

pytorch(五) 批訓練

import torch import torch.utils.data as data 虛構要訓練的資料 x torch.linspace 11,20,10 在 11,20 裡取出10個間隔相等的數 torch tensor y torch.linspace 20,11,10 batch size...

PyTorch學習筆記5 批訓練

1 torch.utils.data.tensordataset 和torch.utils.data.dataloader pytorch提供了乙個資料讀取的方法,其由兩個類構成 torch.utils.data.dataset和dataloader,我們要自定義自己資料讀取的方法,就需要繼承tor...

Pytorch教程 批訓練

torch and numpy 變數 variable 激勵函式 關係擬合 回歸 區分型別 分類 快速搭建法 批訓練加速神經網路訓練 optimizer優化器 卷積神經網路 cnn 卷積神經網路 rnn lstm rnn 迴圈神經網路 分類 rnn 迴圈神經網路 回歸 自編碼 autoencoder...