Pytorch對多維Tensor按維度操作

2021-10-10 18:57:13 字數 1039 閱讀 7540

記憶要點:

dim = 0 說明是處理行

dim = 1 說明是處理列

keepdim = true  保留處理的行/列的特徵

keepdim = false 不保留處理的行/列的特徵

網上流傳的版本有很多,但是我們根據結果來說話。我的理解是哪個維度發生了變化就是處理的是哪個維度。

if __name__ == "__main__":

#模型引數初始化

num_input = 784

num_output = 10

w = torch.tensor(np.random.normal(0,0.1,(num_input,num_output)),dtype=torch.float32)

b = torch.tensor(num_output,dtype=torch.float32)

w.requires_grad_(requires_grad = true)

b.requires_grad_(requires_grad = true)

#多維tensor按維度操作

x = torch.tensor([[1,2,3],[4,5,6]])

print(x.sum(dim = 0,keepdim = true)) # dim為0,將同一列中所有行相加,並在結果中保留行特徵 1,3

print(x.sum(dim = 1,keepdim = true)) # dim為1,同一行中所有列相加,並在結果中保留列特徵 2,1

print(x.sum(dim = 0,keepdim = false))# dim為0,將同一列中所有行相加,並在結果中不保留行特徵 3

print(x.sum(dim = 1,keepdim = false)) # dim為1,同一行中所有列相加,並在結果中不保留行特徵 2

#tensor([[5, 7, 9]])

tensor([[ 6],

[15]])

tensor([5, 7, 9])

tensor([ 6, 15])

#

pytorch多維篩選

多級篩選 比如結構是2 2 3,只想選第三維的最大的 tx index,best n,g y center,g x center index 01 best n 0,1 最後只取兩個值,第一行,第1列,第二行,第2列的。篩選第3維最大的值,下面的 不對,解決方法 查詢max原始碼 也可以把3維用vi...

pytorch 拓展cuda語言 多維索引問題

四維矩陣索引公式 四維矩陣 n,c,h,w 當前四維索引為 n,c,h,w,out idx n c h w c h w h w w pytorch中的 permute a torch.randn 5,3,4 a的size為 5,3,4 b a.permute 0,2,1 此時b的size會變成 5,...

pytorch實現BP,處理多維資料輸入

import torch import matplotlib.pyplot as plt import numpy as np xy np.loadtxt 000.txt delimiter dtype np.float32 x data torch.from numpy xy 1 取前九列 y d...