關於Tensorflow 的資料讀取環節

2021-09-25 04:55:20 字數 4111 閱讀 4820

tensorflow讀取資料的一般方式有下面3種:

preloaded方法的簡單例子

1

import

tensorflow as tf23

"""定義常量

"""4 const_var = tf.constant([1, 2, 3])

5"""

定義變數

"""6 var = tf.variable([1, 2, 3])78

with tf.session() as sess:

9sess.run(tf.global_variables_initializer())

10print

(sess.run(var))

11print(sess.run(const_var))

feed方法

可以在tensorflow運算圖的過程中,將資料傳遞到事先定義好的placeholder中。方法是在呼叫session.run函式時,通過feed_dict引數傳入。簡單例子:

1

import

tensorflow as tf

2"""

定義placeholder

"""3 x1 =tf.placeholder(tf.int16)

4 x2 =tf.placeholder(tf.int16)

5 result = x1 +x2

6"""

定義feed_dict

"""7 feed_dict =

11"""

執行圖"""

12with tf.session() as sess:

13print(sess.run(result, feed_dict=feed_dict))

上面的兩個方法在面對大量資料時,都存在效能問題。這時候就需要使用到第3種方法,檔案讀取,讓tensorflow自己從檔案中讀取資料

從檔案中讀取資料

步驟:獲取檔名列表list

建立檔名佇列,呼叫tf.train.string_input_producer,引數包含:檔名列表,num_epochs【定義重複次數】,shuffle【定義是否打亂檔案的順序】

定義對應檔案的閱讀器》* tf.readerbase >* tf.tfrecordreader >* tf.textlinereader >* tf.wholefilereader >* tf.identityreader >* tf.fixedlengthrecordreader

解析器 >* tf.decode_csv >* tf.decode_raw >* tf.image.decode_image >* …

預處理,對原始資料進行處理,以適應network輸入所需

生成batch,呼叫tf.train.batch() 或者 tf.train.shuffle_batch()

prefetch【可選】使用預載入佇列slim.prefetch_queue.prefetch_queue()

啟動填充佇列的執行緒,呼叫tf.train.start_queue_runners

圖引用自

讀取檔案格式舉例

tensorflow支援讀取的檔案格式包括:csv檔案,二進位制檔案,tfrecords檔案,影象檔案,文字檔案等等。具體使用時,需要根據檔案的不同格式,選擇對應的檔案格式閱讀器,再將檔名隊列傳為引數,傳入閱讀器的read方法中。方法會返回key與對應的record value。將value交給解析器進行解析,轉換成網路能進行處理的tensor。

csv檔案讀取:

解析器:tf.decode_csv

1 filename_queue = tf.train.string_input_producer(["

file0.csv

", "

file1.csv"])

2"""

閱讀器"""

3 reader =tf.textlinereader()

4 key, value =reader.read(filename_queue)

5"""

解析器"""

6 record_defaults = [[1], [1], [1], [1]]

7 col1, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults)

8 features = tf.concat([col1, col2, col3, col4], axis=0)910

with tf.session() as sess:

11 coord =tf.train.coordinator()

12 threads = tf.train.start_queue_runners(coord=coord)

13for i in range(100):

14 example =sess.run(features)

15coord.request_stop()

16 coord.join(threads)

二進位制檔案讀取:

解析器:tf.decode_raw

影象檔案讀取:

解析器:tf.image.decode_image, tf.image.decode_gif, tf.image.decode_jpeg, tf.image.decode_png

tfrecords檔案讀取

tfrecords檔案是tensorflow的標準格式。要使用tfrecords檔案讀取,事先需要將資料轉換成tfrecords檔案,具體可察看:convert_to_records.py 在這個指令碼中,先將資料填充到tf.train.example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為字串,再通過tf.python_io.tfrecordwriter寫入到tfrecords檔案中去。

解析器:tf.parse_single_example

又或者使用slim提供的簡便方法:slim.dataset.data以及slim.dataset_data_provider.datasetdataprovider方法

1

defget_split(record_file_name, num_sampels, size):

2 reader =tf.tfrecordreader

34 keys_to_features =

1011 items_to_handlers =

1617 decoder =slim.tfexample_decoder.tfexampledecoder(

18keys_to_features, items_to_handlers19)

20return

slim.dataset.dataset(

21 data_sources=record_file_name,

22 reader=reader,

23 decoder=decoder,

24 items_to_descriptions={},

25 num_samples=num_sampels26)

2728

29def get_image(num_samples, resize, record_file="

image.tfrecord

", shuffle=false):

30 provider =slim.dataset_data_provider.datasetdataprovider(

31get_split(record_file, num_samples, resize),

32 shuffle=shuffle33)

34 [data_image] = provider.get(["

image"])

35return data_image

關於tensorflow的學習

import tensorflow as tf import numpy as np x tf.placeholder tf.int32,shape 3,3 y tf.matmul x,x z tf.reduce sum y,1 with tf.session as sess rand array ...

關於tensorflow的碎片

1 突然間視覺化tensorflow報錯 importerror cannot import name monitoring 解決 pip uninstall tensorflow estimator pip install iv tensorflow estimator 1.13.02 tenso...

關於TensorFlow安裝

筆者今日對tensorflow產生濃厚興趣,但在安轉過程中遇到一系列問題,去找資料無奈發現基本上都是複製別人的答案,到最後並沒有解決問題,於是去逛了一些國外的 捯飭了許久終於安裝成功了,下面介紹一下。一開始是檢查一下自己的python版本,在terminal輸入 python顯示 筆者為2.7.12...