理解GAN對抗神經網路的損失函式和訓練過程

2021-10-06 08:40:03 字數 4950 閱讀 6623

gan最不好理解的就是loss函式的定義和訓練過程,這裡用一段**來輔助理解,就能明白到底是怎麼回事。其實gan的損失函式並沒有特殊之處,就是常用的binary_crossentropy,關鍵在於訓練過程中存在兩個神經網路和兩個損失函式。

np.random.seed(42)

tf.random.set_seed(42)

codings_size =

30generator = keras.models.sequential(

[ keras.layers.dense(

100, activation=

"selu"

, input_shape=

[codings_size]),

keras.layers.dense(

150, activation=

"selu"),

keras.layers.dense(28*

28, activation=

"sigmoid"),

keras.layers.reshape([28

,28])

])discriminator = keras.models.sequential(

[ keras.layers.flatten(input_shape=[28

,28])

, keras.layers.dense(

150, activation=

"selu"),

keras.layers.dense(

100, activation=

"selu"),

keras.layers.dense(

1, activation=

"sigmoid")]

)gan = keras.models.sequential(

[generator, discriminator]

)discriminator.

compile

(loss=

"binary_crossentropy"

, optimizer=

"rmsprop"

)discriminator.trainable =

false

gan.

compile

(loss=

"binary_crossentropy"

, optimizer=

"rmsprop"

)batch_size =

32dataset = tf.data.dataset.from_tensor_slices(x_train)

.shuffle(

1000

)dataset = dataset.batch(batch_size, drop_remainder=

true

).prefetch(

1)

這裡generator並不用compile,因為gan網路已經compile了。具體原因見下文。

訓練過程的**如下

def

train_gan

(gan, dataset, batch_size, codings_size, n_epochs=50)

: generator, discriminator = gan.layers

for epoch in

range

(n_epochs)

:print

("epoch {}/{}"

.format

(epoch +

1, n_epochs)

)# not shown in the book

for x_batch in dataset:

# phase 1 - training the discriminator

noise = tf.random.normal(shape=

[batch_size, codings_size]

) generated_images = generator(noise)

x_fake_and_real = tf.concat(

[generated_images, x_batch]

, axis=0)

y1 = tf.constant([[

0.]]

* batch_size +[[

1.]]

* batch_size)

discriminator.trainable =

true

discriminator.train_on_batch(x_fake_and_real, y1)

# phase 2 - training the generator

noise = tf.random.normal(shape=

[batch_size, codings_size]

) y2 = tf.constant([[

1.]]

* batch_size)

discriminator.trainable =

false

gan.train_on_batch(noise, y2)

plot_multiple_images(generated_images,8)

# not shown

plt.show(

)# not shown

第一階段(discriminator訓練)

# phase 1 - training the discriminator

noise = tf.random.normal(shape=

[batch_size, codings_size]

)generated_images = generator(noise)

x_fake_and_real = tf.concat(

[generated_images, x_batch]

, axis=0)

y1 = tf.constant([[

0.]]

* batch_size +[[

1.]]

* batch_size)

discriminator.trainable =

true

discriminator.train_on_batch(x_fake_and_real, y1)

這個階段首先生成數量相同的真實和假,concat在一起,即x_fake_and_real = tf.concat([generated_images, x_batch], axis=0)。然後是label,真的label是1,假的label是0。

然後是迅速階段,首先將discrinimator設定為可訓練,discriminator.trainable = true,然後開始階段。第乙個階段的訓練過程只訓練discriminator,discriminator.train_on_batch(x_fake_and_real, y1),而不是整個gan網路gan

第二階段(generator訓練)

# phase 2 - training the generator

noise = tf.random.normal(shape=

[batch_size, codings_size]

)y2 = tf.constant([[

1.]]

* batch_size)

discriminator.trainable =

false

gan.train_on_batch(noise, y2)

在第二階段首先生成假,但是不再生成真。把假的label全部設定為1,並把discriminator的權重凍結,即discriminator.trainable = false。這一步很關鍵,應該這麼理解:

前面第一階段的是discriminator的訓練,使真的**值盡量接近1,假的**值盡量接近0,以此來達到優化損失函式的目的。現在將discrinimator的權重凍結,網路中輸入假,並故意把label設定為1。

注意,在整個gan網路中,從上向下的順序是先通過geneartor,再通過discriminator,即gan = keras.models.sequential([generator, discriminator])。第二個階段將discrinimator凍結,並訓練網路gan.train_on_batch(noise, y2)。如果generator生成的足夠真實,經過discrinimator後label會盡可能接近1。由於故意把y2的label設定為1,所以如果genrator生成的足夠真實,此時generator訓練已經達到最優狀態,不會大幅度更新權重;如果genrator生成的不夠真實,經過discriminator之後,**值會接近0,由於y2的label是1,相當於**值不準確,這時候gan網路的損失函式較大,generator會通過更新generator的權重來降低損失函式。

之後,重新回到第一階段訓練discriminator,然後第二階段訓練generator。假設整個gan網路達到理想狀態,這時候generator產生的假,經過discriminator之後,**值應該是0.5。假如這個值小於0.5,證明generator不是特別準確,在第二階段訓練過程中,generator的權重會被繼續更新。假如這個值大於0.5,證明discriminator不是特別準確,在第一階段訓練中,discriminator的權徵會被繼續更新。

簡單說,對於一張generator生成的假,discriminator會盡量把**值拉下拉,generator會盡量把**值往上扯,類似乙個拔河的過程,最後達到均衡狀態,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。

對抗神經網路(GAN)

對抗神經網路其實是兩個網路的組合,可以理解為乙個網路生成模擬資料,另乙個網路判斷生成的資料是真實的還是模擬的。生成模擬資料的網路要不斷優化自己讓判別的網路判斷不出來,判別的網路也要優化自己讓自己判斷得更準確。二者關係形成對抗,因此叫對抗神經網路。實驗證明,利用這種網路間的對抗關係所形成的網路,在無監...

GAN生成對抗神經網路原理(一)

1.基本原理 此處以生成為例進行說明 假設有2個網格,g generator 和d discriminator 功能分別是 g 生成的網格 接收乙個隨機的雜訊z,通過這個雜訊生成,記作g z d 判別網格,判別一張是不是 真實的 它的輸入引數是x,x代表一張,輸出d x 代表x真實的概率 若為1,代...

對抗神經網路的應用

接下來,我們要為你介紹一款能夠偽造人臉影象的ai neural face。neural face使用了facebook 人工智慧研究團隊開發的深度卷積神經網路 dcgan 研發團隊用由100個0到1的實數組成的1個向量z來代表每一張影象。通過計算出人類影象的分布,生成器就可以用高斯分布 gaussi...