pytorch 修改預訓練模型

2021-09-25 11:16:04 字數 729 閱讀 2079

torchvision中提供了很多訓練好的模型,這些模型是在1000類,224*224的imagenet中訓練得到的,很多時候不適合我們自己的資料,可以根據需要進行修改。

1、類別不同

#

coding=utf-8

import

torchvision.models as models

#呼叫模型

model = models.resnet50(pretrained=true)

#提取fc層中固定的引數

fc_features =model.fc.in_features

#修改類別為9

model.fc = nn.linear(fc_features, 9)

2、新增層後,載入部分引數

model =...

model_dict =model.state_dict()

#1. filter out unnecessary keys

pretrained_dict =

#2. overwrite entries in the existing state dict

model_dict.update(pretrained_dict)

#3. load the new state dict

model.load_state_dict(model_dict)

參考:

pytorch 載入預訓練模型

pytorch的torchvision中給出了很多經典的預訓練模型,模型的引數和權重都是在imagenet資料集上訓練好的 載入模型 方法一 直接使用預訓練模型中的引數 import torchvision.models as models model models.resnet18 pretrai...

Pytorch 修改預訓練網路結構

我們以 inceptionv3 為例 pytorch裡我們如何使用設計好的網路結構,比如inceptionv3 import torchvision.models as models inception models.inception v3 pretrained true pytorch提供了個叫...

pytorch載入預訓練模型後,訓練指定層

1 有了已經訓練好的模型引數,對這個模型的某些層做了改變,如何利用這些訓練好的模型引數繼續訓練 pretrained params torch.load pretrained model model the new model model.load state dict pretrained par...