Keras加權損失函式

2021-09-24 03:55:34 字數 2147 閱讀 1455

ref: 

keras提供的損失函式binary_crossentropy和categorical_crossentropy沒有加權,如果想實現樣本的不同權重功能有一種策略是對損失函式加權處理。

二分加權交叉熵損失

class weightedbinarycrossentropy(object):

def __init__(self, pos_ratio):

neg_ratio = 1. - pos_ratio

self.pos_ratio = tf.constant(pos_ratio, tf.float32)

self.weights = tf.constant(neg_ratio / pos_ratio, tf.float32)

self.__name__ = "weighted_binary_crossentropy()".format(pos_ratio)

def __call__(self, y_true, y_pred):

return self.weighted_binary_crossentropy(y_true, y_pred)

def weighted_binary_crossentropy(self, y_true, y_pred):

# transform to logits

epsilon = tf.convert_to_tensor(k.common._epsilon, y_pred.dtype.base_dtype)

y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)

y_pred = tf.log(y_pred / (1 - y_pred))

cost = tf.nn.weighted_cross_entropy_with_logits(y_true, y_pred, self.weights)

return k.mean(cost * self.pos_ratio, axis=-1)

多分類加權交叉熵

class weightedcategoricalcrossentropy(object):

def __init__(self, weights):

nb_cl = len(weights)

self.weights = np.ones((nb_cl, nb_cl))

for class_idx, class_weight in weights.items():

self.weights[0][class_idx] = class_weight

self.weights[class_idx][0] = class_weight

self.__name__ = 'w_categorical_crossentropy'

def __call__(self, y_true, y_pred):

return self.w_categorical_crossentropy(y_true, y_pred)

def w_categorical_crossentropy(self, y_true, y_pred):

nb_cl = len(self.weights)

final_mask = k.zeros_like(y_pred[..., 0])

y_pred_max = k.max(y_pred, axis=-1)

y_pred_max = k.expand_dims(y_pred_max, axis=-1)

y_pred_max_mat = k.equal(y_pred, y_pred_max)

for c_p, c_t in itertools.product(range(nb_cl), range(nb_cl)):

w = k.cast(self.weights[c_t, c_p], k.floatx())

y_p = k.cast(y_pred_max_mat[..., c_p], k.floatx())

y_t = k.cast(y_pred_max_mat[..., c_t], k.floatx())

final_mask += w * y_p * y_t

return k.categorical_crossentropy(y_pred, y_true) * final_mask

Keras筆記 損失函式的使用

keras中文文件 損失函式 或稱目標函式 優化評分函式 是編譯模型時所需的兩個引數之一 model.compile loss mean squared error optimizer sgd from keras import losses model.compile loss losses.me...

pytorch和keras損失函式區別

學習pytorch首先是要裝pytorch啦!但是這是乙個磨人的小妖精,傳統的用pip可是裝不上的。為此,可以參考我的另一篇部落格,這可是我研究一天的結晶!這篇筆記是 關於機器學習損失函式的,根據不同的應用場景,需要選擇不同的損失函式。線性回歸因為 的數值有具體的意義,所以損失函式一般使用的均方誤差...

keras中損失函式簡要總結

from keras.losses import 以下是正文。方差。差點忘了方差是什麼,丟死人。注重單個巨大偏差。差的絕對值的平均數。平均對待每個偏差。誤差百分數 非負 的平均數。比如50和150對100的誤差百分數都是50。自己感受。對數的方差。會將值先 1再取對數 奇怪。log cosh err...