pytorch實現topk剪枝

2021-10-12 20:16:27 字數 3479 閱讀 7256

這篇部落格,以mnist資料集為例,對lstm的權重矩陣實現top-k剪枝(7,2),介紹了如何在pytorch框架下實現top-k剪枝。

可以使用如下**,檢視模型都含有哪些權重矩陣:

for name, _  in model.named_parameters():

print

(name)

矩陣每行含有28個引數,將其分為4組,每組7個元素,只保留最大的2個:

def

topk

(para, k)

:#對parameter,生成掩模矩陣,k表示保留前k個最大的

parameter = torch.

abs(para)

l =int(parameter.size()[

1]/7

) _, b = torch.topk(parameter[:,

0:7]

, k,

1, largest =

true

)for i in

range(1

,l):

_, b1 = torch.topk(parameter[

:,l*7:

(l+1)*

7], k,

1, largest =

true

)#該函式在dim=1上,保留前k個最大值,返回b1為前k個最大值的索引

b1 = b1 + i *

7 b = torch.cat(

(b,b1)

,dim =1)

c = torch.zeros(parameter.size()[

0], parameter.size()[

1],dtype = torch.

int)

#lstm權重矩陣為[4*28,28],所以這裡c也選這麼大

for i in

range

(c.size()[

0]):

for j in

range

(c.size()[

1]):

if j in b[i]

: c[i]

[j]=

1else

: c[i]

[j]=

0return c

c1,c2,c3,c4是根據四個權重矩陣生成的四個掩模矩陣(我定義的雙層lstm有四個權重矩陣),生成的掩模矩陣元素均為0或1

c1 = topk(rnn.lstm.weight_ih_l0.data,2)

c2 = topk(rnn.lstm.weight_hh_l0.data,2)

c3 = topk(rnn.lstm.weight_ih_l1.data,2)

c4 = topk(rnn.lstm.weight_hh_l1.data,

2)

生成的掩模矩陣如圖所示:

pytorch提供的自定義剪枝的模板,這裡分別將c1,c2,c3,c4作為掩模矩陣,這段**的意思就是,rnn模型中的lstm層的權重矩陣weight_ih_l0對應掩模矩陣c1, c1元素為1的位置,保留;c1為0的,weight_ih_l0對應的位置被剪枝掉,以此類推;

class

foobarpruningmethod1

(prune.basepruningmethod)

:"""prune every other entry in a tensor

"""pruning_type =

'unstructured'

defcompute_mask

(self, t, default_mask)

: mask = c1

return mask

class

foobarpruningmethod2

(prune.basepruningmethod)

:"""prune every other entry in a tensor

"""pruning_type =

'unstructured'

defcompute_mask

(self, t, default_mask)

: mask = c2

return mask

class

foobarpruningmethod3

(prune.basepruningmethod)

:"""prune every other entry in a tensor

"""pruning_type =

'unstructured'

defcompute_mask

(self, t, default_mask)

: mask = c3

return mask

class

foobarpruningmethod4

(prune.basepruningmethod)

:"""prune every other entry in a tensor

"""pruning_type =

'unstructured'

defcompute_mask

(self, t, default_mask)

: mask = c4

return mask

deffoobar_unstructured

(model)

: foobarpruningmethod1.

(model.lstm,

'weight_ih_l0'

) foobarpruningmethod2.

(model.lstm,

'weight_hh_l0'

) foobarpruningmethod3.

(model.lstm,

'weight_ih_l1'

) foobarpruningmethod3.

(model.lstm,

'weight_hh_l1'

)return model

rnn = foobar_unstructured(rnn)

#對預訓練完成的模型進行top-k剪枝

剪枝過後再訓練,會發現,剪枝後的訓練速度,明顯快於剪枝前。

剪枝後的矩陣如圖所示:

這篇部落格以mnist資料集為例,搭建了乙個含有雙層lstm,和fc層的模型,預訓練後對其進行top-k剪枝,詳細介紹了pytorch框架下的top-k剪枝過程;

pytorch 中的topk函式

1.函式介紹 最近在 中看到這兩個語句 maxk max topk pred output.topk maxk,1,true,true 這個函式是用來求output中的最大值或最小值,返回兩個引數 其一返回output中的最大值 或最小值 其二返回該值的索引。2.topk 函式原型 具體的用法參考p...

PyTorch中的topk函式詳解

聽名字就知道這個函式是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。用法torch.topk input,k,dim none,largest true,sorted true,out none tensor,longtensor topk最常用的場合就是求乙個樣本被網路...

TOP K問題(c 實現)

top k問題 c 實現 給定乙個陣列,找出陣列中最大的k個數或者最小的k個數,稱為top k問題。這是面試的常考題,解法可以是基於最大堆 最大堆排序 基於快速排序實現等等,文字基於快速排序的思想實現。我們不會採用快速排序的演算法來實現top k問題,但我們可以利用快速排序的思想,在陣列中隨機找乙個...