Pytorch的資料載入

2021-10-08 23:43:28 字數 2351 閱讀 9678

pytorch將資料集的處理過程標準化,提供了dataset基本的資料 類,並在torchvision中提供了眾多資料變換函式,資料載入的具體過程 主要分為3步:

1.繼承dataset類

對於資料集的處理,pytorch提供了torch.utils.data.dataset這個抽象 類,在使用時只需要繼承該類,並重寫__len__()和__getitem()__函式, 即可以方便地進行資料集的迭代。

from torch.utils.data import dataset

class

my_data

(dataset)

:def

_init_

(self,image_path,annotation_path,transform-

none):

#初始化,讀取資料集

def_len_

(self)

:#獲取資料集的總大小

def_getitem_

(self,id)

:#對於指定的id,讀取資料並返回

對上述初始化的·實際使用:

dataset = my_data(

"your image path"

,"your annotation path"

)# 例項化該類 for data in dataset:

print

(data)

2.資料變換與增強:torchvision.transforms

第一步雖然將資料集載入到了例項中,但在實際應用時,資料集中 的有可能存在大小不一的情況,並且原始畫素rgb值較大 (0~255),這些都不利於神經網路的訓練收斂,因此還需要進行一些 影象變換工作。pytorch為此提供了torchvision.transforms工具包,可以 方便地進行影象縮放、裁剪、隨機翻轉、填充及張量的歸一化等操作, 操作物件是pil的image或者tensor。

如果需要進行多個變換功能,可以利用transforms.compose將多個 變換整合起來,並且在實際使用時,通常會將變換操作整合到dataset的 繼承類中。具體示例如下:

from torchvision import transforms

#將transform整合到dataset類中,使用compose將多個變換整合到一起

dataset = my_data(

"your image path"

,"your annotation path"

,transforms=transforms.compose(

[transforms.resize(

256)

#將影象最短邊縮小至256,寬高比例不變

#以0.5的概率隨即翻轉指定的pil影象

transforms.randomhorizaontalflip(

)#將pil影象轉為tensor,元素區間從[0,255]歸一化到[0,1]

transforms.totensor(

)#進行mean與std為0.5的標準化

transforms.normalize(

[0.5

,0.5

,0.5],

[0.5

,0.5

,0.5])

]))

3.繼承dataloader

經過前兩步已經可以獲取每乙個變換後的樣本,但是仍然無法進行 批量處理、隨機選取等操作,因此還需要torch.utils.data.dataloader類進 一步進行封裝,使用方法如下例所示,該類需要4個引數,第1個引數是 之前繼承了dataset的例項,第2個引數是批量batch的大小,第3個引數是 是否打亂資料引數,第4個引數是使用幾個執行緒來載入資料。

from torch.utils.data import dataloader

# 使用dataloader進一步封裝dataset

dataloader = dataloader(dataset, batch_size=

4, shuffle=

true

, num_workers=

4)

dataloader是乙個可迭代物件,對該例項進行迭代即可用於訓練過程。

data_iter =

iter

(dataloader)

for step in

range

(iters_per_epoch)

: data =

next

(data_iter)

# 將data用於訓練網路即可

pytorch載入資料

參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...

pytorch十 資料載入

在pytorch中,資料載入可通過自定義的資料集物件實現。資料及物件被抽象為dataset類,實現自定義的資料集需要繼承dataset,並實現兩個python魔法方法。這裡我們以kaggle經典挑戰比賽 dogs vs cat 的資料為例,詳細講解如何處理資料。這是乙個分類問題,判斷一張是狗還是貓,...

Pytorch資料載入 (一)

在pytorch中,資料載入可以通過自定義的資料集物件實現。資料集物件被抽象為dataset類,實現自定義的資料集需要繼承datase類,並且實現python的兩個魔法方法。a.getitem 返回一條資料或者樣本。如obj index 等價於obj.getitem index 如果定義乙個 cla...