額外引數 Pytorch獲取模型引數情況的方法

2021-10-14 18:15:06 字數 1910 閱讀 2088

分享人工智慧技術乾貨,專注深度學習與計算機視覺領域!

相較於tensorflow,pytorch一開始就是以動態圖構建神經網路圖的,其獲取模型引數的方法也比較容易,既可以根據其內建介面自己寫**獲取模型引數情況,也可以借助第三方庫來獲取模型引數情況,下面,就讓我們一起來了解pytorch獲取模型引數情況的這兩種方法!

pytorch依據其內建介面自己寫**獲取模型引數情況,我們主要是借助該框架提供的模型parameters()介面並獲取對應引數的size來實現的,對於該引數是否屬於可訓練引數,那麼我們可以依據pytorch提供的requires_grad標誌位來進行判斷,具體方法如下**所示:

# 定義總參數量、可訓練參數量及非可訓練參數量變數

total_params = 0

trainable_params = 0

nontrainable_params = 0

# 遍歷model.parameters()返回的全域性引數列表

for param in model.parameters():

mulvalue = np.prod(param.size()) # 使用numpy prod介面計算引數陣列所有元素之積

total_params += mulvalue # 總參數量

if param.requires_grad:

trainable_params += mulvalue # 可訓練參數量

else:

nontrainable_params += mulvalue # 非可訓練參數量

print(f'total params: ')

print(f'trainable params: ')

print(f'non-trainable params: ')

如無特殊設定,一般來說,因為我們是直接獲取的model網路引數,因此很少有不可訓練引數,往往nontrainable_params輸出結果是0。

這裡的第三方庫是指torchsummary,欲要使用該庫,首先我們得安裝它,命令如下:

pip install torchsummary
然後,引入該庫的summary方法:

from torchsummary import summary
最後,直接呼叫一條命令即可獲取到pytorch模型引數情況:

summary(model, input_size=(ch, h, w), batch_size=-1)
這裡的ch是指輸入張量的channel數量,h表示輸入張量的高,w表示輸入張量的寬。

我們從以上**可以看到,借助第三方庫torchsummary來獲取pytorch的模型引數情況非常之簡便,只需確認好輸入影象shape即可,那麼,torchsummary的輸出是如何的呢?

上圖是應用torchsummary獲得輸出結果的乙個示例,這與tensorflow v2.x及其之後的版本的模型summary()輸出是差不多的,輸出資訊裡也是有各個類別的參數量情況、每層網路的參數量、額外的層名稱及其輸出shape大小,此外,torchsummary庫還為我們計算了輸入大小、模型引數大小及前向/反向傳播參數量大小,可謂資訊非常細緻,這極大地方便了我們檢視pytorch模型的構造情況。

除了上述兩種獲取pytorch模型引數情況的方法,我們當然也可以直接使用model.state_dict()介面獲取pytorch網路引數,但是此種方法列印出來的資訊結構非常混亂,也沒有為我們進行有效的資訊整理,因此很不建議該方法。

RequestBody怎麼獲取額外的引數

有這麼乙個情況,我的controller 層接收 這麼乙個json串 但是我的req中卻沒有pageno和pagesize這兩個字段,但是,往後面新增引數卻不能接收到 接收不到後面的引數 public object getfunction requestbody reqparam req,reque...

檢視模型各層引數(Pytorch

這個實驗用到的資料集是mnist資料集,維度是1 28 28 import torch.nn as nn class cnn nn.module def init self super cnn,self init 卷積層 self.conv1 nn.sequential in channels 1,...

pytorch 模型部分引數的載入

如果對預訓練模型的結構進行了一些改動,在訓練的開始前希望載入未改動部分的引數,如將resnet18的第一層卷積層conv1的輸入由3通道改為6通道的new conv1,將分類層fc的1000類輸出改為2類輸出的new fc,注意 要改一下名字與原來的不同。匯入模型 mynet resnet 然後就載...