MNIST資料集的格式轉換

2021-08-09 01:58:51 字數 3359 閱讀 3088

以前直接用的是sklearn或者tensorflow提供的mnist資料集,已經轉換為矩陣形式的資料格式。但是sklearn體用的資料集合並不全,一共只有3000+圖,每個圖是8*8的大小,但是原始資料並不是這樣的。

mnist資料集合的原始**為:

進入官網,發現有4個檔案,分別對應訓練集、測試集的影象和標籤:

官網給的資料集合並不是原始的影象資料格式,而是編碼後的二進位制格式:

影象的編碼為:

典型的head+data模式:前16個位元組分為4個整型資料,每個4位元組,分別代表:資料資訊des、影象數量(img_num),影象行數(row)、影象列數(col),之後的資料全部為畫素,每row*col個畫素構成一張圖,每個色素的值為(0-255)。

標籤的編碼為:

模式和前面的一樣,不同的是head只有8位元組,分別為des和標籤的數量(label_num).之後每乙個位元組代表乙個標籤,值為(0-9)。

弄清楚編碼後,就可以直接上**了:

import numpy as np

import struct

mnist_dir = r'./digit/'

deffetch_mnist

(mnist_dir,data_type):

train_data_path = mnist_dir + 'train-images.idx3-ubyte'

train_label_path = mnist_dir + 'train-labels.idx1-ubyte'

test_data_path = mnist_dir + 't10k-images.idx3-ubyte'

test_label_path = mnist_dir + 't10k-labels.idx1-ubyte'

# train_img

with open(train_data_path, 'rb') as f:

data = f.read(16)

des,img_nums,row,col = struct.unpack_from('>iiii', data, 0)

train_x = np.zeros((img_nums, row*col))

for index in range(img_nums):

data = f.read(784)

if len(data) == 784:

train_x[index,:] = np.array(struct.unpack_from('>' + 'b' * (row * col), data, 0)).reshape(1,784)

f.close()

# train label

with open(train_label_path, 'rb') as f:

data = f.read(8)

des,label_nums = struct.unpack_from('>ii', data, 0)

train_y = np.zeros((label_nums, 1))

for index in range(label_nums):

data = f.read(1)

train_y[index,:] = np.array(struct.unpack_from('>b', data, 0)).reshape(1,1)

f.close()

# test_img

with open(test_data_path, 'rb') as f:

data = f.read(16)

des, img_nums, row, col = struct.unpack_from('>iiii', data, 0)

test_x = np.zeros((img_nums, row * col))

for index in range(img_nums):

data = f.read(784)

if len(data) == 784:

test_x[index, :] = np.array(struct.unpack_from('>' + 'b' * (row * col), data, 0)).reshape(1, 784)

f.close()

# test label

with open(test_label_path, 'rb') as f:

data = f.read(8)

des, label_nums = struct.unpack_from('>ii', data, 0)

test_y = np.zeros((label_nums, 1))

for index in range(label_nums):

data = f.read(1)

test_y[index, :] = np.array(struct.unpack_from('>b', data, 0)).reshape(1, 1)

f.close()

if data_type == 'train':

return train_x, train_y

elif data_type == 'test':

return test_x, test_y

elif data_type == 'all':

return train_x, train_y,test_x, test_y

else:

print('type error')

if __name__ == '__main__':

tr_x, tr_y, te_x, te_y = fetch_mnist(mnist_dir,'all')

import matplotlib.pyplot as plt # plt 用於顯示

img_0 = tr_x[59999,:].reshape(28,28)

plt.imshow(img_0)

print(tr_y[59999,:])

img_1 = te_x[500,:].reshape(28,28)

plt.imshow(img_1)

print(te_y[500,:])

plt.show()

執行結果:

簡單 mnist 資料集轉為csv格式讀取

對於剛入門ai的童鞋來說,mnist 資料集就相當於剛接觸程式設計時的 hello world 一樣,具有別樣的意義,後續許多機器學習的演算法都可以用該資料集來進行簡單測試。也給出了資料集的格式,但是要手動解析這些資料也是有點複雜的。以下 的功能是將訓練集和訓練標籤整合到乙個csv檔案裡 測試檔案同...

MNIST資料集的處理

1 mnist資料集介紹 資料格式介紹 2 資料讀取 mnist資料集的讀取比較複雜,這裡給出兩種讀取方式。2.1 struct包讀取資料 nn網路中使用的讀取方法 2.2 torch.version和torch.utils.data.dataloader處理資料 import torch from...

MNIST資料集介紹

mnist資料集包含了6w張作為訓練資料,1w作為測試資料。在mnist資料集中,每一張都代表了0 9中的乙個數字,的大小都是28 28,且數字都會出現在的正中間。資料集包含了四個檔案 t10k images idx3 ubyte.gz 測試資料 t10k labels idx1 ubyte.gz ...