mini batch 版交叉熵誤差的實現

2021-09-24 14:25:24 字數 1921 閱讀 2297

4.2.4 mini-batch 版交叉熵誤差的實現

one-hot格式的監督資料

def cross_entropy_error(y, t):

if y.ndim == 1: #一維陣列要把第一維放到第二維,表示只有一條資料

t = t.reshape(1, t.size) #reshape函式代兩維的引數

y = y.reshape(1, y.size)

batch_size = y.shape[0] #記下第一維的數值,表示有多少條資料

return -np.sum(t * np.log(y + 1e-7)) / batch_size

t為(batch_size,10), y 為(batch_size,10) 

t*logy   結果還是(batch_size,10) 

np.sum 後成為乙個數值

y=np.array([[0,1,0.1,0,0,0,0,0,0,0],

[0,0,0.2,0.8,0,0,0,0,0,0]])

t=np.array([[0,1,0,0,0,0,0,0,0,0],

[0,0,0,1,0,0,0,0,0,0]])#one-hot

batch_size =y.shape[0] 

print( batch_size ) #batch_size =2

r= -np.sum(t * np.log(y + 1e-7)) / batch_size 

print(r) #0.11157166315711126

非one-hot格式的監督資料

def cross_entropy_error(y, t):

if y.ndim == 1:

t = t.reshape(1, t.size)

y = y.reshape(1, y.size)

batch_size = y.shape[0]

return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7))/ batch_size #

import numpy as np

y=np.array([[0,1,0.1,0,0,0,0,0,0,0],            

[0,0,0.2,0.8,0,0,0,0,0,0]])

t_onehot=np.array([[0,1,0,0,0,0,0,0,0,0],            

[0,0,0,1,0,0,0,0,0,0]])#one-hot

t = t_onehot.argmax(axis=1)#非one-hot

print(t)#[1 3]

batch_size = y.shape[0]

print(batch_size)#2

k=y[np.arange(batch_size), t] # [y[0,1] y[1,3]]

print(k)#[1.  0.8]

r=-np.sum(np.log(y[np.arange(batch_size), t] + 1e-7))/ batch_size

print(r)#0.11157166315711126

t為(batch_size,1), y 為(batch_size,10)

y[np.arange(batch_size), t]   結果是一維的(有batch_size個值)

np.sum 後成為乙個數值

import numpy as np

a=np.array([1,2,3,4])

print(a) # [1 2 3 4]

b=a.reshape(1,a.size)

print(b)#[[1 2 3 4]]

交叉熵誤差函式

概率分布p和q的交叉熵定義為 p,q operatorname log q mathrm p d p parallel q 可以看到,交叉熵可以拆解為兩部分的和,也就是p的熵加上p與q之間的kl距離,對於乙個已知的分布p,它的熵 是乙個已知的常數,所以在這種情況下,使用交叉熵等價於使用kl距離,而且...

均方誤差和交叉熵誤差

均方誤差個交叉熵誤差都是常用的損失函式之一。損失函式是用來表示神經網路效能的 惡劣程度 的指標。即當前神經網路對監督資料在多大程度上不擬合,在多大 程度上不一致。說白了,即所建立的神經網路對輸入資料的 輸出值與監督資料 實際輸出值 的差距。上面是計算公式,其中yk表示神經網路的 輸出值,tk表示監督...

深度學習入門 損失函式 交叉熵誤差

交叉熵誤差 cross entropy error 07 設定乙個微小值,避免對數的引數為0導致無窮大 1 10的負7次方 return np.sum t np.log y delta 注意這個log對應的是ln t 0,0,1,0,0,0,0,0,0,0 設定 2 為正確解標籤 y 0.1,0.0...