機器學習 利用pickle載入cifar檔案

2022-04-28 01:29:13 字數 2930 閱讀 2449

然後奉獻**

def

load_cifar10(root):

"""載入cifar全部資料

"""xs =

ys =

for b in range(1, 2):

f = os.path.join(root, '

data_batch_%d

' %(b,))

x, y =load_cifar_batch(f)

#將所有batch整合起來

xtr = np.concatenate(xs) #

使變成行向量,最終xtr的尺寸為(50000,32,32,3)

ytr =np.concatenate(ys)

delx, y

xte, yte = load_cifar_batch(os.path.join(root, '

test_batch'))

return xtr, ytr, xte, yte

找到cifar資料夾下面的二進位制檔案:

然後對每次的檔案進行批處理:

def

load_cifar_batch(filename):

"""直接讀入cifar資料集的乙個batch

"""with open(filename, 'rb

') as f:

datadict = p.load(f, encoding='

latin1')

x = datadict['

data']

y = datadict['

labels']

x = x.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("

float")

y =np.array(y)

return x, y

測試:

import

numpy as np

#載入cifar-10資料集

cifar10_dir = '

data\cifar10\cifar-10-batches-py

'x_train, y_train, x_test, y_test =load_cifar10(cifar10_dir)

#看看資料集中的一些樣本:每個類別展示一些

print('

訓練資料的形狀:

', x_train.shape)

print('

訓練集標籤的形狀:

', y_train.shape)

print('

測試資料的形狀:

', x_test.shape)

print('

測試資料的形狀:

', y_test.shape)

import

pickle as p

import

osdef

load_cifar_batch(filename):

"""載入cifar資料集的乙個batch

"""with open(filename, 'rb

') as f:

datadict = p.load(f, encoding='

latin1')

x = datadict['

data']

y = datadict['

labels']

x = x.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("

float")

y =np.array(y)

return

x, y

defload_cifar10(root):

"""載入cifar全部資料

"""xs =

ys =

for b in range(1, 2):

f = os.path.join(root, '

data_batch_%d

' %(b,))

x, y =load_cifar_batch(f)

#將所有batch整合起來

xtr = np.concatenate(xs) #

使變成行向量,最終xtr的尺寸為(50000,32,32,3)

ytr =np.concatenate(ys)

delx, y

xte, yte = load_cifar_batch(os.path.join(root, '

test_batch'))

return

xtr, ytr, xte, yte

if__name__ == '

__main__':

import

numpy as np

#載入cifar-10資料集

cifar10_dir = '

data\cifar10\cifar-10-batches-py

'x_train, y_train, x_test, y_test =load_cifar10(cifar10_dir)

#看看資料集中的一些樣本:每個類別展示一些

print('

training data shape:

', x_train.shape)

print('

training labels shape:

', y_train.shape)

print('

test data shape:

', x_test.shape)

print('

test labels shape:

', y_test.shape)

機器學習之 載入資料

import numpy as np import urllib url with dataset url raw data urllib.request.urlopen url load the csv file as a numpy matrix dataset np.loadtxt raw d...

機器學習 實戰 利用tensorflow識別衣物

fashion mnist 是乙個衣物資料集,整合在keras中可以直接使用。本文記錄了一步一步利用 fashion minst 的資料庫訓練 tensorflow 神經網路。import tensorflow as tf from tensorflow import keras import nu...

利用機器學習檢測惡意活動

研究人員開始使用無監督機器學習演算法來對大量網域名稱資訊資料集進行分析,以發現新的威脅並進行攔截。一旦惡意網域名稱開始活躍,機器學習演算法就可以快速識別出攻擊活動的惡意網域名稱。背景比如在一類 的惡意活動中使用了許多個網域名稱,並持續了一段時間。這些活動一般利用像世界盃這類近期的熱點事件,網域名稱一...