9 1 mnist softmax 交叉熵多分類器

2022-04-29 14:30:10 字數 2270 閱讀 8885

具體含義不再解釋,這是乙個我們比較常用的乙個多分類器.深度學習的一大優點就是特徵的自動構建,也正是因為該優點,使得分類器層顯得不再那麼重要,在tensorflow的官方原始碼中,softmax是很常見的乙個多分類器.其呼叫也十分的簡單.此處再此單獨拿出來介紹,是為了下一步的學習做準備.

使用方法

cross_entropy = tf.reduce_mean(

tf.nn

.softmax_cross_entropy_with_logits(labels=y_, logits=y))

用於損失函式的定義.

# 引用,官網自帶的原始碼有很多特殊之處,但是沒啥影響,自己寫的時候,完全沒必要這麼多引用

# 額外新增了控制警告訊息等級的code

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import argparse

import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

import os

os.environ['tf_cpp_min_log_level'] = '2'

flags = none

mnist = input_data.read_data_sets("/home/fonttian/data/mnist_data/", one_hot=true)

# create the model,可以看出此處的model非常簡單,就是一層y=wx+b,你也可以繼續增加層數,或者將其替代為卷積層,但是此處對於展示softmax並沒有什麼意義

x = tf.placeholder(tf.float32, [none, 784])

w = tf.variable(tf.zeros([784, 10]))

b = tf.variable(tf.zeros([10]))

y = tf.matmul(x, w) + b

# define loss and optimizer

y_ = tf.placeholder(tf.float32, [none, 10])

# 這部分**很簡單,一些細節我在之前已經介紹過了.

cross_entropy = tf.reduce_mean(

tf.nn

.softmax_cross_entropy_with_logits(labels=y_, logits=y))

train_step = tf.train

.gradientdescentoptimizer(0.5).minimize(cross_entropy)

sess = tf.interactivesession()

tf.global_variables_initializer().run()

# train

for _ in range(1000):

batch_xs, batch_ys = mnist.train

.next_batch(100)

sess.run(train_step, feed_dict=)

# test trained model

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(sess.run(accuracy, feed_dict=))

關於main的部分之前已經有介紹了:

if __name__ == '__main__':

parser = argparse.argumentparser()

parser.add_argument('--data_dir', type=str, default='/home/fonttian/data/mnist_data',

help='directory for storing input data')

flags, unparsed = parser.parse_known_args()

softMax交叉熵多分類引數調優之批次大小選擇

import tensorflow as tf import os import numpy as np import numpy as np os.environ tf cpp min log level 3 輸入隨機種子 myseed eval input learning rate eval ...

神經網路多分類任務的損失函式 交叉熵

神經網路解決多分類問題最常用的方法是設定n個輸出節點,其中n為類別的個數。對於每乙個樣例,神經網路可以得到的乙個n維陣列作為輸出結果。陣列中的每乙個維度 也就是每乙個輸出節點 對應乙個類別。在理想情況下,如果乙個樣本屬於類別k,那麼這個類別所對應的輸出節點的輸出值應該為1,而其他節點的輸出都為0。以...

神經網路多分類任務的損失函式 交叉熵

神經網路解決多分類問題最常用的方法是設定n個輸出節點,其中n為類別的個數。對於每乙個樣例,神經網路可以得到的乙個n維陣列作為輸出結果。陣列中的每乙個維度 也就是每乙個輸出節點 對應乙個類別。在理想情況下,如果乙個樣本屬於類別k,那麼這個類別所對應的輸出節點的輸出值應該為1,而其他節點的輸出都為0。以...