Tensorflow dropout解決過擬合問題

2021-08-18 21:39:16 字數 2408 閱讀 9219

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

#載入資料集;對資料集分batch並計算總共有多少batch

mnist = input_data.read_data_sets("mnist_data",one_hot=true)

batch_size = 100

n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder, keep_prob控制有多少個網路節點用來訓練模型

x = tf.placeholder(tf.float32,[none,784])

y = tf.placeholder(tf.float32,[none,10])

keep_prob=tf.placeholder(tf.float32)

#建立乙個簡單的神經網路

w1 = tf.variable(tf.truncated_normal([784,2000],stddev=0.1))

b1 = tf.variable(tf.zeros([2000])+0.1)

l1 = tf.nn.tanh(tf.matmul(x,w1)+b1)

l1_drop = tf.nn.dropout(l1,keep_prob)

w2 = tf.variable(tf.truncated_normal([2000,2000],stddev=0.1))

b2 = tf.variable(tf.zeros([2000])+0.1)

l2 = tf.nn.tanh(tf.matmul(l1_drop,w2)+b2)

l2_drop = tf.nn.dropout(l2,keep_prob)

w3 = tf.variable(tf.truncated_normal([2000,1000],stddev=0.1))

b3 = tf.variable(tf.zeros([1000])+0.1)

l3 = tf.nn.tanh(tf.matmul(l2_drop,w3)+b3)

l3_drop = tf.nn.dropout(l3,keep_prob)

w4 = tf.variable(tf.truncated_normal([1000,10],stddev=0.1))

b4 = tf.variable(tf.zeros([10])+0.1)

prediction = tf.nn.softmax(tf.matmul(l3_drop,w4)+b4)

#交叉熵代價函式並用梯度下降法進行訓練

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))

train_step = tf.train.gradientdescentoptimizer(0.2).minimize(loss)

#初始化變數

init = tf.global_variables_initializer()

#結果存放在乙個布林型列表中,並計算準確率 argmax()返回一維向量中最大值所在的位置

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.session() as sess:

sess.run(init)

for epoch in range(31):

for batch in range(n_batch):

batch_xs,batch_ys = mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict=) #0.7的網路節點訓練

#keep_prob:1.0 表示所有的網路節點用來測試

test_acc = sess.run(accuracy,feed_dict=)

train_acc = sess.run(accuracy,feed_dict=)

print("iter " + str(epoch) + ",testing accuracy " + str(test_acc) +",training accuracy " + str(train_acc))

正常情況下,用訓練集來測試和用測試集來測試的差異並不是太大,但是這個試驗中差異卻很大,這就是過擬合導致的。

過擬合一般出現在這樣的情況下:待訓練的網路結構複雜,這就會使引數太多,但是訓練資料不足,從而出現過擬合。

可以通過增加資料集的方法,或者dropout方法解決。

解決過擬合

獲取和使用更多的資料集 對於解決過擬合的辦法就是給與足夠多的資料集,讓模型在更可能多的資料上進行 觀察 和擬合,從而不斷修正自己。然而事實上,收集無限多的資料集幾乎是不可能的,因此乙個常用的辦法就是調整已有的資料,新增大量的 噪音 或者對影象進行銳化 旋轉 明暗度調整等優化。另外補充一句,cnn在影...

防止過擬合以及解決過擬合

過擬合 為了得到一致假設而使假設變得過度複雜稱為過擬合。乙個過配的模型試圖連誤差 噪音 都去解釋 而實際上噪音又是不需要解釋的 導致泛化能力比較差,顯然就過猶不及了。這句話很好的詮釋了過擬合產生的原因,但我認為這只是一部分原因,另乙個原因是模型本身並不能很好地解釋 匹配 資料,也就是說觀測到的資料並...

dropout解決過擬合

原理就是在第一次學習的過程中,隨即忽略一些神經元和神經的鏈結。使得神經網路變得不完整。一次一次。每一次得出的結果不依賴某乙個引數。這樣就解決了過擬合問題。import tensorflow as tf from sklearn.datasets import load digits from skl...