pytorch中的廣播機制

2021-10-01 21:12:07 字數 3333 閱讀 2253

pytorch中的廣播機制和numpy中的廣播機制一樣, 因為都是陣列的廣播機制

兩個維度不同的tensor可以相乘, 示例

a = torch.arange(0,

6).reshape((6

,))'''

tensor([0, 1, 2, 3, 4, 5])

shape: torch.size([6])

ndim: 1

'''b = torch.arange(0,

12).reshape((2

,6))

'''tensor([[ 0, 1, 2, 3, 4, 5],

[ 6, 7, 8, 9, 10, 11]])

shape: torch.size([2, 6])

ndim: 2

'''# a和b的ndim不同, 但是可以element-wise相乘, 因為用到了廣播機制

res = torch.mul(a,b)

'''tensor([[ 0, 1, 4, 9, 16, 25],

[ 0, 7, 16, 27, 40, 55]])

shape: torch.size([2, 6])

ndim: 2

'''

如何理解陣列的廣播機制

以陣列a和陣列b的相加為例, 其餘數**算同理

核心:如果相加的兩個陣列的shape不同, 就會觸發廣播機制, 1)程式會自動執行操作使得a.shape==b.shape, 2)對應位置進行相加

運算結果的shape是:a.shape和b.shape對應位置的最大值,比如:a.shape=(1,9,4),b.shape=(15,1,4),那麼a+b的shape是(15,9,4)

有兩種情況能夠進行廣播

a.ndim > b.ndim, 並且a.shape最後幾個元素包含b.shape, 比如下面三種情況, 注意不要混淆ndim和shape這兩個基本概念

a.ndim == b.ndim, 並且a.shape和b.shape對應位置的元素要麼相同要麼其中乙個是1, 比如

下面分別進行舉例

a.ndim 大於 b.ndim

# a.shape=(2,2,3,4)

a = np.arange(1,

49).reshape((2

,2,3

,4))

# b.shape=(3,4)

b = np.arange(1,

13).reshape((3

,4))

# numpy會將b.shape調整至(2,2,3,4), 這一步相當於numpy自動實現np.tile(b,[2,2,1,1])

res = a + b

print

('***********************************'

)print

(a)print

(a.shape)

print

('***********************************'

)print

(b)print

(b.shape)

print

('***********************************'

)print

(res)

print

(res.shape)

print

('***********************************'

)print

(a+b == a + np.tile(b,[2

,2,1

,1])

)

a.ndim 等於 b.ndim
#示例1

# a.shape=(4,3)

a = np.arange(12)

.reshape(4,

3)# b.shape=(4,1)

b = np.arange(4)

.reshape(4,

1)# numpy會將b.shape調整至(4,3), 這一步相當於numpy自動實現np.tile(b,[1,3])

res = a + b

print

('***********************************'

)print

(a)print

(a.shape)

print

('***********************************'

)print

(b)print

(b.shape)

print

('***********************************'

)print

(res)

print

(res.shape)

print

('***********************************'

)print

((a+b == a + np.tile(b,[1

,3])

))# 列印結果都是true

#示例2

# a.shape=(1,9,4)

a = np.arange(1,

37).reshape((1

,9,4

))# b.shape=(15,1,4)

b = np.arange(1,

61).reshape((15

,1,4

))res = a + b

print

('***********************************'

)# print(a)

print

(a.shape)

print

('***********************************'

)# print(b)

print

(b.shape)

print

('***********************************'

)# print(res)

print

(res.shape)

print

('***********************************'

)q = np.tile(a,[15

,1,1

])+ np.tile(b,[1

,9,1

])print

(q == res)

# 列印結果都是true

pytorch的廣播機制

廣播機制,就是將不同維度 不同長度的tensor,在滿足一定規則的前提下能夠自動進行長度和維度的擴充,從而使不同維度 不同長度的tensor之間正確的進行運算。自動廣播規則 兩個tensor能夠進行自動廣播需要滿足以下幾個規則 對應相等 其中乙個tensor的大小等於1 其中乙個tensor的某個維...

pytorch的廣播機制

廣播機制,就是將不同維度 不同長度的tensor,在滿足一定規則的前提下能夠自動進行長度和維度的擴充,從而使不同維度 不同長度的tensor之間正確的進行運算。自動廣播規則 兩個tensor能夠進行自動廣播需要滿足以下幾個規則 對應相等 其中乙個tensor的大小等於1 其中乙個tensor的某個維...

numpy中的廣播機制

numpy兩個陣列的相加 相減以及相乘都是對應元素之間的操作。import numpy as np x np.array 2,2,3 1,2,3 y np.array 1,1,3 2,2,4 print x y numpy當中的陣列相乘是對應元素的乘積,與線性代數當中的矩陣相乘不一樣 輸入結果如下 ...