pytorch fine tuning注意事項

2021-09-02 17:37:02 字數 1867 閱讀 5160

前言

這篇文章算是論壇pytorch forums關於引數初始化和finetune的總結,也是我在寫**中用的算是「最佳實踐」吧。最後希望大家沒事多逛逛論壇,有很多高質量的回答。

引數初始化

引數的初始化其實就是對引數賦值。而我們需要學習的引數其實都是variable,它其實是對tensor的封裝,同時提供了data,grad等藉口,這就意味著我們可以直接對這些引數進行操作賦值了。這就是pytorch簡潔高效所在。 

所以我們可以進行如下操作進行初始化,當然其實有其他的方法,但是這種方法是pytorch作者所推崇的:

def weight_init(m):

# 使用isinstance來判斷m屬於什麼型別

if isinstance(m, nn.conv2d):

n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels

m.weight.data.normal_(0, math.sqrt(2. / n))

elif isinstance(m, nn.batchnorm2d):

# m中的weight,bias其實都是variable,為了能學習引數以及後向傳播

m.weight.data.fill_(1)

m.bias.data.zero_()

finetune

往往在載入了預訓練模型的引數之後,我們需要finetune模型,可以使用不同的方式finetune。

區域性微調:僅僅訓練輸出層,也就是只有fc層有學習率,前面的model的引數保留保持不變因此沒有學習率

有時候我們載入了訓練模型後,只想調節最後的幾層,其他層不訓練。其實不訓練也就意味著不進行梯度計算,pytorch中提供的requires_grad使得對訓練的控制變得非常簡單。

model = torchvision.models.resnet18(pretrained=true)

for param in model.parameters():

param.requires_grad = false

# 替換最後的全連線層, 改為訓練100類

# 新構造的模組的引數預設requires_grad為true

model.fc = nn.linear(512, 100)

# 只優化最後的分類層

optimizer = optim.sgd(model.fc.parameters(), lr=1e-2, momentum=0.9)

全域性微調:前面的base_model使用比較小的學習率繼續訓練,並不freeze,後面的fc層使用較大的學習率

有時候我們需要對全域性都進行finetune,只不過我們希望改換過的層和其他層的學習速率不一樣,這時候我們可以把其他層和新層在optimizer中單獨賦予不同的學習速率。比如:

ignored_params = list(map(id, model.fc.parameters()))

base_params = filter(lambda p: id(p) not in ignored_params,

model.parameters())

optimizer = torch.optim.sgd([

,#fc層使用較大的學習率

], lr=1e-3, momentum=0.9)#base_model使用較小的學習率

base_params因為已經提取到resnet18的特徵了,因此不需要特別大的學習率

model.fc.parameters,模型引數採用隨機初始化因此需要採用比較大的學習率

其中base_params使用較大的學習率:1e-2來訓練,model.fc.parameters使用較小的學習率1e-3來訓練,momentum是二者共有的。

PHP Open Flash Chart注意事項

1.在html頁面必須src正確的swfobject.js的路徑 可以用firebug檢視絕對路徑是否正確 2.在html頁面必須指定正確的swfobject使用時的open flash chart.swf的位置 可以用firebug檢視絕對路徑是否正確 3.在html頁面必須制定正確的data f...

Spring Hibernate整合注意事項

1 spring jar包 需要額外加入 commons pool 和commons dbcp 若包含 spring 自帶的測試,還需要引入 spring test 2 spring beans.xml 如果使用了 spring annotation 則需要加入以下兩項配置 前提是已經匯入了bean...

Protocol Buffers使用注意事項

protocol buffers做為廣泛使用的乙個序列化開源庫,提供了很多語言下的支援,本文就談談msvc c 使用pb遇到的問題,當然這些問題因為每個人的使用模式不同,可能都不一樣,本文也不討論怎麼寫proto及編譯。我們使用pb做序列化可以把pb生成靜態庫或者動態庫 libprotobuf.dl...