pytorch的gather函式的一些粗略的理解

2021-10-05 13:22:29 字數 2562 閱讀 8736

先給出官方文件的解釋,我覺得官方的文件寫的已經很清楚了,四個引數分別是input,dim,index,out,輸出的tensor是以index為大小的tensor。

其中,這就是最關鍵的定義

out[i]

[j][k]

= tensor[index[i]

[j][k]

][j]

[k]# dim=0

out[i]

[j][k]

= tensor[i]

[index[i]

[j][k]

][k]

# dim=1

out[i]

[j][k]

= tensor[i]

[j][index[i]

[j][k]

]# dim=3

主要解釋一下dim,dim=0的時候,把index的元素放入進行索引,有一點需要注意的是,引數index的tensor格式是除了第1維也就是行那一維之外,其他維的格式需與input保持一致!下面給個例子

import torch 

a = torch.arange(0,

16).view(4,

4)index = torch.longtensor([[

0,1,

2,3]

])b = a.gather(

0, index)

print

(a)print

(index)

print

(b)#形象的理解就是在每一列的第index上進行索引

for j in

range(4

):print

(a[index[0]

[j]]

[j].item())

----

----

----

----

----

----

----

----

----

----

----

----

----

----

----

----

----

tensor([[

0,1,

2,3]

,[4,

5,6,

7],[

8,9,

10,11]

,[12,

13,14,

15]])

tensor([[

0,1,

2,3]

])tensor([[

0,5,

10,15]

])05

1015

dim = 1的時候,把index的元素放入進行索引,有一點需要注意的是,引數index的tensor格式是除了第2維也就是列那一維之外,其他維的格式需與input保持一致!下面給個例子

import torch 

a = torch.arange(0,

16).view(4,

4)index = torch.longtensor([[

0],[

1],[

2],[

3]])

b = a.gather(

1, index)

print

(a)print

(index)

print

(b)#形象的理解就是在每一行的第index列上進行索引

for j in

range(4

):print

(a[j]

[index[j][0

]].item())

----

----

----

----

----

----

----

----

----

----

----

----

----

----

----

----

----

tensor([[

0,1,

2,3]

,[4,

5,6,

7],[

8,9,

10,11]

,[12,

13,14,

15]])

tensor([[

0],[

1],[

2],[

3]])

tensor([[

0],[

5],[

10],[

15]])

051015

本人對矩陣的一些概念還有一些模糊不清,以上就是我的一些理解,希望有大佬可以一起交流一下,pytorch 的張量一開始很難處理清楚,還需慢慢來。

pytorch中的gather函式

from 今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather b torch.tensor 1,2,3 4,5,6 print bindex 1 torch.longtensor 0,1 2,0 index 2 torch.longtensor 0,1,1 0...

pytorch的gather 方法詳解

首先,先將結果展示出來,後續根據結果來進行分析 t torch.tensor 1,2,3 4,5,6 index a torch.longtensor 0,0 0,1 index b torch.longtensor 0,1,1 1,0,0 print t print torch.gather t,...

我對pytorch中gather函式的一點理解

torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...