tensorflow中tfrecords使用介紹

2021-08-09 08:24:38 字數 3523 閱讀 4265

這篇文章主要講一下如何用tensorflow中的標準資料讀取方式簡單的實現對自己資料的讀取操作.

主要分為以下兩個步驟:(1)將自己的資料集轉化為 xx.tfrecords的形式;(2):在自己的程式中讀取並使用.tfrecords進行操作.

資料集轉換:為了便於講解,我們簡單製作了乙個資料,如下圖所示:

程式:[python]view plain

copy

import

tensorflow as tf  

import

numpy as np  

import

os  

from

pil 

import

image  

def_int64_feature(value):  

return

tf.train.feature(int64_list=tf.train.int64list(value=[value]))  

def_bytes_feature(value):  

return

tf.train.feature(bytes_list=tf.train.byteslist(value=[value]))  

defimg_to_tfrecord(data_path):  

rows = 256

cols = 256

depth = 3

writer = tf.python_io.tfrecordwriter('test.tfrecords'

)  labelfile=open("random.txt"

)  lines=labelfile.readlines()  

forline 

inlines:  

#print line

img_name = line.split(" ")[0

]#name

label = line.split(" ")[1

]#label

img_path = data_path+img_name  

img = image.open(img_path)  

img = img.resize((rows,cols))  

#img_raw = img.tostring()    

img_raw = img.tobytes()   

example = tf.train.example(features = tf.train.features(feature = ))  

writer.write(example.serializetostring())      

writer.close()   

if__name__ == 

'__main__'

:  current_dir = os.getcwd()      

data_path = current_dir + '/data/'

#name = current_dir + '/data'

print

('convert start'

)     

img_to_tfrecord(data_path)  

print

('done!'

)  

執行該段程式可以看到在dataset_tfrecord資料夾下面有test.tfrecord檔案生成。

在tf的session中呼叫這個生成的檔案

[python]view plain

copy

#encoding=utf-8 

# 設定utf-8編碼,方便在程式中加入中文注釋.

import

os  

import

scipy.misc  

import

tensorflow as tf  

import

numpy as np  

from

test 

import

*  import

matplotlib.pyplot as plt  

defread_and_decode(filename_queue):  

reader = tf.tfrecordreader()  

_, serialized_example = reader.read(filename_queue)  

features = tf.parse_single_example(serialized_example,features = )  

image = tf.decode_raw(features['image_raw'

], tf.uint8)  

image = tf.reshape(image, [output_size, output_size, 3

])  

image = tf.cast(image, tf.float32)  

#image = image / 255.0

return

image  

data_dir = '/home/sanyuan/dataset_animal/dataset_tfrecords/'

filenames = [os.path.join(data_dir,'train%d.tfrecords'

% ii) 

forii 

inrange(

1)] 

#如果有多個檔案,直接更改這裡即可

filename_queue = tf.train.string_input_producer(filenames)  

image = read_and_decode(filename_queue)  

with tf.session() as sess:      

coord = tf.train.coordinator()  

threads = tf.train.start_queue_runners(coord=coord)  

fori 

inxrange(

2):  

img = sess.run([image])  

print

(img[

0].shape)  

# 設定batch_size等於1.每次讀出來只有一張圖

plt.imshow(img[0

])  

plt.show()  

coord.request_stop()  

coord.join(threads)  

程式到這裡就已經處理完成了,當然在decorde的過程中也是可以進行一些預處理操作的,不過建議還是在製作資料集的時候進行,tfrecord使用的是佇列的方式進行讀取資料,這個對於多執行緒操作來說還是很方便的,只需要設定好格式,每次直接讀取就可以了.

Tensorflow中dynamic rnn的用法

1 api介面dynamic rnn cell,inputs,sequence length none,initial state none,dtype none,parallel iterations none,swap memory false,time major false,scope no...

TensorFlow中遮蔽warning的方法

tensorflow的日誌級別分為以下三種 tf cpp min log level 1 預設設定,為顯示所有資訊 tf cpp min log level 2 只顯示error和warining資訊 tf cpp min log level 3 只顯示error資訊 所以,當tensorflow出...

Tensorflow中TFRecord格式介紹

由於資料的 複雜性以及每乙個樣例中的資訊較為豐富,從而需要一種統一的格式來儲存資料,然而在tensorflow中提供了tfreord的格式來統一輸入資料的格式。tfrecord檔案中的資料是通過tf.train.example protoclo buffer的格式儲存 tf.train.exampl...