生成對抗網路(二)CGAN

2021-09-24 16:46:13 字數 2978 閱讀 1477

之前介紹了生成式對抗網路(gan),關於gan的變種比較多,我打算將幾種常見的gan做乙個總結,也算是激勵自己學習,分享自己的一些看法和見解。

之前提到的gan是最基本的模型,我們的輸入是隨機雜訊,輸出的是對應的影象,但是我們沒法控制生成影象的型別。比如,我要生成一張數字0的,但是gan生成的卻是數字0-9的,針對這個問題,conditional generative adversarial nets被提了出來,在原有gan的基礎上,新增了類別資訊以便讓模型生成特定的。這裡的條件(conditional),就是這個額外的類別資訊。

由於在gan的生成器和判別器中都加入了額外的類別資訊,模型的目標優化函式也發生了變化。

生成器的輸入變為噪音變數

就是在gan的目標函式上新增了y這一類別變數,x變為了條件分布。

模型的結構圖如下,

gan的結構與這個類似,生成器部分和判別器部分是分開的兩個子網路,單獨進行訓練。類別資訊y是通過embedding層嵌入的。

具體的實現可以看看**:

生成器:

def build_generator(self):

model = sequential()

model.add(dense(256, input_dim=self.latent_dim))

model.add(leakyrelu(alpha=0.2))

model.add(batchnormalization(momentum=0.8))

model.add(dense(512))

model.add(leakyrelu(alpha=0.2))

model.add(batchnormalization(momentum=0.8))

model.add(dense(1024))

model.add(leakyrelu(alpha=0.2))

model.add(batchnormalization(momentum=0.8))

model.add(dense(np.prod(self.img_shape), activation='tanh'))

model.add(reshape(self.img_shape))

model.summary()

noise = input(shape=(self.latent_dim,))

label = input(shape=(1,), dtype='int32')

label_embedding = flatten()(embedding(self.num_classes, self.latent_dim)(label))

model_input = multiply([noise, label_embedding])

img = model(model_input)

return model([noise, label], img)

標籤是通過嵌入層實現的,embedding層可以將類別標籤轉換為對應的向量表示,在此生成器中,類別有10個(0-9),對應embedding中的input_dim, 輸出維度和噪音資料是相同的,之後,再利用multiply層將兩者逐項做乘積,這便是生成器的輸入。

判別器:

def build_discriminator(self):

model = sequential()

model.add(dense(512, input_dim=np.prod(self.img_shape)))

model.add(leakyrelu(alpha=0.2))

model.add(dense(512))

model.add(leakyrelu(alpha=0.2))

model.add(dropout(0.4))

model.add(dense(512))

model.add(leakyrelu(alpha=0.2))

model.add(dropout(0.4))

model.add(dense(1, activation='sigmoid'))

model.summary()

img = input(shape=self.img_shape)

label = input(shape=(1,), dtype='int32')

label_embedding = flatten()(embedding(self.num_classes, np.prod(self.img_shape))(label))

flat_img = flatten()(img)

model_input = multiply([flat_img, label_embedding])

validity = model(model_input)

return model([img, label], validity)

判別器的輸入和生成器是一樣的,輸出是對應的的類別。

訓練:訓練採用的mnist資料集,訓練時需要將資料和對應的標籤輸入模型。

生成器和判決器作為乙個整體進行訓練的時候,判別器是不訓練的,這時只訓練生成器;當判決器作為乙個單獨的模型時,判決器會得到訓練。二者的訓練是交替進行的。

具體的**可以參考github

最後跑出來的效果還是很不錯的,我在台式電腦上跑的,用的是1050ti的顯示卡,訓練速度還比較快,一共20000輪,大概10分鐘左右跑完。

這是最後的訓練效果:

可以與前一篇部落格裡面的內容進行比較,與原始的gan相比,效果要好一些,但是還是不是很清晰。一方面,mnist提供的畫素較低,另一方面,我們採用的是全連線神經網路,對於的處理效果並不是很好。

要生成更加清晰地,可以利用dcgan,這也是我接下來要做的工作。

生成對抗網路 二 cGAN

cgan conditional gan 也是最基礎的gan模型,和gan原文同時發表在nips2014上面。事實上,cgan在gan的基礎上並沒有做很大的改動,下文會主要分析一下cgan的改動。conditional generative adversarial nets 在訓練判別器d的時候,給...

半監督生成對抗網路 生成對抗網路

一 生成對抗網路相關概念 一 生成模型在概率統計理論中,生成模型是指能夠在給定某些隱含引數的條件下,隨機生成觀測資料的模型,它給觀測值和標註資料序列指定乙個聯合概率分布。在機器學習中,生成模型可以用來直接對資料建模,也可以用來建立變數間的條件概率分布。通常可以分為兩個型別,一種是可以完全表示出資料確...

生成對抗網路

我們提出乙個框架來通過對抗方式評估生成模型,我們同時訓練兩個模型 乙個生成模型g捕捉資料分布,乙個鑑別模型d估計乙個樣本來自於訓練資料而不是g的概率。g的訓練過程是最大化d犯錯的概率。這個框架與minmax兩個玩家的遊戲相對應。在任意函式g和d的空間存在乙個唯一解,g恢復訓練資料的分布,d等於1 2...