Pytorch框架下的Kaggle貓狗識別

2022-09-12 19:09:15 字數 3879 閱讀 7621

kaggle貓狗識別,訓練集為25000張打好標籤的貓狗,測試集為劃分出來的5000張貓和狗的**。

使用的網路為le-net,其結構圖如下

**如下:

import

torch

import

numpy

import

matplotlib.pylab as plt

from torch.autograd import

variable

from torchvision import

transforms,datasets

import

cv2import

torch.nn as nn

from tensorboardx import

summarywriter

print

(torch.cuda.is_**ailable())

print

(torch.cuda.current_device())

print

(torch.cuda.get_device_capability(),torch.cuda.get_device_name())

image_transform =transforms.compose([

transforms.resize(84),

transforms.centercrop(84),

transforms.totensor(),

transforms.normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

])train_dataset = datasets.imagefolder(root =r'

d:\python\project\venv\include\kaggle資料集\train/

',transform=image_transform)

train_loader = torch.utils.data.dataloader(train_dataset,4,true)

test_dataset = datasets.imagefolder(root = r'

d:\python\project\venv\include\kaggle資料集\selftest/

',transform=image_transform)

test_loader = torch.utils.data.dataloader(test_dataset,4,true)

use_dataset = datasets.imagefolder(root = r'

d:\python\project\venv\include\kaggle資料集\use/

',transform=image_transform)

use_loader = torch.utils.data.dataloader(use_dataset,1,true)

class

net(nn.module):

def__init__

(self):

super(net, self).

__init__

() self.conv1 = nn.sequential(nn.conv2d(3,6,5),nn.relu(),nn.maxpool2d(2,2))

self.conv2 = nn.sequential(nn.conv2d(6,16,5),nn.relu(),nn.maxpool2d(2,2))

self.fc1 = nn.sequential(nn.linear(16*18*18,1024),nn.relu())

self.fc2 = nn.sequential(nn.linear(1024,512),nn.relu(),nn.linear(512,2))

defforward(self,input):

x =self.conv1(input)

x =self.conv2(x)

x = x.view(-1,16*18*18)

x =self.fc1(x)

x =self.fc2(x)

x =torch.sigmoid(x)

return

xnet =net()

net.cuda()

net.load_state_dict(torch.load(r

'd:\python\project\venv\include\nural_net/torch_net10(74%)'))

writer = summarywriter('

figure')

opti = torch.optim.sgd(net.parameters(),lr = 0.0005,momentum=0.9)

loss_fun =nn.crossentropyloss()

stepp =0

for epoch in range(5):

run_loss = 0.0

for i,data in

enumerate(train_loader,0):

images,labels =data

images,labels =images.cuda(),labels.cuda()

images,labels =variable(images),variable(labels)

out =net(images)

loss =loss_fun(out,labels)

opti.zero_grad()

loss.backward()

opti.step()

print

(i)

print

(loss.item())

run_loss +=loss.item()

if i % 10 == 9:

stepp += 10writer.add_scalar(

'train_loss

',loss.item(),stepp)

if i%2000 == 1999:

print('

[%d %d]: %.4f

'%(epoch+1,i+1,run_loss))

run_loss = 0.0

print('

finish train')

torch.s**e(net.state_dict(),r

'd:\python\project\venv\include\nural_net/torch_net0')

print('

s**e finish')

net.eval()

correct = 0.0total =0

for data in

test_loader:

images, labels =data

images, labels =images.cuda(), labels.cuda()

images, labels =variable(images), variable(labels)

out =net(images)

_,prediction = torch.max(out.data,1)

print('

prediction

',prediction)

print('

labels

',labels)

total +=labels.size(0)

correct += ((prediction ==labels).data).sum()

print('

the acc is %.4f

'%(correct*100/total))

pytorch框架下語義分割訓練實踐(一)

目錄 環境準備 開始訓練 torch 1.1.0 torchvision 0.3.0 tqdm 4.32.2 tensorboard 1.14.0 pillow 6.2.0 opencv python 4.1.0.25 這裡面幾個只有torch比較大,其他都很小,很快就裝完,安裝庫前務必裝下pip,...

Java集合框架(下)

上篇博文介紹了collection集合 這篇博文將介紹map集合。首先map和collection都是乙個介面,具體的實現都由下面的實現類實現功能。它們最大的區別就是collection是單列集合,map是雙列集合 泛型引數是乙個鍵 值對 map集合與set類似,主要有hashmap treemap...

Foundation框架下的基本類

功能 將oc和c語言當中的基本資料型別轉換成例項物件 oc中的字串具有強大的功能,即封裝性極強,我們只需要找到相應的api,就可以對字串做相應操作。oc中字串分為 不可變字串 和 可變字串 其中 可變字串 是 不可變字串 的子類。在ios開發中 字串通常用作顯示文字,即作為 uilable uite...