pytorch中引數dim的含義 正負,零,不傳

2022-06-02 03:18:13 字數 2878 閱讀 9115

總結:

torch.function(x, dim)

1.if 不傳: 依照預設引數決定

2.if dim >=0 and dim <= x.dim()-1: 0是沿最粗資料粒度的方向進行操作,x.dim()-1是按最細粒度的方向。

3.if dim <0: dim的最小取值(此按照不同function而定)到最大取值(-1)之間。與情況2正好相反,最大的取值(-1)代表按最細粒度的方向,最小的取值按最粗粒度的方向。

實驗**:(使用torch.max(x, dim)為例子)

1.dim=2

mout[77]:

tensor([[1, 2, 3],

[4, 5, 6]])

torch.max(m,)

out[85]: tensor(6)

不傳:預設引數的設定是對整個傳入的資料進行操作

torch.max(m, dim=0)

out[79]:

torch.return_types.max(

values=tensor([4, 5, 6]),

indices=tensor([1, 1, 1]))

此處最粗粒度是兩行之間[1, 2, 3]->[4, 5, 6]的方向,也就是常說是縱向進行操作。

torch.max(m, dim=1)

out[78]:

torch.return_types.max(

values=tensor([3, 6]),

indices=tensor([2, 2]))

此處最細粒度是一行之內[1, 2, 3]的方向,也就是常說是橫向進行操作。

torch.max(m, dim=2)

traceback (most recent call last):

file "/home/xutianfan/anaconda3/lib/python3.8/site-packages/ipython/core/interactiveshell.py", line 3418, in run_code

exec(code_obj, self.user_global_ns, self.user_ns)

file "", line 1, in

torch.max(m, dim=2)

indexerror: dimension out of range (expected to be in range of [-2, 1], but got 2)

torch.max(m, dim=-1)

out[86]:

torch.return_types.max(

values=tensor([3, 6]),

indices=tensor([2, 2]))

-1+2=1,同torch.max(m, dim=1)結果。

torch.max(m, dim=-2)

out[87]:

torch.return_types.max(

values=tensor([4, 5, 6]),

indices=tensor([1, 1, 1]))

2.dim=3(tensor)

t1out[89]:

tensor([[[0, 1, 2, 3],

[1, 2, 3, 4]],

[[2, 3, 4, 5],

[4, 5, 6, 7]],

[[5, 6, 7, 8],

[6, 7, 8, 9]]])

torch.max(t1)

out[94]: tensor(9)

torch.max(t1, dim=0)

out[91]:

torch.return_types.max(

values=tensor([[5, 6, 7, 8],

[6, 7, 8, 9]]),

indices=tensor([[2, 2, 2, 2],

[2, 2, 2, 2]]))

最粗粒度是在各個矩陣之間的方向,所以對各個矩陣的每個位置分別取最大。

torch.max(t1, dim=1)

out[92]:

torch.return_types.max(

values=tensor([[1, 2, 3, 4],

[4, 5, 6, 7],

[6, 7, 8, 9]]),

indices=tensor([[1, 1, 1, 1],

[1, 1, 1, 1],

[1, 1, 1, 1]]))

其次粗的粒度是矩陣中各行之間的方向

torch.max(t1, dim=2)

out[93]:

torch.return_types.max(

values=tensor([[3, 4],

[5, 7],

[8, 9]]),

indices=tensor([[3, 3],

[3, 3],

[3, 3]]))

最細粒度是各行之內的方向。所以取出了各行中最大的元素。

torch.max(t1, dim=-1)

out[97]:

torch.return_types.max(

values=tensor([[3, 4],

[5, 7],

[8, 9]]),

indices=tensor([[3, 3],

[3, 3],

[3, 3]]))

雖然我們這裡只使用了max函式,但是這對於torch中其他函式(例如softmax)也有效。

可以有這種寫法:mean = x.mean(-1, keepdim=true)

這樣無論是對於2維還是3維的輸入,都自動dim=input.dim()-1,也就是從最細粒度取平均。

Pytorch 二維矩陣中關於dim使用的樣例

import torch a torch.tensor 1.0,2.0,3.0 2.0,2.0,2.0 3.0,2.0,1.0 print a.shape print a torch.size 3,3 tensor 1.2.3.2.2.2.3.2.1.使用softmax函式,當dim 1時,矩陣中每...

C main函式中引數argv,argc的含義

argc 是 argument count的縮寫,表示傳入main函式的引數個數。argv 是 argument vector的縮寫,表示傳入main函式的引數序列或指標。第乙個引數argv 0 一定是程式的名稱,並且包含了程式所在的完整路徑,所以輸入main函式的引數個數實際是argc 1個。in...

對pytorch的函式中的group引數的作用介紹

1.當設定group 1時 conv nn.conv2d in channels 6,out channels 6,kernel size 1,groups 1 conv.weight.data.size 返回 torch.size 6,6,1,1 另乙個例子 conv nn.conv2d in c...