Pytorch載入部分引數並凍結

2021-10-09 12:21:01 字數 2971 閱讀 4717

pytorch 模型部分引數的載入

pytorch中,只匯入部分模型引數的做法

correct way to freeze layers

pytorch自由載入部分模型引數並凍結

pytorch凍結部分引數訓練另一部分

pytorch更新部分網路,其他不更新

pytorch固定部分引數(只訓練部分層)

如果載入現有模型的所有引數,我們常使用的是**如下:

torch.load(model.state_dict())
在訓練過程中,我們常常會使用預訓練模型,有時我們是在自己的模型中加入別人的某些模組,或者對別人的模型進行區域性修改,這個時候再使用torch.load(model.state_dict()),就會出現類似這些的錯誤:runtimeerror: error(s) in loading state_dict for net:missing key(s) in state_dict:***。出現這個錯誤就是某些引數缺失或者不匹配。

現有模型中引入的那部分網路結構的網路層的名稱和結構保持不變,這時候載入引數的**很簡單。

# 載入引入的網路模型

model_path = "***"

checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))

pretrained_dict = checkpoint['net']

# 獲取現有模型的引數字典

model_dict = model.state_dict()

# 獲取兩個模型相同網路層的引數字典

state_dict =

# update必不可少,實現相同key的value同步

model_dict.update(state_dict)

# 載入模型部分引數

model.load_state_dict(model_dict)

這個時候再直接使用上面的載入方法,會導致部分key的value無法實現更新。

我就曾在這個位置犯過很嚴重的錯誤。首先我定義了attentionresnet,這是乙個unet來實現影象分割,然後在另乙個模型中我使用了這個模型self.attention_map = attentionresnet(***)。因為我在引用的過程中並沒有對attentionresnet那部分**進行修改,所以本能的覺得這部分網路層的名稱是相同的,所以載入這部分引數時,我直接使用了上面的方法。這個錯誤隱藏了差不多乙個星期。直到我開始凍結這部分引數進行訓練時,發現情況不對。因為我在輸出attention_map的特徵圖時,我發現它是一張全黑圖(畫素全為0),這表示載入的引數不對,然後我嘗試輸出pretrained_dict時,它是乙個空字典。然後繼續輸出pretrained_dict.keys()(未修改之前的pretrained_dict)和model_dict.keys()發現預期相同的那部分key中都多了一部分attention_map.。問題主要出在self.attention_map = attentionresnet(***)這一句,它使原有的網路層名稱都加了個字首attention_map.,知道了錯誤,修改起來很簡單。

# 載入引入的網路模型

model_path = "***"

checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))

pretrained_dict = checkpoint['net']

# 獲取現有模型的引數字典

model_dict = model.state_dict()

# 獲取兩個模型相同網路層的引數字典

state_dict =

# update必不可少,實現相同key的value同步

model_dict.update(state_dict)

# 載入模型部分引數

model.load_state_dict(model_dict)

其實我這個位置的修改有點投機,更加常規的方法是:

引用自pytorch自由載入部分模型引數並凍結

我們看出只要構建乙個字典,使得字典的keys和我們自己建立的網路相同,我們在從各種預訓練網路把想要的引數對著新的keys填進去就可以有乙個新的state_dict了,這樣我們就可以load這個新的state_dict,這是最普適的方法適用於所有的網路變化。

先輸出兩個模型的引數字典,觀察需要載入的那部分引數所處的位置,然後利用for迴圈構建新的字典。

將需要固定的那部分引數的requires_grad置為false.

在優化器中加入filter根據requires_grad進行過濾.

ps: 解決attributeerror: 『nonetype』 object has no attribute 『data』問題的一種思路就是凍結引數,參考部落格

**如下:

# requires_grad置為false

for p in net.***.parameters():

p.requires_grad = false

# filter

optimizer.sgd(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

當需要凍結的那部分引數的網路層名稱不太明確時,可以採用pytorch凍結部分引數訓練另一部分的思路,列印出所有網路層,通過引數名稱進行凍結。

pytorch 模型部分引數的載入

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet 然後就載...

pytorch 模型部分引數的載入

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet18 然後...

pytorch 更新部分引數(凍結引數)注意事項

實驗的pytorch版本1.2.0 在訓練過程中可能需要固定一部分模型的引數,只更新另一部分引數。有兩種思路實現這個目標,乙個是設定不要更新引數的網路層為false,另乙個就是在定義優化器時只傳入要更新的引數。當然最優的做法是,優化器中只傳入requires grad true的引數,這樣占用的記憶...