Pytorch scatter 理解軸的含義

2021-08-18 06:20:32 字數 1109 閱讀 1726

scatter_(input, dim, index, src)將src中資料根據index中的索引按照dim的方向填進input中。

>>> x = torch.rand(2, 5)

>>> x

0.4319 0.6500 0.4080 0.8760 0.2355

0.2609 0.4711 0.8486 0.8573 0.1029

[torch.floattensor of size 2x5]

longtensor的shape剛好與x的shape對應,也就是longtensor每個index指定x中乙個資料的填充位置。dim=0,表示按行填充,主要理解按行填充。舉例longtensor中的第0行第2列index=2,表示在第2行(從0開始)進行填充填充,對應到zeros(3, 5)中就是位置(2,2)。所以此處要求zeros(3, 5)的列數要與x列數相同,而longtensor中的index最大值應與zeros(3, 5)行數相一致。

>>> torch.zeros(3, 5).scatter_(0, torch.longtensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

0.4319 0.4711 0.8486 0.8760 0.2355

0.0000 0.6500 0.0000 0.8573 0.0000

0.2609 0.0000 0.4080 0.0000 0.1029

[torch.floattensor of size 3x5]

同上理,可以把1.23看成[[1.23], [1.23]]。此處按列填充,longtensor中的index=2對應zeros(2, 4)的(0,2)位置。

>>> z = torch.zeros(2, 4).scatter_(1, torch.longtensor([[2], [3]]), 1.23)

>>> z

0.0000 0.0000 1.2300 0.0000

0.0000 0.0000 0.0000 1.2300

[torch.floattensor of size 2x4]

pytorch scatter直觀理解

簡單記錄一下pytorch scatter 的理解,官方解釋在 官方的例子如下,下面說說使用層面的直觀理解。src torch.arange 1,11 reshape 2,5 src tensor 1,2,3,4,5 6,7,8,9,10 index torch.tensor 0,1,2,0 tor...

25 理一理關於tensorflow的各種騷操作

1.tf.squeeze 2.tf.cast 3.tf.expand dims 4.tf.slice 按照指定的下標範圍抽取連續區域的子集 講的不錯 5.tf.gather 按照指定的下標集合從axis 0中抽取子集,適合抽取不連續區域的子集 6.tf.one hot 7.tf.transpose ...

理專案思緒,

下一步本來想專攻,想選擇如下兩個之一 1 日立cobol,2 英語日語都用得上的軟體開發相關工作,專攻的話,自己會的東西既有市場又有一定的門檻,這樣自己就會穩步公升值,再成功開發三五個專案後,正常情況下就單幹這兩個專攻之一,可以在沒有找到自己想專攻的兩個中的任何乙個,出現乙個機遇,就是得到從美國朋友...