pytorch 自定義資料集載入方法

2021-10-17 05:26:09 字數 3100 閱讀 8161

更多python教程請到: 菜鳥教程

pytorch 官網給出的例子中都是使用了已經定義好的特殊資料集介面來載入資料,而且其使用的資料都是官方給出的資料。如果我們有自己收集的資料集,如何用來訓練網路呢?此時需要我們自己定義好資料處理介面。幸運的是pytroch給出了乙個資料集介面類(torch.utils.data.dataset),可以方便我們繼承並實現自己的資料集介面。

torch.utils.data

torch的這個檔案包含了一些關於資料集處理的類。

class torch.utils.data.dataset: 乙個抽象類, 所有其他類的資料集類都應該是它的子類。而且其子類必須過載兩個重要的函式:len(提供資料集的大小)、getitem(支援整數索引)。

class torch.utils.data.tensordataset: 封裝成tensor的資料集,每乙個樣本都通過索引張量來獲得。

class torch.utils.data.concatdataset: 連線不同的資料集以構成更大的新資料集。

class torch.utils.data.subset(dataset, indices): 獲取指定乙個索引序列對應的子資料集。

class torch.utils.data.dataloader(dataset, batch_size=1, shuffle=false, sampler=none, batch_sampler=none, num_workers=0, collate_fn=, pin_memory=false, drop_last=false, timeout=0, worker_init_fn=none): 資料載入器。組合了乙個資料集和取樣器,並提供關於資料的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照給定的長度將資料集劃分成沒有重疊的新資料集組合。

class torch.utils.data.sampler(data_source):所有取樣的器的基類。每個取樣器子類都需要提供 __iter__ 方法以方便迭代器進行索引 和乙個 len方法 以方便返回迭代器的長度。

class torch.utils.data.sequentialsampler(data_source):順序取樣樣本,始終按照同乙個順序。

class torch.utils.data.randomsampler(data_source):無放回地隨機取樣樣本元素。

class torch.utils.data.subsetrandomsampler(indices):無放回地按照給定的索引列表取樣樣本元素。

class torch.utils.data.weightedrandomsampler(weights, num_samples, replacement=true): 按照給定的概率來取樣樣本。

class torch.utils.data.batchsampler(sampler, batch_size, drop_last): 在乙個batch中封裝乙個其他的取樣器。

class torch.utils.data.distributed.distributedsampler(dataset, num_replicas=none, rank=none):取樣器可以約束資料載入進資料集的子集。

自定義資料集

自己定義的資料集需要繼承抽象類class torch.utils.data.dataset,並且需要過載兩個重要的函式:__len__ 和__getitem__。

整個**僅供參考。在__init__中是初始化了該類的一些基本引數;__getitem__中是真正讀取資料的地方,迭代器通過索引來讀取資料集中資料,因此只需要這乙個方法中加入讀取資料的相關功能即可;__len__給出了整個資料集的尺寸大小,迭代器的索引範圍是根據這個函式得來的。

import torch

class mydataset(torch.nn.data.dataset):

definit(self, datasource)

self.datasource = datasource

defgetitem(self, index):

element = self.datasource[index]

return element

deflen(self):

return len(self.datasource)

train_data = mydataset(datasource)

自定義資料集載入器

class torch.utils.data.dataloader(dataset, batch_size=1, shuffle=false, sampler=none, batch_sampler=none, num_workers=0, collate_fn=, pin_memory=false, drop_last=false, timeout=0, worker_init_fn=none): 資料載入器。組合了乙個資料集和取樣器,並提供關於資料的迭代器。

dataset (dataset) – 需要載入的資料集(可以是自定義或者自帶的資料集)。

batch_size – batch的大小(可選項,預設值為1)。

shuffle – 是否在每個epoch中shuffle整個資料集, 預設值為false。

sampler – 定義從資料中抽取樣本的策略. 如果指定了, shuffle引數必須為false。

num_workers – 表示讀取樣本的執行緒數, 0表示只有主線程。

collate_fn – 合併乙個樣本列表稱為乙個batch。

pin_memory – 是否在返回資料之前將張量拷貝到cuda。

drop_last (bool, optional) – 設定是否丟棄最後乙個不完整的batch,預設為false。

timeout – 用來設定資料讀取的超時時間的,但超過這個時間還沒讀取到資料的話就會報錯。應該為非負整數。

train_loader=torch.utils.data.dataloader(dataset=train_data, batch_size=64, shuffle=true)

Pytorch 自定義資料集

pytorch將資料集的處理過程標準化。繼承dataset類 pytorch中提供了torch.utils.data.dataset抽象類,使用時需要繼承這個類,並重寫 len 和 geiitem 函式。增加資料變換 pytorch提供了torchvision.transforms可以比較方便進行影...

PyTorch 入門 自定義資料載入

之前學習tensorflow時也學習了它的資料載入,不過在網上看了很多教程後還是有很多小問題,不知道為什麼在別人電腦上可以執行但是我的就不行 把我頭搞暈了 很煩,這時想起之前聽導師說pytorch容易入門上手,所以果斷去學了pytorch,寫這篇博文的目的就是總結學到的,然後記錄下來,也希望以後學到...

pytorch 載入資料集

2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...