PyTorch學習 載入資料集

2021-10-09 02:08:36 字數 3978 閱讀 3480

需要定義diabetesdataset做為載入資料集diabetes的類, 繼承自dataset,dataset是抽象類,需要實現其中的三個方法,__ initgetitemlen __

import torch

from torch.utils.data import dataset # 抽象類

from torch.utils.data import dataloader

import numpy as np

class

diabetesdataset

(dataset)

:# 繼承自dataset

def__init__

(self, filepath)

: xy = np.loadtxt(filepath, delimiter =

',', dtype = np.float32)

self.

len= xy.shape[0]

self.x_data = torch.from_numpy(xy[:,

:-1]

) self.y_data = torch.from_numpy(xy[:,

[-1]

])def__getitem__

(self, index)

:# 支援下標操作,根據索引獲取資料

return self.x_data[index]

, self.y_data[index]

def__len__

(self)

:# 獲取資料條數

return self.

lendataset = diabetesdataset(

'diabetes.csv.gz'

)train_loader = dataloader(dataset = dataset,

# 處理的資料集

batch_size =32,

# 每次處理的資料大小

shuffle =

true

,# 是否打亂

num_workers =0)

# 多執行緒數量,在windows裡需要設定為0, linux可以大於0

class

model

(torch.nn.module)

:def

__init__

(self)

:super

(model, self)

.__init__(

) self.linear1 = torch.nn.linear(8,

6)self.linear2 = torch.nn.linear(6,

4)self.linear3 = torch.nn.linear(4,

1)self.sigmoid = torch.nn.sigmoid(

)# 與nn.function.sigmoid不同,用於構建計算圖

defforward

(self, x)

: x = self.sigmoid(self.linear1(x)

) x = self.sigmoid(self.linear2(x)

) x = self.sigmoid(self.linear3(x)

)return x

model = model(

)criterion = torch.nn.bceloss(reduction=

'mean'

)# 損失函式

optimizer = torch.optim.sgd(model.parameters(

), lr =

0.1)

# 優化器

if __name__ ==

'__main__'

:for epoch in

range

(100):

for i, data in

enumerate

(train_loader,0)

:#1. prepare data

inputs, labels = data

# 2.forward

y_pred = model(inputs)

loss = criterion(y_pred, labels)

print

(epoch, i, loss.item())

# 3.backward

optimizer.zero_grad(

) loss.backward(

)# 4.update

optimizer.step(

)

輸出:

0 0 0.6936783194541931

0 1 0.693471372127533

0 2 0.6917673349380493

0 3 0.6861389875411987

0 4 0.6913132667541504

0 5 0.6789288520812988

0 6 0.6768878698348999

0 7 0.6651645302772522

0 8 0.6861144304275513

0 9 0.6686166524887085

0 10 0.6661809682846069

0 11 0.6636384129524231

0 12 0.6618748307228088

0 13 0.6681938767433167

0 14 0.6153277158737183

0 15 0.6548603773117065

... ...

98 14 0.5910221338272095

98 15 0.6699521541595459

98 16 0.6283824443817139

98 17 0.6495291590690613

98 18 0.6865949630737305

98 19 0.6016601920127869

98 20 0.630635678768158

98 21 0.6044492721557617

98 22 0.6302173137664795

98 23 0.6102578043937683

99 0 0.5284566283226013

99 1 0.6872431039810181

99 2 0.6330350041389465

99 3 0.6103817820549011

99 4 0.6251040697097778

99 5 0.6059320569038391

99 6 0.6281994581222534

99 7 0.6733802556991577

99 8 0.6273549795150757

99 9 0.7067252993583679

99 10 0.6479067802429199

99 11 0.7034580111503601

99 12 0.633543848991394

99 13 0.5920330882072449

99 14 0.6311102509498596

99 15 0.6479007601737976

99 16 0.6280706524848938

99 17 0.6995146870613098

99 18 0.6469420790672302

99 19 0.6414950489997864

99 20 0.5969923734664917

99 21 0.5866757035255432

99 22 0.5923041105270386

99 23 0.524055004119873

pytorch 載入資料集

2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...

pytorch 載入自己的資料集

pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的 init 方法 getitem 方法 len 方法。預設載入的都是,init 的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用 getitem 方法得到...

pytorch載入資料

參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...