PyTorch 入門 自定義資料載入

2021-09-19 08:44:53 字數 2337 閱讀 3231

之前學習tensorflow時也學習了它的資料載入,不過在網上看了很多教程後還是有很多小問題,不知道為什麼在別人電腦上可以執行但是我的就不行(把我頭搞暈了),很煩,這時想起之前聽導師說pytorch容易入門上手,所以果斷去學了pytorch,寫這篇博文的目的就是總結學到的,然後記錄下來,也希望以後學到新的知識或技術能夠用寫部落格的方式記錄下來,這樣有助於形成比較好的知識體系,也方便以後溫故知新。

在進行深度學習實驗前,必須準備資料,而在準備資料的時候有乙個步驟就是把資料封裝成符合深度學習模型要求的資料形式,方便模型讀取。

在pytorch中,有乙個torch.utils.data.dataset類,這是乙個抽象類,其他所有不管是公開的官方資料集還是自定義資料集都必須繼承這個抽象類(比如mnist資料集),繼承這個抽象類的同時必須重寫它的兩個函式:__len__()  和    __getitem__()。

具體怎麼定義自定義資料集,**如下:

import torch

from torch.utils.data import dataset #首先匯入這個抽象類

from skimage import io,transform

class mydataset(dataset):

"""這是乙個初始化函式,相當於c++的建構函式,定義類的傳入引數和初始化

root_dir是資料集的路徑,transform是乙個資料處理操作

"""def __init__(self,root_dir,transform=none):

#os.listdir函式讀取路徑下所以檔案的檔名,並組成乙個列表並返回

self.file = os.listdir(root_dir)

self.root_dir = root_dir

self.transform = transform

def __len__(self):

return len(self.file) #返回這給列表的大小

def __getitem__(self,index):

#將傳入路徑和檔名組成乙個新的位址,這個資料就是單個資料的具體位址,方便之後以位址讀取該資料

img_name = os.path.join(self.root_dir, self.file[index])

#標籤)

if img_name[-7:-4] == 'dog':

label = 0

else: label = 1

image = io.imread(img_name)

#對進行縮放為乙個大小,方便深度學習模型處理

image = transform.resize(image,(128,128))

#對的維度進行轉換,

#numpy的三個維度順序為:h * w * c

#而torch的張量維度順序:c * h * w ,所以模型要處理它必須轉換成torch的形式

image = image.transpose((2, 0, 1))

#返回資料和標籤

return image,label

自定義資料集定義好了,那麼怎麼批量載入它呢,pytorch使用多執行緒載入資料,模型需要使用時才載入進記憶體讓模型讀取,而使用批量讀取資料必須使用pytorch的torch.utils.data.dataloader類,使用方法如下:

mydataloader = torch.utils.data.dataloader()

#首先例項化乙個自定義資料類

dataset = mydataset(root_dir='./dog_vs_cat/train/',transform=transforms.compose([

transforms.totensor(),

transforms.normalize((0.1307,), (0.3081,))

]))#然後例項化乙個資料載入類

dataloader = torch.utils.data.dataloader(dataset,batch_size=100,shuffle=true,num_workers=0)

第乙個引數是要載入的資料類

第二個引數是資料載入時每個批次多少資料

第三個引數設定資料載入時是否打亂資料

第四個引數設定多執行緒的個數,預設值是0,表示單個執行緒

然後就可以在迭代器中使用了

for batch_idx, (image, label) in enumerate(dataloader):
batch_idx 表示迭代器返回的自帶序號

(image,label)表示返回的資料和標籤

最後就可以將返回的資料和標籤輸入到模型中訓練了,完美!

Pytorch 自定義資料集

pytorch將資料集的處理過程標準化。繼承dataset類 pytorch中提供了torch.utils.data.dataset抽象類,使用時需要繼承這個類,並重寫 len 和 geiitem 函式。增加資料變換 pytorch提供了torchvision.transforms可以比較方便進行影...

Pytorch自定義引數

如果想要靈活地使用模型,可能需要自定義引數,比如 class net nn.module def init self super net,self init self.a torch.randn 2 3 requires grad true self.b nn.linear 2,2 defforwa...

PyTorch 自定義層

與使用module類構造模型類似。下面的centeredlayer類通過繼承module類自定義了乙個將輸入減掉均值後輸出的層,並將層的計算定義在了forward函式裡。這個層裡不含模型引數。class mydense nn.module def init self super mydense,se...