pytorch中state dict的理解

2021-10-05 21:49:34 字數 2395 閱讀 4298

在pytorch中,state_dict是乙個python字典物件(在這個有序字典中,key是各層引數名,value是各層引數),包含模型的可學習引數(即權重和偏差,以及bn層的的引數) 優化器物件(torch.optim)也具有state_dict,其中包含有關優化器狀態以及所用超引數的資訊。其實看了如下**的輸出應該就懂了。

import torch

import torch.nn as nn

import torchvision

import numpy as np

from torchsummary import summary

# define model

class

themodelclass

(nn.module)

:def

__init__

(self)

:super

(themodelclass, self)

.__init__(

) self.conv1 = nn.conv2d(3,

6,5)

self.pool = nn.maxpool2d(2,

2)self.conv2 = nn.conv2d(6,

16,5)

self.fc1 = nn.linear(16*

5*5,

120)

self.fc2 = nn.linear(

120,84)

self.fc3 = nn.linear(84,

10)defforward

(self, x)

: x = self.pool(f.relu(self.conv1(x)))

x = self.pool(f.relu(self.conv2(x)))

x = x.view(-1

,16*5

*5) x = f.relu(self.fc1(x)

) x = f.relu(self.fc2(x)

) x = self.fc3(x)

return x

# initialize model

model = themodelclass(

)# initialize optimizer

optimizer = torch.optim.sgd(model.parameters(

), lr=

0.001

, momentum=

0.9)

# print model's state_dict

print

("model's state_dict:"

)for param_tensor in model.state_dict():

print

(param_tensor,

"\t"

, model.state_dict(

)[param_tensor]

.size())

# print optimizer's state_dict

print

("optimizer's state_dict:"

)for var_name in optimizer.state_dict():

print

(var_name,

"\t"

, optimizer.state_dict(

)[var_name]

)

輸出如下:

model's state_dict:

conv1.weight torch.size([6

,3,5

,5])

conv1.bias torch.size([6

])conv2.weight torch.size([16

,6,5

,5])

conv2.bias torch.size([16

])fc1.weight torch.size(

[120

,400])

fc1.bias torch.size(

[120])

fc2.weight torch.size([84

,120])

fc2.bias torch.size([84

])fc3.weight torch.size([10

,84])

fc3.bias torch.size([10

])optimizer's state_dict:

state

param_groups [

]

我是剛接觸深度學西的小白乙個,希望大佬可以為我指出我的不足,此部落格僅為自己的筆記!!!!

Pytorch 中 torchvision的錯誤

在學習pytorch的時候,使用 torchvision的時候發生了乙個小小的問題 安裝都成功了,並且import torch也沒問題,但是在import torchvision的時候,出現了如下所示的錯誤資訊 dll load failed 找不到指定模組。首先,我們得知道torchvision在...

Pytorch中建立DataLoader的幾種方法

簡介 這段 是mnist手寫體識別中的部分 此篇 為mnist手寫體識別中的 import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import datalo...

pytorch中的乘法

總結 按元素相乘用torch.mul,二維矩陣乘法用torch.mm,batch二維矩陣用torch.bmm,batch 廣播用torch.matmul if name main a torch.tensor 1 2,3 b torch.arange 0,12 reshape 4 3 c torch...