pytorch中attention的兩種實現方式

2021-10-01 13:33:03 字數 3427 閱讀 7184

"""採用seq to seq模型,修改注意力權重的計算方式

"""def __init__(self, hidden_size, output_size, dropout_p=0.1):

super(attentiondecoderv2, self).__init__()

self.hidden_size = hidden_size

self.output_size = output_size

self.dropout_p = dropout_p

self.embedding = nn.embedding(self.output_size, self.hidden_size)

self.attn_combine = nn.linear(self.hidden_size * 2, self.hidden_size)

self.dropout = nn.dropout(self.dropout_p)

self.gru = nn.gru(self.hidden_size, self.hidden_size)

self.out = nn.linear(self.hidden_size, self.output_size)

# test

self.vat = nn.linear(hidden_size, 1)

def forward(self, input, hidden, encoder_outputs):

embedded = self.embedding(input) # 前一次的輸出進行詞嵌入

embedded = self.dropout(embedded)

# test

batch_size = encoder_outputs.shape[1]

alpha = hidden + encoder_outputs # 特徵融合採用+/concat其實都可以

alpha = alpha.view(-1, alpha.shape[-1])

attn_weights = self.vat( torch.tanh(alpha)) # 將encoder_output:batch*seq*features,將features的維度降為1

attn_weights = attn_weights.view(-1, 1, batch_size).permute((2,1,0))

attn_weights = f.softmax(attn_weights, dim=2)

# attn_weights = f.softmax(

# self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1) # 上一次的輸出和隱藏狀態求出權重

encoder_outputs.permute((1, 0, 2))) # 矩陣乘法,bmm(8×1×56,8×56×256)=8×1×256

output = self.attn_combine(output).unsqueeze(0)

output = f.relu(output)

output, hidden = self.gru(output, hidden)

output = f.log_softmax(self.out(output[0]), dim=1) # 最後輸出乙個概率

return output, hidden, attn_weights

def inithidden(self, batch_size):

result = variable(torch.zeros(1, batch_size, self.hidden_size))

return result

配圖是第一種

Pytorch 中 torchvision的錯誤

在學習pytorch的時候,使用 torchvision的時候發生了乙個小小的問題 安裝都成功了,並且import torch也沒問題,但是在import torchvision的時候,出現了如下所示的錯誤資訊 dll load failed 找不到指定模組。首先,我們得知道torchvision在...

Pytorch中建立DataLoader的幾種方法

簡介 這段 是mnist手寫體識別中的部分 此篇 為mnist手寫體識別中的 import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import datalo...

pytorch中的乘法

總結 按元素相乘用torch.mul,二維矩陣乘法用torch.mm,batch二維矩陣用torch.bmm,batch 廣播用torch.matmul if name main a torch.tensor 1 2,3 b torch.arange 0,12 reshape 4 3 c torch...