TensorFlow輸入資料處理框架

2021-08-17 09:53:32 字數 2543 閱讀 3297

如圖,大致為輸入資料處理流程示意圖。輸入資料處理第一步為獲取儲存訓練資料的檔案列表,在該圖中檔案列表為。通過tf.train.string_input_producer函式可以選擇性將檔案順序打亂,並加入輸入佇列。tf.train.string_input_producer函式會生成並維護乙個輸入檔案佇列,不同執行緒中的檔案讀取函式可以共享這個檔案佇列。

在讀取樣例程式後,需要對影象進行預處理。預處理的過程也會通過tf.train.shuffle_batch提供的機制並行的跑在多個執行緒中。輸入資料處理流程的最後通過tf.train.shuffle_batch函式將處理好的單個輸入樣例整理成batch提供給神經網路輸入層。

import tensorflow as tf

#建立檔案列表

files = tf.train.match_filenames_once("records/output.tfrecords")

#建立檔案輸入佇列

filename_queue = tf.train.string_input_producer(files, shuffle=false)

# 讀取檔案。

# 解析資料。假設image是影象資料,label是標籤,height、width、channels給出了的維度

reader = tf.tfrecordreader()

_,serialized_example = reader.read(filename_queue)

# 解析讀取的樣例。

features = tf.parse_single_example(

serialized_example,

features=)

image, label = features['image'], features['label']

height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)

channels = tf.cast(features['channels'], tf.int32)

# 從原始影象中解析出畫素矩陣,並根據畫素尺寸還原影象

decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)

decoded_image.set_shape([height, width, channels])

#定義神經網路輸入層的大小

image_size = 299

# preprocess_for_train函式是對進行預處理的函式

distorted_image = preprocess_for_train(decoded_image, image_size, image_size,

none)

#將處理後的影象和標籤通過tf.train.shuffle_batch整理成神經網路訓練時需要的batch

min_after_dequeue = 10000

batch_size = 100

capacity = min_after_dequeue + 3 * batch_size

image_batch, label_batch = tf.train.shuffle_batch([images, labels],

batch_size=batch_size,

capacity=capacity,

min_after_dequeue=min_after_dequeue)

# 定義神經網路的結構及優化過程。image_batch可以作為輸入提供給神經網路的輸入層

#label_batch則提供了輸入batch中樣例的正確答案

logit = inference(image_batch)

loss = calc_loss(logit, label_batch)

train_step = tf.train.gradientdescentoptimizer(learning_rate).minimize(loss)

#宣告會話並執行神經網路優化過程

with tf.session() as sess:

#神經網路訓練準備工作,這些工作包括變數初始化、執行緒啟動

sess.run(

[tf.global_variables_initializer(),

tf.local_variables_initializer()])

coord = tf.train.coordinator()

threads = tf.train.start_queue_runners(coord=coord)

# 神經網路訓練過程

for i in range(training_rounds):

sess.run(train_step)

#停止所有執行緒

coord.request_stop()

coord.join()

其**如下:

tensorflow資料輸入

tensorflow的資料輸入採用佇列 執行緒的機制,這樣可以使得系統更加輕量。如例項 獲取資料的列表 image list,label list read image label list images tf.convert to tensor self.image list,dtype tf.s...

深度學習 TensorFlow 輸入資料處理框架

將mnist資料集中的所有訓練資料儲存到tfrecord檔案中 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data import numpy as np 生成整數型屬性 轉換型別 將輸入...

9 處理輸入

遊戲有很多輸入,如鍵盤,滑鼠,手柄等.sdl庫將這些處理變得非常簡單,我們這裡將這幾種輸入統一到一起.這裡不講太多,因為目前我們對輸入的處理就是檢測輸入裝置的狀態,來更新遊戲物件的狀態.看下 就行了,然後對於遊戲中的物件怎麼使用輸入裝置的狀態,自己發揮 inputhandler.h ifndef i...