PyTorch學習總結 四 Utilities

2021-08-11 21:32:38 字數 3165 閱讀 2183

這個類的例項不能手動建立。它們只能被pack_padded_sequence() 例項化。

torch.nn.utils.rnn.pack_padded_sequence()**

輸入:input:[seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照長度遞減的方式排列。

lengths:seq的長度列表,是乙個遞減的列表,與input裡的seq長度對應。ie. [5,4,1]

batch_first:bool變數,當它為true時,表示input為這種輸入形式[batch_size x seq_length x input_size],否則為另一種。

輸出:

乙個packedsequence物件,包含乙個variable型別的data,和鍊錶型別的batch_sizes。

batch的每乙個元素,代表data中,多少行為乙個batch。

例如:輸入為

input

variable containing:

(0 ,.,.) =

123(1 ,.,.) =

100[torch.floattensor of size 2x3x1]

lengths = [3, 1]

為了實現壓縮編碼,即把填充去除。我們最終的輸出為

packedsequence(data=variable containing: 11

23[torch.floattensor of size 4x1]

, batch_sizes=[2, 1, 1])

這就表明,前兩個1屬於乙個batch,後面兩個分別屬於不同的batch。換句話說,從batch_sizes可以看出,兩個seq的長度分別為1,3。後面的module或function可以根據batch_sizes讀取對應的資料。

**詳解

這裡我們以上面的輸入為例,研究該函式到底是怎麼實現資料壓縮的。

def

pack_padded_sequence

(input, lengths, batch_first=false):

# juge the length is > 0

if lengths[-1] <= 0:

raise valueerror("length of all samples has to be greater than 0, "

"but found an element in 'lengths' that is <=0")

# change the input into the shape of [seq_length x batch_size x input_size]

# here input is [3, 2, 1]

if batch_first:

input = input.transpose(0, 1)

steps =

batch_sizes =

# get the reversed iterator of the lengths

lengths_iter = reversed(lengths)

# here current_length == 1

current_length = next(lengths_iter)

batch_size = input.size(1)

if len(lengths) != batch_size:

raise valueerror("lengths array has incorrect size")

# here 1 indicate the 'step' start from 1

for step, step_value in enumerate(input, 1):

"""step_value == 1

1[torch.floattensor of size 2x1]

"""# juge if step to the end of a short seq

while step == current_length:

try:

new_length = next(lengths_iter)

except stopiteration:

current_length = none

break

# check the lengths if is a decrasing list

if current_length > new_length: # remember that new_length is the preceding length in the array

raise valueerror("lengths array has to be sorted in decreasing order")

# already step over a short seq, so the number of the batch should minus 1.

batch_size -= 1

current_length = new_length

if current_length is

none:

break

# here concat the list along the dim0.

return packedsequence(torch.cat(steps), batch_sizes)

nn.utils.rnn.pad_packed_sequence()

這就是上乙個函式的逆操作。輸入是乙個packedsequence物件,包含batch_sizes,可以根據其對其中的data進行解耦。

pytorch總結學習系列 操作

算術操作 在pytorch中,同一種操作可能有很多種形式,下 用加法作為 加法形式 x torch.tensor 5.5,3 y torch.rand 5,3 print x y 加法形式 print torch.add x,y 還可指定輸出 result torch.empty 5,3 torch...

學習pytorch(四)簡單RNN舉例

import torch 簡單rnn學習舉例。rnn 迴圈神經網路 是把乙個線性層重複使用,適合訓練序列型的問題。單詞是乙個序列,序列的每個元素是字母。序列中的元素可以是任意維度的。實際訓練中,可以首先把序列中的元素變為合適的維度,再交給rnn層。學習 將hello 轉為 ohlol。dict e ...

pytorch總結學習系列 資料操作

在深度學習中,我們通常會頻繁地對資料進 行 操作。作為動 手學深度學習的基礎,本節將介紹如何對內 存中的資料進 行 操作。在pytorch中,torch.tensor 是儲存和變換資料的主要 工具。如果你之前 用過numpy,你會發現 tensor 和numpy的多維陣列 非常類似。然 tensor...