cifar10資料的讀取

2021-08-21 20:38:55 字數 1794 閱讀 1504

cifar10資料集檔案結構如圖所示,其中data_batch_1~5.bin是訓練集,每個檔案包含10000個樣本,test_batch.bin是測試集,包含10000個樣本。

開啟任意乙個檔案,發現是一堆二進位制資料,

其中乙個樣本由3037個位元組組成,其中第乙個位元組是label,剩餘3036(32*32*3)個位元組是image,每個檔案由連續的10000個樣本組成,具體的讀取過程參考下面**及注釋。

#獲取image和label

defget_input

():#檔名佇列

filenames = tf.train.match_filenames_once(data_dir+'/data_batch_*')

filename_queue = tf.train.string_input_producer(filenames)

#cifar10的資料格式:

#乙個樣本由3037個位元組組成,其中第乙個位元組是label,剩餘3036(32*32*3)個位元組是image

#每個檔案由連續的10000個樣本組成,共5個檔案

image_bytes = image_size * image_size * image_depth

record_bytes = image_bytes + label_bytes

#使用fixedlengthrecordreader讀取樣本,每次讀取乙個

reader = tf.fixedlengthrecordreader(record_bytes=record_bytes)

#獲取樣本的值

_,value = reader.read(filename_queue)

#讀出來的樣本為二進位制的字串格式,轉化為uint8的格式

raw_value = tf.decode_raw(value,tf.uint8)

#劃分label和image

labels = tf.cast(tf.strided_slice(raw_value,[0],[1]),tf.int32)

#由於image是按照(depth,height,width)的格式儲存的,因此讀出來後還要將其轉化為(height,width,depth)的格式

images = tf.reshape(

tf.strided_slice(raw_value,[label_bytes],[label_bytes+image_bytes]),

[image_depth,image_size,image_size]

)images = tf.transpose(images,[1,2,0])

images = tf.cast(images,tf.float32)

#資料型別:label是int32,image是範圍為0-1的float32

#標準化處理:減去平均值並除以方差,使得樣本均值為0,方差為1

standard_images = tf.image.per_image_standardization(images)

#官方bug,得加上

standard_images.set_shape([resize_size,resize_size,3])

labels.set_shape([1])

return standard_images,labels

CIFAR 10資料集讀取

參考 1 使用讀取方式pickle def unpickle file import pickle with open file,rb as fo dict pickle.load fo,encoding bytes return dict 返回的是乙個python字典 2 通過字典的內建函式,獲取...

資料集處理 CIFAR10

transform transforms.compose transforms.totensor transforms.normalize 0.5,0.5,0.5 0.5,0.5,0.5 trainset torchvision.datasets.cifar10 root cifar10 train...

讀取和歸一化CIFAR10

讀取和歸一化cifar10 torchvision.datasets.cifar10 root,train true,transform none,target transform none,download false 引數說明 root cifar 10 batches py 的根目錄 trai...