多輸出感知機及其梯度

2022-04-29 14:42:05 字數 2619 閱讀 2159

目錄

對於多輸出感知機,每個輸出元只和輸出元上的x和w和σ

'>σ

σ有關。

import tensorflow as tf
x = tf.random.normal([2, 4])

w = tf.random.normal([4, 3])

b = tf.zeros([3])

y = tf.constant([2, 0])

with tf.gradienttape() as tape:

tape.watch([w, b])

# axis=1,表示結果[b,3]中的3這個維度為概率

prob = tf.nn.softmax(x @ w + b, axis=1)

# 2 --> 001; 0 --> 100

loss = tf.reduce_mean(tf.losses.mse(tf.one_hot(y, depth=3), prob))

grads = tape.gradient(loss, [w, b])

grads[0]
id=92,

shape=(4,3),

dtype=float32,

numpy=

array([[

0.00842961

,-0.02221732

,0.01378771

], [ 0.02969089, -0.04625662, 0.01656573],

[ 0.05807886, -0.08139262, 0.02331377],

[-0.06571108, 0.11157083, -0.04585974]],

dtype=float32)>

grads[1]
id=90,

shape=(3,),

dtype=float32,

numpy=array([-0.05913186,

0.09886257, -0.03973071], dtype=float32)>

目錄

e=12∑(

oi1−

ti)2

'>e=1

2∑(o

1i−t

i)2e=12∑(oi1−ti)2

對於多輸出感知機,每個輸出元只和輸出元上的x和w和σ

'>σ

σ有關。

import tensorflow as tf
x = tf.random.normal([2, 4])

w = tf.random.normal([4, 3])

b = tf.zeros([3])

y = tf.constant([2, 0])

with tf.gradienttape() as tape:

tape.watch([w, b])

# axis=1,表示結果[b,3]中的3這個維度為概率

prob = tf.nn.softmax(x @ w + b, axis=1)

# 2 --> 001; 0 --> 100

loss = tf.reduce_mean(tf.losses.mse(tf.one_hot(y, depth=3), prob))

grads = tape.gradient(loss, [w, b])

grads[0]
id=92,

shape=(4,3),

dtype=float32,

numpy=

array([[

0.00842961

,-0.02221732

,0.01378771

], [ 0.02969089, -0.04625662, 0.01656573],

[ 0.05807886, -0.08139262, 0.02331377],

[-0.06571108, 0.11157083, -0.04585974]],

dtype=float32)>

grads[1]
id=90,

shape=(3,),

dtype=float32,

numpy=array([-0.05913186,

0.09886257, -0.03973071], dtype=float32)>

感知機介紹及實現

感知機 perceptron 由rosenblatt於1957年提出,是神經網路與支援向量機的基礎。感知機是最早被設計並被實現的人工神經網路。感知機是一種非常特殊的神經網路,它在人工神經網路的發展史上有著非常重要的地位,儘管它的能力非常有限,主要用於線性分類。感知機還包括多層感知機,簡單的線 知機用...

keras搬磚系列 keras多輸入多輸出模型

使用函式式模型的乙個典型的場景就是搭建多輸入,多輸出模型。考慮這樣乙個模型,希望 一條新聞會被 和點讚多少次。模型的主要輸入是新聞的本身,也就是乙個詞語的序列,但是我們可能還需要額外的輸入,新聞發布的日期等,所以這個模型的損失函式將會由兩個部分組成,輔助的損失函式基於新聞本身做出的 的情況,主損失函...

感知機原理及python實現

感知機python實現 給定乙個資料集 t yi 輸入空間中任意一點x0 到超平面s的距離為 1 w yi w x0 b 這裡 w 是 w的l2 範數 假 設超平面 s的誤分 點集合為 m,那麼 所有誤分 點到超平 面s的總 距離為 1 w xi myi w xi b 在不考慮 1 w 的情 況下得...