pytorch筆記 scatter 的使用

2021-10-13 07:10:41 字數 4072 閱讀 7182

scatter_(input, dim, index, src)將src中資料根據index中的索引按照dim的方向填進input中.

1 >>> x = torch.rand(2, 5)

2 >>> x

3 4 0.4319 0.6500 0.4080 0.8760 0.2355

5 0.2609 0.4711 0.8486 0.8573 0.1029

6 [torch.floattensor of size 2x5]

1) dim = 0,分別對每列填充:

>>> torch.zeros(3, 5).scatter_(0, torch.longtensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

0.4319 0.4711 0.8486 0.8760 0.2355

0.0000 0.6500 0.0000 0.8573 0.0000

0.2609 0.0000 0.4080 0.0000 0.1029

[torch.floattensor of size 3x5]

實現原理:

對於lonetensor內的矩陣,暫且稱為 tmp = [[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]];將最終的 3*5的矩陣,暫且稱為result。result初始為全0,需要經過scatter_處理。

舉例:對於tmp[0][0] = 0  -> 取x中x[0][0] = 0.4319,將其插入到result第0列的第0個位置,result[0][0] = 0.4319;

對於tmp[0][1] = 1  -> 取x中x[0][1] = 0.6500,將其插入到result第1列的第1個位置,result[1][1] = 0.6500;

對於tmp[0][2] = 2  -> 取x中x[0][1] = 0.4080,將其插入到result第2列的第2個位置,result[2][2] = 0.4080;

對於tmp[1][0] = 2  -> 取x中x[1][0] = 0.2609,將其插入到result第0列的第2個位置,result[2][0] = 0.2609;

對於tmp[1][1] = 0  -> 取x中x[1][1] = 0.4711,將其插入到result第1列的第0個位置,result[0][1] = 0.4711。

2) dim = 1,分別對每行填充

1 >>> z = torch.zeros(2, 4).scatter_(1, torch.longtensor([[2], [3]]), 1.23)

2 >>> z

3 4 0.0000 0.0000 1.2300 0.0000

5 0.0000 0.0000 0.0000 1.2300

6 [torch.floattensor of size 2x4]

tmp = [[2], [3]]

tmp[0][0] = 2 -> 取x中x[0][0] = 0.4319,將其插入到result第0行的第2個位置,result[0][2] = 0.4319;

pytorch 中,一般函式加下劃線代表直接在原來的 tensor 上修改

scatter(dim, index, src) 的引數有 3 個

這個 scatter  可以理解成放置元素或者修改元素

簡單說就是通過乙個張量 src  來修改另乙個張量,哪個元素需要修改、用 src 中的哪個元素來修改由 dim 和 index 決定

官方文件給出了 3維張量 的具體操作說明,如下所示

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0

self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1

self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

exmaple:

x = torch.rand(2, 5)

#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],

# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],

# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],

# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

具體地說,我們的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),乙個二維張量,下面用圖簡單說明

我們是 2維 張量,一開始進行 self[index[0][0]][0]self[index[0][0]][0],其中 index[0][0]index[0][0] 的值是0,所以執行 self[0][0]=x[0][0]=0.1940self[0][0]=x[0][0]=0.1940 

src 除了可以是張量外,也可以是乙個標量

example:

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)

#tensor([[7., 7., 7., 7., 7.],

# [0., 7., 0., 7., 0.],

# [7., 0., 7., 0., 7.]]

scatter()一般可以用來對標籤進行 one-hot 編碼,這就是乙個典型的用標量來修改張量的乙個例子

example:

class_num = 10

batch_size = 4

label = torch.longtensor(batch_size, 1).random_() % class_num

#tensor([[6],

# [0],

# [3],

# [2]])

torch.zeros(batch_size, class_num).scatter_(1, label, 1)

#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],

# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],

# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],

# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

Pytorch 學習筆記

本渣的pytorch 逐步學習鞏固經歷,希望各位大佬指正,寫這個部落格也為了鞏固下記憶。a a b c 表示從 a 取到 b 步長為 c c 2 則表示每2個數取1個 若為a 1 1 1 即表示從倒數最後乙個到正數第 1 個 最後乙個 1 表示倒著取 如下 陣列的第乙個為 0 啊,第 0 個!彆扭 ...

Pytorch學習筆記

資料集 penn fudan資料集 在學習pytorch官網教程時,作者對penn fudan資料集進行了定義,並且在自定義的資料集上實現了對r cnn模型的微調。此篇筆記簡單總結一下pytorch如何實現定義自己的資料集 資料集必須繼承torch.utils.data.dataset類,並且實現 ...

PyTorch入門筆記

原教程 資料集csv 此處使用numpy來匯入,除此之外還可以使用csv和pandas匯入 資料集鏈結 import csv import numpy as np wine path data chapter3 winequality white.csv 路徑 wineq numpy np.load...