Tensorflow的佇列多執行緒讀取資料

2021-09-12 02:07:25 字數 2584 閱讀 8656

在tensorflow中,有三種方式輸入資料

1利用feed_dict送入numpy陣列

2利用佇列從檔案中直接讀取資料

3預載入資料

其中第一種方式很常用,在tensorflow的mnist訓練原始碼中可以看到,通過feed_dict={},可以將任意資料送入tensor中。

第二種方式相比於第一種,速度更快,可以利用多執行緒的優勢把資料送入佇列,再以batch的方式出隊,並且在這個過程中可以很方便地對影象進行隨機裁剪、翻轉、改變對比度等預處理,同時可以選擇是否對資料隨機打亂,可以說是非常方便。該部分的原始碼在tensorflow官方的cifar-10訓練原始碼中可以看到,但是對於剛學習tensorflow的人來說,比較難以理解,本篇部落格就當成我除錯完成後寫的一篇總結,以防自己再忘記具體細節。

path = 'e:\dataset\cifar-10\cifar-10-batches-py'

# extract train examples

num_train_examples = 50000

x_train = np.empty((num_train_examples, 32, 32, 3), dtype='uint8')

y_train = np.empty((num_train_examples), dtype='uint8')

for i in range(1, 6):

fpath = os.path.join(path, 'data_batch_' + str(i))

(x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000]) = load_and_decode(fpath)

# extract test examples

fpath = os.path.join(path, 'test_batch')

x_test, y_test = load_and_decode(fpath)

return x_train, y_train, x_test, np.array(y_test)

其中load_and_decode函式只需要按照cifar-10官網給出的方式decode就行,最終返回的x_train是乙個[50000, 32, 32, 3]的ndarray,但對於ndarray來說,進行預處理就要麻煩很多,為了取mini-sgd的batch,還自己寫了乙個類,通過呼叫train_set.next_batch()函式來取,總而言之就是什麼都要自己動手,效率確實不高

但對於第二種方式,讀取起來就要麻煩很多,但使用起來,又快又方便

以讀取cifar10為例。

**#1、讀取檔案,生成檔名列表**

path = 'e:\dataset\cifar-10\cifar-10-batches-py'

filenames = [os.path.join(path, 'data_batch_%d' % i) for i in range(1, 6)]

**#2、利用tf.train.string_input_producer函式生成乙個讀取佇列**

filename_queue = tf.train.string_input_producer(filenames)

def read_cifar10(filename_queue):

label_bytes = 1

image_size = 32

channels = 3

image_bytes = image_size*image_size*3

record_bytes = label_bytes+image_bytes

# **3、定義乙個 reader。**

#若讀取列表中為單獨檔案則用tf.wholefilereader()

reader = tf.fixedlengthrecordreader(record_bytes)

key, value = reader.read(filename_queue)

record_bytes = tf.decode_raw(value, tf.uint8)

label = tf.strided_slice(record_bytes, [0], [label_bytes])

depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],

[label_bytes + image_bytes]),

[channels, image_size, image_size])

image = tf.transpose(depth_major, [1, 2, 0])

return image, label #tensor格式

定義乙個reader,來讀取固定長度的資料,這個固定長度是由cifar-10資料集的儲存格式決定的,1byte的標籤加上32 *32 *3長度的影象,3代表rgb三通道,由於的是按[channel, height, width]的格式儲存的,為了變為常用的[height, width, channel]維度,需要在17行reshape一次影象,最終我們提取出了一副完整的影象與對應的標籤。

ThreadPoolExecutor 多執行緒

from concurrent.futures import threadpoolexecutor,wait,all completed from queue import queue myqueue queue 佇列,用於儲存函式執行結果。多執行緒的問題之一 如何儲存函式執行的結果。def thr...

c 多線例項

using system using system.threading using system.text namespace controlthread 第二個執行緒正在執行,請輸入 s uspend,r esume,i nterrupt,or e xit.datetime.now.tostrin...

CLLocationManager在多執行緒下使用

似乎定位的返回 呼叫 只能有主線程來呼叫,並且這個物件還必須是在主線程建立的。做過以下實驗 1.子執行緒中 self.locationmanager cllocationmanager alloc init autorelease locationmanager.delegate self loca...