tensorflow2 0的手寫數字識別

2021-10-25 03:01:55 字數 3914 閱讀 8583

import tensorflow as tf

import tensorflow.keras as keras

import tensorflow.keras.layers as layers

from tensorflow.keras import sequential

import matplotlib.pyplot as plt

import pandas as pd

from tensorflow.keras.utils import normalize

from tensorflow.keras.utils import to_categorical

import tensorflow.keras.optimizers as optimizers

import tensorflow.keras.losses as losses

import tensorflow.keras.metrics as metrics

(x_train, y_train)

,(x_test, y_test)

= tf.keras.datasets.mnist.load_data(

)x1_train = x_train.reshape(-1

,784

)#改變形狀,變一維

x1_test = x_test.reshape(-1

,784

)x_train_nor = normalize(x1_train, axis=1)

#歸一化

x_test_nor = normalize(x1_test, axis=1)

y_train_onehot = to_categorical(y_train)

#標籤轉為onehot型別

y_test_onthot = to_categorical(y_test)

network = sequential(

)network.add(layers.dense(units=

256, kernel_initializer=

'normal'

, activation=

'relu'))

network.add(layers.dropout(

0.5)

)network.add(layers.dense(units=

10,kernel_initializer=

'normal'

,activation=

'softmax'))

network.build(input_shape=(4

,784))

#此處的build是為了建立引數,所以需要設定

#設定優化方法, 損失函式, 以及評價指標

network.

compile

(optimizer=optimizers.adam(

), loss=losses.categoricalcrossentropy(

), metrics=

['accuracy'])

#設定訓練資料以及標籤, 驗證集佔比重, 迴圈次數, 批大小, 日誌級別

train_history = network.fit(x=x_train_nor, y=y_train_onehot, validation_split=

0.2, epochs=

5, batch_size=

200, verbose=

2)

其中verbose是日誌等級: verbose:日誌顯示

verbose = 0 為不在標準輸出流輸出日誌資訊

verbose = 1 為輸出進度條記錄

verbose = 2 為每個epoch輸出一行記錄

注意: 預設為 1

**錯誤的樣本位置:

df = pd.dataframe(

)print

(df[

(df.label==5)

&(df.predict==3)

])#檢視是5**錯成3的資料位置

tensorflow2 0視訊記憶體設定

遇到乙個問題 新買顯示卡視訊記憶體8g但是tensorflow執行的時候介面顯示只有約6.3g的視訊記憶體可用,如下圖 即限制了我的視訊記憶體,具體原因為什麼我也不知道,但原來的視訊記憶體小一些的顯示卡就沒有這個問題。目前的解決辦法是 官方文件解決 然後對應的中文部落格 總結一下,就是下面的兩個辦法...

Tensorflow2 0 啟用函式

常用啟用函式及對應特點 神經網路結構的輸出為所有輸入的加權和,這導致整個神經網路是乙個線性模型。而線性模型不能解決異或問題,且面對多分類問題,也顯得束手無策。所以為了解決非線性的分類或回歸問題,啟用函式必須是非線性函式。神經網路中啟用函式的主要作用是提供網路的非線性建模能力。這是因為反向傳播演算法就...

初步了解TensorFlow2 0

為什麼要學習tensorflow?深度學習能夠更好地抽取資料中的規律,從而給公司帶來更大的價值 tensorflow是強大且靈活的開源框架 使用廣泛 2.0更加強大 易用 成熟 tensorflow是什麼?是google的開源軟體庫 採用資料流圖,用於數值計算 支援多平台 gpu cpu 移動裝置 ...