資料集處理 CIFAR10

2021-10-25 06:41:19 字數 1299 閱讀 8879

transform = transforms.compose(

[transforms.totensor(),

transforms.normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.cifar10(root = 'cifar10', train = true,

download = true, transform = transform)

trainloader = torch.utils.data.dataloader(trainset, batch_size = 4,

shuffle = true, num_workers = 2)

#batch_size是每個分組的數量

testset = torchvision.datasets.cifar10(root = 'cifar10', train = false,

download = true, transform = transform)

testloader = torch.utils.data.dataloader(testset, batch_size = 4,

shuffle = false, num_workers = 2)

classes = ('plane', 'car', 'bird', 'cat',

'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

trainset包含所有測試集樣本。每個樣本是乙個元組,乙個元組包括兩個元素

tensor:的畫素值(3通道),3*32*32

testloader是乙個可迭代物件

for data in trainloader
data是乙個列表,包括4個的資料。列表包括兩個tensor

具體使用:

for data in trainloader:

inputs,labels=data

# 將兩個tensor用varible類封裝

inputs,labels=variable(inputs.cuda()),variable(labels.cuda())

# net是乙個類,包括神經網路的結構,前向傳播的過程,等價於outputs=net.forward(inputs)

outputs=net(inputs)

loss=criterion(outputs,labels)

loss.backward()

CIFAR 10資料集讀取

參考 1 使用讀取方式pickle def unpickle file import pickle with open file,rb as fo dict pickle.load fo,encoding bytes return dict 返回的是乙個python字典 2 通過字典的內建函式,獲取...

Pytorch實現CIFAR 10資料集

練習pytorch,做個記錄。寫的有點亂 import torch import torchvision import torch.nn as nn from torchvision import transforms from torch.utils.data.dataloader import ...

利用pytorch對CIFAR 10資料集的分類

步驟如下 1.使用torchvision載入並預處理cifar 10資料集 2.定義網路 3.定義損失函式和優化器 4.訓練網路並更新網路引數 5.測試網路 執行環境 windows python3.6.3 pycharm pytorch0.3.0 import torchvision as tv ...