pytorch載入模型與凍結

2021-10-05 09:45:25 字數 1245 閱讀 8365

weights=torch.load(path)
with open('a.pkl', 'wb') as f:

pickle.dump(score_dict, f)

weights=pickle.load(f)

直接載入

model.load_state_dict(weights)
字典生成式載入

self.load_state_dict()

#self.load_state_dict()

collections.ordereddict()載入

b = collections.ordereddict()

for k,v in weights.items():

b[k]=v

#b[k.replace("module.","")]=v

self.load_state_dict(b)

更新引數

new_weights=model.state_dict()

new_weights.update(weights) # 將weights中引數更新至new_weights中

self.load_state_dict(new_weights)

for key, value in model.named_parameters():# named_parameters()包含網路模組名稱 key為模型模組名稱 value為模型模組值,可以通過判斷模組名稱進行對應模組凍結

value.requires_grad = true

for value in model.parameters()():#不包含網路模組名稱 value為模型模組值

value.requires_grad = true

使用filter過濾需要的模組引數

optimizer = optim.sgd(

filter(lambda p: p.requires_grad, model.parameters()), #只更新 requires_grad=true的引數,即進行反向傳播的引數

lr=,

momentum=,

weight_decay=,

nesterov=

)

Pytorch載入部分引數並凍結

pytorch 模型部分引數的載入 pytorch中,只匯入部分模型引數的做法 correct way to freeze layers pytorch自由載入部分模型引數並凍結 pytorch凍結部分引數訓練另一部分 pytorch更新部分網路,其他不更新 pytorch固定部分引數 只訓練部分層...

Pytorch載入模型時報錯

報錯截圖如下 反覆排查問題沒發現為何如此,檢視pytorch中文文件發現儲存和載入模型方法都完全正確,模型儲存和載入 對比中文文件截圖如下 其中乙個方法是在載入模型時新增引數strict false,可以只保留鍵值相同的引數避免出錯,用法如下 model.load state dict ckpt s...

pytorch 載入預訓練模型

pytorch的torchvision中給出了很多經典的預訓練模型,模型的引數和權重都是在imagenet資料集上訓練好的 載入模型 方法一 直接使用預訓練模型中的引數 import torchvision.models as models model models.resnet18 pretrai...