pytorch tensor 獲得中間節點的梯度

2021-10-01 18:39:58 字數 2027 閱讀 7708

我們只能指定計算圖的leaf節點的requires_grad變數來決定改變量是否記錄梯度,而不能指定它們運算產生的節點的requires_grad,它們是否要梯度取決於它們的輸入節點,它們的輸入節點只要有乙個requires_grad是true,那麼它的requires_grad也是true.

x = torch.randn(2,

100)

x.requires_grad =

false

w = torch.randn(10,

100)

w2 = torch.randn(3,

10)w.requires_grad =

true

2w.requires_grad =

true

y = x @ w.t(

)z = y @ w2.t(

)print

(y.requires_grad, z.requires_grad)

z.sum()

.backward(

)

對於葉節點,如果我們指定了梯度,我們可以呼叫v.grad檢視梯度;但是對於中間變數v.grad永遠是none,如果要獲得其梯度,就要使用register_hook,它會在呼叫這個變數的梯度反傳的時候呼叫註冊的函式.以下是乙個簡單的檢視版本

import torch

from torch import nn

defhook

(grad)

:print

(grad)

x = torch.randn(2,

100)

x.requires_grad =

false

w = torch.randn(10,

100)

w2 = torch.randn(3,

10)w.requires_grad =

true

w2.requires_grad =

true

y = x @ w.t(

)z = y @ w2.t(

)y.register_hook(hook)

z.sum()

.backward(

)# invoke get_grad('y') here

改進版

import torch

class

gradcollector

(object):

def__init__

(self)

: self.grads =

def__call__

(self, name:

str)

:def

hook

(grad)

: self.grads[name]

= grad

return hook

x = torch.randn(2,

100)

x.requires_grad =

false

w = torch.randn(10,

100)

w2 = torch.randn(3,

10)w.requires_grad =

true

w2.requires_grad =

true

y = x @ w.t(

)z = y @ w2.t(

)grad_collector = gradcollector(

)y.register_hook(grad_collector(

"y")

)z.register_hook(grad_collector(

'z'))z.

sum(

).backward(

)print

(grad_collector.grads[

'y']

)print

(grad_collector.grads[

'z']

)

Pytorch Tensor和tensor的區別

在pytorch中,tensor和tensor都能用於生成新的張量 a torch.tensor 1 2 a tensor 1 2.a torch.tensor 1 2 a tensor 1 2 首先,我們需要明確一下,torch.tensor 是python類,更明確地說,是預設張量型別torch...

pytorch tensor 篩選排除

篩選排除還沒找到答案 取數運算 正好遇到乙個需求。我有m行k列的乙個表a,和乙個長為m的索引列表b。b中儲存著,取每行第幾列的元素。這種情況下,你用普通的索引是會失效的。import torch a torch.longtensor 1,2,3 4,5,6 b torch.longtensor 0,...

Pytorch tensor的感知機

單層感知機的主要步驟 單層感知機梯度的推導 要進行優化的是w,對w進行梯度下降 a torch.randn 1,10 a是乙個 1,10 的向量 w torch.randn 1,10,requires grad true w是乙個可導的 1,10 的向量 1.2.經過乙個sigmoid啟用函式 o ...