Pytorch 多程序在單卡上測試

2021-10-10 22:55:08 字數 1925 閱讀 6197

有些煉丹師可能機器不足,只有一張卡,然後訓練完成了,想要測試的時候,受限於影象樣本size不一致,不能合併到乙個batch中。當然解決方案很多,但有一種更自然的辦法,既使用任意解析度的輸入,同樣使用多程序在單卡上並行執行資料的測試,從而加快測試速度,趕上dideline。

大家都知道pytorch官方推廣大家使用分布式多卡平行計算,其原理是每個程序都使用乙個gpu。那我們現在要做的就是多程序都在乙個gpu上,每個程序處理一批資料,從而加快處理速度。

執行這套流程的大致思路是,我們可以先把要處理的資料加入到容器中(比如字典或者元祖列表),然後把容器中的資料分段,分段的數目就是程序數,然後每個程序處理自己對應的那些資料。由於輸入的size不一樣,所以我們要一張一張的處理,但因為開了多程序,實際的batchsize等於程序數目。

假設我們現在有一批資料,我們先要把這批資料的目錄位址加到乙個容器中,

target_loc = make_dataset(os.path.join(img_root, training_subfix)

,true

) target_loc.extend(make_dataset(os.path.join(img_root, val_subfix)

,false))

target_loc.sort(key=

lambda x: x[0]

)# 保證一致

print

(len

(target_loc)

)

make_dataset的作用就是把所有的資料的位址加入到target_loc中。target_loc是乙個list。

然後開始多程序啟動

import multiprocessing as mp # 這裡不需要使用pytorch的multiprocessing

mp = mp.get_context(

'spawn'

)pool =

n_proc =

4# 開4程序

total_num =

len(target_loc)

phrase = total_num // n_proc +

1for i in

range

(n_proc)

: divied_target = target_loc[i*phrase:

(i+1

)*phrase]

# 所有的資料分成4份

process = mp.process(target=prediction_by_dfsd, args=

(divied_target,i)

)# 每個程序都執行prediction_by_dfsd

process.start(

)for p in pool:

p.join(

)# 等待所有程序執行完畢

prediction_by_dfsd這個函式就是測試資料的函式,這個函式需要完成的是,從分段的容器中讀取影象,然後一張張的**,然後儲存結果。

另外我需要說明的是,spawn的啟動程序的方式和我們平常使用的map_async等不一樣,從spawn啟動的程序,是無法訪問到主程序的任何變數的。所有我們需要在prediction_by_dfsd完成模型定義,匯入引數等操作。當然一些變數可以通過process的args傳進去。

所以和使用分布式訓練一樣,每個程序都會定義一次模型。只不過我們這個方法,是在單卡上使用多程序,一樣能使用並行**,滿足了在一些情況,比如樣本size不一樣,不能合併到batch中,而一張一張**又太慢的問題。

當然我後面反應過來了。只用使用dataloader,batchsize為1,一樣可以使用標準的pytorch分布式流程使用多程序在單卡上的測試,只不過不需要設定gpu id就行了。所有的模型都往乙個gpu上搬運。不過呢,對於簡單的任務,沒有必要再寫乙個dataloader,所以我這個辦法還是有價值的。

pytorch多程序最佳實踐

torch.multiprocessing是 python 的multiprocessing多程序模組的替代品。它支援完全相同的操作,但對其進行了擴充套件,以便所有通過多程序佇列multiprocessing.queue傳送的張量都能將其資料移入共享記憶體,而且僅將其控制代碼傳送到另乙個程序。注意 ...

關於pytorch在windows上編輯的問題集合

cmake在windows上自動尋找v140 vs2015 的編譯器,現在只有vs2013的ide,所以要修改編譯器 修改掉vs2015的編譯器名稱,報錯提示引數cmake c compiler和cmake cxx compiler引數的對應位址找不到 在cmakelists.txt裡顯式設定這兩個...

Python多程序在Windows作業系統下的坑

筆者是乙個python初學者,因為windows有圖形化介面寫 方便,基本 都是在windows下寫的,這就導致了出現很多問題,比如使用建立多程序來實現伺服器併發會出現一些很難想象到的錯誤,如 import socket import multiprocessing defsend new data...