pytorch基礎知識整理(四) 模型

2021-10-13 21:16:45 字數 1510 閱讀 4000

torch.nn.module()是所有網路模型的基類,所有網路都需要繼承此類,模板如下:

import torch.nn as nn

import torch.nn.functional as f

class model

(nn.module)

: def __init__

(self)

:super()

.__init__

() #'表示繼承父類的__init__()方法'

self.conv1 = nn.

conv2d(1

,20,5

) self.conv2 = nn.

conv2d(20

,20,5

) def forward

(self, x)

: x = f.

relu

(self.

conv1

(x))

return f.

relu

(self.

conv2

(x))

model.modules()model.children()返回模型所有模組/子模組的迭代器。

model.named_modules(), model.named_children()返回模型所有模組/子模組的名字和模組本身的迭代器。

model.parameters()返回模型所有引數的迭代器。常用來作為optimizer的迭代器。

model.register_parameter(name, param)向模型新增parameter。

model.register_buffer(name, tensor)向模型新增buffer。

model.state_dict(), model.load_state_dict()返回/載入 狀態字典。

model.train(), model.eval()訓練/推理模式,僅影響模型中的dropout和bn層。

model.cpu(),model.cuda()把模型中的所有parameters和buffers賦值到cpu/gpu中。

model.float(), model.half(), model.double()轉換模型的所有parameters和buffers的型別。

model.zero_grad()把模型所有parameters的梯度置0,和optimizer.zero_grad()完全等效。

注:可以用model.layer_name的方法得到子模型,如model.conv1.parameters()就得到了conv1的引數

c 基礎知識整理(四)

一 explicit pragma once include class explicit test explicit test operator const explicit test other 不帶explicit 測試結果 explicit test aa 5 直接隱式轉換,可以傳乙個引數是...

Hibernate基礎知識整理(四)

事務的邊界 開啟事務 transcation tx session.begintransaction 提交事務 tx.commit 回滾事務 tx.rollback hibernate的事務是通過呼叫jdbc來直接實現的,預設hibernate事務是不開啟的。通常事務的邊界控制是放在service層...

pytorch基礎知識整理 一)自動求導機制

torch.autograd是pytorch最重要的元件,主要包括variable類和function類,variable用來封裝tensor,是計算圖上的節點,function則定義了運算操作,是計算圖上的邊。1.tensor tensor張量和numpy陣列的區別是它不僅可以在cpu上執行,還可...