Tensorflow實現MNIST資料自編碼 1

2021-08-18 22:11:37 字數 2445 閱讀 7985

自編碼網路能夠自學習樣本特徵的網路,屬於無監督學習模型的網路,可以從無標註的資料中學習特徵,它可以給出比原始資料更好的特徵描述,具有較強的特徵學習能力。

主要的網路結構就是高維特徵樣本---》編碼成---》低維特徵---》解碼回---》高維特徵,下面以mnist資料集為示例進行演示:

import

tensorflow as tf  

#匯入資料集合

from

tensorflow.examples.tutorials.mnist 

import

input_data  

mnist = input_data.read_data_sets('/data/'

,one_hot=

true

)  #整體流程,原始畫素28*28-784

#784-》256-》128-》128-》256-》784

learning_rate = 0.01

n_hidden_1 = 256

#第一層256個結點

n_hidden_2 = 128

#第二層128個結點

n_input = 784

x = tf.placeholder('float'

,[none

,n_input])  

y = x  

weights =   

biases =   

defencoder(x):  

layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_h1'

]),biases[

'encoder_b1'

]))  

layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_h2'

]),biases[

'encoder_b2'

]))  

return

layer_2  

defdecoder(x):  

layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_h1'

]),biases[

'decoder_b1'

]))  

layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_h2'

]),biases[

'decoder_b2'

]))  

return

layer_2  

pred = decoder(encoder(x))  

cost = tf.reduce_mean(tf.pow(y-pred,2

))  

optimizer = tf.train.gradientdescentoptimizer(learning_rate).minimize(cost)  

training_epochs = 20

#共迭代20次

batch_size = 256

#每次取256個樣本

display_step = 5

#迭代5次輸出一次資訊

#啟動會話

with tf.session() as sess:  

sess.run(tf.global_variables_initializer())  

total_batch = int(mnist.train.num_examples/batch_size)  

#開始訓練

forepoch 

inrange(training_epochs):  

fori 

inrange(total_batch):  

batch_xs,batch_ys = mnist.train.next_batch(batch_size)#取資料

_,c = sess.run([optimizer,cost],feed_dict=)#訓練模型

ifepoch % display_step == 0:

#輸出日誌資訊

print

("epoch:"

,'%4d'

% (epoch+1),

'cost=',""

.format(c))  

print

('training finished!'

)  correct_prediction = tf.equal(tf.argmax(pred,1

),tf.argmax(y,

1))  

accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'

))  

print

('accuracy:',1

-accuracy.eval())  

tensorflow教程學習三深入MNIST

載入資料 from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data one hot true 我們使用interactivesession類可以更加靈活地...

Python keras神經網路識別mnist

上次用matlab寫過乙個識別mnist的神經網路,位址在 這次又用keras做了乙個差不多的,畢竟,現在最流行的專案都是python做的,我也跟一下潮流 資料是從本地解析好的影象和標籤載入的。神經網路有兩個隱含層,都有512個節點。import numpy as np from keras.pre...

Python keras神經網路識別mnist

上次用matlab寫過乙個識別mnist的神經網路,位址在 這次又用keras做了乙個差不多的,畢竟,現在最流行的專案都是python做的,我也跟一下潮流 資料是從本地解析好的影象和標籤載入的。神經網路有兩個隱含層,都有512個節點。import numpy as np from keras.pre...