tensorflow教程學習三深入MNIST

2021-08-04 15:37:17 字數 4405 閱讀 1966

#載入資料

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('mnist_data', one_hot=true)

#我們使用interactivesession類可以更加靈活地構建**

#它能讓你在執行圖的時候,插入一些計算圖,這些計算圖是由某些操作(operations)構成的。

import tensorflow as tf

sess = tf.interactivesession()

#為輸入影象和目標輸出類別建立節點來開始構建計算圖

x = tf.placeholder("float", shape=[none, 784])

y_ = tf.placeholder("float", shape=[none, 10])

w = tf.variable(tf.zeros([784,10]))

b = tf.variable(tf.zeros([10]))

sess.run(tf.initialize_all_variables())

#我們把向量化後的x和權重矩陣w相乘,加上偏置b,然後計算每個分類的softmax概率值。

y = tf.nn.softmax(tf.matmul(x,w) + b)

#可以很容易的為訓練過程指定最小化誤差用的損失函式,我們的損失函式是目標類別和**類別之間的交叉熵。

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

#用最速下降法讓交叉熵下降,步長為0.01.

train_step = tf.train.gradientdescentoptimizer(0.01).minimize(cross_entropy)

#每一步迭代,我們都會載入50個訓練樣本,然後執行一次train_step,

#並通過feed_dict將x 和 y_張量佔位符用訓練訓練資料替代。

for i in range(1000):

batch = mnist.train.next_batch(50)

train_step.run(feed_dict=)

#首先讓我們找出那些**正確的標籤。tf.argmax 是乙個非常有用的函式,

#它能給出某個tensor物件在某一維上的其資料最大值所在的索引值。由於標籤向量是由0,1組成,

#因此最大值1所在的索引位置就是類別標籤,比如tf.argmax(y,1)返回的是模型對於任一輸入x**到的標籤值,

#而tf.argmax(y_,1)代表正確的標籤,我們可以用tf.equal來檢測我們的**是否真實標籤匹配(索引位置一樣表示匹配)。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

#這裡返回乙個布林陣列。為了計算我們分類的準確率,我們將布林值轉換為浮點數來代表對、錯,然後取平均值。

#例如:[true, false, true, true]變為[1,0,1,1],計算出平均值為0.75。

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

print (accuracy.eval(feed_dict=))

#權重初始化

defweight_variable

(shape):

initial = tf.truncated_normal(shape, stddev=0.1)

return tf.variable(initial)

defbias_variable

(shape):

initial = tf.constant(0.1, shape=shape)

return tf.variable(initial)

#卷積和池化

defconv2d

(x, w):

return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='same')

defmax_pool_2x2

(x):

return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],

strides=[1, 2, 2, 1], padding='same')

#第一層卷積

#卷積在每個5x5的patch中算出32個特徵。卷積的權重張量形狀是[5, 5, 1, 32],

#前兩個維度是patch的大小,接著是輸入的通道數目,最後是輸出的通道數目。

w_conv1 = weight_variable([5, 5, 1, 32])

#對於每乙個輸出通道都有乙個對應的偏置量

b_conv1 = bias_variable([32])

#為了用這一層,我們把x變成乙個4d向量,其第2、第3維對應的寬、高,最後一維代表的顏色通道數

#(因為是灰度圖所以這裡的通道數為1,如果是rgb彩色圖,則為3)。

x_image = tf.reshape(x, [-1,28,28,1])

#把x_image和權值向量進行卷積,加上偏置項,然後應用relu啟用函式,最後進行max pooling

h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)

h_pool1 = max_pool_2x2(h_conv1)

#第二層卷積

#把幾個類似的層堆疊起來。第二層中,每個5x5的patch會得到64個特徵。

w_conv2 = weight_variable([5, 5, 32, 64])

b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)

h_pool2 = max_pool_2x2(h_conv2)

#密集連線層

#現在,尺寸減小到7x7,我們加入乙個有1024個神經元的全連線層,用於處理整個。

#我們把池化層輸出的張量reshape成一些向量,乘上權重矩陣,加上偏置,然後對其使用relu

w_fc1 = weight_variable([7 * 7 * 64, 1024])

b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])

h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

#為了減少過擬合,我們在輸出層之前加入dropout

keep_prob = tf.placeholder("float")

h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

#輸出層

w_fc2 = weight_variable([1024, 10])

b_fc2 = bias_variable([10])

#訓練和評估模型

y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))

#用更加複雜的adam優化器來做梯度最速下降,在feed_dict中加入額外的引數keep_prob來控制dropout比例

train_step = tf.train.adamoptimizer(1e-4).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

sess.run(tf.initialize_all_variables())

for i in range(20000):

batch = mnist.train.next_batch(50)

if i%100 == 0:

train_accuracy = accuracy.eval(feed_dict=)

print ("step %d, training accuracy %g"%(i, train_accuracy))

train_step.run(feed_dict=)

print ("test accuracy %g"%accuracy.eval(feed_dict=))

最後附一張結果圖,可以看出準確率上公升速度很快

tensorflow學習筆記三

分布式tensorflow就是多台伺服器參加乙個tensorflow圖的分布式執行,分布式我感覺就是原來在一台計算機上面執行好幾個程序這些程序互動是由os控制的,而分布式就是把這些程序放在了不同的機器上面執行,他們之間的互動是由分布式框架控制的,實際分布式的核心或者說基本點還是執行的程序。一提到分布...

tensorflow學習筆記(三)

1.是否列印裝置分配日誌 sess tf.session config tf.configproto log device placement true 2.如果指定的裝置不存在,是否允許tf自動分配裝置 sess tf.session config tf.configproto allow sof...

莫煩tensorflow系列教程學習

1.普通機器學習 函式係數 y 0.1x 0.3 coding gbk import tensorflow as tf import numpy as np 生成資料,y 0.1x 0.3 x data np.random rand 100 astype np.float32 y data x da...