二分類問題中,大量的負樣本會影響到網路的訓練嗎

2021-09-26 20:24:51 字數 4016 閱讀 3174

在分類問題中,比如我們訓練乙個網路,讓它識別出這張**是否為人,我們的資料集假設有張1000張,其中100張為人,900張不是人(裡面包含了車,飛機,花朵等亂七八糟的),可以看到這個資料集出現了負樣本遠大於正樣本的情況.現在假設我們把這1000張一次性送入網路進行訓練,那麼得到的損失值如何計算?負樣本對損失值有貢獻嗎?

答案:(1)損失值loss計算公式:

假設真實標籤值y_true=[y1,y2,y3,...yn]

**值y_pred=[z1,z2,z3,...zn]

(2) 負樣本對損失值有貢獻,主要體現到上面的loss計算上,雖然負樣本標籤值=0,分子中遇到負樣本0,y_i*log(z_i)=0,但是在分母中,n的值是正負樣本樣本的總數量,所以負樣本對損失值是有貢獻的,特別當正負樣本比例失衡,負樣本遠大於正樣本時,會造成訓練進行不下去.

還是用一段程式來證明吧

import tensorflow as tf

y_true=[[0], [0],[0],[0],[1], [1]]#真實標籤

y_pre=[[0.9],[0.9],[0.9],[0.9],[0.6],[0.5]]#**值

cross_entropy=-tf.reduce_mean(y_true*tf.log(tf.clip_by_value(y_pre,1e-10,1.0)))#-ylogp

a=-tf.log(0.6)-tf.log(0.5)#就算所有樣本的損失和,可以看出這裡負樣本因為標籤值=0,所以無論**值是多少,對損失和是沒有影響的

with tf.session() as sess:

print(sess.run(a))#1.20397

print(sess.run(a / 2))#0.601986

print(sess.run(a / 6))#0.200662

print(sess.run(cross_entropy))#0.200662#這裡可以看出經過reduce_mean之後,這個是按照所有的樣本數量來求平均值的,雖然前面正樣本只有2個,

# 但是算平均損失的話,這裡是除以樣本總數的,所以隨著負樣本的增多,當其數量遠遠大於正樣本數量時,無論網路對於正樣本是否**正確,最後的總的損失值會很小,

#即由負樣本主導了損失函式值,造成損失函式值非常小,無法反向傳播,訓練失敗

完整例子,**於tensorflow實戰google深度學習框架

import tensorflow as tf

from numpy.random import randomstate

batch_size=8

w1=tf.variable(tf.random_normal([2,3],stddev=1,seed=1))#第乙個權重矩陣

#tf.randdom_normal(shape,mean=0.0,stddev=1.0,dtype=tf.float32,seed=none,name=none)

#shape:輸出張量的形狀,必選,mean正態分佈的均值,預設=0,stddev標準差,預設=1,dtype,輸出的型別,預設=tf.float32

#seed隨機數種子,是乙個整數,當設定之後,每次生成的隨機數都一樣,name操作的名稱

w2=tf.variable(tf.random_normal([3,1],stddev=1,seed=1))#生成3行1列矩陣

x=tf.placeholder(tf.float32,shape=(none,2),name='x-input')#訓練樣本,?x2的矩陣

y_=tf.placeholder(tf.float32,shape=(none,1),name='y-input')#訓練樣本的標籤,?x1的矩陣

a=tf.matmul(x,w1)#計算x*w1,[?,2]*[2,3]=[?,3]

y=tf.matmul(a,w2)#計算a*w2=x*w1*w2,[?,2]*[2,3]*[3,1=[?,3]*[3,1]=[?,1]

cross_entropy=-tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0)))#-ylogp

#tf.reduce_mean(input_tensor,axis=none,)函式用於計算張量tensor沿著指定的數軸tensor的某一維度)上的平均值,主要用作將維或者計算tensor

# (影象)的平均值,

#第乙個引數input_tensor:輸入的待降維的tensor

#axis: 指定的軸,如果不指定,則計算所有元素的均值

train_step=tf.train.adamoptimizer(0.001).minimize(cross_entropy)

rdm=randomstate(1)#因為這裡固定了隨機數種子,所以每次執行程式,生成的訓練樣本都是一樣的

dataset_size=3

x=rdm.rand(dataset_size,2)#生成128x2的訓練樣本

y=[[int(x1+x2<1)] for (x1,x2) in x]#對訓練樣本生成標籤,值=0或1

with tf.session() as sess:

init_op=tf.initialize_all_variables()

sess.run(init_op)

print(sess.run(w1))

print(sess.run(w2))

# print('x', sess.run(x))

# print('y', sess.run(y))

steps=5000

for i in range(steps):

start=(i*batch_size)%dataset_size

end=min(start+batch_size,dataset_size)

sess.run(train_step,feed_dict=)

if i%1000==0:

total_cross_entropy=sess.run(cross_entropy,feed_dict=)

print('after %d training steps, cross entropy on all data is %g'%(i,total_cross_entropy))#%g啥意思?

print(sess.run(w1))

print(sess.run(w2))

解決方案:

挖坑1: 從等式(1)可以看出負樣本對損失函式值的影響主要體現在分母上,那我乾脆對於y_true, 多加一列,標記一下其是負樣本,這樣我就不把它統計到樣本總量中不久可以了嗎?

具體操作:把y_true中值=1的個數統計出來,假設=m,求均值時,分母=m即可.這個操作可以在retinanet或其他有構造anchor的**中發現.

cls_loss = focal_weight * keras.backend.binary_crossentropy(target=labels, output=classification)

# compute the normalizer: the number of positive anchors

normalizer = backend.where(keras.backend.equal(anchor_state, 1))#只計算正樣本的數量,忽略負樣本的數量

normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx())

normalizer = keras.backend.maximum(1.0, normalizer)#儘管anchor中負樣本特別多,這裡算平均值時只除以正樣本的數量,可有效避免負樣本主導損失函式造成loss很小,

#導致的訓練失敗

挖坑2:在實際的應用場景中,做目標檢測時,同一張中會出現多個目標,而且這多個目標可能屬於不同的類別中,比如在一張**現了乙個人和一群狗,我們的網路任務為在這中找出人和狗,並識別出是狗還是人,所以這裡我們對同一張,就會

後面問題不知道如何描述了,讓我想想再寫,2019-09-04

二分類問題中的混淆矩陣 ROC以及AUC評估指標

本篇博文簡要討論機器學習二分類問題中的混淆矩陣 roc以及auc評估指標 作為評價模型的重要參考,三者在模型選擇以及評估中起著指導性作用。按照循序漸進的原則,依次討論混淆矩陣 roc和auc 設定乙個機器學習問題情境 給定一些腫瘤患者樣本,構建乙個分類模型來 腫瘤是良性還是惡性,顯然這是乙個二分類問...

二分類問題中混淆矩陣 PR以及AP評估指標

仿照上篇博文對於混淆矩陣 roc和auc指標的 本文簡要討論機器學習二分類問題中的混淆矩陣 pr以及ap評估指標 實際上,roc,auc 與 pr,ap 指針對具有某種相似性。按照循序漸進的原則,依次討論混淆矩陣 pr和ap 設定乙個機器學習問題情境 給定一些腫瘤患者樣本,構建乙個分類模型來 腫瘤是...

SVM的二分類問題(fitcsvm)

非線性二分類 matlab2018及之後的版本取消了svmtrain和svmclassify函式,取而代之的更多的fitcsvm和predict函式。但是也使得決策邊界無法直接生成,需要根據fitcsvm中得到的引數畫出決策邊界。問題1 編寫程式,用svm方法解決線性二分類問題 資料集 testse...