pytorch載入自定義網路權重的實現

2022-09-28 21:24:17 字數 1971 閱讀 4856

在將自定義的網路權重載入到網路中時,報錯:

attributeerror: 'dict程式設計客棧' object has no attribute 'seek'. you can only torch.load from a file that is seekable. please pre-load the data into a buffer like io.bytesio and try to load from it instead.

我們一步一步分析。

模型網路權重儲存額**是:torch.s**e(net.state_dict程式設計客棧(),'net.pkl')

(1)檢視獲取模型權重的原始碼:

pytorch原始碼:net.state_dict()

def state_dict(self, destination=none, prefix='', keep_vars=false):

r"""returns a dictionary containing a whole state of the module.

both parameters and persistent buffers (e.g. running **erages) are

included. keys are corresponding parameter and buffer names.

returns:

dict:

a dictionary containing a whole state of the module

example::

>>> module.state_dict().keys()

['bias', 'weight']

"""將網路中所有的狀態儲存到乙個字典中了,我自己構建的就是乙個字典,沒問題!

(2)檢視儲存模型權重的原始碼:

pytorch原始碼:torch.s**e()

def s**e(obj, f, pickle_module=pickle, pickle_protocol=default_protocol):

"""s**es an object to a disk file.

see also: :ref:`recommend-s**ing-models`

args:

obj: s**ed object

f: a file-like object (has to implement write and flush) or a string

containing a file name

pickle_module: module used for pickling and objects

pickle_protocol: can be specified to override the default protocol

.. warning::

if you are using python 2, torch.s**e does not程式設計客棧 support stringio.stringio

as a valid file-like object. this is because the write method should return

the number of bytes written; stringio does not do this.

please use something like io.bytesio instead.

函式功能是將字典儲存為磁碟檔案(二進位制資料),那麼我們在torch.load()時,就是在記憶體中載入二進位制資料,這就是報錯點。

解決方案:將字典儲存為bytesio檔案之後,模型再net.load_state_dict()

#b為自定義的字典

torch.s**e(b,'new.pkl')

net.load_state_dict(torch.load(b))

解決方法很簡單,主要記錄解決思路。

本文標題: pytorch載入自定義網路權重的實現

本文位址:

PyTorch 入門 自定義資料載入

之前學習tensorflow時也學習了它的資料載入,不過在網上看了很多教程後還是有很多小問題,不知道為什麼在別人電腦上可以執行但是我的就不行 把我頭搞暈了 很煩,這時想起之前聽導師說pytorch容易入門上手,所以果斷去學了pytorch,寫這篇博文的目的就是總結學到的,然後記錄下來,也希望以後學到...

pytorch 自定義資料集載入方法

更多python教程請到 菜鳥教程 pytorch 官網給出的例子中都是使用了已經定義好的特殊資料集介面來載入資料,而且其使用的資料都是官方給出的資料。如果我們有自己收集的資料集,如何用來訓練網路呢?此時需要我們自己定義好資料處理介面。幸運的是pytroch給出了乙個資料集介面類 torch.uti...

Pytorch自定義引數

如果想要靈活地使用模型,可能需要自定義引數,比如 class net nn.module def init self super net,self init self.a torch.randn 2 3 requires grad true self.b nn.linear 2,2 defforwa...