神經網路高維互資訊計算Python實現(MINE)

2021-10-12 08:39:28 字數 3788 閱讀 8588

belghazi, mohamed ishmael, et al. 「mutual information neural estimation.」 international conference on machine learning. 2018.

利用神經網路的梯度下降法可以實現快速高維連續隨機變數之間互資訊的估計,上述**提出了mutual information neural estimator (mine)。nn在維度和樣本量上都是線性可伸縮的,mi的計算可以通過反向傳播進行訓練。

現有github上的**無法計算和估計高維隨機變數,只能計算一維隨機變數,下面的**給出的修改方案能夠計算真實和估計高維隨機變數的真實互資訊。

其中,為了計算理論的真實互資訊,我們不直接暴力求解矩陣(耗時,這也是為什麼要有mine的原因),我們採用給定生成隨機變數的引數計算理論互資訊。

signal_noise = 0.2

signal_power = 3

完整**基於pytorch

# name: mine_******

# author: reacubeth

# time: 2020/12/15 18:49

# mail: [email protected]

# site: www.omegaxyz.com

# *_*coding:utf-8 *_*

import numpy as np

import torch

import torch.nn as nn

from tqdm import tqdm

import matplotlib.pyplot as plt

signal_noise =

0.2signal_power =

3data_dim =

3num_instances =

20000

defgen_x

(num, dim)

:return np.random.normal(0.

, np.sqrt(signal_power)

,[num, dim]

)def

gen_y

(x, num, dim)

:return x + np.random.normal(0.

, np.sqrt(signal_noise)

,[num, dim]

)def

true_mi

(power, noise, dim)

:return dim *

0.5* np.log2(

1+ power/noise)

mi = true_mi(signal_power, signal_noise, data_dim)

print

('true mi:'

, mi)

hidden_size =

10n_epoch =

500class

mine

(nn.module)

:def

__init__

(self, hidden_size=10)

:super

(mine, self)

.__init__(

) self.layers = nn.sequential(nn.linear(

2* data_dim, hidden_size)

, nn.relu(),

nn.linear(hidden_size,1)

)def

forward

(self, x, y)

: batch_size = x.size(0)

tiled_x = torch.cat(

[x, x,

], dim=0)

idx = torch.randperm(batch_size)

shuffled_y = y[idx]

concat_y = torch.cat(

[y, shuffled_y]

, dim=0)

inputs = torch.cat(

[tiled_x, concat_y]

, dim=1)

logits = self.layers(inputs)

pred_xy = logits[

:batch_size]

pred_x_y = logits[batch_size:

] loss =

- np.log2(np.exp(1)

)*(torch.mean(pred_xy)

- torch.log(torch.mean(torch.exp(pred_x_y)))

)# compute loss, you'd better scale exp to bit

return loss

model = mine(hidden_size)

optimizer = torch.optim.adam(model.parameters(

), lr=

0.01

)plot_loss =

all_mi =

for epoch in tqdm(

range

(n_epoch)):

x_sample = gen_x(num_instances, data_dim)

y_sample = gen_y(x_sample, num_instances, data_dim)

x_sample = torch.from_numpy(x_sample)

.float()

y_sample = torch.from_numpy(y_sample)

.float()

loss = model(x_sample, y_sample)

model.zero_grad(

) loss.backward(

) optimizer.step(

)-loss.item())

fig, ax = plt.subplots(

)ax.plot(

range

(len

(all_mi)

), all_mi, label=

'mine estimate'

)ax.plot([0

,len

(all_mi)],

[mi, mi]

, label=

'true mutual information'

)ax.set_xlabel(

'training steps'

)ax.legend(loc=

'best'

)plt.show(

)

結果

變數維度為1

變數維度為3

需要指出的是在計算最終的互資訊時需要將基數e轉為基數2。如果只是求得乙個比較值,在真實使用的過程中可以省略。

更多互資訊公式

互資訊建立基因網路(一)

之前一直摸不著頭腦,現在我心裡有點道道了。寫在前面的話 我一直以為研究生有時候之所以很難,是因為沒人帶你玩了,要靠你自己去學習。既然是自己探索,基礎就很重要。一句話,沒有基礎,談不上創新。但是很多時候,我們接觸的都是新領域的知識,怎麼才能很快入手呢?我就是特別喜歡找最簡單的最有趣的材料去學,總之越容...

一維卷積神經網路 卷積神經網路中的計算

卷積的基本介紹 卷積操作後張量的大小計算 卷積參數量的計算 卷積flops的計算 感受野的計算 卷積神經網路中的卷積是指定義好卷積核 kernel 並對影象 或者特徵圖,feature map 進行滑動匹配,即對應位置相乘再相加。其特點就在於能夠捕捉區域性的空間特徵。具體過程如下圖所示 圖1 二維卷...

一維卷積神經網路的理解

設輸入的資料維度是b x s x t 一維卷積神經網路在維度s上進行卷積 如下,設定一維卷積網路的輸入通道為16維,輸出通道為33維,卷積核大小為3,步長為2 in channels 16 out channels 33 kernel size 3 m nn.conv1d 16,33,3,strid...