生成對抗網路GAN的keras例項

2021-10-19 10:42:01 字數 3506 閱讀 2996

匯入一些需要的包

from keras.layers import input, dense, reshape, flatten, dropout

from keras.layers import batchnormalization, activation, zeropadding2d

from keras.layers.advanced_activations import leakyrelu

from keras.layers.convolutional import upsampling2d, conv2d

from keras.models import sequential, model

from keras.optimizers import adam

import matplotlib.pyplot as plt

import sys

import numpy as np

定義優化器

optimizer = adam(

0.0002

,0.5

)

構建鑑別器並編譯

n_y_value =

20d = sequential(

)d.add(dense(

512)

)d.add(leakyrelu(alpha=

0.2)

)d.add(dense(

256)

)d.add(dense(

1, activation=

'sigmoid'))

# d.summary()

img = input(shape=

(n_y_value,))

validity = d(img)

discriminator = model(img,validity)

discriminator.

compile

(loss=

'binary_crossentropy'

, optimizer=optimizer,

metrics=

['accuracy'

])

構建生成器,並組合生成器和鑑別器成gan

n_ideas =

5g = sequential(

)g.add(dense(

512,input_dim=n_ideas)

)g.add(leakyrelu(alpha=

0.2)

)g.add(batchnormalization(momentum=

0.8)

)g.add(dense(

512)

)g.add(leakyrelu(alpha=

0.2)

)g.add(batchnormalization(momentum=

0.8)

)g.add(dense(

1024))

g.add(leakyrelu(alpha=

0.2)

)g.add(batchnormalization(momentum=

0.8)

)g.add(dense(n_y_value, activation=

'tanh'))

# g.add(reshape(n_y_value))

# g.summary()

noise = input(shape=

(n_ideas,))

g_img = g(noise)

generator = model(noise,g_img)

z = input(shape=

(n_ideas,))

g_img = generator(z)

discriminator.trainable =

false

validity = discriminator(g_img)

gan = model(z,validity)

gan.

compile

(loss=

'binary_crossentropy'

, optimizer=optimizer)

訓練過程

batch_size=

64x= np.vstack(

[np.linspace(-1

,1,n_y_value)

for _ in

range

(batch_size)])

true_imgs =np.power(x,2)

valid = np.ones(

(batch_size,1)

)fake = np.zeros(

(batch_size,1)

)plt.ion(

)for i in

range

(1600):

noise = np.random.normal(0,

1,(batch_size, n_ideas)

) g_imgs = generator.predict(noise)

d_loss_fake = discriminator.train_on_batch(g_imgs,fake)

d_loss_real = discriminator.train_on_batch(true_imgs,valid)

d_loss =

0.5* np.add(d_loss_real, d_loss_fake)

g_loss = gan.train_on_batch(noise,valid)

print

("%d [d loss: %f, acc.: %.2f%%] [g loss: %f]"

%(i, d_loss[0]

,100

* d_loss[1]

, g_loss)

)# print("g_imgs.shape:",g_imgs.shape) #(64,20)

plt.cla(

) plt.xlim((-

1.2,

1.2)

) plt.ylim((-

0.2,

1.2)

) plt.plot(x[0]

, true_imgs[0]

, lw=

2, c=

'#11aaaa'

) plt.plot(x[0]

,g_imgs[0]

, lw=

2, c=

'#b62a2a'

) plt.pause(

0.01

)plt.ioff(

)plt.show(

)

最終網路生成的虛假與真實如下

GAN 生成對抗網路

原理 假設我們有兩個網路 乙個生g generator 乙個判別d discriminator g是乙個生成的的網路,它接受乙個隨機的雜訊z,通過這個雜訊生成,記做g z d是乙個判別網路,判斷一張是不是 真實的 它的輸入引數是x,x代表一張的。輸出d x 代表x為真實的概率,如果為1,就代表100...

生成對抗網路 GAN

原文 generative adversarial networks 模型組成 核心公式 演算法圖示化描述 全域性最優點 pg pdata 效果與對比展望 ming maxdv d,g exp data x logd x exp x x log 1 d g z 分析 上方為 gan 網路的核心演算法...

GAN(生成對抗網路)

gan,generative adversarial network.起源於2014年,nips的一篇文章,generative adversarial net.gan,是一種二人博弈的思想,雙方利益之和是乙個常數,是固定的。你的利益多點,對方利益就少點。gan裡面,博弈雙方是 乙個叫g 生成模型 ...