3 7 PyTorch中優化器比較

2021-10-14 15:19:07 字數 2911 閱讀 6048

對pytorch中優化器進行乙個簡單的例項進行比較說明:

import torch

import torch.nn as nn

import torch.utils.data as data

import torch.nn.functional as f

import matplotlib.pyplot as plt

# 超引數

lr =

0.1batch_size =

32epoch =

12

# 生成訓練資料

# torch.unsqueeze()的作用是將一維變成二維,torch只能處理二維的資料

x = torch.unsqueeze(torch.linspace(-1

,1,1000

), dim=1)

# 0.1 * torch.normal(torch.zeros(*x.size())為增加噪點

y = x.

pow(2)

+0.1

* torch.normal(torch.zeros(

*x.size())

)# tensordataset是將樣本和標籤打包成dataset

torch_dataset = data.tensordataset(x, y)

# 得到乙個大批量的生成器

# dataloader組合資料集和取樣器

loader = data.dataloader(dataset=torch_dataset, batch_size=batch_size, shuffle=

true

)

class

net(nn.module)

:# 初始化

def__init__

(self)

:super

(net, self)

.__init__(

) self.hidden = nn.linear(1,

20)self.predict = nn.linear(20,

1)# 前向傳播

defforward

(self, x)

: x = f.relu(self.hidden(x)

) x = self.predict(x)

return x

# 使用多種優化器

net_sgd = net(

)net_momentum = net(

)net_rmsprop = net(

)net_adam = net(

)# 裝進乙個列表裡

nets =

[net_sgd, net_momentum, net_rmsprop, net_adam]

opt_sgd = torch.optim.sgd(net_sgd.parameters(

), lr=lr)

opt_momentum = torch.optim.sgd(net_momentum.parameters(

), lr=lr, momentum=

0.9)

opt_rmsprop = torch.optim.rmsprop(net_rmsprop.parameters(

), lr=lr, alpha=

0.9)

opt_adam = torch.optim.adam(net_adam.parameters(

), lr=lr, betas=

(0.9

,0.99))

optimizers =

[opt_sgd, opt_momentum, opt_rmsprop, opt_adam]

# 訓練模型

# 呼叫均方損失函式

loss_func = torch.nn.mseloss(

)loss_his =[[

],,,

]for epoch in

range

(epoch)

:for step,

(batch_x, batch_y)

inenumerate

(loader)

:for net, opt, l_his in

zip(nets, optimizers, loss_his)

:# 從每乙個網路裡獲取輸出

output = net(batch_x)

# 計算每乙個網路的損失

loss = loss_func(output, batch_y)

# 梯度清零

opt.zero_grad(

)# 反向傳播

loss.backward(

)print

(loss)

# 更新引數

opt.step())

)labels =

["sgd"

,"momentum"

,"rmsprop"

,"adam"

]

# 視覺化結果

for i, l_his in

enumerate

(loss_his)

: plt.plot(l_his,label=labels[i]

)# print(l_his)

plt.legend(loc=

"best"

)plt.xlabel(

"steps"

)plt.ylabel(

"loss"

)plt.ylim((0

,0.2))

plt.show(

)

果然還是adam比較好。

PyTorch常見的優化器

用法pytorch學習率調整策略通過torch.optim.lr scheduler介面實現。torch.optim是乙個實現了各種優化演算法的庫。大部分常用的方法得到支援,並且介面具備足夠的通用性,使得未來能夠整合更加複雜的方法。參考連線 首先需要構建乙個optimizer物件。這個物件能夠保持當...

Pytorch中adam優化器的引數問題

之前用的adam優化器一直是這樣的 alpha optim torch.optim.adam model.alphas config.alpha lr,betas 0.5,0.999 weight decay config.alpha weight decay 沒有細想內部引數的問題,但是最近的工作...

PyTorch自定義優化器

簡單粗暴的方法直接更新引數 def myopt pre pre儲存當前梯度與歷史梯度方向是否一致的資訊 lr lr儲存各層各引數學習率 vdw vdw儲存各層各引數動量 y pred net x loss loss func y pred,y net.zero grad loss.backward ...