13 tensorflow資料集操作

2021-10-24 07:18:26 字數 3216 閱讀 1174

功能

函式**

載入資料集

datasets.dataset_name.load_data()

構建 dataset 物件

tf.data.dataset_name.from_tensor_slices((x, y))

隨機打散

dataset_name.shuffle(buffer_size)

批訓練dataset_name.batch(size)

資料預處理

dataset_name.map(func_name)

資料集datatset_name

型別boston housing

波士頓房價趨勢

cifar10/100

資料集mnist/fashion_mnist

手寫數字

imdb

文字分類

資料集快取在使用者目錄下的.keras/datasets 資料夾

import tensorflow as tf

from tensorflow.keras import datasets

(x,y)

,(x_text,y_text)

= datasets.mnist.load_data(

)print

(x.shape)

print

(y.shape)

print

(x_text.shape)

print

(y_text.shape)

out:

(60000,28

,28)(

60000,)

(10000,28

,28)(

10000

,)

資料載入進入記憶體後,需要轉換成 dataset 物件, 才能利用 tensorflow 提供的各種操作

import tensorflow as tf

from tensorflow.keras import datasets

(x,y)

,(x_text,y_text)

= datasets.mnist.load_data(

)print

(x.shape)

print

(y.shape)

print

(x_text.shape)

print

(y_text.shape)

train_db = tf.data.dataset.from_tensor_slices(

(x, y)

)print

(train_db)

out:

(60000,28

,28)(

60000,)

(10000,28

,28)(

10000,)

28,28)

,())

, types:

(tf.uint8, tf.uint8)

>

import tensorflow as tf

from tensorflow.keras import datasets

(x,y)

,(x_text,y_text)

= datasets.mnist.load_data(

)print

(x.shape)

print

(y.shape)

print

(x_text.shape)

print

(y_text.shape)

train_db = tf.data.dataset.from_tensor_slices(

(x, y)

)td = train_db.shuffle(

500)

print

(td)

out:

(60000,28

,28)(

60000,)

(10000,28

,28)(

10000,)

28,28)

,())

, types:

(tf.uint8, tf.uint8)

>

import tensorflow as tf

from tensorflow.keras import datasets

(x,y)

,(x_text,y_text)

= datasets.mnist.load_data(

)train_db = tf.data.dataset.from_tensor_slices(

(x, y)

)train_db = train_db.batch(

100)

print

(train_db)

out:

none,28

,28),

(none,)

), types:

(tf.uint8, tf.uint8)

>

import tensorflow as tf

from tensorflow.keras import datasets

(x,y)

,(x_text,y_text)

= datasets.mnist.load_data(

)train_db = tf.data.dataset.from_tensor_slices(

(x, y)

)def

func_name

(x,y)

: x = tf.cast(x, dtype=tf.float32)

/255

. x = tf.reshape(x,[-

1,28*

28]) y = tf.cast(y, dtype=tf.int32)

y = tf.one_hot(y, depth=10)

return x , y

train_db = train_db.

map(func_name)

print

(train_db)

out:

1,784),(

10,))

, types:

(tf.float32, tf.float32)

>

tensorflow學習筆記13

訓練神經網路3 問題解決 上面定義與下面呼叫的引數不一致,導致出現了錯誤 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import input data mnist input data.rea...

用TensorFlow實現iris資料集線性回歸

本文將遍歷批量資料點並讓tensorflow更新斜率和y截距。這次將使用scikit learn的內建iris資料集。特別地,我們將用資料點 x值代表花瓣寬度,y值代表花瓣長度 找到最優直線。選擇這兩種特徵是因為它們具有線性關係,在後續結果中將會看到。本文將使用l2正則損失函式。用tensorflo...

Tensorflow中資料集的相關操作

在資料集框架中,每乙個資料集代表乙個資料 其資料 有一下幾種 張量,tfrecord檔案,文字檔案,sharding檔案等等。一.資料集dataset的常用構造方法 1 從乙個tensor中構造資料集 dataset tf.data.dataset.from tensor slices tensor...