通過mnist資料庫學習tfrecords的使用

2021-08-21 16:51:13 字數 3917 閱讀 1893

在用tensorflow跑實驗的時候,我原本資料是用sqlite3存資料,然後再從資料庫中選擇相應的資料出來,但是這樣太耗時了,於是便想要用tfrecord來存資料。於是通過mnist資料來試驗一下。

先載入:

import tensorflow as tf

import numpy as np

import os

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/home/jianyan/data/mnist", one_hot=true)

訓練集包括55000個28×28畫素的影象。這些784(28x28)畫素值以單個維度向量的形式被平坦化。所有這樣的55000個畫素向量(每個影象乙個)的集合被儲存為numpy陣列的形式(55000,784),並被稱為mnist.train.images。

這些55000個訓練影象中的每乙個與表示該影象屬於的類的標籤相關聯。一共有10個這樣的類(0,1,2 … 9)。標籤以一種熱編碼形式的表示。因此標籤被儲存為numpy形狀陣列的形式(55000,10)被稱為mnist.train.labels。

讀tfrecord裡面有多少條資料

tfrecords_filename = 'mnist.tfrecords'

count = 0

for r in tf.python_io.tf_record_iterator(tfrecords_filename):

count += 1

讀tfrecord

filename_queue = tf.train.string_input_producer([tfrecords_filename],num_epochs=none) #讀入流中

reader = tf.tfrecordreader()

_, serialized_example = reader.read(filename_queue)

features = tf.parse_single_example(serialized_example,

features=)

img= tf.decode_raw(features['sample'],tf.float32)

img= tf.reshape(disk, [28,28])

label = tf.decode_raw(features['label'],tf.float64)

label = tf.reshape(label, [10])

init=tf.global_variables_initializer()

with tf.session() as sess:

sess.run(init)

coord = tf.train.coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

sample, l = sess.run([img, label]) # 每次讀一條資料

因為上面的**只是讀一次資料,那我們如果想一次讀出batch_size的資料或者讀出全部資料那怎麼辦呢?可以用下面的函式來實現:

def decode_from_tfrecords(filename_queue, is_batch, batch_size):

reader = tf.tfrecordreader()

_, serialized_example = reader.read(filename_queue)

features = tf.parse_single_example(serialized_example,

features=)

img= tf.decode_raw(features['sample'],tf.float32)

img= tf.reshape(disk, [28,28])

label = tf.decode_raw(features['label'],tf.float64)

label = tf.reshape(label, [10])

if is_batch:

min_after_dequeue = 10

capacity = min_after_dequeue+3*batch_size

img, label = tf.train.shuffle_batch([img, label],

batch_size=batch_size,

num_threads=3,

capacity=capacity,

min_after_dequeue=min_after_dequeue)

return img, label

通過 decode_from_tfrecords 函式,可以設定一次讀多少資料:

# 每次隨機讀取讀 batch_size=128 條資料送進去訓練

filename_queue = tf.train.string_input_producer([tfrecords_filename],num_epochs=none) #讀入流中

train_image, train_label = decode_from_tfrecords(filename_queue, true, 128)

# 一次性讀完全部的資料

'''tfrecords_filename = 'mnist.tfrecords'

count = 0

for r in tf.python_io.tf_record_iterator(tfrecords_filename):

count += 1

'''filename_queue = tf.train.string_input_producer([tfrecords_filename],num_epochs=none) #讀入流中

test_image_all, test_label_all = decode_from_tfrecords(filename_queue, true, count)

再用 sess.run 取資料即可。

注意:原先資料是什麼格式的,在讀資料的時候也要設定成什麼格式的,如:

img= tf.decode_raw(features['sample'],tf.float32) # 原先的資料是 float32

機器學習 MATLAB讀取mnist資料庫

最近要做 優化理論基礎 的課程大作業,需要用到mnist這個手寫識別資料庫,在網上查了一下如何使用,分享在這裡,以饗讀者。mnist是紐約大學 nyu yann lecun在上個世紀90年代做的乙個關於手寫數字識別的資料庫。該資料庫提出的motivation是為了解決美國郵政zip code機器識別...

MNIST資料庫格式的解析和生成

該資料格式是bytestream,無論是訓練樣本還是測試樣本,其影象資料檔案均在開頭有乙個2051的標誌,之後便是影象的個數 行值 列值,緊接著按行讀取所有的影象,且影象資料間無間隔 label coding utf 8 from future import absolute import from...

Oracle學習筆記(通過游標操縱資料庫)

不是原創哦 來自天極網 在游標for迴圈中使用查詢 在游標for迴圈中可以定義查詢,由於沒有顯式宣告所以游標沒有名字,記錄名通過游標查詢來定義。decalre v tot salary emp.salary type begin for r dept in select deptno,dname f...