Tensorflow訓練過程中validation

2021-10-14 07:41:58 字數 1591 閱讀 5073

tensorflow因為靜態圖的原因,邊train邊validation的過程相較於pytorch來說複雜一些。

分別獲取訓練集和驗證集的資料。我這裡使用的是從tfrecoed讀入資料。

# training data

img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train = \

next_batch(dataset_name = ***,..

., is_training =

true

)# validation data

img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val = \

next_batch(dataset_name = ***,..

., is_training =

false

)

注意is training

is_trainging = tf.placeholder(tf.

bool

, shape=()

)

用乙個tf.placeholder來控制是否訓練、驗證。

使用這種方式就可以在乙個graph裡建立乙個分支條件,從而通過控制placeholder來控制是否進行驗證。

img_name_batch, img_batch, gtboxes_and_label_batch, num_objs_batch, img_h_batch, img_w_batch = \

tf.cond(is_training,

lambda

:(img_name_batch_train, img_batch_train, gtboxes_and_label_batch_train, num_objs_batch_train, img_h_batch_train, img_w_batch_train)

,lambda

:(img_name_batch_val, img_batch_val, gtboxes_and_label_batch_val, num_objs_batch_val, img_h_batch_val, img_w_batch_val)

)

如果不適用tf.cond(),會在原圖上新增上許多新的結點,這些結點的引數都是需要重新初始化的,也是就是說,驗證的時候並不是使用訓練的權重。

_, global_stepnp, total_loss_dict_ = sess.run(

[train_op, global_step, total_loss_dict]

, feed_dict =

)val_loss_list =

total_loss_dict_ = sess.run(total_loss_dict_, feed_dict=

)

的訓練過程 模型訓練過程中累計auc

在平時計算auc的時候,大都是使用 sklearn.metrics.roc auc score 來計算。一般操作是將每個batch 出來的結果 拼接起來,然後扔到該函式中計算。但是如果測試集量級過大 比如 10億量級 每個樣本的 結果拼接起來之後至少需要 3g記憶體。這個開銷顯然不是我們想要的。有什...

TensorFlow訓練過程遇到的問題

第一次自己實現乙個完整程式,遇到不少坑。等我這個程式搞完,我要入坑pytorch。我的程式是乙個分類程式,剛開始的時候訓練不收斂,accuracy基本為零,輸出只輸出一類。愁的我腦仁疼。解決辦法 對輸入加乙個正則化。不得不說,微調對訓練過程太重要啦。只是我訓練過程中遇到的第二大坑!為了解決這個問題,...

Tensorboard 訓練過程中的資料視覺化

作業系統 centos7.0 gpu quadro p5000 tensorflow 1.10 cuda version 10.1 python3.6.2 使用tensorboard的基本思想是 在tensorflow程式裡輸出smmary到某個目錄,然後用啟動tensorboard,指定剛才的目錄...