Pytorch教程 加速神經網路訓練

2021-10-18 09:33:32 字數 2233 閱讀 3063

torch and numpy

變數 (variable)

激勵函式

關係擬合(回歸)

區分型別 (分類)

快速搭建法

批訓練加速神經網路訓練

optimizer優化器

卷積神經網路 cnn

卷積神經網路(rnn、lstm)

rnn 迴圈神經網路 (分類)

rnn 迴圈神經網路 (回歸)

自編碼 (autoencoder)

dqn 強化學習

生成對抗網路 (gan)

為什麼 torch 是動態的

gpu 加速運算

過擬合 (overfitting)

批標準化 (batch normalization)

包括以下幾種模式:

越複雜的神經網路 , 越多的資料 , 我們需要在訓練神經網路的過程上花費的時間也就越多. 原因很簡單, 就是因為計算量太大了. 可是往往有時候為了解決複雜的問題, 複雜的結構和大資料又是不能避免的, 所以我們需要尋找一些方法, 讓神經網路聰明起來, 快起來.

最基礎的方法就是 sgd 啦, 想像紅色方塊是我們要訓練的 data, 如果用普通的訓練方法, 就需要重複不斷的把整套資料放入神經網路 nn訓練, 這樣消耗的計算資源會很大.

我們換一種思路, 如果把這些資料拆分成小批小批的, 然後再分批不斷放入 nn 中計算, 這就是我們常說的 sgd 的正確開啟方式了. 每次使用批資料, 雖然不能反映整體資料的情況, 不過卻很大程度上加速了 nn 的訓練過程, 而且也不會丟失太多準確率.

如果運用上了 sgd, 你還是嫌訓練速度慢, 那怎麼辦?

事實證明, sgd 並不是最快速的訓練方法, 紅色的線是 sgd, 但它到達學習目標的時間是在這些方法中最長的一種. 我們還有很多其他的途徑來加速訓練.

大多數其他途徑是在更新神經網路引數那一步上動動手腳. 傳統的引數 w 的更新是把原始的 w 累加上乙個負的學習率(learning rate)乘以校正值 (dx). 這種方法可能會讓學習過程曲折無比, 看起來像 喝醉的人回家時, 搖搖晃晃走了很多彎路.

所以我們把這個人從平地上放到了乙個斜坡上, 只要他往下坡的方向走一點點, 由於向下的慣性, 他不自覺地就一直往下走, 走的彎路也變少了. 這就是 momentum 引數更新. 另外一種加速方法叫adagrad.

這種方法是在學習率上面動手腳, 使得每乙個引數更新都會有自己與眾不同的學習率, 他的作用和 momentum 類似, 不過不是給喝醉酒的人安排另乙個下坡, 而是給他一雙不好走路的鞋子, 使得他一搖晃著走路就腳疼, 鞋子成為了走彎路的阻力, 逼著他往前直著走. 他的數學形式如上圖所示.

接下來又有什麼方法呢? 如果把下坡和不好走路的鞋子合併起來, 是不是更好呢? 沒錯, 這樣我們就有了 rmsprop 更新方法.

有了 momentum 的慣性原則 , 加上 adagrad 的對錯誤方向的阻力, 我們就能合併成這樣. 讓 rmsprop同時具備他們兩種方法的優勢.

似乎在 rmsprop 中少了些什麼. 原來是我們還沒把 momentum合併完全, rmsprop 還缺少了 momentum 中的 這一部分. 所以, 我們在 adam 方法中補上了這種想法.

計算m 時有 momentum 下坡的屬性, 計算 v 時有 adagrad 阻力的屬性, 然後在更新引數時 把 m 和 v 都考慮進去. 實驗證明, 大多數時候, 使用 adam 都能又快又好的達到目標, 迅速收斂. 所以說, 在加速神經網路訓練的時候, 乙個下坡, 一雙破鞋子, 功不可沒.

pytorch實現神經網路

import torch import torch.nn as nn import torch.nn.functional as f import inspect import torch.optim as optim 自動求導機制 對乙個標量用backward 會反向計算在計算圖中用到的葉節點的梯...

PyTorch分類神經網路

這次我們也是用最簡單的途徑來看看神經網路是怎麼進行事物的分類.我們建立一些假資料來模擬真實的情況.比如兩個二次分布的資料,不過他們的均值都不一樣.import torch import matplotlib.pyplot as plt 假資料 n data torch.ones 100,2 資料的基...

卷積神經網路 pytorch

vocab args.vocab size 已知詞的數量 dim args.embed dim 每個詞向量長度 cla args.class num 類別數 ci 1 輸入的channel數 knum args.kernel num 每種卷積核的數量 ks args.kernel sizes 卷積核...