tensorflow神經網路訓練流程

2021-08-28 05:06:14 字數 2065 閱讀 2935

from __future__ import print_function

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

mnist=input_data.read_data_sets('/data/machine_learning/mnist/',one_hot=true)

x=tf.placeholder("float",[none,num_input])

y=tf.placeholder("float",[none,num_classes])

#基本引數設定

learning_rate=0.1

num_steps=500

batch_size=128

display_step=100

#網路引數設定

n_hidden_1=256 #第乙個隱藏層的神經元

n_hidden_2=256 #第二個隱藏層的神經元

num_input=784 #輸入的特徵數量

num_classes=10 #標籤數

#權重和偏置

weights=

biases=

#建立模型

def neural_net(x):

#第乙個隱藏的全連線層

layer_1=tf.add(tf.matmul(x,weights['h1']),biases['b1'])

#第二個隱藏的全連線層

layer_2=tf.add(tf.matmul(layer_1,weights['h2']),biases['b2'])

#輸出層

out_layer=tf.matmul(layer_2,weights['out'])+biases['out']

return out_layer

logits=neural_net(x)

prediction=tf.nn.softmax(logits)

#定義損失和優化器

loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=y))#交叉熵損失

optimizer=tf.train.adamoptimizer(learning_rate=learning_rate) #優化器

train_op=optimizer.minimize(loss_op)#最小化損失

correct_pred=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))

accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))

init=tf.global_variables_initializer()

with tf.session() as sess:

sess.run(init)

for step in range(1,num_steps+1):

batch_x,batch_y=mnist.train.next_batch(batch_size)

sess.run(train_op,feed_dict=)

if step % display_step == 0 or step == 1:

# calculate batch loss and accuracy

loss, acc = sess.run([loss_op, accuracy], feed_dict=)

print("step " + str(step) + ", minibatch loss= " + \

"".format(loss) + ", training accuracy= " + \

"".format(acc))

print("optimization finished!")

# calculate accuracy for mnist test images

print("testing accuracy:", \

sess.run(accuracy, feed_dict=))

Tensorflow卷積神經網路

卷積神經網路 convolutional neural network,cnn 是一種前饋神經網路,在計算機視覺等領域被廣泛應用.本文將簡單介紹其原理並分析tensorflow官方提供的示例.關於神經網路與誤差反向傳播的原理可以參考作者的另一篇博文bp神經網路與python實現.卷積是影象處理中一種...

Tensorflow 深層神經網路

維基百科對深度學習的定義 一類通過多層非線性變換對高複雜性資料建模演算法的合集.tensorflow提供了7種不同的非線性啟用函式,常見的有tf.nn.relu,tf.sigmoid,tf.tanh.使用者也可以自己定義啟用函式.3.1.1 交叉熵 用途 刻畫兩個概率分布之間的距離,交叉熵h越小,兩...

Tensorflow(三) 神經網路

1 前饋傳播 y x w1 b1 w2 b2 import tensorflow as tf x tf.constant 0.9,0.85 shape 1,2 w1 tf.variable tf.random normal 2,3 stddev 1,seed 1 name w1 w2 tf.vari...