Mnist手寫數字自編碼 分類實驗

2022-09-09 03:51:07 字數 4682 閱讀 1020

import torch

import torch.nn as nn

import torch.nn.functional as f

import random

import numpy as np

import matplotlib.pyplot as plt

import torchvision

class autoencodenet(nn.module):

def __init__(self):

super(autoencodenet, self).__init__()

# 編碼

self.encoder = nn.sequential(

nn.linear(28*28, 128),

nn.tanh(),

nn.linear(128, 64),

nn.tanh(),

nn.linear(64, 12),

nn.tanh(),

nn.linear(12, 3), # 壓縮成3個特徵, 進行 3d 影象視覺化

)# 解壓

self.decoder = nn.sequential(

nn.linear(3, 12),

nn.tanh(),

nn.linear(12, 64),

nn.tanh(),

nn.linear(64, 128),

nn.tanh(),

nn.linear(128, 28*28),

nn.sigmoid(), # 激勵函式讓輸出值在 (0, 1)

)# 分類器

self.classfier = nn.sequential(

nn.linear(3,128),

nn.tanh(),

nn.linear(128,10),

nn.sigmoid(),

)def forward(self, x):

encoded = self.encoder(x)

decoded = self.decoder(encoded)

lable = self.classfier(encoded)

return encoded, decoded,lable

def train():

# 超引數

epoch = 20

batch_size = 64

lr = 0.005

download_mnist = false # 下過資料的話, 就可以設定成 false

n_test_img = 5 # 到時候顯示 5張看效果, 如上圖一

# mnist digits dataset

train_data = torchvision.datasets.mnist(

root='./mnist/',

train=true, # this is training data

transform=torchvision.transforms.totensor(), # converts a pil.image or numpy.ndarray to

# torch.floattensor of shape (c x h x w) and normalize in the range [0.0, 1.0]

download=download_mnist, # download it if you don't h**e it

)autoencoder = autoencodenet()

# autoencoder = torch.load("autoencoder_115.pkl")

optimizer = torch.optim.adam(autoencoder.parameters(), lr=lr)

# 編碼損失函式

loss_func = nn.mseloss()

# 分類損失函式

loss_func1 = nn.crossentropyloss()

# 資料載入

train_loader = torch.utils.data.dataloader(train_data,batch_size=128,shuffle=true)

losses =

fig,ax=plt.subplots(2,n_test_img)

plt.ion() # continuously plot

# 會出驗證的五張原圖

testimg = train_data.data[:5].view(-1,28,28).type(torch.floattensor)/255.

for i in range(5):

ax[0][i].imshow(testimg[i])

for epoch in range(epoch):

for step, (x, b_label) in enumerate(train_loader):

b_x = x.view(-1, 28*28) # batch x, shape (batch, 28*28)

b_y = x.view(-1, 28*28) # batch y, shape (batch, 28*28)

encoded, decoded ,lable= autoencoder(b_x)

# 求損失

MNIST手寫數字識別 分類應用入門

import tensorflow as tf import numpy as np from tensorflow import placeholder from tensorflow.examples.tutorials.mnist import input data mnist input d...

自編碼網路實現Mnist

usr bin python3 coding utf 8 time 2018 3 16 author machuanbin import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data impor...

mnist手寫數字識別

import tensorflow as tf import numpy as np from tensorflow.contrib.learn.python.learn.datasets.mnist import read data sets mnist read data sets f pyth...