半精度訓練pytorch Apex

2022-08-19 01:27:14 字數 680 閱讀 4837

想起乙個關於運維的段子:很多問題可以通過重啟解決,想說演算法工(diao)程(bao)師(xia)的很多問題可以通過換版本解決。

起因是白嫖到乙個tensorflow的架子跑bert,自己花一上午時間搞定了單機多卡訓練,之後花了兩個下午也沒有搞定半精度,症狀是不報錯,但是視訊記憶體不降,速度不漲(32g v100)。於是開始懷念我熟悉的pytorch+apex,又斷斷續續花了兩天多的時間把整個訓練框架用pytorch實現了一遍,基於huggingface的transformers。

看到單卡loss正常下降就開始了多卡+apex半精度,結果發現fp16o1雖然視訊記憶體降了,速度卻比fp32還要慢2倍多,期間也參考了下其他人遇到的問題,最終懷疑了一下是不是自己的pytorch版本太老,pytorch版本從1.1.0切換到1.5.1,重新編譯apex,果然速度上來了....前後版本如下(右側是正常的,fp16o1速度是左側版本的10倍),python版本都是3.7.6:

apex按readme quick start安裝即可,可能需要指定載入路徑。

export pythonpath=/你的apex路徑/:$pythonpath
apex半精度訓練可以參考這裡, transformers裡面已經呼叫的很好了,不必自己改什麼。

雙精度,單精度和半精度

浮點數是計算機上最常用的資料型別之一,有些語言甚至數值只有浮點型 perl,lua同學別跑,說的就是你 常用的浮點數有雙精度和單精度。除此之外,還有一種叫半精度的東東。雙精度64位,單精度32位,半精度自然是16位了。半精度是英偉達在2002年搞出來的,雙精度和單精度是為了計算,而半精度更多是為了降...

d precision 混合精度訓練

意思是使用 進行訓練,同時有乙份 的引數主副本用於引數更新 那麼實現上其實就很簡單,只需要在每次迭代之前,將每個 或者 的引數輸入都確保是從 拉取到的,然後轉換成 輸入 而最後將計算得到梯度,則是更新到 的主副本上面 這樣做的好處在於可以避免兩種情況下的溢位,第一次就是當梯度特別小,超出 表達範圍後...

半精度浮點數到單精度的python 實現

實現原理可參考3.參考文獻部分。輸入引數s是字串形式的16位二進數,如 0011010101010101 def halfpre2spre s s代表16位二進數,sign int s 0 res0 pow 1 sign 符號位 exp int 0b s 1 6 2 指數字 endpre s 6 尾...