莫煩pytorch批訓練

2021-09-29 04:15:21 字數 1948 閱讀 9943

import torch

import torch.utils.data as data

包裝資料類:tensordataset

【包裝資料和目標張量的資料集,通過沿著第乙個維度索引兩個張量來】

class torch.utils.data.tensordataset(data_tensor, target_tensor)

#data_tensor (tensor) - 包含樣本資料

#target_tensor (tensor) - 包含樣本目標(標籤)

載入資料類:dataloader【資料載入器。組合資料集和取樣器,並在資料集上提供單程序或多程序迭代器。】

class torch.utils.data.dataloader(dataset, batch_size=1, shuffle=false, sampler=none, num_workers=0, collate_fn=, pin_memory=false, drop_last=false)

#num_workers (int, optional) – 用多少個子程序載入資料

#drop_last (bool, optional) – 如果資料集大小不能被batch size整除,則設定為true後可刪除最後乙個不完整的batch。如果設為false並且資料集的大小不能被batch size整除,則最後乙個batch將更小。(預設: false)

首先介紹一下import torch.utils.data as data,這在訓練過程中基本都會用到。該介面大多用來讀取資料和把資料封裝成tensor,之後的dataloader用來做mini—batch訓練。

import torch

import torch.utils.data as data

batch_size=5

x=torch.linspace(1,10,10)

y=torch.linspace(10,1,10)

# 先轉換成 torch 能識別的 dataset

torch_dataset=data.tensordataset(x,y) #變成tensor

loader=data.dataloader(

dataset=torch_dataset,

batch_size=batch_size,

shuffle=true,

num_workers=2 # 多執行緒來讀資料

注意dataloader裡面的dataset引數必須要是tensor型別,batch_size是你需要的mini-batch的大小,shuffle是否打亂,true就是打亂(打亂效果比較好,一般都打亂),num_workers=2是多執行緒讀取資料。

def show_batch():

for epoch in range(3):# 訓練所有!整套!資料 3 次

for step,(batch_x,batch_y) in enumerate(loader):

# 打出來一些資料

print('epoch: ', epoch, '| step: ', step, '| batch x: ',

batch_x.numpy(), '| batch y: ', batch_y.numpy())

if __name__ == '__main__':

show_batch()

enumerate就是可以把乙個list變成索引-元素對,這樣可以在for迴圈中同時迭代索引和元素本身。

參考:

莫煩學習筆記4 批訓練

批訓練就是把你的資料分批訓練,像之前回歸的時候100個點可以分成兩批訓練,可以是50,50,也可以是80,20。分批有啥好處呢?我想到的就是可以用多執行緒平行計算。莫煩這篇部落格就是說在torch.utils.data這個庫裡面 暫且稱這個庫為data 有乙個函式,也就是data.dataloade...

莫煩pytorch學習筆記

此處x,y為資料集的tensor torch dataset data.tensordataset data tensor x,target tensor y loader data.dataloader dataset torch dataset,batch size batch size,shu...

莫煩 pytorch筆記 variable是什麼

variable型別是什麼 variable tensor1 torch.floattensor 1,2 3,4 建立tensor variable variable tensor1,requires grad true 建立variable。其中requires grad是誤差反向傳播 計算梯度的...