tensorflow資料載入器

2021-10-02 12:53:25 字數 3166 閱讀 1843

自己看的

import numpy as np

import os

from pil import image

import random

os.environ['tf_cpp_min_log_level'] = '2'

train_file = './data/train'

test_file = './data/test'

class data:

def __init__(self, batch_size, istrain = true):

self.batch_size = batch_size # batch的大小

self.classific = 9 # 9種分類

self.istrain = istrain # 訓練

self.ls_imgs_labels = self.get_imgs_labels() # 獲得全部的路徑

self.epoch_imgs, self.epoch_labels = self.get_epoch() # 獲得全部的內容

def get_classific(self):

return self.classific

# 讀取全部資料的路徑

def get_imgs_labels(self):

ls_imgs =

cwd = os.getcwd()

if self.istrain:

folder = train_file

print("訓練檔案")

else:

folder = test_file

print("測試檔案")

folder = os.path.normpath(folder)

folder = os.path.join(cwd,folder)

# 每個資料夾的裡面的標籤

label = 0

# 遍歷每乙個資料夾

for files in os.listdir(folder):

if files == "airplane":

label = 0

elif files == "automobile":

label = 1

elif files == "bird":

label = 2

elif files == "cat":

label = 3

elif files == "deer":

label = 4

elif files == "dog":

label = 5

elif files == "horse":

label = 6

elif files == "ship":

label = 7

elif files == "truck":

label = 8

else:

print("沒有這個種類 :",files)

return

curr_folder = os.path.join(folder,files)

# print("當前處理 :",files,"標籤 :",label)

# 每一張

for pic in os.listdir(curr_folder):

ls_pic =

img_path = os.path.join(curr_folder,pic)

# img = image.open(img_path)

# img_raw = img.tobytes()

# labels = [0]*self.classific

# labels[label] = 1

print("全部記錄完成 %d" % len(ls_imgs))

return ls_imgs

# 獲得全部

def get_epoch(self):

print("begin load data")

# 得到所有路徑

ls_imgs_labels = self.ls_imgs_labels

imgs =

labels =

for name in ls_imgs_labels:

img = image.open(name[0])

img = np.asarray(img) / 255.0

label = [0] * self.classific

label[name[1]] = 1

print("data load successful")

return imgs, labels

# 從epoch裡面選擇batch個

def get_epoch_batch(self):

length = len(self.epoch_labels)

curr_imgs =

curr_labels =

count = 0

while(count < self.batch_size):

begin = int(random.random() * length)

count = count + 1

return curr_imgs, curr_labels

# 獲得batch_size個圖

def get_batch(self):

# 從乙個隨機數開始

length = len(self.ls_imgs_labels)

# 儲存需要得到的名

curr_img_name =

count = 0

while(count < self.batch_size):

begin = int(random.random() * length)

count = count + 1

imgs =

labels =

for name in curr_img_name:

img = image.open(name[0])

img = np.asarray(img) / 255.0

label = [0]*self.classific

label[name[1]] = 1

# print(name)

return imgs,labels

Tensorflow載入資料

1.reader tf.textlinereader 每次讀取一行 閱讀器的read方法會輸出乙個key來表徵輸入的檔案和其中的紀錄 對於除錯非常有用 同時得到乙個字串標量,這個字串標量可以被乙個或多個解析器,或者轉換操作將其解碼為張量並且構造成為樣本。file1.csv內容 10010 1112 ...

Tensorflow之資料的載入

載入資料 tensorflow作為符號程式設計框架,需要先構建資料流圖,再讀取資料,隨後在進行模型的訓練,所以其官網給出了三種載入資料的方式 1 預載入資料 x1 tf.constant 2,3,4 x2 tf.constant 4,0,1 y tf.add x1,x2 這種方法的缺點在於,將資料直...

tensorflow2的資料載入

對於一些小型常用的資料集,tensorflow有相關的api可以呼叫 keras.datasets 經典資料集 1 boston housing 波士頓房價 2 mnist fasion mnist 手寫數字集 時髦品集 3 cifar10 100 物象分類 4 imdb 電影評價 使用 tf.da...