tensflow自定義損失函式

2022-01-22 19:55:27 字數 1883 閱讀 8186

tensflow 不僅支援經典的損失函式,還可以優化任意的自定義損失函式。

**商品銷量時,如果**值比真實銷量大,商家損失的是生產商品的成本;如果**值比真實值小,損失的則是商品的利潤。

比如如果乙個商品的成本是1元,但利潤是10元,那麼少**乙個就少賺9元;而多**乙個才虧1元,為了最大化利潤預期,需要將損失函式和利潤直接聯絡起來。注意損失函式

定義的是損失,所以要將利潤最大化,定義的損失函式應該刻成本或者代價。下面給出了乙個當**多於真實值和**少於真實值時有不同損失係數和損失函式:

import  tensorflow as

tffrom

numpy.random import randomstate

batch_size=8

x=tf.placeholder(tf.float32,shape=(none,2),name='

x-input')

y_=tf.placeholder(tf.float32,shape=(none,1),name='

y-input')

w1=tf.variable(tf.random_normal([2,1],stddev=1,seed=1

))y=tf.matmul(x,w1)

logss_less=10

logss_more=1

# 損失函式

logss=tf.reduce_sum(tf.where

(tf.greater(y,y_),

(y-y_)*logss_more,

(y_-y)*logss_less))

train_step=tf.train.adamoptimizer(0.001

).minimize(logss)

rdm=randomstate(1

)dataset_size=128

x=rdm.rand(dataset_size,2

)y=[[x1 +x2+rdm.rand()/10.0-0.05] for (x1,x2) in

x]with tf.session()

assess:

init_op=tf.global_variables_initializer()

sess.run(init_op)

steps=5000

for i in

range(steps):

start=(i*batch_size) %dataset_size

end=min(start+batch_size,dataset_size)

sess.run(train_step,

feed_dict=

)print sess.run(w1)

結果為...

[[1.0194283]

[1.0428752]]

[[1.0194151]

[1.0428821]]

[[1.019347 ]

[1.0428089]]

所以**函式的值是1.02x1+1.04x2,這要比x1+x2大,因為在損失函式中指定**少了的損失(logss_less>loss_more.如果將log_less的值調整為1,log_more的值調整為10,

那麼結果將會如下

[[0.95491844]

[0.9814671 ]]

[[0.95506585]

[0.98148215]]

[[0.9552581]

[0.9813394]]

也就是說,在這樣的設定下,模型會更加偏向於**少一點,而如果使用均方誤差作為損失函式,那麼w1會是[0.97437561,1.0243336],使用這個損失函式會盡量讓**值

離標準答案更近。通過這個樣例可以看出,對於相同的神經網路,不同的損失函式會對訓練得到的模型產生重要影響。

tensflow自定義損失函式

tensflow 不僅支援經典的損失函式,還可以優化任意的自定義損失函式。商品銷量時,如果 值比真實銷量大,商家損失的是生產商品的成本 如果 值比真實值小,損失的則是商品的利潤。比如如果乙個商品的成本是1元,但利潤是10元,那麼少 乙個就少賺9元 而多 乙個才虧1元,為了最大化利潤預期,需要將損失函...

Keras例項 自定義損失函式 指標函式

在訓練模型的時候,keras提供了許多損失函式供我們使用,但是即便如此,我們也會有遇到需要用自己的損失函式的情況,這樣我們就要自定義乙個損失函式。比如我現在需要定義乙個損失函式,類似於relu函式,低於threshold的loss為0,大於threshold的loss就是他們之間的差。注意我們在定義...

TensorFlow實戰 自定義損失函式完整案例

import tensorflow as tf from numpy.random import randomstate import os os.environ tf cpp min log level 2 batch size 8 兩個輸入節點。x tf.placeholder tf.float...