pytorch中不定長序列補齊方法

2021-10-21 19:30:34 字數 1270 閱讀 3898

第二種方法通常是在load乙個batch資料時, 在collate_fn中進行補齊的. collate_fn使用方法

以下給出兩種思路:

第一種思路是比較容易想到的, 就是對乙個batch的樣本進行遍歷, 然後使用np.pad對每乙個樣本進行補齊.

for unit in data:

mask = np.zeros(max_length)

s_len =

len(unit[0]

)# calculate the length of sequence in each unit

mask[

: s_len]=1

unit[0]

= np.pad(unit[0]

,(0, max_length - s_len)

,'constant'

, constant_values=(0

,0))

但是這種方法在batch size很大的情況下會很慢, 因為使用for迴圈進行了遍歷. 我在實際用的時候, 當batch_size=128時, 乙個batch的載入時間甚至是乙個batch訓練時間的幾倍!

因此, 我想到如何並行地對序列進行補齊. 第二種方法的思路就是使用torch中自帶的pad_sequence來並行補齊.

batch_sequence =

list

(map

(lambda x: torch.tensor(x[findex]

), x_data)

) batch_data[feat]

= torch.nn.utils.rnn.pad_sequence(batch_sequence)

.t

可以看到這裡使用pad_sequence一次性對整個batch進行補齊. 下面對這個函式進行詳細說明.

from torch.utils.rnn import pad_sequence

a = torch.ones(10)

b = torch.ones(6)

c = torch.ones(20)

abc = pad_sequence(

[a,b,c]

)# shape(20, 3)

注意這個函式接收的是乙個元素為tensor的列表, 而不是tensor.

最終, 這個函式會將所有tensor轉換為tensor矩陣#shape(max_length, batch_size). 因此, 在使用完後通常還需要轉置一下.

Java中不定長度的引數

什麼是不定長度的引數呢,就是沒有規定引數的長度,可以用三個小數點意為省略的意思,比如下面 package laojiuxuetangzhixunhuan public class testcanshu public static void paramtest string s,int nums sy...

golang中定義不定長陣列的方法

go語言提供了陣列型別的資料結構。陣列是具有相同唯一型別的一組已編號且長度固定的資料項序列,這種型別可以是任意的原始型別例如整形 字串或者自定義型別。宣告陣列 go 語言陣列宣告需要指定元素型別及元素個數,語法格式如下 var variable name size variable type以上為一...

oracle中建立自增長序列

首先建立序列 create sequence incr stu id seq minvalue 1 start with 1 increment by 1 nomaxvalue nocache 然後建立觸發器 create or replace trigger incr stu id trig be...