本地匯入Mnist的資料集的方法

2021-08-28 23:25:49 字數 2172 閱讀 2197

#載入mnist資料集

from keras.datasets import mnist

import gzip

import os

import numpy

local_file = "f:\python\deeplearning"

#(train_images, train_labels),(test_images, test_labels) = mnist.load_data()

train_images = 'train-images-idx3-ubyte.gz' #訓練集影象的檔名

train_labels = 'train-labels-idx1-ubyte.gz' #訓練集label的檔名

test_images = 't10k-images-idx3-ubyte.gz' #測試集影象的檔名

test_labels = 't10k-labels-idx1-ubyte.gz' #測試集label的檔名

#主要是下面的兩個函式實現的:

def extract_images(filename):

def extract_labels(filename, one_hot=false):

train_images = extract_images(os.path.join(local_file,train_images))

train_labels = extract_labels(os.path.join(local_file,train_labels))

test_images = extract_images(os.path.join(local_file,test_images))

test_labels = extract_labels(os.path.join(local_file,test_labels))

#網路架構

'''神經網路的核心元件是layer,它是一種資料處理模組,可以看成是資料過濾器。

'''from keras import models

from keras import layers

network = models.sequential()

network.add(layers.dense(512, activation='relu',input_shape=(28*28,)))

network.add(layers.dense(10, activation='softmax'))

#編譯步驟

'''要想訓練網路,需要選擇變異步驟的三個引數:

(1)損失函式(loss):衡量網路在訓練資料集上的效能;

(2)優化器(optimizer):基於訓練資料和損失函式更新網路的機制;

(3)訓練和測試中的監控指標(metric):如精度

'''network.compile(optimizer='rmsprop',

loss='categorical_crossentropy',

metrics=['accuracy'])

#資料預處理

train_images = train_images.reshape((60000, 28*28))

train_images = train_images.astype('float32')/255

test_images = test_images.reshape((10000, 28*28))

test_images = test_images.astype('float32')/255

#準備標籤

from keras.utils import to_categorical

train_labels = to_categorical(train_labels)

test_labels = to_categorical(test_labels)

#訓練網路

network.fit(train_images, train_labels, epochs = 5, batch_size = 256)

#效能評估

train_loss, train_acc = network.evaluate(test_images, test_labels)

print('test_acc:', train_acc)

print('test_error:', train_loss)

Keras匯入Mnist資料集出錯解決方案

exception url fetch failure on none winerror 10060 由於連線方在一段時間後沒有正確答覆或連線的主機沒有反應,連線嘗試失敗。def load data loads the mnist dataset.arguments path path where ...

消除匯入MNIST資料集發出的警告資訊

原本匯入資料集你僅需這樣 import mnist data from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data one hot true 但是由於...

MNIST資料集的處理

1 mnist資料集介紹 資料格式介紹 2 資料讀取 mnist資料集的讀取比較複雜,這裡給出兩種讀取方式。2.1 struct包讀取資料 nn網路中使用的讀取方法 2.2 torch.version和torch.utils.data.dataloader處理資料 import torch from...