TensorFlow經典案例2 實現最近鄰演算法

2022-05-24 16:09:06 字數 1879 閱讀 9381

本次案例需要大家了解關於手寫數字識別(mnist)的資料集的特點和結構:

#tensorflow實現最近鄰演算法

#次案例的前提是了解mnist資料集(手寫數字識別)

import tensorflow as tf

import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

#匯入mnist資料集

mnist = input_data.read_data_sets("/tmp/data/", one_hot=true)

#5000樣本作為訓練集 每乙個訓練和測試樣本的資料都是1*784的矩陣,標籤是1*10的矩陣並且採用one-hot編碼

x_train , y_train = mnist.train.next_batch(5000)

#600樣本作為測試集

x_test , y_test = mnist.test.next_batch(200)

#建立佔位符 none代表將來可以選多個樣本的,如:[60,784]代表選取60個樣本,每乙個樣本的是784列

x_train = tf.placeholder("float",[none,784])

x_test = tf.placeholder("float",[784])#x_test代表只用乙個樣本

#計算距離

#tf.negative(-2)的輸出的結果是2

#tf.negative(2)的輸出的結果是-2

#reduce_sum的引數reduction_indices解釋見下圖

#計算乙個測試樣本和訓練樣本的的距離

#distance 返回的是n個訓練樣本的和單個測試樣本的距離

distance = tf.reduce_sum(tf.abs(tf.add(x_train,tf.negative(x_test))),reduction_indices=1)

#的到距離最短的訓練樣本的索引

prediction = tf.arg_min(distance,0)

accuracy = 0

#初始化變數

init = tf.global_variables_initializer()

with tf.session() as sess:

sess.run(init)

for i in range(len(x_test)):#遍歷整個測試集,每次用乙個的測試樣本和整個訓練樣本的做距離運算

#獲得最近鄰

# 獲得訓練集中與本次參與運算的測試樣本最近的樣本編號

nn_index = sess.run(prediction,feed_dict=)

#列印樣本編號的**類別和準確類別

print("test",i,"prediction:",np.argmax(y_train[nn_index]),"true class:",np.argmax(y_test[i]))

if np.argmax(y_train[nn_index]) == np.argmax(y_test[i]):

#如果**正確。更新準確率

accuracy += 1./len(x_test)

print("完成!")

print("準確率:",accuracy)

輸出:test 196 prediction: 7 true class: 9

test 197 prediction: 9 true class: 9

test 198 prediction: 1 true class: 9

test 199 prediction: 9 true class: 9

完成!準確率: 0.9150000000000007

(分享大量ai大資料資源)

tensorflow 入門經典例項

import tensorflow as tf 發起會話 sess tf.session 兩行都可以執行 具體意思見下方注釋 a tf.variable tf.truncated normal 2,3 0,1,dtype tf.float32,seed 3 a tf.variable tf.rand...

mysql經典案例

使用sql語句建立資料庫,名稱為customdb 答 create database if not exists customdb 建立資料表customer 客戶 deposite 存款 bank 銀行 表結構如下 建立表,如下 答 建立顧客表 create table ifnot exists ...

補碼經典案例

在審核下屬提交的 的時候,發現有這樣一條修改,修改內容為下面參考 的帶 號的兩行,僅是參考,從我們的工程 中擷取了部分 static inline intlm75 temp from reg u16 reg inttmp75 temp get struct i2c client client 那麼 ...