torch讀取資料

2021-10-10 10:51:30 字數 2609 閱讀 1279

使用torchvision.datasets.imagefolder來讀取

torchvision.datasets.imagefolder(root=

"root folder path"

,[transform, target_transform]

)

target_transform - 乙個函式,輸入為target,輸出對其的轉換。例子,輸入的是標註的string,輸出為word的索引。

有幾個變數

self.classes - 用乙個list儲存 類名

self.class_to_idx - 類名對應的 索引

self.imgs - 儲存(img-path, class) tuple的list

定義乙個torch.utils.data.dataset資料類

dataset有兩個函式

class

dataset_1

(torch.utils.data.dataset)

:def

__init__

(self,root,is_resize=

false

,is_transfrom=

false):

self.root=root

self.is_resize=is_resize

self.is_transfrom=is_transfrom

self.imgs_list=..

.#儲存路徑節省記憶體

self.labs_list=..

.def

__getitem__

(self, index)

: img_path,lab=self.imgs_list[index]

,self.labs_list[index]

img_data = image.

open

(img_path)

if self.is_transfrom:

img_data=self.is_transfrom(img_data)

return img_data,lab

def__len__

(self)

:return

len(self.imgs_list)

定義好dataset資料類,之後使用dataloader匯入

torch.utils.data.dataloader(dataset=dataset_1, batch_size=args.batchsize, shuffle=

true

, num_workers=args.nthreads)

有時需要對資料進行處理

train_transforms = torchvision.transforms.compose(

[torchvision.transforms.resize(

256)

,torchvision.transforms.centercrop(

224)

,torchvision.transforms.randomhorizontalflip(),

torchvision.transforms.totensor()]

)img = image.

open

('test.png'

)train_transforms(img)

訓練資料
train_loader = torch.utils.data.dataloader(dataset=dataset_1, batch_size=args.batchsize, shuffle=

true

, num_workers=args.nthreads)

model = torchvision.models.__dict__[

'resnet101'

](pretrained=

true

)model.load_state_dict(torch.load(

'...pth'))

model.to(device)

# 訓練

optimizer = torch.optim.sgd(model.parameters(

),lr =

0.01

)loss_fn = torch.nn.mseloss(

)model.train(

)for batch_index,

(data,target)

inenumerate

(train_loader)

: data, target = data.to(device)

, target.to(device)

output = model(data)

loss = loss_fn(output, target)

optimizer.zero_grad(

) loss.backward(

) optimizer.step(

)# 推理

model.

eval()

output = model(data)

Torch之讀取梯度

讀取gradient z.backward torch.ones like x 我們的返回值不是乙個標量,所以需要輸入乙個大小相同的張量作為引數,這裡我們用ones like函式根據x生成乙個張量。個人認為,因為要對x和y分別求導數,所以函式z必須是求得的乙個值,即標量。然後開始對x,y分別求偏導數...

torch學習筆記2 資料整理

官方入門資料 getting started with torch torch自帶package說明文件 torch package reference manual torch tensor運算說明文件 torch tensor torch使用常見問題 torch7 faq torch wiki ...

Torch學習 開始

到目前為止出現了各種各樣的深度學習的解決方案框架,其中包括caffe,cuda convnet,pylearn2,theano,torch以及tensorflow等。caffe和cuda convnet沒用過,不過一直很火的樣子,pylearn2包含了cuda covnet,入門用theano,te...