pytorch 模型部分引數的載入

2021-10-25 12:22:07 字數 1765 閱讀 8966

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new_conv1,將分類層fc的1000類輸出改為2類輸出的new_fc,注意:要改一下名字與原來的不同。

匯入模型

mynet=resnet18()

然後就載入模型的引數,參考pytorch 如何載入部分預訓練模型

pretrained_dict=torch.load(model_weight)

model_dict=mynet.state_dict(

)# 1. filter out unnecessary keys

pretrained_dict =

# 2. overwrite entries in the existing state dict

model_dict.update(pretrained_dict)

mynet.load_state_dict(model_dict)

也可以通過pretrained model.state_dict()提取需要的模型引數。

mynet.load_state_dict(torch.load(model_weight)

, strict=

false

)

這一句話就搞定了,key相同(key可以理解為模組名字)的載入進去,不相同的就丟棄掉了。 注意,若你更改了比如某個conv的output_channel, 此時key還是相同的,當你使用load_state_dict載入時就會報錯。以下是使用coco(184類)訓練的deeplab, 嘗試載入預訓練權重到用於訓練voc資料集時(只有21類)就會出現:

error(s) in loading state_dict for dataparallelwithcallback:

1. size mismatch for module.decoder.output.7.weight: copying a param with shape torch.size([184, 256, 1, 1])from checkpoint,

the shape in current model is torch.size([91, 256, 1, 1]).

2. size mismatch for module.decoder.output.7.bias: copying a param with shape torch.size([184]) from checkpoint, the shape in current model is torch.size([91]).

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new_conv1,將分類層fc的1000類輸出改為2類輸出的new_fc,注意:要改一下名字與原來的不同。

匯入模型

mynet=resnet18()

然後就載入模型的引數,參考pytorch 如何載入部分預訓練模型

pytorch 模型部分引數的載入

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet 然後就載...

檢視模型各層引數(Pytorch

這個實驗用到的資料集是mnist資料集,維度是1 28 28 import torch.nn as nn class cnn nn.module def init self super cnn,self init 卷積層 self.conv1 nn.sequential in channels 1,...

額外引數 Pytorch獲取模型引數情況的方法

分享人工智慧技術乾貨,專注深度學習與計算機視覺領域!相較於tensorflow,pytorch一開始就是以動態圖構建神經網路圖的,其獲取模型引數的方法也比較容易,既可以根據其內建介面自己寫 獲取模型引數情況,也可以借助第三方庫來獲取模型引數情況,下面,就讓我們一起來了解pytorch獲取模型引數情況...