PyTorch 1 x 資料IO及預處理

2021-10-10 02:02:53 字數 2696 閱讀 9762

三者關係:

偽**示例

class customdataset(dataset):

# 自定義自己的dataset

dataset = customdataset(

)dataloader = dataloader(dataset, ...)

for data in dataloader:

# training...

使用方法示例

class my_dataset(torch.utils.data.dataset):

def __init__(self, trainingimagedir, bndbox, keypointspixel, keypointsworld, center):

self.trainingimagedir = trainingimagedir

self.mean = img_mean

self.std = img_std

self.bndbox = bndbox

self.keypointspixel = keypointspixel

self.keypointsworld = keypointsworld

self.center = center

self.depth_thres = 0.4

def __getitem__(self, index):

# data4dtemp = scio.loadmat(self.trainingimagedir + str(index+1) + '.mat')['depthnormal']

data4dtemp = scio.loadmat(self.trainingimagedir + str(index) + '.mat')[

'depthnormal'

] depthtemp = data4dtemp[:,:,3]

data, label = datapreprocess(index, depthtemp, self.keypointspixel, self.keypointsworld, self.bndbox, self.center, self.depth_thres)

return data, label

def __len__(self):

return len(self.bndbox)

dataloader(dataset, batch_size=1, shuffle=false, sampler=none,

batch_sampler=none, num_workers=0, collate_fn=none,

pin_memory=false, drop_last=false, timeout=0,

worker_init_fn=none, *, prefetch_factor=2,

persistent_workers=false)

引數

說明dataset

從中載入資料的資料集

batch_size

每批次要載入的樣本數

shuffle

設定為true以使資料在每個訓練epoch都重新洗牌

collate_fn

這個函式用來打包batch

sampler

定義從資料集中抽取樣本的策略。 如果指定,則shuffle必須為false

batch_sampler

類似於取樣器(sampler),但一次返回一批索引。 與batch_size,shuffle,sampler和drop_last互斥

num_workers

用於資料載入的子程序數量。 0表示將在主程序中載入資料。 (預設值:0)

collate_fn

合併樣本列表以形成小批量

pin_memory

如果為true,則資料載入器在將張量返回之前將其複製到cuda固定的記憶體中

drop_last

如果資料集大小不能被批量大小整除,則設定為true以刪除最後乙個不完整的批量。 如果為false並且資料集的大小不能被批次大小整除,則最後一批將較小(預設值:false)

timeout

如果為正,則為從工作程序收集批次的超時值。 應始終為非負數 (預設值:0)

worker_init_fn

如果不為none,則在種子建立之後和資料載入之前,將在每個工作子程序上以工作id([0,num_workers-1]中的int)作為輸入來呼叫此方法 (預設值:無)

def __iter__(self):

return dataloaderiter(self)

示例

test_image_datasets = my_dataset(testingimagedir, bndbox_test, keypointspixeltest, keypointsworldtest, center_test)

test_dataloaders = torch.utils.data.dataloader(test_image_datasets, batch_size = batch_size,

shuffle = false, num_workers = 8)

pytorch訓練MNIST資料集1

本文採用全連線網路對mnist資料集進行訓練,訓練模型主要由五個線性單元和relu啟用函式組成 import torch from torchvision import transforms from torchvision import datasets from torch.utils.data...

Pytorch學習筆記 1 基礎資料結構

參考連線 torch.tensor是乙個多維矩陣,其中包含單個資料型別的元素,使用cpu和gpu及其變體定義了10種張量型別,如下所示 data type dtype cpu tensor gpu tensor 32位浮點 torch.float32 or torch.float torch.flo...

vue 1 x 元件資料傳遞

本文章主要講了元件如何進行資料的傳遞,從簡單的元件裡面的資料如何顯示,子元件裡面的資料顯示,子元件獲取父元件的資料,子元件主動傳送資料給父元件。例子 可以在props 宣告傳遞的資料的型別,如 props 實際 例子 id aaa 我是父級 元件title v cloak box style hea...