PyTorch中的topk函式詳解

2021-10-01 18:37:54 字數 2372 閱讀 4153

聽名字就知道這個函式是用來求tensor中某個dim的前k大或者前k小的值以及對應的index。

用法

torch.topk(input, k, dim=none, largest=true, sorted=true, out=none) -> (tensor, longtensor)

topk最常用的場合就是求乙個樣本被網路認為前k個最可能屬於的類別。我們就用這個場景為例,說明函式的使用方法。

假設乙個tensor f∈r

df \in r^

f∈rn×d

,n是樣本數目,一般等於batch size, d是類別數目。我們想知道每個樣本的最可能屬於的那個類別,其實可以用torch.max得到。如果要使用topk,則k應該設定為1。

import torch

pred = torch.randn((4

,5))

print

(pred)

values, indices = pred.topk(

1, dim=

1, largest=

true

,sorted

=true

)print

(indices)

# 用max得到的結果,設定keepdim為true,避免降維。因為topk函式返回的index不降維,shape和輸入一致。

_, indices_max = pred.

max(dim=

1, keepdim=

true

)print

(indices_max == indices)

# pred

tensor([[

-0.1480,-

0.9819,-

0.3364

,0.7912,-

0.3263],

[-0.8013,-

0.9083

,0.7973

,0.1458,-

0.9156],

[-0.2334,-

0.0142,-

0.5493

,0.0673

,0.8185],

[-0.4075,-

0.1097

,0.8193,-

0.2352,-

0.9273]]

)# indices, shape為 【4,1】,

tensor([[

3],#【0,0】代表 第乙個樣本最可能屬於第一類別[2

],# 【1, 0】代表第二個樣本最可能屬於第二類別[4

],[2

]])# indices_max等於indices

tensor([[

true],

[true],

[true],

[true]]

)

現在在嘗試一下k=2

import torch

pred = torch.randn((4

,5))

print

(pred)

values, indices = pred.topk(

2, dim=

1, largest=

true

,sorted

=true

)# k=2

print

(indices)

# pred

tensor([[

-0.2203,-

0.7538

,1.8789

,0.4451,-

0.2526],

[-0.0413

,0.6366

,1.1155

,0.3484

,0.0395],

[0.0365

,0.5158

,1.1067,-

0.9276,-

0.2124],

[0.6232

,0.9912,-

0.8562

,0.0148

,1.6413]]

)# indices

tensor([[

2,3]

,[2,

1],[

2,1]

,[4,

1]])

可以發現indices的shape變成了【4, k】,k=2。

其中indices[0] = [2,3]。其意義是說明第乙個樣本的前兩個最大概率對應的類別分別是第3類和第4類。

大家可以自行print一下values。可以發現values的shape和indices的shape是一樣的。indices描述了在values中對應的值在pred中的位置。

pytorch 中的topk函式

1.函式介紹 最近在 中看到這兩個語句 maxk max topk pred output.topk maxk,1,true,true 這個函式是用來求output中的最大值或最小值,返回兩個引數 其一返回output中的最大值 或最小值 其二返回該值的索引。2.topk 函式原型 具體的用法參考p...

pytorch實現topk剪枝

這篇部落格,以mnist資料集為例,對lstm的權重矩陣實現top k剪枝 7,2 介紹了如何在pytorch框架下實現top k剪枝。可以使用如下 檢視模型都含有哪些權重矩陣 for name,in model.named parameters print name 矩陣每行含有28個引數,將其分...

pytorch中的gather函式

from 今天剛開始接觸,讀了一下documentation,寫乙個一開始每太搞懂的函式gather b torch.tensor 1,2,3 4,5,6 print bindex 1 torch.longtensor 0,1 2,0 index 2 torch.longtensor 0,1,1 0...