tensorflow 2 1 自定義訓練

2021-10-10 08:00:53 字數 854 閱讀 8266

常常會遇到自定義網路結構的情況,自定結構後往往會有多個輸入,或者還需要自定義loss或者accuracy函式,那麼keras的fit就無法使用了,

因此需要自定義訓練步驟

下面則自定義一次batch的訓練步驟,包含了計算loss,accuracy和梯度下降。

tensorflow2.0 主推eager模式,那麼tf.gradienttape則是eager模式下的利器,自動計算梯度並傳遞

最後別忘了加@tf.function進行封裝,使train函式在tensorflow框架下加速執行

然後就可以把封裝好的train函式應用到每個batch data上,開始訓練

@tf.function

def train_step(input1, input2, target):

loss = 0

acc = 0

with tf.gradienttape() as tape: # 開啟自動梯度

predictions = model(input1, input2) # 獲取model中call函式的輸出

loss += loss_func(target, predictions) # 計算loss

acc += train_acc(target, predictions) # 計算accuracy

variables = model.trainable_variables # 獲取model的所有可訓練引數,好進行梯度更新

gradients = tape.gradient(loss, variables) # 將loss函式及可訓練引數傳入得到梯度

return loss, acc

tensorflow2 1安裝指南

開啟anconda prompt 建立conda虛擬環境 用create n 新建乙個名叫tf2.1的環境用python3.7版本 conda create n tf2.1 python 3.7 進入tensorflow2.1環境 conda activate tf2.1 安裝英偉達的sdk10.1...

tensorflow2 1的維度變換

函式的作用是將tensor變換為引數shape的形式,其中shape為乙個列表形式,特殊的一點是列表中可以存在 1 1代表的含義是不用我們自己指定這一維的大小,函式會自動計算,但列表只能存在乙個 1。如果存在多個 1,就是乙個存在多解的方程 a tf.random.normal 4 28,28 3 ...

TensorFlow2 1張量排序

排序函式tf.sort 用法 tf.sort values,axis 1 direction ascending name none 引數說明 排序的座標tf.argsort 返回張量的索引,該張量給出沿軸的排序順序。用法 tf.argsort values,axis 1 direction asc...