RNN 網路中文字的 pack 和 pad 操作

2022-07-08 02:06:07 字數 4659 閱讀 4675

rnn 模型一般設定固定的文字長度(text sequence length,可理解為文字序列在時間維度上的步數 time step),以保證網路輸出

層資料維度的一致性。

但在訓練和測試時,難以保證輸入文字長度的一致性,因此常常需要截斷操作(即將超過預設長度的文字截斷)

和 pad 操作(即對不足預設長度的文字進行補 0 填充)。

pytorch 中,在文字資料的 transfrom 以及 rnn 網路的輸入階段,均充分考慮了 pad 操作。其主要體現在:

(1)rnn、lstm 和 gru 等網路的輸入資料均可為 packedsequence 型別資料;

(2)可通過 pad_sequence、pack_sequence、pack_padded_sequence 和 pad_packed_sequence 等操作,實現 pad 和 pack 操作。

1. pack_sequence

import torch

from torch.nn.utils.rnn import pack_sequence, pad_sequence, pad_packed_sequence, pack_padded_sequence

text1 = torch.tensor([1,2,3,4]) # 可視為有 4 個文字的樣本

text2 = torch.tensor([5,6,7]) # 可視為有 3 個文字的樣本

text3 = torch.tensor([8,9]) # 可視為有 2 個文字的樣本

sequences = [text1, text2, text3] # 三個文字序列拼接

x = pack_sequence(sequences)

print(x)

"""packedsequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]),

sorted_indices=none, unsorted_indices=none)

"""

輸入資料是由 tensor 列表,每個 tensor 表示乙個序列資料。pack 後的返回值包括兩資料。一類為 data,即壓縮後的資料;

而另一類 batch_sizes 表示每個時間步 batch 中包含的樣本量。

值的注意的是,sequences 列表內的各元素長度必須按照降序排列,也就是越長的文字應放在前面,輸入的 batch * sequence 矩陣為上三角陣。

前文提到的 rnn 網路中可以接收的 input 資料可以為 packedsequence 型別資料,即是類似於這裡的返回值。

text1 = torch.tensor([1,2,3,4]) # 可視為有 4 個文字的樣本

text2 = torch.tensor([5,6,7]) # 可視為有 3 個文字的樣本

text3 = torch.tensor([8,9]) # 可視為有 2 個文字的樣本

sequences = [text1, text2, text3] # 三個文字序列拼接

x = pad_sequence(sequences)

print(x)

"""tensor([[1, 5, 8],

[2, 6, 9],

[3, 7, 0],

[4, 0, 0]])

"""pad 操作即是將不同長度的文字序列進行對齊的填充過程。預設情況下,引數 batch_first=false,這裡指定的是輸出資料的形狀,

有些函式的這個引數是用來指明輸入資料的形狀,注意區分。pad_sequence 輸入資料的形狀和 pack_sequence 是一樣的。

與 pack 操作不同,pad 操作對於 sequences 列表內的各元素長度順序並無要求。

觀察上述 pack 和 pad 操作,返回結果均傾向於按照序列 sequece 的順序進行輸出,而將 batch 的輸出順序後置,其實這是 pytorch 中

整個 rnn 網路的統一推薦用法,觀察 rnn、lstm 和 gru 等網路架構,引數 batch_first 的預設值均為 false!

顧名思義,

這個函式的輸入是 pad_sequence 函式的輸出,也就是填充後的資料。

import torch

from torch.nn.utils.rnn import pack_sequence, pad_sequence, pad_packed_sequence, pack_padded_sequence

text1 = torch.tensor([1,2,3,4]) # 可視為有 4 個文字的樣本

text2 = torch.tensor([5,6,7]) # 可視為有 3 個文字的樣本

text3 = torch.tensor([8,9]) # 可視為有 2 個文字的樣本

sequences = [text1, text2, text3] # 三個文字序列拼接

x = pad_sequence(sequences, batch_first=true) # batch_first 指定輸出資料的形狀

print(x)

y = pack_padded_sequence(x, lengths=[4,3,2], batch_first=true) # batch_first 指明輸入資料的形狀

print(y)

"""tensor([[1, 2, 3, 4],

[5, 6, 7, 0],

[8, 9, 0, 0]])

packedsequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]),

sorted_indices=none, unsorted_indices=none)

"""

pack_padded_sequence 函式的作用過程可分解為如下步驟:

(1)接收乙個 padded_sequence 資料;

(2)根據 batch_first 引數明確該資料的布局(預設為 batch_first=false);

(3)根據 lengths 引數明確 batch 內各樣本的時間步長,選擇資料;注意列表內的元素必須為降序。

(4)將上述資料按照時間維度進行壓縮,得到目標的 packedsequence 型別資料。

4. pad_packed_sequence

pad_packed_sequence 函式即為 pack_padded_sequence 的逆操作,其在引數設定時也許注意通過 batch_first 控制返回值的維度順序,

同時可通過設定 total_lengths 來控制 pad 後的總步長(該值必須不小於輸入 packedsequence 資料的步長數)。

import torch

from torch.nn.utils.rnn import pack_sequence, pad_sequence, pad_packed_sequence, pack_padded_sequence

text1 = torch.tensor([1,2,3,4]) # 可視為有 4 個文字的樣本

text2 = torch.tensor([5,6,7]) # 可視為有 3 個文字的樣本

text3 = torch.tensor([8,9]) # 可視為有 2 個文字的樣本

sequences = [text1, text2, text3] # 三個文字序列拼接

x = pack_sequence(sequences)

print(x)

y = pad_packed_sequence(x, total_length=5, batch_first=true)

print(y)

"""packedsequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]),

sorted_indices=none, unsorted_indices=none)

(tensor([[1, 2, 3, 4, 0],

[5, 6, 7, 0, 0],

[8, 9, 0, 0, 0]]), tensor([4, 3, 2]))

"""

簡單的RNN和BP多層網路之間的區別

先來個簡單的多層網路 關於rnn出現的原因,rnn詳細的原理,已經有很多博文講解的非常棒了。如下 多層網路 x tf.placeholder tf.float32,none,256 y tf.placeholder tf.float32,none,10 w1 tf.variable tf.rando...

OpenInventor中文字元的顯示和解決方法

很多人問我如何顯示漢字,總是乙個乙個地說,很麻煩,特此寫在這裡 原因 openinventor支援freetype字型,但漢字採用unicode編碼,故而無法直接正確顯示 解決 使用freetype字型,或者,將unicode編碼轉換為freetype可以識別的字型 使用內建的freetype引擎 ...

VIM中文字的替換和複製

1.替換當前行中的內容 s from to s即substitude s from to 將當前行中的第乙個from,替換成to。如果當前行含有多個from,則只會替換其中的第乙個。s from to g 將當前行中的所有from都替換成to。s from to gc 將當前行中的所有from都替換...