龍良曲pytorch學習筆記 載入寶可夢資料集

2022-07-22 13:21:11 字數 4972 閱讀 4062

1

import

torch

2import

os,glob

3import

random,csv45

from torch.utils.data import

dataset,dataloader67

from torchvision import

transforms

8from pil import

image910

class

pokemon(dataset):

11'''

12@param

13root:儲存的根路徑

1415

mode:train或者test模式

16'''

17def

__init__

(self,root,resize,mode):

18 super(pokemon,self).__init__

()19

20 self.root =root

21 self.resize =resize

2223

#字典型別key:name value:label

24 self.name2label ={}25#

listdir返回順序不固定,用sorted將它固定,因為排序一次之後就固定了

26for name in

sorted(os.listdir(os.path.join(root))):

27if

notos.path.isdir(os.path.join(root,name)):

28continue

2930 self.name2label[name] =len(self.name2label.keys())

3132

#print(self.name2label)

3334

#image_path + image_label

35 self.images,self.labels = self.load_csv('

images.csv')

3637

if mode == '

train

': #

60%38 self.images = self.images[:int(0.6*len(self.images))]

39 self.labels = self.labels[:int(0.6*len(self.labels))]

40elif mode == '

val': #

20%41 self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]

42 self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]

43elif mode == '

test

': #

20% = 80% ->100%

44 self.images = self.images[int(0.8*len(self.images)):]

45 self.labels = self.labels[int(0.8*len(self.labels)):]

4647

defload_csv(self,filename):

4849

#如果不存在再寫入,存在的話直接讀取就可以了

保證images和labels一一對應,長度相等

84assert len(images) ==len(labels)

85return

images,labels

8687

def__len__

(self):

8889

return

len(self.images)

9091

defdenormalize(self,x_hat):

9293 mean=[0.485,0.456,0.406]

94 std=[0.229,0.224,0.225]

9596

#x_hat = (x-mean)/std97#

x = x_hat*std+mean98#

x: [c,h,w]99#

mean: [3] --> [3,1,1]

100 mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)

101 std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

102103 x = x_hat*std +mean

104105

return

x106

107108

def__getitem__

(self,idx):

109#

idx~[0~len(images)]

110#

self.images,self.labels

111#

img: pokemon\\bulbasaur\\00000000.png'

112#

label: 0

113 img,label =self.images[idx],self.labels[idx]

114115 tf =transforms.compose([

116lambda x:image.open(x).convert('

rgb'), #

string path --> image data

117 transforms.resize((int(self.resize*1.25),int(self.resize*1.25))),

118 transforms.randomrotation(15),

119transforms.centercrop(self.resize),

120transforms.totensor(),

121 transforms.normalize(mean=[0.485,0.456,0.406],

122 std=[0.229,0.224,0.225])

123])

124125 img =tf(img)

126 label =torch.tensor(label)

127128

return img,label

Pytorch 學習筆記

本渣的pytorch 逐步學習鞏固經歷,希望各位大佬指正,寫這個部落格也為了鞏固下記憶。a a b c 表示從 a 取到 b 步長為 c c 2 則表示每2個數取1個 若為a 1 1 1 即表示從倒數最後乙個到正數第 1 個 最後乙個 1 表示倒著取 如下 陣列的第乙個為 0 啊,第 0 個!彆扭 ...

Pytorch學習筆記

資料集 penn fudan資料集 在學習pytorch官網教程時,作者對penn fudan資料集進行了定義,並且在自定義的資料集上實現了對r cnn模型的微調。此篇筆記簡單總結一下pytorch如何實現定義自己的資料集 資料集必須繼承torch.utils.data.dataset類,並且實現 ...

Pytorch學習筆記

lesson 1.張量 tensor 的建立和常用方法 一 張量 tensor 的基本建立及其型別 import torch 匯入pytorch包 import numpy as np torch.version 檢視版本號1.張量 tensor 函式建立方法 張量 tensor 函式建立方法 t ...