nn優化研究(二)

2021-07-31 04:36:03 字數 2318 閱讀 1614

在優化問題中我碰到乙個很奇怪的現象,在起初迭代的時候準確率還能到95%,迭代幾次後就變為了9.8%。不知道是什麼鬼,等有結論了再更新。先復現一下程式

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

sess = tf.interactivesession()

mnist = input_data.read_data_sets('mnist_data',one_hot=true)

#引數(每個批次資料量的大小)

batch_size = 100

#計算共有多少個批次

n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder

x = tf.placeholder(tf.float32,[none,784])

y = tf.placeholder(tf.float32,[none,10])

#建立神經網路層

w1 = tf.variable(tf.random_normal([784,100]))

b1 = tf.variable(tf.zeros([100]))

l1 = tf.nn.tanh(tf.matmul(x,w1)+b1)

w2 = tf.variable(tf.random_normal([100,10]))

b2 = tf.variable(tf.zeros([10]))

prediction = tf.nn.softmax(tf.matmul(l1,w2)+b2)

#代價函式

#loss = tf.reduce_mean(tf.square(prediction-y))

loss =-tf.reduce_sum(y*tf.log(prediction))

#優化演算法

train_step = tf.train.gradientdescentoptimizer(0.02).minimize(loss)

#準確率結果計算

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#執行模型

with tf.session() as sess:

sess.run(tf.global_variables_initializer())

for epoch in range(100):

for batch in range(n_batch):

batch_xs,batch_ys = mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict=)

trainacc = sess.run(accuracy,feed_dict=)

acc = sess.run(accuracy,feed_dict=)

if (epoch+1)%10 == 0:

print ("epoch= "+str(epoch)+" accuray= "+str(acc) + " train accuray="+str(trainacc))

epoch= 9 accuray= 0.966564 train accuray=0.9398

epoch= 19 accuray= 0.991164 train accuray=0.9546

epoch= 29 accuray= 0.998891 train accuray=0.954

epoch= 39 accuray= 0.999764 train accuray=0.9523

epoch= 49 accuray= 0.0989818 train accuray=0.098

epoch= 59 accuray= 0.0989818 train accuray=0.098

epoch= 69 accuray= 0.0989818 train accuray=0.098

epoch= 79 accuray= 0.0989818 train accuray=0.098

epoch= 89 accuray= 0.0989818 train accuray=0.098

epoch= 99 accuray= 0.0989818 train accuray=0.098

NgDL 第二週 NN基礎

正向比如說是計算代價函式值,反向就是增大多少a b c對j的影響,也就是導數的意義,這裡講的是求導鏈式法則。簡直是100倍的時間,看來之前實現的那個 根本就不能用好幾層for迴圈來實現,時間太長了啦!第一次知道。使用對列求和,並 原來的矩陣,b將會被複製3份。第乙個對100複製為乙個行向量,第二個複...

torch學習筆記 二 nn類結構 Linear

linear 是module的子類,是引數化module的一種,與其名稱一樣,表示著一種線性變換。建立parent 的init函式 linear的建立需要兩個引數,inputsize 和 outputsize inputsize 輸入節點數 outputsize 輸出節點數 所以linear 有7個...

iOS 開發 記憶體優化研究

what is resident and dirty memory of ios?記憶體的分配 幾個記憶體 crash 的型別 單例避免過於龐大的單例。單例的使用 普通物件 檢查物件屬性的修飾詞,避免不能釋放導致長時間占用記憶體的情況。資料量很大的屬性處理 利用 void didrecievemem...