Pytorch視訊記憶體不斷增長問題的解決思路

2021-10-25 17:49:59 字數 3637 閱讀 5532

這個問題,我先後遇到過兩次,每次都異常艱辛的解決了。

在網上,關於這個問題,你可以找到各種看似不同的解決方案,但是都沒能解決我的問題。所以只能自己摸索,在摸索的過程中,有了乙個排查問題點的思路。

下面舉個例子說一下我的思路。

其實思路很簡單,就是在**的執行階段輸出視訊記憶體佔用量,觀察在哪一塊存在視訊記憶體劇烈增加或者視訊記憶體異常變化的情況。但是在這個過程中要分級確認問題點,也即如果存在三個檔案main.py、train.py、model.py。在此種思路下,應該先在main.py中確定問題點,然後,從main.py中進入到train.py中,再次輸出視訊記憶體佔用量,確定問題點在哪。隨後,再從train.py中的問題點,進入到model.py中,再次確認。如果還有更深層次的呼叫,可以繼續追溯下去。

在main.py

def

train

(model,epochs,data)

:for e in

range

(epochs)

:print

("1:{}"

.format

(torch.cuda.memory_allocated(0)

))train_epoch(model,data)

print

("2:{}"

.format

(torch.cuda.memory_allocated(0)

))eval

(model,data)

print

("3:{}"

.format

(torch.cuda.memory_allocated(0)

))

假設1與2之間視訊記憶體增加極為劇烈,說明問題出在train_epoch中,進一步進入到train.py中。

train.py

def

train_epoch

(model,data)

: model.train(

) optim=torch.optimizer(

)for batch_data in data:

print

("1:{}"

.format

(torch.cuda.memory_allocated(0)

))output=model(batch_data)

print

("2:{}"

.format

(torch.cuda.memory_allocated(0)

))loss=loss(output,data.target)

print

("3:{}"

.format

(torch.cuda.memory_allocated(0)

))optim.zero_grad(

)print

("4:{}"

.format

(torch.cuda.memory_allocated(0)

))loss.backward(

)print

("5:{}"

.format

(torch.cuda.memory_allocated(0)

))utils.func(model)

print

("6:{}"

.format

(torch.cuda.memory_allocated(0)

))

如果在1,2之間,5,6之間同時出現視訊記憶體增加異常的情況。此時需要使用控制變數法,例如我們先讓5,6之間的**失效,然後執行,觀察是否仍然存在視訊記憶體**。如果沒有,說明問題就出在5,6之間下一級的**中。進入到下一級**,進行除錯:

utils.py

1

deffunc

(model):2

print

("1:{}"

.format

(torch.cuda.memory_allocated(0)

))3 a=f1(model)

4print

("2:{}"

.format

(torch.cuda.memory_allocated(0)

))5 b=f2(a)

6print

("3:{}"

.format

(torch.cuda.memory_allocated(0)

))7 c=f3(b)

8print

("4:{}"

.format

(torch.cuda.memory_allocated(0)

))9 d=f4(c)

10print

("5:{}"

.format

(torch.cuda.memory_allocated(0)

))

此時我們再展示另一種除錯思路,先注釋第5行之後的**,觀察視訊記憶體是否存在先訓**,如果沒有,則注釋掉第7行之後的,直至確定哪一行的**出現導致了視訊記憶體**。假設第9行起作用後,**出現視訊記憶體**,說明問題出在第九行,視訊記憶體**的問題鎖定。

def

pro_weight

(p, x, w, alpha=

1.0, cnn=

true

, stride=1)

:if cnn:

_, _, h, w = x.shape

f, _, hh, ww = w.shape

s = stride # stride

ho =

int(1+

(h - hh)

/ s)

wo =

int(1+

(w - ww)

/ s)

for i in

range

(ho)

:for j in

range

(wo)

:# n*c*hh*ww, c*hh*ww = n*c*hh*ww, sum -> n*1

r = x[:,

:, i * s: i * s + hh, j * s: j * s + ww]

.contiguous(

).view(1,

-1)# r = r[:, range(r.shape[1] - 1, -1, -1)]

k = torch.mm(p, torch.t(r)

) p.sub_(torch.mm(k, torch.t(k))/

(alpha + torch.mm(r, k)))

w.grad.data = torch.mm(w.grad.data.view(f,-1

), torch.t(p.data)

).view_as(w)

else

: r = x

k = torch.mm(p, torch.t(r)

) p.sub_(torch.mm(k, torch.t(k))/

(alpha + torch.mm(r, k)))

w.grad.data = torch.mm(w.grad.data, torch.t(p.data)

)

TensorFlow訓練內(顯)存不斷增長

在使用tensorflow過程中,乙個不標準的操作,就可能導致程式出各種bug,今天我們的豬腳就是 tensorflow訓練內 顯 存不斷增長 此問題並不是我遇到的,是公司一位同事遇到的,我把 翻了一下,看出了問題所在,由於一些保密原因,我就不在這裡展示那個 但可以用其他 來替代。import te...

日誌檔案不斷增長

原文 日誌檔案不斷增長 sqlserver定時執行 checkpoint 保證 髒頁 被寫入硬碟。沒做checkpoint的,可能是只在記憶體中修改,資料檔案還沒同步。sqlserver要在硬碟的日誌檔案中有記錄,一邊異常重啟後重新修改。所有日誌都有嚴格順序,不能有跳躍。如果恢復模式不是簡單模式,那...

日誌檔案不斷增長

sqlserver定時執行 checkpoint 保證 髒頁 被寫入硬碟。沒做checkpoint的,可能是只在記憶體中修改,資料檔案還沒同步。sqlserver要在硬碟的日誌檔案中有記錄,一邊異常重啟後重新修改。所有日誌都有嚴格順序,不能有跳躍。如果恢復模式不是簡單模式,那麼sqlserver會認...