pytorch rnn 變長輸入序列問題

2021-10-16 17:48:35 字數 3660 閱讀 6595

輸入資料是長度不固定的序列資料,主要講解兩個部分

data.dataloader的collate_fn用法,以及按batch進行padding資料

pack_padded_sequence和pad_packed_sequence來處理變長序列

dataloader的collate_fn引數,定義資料處理和合併成batch的方式。

由於pack_padded_sequence用到的tensor必須按照長度從大到小排過序的,所以在collate_fn中,需要完成兩件事,一是把當前batch的樣本按照當前batch最大長度進行padding,二是將padding後的資料從大到小進行排序。

def pad_tensor(vec, pad):

"""args:

vec - tensor to pad

pad - the size to pad to

return:

a new tensor padded to 'pad'

"""return torch.cat([vec, torch.zeros(pad - len(vec), dtype=torch.float)], dim=0).data.numpy()

class collate:

"""a variant of callate_fn that pads according to the longest sequence in

a batch of sequences

"""def __init__(self):

pass

def _collate(self, batch):

"""args:

batch - list of (tensor, label)

reutrn:

xs - a tensor of all examples in 'batch' before padding like:

'''[tensor([1,2,3,4]),

tensor([1,2]),

tensor([1,2,3,4,5])]

'''ys - a longtensor of all labels in batch like:

'''[1,0,1]

'''"""

xs = [torch.floattensor(v[0]) for v in batch]

ys = torch.longtensor([v[1] for v in batch])

# 獲得每個樣本的序列長度

seq_lengths = torch.longtensor([v for v in map(len, xs)])

max_len = max([len(v) for v in xs])

# 每個樣本都padding到當前batch的最大長度

xs = torch.floattensor([pad_tensor(v, max_len) for v in xs])

# 把xs和ys按照序列長度從大到小排序

seq_lengths, perm_idx = seq_lengths.sort(0, descending=true)

xs = xs[perm_idx]

ys = ys[perm_idx]

return xs, seq_lengths, ys

def __call__(self, batch):

return self._collate(batch)

定義完collate類以後,在dataloader中直接使用

train_data = data.dataloader(dataset=train_dataset, batch_size=32, num_workers=0, collate_fn=collate())
pack_padded_sequence將乙個填充過的變長序列壓緊。輸入引數包括

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# x是填充過後的batch資料,seq_lengths是每個樣本的序列長度

packed_input = pack_padded_sequence(x, seq_lengths, batch_first=true)

定義了乙個單向的lstm模型,因為處理的是變長序列,forward函式傳入的值是乙個packedsequence物件,返回值也是乙個packedsequence物件

class model(nn.module):

def __init__(self, in_size, hid_size, n_layer, drop=0.1, bi=false):

super(model, self).__init__()

self.lstm = nn.lstm(input_size=in_size,

hidden_size=hid_size,

num_layers=n_layer,

batch_first=true,

dropout=drop,

bidirectional=bi)

# 分類類別數目為2

self.fc = nn.linear(in_features=hid_size, out_features=2)

def forward(self, x):

''':param x: 變長序列時,x是乙個packedsequence物件

:return: packedsequence物件

'''# lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)

lstm_out, _ = self.lstm(x)

return lstm_out

model = model()

lstm_out = model(packed_input)

這個操作和pack_padded_sequence()是相反的,把壓緊的序列再填充回來。因為前面提到的lstm模型傳入和返回的都是packedsequence物件,所以我們如果想要把返回的packedsequence物件轉換回tensor,就需要用到pad_packed_sequence函式。

引數說明:

返回值: 乙個tuple,包含被填充後的序列,和batch中序列的長度列表。

用法:

# 此處lstm_out是乙個packedsequence物件

output, _ = pad_packed_sequence(lstm_out)

返回的output是乙個形狀為(batch_size,seq_len,input_size)的tensor。

pytorch在自定義dataset時,可以在dataloader的collate_fn引數中定義對資料的變換,操作以及合成batch的方式。

處理變長rnn問題時,通過pack_padded_sequence()將填充的batch資料轉換成packedsequence物件,直接傳入rnn模型中。通過pad_packed_sequence()來將rnn模型輸出的packedsequence物件轉換回相應的tensor。

pytorch rnn 變長輸入序列問題

輸入資料是長度不固定的序列資料,主要講解兩個部分 data.dataloader的collate fn用法,以及按batch進行padding資料 pack padded sequence和pad packed sequence來處理變長序列 dataloader的collate fn引數,定義資料...

可變長字串

目錄stringbuilder 其他可變長字串,jdk1.0提供,執行效率慢,執行緒安全字串緩衝區 執行緒安全的可變字串 字串行 字串 如果字串需要頻繁修改,可用stringbuffer構造方法stringbuffer 初始容量為16個字元 stringbuffer int capacity 構造乙...

struct 封裝變長字串

使用struct,可以非常方便的處理二進位制資料,將常用的int,string等型別的資料轉成二進位制資料,它有兩個重要函式,乙個是pack,乙個是unpack 先看一張表 struct中支援的格式如下表 format c type python 位元組數x pad byte no value1c ...