pytorch的gather 方法詳解

2021-10-13 08:47:44 字數 1088 閱讀 5744

首先,先將結果展示出來,後續根據結果來進行分析

t = torch.tensor([[1,2,3],[4,5,6]])

index_a = torch.longtensor([[0,0],[0,1]])

index_b = torch.longtensor([[0,1,1],[1,0,0]])

print(t)

print(torch.gather(t,dim=1,index=index_a))

print(torch.gather(t,dim=0,index=index_b))

執行結果如下:

>>tensor([[1., 2., 3.],

[4., 5., 6.]])

>>tensor([[1., 1.],

[4., 5.]])

>>tensor([[1., 5., 6.],

[4., 2., 3.]])

結果分析

首先,三個引數介紹

輸入的變數input、指定在某一維上聚合的dim、聚合使用的索引index,輸出為tensor型別的結果(index必須為longtensor型別

1.當dim = 1 時,是在第二個維度進行融合,index是所需值在原矩陣中的索引。即按照列來進行融合,第幾行列號為[x,x]的值。[0, 0]表示第一行第1列的值,兩個都為1.[0, 1]表示對應原陣列第二行第一列和第二列的值,分別為4,5.

因此最終結果如下:

tensor([[1., 1.],

[4., 5.]])

當dim = 0時,是在第乙個維度進行融合,按行來進行融合。第幾列行號為[x,x,x]的值。[0,1,1]表示對應原陣列第一列第一行(0)的值1、第二列第二行(1)的值5,第三列第二行(1)的值6.

最終取值如下

tensor([[1., 5., 6.],

[4., 2., 3.]])

pytorch中的gather函式

from 今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather b torch.tensor 1,2,3 4,5,6 print bindex 1 torch.longtensor 0,1 2,0 index 2 torch.longtensor 0,1,1 0...

我對pytorch中gather函式的一點理解

torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...

pytorch的gather函式的一些粗略的理解

先給出官方文件的解釋,我覺得官方的文件寫的已經很清楚了,四個引數分別是input,dim,index,out,輸出的tensor是以index為大小的tensor。其中,這就是最關鍵的定義 out i j k tensor index i j k j k dim 0 out i j k tensor...