在優化問題中我碰到乙個很奇怪的現象,在起初迭代的時候準確率還能到95%,迭代幾次後就變為了9.8%。不知道是什麼鬼,等有結論了再更新。先復現一下程式
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
sess = tf.interactivesession()
mnist = input_data.read_data_sets('mnist_data',one_hot=true)
#引數(每個批次資料量的大小)
batch_size = 100
#計算共有多少個批次
n_batch = mnist.train.num_examples // batch_size
#定義兩個placeholder
x = tf.placeholder(tf.float32,[none,784])
y = tf.placeholder(tf.float32,[none,10])
#建立神經網路層
w1 = tf.variable(tf.random_normal([784,100]))
b1 = tf.variable(tf.zeros([100]))
l1 = tf.nn.tanh(tf.matmul(x,w1)+b1)
w2 = tf.variable(tf.random_normal([100,10]))
b2 = tf.variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(l1,w2)+b2)
#代價函式
#loss = tf.reduce_mean(tf.square(prediction-y))
loss =-tf.reduce_sum(y*tf.log(prediction))
#優化演算法
train_step = tf.train.gradientdescentoptimizer(0.02).minimize(loss)
#準確率結果計算
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#執行模型
with tf.session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(100):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict=)
trainacc = sess.run(accuracy,feed_dict=)
acc = sess.run(accuracy,feed_dict=)
if (epoch+1)%10 == 0:
print ("epoch= "+str(epoch)+" accuray= "+str(acc) + " train accuray="+str(trainacc))
epoch= 9 accuray= 0.966564 train accuray=0.9398
epoch= 19 accuray= 0.991164 train accuray=0.9546
epoch= 29 accuray= 0.998891 train accuray=0.954
epoch= 39 accuray= 0.999764 train accuray=0.9523
epoch= 49 accuray= 0.0989818 train accuray=0.098
epoch= 59 accuray= 0.0989818 train accuray=0.098
epoch= 69 accuray= 0.0989818 train accuray=0.098
epoch= 79 accuray= 0.0989818 train accuray=0.098
epoch= 89 accuray= 0.0989818 train accuray=0.098
epoch= 99 accuray= 0.0989818 train accuray=0.098
NgDL 第二週 NN基礎
正向比如說是計算代價函式值,反向就是增大多少a b c對j的影響,也就是導數的意義,這裡講的是求導鏈式法則。簡直是100倍的時間,看來之前實現的那個 根本就不能用好幾層for迴圈來實現,時間太長了啦!第一次知道。使用對列求和,並 原來的矩陣,b將會被複製3份。第乙個對100複製為乙個行向量,第二個複...
torch學習筆記 二 nn類結構 Linear
linear 是module的子類,是引數化module的一種,與其名稱一樣,表示著一種線性變換。建立parent 的init函式 linear的建立需要兩個引數,inputsize 和 outputsize inputsize 輸入節點數 outputsize 輸出節點數 所以linear 有7個...
iOS 開發 記憶體優化研究
what is resident and dirty memory of ios?記憶體的分配 幾個記憶體 crash 的型別 單例避免過於龐大的單例。單例的使用 普通物件 檢查物件屬性的修飾詞,避免不能釋放導致長時間占用記憶體的情況。資料量很大的屬性處理 利用 void didrecievemem...