對Pytorch中backward()函式的理解

2021-09-16 12:24:59 字數 1980 閱讀 6887

寫在第一句:

這個部落格解釋的也很好,參考了很多:pytorch中的自動求導函式backward()所需引數含義

所以切入正題:

backward()函式中的引數應該怎麼理解?

官方:如果需要計算導數,可以在tensor上呼叫.backward()。		

1. 如果tensor是乙個標量(即它包含乙個元素的資料),則不需要為backward()指定任何引數

2. 但是如果它有更多的元素,則需要指定乙個gradient引數,它是形狀匹配的張量。

先看看官方怎麼解釋的
backward(gradient=none, retain_variables=false)[source]
當前variable(理解成函式y)對leaf variable(理解成變數x=[x1,x2,x3])求偏導。

計算圖可以通過鏈式法則求導。如果variable是 非標量(non-scalar)的(即是說y中有不止乙個y,即y=[y1,y2,…]),且requires_grad=true。那麼此函式需要指定gradient,它的形狀應該和variable的長度匹配(這個就很好理解了,gradient的長度體與y的長度一直才能儲存每乙個yi的梯度值啊),裡面儲存了variable的梯度。

此函式累積leaf variable的梯度。你可能需要在呼叫此函式之前將variable的梯度置零。(梯度不置零的話為出現累加

引數:gradient (tensor) – 其他函式對於此variable的導數。僅當variable不是標量的時候使用,型別和位形狀應該和self.data一致。

(補充:這裡說的時其他函式對variable的導數!)retain_variables (bool) – true, 計算梯度所必要的buffer在經歷過一次backward過程後不會被釋放。如果你想多次計算某個子圖的梯度的時候,設定為true。在某些情況下,使用autograd.backward()效率更高。

情況1:out是乙個標量(就是說乙個輸出值)

'''情況1:out是乙個標量(就是說乙個輸出值)'''

b=a+3

c=b*3

out=c.mean()

out.backward()

print('input:',a.data)

print('input gradients are:',a.grad)

print(a.numel()) #返回元素個數,所以c關於a的導數應該是[(a+3)*3]/6 就是0.5

輸出結果:

input: tensor([[1., 1., 1.],

[1., 1., 1.]])

input gradients are: tensor([[0.5000, 0.5000, 0.5000],

[0.5000, 0.5000, 0.5000]])

6

這就是最簡單的情況

情況2:out是乙個向量(就是說輸出一列值)

a=t.ones(2,1,requires_grad=true)

b=t.zeros(2,1)

b[0,0]=a[0,0]**2+a[1,0]*5

b[1,0]=a[0,0]**3+a[1,0]*4

接下來開始嘗百草:

'''try 1'''

b.backward() #什麼引數也不說

print(a.grad)

不出意外,得到錯誤:

runtimeerror: grad can be implicitly created only for scalar outputs

'''try 2'''
未完待續…

對Pytorch 中的contiguous理解說明

最近遇到這個函式,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答並加上自己的理解。在pytorch中,只有很少幾個操作是不改變tensor的內容本身,而只是重新定義下標與元素的對應關係的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是元資料。這些操作是 ...

對PyTorch中inplace欄位的全面理解

torch.nn.relu inplace true inplace true 表示進行原地操作,對上一層傳遞下來的tensor直接進行修改,如x x 3 inplace false 表示新建乙個變數儲存操作結果,如y x 3,x y inplace true 可以節省運算記憶體,不用多儲存變數。補...

我對pytorch中gather函式的一點理解

torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...