pytorch中的乘法

2021-10-18 14:29:49 字數 4098 閱讀 2463

總結:按元素相乘用torch.mul,二維矩陣乘法用torch.mm,batch二維矩陣用torch.bmm,batch、廣播用torch.matmul

if __name__==

"__main__"

: a = torch.tensor([1

,2,3

])b = torch.arange(0,

12).reshape((4

,3))

c = torch.tensor([4

,5,6

,7])

d = torch.arange(0,

12).reshape((3

,4))

aa = torch.unsqueeze(a, dim=1)

cc = torch.unsqueeze(c, dim=0)

print

('a:,a.shape:'

.format

(a, a.shape)

)print

('b:,b.shape:'

.format

(b, b.shape)

)print

('c:,c.shape:'

.format

(c, c.shape)

)print

('d:,d.shape:'

.format

(d, d.shape)

)print

('aa:,aa.shape:'

.format

(aa, aa.shape)

)print

('cc:,cc.shape:'

.format

(cc, cc.shape)

)#output:

a:tensor([1

,2,3

]),a.shape:torch.size([3

])b:tensor([[

0,1,

2],[

3,4,

5],[

6,7,

8],[

9,10,

11]])

,b.shape:torch.size([4

,3])

c:tensor([4

,5,6

,7])

,c.shape:torch.size([4

])d:tensor([[

0,1,

2,3]

,[4,

5,6,

7],[

8,9,

10,11]

]),d.shape:torch.size([3

,4])

aa:tensor([[

1],[

2],[

3]])

,aa.shape:torch.size([3

,1])

cc:tensor([[

4,5,

6,7]

]),cc.shape:torch.size([1

,4])

** torch.mul()元素乘:能自動增加維度,並且沿著新維度進行廣播,或者之前有新維度且為1。**

print

(a.mul(b)

)#a.shape:torch.size([3])

print

(aa.mul(d)

)#aa.shape:torch.size([3, 1])

#output:

tensor([[

0,2,

6],[

3,8,

15],[

6,14,

24],[

9,20,

33]])

tensor([[

0,1,

2,3]

,[8,

10,12,

14],[

24,27,

30,33]

])

torch.mm():二維矩陣相乘,並且滿足對應的乘法規則,不能廣播

print

(b.mm(d)

)#output:

tensor([[

20,23,

26,29]

,[56,

68,80,

92],[

92,113,

134,

155],[

128,

158,

188,

218]

])

** torch.matmul():可進行廣播,以及batch乘**

情況1:向量✖向量:點乘

>>

>

# vector x vector

>>

> tensor1 = torch.randn(3)

>>

> tensor2 = torch.randn(3)

>>

> torch.matmul(tensor1, tensor2)

.size(

)torch.size(

)

情況2:矩陣與向量相乘:向量增加乙個新維度1,矩陣相乘後,再將此維度移除。

如 tensor1 @ tensor2 --> (3,4)(4,) --> (3,4)(4,1) --> (3,1) --> (3,)

>>

>

# matrix x vector

>>

> tensor1 = torch.randn(3,

4)>>

> tensor2 = torch.randn(4)

>>

> torch.matmul(tensor1, tensor2)

.size(

)torch.size([3

])

情況3:批量矩陣與向量乘,向量broadcast到dim=1,以及dim=0,之後後兩個維度作矩陣乘法,對於增加的維度1最後移除。

如: tensor1 @ tensor2 --> (10,3,4)(4,) --> (10,3,4)(10,4,1) --> (10,3,1) --> (10,3)

>>

>

# batched matrix x broadcasted vector

>>

> tensor1 = torch.randn(10,

3,4)

>>

> tensor2 = torch.randn(4)

>>

> torch.matmul(tensor1, tensor2)

.size(

)torch.size([10

,3])

情況4:批量矩陣與批量矩陣相乘,後兩維度矩陣乘。

>>

>

# batched matrix x batched matrix

>>

> tensor1 = torch.randn(10,

3,4)

>>

> tensor2 = torch.randn(10,

4,5)

>>

> torch.matmul(tensor1, tensor2)

.size(

)torch.size([10

,3,5

])

情況5:批量矩陣與矩陣乘,矩陣先broadcast到批量數,之後後兩個維度乘。

>>

>

# batched matrix x broadcasted matrix

>>

> tensor1 = torch.randn(10,

3,4)

>>

> tensor2 = torch.randn(4,

5)>>

> torch.matmul(tensor1, tensor2)

.size(

)torch.size([10

,3,5

])

PyTorch 矩陣乘法總結

torch.mm mat1,mat2,out none 其中mat1 n times m mat2 m times d 輸出out的維度是 n times d 該函式一般只用來計算兩個二維矩陣的矩陣乘法,並且不支援broadcast操作。由於神經網路訓練一般採用mini batch,經常輸入的時三維...

Pytorch 中 torchvision的錯誤

在學習pytorch的時候,使用 torchvision的時候發生了乙個小小的問題 安裝都成功了,並且import torch也沒問題,但是在import torchvision的時候,出現了如下所示的錯誤資訊 dll load failed 找不到指定模組。首先,我們得知道torchvision在...

pytorch中index select 的用法

a torch.linspace 1,12,steps 12 view 3,4 print a b torch.index select a,0,torch.tensor 0,2 print b print a.index select 0,torch.tensor 0,2 c torch.inde...