《大資料演算法學習》(三)MNIST手寫數字識別

2021-10-04 12:45:59 字數 3252 閱讀 1106

本次學習使用神經網路識別手寫數字,我們使用的資料集是mnist資料集,mnist資料集的長相如下圖所示。

mnist資料集是由0 到9 的數字影象構成。訓練影象有6 萬張,測試影象有1 萬張,這些影象可以用於學習和推理。mnist資料集的一般使用方法是,先用訓練影象進行學習,再用學習到的模型度量能在多大程度上對測試影象進行正確的分類。 mnist的影象資料是28 畫素 × 28 畫素的灰度影象(1 通道),各個畫素的取值在0 到255 之間。每個影象資料都相應地標有「7」、「2」、「1」等標籤。

load_mnist函式以「( 訓練影象, 訓練標籤),( 測試影象,測試標籤)」的形式返回讀入的mnist資料。

def load_mnist():

train_labels_path = 'train-labels.idx1-ubyte'

test_labels_path = 't10k-labels.idx1-ubyte'

train_images_path = 'train-images.idx3-ubyte'

test_images_path = 't10k-images.idx3-ubyte'

with open(train_labels_path, 'rb') as lpath:

magic, n = struct.unpack('>ii', lpath.read(8))

train_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float)

with open(train_images_path, 'rb') as ipath:

magic, num, rows, cols = struct.unpack('>iiii', ipath.read(16))

loaded = np.fromfile(train_images_path, dtype=np.uint8)

train_images = loaded[16:].reshape(len(train_labels), 784).astype(np.float)

with open(test_labels_path, 'rb') as lpath:

magic, n = struct.unpack('>ii', lpath.read(8))

test_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float)

with open(test_images_path, 'rb') as ipath:

magic, num, rows, cols = struct.unpack('>iiii', ipath.read(16))

loaded = np.fromfile(test_images_path, dtype=np.uint8)

test_images = loaded[16:].reshape(len(test_labels), 784)

return train_images, train_labels, test_images, test_labels

img_train,label_train,img_test,label_test = load_mnist()

神經網路的輸入層有784 個神經元,輸出層有10 個神經元。輸入層的784 這個數字**於影象大小的28 × 28 = 784,輸出層的10 這個數字**於10 類別分類(數字0 到9,共10 類別)。此外,這個神經網路有2 個隱藏層,第1 個隱藏層有50 個神經元,第2 個隱藏層有100 個神經元。這個50 和100 可以設定為任何值。

我們本次學習使用的權重是已經訓練好的權重資料,準確率可以達到94%,權重資料檔案名為sample_weight.pkl。

import pickle

def init_network():

with open("sample_weight.pkl","rb") as f:

network = pickle.load(f)

用神經網路進行**:

def sigmoid(x):

return 1/(1+np.exp(-x))

def softmax(a):

exp_a = np.exp(a)

sum_exp_a = np.sum(exp_a)

y = exp_a / sum_exp_a

return y

def predict(network,x):

w1,w2,w3 = network['w1'],network['w2'],network['w3']

b1,b2,b3 = network['b1'],network['b2'],network['b3']

a1 = np.dot(x,w1)+b1

z1 = sigmoid(a1)

a2 = np.dot(z1,w2)+b2

z2 = sigmoid(a2)

a3 = np.dot(z2,w3)+b3

y = softmax(a3)

return y

批處理統計訓練標籤和測試標籤每萬份資料中相同的數字的個數是多少。(訓練標籤有六萬個,測試標籤有一萬個)

batch_szie = 10000

all_same_count = 0

for i in range(0,len(label_train),batch_szie):

label_train_batch = label_train[i:i+batch_szie]

same_count = 0

same_count += np.sum(label_train_batch == label_test)

all_same_count += same_count

print("每萬份資料中相同數字個數:"+str(same_count))

print("總數:"+str(all_same_count))

每萬份資料中相同數字個數:1008

每萬份資料中相同數字個數:1034

每萬份資料中相同數字個數:941

每萬份資料中相同數字個數:1004

每萬份資料中相同數字個數:1018

每萬份資料中相同數字個數:1014

總數:6019

大資料演算法學習筆記

基礎資料結構 線性表 線性表是由相同型別的資料按照一定的順序排成的序列。具體線性表有鍊錶 陣列線性表 棧 形象比喻 從乙個書箱中拿書 和佇列 形象比喻 車站排隊買票 資料概要 概括資料的資料結構叫作資料概要。對於判定問題的嚴格精確解,我們能給出嚴格的是或者否。而對於判定問題的近似演算法,只要給出 是...

大資料演算法學習筆記 七 外存演算法

當資料量巨大時,傳統隨機儲存模型無法適用。一 儲存結構 標準計算理論模型 1 無限記憶體 2 統一訪問代價 3 模型簡單 分層儲存 1 儲存量得到較大提公升,較慢的層次遠離cpu 2 以塊為單位的資料移動 可擴充套件性問題 大多數程式在ram模型中執行,作業系統按需訪問塊。但如果程式分散地訪問磁碟資...

演算法學習(三)

快慢指標 雙指標 兩個指標指向不同元素,從而協同完成任務,主要用於遍歷元素。對撞指標,快慢指標,滑動視窗 對撞指標是指在陣列中,將指向最左側的索引定義為左指標,最右側的定義為右指標,然後從兩頭向中間進行陣列遍歷。leetcode 167 兩數之和 ii 輸入有序陣列 給定乙個已按照 公升序排列 的整...