tensorflow的三種資料輸入

2021-08-24 20:40:42 字數 4546 閱讀 5613

tensorflow的資料讀取一共有三種方式

供給資料(feeding):在tensorflow程式執行的每一步,讓python**來供給資料

從檔案讀取資料:在tensorflow圖的起始,讓乙個輸入管線從檔案中讀取資料

預載入資料:在tensorflow圖中定義常量或變數來儲存所有資料(僅僅適用於資料量比較小的情況)

tensorflow的資料供給機制允許你在tensorflow運算圖中將資料注入到任意張量中,因此,python運算可以把資料直接設定到tensorflow圖中。然而卻需要設定placeholder節點,通過run()函式輸入feed_dict引數,可以啟動運算過程。placeholder節點被宣告的時候是未初始化的,也不包含資料,如果沒有為它供給資料,則tensorflow運算的時候會產生錯誤。

在訓練mnist手寫字型識別時就使用到了feed_dict輸入資料,部分**如下。

完整**見:

首先,使用tf.train.string_input_producer()函式產生乙個先入先出的佇列queueu, 如上圖所示,此操作是將檔名堆入佇列中。函式格式為tf.train.string_input_producer(string_tensor,num_epochs=none,shuffle=true),num_epochs和shuffer兩個可配置引數設定最大的訓練迭代次數和檔名亂序,shuffle預設為true,會對檔名進行亂序處理。

filename = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]

# 建立檔案佇列,不限制讀取的數量,所以沒有設定num_epochs

filename_queue = tf.train.string_input_producer(filename)

其次,建立檔案閱讀器reader從佇列中取檔名並讀取資料, 不同reader對應不同的檔案結構。我們以cifar-10二進位制資料集為例,使用tf.fixedlengthrecordreader函式從二進位制檔案中讀取固定長度資料。接下來,使用reader的read方法從上述建立的檔案佇列filename_queue中讀取資料,並用tf.decode_raw()函式將讀取的value值轉換成乙個uint8的張量,然後就可以通過切片和轉換得到需要的格式。即最後得到上述**的example queue。

reader = tf.fixedlengthrecordreader(record_bytes=record_bytes)

result.key, value = reader.read(filename_queue)

# decode_raw操作將乙個字串轉換成乙個uint8的張量

record_bytes = tf.decode_raw(value, tf.uint8)

# tf.strides_slice(input, begin, end, strides=none)擷取[begin, end)之間的資料

result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],[label_bytes+image_byte]), [result.depth, result.height, result.width])

# convert from [depth, height, width] to [height, width, depth]

result.uint8image = tf.transpose(depth_major, [1, 2, 0])

最後,進行批處理,從example queue中批量取出樣本,使用tf.train.shuffle_batch來實現,返回乙個batch_size大小的樣本和樣本標籤。

min_fraction_of_examples_in_queue = 0.4

min_queue_examples =int(num_examples_per_epoch_for_train * min_fraction_of_examples_in_queue)

image_batch, label_batch = tf.train.shuffle_batch([float_image, read_input.label], batch_size=batch_size, capacity= min_queue_examples + 3 * batch_size, min_after_dequeue=min_queue_examples)

在訓練步驟執行之前,需要呼叫tf.train.start_queue_runner()函式啟動輸入管道的執行緒,填充樣本到樣本佇列中,以便出隊操作可以從佇列中拿到樣本。和tf.train.coordinator()配合使用,當有錯誤時,它會完全關閉掉開啟的threads。

with tf.session() as session:

session.run(tf.global_variables_initializer())

# 建立乙個執行緒協調器,用來管理session中啟動的所有執行緒

coord = tf.train.coordinator()

threads = tf.train.start_queue_runners(sess=session, coord=coord)

for index in range(epoches):

_, loss_value, accuracy_value, summary = session.run([t_optimizer, t_loss, t_accuracy, merged])

if index % 1000 == 0:

print('index:', index, ' loss_value:', loss_value, ' accuracy_value:', accuracy_value)

# 終止所有執行緒的命令

coord.request_stop()

# 把threads加入主線程,等到threads結束

coord.join(threads)

coordinator類用來管理session中的多個執行緒,可以用來同時停止多個工作執行緒並且向等待所有工作程序終止的執行緒報告異常, 此執行緒捕獲到異常之後就會終止所有的執行緒。

cifar-10的完整程式在:

Tensorflow載入資料的三種方式

tensorflow作為符號程式設計框架,需要先構建資料流圖,再讀取資料,然後再進行訓練。tensorflow提供了以下三種方式來載入資料 預載入資料 preloaded data 在tensorflow圖中定義常量或變數來儲存所有資料 填充資料 feeding python產生資料,再把資料填充到...

EF的三種資料載入方式

ef的關聯實體載入有三種方式 lazy loading,eager loading,explicit loading,其中lazy loading和explicit loading都是延遲載入。一 延遲載入 預設 lazy loading使用的是動態 預設情況下,如果poco類滿足以下兩個條件,ef...

WPF ListView 的三種資料繫結方式

1.最原始的繫結方式 public observablecollectionobservableobj public mainwindow observableobj.add new observableobj.add new observableobj.add new observableobj....