pytorch重寫DataLoader載入本地資料

2021-10-21 02:05:00 字數 4465 閱讀 3380

前兩天學習了huggingface datasets來寫乙個資料載入指令碼,但是,在實驗中發現,使用dataloader載入資料的便捷性,這兩天查資料勉強重寫dataloader載入本地資料,在這裡記錄下,如果有錯誤,歡迎指正。 總結

在pytorch官網搜尋dataloader,返回的一篇教程是

writing custom datasets, dataloaders and transforms講了編寫自定義資料集,資料載入器和轉換,但是講的是影象資料,在這裡我還使用蘇劍林老師在《基於cnn的閱讀理解式問答模型:dgcnn》裡提供的webqa和sogouqa資料集來重寫dataloader,因為蘇劍林老師是用bert4keras庫 來寫資料集載入,我看的是task_reading_comprehension_by_mlm.py的資料載入格式,有部分**我沒有看懂,我按照自己的理解來重寫的。

torch.utils.data.dataset是表示資料集的抽象類。自定義資料集應繼承資料集並覆蓋以下方法:

對自己的資料集首先要建立乙個dataset 類。在__len__中讀取檔案,要將讀取的檔案載入到__getitem__中,這樣的話所有資料不會立即儲存到記憶體中,二十根據需要讀取,因此可以提高記憶體效率。視覺化演示可以參考這裡。

torch.utils.data.dataset 是乙個抽象類,它只能被繼承。在b站上講解的有兩個我參考的比較好的教程,飯客帆和劉二大人的pytorch教程都很好。

首先是導包。

import re

import torch

import tokenizers

from torch.utils.data import dataset, dataloader

import numpy as np

這裡有幾個要預定義的超引數:

max_length =

384max_p_len =

256max_q_len =

64max_a_len =

32

在蘇神的示例中

輸入:[cls][mask][mask][sep]問題[sep]篇章[sep]

輸出:答案

先要說明的是,蘇神的這個資料載入的**我沒有完全看懂,全部的**我暫時還沒看完,我先按照自己的理解來寫自己的,等我看完所有的**會回來重新修改這篇文章。蘇神的原始碼等我看懂了會回來修改的。

class

mydataset

(dataset)

:# 繼承dataset模組的dataset類

# 初始化定義,得到資料內容

def__init__

(self, data_set)

:super

(mydataset, self)

.__init__(

) self.data_set = data_set # 載入資料集

self.length =

len(data_set)

# 資料集長度

# 返回資料集大小

def__len__

(self)

:return self.length

# 資料預處理,這部分根據自己的資料集進行處理

def__getitem__

(self, index)

:# index(或item)不能少,這個引數是來挑選某條資料的

# d = self.data_set[index]

# 從data_set中取樣乙個資料

token_ids, segment_ids, a_token_ids =,,

question = d[

'question'

] answers =

[p['answer'

]for p in d[

'passages'

]if p[

'answer']]

# 蘇神的**,我沒有看懂是否挑選了無答案的,這裡是自己改的。挑選帶答案的文章

passage =

""# 先宣告再使用是個好習慣,不然會報錯

for pre_passage in d[

'passages']:

if pre_passage[

'answer']:

passage = pre_passage[

'passage'

]break

passage = re.sub(u' |、|;|,'

,','

, passage)

# 清洗資料

final_answer =

''for answer in answers:

# 選擇答案

ifall

([a in passage[

:max_p_len -2]

for a in answer.split(

' ')])

: final_answer = answer.replace(

' ',

',')

break

# print(question)

a_token_ids = tokenizer.encode(final_answer, max_length=max_a_len +

1, padding=

"max_length"

, truncation=

true

)# 答案編碼

q_token_ids = tokenizer.encode(question, max_length=max_q_len +

1, truncation=

true

)# 對問題進行截斷

p_token_ids = tokenizer.encode(passage, max_length=max_p_len +

1, truncation=

true

) token_ids +=

[tokenizer.mask_token_id]

* max_a_len

token_ids +=

[tokenizer.sep_token_id]

token_ids +=

(q_token_ids[1:

]+ p_token_ids[1:

-1])

# [mask][mask][sep]問題[sep]篇章

token_ids = tokenizer.encode(tokenizer.convert_ids_to_tokens(token_ids)

, max_length=max_length ,

padding=

"max_length"

, truncation=

true

)# [cls][mask][mask][sep]問題[sep]篇章[sep]

segment_ids =[0

]*len(token_ids)

token_ids = torch.as_tensor(token_ids)

segment_ids = torch.as_tensor(segment_ids)

a_token_ids = torch.as_tensor(a_token_ids)

return

[token_ids, segment_ids]

, a_token_ids

**如下(示例):

首先,例項化

data = mydataset(train_data)
輸出一下結果

這裡自己重寫了dataloader,有需要學習dataloader載入本地資料的,可以仿照寫就可以了。

PyTorch學習 安裝PyTorch

例如,使用的是 windows 系統,想用 pip 安裝,python 是 3.6 版的,沒有 gpu 加速,那就按上面的選,然後根據上面的提示,在 terminal 中輸入以下指令就好了 pip3 install torch 1.3.1 cpu torchvision 0.4.2 cpu ftor...

Pytorch 通過pytorch實現線性回歸

linear regression 線性回歸是分析乙個變數與另外乙個 多個 變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 分析 求解w,b 求解步驟 1.確定模型 2.選擇損失函式 3.求解梯度並更新w,b 此題 1.model y wx b 下為 實現 import tor...

PyTorch入門(三)PyTorch常用操作

def bilinear kernel in channels,out channels,kernel size return a bilinear kernel tensor tensor in channels,out channels,kernel size,kernel size 返回雙線性...