我對pytorch中gather函式的一點理解

2021-10-03 10:07:56 字數 2159 閱讀 7525

torch.gather(

input

,dim,index,out=

none

) → tensor

torch.gather(

input

, dim, index, out=

none

) → tensor

gathers values along an axis specified by dim.

for a 3

-d tensor the output is specified by:

out[i]

[j][k]

=input

[index[i]

[j][k]

][j]

[k]# dim=0

out[i]

[j][k]

=input

[i][index[i]

[j][k]

][k]

# dim=1

out[i]

[j][k]

=input

[i][j]

[index[i]

[j][k]

]# dim=2

parameters:

input

(tensor) – the source tensor

dim (

int) – the axis along which to index

index (longtensor) – the indices of elements to gather

out (tensor, optional) – destination tensor

example:

>>

> t = torch.tensor([[

1,2]

,[3,

4]])

>>

> torch.gather(t,

1, torch.longtensor([[

0,0]

,[1,

0]])

)114

3[torch.floattensor of size 2x2]

import torch

a = torch.tensor([[

1,2]

,[3,

4]])

b = torch.gather(a,

1,torch.longtensor([[

0,0]

,[1,

0]])

)#1. 取各個元素行號:[(0,y)(0,y)][(1,y)(1,y)]

#2. 取各個元素值做行號:[(0,0)(0,0)][(1,1)(1,0)]

#3. 根據得到的索引在輸入中取值

#[1,1],[4,3]

c = torch.gather(a,

0,torch.longtensor([[

0,0]

,[1,

0]])

)#1. 取各個元素列號:[(x,0)(x,1)][(x,0)(x,1)]

#2. 取各個元素值做行號:[(0,0)(0,1)][(1,0)(0,1)]

#3. 根據得到的索引在輸入中取值

#[1,2],[3,2]

假設輸入與上同;index=b;輸出為c

b中每個元素分別為b(0,0)=0,b(0,1)=0

b(1,0)=1,b(1,1)=0

如果dim=0(列)

則取b中元素的列號,如:b(0,1)的1

b(0,1)=0,所以c中的c(0,1)=輸入的(0,1)處元素2

如果dim=1(行)

則取b中元素的列號,如:b(0,1)的0

b(0,1)=0,所以c中的c(0,1)=輸入的(0,0)處元素1

總結如下:

輸出 元素 在 輸入張量 中的位置為:

輸出元素位置取決與同位置的index元素

dim=1時,取同位置的index元素的行號做行號,該位置處index元素做列號

dim=0時,取同位置的index元素的列號做列號,該位置處index元素做行號。

最後根據得到的索引在輸入中取值

index型別必須為longtensor

gather最終的輸出變數與index同形。

對Pytorch 中的contiguous理解說明

最近遇到這個函式,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答並加上自己的理解。在pytorch中,只有很少幾個操作是不改變tensor的內容本身,而只是重新定義下標與元素的對應關係的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是元資料。這些操作是 ...

對Pytorch中backward()函式的理解

寫在第一句 這個部落格解釋的也很好,參考了很多 pytorch中的自動求導函式backward 所需引數含義 所以切入正題 backward 函式中的引數應該怎麼理解?官方 如果需要計算導數,可以在tensor上呼叫.backward 1.如果tensor是乙個標量 即它包含乙個元素的資料 則不需要...

對PyTorch中inplace欄位的全面理解

torch.nn.relu inplace true inplace true 表示進行原地操作,對上一層傳遞下來的tensor直接進行修改,如x x 3 inplace false 表示新建乙個變數儲存操作結果,如y x 3,x y inplace true 可以節省運算記憶體,不用多儲存變數。補...