TFRecord格式介紹

2021-09-23 15:27:31 字數 3017 閱讀 7271

tfrecord檔案中的資料都是通過tf.train.example protocol buffer的格式儲存的,tfrecord格式是一種二進位制檔案,它能夠更好的利用記憶體,更方便複製和移動,並且不需要單獨的標籤檔案;我們可以寫一段**獲取你的資料,然後將資料填入到example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為乙個字串,並且通過tf.python_io.tfrecordwriter寫入到tfrecord檔案中去

從tfrecords檔案中讀取資料, 可以使用tf.tfrecordreadertf.parse_single_example解析器。這個操作可以將example協議記憶體塊(protocol buffer)解析為張量。

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

import numpy as np

# 生成整數型的屬性

def _int64_feature(value):

return tf.train.feature(int64_list=tf.train.int64list(value=[value]))

# 生成字串型的屬性

def _bytes_feature(value):

return tf.train.feature(bytes_list=tf.train.byteslist(value=[value]))

mnist = input_data.read_data_sets('/path/to/mnist_data/', dtype=tf.uint8, one_hot=true)

images = mnist.train.images

labels = mnist.train.labels

# 訓練資料的影象解析度,這裡可以作為乙個屬性儲存在tfrecord中

pixels = images.shape[1]

num_examples = mnist.train.num_examples

# 輸出tfrecord的位址

file_name = '/path/to/output.tfrecords'

# 建立乙個writer來寫tfrecord檔案

writer = tf.python_io.tfrecordwriter(file_name)

# writer = tf.python_io.tfrecordwriter(file_name)

for index in range(num_examples):

# 將影象轉換為乙個字串

image_raw = images[index].tostring()

# 講乙個樣例轉換為example protocol buffer, 並將所有的資訊寫入這個資料結構

example = tf.train.example(features=tf.train.features(feature=))

# 將乙個example寫入tfrecord檔案

writer.write(example.serializetostring())

print('寫入成功')

writer.close()

輸出結果:

extracting /path/to/mnist_data/train-images-idx3-ubyte.gz

extracting /path/to/mnist_data/train-labels-idx1-ubyte.gz

extracting /path/to/mnist_data/t10k-images-idx3-ubyte.gz

extracting /path/to/mnist_data/t10k-labels-idx1-ubyte.gz

寫入成功

import tensorflow as tf

# 建立乙個reader來讀取tfrecord檔案中的樣例

reader = tf.tfrecordreader()

# 建立乙個佇列來維護輸入檔案列表

file_queue = tf.train.string_input_producer(['/path/to/output.tfrecords'])

# 從檔案中讀出乙個樣例,也可以使用read_up_to函式一次性讀取多個樣例

_, serialized_example = reader.read(file_queue)

# 解讀入的乙個樣例,如果需要解析多個樣例,使用parse_example函式

features = tf.parse_single_example(serialized_example, features=)

# tf.decode_raw可以將字串解析為影象對應的畫素陣列

image = tf.decode_raw(features['image_raw'], tf.uint8)

label = tf.cast(features['labels'], tf.int32)

pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.session()

# 啟動多執行緒處理輸入資料

coord = tf.train.coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

print('------------')

# 每次執行可以讀取tfrecord檔案中的乙個樣例,當所有的樣例都讀完的時候,在此樣例中程式會重頭讀取

# 讀取前十個資料

for i in range(10):

print(sess.run([image, label, pixels]))

執行結果:

輸出mnist資料集前十個畫素陣列和對應的label、pixels

轉,讀csv生成tf record格式

import timeit 檢視執行開始到結束所用的時間 import tensorflow as tf import os defgenerate tfrecords input filename,output filename print nstart to convert to n forma...

自製資料集轉為tfrecord格式

教程 上面這個教程中轉化格式那部分很好用,親測有效,python程式改為自己的路徑執行的時候,因為是第一次轉化這種格式,也沒有怎麼用過遠端伺服器,所以在csv轉tfrecord的時候,一直出現這個錯誤 網上也找不到,最後問師兄的時候,突然發現是疏忽了下面這兩行 是需要自己輸入路徑的,導致卡了一整天,...

tensorflow製作tfrecord格式資料集

encoding utf 8 import os import tensorflow as tf from pil import image cwd os.getcwd classes 製作二進位制資料 defcreate record writer tf.python io.tfrecordwri...