syncbn在TensorFlow中的實現

2021-09-02 11:44:39 字數 2767 閱讀 8251

在syncbn之前我們先簡單介紹一下bn層以及多卡機制

bn層介紹

bn層中有兩個可訓練引數(beta, gamma),以及兩個統計引數(moving mean, moving variance)。訓練過程和測試過程,bn層計算方式是不同的。訓練過程,beta和gamma與卷積層中的weight是一樣參與訓練的,然後moving mean與moving variance由當前的batch mean和batch variance統計的,可以在訓練過程的bn層可以用下列公式表示:

y =γ

×x−m

eanb

atch

vari

ance

batc

h+ε+

βy = \gamma \times \frac} + \varepsilon}} + \beta

y=γ×va

rian

ceba

tch​

+ε​x

−mea

nbat

ch​​

+βm ea

nmov

ing=

deca

y×me

anmo

ving

+(1−

deca

y)×m

eanb

atch

mean_=decay \times mean_ + (1-decay) \times mean_

meanmo

ving

​=de

cay×

mean

movi

ng​+

(1−d

ecay

)×me

anba

tch​

v ar

ianc

emov

ing=

deca

y×va

rian

cemo

ving

+(1−

deca

y)×v

aria

nceb

atch

variance_ = decay \times variance_ + (1-decay) \times variance_

varian

cemo

ving

​=de

cay×

vari

ance

movi

ng​+

(1−d

ecay

)×va

rian

ceba

tch​

測試過程(以及bn層不參與訓練時),beta和gamma採用跟weight相同的使用方式,然而使用moving mean和moving variance替代batch mean和batch variance,所以測試過程中的bn層可以用下列公式表示:

y =γ

×x−m

eanm

ovin

gvar

ianc

emov

ing+

ε+

βy = \gamma \times \frac} + \varepsilon}} + \beta

y=γ×va

rian

cemo

ving

​+ε​

x−me

anmo

ving

​​+β

多卡機制

目前的多卡訓練可以分成非同步式和同步式。我們這裡講同步式,同步式是指將模型複製到各個gpu上,資料切分分發到各個gpu上,如resnet50模型複製到8個gpu上,如果batch size=256, 那麼每個gpu上的batch size=256/8=32。由於tensorflow, pytorch, caffe, caffe2等開源網路框架對於效率的考慮,且在傳統影象分類任務中,單卡的batch size就能設定的很大對於bn層的統計影響不大,所以各大網路框架都沒有做bn的同步(moving mean和moving variance的同步)。然而對於semantic segmentaion,object detection等任務來說,單卡統計的數量很小甚至為1,所以bn層中moving mean, mean variance存在很大的擾動,這便造成了bn層的失效。通過以上的分析可以看出bn的同步很有必要。對於bn層的更多實驗結果可以參照曠視的文章《megdet:a large mini-batch object detector》

多卡同步

對於pytorch的多卡bn層同步已經有了一些解決方案,如 syncbn-pytorch 。對於mxnet的多卡同步已經被整合到了官方的api中。對於tensorflow也有一些方法,如將每個操作均同步如:

def relu(list_input):

assert type(list_input) == list

list_output =

for i in range(len(list_input)):

with tf.device('/gpu:%d' % i):

output = tf.nn.relu(list_input[i], name='relu')

return list_output

但是這種方式**量略大。tensorpack大神也提出了他們的方法,但是他們的方法也需要依賴tensorpack。我實現了一下syncbn-tensorflow,具體實現可以看 syncbn-tensorflow, 這種方式可以對原來的**進行非常小的改動,就能實現多卡同步了。

TensorFlow框架 tensorflow基礎

1 圖預設已經註冊,一組表示 tf.operation計算單位的物件和tf.tensor,表示操作之間流動的資料單元的物件 2 獲取呼叫 tf.get default graph op sess或者tensor 的graph屬性 3 圖的建立和使用 執行tensorflow操作圖的類,使用預設註冊的...

基於 Anaconda 安裝 tensorflow

anaconda 是乙個整合許多第三方科學計算庫的 python 科學計算環境,anaconda 使用 conda 作為自己的包管理工具,同時具有自己的計算環境,類似 virtualenv.和 virtualenv 一樣,不同 python 工程需要的依賴包,conda 將他們儲存在不同的地方。te...

基於Anaconda安裝tensorflow

anaconda 是乙個整合許多第三方科學計算庫的 python 科學計算環境,anaconda 使用 conda 作為自己的包管理工具,同時具有自己的計算環境,類似 virtualenv.和 virtualenv 一樣,不同 python 工程需要的依賴包,conda 將他們儲存在不同的地方。te...