pytorch 深度學習框架

2021-10-24 02:30:49 字數 2567 閱讀 8769

# 定義網路

net = net(

)# 定義資料

#資料預處理,1.轉為tensor,2.歸一化

transform = transforms.compose(

[transforms.totensor(),

transforms.normalize(

(0.5

,0.5

,0.5),

(0.5

,0.5

,0.5))

])# 訓練集

trainset = torchvision.datasets.cifar10(root=

'./data'

, train=

true

, download=

true

, transform=transform)

trainloader = torch.utils.data.dataloader(trainset, batch_size=4,

shuffle=

true

, num_workers=2)

# 驗證集

testset = torchvision.datasets.cifar10(root=

'./data'

, train=

false

, download=

true

, transform=transform)

testloader = torch.utils.data.dataloader(testset, batch_size=4,

shuffle=

false

, num_workers=2)

# 定義損失函式和優化器

criterion = nn.crossentropyloss(

)optimizer = optim.sgd(net.parameters(

), lr=

0.001

, momentum=

0.9)

# 開始訓練

net.train(

)for epoch in

range(2

):# loop over the dataset multiple times

running_loss =

0.0for i, data in

enumerate

(trainloader,0)

:# get the inputs; data is a list of [inputs, labels]

inputs, labels = data

# 將梯度置為0

# zero the parameter gradients

optimizer.zero_grad(

)# 求loss

# forward + backward + optimize

outputs = net(inputs)

loss = criterion(outputs, labels)

# 梯度反向傳播

loss.backward(

)# 由梯度,更新引數

optimizer.step(

)# 視覺化

# print statistics

running_loss += loss.item(

)if i %

2000

==1999

:# print every 2000 mini-batches

print

('[%d, %5d] loss: %.3f'

%(epoch +

1, i +

1, running_loss /

2000))

running_loss =

0.0# 檢視在驗證集上的效果

dataiter =

iter

(testloader)

images, labels = dataiter.

next()

# print images

imshow(torchvision.utils.make_grid(images)

)print

('groundtruth: '

,' '

.join(

'%5s'

% classes[labels[j]

]for j in

range(4

)))net.

eval()

outputs = net(images)

_, predicted = torch.

max(outputs,1)

print

('predicted: '

,' '

.join(

'%5s'

% classes[predicted[j]

]for j in

range(4

)))

深度學習框架 PyTorch(一)

pytorch是基於python的開源深度學習框架,它包括了支援gpus計算的tensor模組以及自動求導等先進的模組,被廣泛應用於科學研究中,是最流行的動態圖框架。pytorch的運算單元叫作張量tensor。我們可以將張量理解為乙個多維陣列,一階張量即為一位陣列,通常叫作向量vector 二階張...

第04課 深度學習框架 PyTorch

隨著深度學習的研究熱潮持續高漲,各種開源深度學習框架也層出不窮,包括 tensorflow pytorch caffe2 keras cntk mxnet paddle deeplearning4 lasagne neon 等等。其中,谷歌推出的 tensorflow 無疑在關注度和使用者數上都佔據...

pytorch 深度學習

pytorch深度學習實踐 訓練集 開發集 模型評估 測試集。f x wx b f x wx b f x w x bloss 乙個樣本 cost mean square error training set 區域性最優,不一定全域性最優。鞍點 梯度為0,但無法繼續迭代。w w c ost ww w ...