tensoflow 識別數字

2021-08-29 01:15:33 字數 3217 閱讀 7437

tensoflow 的確抽象,不過也是很不錯的

現在貼段識別文體的** 有注釋,慢慢體會

#樣本的準備

#通過最近臨域法來判斷識別數字 待檢測的和樣本進行比較 k個中我們找到個相似度最大的

#當前描繪的是哪些點,就需要解析中解析的點,需要通過lable標籤獲得

#將當前的lable轉換為具體的數字

#檢測概率統計

import tensorflow as tf

import numpy as ny

import random

from tensorflow.examples.tutorials.mnist import input_data

#load data 第乙個文字路徑 第二個表示 第乙個引數為1 其他為0

#完成測試和訓練的距離計算

#如何根據knn 中k個最近的5張 和500張做差,在500張中找到4張最接近的測試

mnist = input_data.read_data_sets('mnist_data',one_hot = true)

# 屬性設定 總共訓練 55000張

trainnum = 55000

#測試 10000

testnum = 10000

#訓練的時候使用的500張

trainsize = 500

#測試5張

testsize = 5

#以下資料的分解

#訓練資料的下標 生存了 trainsize 這麼多個隨機數 範圍是0到trainnum 之間 replace=false 表示不可以重複

trainindex = ny.random.choice(trainnum,trainsize,replace=false)

testindex = ny.random.choice(testnum,testsize,replace=false)

#當前的訓練資料

traindata = mnist.train.images[trainindex]

#獲取當前訓練標籤

trainlable = mnist.train.labels[trainindex]

testdata = mnist.train.images[testindex]

testlable = mnist.train.labels[testindex]

'''如何計算兩張的距離,可以用兩張對應元素相減

'''print('traindata.shape=',traindata.shape)

print('trainlable.shape=',trainlable.shape)

print('testdata.shape=',testdata.shape)

print('tetstlable.shape=',testlable.shape)

traindatainput = tf.placeholder(shape=[none,784],dtype=tf.float32)

trainlableinput = tf.placeholder(shape=[none,10],dtype=tf.float32)

testdatainput = tf.placeholder(shape=[none,784],dtype=tf.float32)

tetstlableinput = tf.placeholder(shape=[none,10],dtype=tf.float32)

#兩張的距離

#完成維度的轉換

f1 = tf.expand_dims(testdata,1)#維度的擴充套件

f2 = tf.subtract(traindatainput,f1) #維度相減

f3 = tf.reduce_sum(tf.abs(f2),reduction_indices=2) #完成資料累加

f4 = tf.negative(f3) #取反

f5,f6 = tf.nn.top_k(f4,k=4) #選取f4中最大的4個值 對f3來說是最小的四個值

#f6 儲存的是最近的4張的index,根據下標索引訓練出標籤

f7 = tf.gather(trainlableinput,f6)

#數字的獲取

f8 = tf.reduce_sum(f7,reduction_indices=1)

#選取在某乙個緯度上最大的值 並記錄當前的x下標

f9 = tf.argmax(f8,dimension=1)

#所有的檢測的最大值

with tf.session() as sess:

p1 = sess.run(f1,feed_dict=)

print('p1=',p1.shape)

p2 = sess.run(f2,feed_dict=)

print('p2=',p2.shape)

p3 = sess.run(f3,feed_dict=)

print('p3=',p3.shape)

print('p3[0,0]=',p3[0,0])

p4 = sess.run(f4,feed_dict=)

print('p4=',p4.shape)

print('p4[0,0]=',p4[0,0])

#每一張測試分別對應4張對應的訓練

p5,p6 = sess.run((f5,f6),feed_dict=)

print('p5=',p5.shape)

print('p5[0,0]=',p5[0,0])

print('p6=',p6.shape)

print('p6[0,0]=',p6[0,0])

p7 = sess.run(f7,feed_dict=)

print('p7=',p7.shape)

print('p7=',p7)

p8 = sess.run(f8,feed_dict=)

print('p8=',p8.shape)

print('p8=',p8)

p9 = sess.run(f9,feed_dict=)

print('p9=',p9.shape)

print('p9=',p9)

#找到測試標籤中的所有內容

p10 = ny.argmax(testlable[0:5],axis=1)

print('p10=',p9)

j = 0

for i in range(0,5):

if p10[i] == p9[i]:

j = j+1

print(j)

mnist手寫數字識別資料初探

mnist手寫數字識別為學習tensorflow等深度學習框架的入門經典資料集,tensorflow有直接載入mnist資料庫相關的模組,其地位類似於使用r語言作資料探勘中的iris資料集。網路上關於使用mnist資料集實現各類深度學習演算法的 非常多,但是對於初學者而言,依著葫蘆畫瓢雖然將網上do...

python 機器學習knn識別數字1

train images idx3 ubyte.gz training set images 9912422 bytes train labels idx1 ubyte.gz training set labels 28881 bytes t10k images idx3 ubyte.gz test...

呼叫海康攝像頭實時識別數字牌數字

呼叫海康攝像頭實時識別數字牌數字 專案所需,呼叫網路攝像頭來完成對乙個數字牌的識別,用模板匹配的方法分離出數字 用vs2015 opencv3.4.3完成 include include include 使用命名空間 using namespace cv videocapture cap rtsp ...