tfrecords檔案製作

2021-08-28 23:10:46 字數 2638 閱讀 7527

#encoding=utf-8

import os

import tensorflow as tf

from pil import image

cwd ="e:\deep_learing\tensorflow_inception_v3\\retrain\data\\train"

classes =

#製作二進位制資料

def create_record():

writer = tf.python_io.tfrecordwriter("train.tfrecords")

for index, name in enumerate(classes):

class_path = cwd+"\\"+ name+"\\"

for img_name in os.listdir(class_path):

img_path = class_path + img_name

img = image.open(img_path)

img = img.resize((64, 64))

img_raw = img.tobytes() #將轉化為原生bytes

print(index,img_raw)

example = tf.train.example(

features=tf.train.features(feature=))

writer.write(example.serializetostring())

writer.close()

data = create_record()

#讀取二進位制資料

def read_and_decode(filename):

# 建立檔案佇列,不限讀取的數量

filename_queue = tf.train.string_input_producer([filename])

# create a reader from file queue

reader = tf.tfrecordreader()

# reader從檔案佇列中讀入乙個序列化的樣本

_, serialized_example = reader.read(filename_queue)

# get feature from serialized example

# 解析符號化的樣本

features = tf.parse_single_example(

serialized_example,

features=

)label = features['label']

img = features['img_raw']

img = tf.decode_raw(img, tf.uint8)

img = tf.reshape(img, [64, 64, 3])

img = tf.cast(img, tf.float32) * (1. / 255) - 0.5

label = tf.cast(label, tf.int32)

return img, label

if __name__ == '__main__':

if 0:

data = create_record("train.tfrecords")

else:

img, label = read_and_decode("train.tfrecords")

print("tengxing",img,label)

#使用shuffle_batch可以隨機打亂輸入 next_batch挨著往下取

# shuffle_batch才能實現[img,label]的同步,也即特徵和label的同步,不然可能輸入的特徵和label不匹配

# 比如只有這樣使用,才能使img和label一一對應,每次提取乙個image和對應的label

# shuffle_batch返回的值就是randomshufflequeue.dequeue_many()的結果

# shuffle_batch構建了乙個randomshufflequeue,並不斷地把單個的[img,label],送入佇列中

img_batch, label_batch = tf.train.shuffle_batch([img, label],

batch_size=4, capacity=2000,

min_after_dequeue=1000)

# 初始化所有的op

init = tf.initialize_all_variables()

with tf.session() as sess:

sess.run(init)

# 啟動佇列

threads = tf.train.start_queue_runners(sess=sess)

for i in range(5):

print(img_batch.shape,label_batch)

val, l = sess.run([img_batch, label_batch])

# l = to_categorical(l, 12)

print(val.shape, l)

上述**為讀取製作資料集的**,餵入網路訓練等有時間再說,還不知道怎麼處理。明天試一試把tfrecord檔案變成檢視標籤是否處理錯誤。

TFRecords 檔案的生成和讀取

1.tensorflow提供了tfrecords的格式來統一儲存資料,理論上,tfrecords可以儲存任何形式的資料。tfrecords檔案中的資料都是通過tf.train.example protocol buffer的格式儲存的。以下的 給出了tf.train.example的定義。messa...

tensorflow中tfrecords使用介紹

這篇文章主要講一下如何用tensorflow中的標準資料讀取方式簡單的實現對自己資料的讀取操作 主要分為以下兩個步驟 1 將自己的資料集轉化為 xx.tfrecords的形式 2 在自己的程式中讀取並使用.tfrecords進行操作 資料集轉換 為了便於講解,我們簡單製作了乙個資料,如下圖所示 程式...

tfrecords資料的讀寫

tfrecords write.py import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data import numpy as np 定義乙個writer 轉化資料的格式 包裝好乙個examp...