pytorch中的hook機制

2021-10-14 06:05:30 字數 4478 閱讀 3226

由於pytorch中,訓練產生的中間變數會在訓練結束後被釋放掉,因此想要將這些變數儲存下來,需要用到hook函式,hook可以理解為乙個外掛程式函式,掛載在原有函式上.

這個用於儲存反向傳播時候的梯度

flag =

1if flag:

#定義網路

w = torch.tensor([1

.], requires_grad=

true

) x = torch.tensor([2

.], requires_grad=

true

) a = torch.add(w, x)

b = torch.add(w,1)

y = torch.mul(a, b)

#定義乙個空列表,用於儲存hook捕捉的梯度

a_grad =

list()

#定義hook函式,

defgrad_hook

(grad)

:#將hook捕捉的梯度儲存到a_grad中

grad *=

2#return為tensor型別時候,會將tensor資料賦給被掛載的變數;return為none的時候則不操作

return grad*

3#掛載hook函式到tensor變數a上

handle = a.register_hook(grad_hook)

#執行反向傳播,這時候在執行反向傳播的過程中會執行a的hook函式

y.backward(

)# 檢視hook儲存的梯度

print

("w.grad: "

, w.grad)

handle.remove(

)

共有三種:

forward_pre_hook:記錄網路前向傳播前的特徵圖

forward_hook:記錄前向傳播後的特徵圖

backward_hook:記錄反向傳播後的梯度資料

flag =

1if flag:

#定義網路

class

net(nn.module)

:def

__init__

(self)

:super

(net, self)

.__init__(

) self.conv1 = nn.conv2d(1,

2,3)

self.pool1 = nn.maxpool2d(2,

2)defforward

(self, x)

: x = self.conv1(x)

x = self.pool1(x)

return x

defforward_hook

(module, data_input, data_output)

:def

forward_pre_hook

(module, data_input)

:print

("forward_pre_hook input:{}"

.format

(data_input)

)def

backward_hook

(module, grad_input, grad_output)

:print

("backward hook input:{}"

.format

(grad_input)

)print

("backward hook output:{}"

.format

(grad_output)

)# 初始化網路

net = net(

) net.conv1.weight[0]

.detach(

).fill_(1)

net.conv1.weight[1]

.detach(

).fill_(2)

net.conv1.bias.data.detach(

).zero_(

)# 註冊hook

fmap_block =

list()

input_block =

list()

net.conv1.register_forward_hook(forward_hook)

net.conv1.register_forward_pre_hook(forward_pre_hook)

net.conv1.register_backward_hook(backward_hook)

# inference

fake_img = torch.ones((1

,1,4

,4))

# batch size * channel * h * w

output = net(fake_img)

loss_fnc = nn.l1loss(

) target = torch.randn_like(output)

loss = loss_fnc(target, output)

loss.backward(

)

在執行

output = net(fake_img)
的時候,實際上是執行了

#---------------------這一段是判斷是否有forward_pre_hook,並執行-----------------

def_call_impl

(self,

*input

,**kwargs)

:for hook in itertools.chain(

_global_forward_pre_hooks.values(),

self._forward_pre_hooks.values())

: result = hook(self,

input

)if result is

notnone:if

notisinstance

(result,

tuple):

result =

(result,

)input

= result

#---------------------這一段是真正執行forward-----------------

if torch._c._get_tracing_state():

result = self._slow_forward(

*input

,**kwargs)

else

: result = self.forward(

*input

,**kwargs)

#---------------------這一段是判斷是否有forward_hook,並執行-----------------

for hook in itertools.chain(

_global_forward_hooks.values(),

self._forward_hooks.values())

: hook_result = hook(self,

input

, result)

if hook_result is

notnone

: result = hook_result

#---------------------這一段是判斷是否有backward_hook,並執行----------------- if(

len(self._backward_hooks)

>0)

or(len(_global_backward_hooks)

>0)

: var = result

while

notisinstance

(var, torch.tensor):if

isinstance

(var,

dict):

var =

next

((v for v in var.values()if

isinstance

(v, torch.tensor)))

else

: var = var[0]

grad_fn = var.grad_fn

if grad_fn is

notnone

:for hook in itertools.chain(

_global_backward_hooks.values(),

self._backward_hooks.values())

:return result

js中的鉤子機制 hook

什麼是鉤子機制?使用鉤子機制有什麼好處?鉤子機制也叫hook機制,或者你可以把它理解成一種匹配機制,就是我們在 中設定一些鉤子,然後程式執行時自動去匹配這些鉤子 這樣做的好處就是提高了程式的執行效率,減少了if else 的使用同事優化 結構。由於js是單執行緒的程式語言,所以程式的執行效率在前端開...

pytorch中的廣播機制

pytorch中的廣播機制和numpy中的廣播機制一樣,因為都是陣列的廣播機制 兩個維度不同的tensor可以相乘,示例a torch.arange 0,6 reshape 6 tensor 0,1,2,3,4,5 shape torch.size 6 ndim 1 b torch.arange 0...

鉤子 HOOK 機制的使用

wh mouse,gethookinfo,hinstance,getcurrentthreadid mymousehook.callbackfun callbackf mymousehook.isrun not mymousehook.isrun end end procedure uninstal...