pytorch之資料集構造

2021-09-25 01:49:43 字數 4499 閱讀 8575

這些天看的東西,真的是比較多,相比以前來說,對我的學習方式起到顛覆性作用。我目前覺得,我們學到的東西,更多是孤立的,因此,在吸收一定知識後,需要在腦子裡形成知識體系。需要把自己以前學到的東西進行整理,形成乙個體系,這篇文章講解的是,深度學習中pytorch資料集的構造!!!

pytorch中有兩個自定義管理資料集的類,

torch.utils.data.datasettorvchvision.datasets.imagefolder

這裡主要講解的第一種。

class dataset(object):

"""an abstract class representing a dataset.

all other datasets should subclass it. all subclasses should override

``__len__``, that provides the size of the dataset, and ``__getitem__``,

supporting integer indexing in range from 0 to len(self) exclusive.

"""def __getitem__(self, index):

raise notimplementederror

def __len__(self):

raise notimplementederror

def __add__(self, other):

return concatdataset([self, other])

我們設計自己資料集類的時候, 只需要重寫__getitem__、__len__兩個函式,分別的功能是,通過切片返回具樣例返回樣本個數

以下是voc2012資料集分割的例子:

return len(self.data_list)通過上面的操作,我們構建自己資料集類,接下來,構建乙個dataloader類,這個作用是訓練過程中,返回 batch個樣例。

由於原始碼過於臃腫了,這裡知識摘出對應的建構函式:

def __init__(self, dataset, batch_size=1, shuffle=false, sampler=none,

batch_sampler=none, num_workers=0, collate_fn=default_collate,

pin_memory=false, drop_last=false, timeout=0,

worker_init_fn=none):

建構函式中,每個引數的意思就不一一介紹了,只著重的講解下,可呼叫函式collate_fn。我們首先看乙個構建dataloader的例項:

def build_dataset(cfg, transforms, is_train=true):

datasets = vocsegdataset(cfg, is_train, transforms)

return datasets

def make_data_loader(cfg, is_train=true):

if is_train:

batch_size = cfg.solver.ims_per_batch

shuffle = true

else:

batch_size = cfg.test.ims_per_batch

shuffle = false

transforms = build_transforms(cfg, is_train)

datasets = build_dataset(cfg, transforms, is_train)

num_workers = cfg.dataloader.num_workers

data_loader = data.dataloader(

datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=true

)return data_loader

上面第乙個函式build_dataset返回資料集例項,第二個函式返回dataloader,關於dataloader,我們需要注意的是,有時我們需要根據dataset中的__getitem__修改collate_fn

我們來看下原始碼:

def __next__(self):

if self.num_workers == 0: # same-process loading

indices = next(self.sample_iter) # may raise stopiteration

batch = self.collate_fn([self.dataset[i] for i in indices])

if self.pin_memory:

batch = _utils.pin_memory.pin_memory_batch(batch)

return batch

我們在原始碼中發現,collate_fn的輸入是乙個list,裡面的每個元素是__getitem__的輸出,由此,我們估計,default_collate的作用是將這個list,**變換格式為[batch,c,h,w]**的tensor,我們在來看下原始碼:

if.......

.........

elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple

return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))

由於原始碼均是對型別的判斷,因此,這裡我們知識摘出,與voc2012分割相關的部分,這個語句的意思是, 對[(img1, label1), (img2, label2)],首先返回[img1,img2],[lable1,label2],在繼續返回兩個tensor,乙個是img,[batch,c,h,w],乙個是label:[batch,c,h,w]。

所以,通過上面分析,如果,我們__getitem__不符合collat_fn不符合預設函式的判斷時,需要修改該函式。

好了,先到這,接下來…慢慢聊程式,需要學的太多了

pytorch之建立資料集

import torch import torchvision from torchvision import datasets,transforms dataroot data celeba 資料集所在資料夾 建立資料集 dataset datasets.imagefolder root data...

pytorch 載入資料集

2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...

pytorch批訓練資料構造

這是對莫凡python的學習筆記。1.建立資料 import torch import torch.utils.data as data batch size 8x torch.linspace 1,10,10 y torch.linspace 10,1,10 可以看到建立了兩個一維資料,x 1 1...