深度學習 GAN生成式對抗網路

2021-09-25 18:05:32 字數 2895 閱讀 5314

生成式對抗網路(gan,generative adversarial network)的簡單理解就是,想想一名偽造者試圖偽造一幅畢卡索的畫作。一開始,偽造者非常不擅長這項任務,他隨便畫了幅與畢卡索真跡放在一起,請鑑定商進行評估,鑑定商鑑定後,將結果反饋給偽造者,並告訴他怎樣可以讓❀看起來更像畢卡索的真跡。偽造者學習後回去重新畫,然後再拿給鑑定商鑑定,多次迴圈後,偽造者已經十分熟練的偽造畢卡索的畫作了,鑑定商的鑑定能力也有了很大的提高。最後,他們手上擁有了一些優秀的畢卡索贗品。

下邊的例子是使用gan模擬二次方程:

import torch

import torch.nn as nn

import matplotlib.pyplot as plt

import numpy as np

"""超引數"""

batch_size = 64 # 每批資料個數

lr_g = 0.0001 # 生成器的學習率(偽造者)

lr_d = 0.0001 # 判別器的學習率(鑑定商)

n_ideas = 5 # 隨機想法個數

art_components = 15 # 線段上資料點個數

paint_points = np.vstack([np.linspace(-1, 1, art_components) for _ in range(batch_size)]) # h恩座標範圍

"""建立聖經網路"""

def artist_works():

a = np.random.uniform(1, 2, size=batch_size)[:, np.newaxis]

paintings = a * np.power(paint_points, 2) + (a - 1) # 定義原始一元二次方程(一開始的真品)

paintings = torch.from_numpy(paintings).float() # 轉換為torch形式

return paintings

g = nn.sequential( # 生成器(偽造者)

nn.linear(n_ideas, 128), # 輸入隨即想法

nn.relu(),

nn.linear(128, art_components) # 生成15個點連線(創造乙個贗品)

)d = nn.sequential( # 判別器(鑑定上)

nn.linear(art_components, 128), # 接受生成器生成的資料(獲得贗品)

nn.relu(),

nn.linear(128, 1), # 判別是否和原始資料相似(鑑定贗品是真是假)

nn.sigmoid() # 產生百分比,表示是什麼資料(表示是真品還是贗品)

)opt_d = torch.optim.rmsprop(d.parameters(), lr=lr_d) # 優化判別器

opt_g = torch.optim.rmsprop(g.parameters(), lr=lr_g) # 優化生成器

"""訓練神經網路"""

plt.ion()

for step in range(5000):

artist_paintings = artist_works() # 先獲取原始標準方程(一開始的真品)

g_ideas = torch.randn(batch_size, n_ideas) # 隨機生成資料(想法)

g_paintings = g(g_ideas) # 生成器產生方程(創造贗品)

prob_artist0 = d(artist_paintings) # 計算式標準方程的概率(真品的概率)

prob_artist1 = d(g_paintings) # 計算式偽造方程的概率(贗品的概率)

d_loss = -torch.mean(torch.log(prob_artist0) + torch.log(1 - prob_artist1)) # 增加標準方程的概率

g_loss = torch.mean(torch.log(1 - prob_artist1)) # 增加偽造方程被認為是真方程的概率

opt_d.zero_grad()

d_loss.backward(retain_graph=true)

opt_d.step()

opt_g.zero_grad()

g_loss.backward()

opt_g.step()

"""迴圈列印"""

if step % 100 == 0: # 每100步列印一次

plt.cla() # 清空上一次的

plt.plot(paint_points[0], g_paintings.data.numpy()[0], c='#4ad631', lw=3,

label='generated painting', ) # 隨機生成的方程

plt.plot(paint_points[0], 2 * np.power(paint_points[0], 2) + 1, c='#74bcff', lw=3,

label='upper bound') # 標準方程上界

plt.plot(paint_points[0], 1 * np.power(paint_points[0], 2) + 0, c='#ff9359', lw=3,

label='lower bound') # 標準方程下界

生成式對抗網路GAN

判別式模型和生成是模型的區別 假設研究物件為變數為x,類別變數為y,則 判別式模型 只是對給定的樣本進行分類,不關心資料如何生成。按照一定的判別準則,從資料中直接學習決策函式y f x 或者條件概率分布p y x a 作為 的模型 典型的判別模型包括 k近鄰法,決策樹,最大熵模型,支援向量機等 生成...

生成式對抗網路GAN

一 背景 gan的用途 影象超畫素 背景模糊 影象修復 二 生成式對抗網路gan 生成模型 乙個能夠生成我們想要的資料的模型 圖模型 函式 神經網路 gan 目的就是訓練乙個生成模型,生成我們想要的資料 生成器和判別器是對抗的 三 訓練演算法 隨機初始化生成器和判別器 交替訓練判別器d和生成器g直到...

深度學習之生成對抗網路(Gan)

概念 生成對抗網路 gan,generative adversatial networks 是一種深度學習模型,近年來無監督學習上最具前景的方法之一。模型主要通用框架有 至少 兩個模組 生成模型 generative 和判別模型 discriminative 的互相博弈學習產生的相當好的輸出。原始g...