Pytorch實現Top1準確率和Top5準確率

2022-04-28 20:57:16 字數 1392 閱讀 3358

之前一直不清楚top1和top5是什麼,其實搞清楚了很簡單,就是兩種衡量指標,其中,top1就是普通的accuracy,top5比top1衡量標準更「嚴格」,

具體來講,比如一共需要分10類,每次分類器的輸出結果都是10個相加為1的概率值,top1就是這十個值中最大的那個概率值對應的分類恰好正確的頻率,而top5則是在十個概率值中從大到小排序出前五個,然後看看這前五個分類中是否存在那個正確分類,再計算頻率。pytorch實現如下:

def

evalutetop1(model, loader):

model.eval()

correct =0

total =len(loader.dataset)

for x,y in

loader:

x,y =x.to(device), y.to(device)

with torch.no_grad():

logits =model(x)

pred = logits.argmax(dim=1)

correct +=torch.eq(pred, y).sum().float().item()

#correct += torch.eq(pred, y).sum().item()

return correct /total

defevalutetop5(model, loader):

model.eval()

correct =0

total =len(loader.dataset)

for x, y in

loader:

x,y =x.to(device),y.to(device)

with torch.no_grad():

logits =model(x)

maxk = max((1,5))

y_resize = y.view(-1,1)

_, pred = logits.topk(maxk, 1, true, true)

correct +=torch.eq(pred, y_resize).sum().float().item()

return correct / total

注意:y_resize = y.view(-1,1)是非常關鍵的一步,在correct的運算中,關鍵就是要pred和y_resize維度匹配,而原來的y是[128],128是batch大小;

pred的維度則是[128,10],假設這裡是cifar10十分類;因此必須把y轉化成[128,1]這種維度,但是不能直接是y.view(128,1),因為遍歷整個資料集的時候,

最後乙個batch大小並不是128,所以view()裡面第乙個size就設為-1未知,而確保第二個size是1就行

topk函式的具體用法參見

TOP 1比不加TOP慢的疑惑

問題描述 有乙個查詢如下,去掉top 1的時候,很快就出來結果了,但加上top 1的時候,一般要2 3秒才出資料,何解?select top 1 a.invno from a,b where a.item b.itemnumber and b.ownercompanycode is notnull ...

TOP 1比不加TOP慢的疑惑

問題描述 有乙個查詢如下,去掉top 1的時候,很快就出來結果了,但加上top 1的時候,一般要2 3秒才出資料,何解?select top 1 a.invno from a,b where a.item b.itemnumber and b.ownercompanycode is notnull ...

Re Selenium新手爬取貓眼Top 100

from selenium import webdriver 引入瀏覽器物件 from selenium.webdriver.common.by import by from selenium.webdriver.common.keys import keys from selenium.webdr...