GAN網路之入門教程(二)之GAN原理

2022-01-10 20:59:53 字數 3253 閱讀 3523

在一篇部落格gan網路從入門教程(一)之gan網路介紹中,簡單的對gan網路進行了一些介紹,介紹了其是什麼,然後大概的流程是什麼。

在這篇部落格中,主要是介紹其數學公式,以及其演算法流程。當然數學公式只是簡單的介紹,並不會設計很複雜的公式推導。如果想詳細的了解gan網路的原理,推薦去看李巨集毅老師的課程。b站和youtube上面都有。

首先我們是可以知道真實的分布函式\(p_(x)\),同時我們把假的也看成乙個概率分布,稱之為\(p_g = (x,\theta)\)。那麼我們的目標是什麼呢?我們的目標就是使得\(p_g(x,\theta)\)盡量的去逼近\(p_(x)\)。在gan中,我們使用神經網路去逼近\(p_g = (x,\theta)\)。

在生成器中,我們有如下模型:

其中\(z \sim p_(z)\),因此\(g(z)\)也是乙個針對於\(z\)概率密度分布函式。

針對於判別器,我們有\(d(x,\theta)\),其代表某一張z\(x\)為真的概率。

在generative adversarial nets**中給出了以下的目標函式,也就是gan網路需要優化的東西。

\[\begin\min _ \max _ v(d, g)=\mathbb_ \sim p_}(\boldsymbol)}[\log d(\boldsymbol)]+\mathbb_ \sim p_}(\boldsymbol)}[\log (1-d(g(\boldsymbol)))]\end

\]公式看起來很複雜,但是我們分開來看還是比較簡單的。

\(d\)網路的目標是什麼?能夠辨別真假,也就是說,給定一張真的\(x\),\(d\)網路能夠給出乙個高分,也就是\(d(x)\)盡量大一點。而針對於生成器\(g\)生成的\(g(z)\),我們希望判別器\(d\)盡量給低分,也就是\(d(g(z))\)盡量的小一點。因此\(d\)網路的目標函式如下所示:

\[\begin\max _ v(d, g)=\mathbb_ \sim p_}(\boldsymbol)}[\log d(\boldsymbol)]+\mathbb_ \sim p_}(\boldsymbol)}[\log (1-d(g(\boldsymbol)))]\end

\]在目標函式中,\(x\)代表的是真實資料(也就是真的),\(g(z)\)代表的是生成器生成的。

\(g\)網路的目標就是使得\(d(g(z))\)盡量得高分,因此其目標函式可以寫成:

\[\begin\max _ v(d, g)=\mathbb_ \sim p_}(\boldsymbol)}[\log (d(g(\boldsymbol)))]\end

\]\(d(g(z))\)盡量得高分(分數在\([0,1]\)之間),等價於\(1 - d(g(z))\)盡量的低分,因此,上述目標函式等價於:

\[\begin\min _ v(d, g)=\mathbb_ \sim p_}(\boldsymbol)}[\log (1-d(g(\boldsymbol)))]\end

\]因此我們優化\(d^*\)和優化\(g^*\)結合起來,也就是變成了**中的目標函式:

\[\begin\min _ \max _ v(d, g)=\mathbb_ \sim p_}(\boldsymbol)}[\log d(\boldsymbol)]+\mathbb_ \sim p_}(\boldsymbol)}[\log (1-d(g(\boldsymbol)))]\end

\]上面的公式看起來很合理,但是如果不存在最優解的話,一切也都是無用功。

首先,我們固定g,來優化d,目標函式為:

\(\begin v(g, d)=\mathbb_ \sim p_}(\boldsymbol)}[\log d(\boldsymbol)]+\mathbb_ \sim p_}(\boldsymbol)}[\log (1-d(g(\boldsymbol)))]\end\)

我們可以寫做:

\[\begin\begin

v(g, d) &=\int_} p_}(\boldsymbol) \log (d(\boldsymbol)) d x+\int_} p_}(\boldsymbol) \log (1-d(g(\boldsymbol))) d z \\

&=\int_} [ p_}(\boldsymbol) \log (d(\boldsymbol))+p_(\boldsymbol) \log (1-d(\boldsymbol))] d x

\end\end

\]我們設(\(d\)代表\(d(x)\),可以代表任何函式):

\[f(d) = p_(x) log d + p_g(x)log(1-d)

\]對於每乙個固定的\(x\)而言,為了使\(v\)最大,我們當然是希望\(f(d)\)越大越好,這樣積分後的值也就越大。因為固定了\(g\),因此\(p_g(x)\)是固定的,而\(p_\)是客觀存在的,則值也是固定的。我們對\(f(d)\)求導,然後令\(f'(d) = 0\),可得:

\[\begind^=\frac(x)}(x)+p_(x)}\end

\]下圖表示了,給定三個不同的 \(g1,g3,g3\) 分別求得的令 \(v(g,d)\)最大的那個$ d^∗\(,橫軸代表了\)p_$,藍色曲線代表了可能的 \(p_g\),綠色的距離代表了 \(v(g,d)\):

同理,我們可以求\(\underset\ v(g,d)\),我們將前面的得到的\(d^=\frac(x)}(x)+p_(x)}\)帶入可得:

\[%

\]其中\(jsd ( p_(x) || p_g(x))\)的取值範圍是從 \(0\)到\(log2\),其中當\(p_ = p_g\)是,\(jsd\)取最小值0。也就是說$ v(g,d)$的取值範圍是\(0\)到\(-2log2\),也就是說$ v(g,d)\(存在最小值,且此時\)p_ = p_g$。

上述我們從理論上討論了全域性最優值的可行性,但實際上樣本空間是無窮大的,也就是我們沒辦法獲得它的真實期望(\(\mathbb_ \sim p_}(\boldsymbol)}\)和\(\mathbb_ \sim p_}}(\boldsymbol)\)是未知的),因此我們使用估測的方法來進行。

\[\tilde v = \frac\sum_^ log d(x^i) + \frac\sum_^ log (1-d(\tilde x^i))

\]演算法流程圖如下所示(來自生成對抗網路——原理解釋和數學推導):

深度學習之生成對抗網路(Gan)

概念 生成對抗網路 gan,generative adversatial networks 是一種深度學習模型,近年來無監督學習上最具前景的方法之一。模型主要通用框架有 至少 兩個模組 生成模型 generative 和判別模型 discriminative 的互相博弈學習產生的相當好的輸出。原始g...

GAN生成對抗網路之生成模型

什麼是生成模型?在開始講生成對抗網路之前,我們先看一下什麼是生成模型。在概率統計理論中,生成模型是指能夠在給定某些隱含引數的條件下,隨機生成觀測資料的模型,它給觀測值和標註資料序列指定乙個聯合概率分布。在機器學習中,生成模型可以用來直接對資料建模,如根據某個變數的概率密度函式進行資料取樣,也可以用來...

生成對抗網路GAN基本入門

2.深度gan dcgan 3.條件gan 4.infogan 5.wasserstein gan 6.例項 生成器 1.1 生成對抗網路 本質 生成器 組成 1.2 數學原理 初始狀態 生成資料同真實資料差距明顯,容易判別 訓練過程 對是否真實判斷得到的loss引導生成模型更新,差距減少 最終狀態...