pytorch中的迴圈神經網路模組

2021-08-28 20:13:07 字數 884 閱讀 3687

對於最簡單的rnn,我們可以使用以下兩個方法呼叫,分別是torch.nn.rnncell()torch.nn.rnn(),這兩種方式的區別在於rnncell()只能接受序列中單步的輸入,且必須傳入隱藏狀態,而rnn()可以接受乙個序列的輸入,缺省會傳入全 0 的隱藏狀態,也可以自己申明隱藏狀態傳入。

rnn()的引數:

input_size 表示輸入特徵的維度;

hidden_size表示輸出特徵的維度;

num_layers表示網路的層數;

nonlinearity表示選用的是非線性啟用函式,預設是『tanh』;

bias表示是否使用偏置,預設是使用;

batch_first 表示輸入資料的形式,預設是 false,就是這樣形式,(seq, batch, feature),也就是將序列長度放在第一位,batch 放在第二位

dropout 表示是否在輸出層應用 dropout;

bidirectional 表示是否使用雙向的 rnn,預設是 false;

對於rnncell(),裡面的引數就少很多,只有 input_size,hidden_size,bias 以及 nonlinearity;

一般情況下我們都是用nn.rnn()而不是nn.rnncell(),因為nn.rnn()能夠避免我們手動寫迴圈,非常方便,同時如果不特別說明,我們也會選擇使用預設的全 0 初始化隱藏狀態。

lstm 和基本的 rnn 是一樣的,他的引數也是相同的,同時他也有nn.lstmcell()nn.lstm()兩種形式。

RNN 迴圈神經網路 分類 pytorch

import torch from torch import nn import torchvision.datasets as dsets import torchvision.transforms as transforms import matplotlib.pyplot as plt imp...

pytorch(八) RNN迴圈神經網路 分類

import torch import torch.nn as nn import torchvision.transforms as transforms from torch.autograd import variable import matplotlib.pyplot as plt imp...

pytorch迴圈神經網路引數說明

時常遇到迴圈神經網路,偶爾也會使用迴圈神經網路模型,但是很容易將rnn中一些引數含義忘記,既然不能像迴圈神經網路能記憶歷史資訊,那我只好將rnn引數內容整理成文件,方便日後查閱使用。以下是rnn中引數含義 input size 輸入x的特徵維度 hidden size 隱藏層特徵數量 num lay...