pytorch中的gather函式

2021-08-20 21:39:26 字數 1812 閱讀 6030

from:

今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather

b = torch.tensor([[1,2,3],[4,5,6]]

)print

bindex_1 = torch.longtensor([[0,1],[2,0]]

)index_2 = torch.longtensor([[0,1,1],[0,0,0]]

)print

torch.gather(b, dim=1

, index=index_1)

print

torch.gather(b, dim=0

, index=index_2)

觀察它的輸出結果:

1

2 3

4 5

6[torch.floattensor of size 2

x3] 1

2 64[torch.floattensor of size 2

x2] 1

5 6

1 2

3[torch.floattensor of size 2

x3]

這裡是官方文件的解釋

torch.gather(input, dim, index, out=none) → tensor

gathers values along an axis specified by dim.

for a 3

-d tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # dim=0

out[i][j][k] = input[i][index[i][j][k]][k] # dim=1

out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

parameters:

input (tensor) – the source tensor

dim (int) – the axis along which to index

index (longtensor) – the indices of elements to gather

out (tensor, optional) – destination tensor

example:

>>> t = torch.tensor([[1,2],[3,4]]

) >>> torch.gather(t, 1

, torch.longtensor([[0,0],[1,0]]

)) 1

1 4

3 [torch.floattensor of size 2

x2]

可以看出,gather的作用是這樣的,index實際上是索引,具體是行還是列的索引要看前面dim 的指定,比如對於我們的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是橫向,那麼索引就是列號。index的大小就是輸出的大小,所以比如index是【1,0;0,0】,那麼看index第一行,1列指的是2, 0列指的是1,同理,第二行為4,4 。這樣就輸入為【2,1;4,4】,參考這樣的解釋看上面的輸出結果,即可理解gather的含義。

gather在one-hot為輸出的多分類問題中,可以把最大值座標作為index傳進去,然後提取到每一行的正確**結果,這也是gather可能的乙個作用。

2023年05月30日20:05:01

春去夏來,溫情演為慾望。 —— 作家, 安德烈莫羅阿

pytorch的gather 方法詳解

首先,先將結果展示出來,後續根據結果來進行分析 t torch.tensor 1,2,3 4,5,6 index a torch.longtensor 0,0 0,1 index b torch.longtensor 0,1,1 1,0,0 print t print torch.gather t,...

我對pytorch中gather函式的一點理解

torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...

pytorch的gather函式的一些粗略的理解

先給出官方文件的解釋,我覺得官方的文件寫的已經很清楚了,四個引數分別是input,dim,index,out,輸出的tensor是以index為大小的tensor。其中,這就是最關鍵的定義 out i j k tensor index i j k j k dim 0 out i j k tensor...