對pytorch的函式中的group引數的作用介紹

2022-10-04 09:09:11 字數 1904 閱讀 4275

1.當設定group=1時:

conv = nn.conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=1)

conv.weight.data.size()

返回:torch.size([6, 6, 1, 1])

另乙個例子:

conv = nn.conv2d(in_channels=6, out_channels=3, kernel_size=1, groups=1)

conv.weight.data.size()

返回:torch.size([3, 6, 1, 1])

可見第乙個值為out_channels的大小,第二個值為in_channels的大小,後面兩個值為kernel_size

2.當設定為group=2時

conv = nn.conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=2)

conv.weight.data.size()

返回:torch.size([6, 3, 1, 1])

3.當設定group=3時

conv = nn.conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=3)

conv.weight.data.size()

返回:torch.size([6, 2, 1, 1])

4.當設定group=4時

conv = nn.conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=4)

conv.weight.data.size()

報錯:valueerror: in_channels must www.cppcns.combe divisible by groups

groups的值必須能整除in_channels

注意:同樣也要求groups的值必須能整除out_channels,舉例:

conv = nn.conv2d(in_channels=6, out_channels=3, kernel_size=1, groups=2)

conv.weight.data.size()

否則會報程式設計客棧錯:

valueerror: out_channels must be divisible by groups

5.當設定group=in_channels時

conv = nn.conv2d(in_channels=6, out_channels=6, kernel_size=1, groups=6)

conv.weight.data.size()

返回:torch.size([6, 1, www.cppcns.com1, 1])

所以當group=1時,該卷積層需要6*6*1*1=36個引數,即需要6個6*1*1的卷積核

計算時就是6*hwww.cppcns.com_in*w_in的輸入整個乘以乙個6*1*1的卷積核,得到輸出的乙個channel的值,即1*h_out*w_out。這樣經過6次與6個卷積核計算就能夠得到6的結果了

如果將group=3時,卷積核大小為torch.size([6, 2, 1, 1]),即6個2*1*1的卷積核,只需要需要6*2*1*1=12個引數

那麼每組計算就只被in_channels/groups=2個channels的卷積核計算,當然這也會將輸入分為三份大小為2*h_in*w_in的小輸入,分別與2*1*1大小的卷積核進行三次運算,然後將得到的3個2*h_out*w_out的小輸出concat起來得到最後的6*h_out*w_out輸出

在實際實驗中,同樣的網路結構下,這種分組的卷積效果是好於未分組的卷積的效果的。

本文標題: 對pytorch的函式中的group引數的作用介紹

本文位址:

對Pytorch中backward()函式的理解

寫在第一句 這個部落格解釋的也很好,參考了很多 pytorch中的自動求導函式backward 所需引數含義 所以切入正題 backward 函式中的引數應該怎麼理解?官方 如果需要計算導數,可以在tensor上呼叫.backward 1.如果tensor是乙個標量 即它包含乙個元素的資料 則不需要...

我對pytorch中gather函式的一點理解

torch.gather input dim,index,out none tensor torch.gather input dim,index,out none tensor gathers values along an axis specified by dim.for a 3 d tens...

對Pytorch 中的contiguous理解說明

最近遇到這個函式,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答並加上自己的理解。在pytorch中,只有很少幾個操作是不改變tensor的內容本身,而只是重新定義下標與元素的對應關係的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是元資料。這些操作是 ...