pytorch 搭建網路步驟

2021-10-22 09:35:04 字數 3225 閱讀 4527

準備資料

定義網路結構model

定義損失函式

定義優化演算法 optimizer

(有是還要定義更新學習率:scheduler=steplr())

訓練5.1 準備好tensor形式的輸入資料和標籤(可選)

5.2 前向傳播計算網路輸出output和 計算損失函式loss

5.3 反向傳播更新引數

以下三句話一句也不能少:

5.3.1optimizer.zero_grad()將上次迭代計算的梯度值清0

5.3.2loss.backward()反向傳播,計算梯度值

5.3.3optimizer.step()更新權值引數

schedule.step(episode)更新學習率)

5.4 儲存訓練集上的loss和驗證集上的loss以及準確率以及列印訓練資訊。(可選

圖示訓練過程中loss和accuracy的變化情況(可選)

在測試集上測試

示例**:

import torch

import torch.nn.functional as f

import matplotlib.pyplot as plt

# 1.準備資料 generate data

x=torch.unsqueeze(torch.linspace(-1

,1,100

),dim=1)

print

(x.shape)

y=x*x+

0.2*torch.rand(x.size())

#顯示資料散點圖

plt.scatter(x.data.numpy(

),y.data.numpy())

# 2.定義網路結構 build net

class

net(torch.nn.module)

:#n_feature:輸入特徵個數 n_hidden:隱藏層個數 n_output:輸出層個數

def__init__

(self,n_feature,n_hidden,n_output)

:# super表示繼承net的父類,並同時初始化父類的引數

super

(net,self)

.__init__(

)# nn.linear代表線性層 代表y=w*x+b 其中w的shape為[n_hidden,n_feature] b的shape為[n_hidden]

# y=w^t*x+b 這裡w的維度是轉置前的維度 所以是反的

self.hidden =torch.nn.linear(n_feature,n_hidden)

self.predict =torch.nn.linear(n_hidden,n_output)

print

(self.hidden.weight)

print

(self.predict.weight)

#定義乙個前向傳播過程函式

defforward

(self, x)

:# n_feature n_hidden n_output

#舉例(2,5,1) 2 5 1

# - ** -

# ** - - - ** - -

# - ** - - - **

# ** - - - ** - -

# - ** -

# 輸入層 隱藏層 輸出層

x=f.relu(self.hidden(x)

) x=self.predict(x)

return x

# 例項化乙個網路為net

net = net(n_feature=

1,n_hidden=

10,n_output=1)

print

(net)

# 3.定義損失函式 這裡使用均方誤差(mean square error)

loss_func=torch.nn.mseloss(

)# 4.定義優化器 這裡使用隨機梯度下降

optimizer=torch.optim.sgd(net.parameters(

),lr=

0.2)

#定義300遍更新 每10遍顯示一次

plt.ion(

)# 5.訓練

for t in

range

(100):

prediction = net(x)

# input x and predict based on x

loss = loss_func(prediction, y)

# must be (1. nn output, 2. target)

# 5.3反向傳播三步不可少

optimizer.zero_grad(

)# clear gradients for next train

loss.backward(

)# backpropagation, compute gradients

optimizer.step(

)if t %

10==0:

# plot and show learning process

plt.cla(

) plt.scatter(x.data.numpy(

), y.data.numpy())

plt.plot(x.data.numpy(

), prediction.data.numpy(),

'r-'

, lw=5)

plt.text(

0.5,0,

'loss=%.4f'

% loss.data.numpy(

), fontdict=

) plt.show(

) plt.pause(

0.1)

plt.ioff(

)

參考:pytorch基礎-搭建網路

pytorch基礎 搭建網路

搭建網路的步驟大致為以下 1.準備資料 2.定義網路結構model 3.定義損失函式 4.定義優化演算法 optimizer 5.訓練 5.1 準備好tensor形式的輸入資料和標籤 可選 5.2 前向傳播計算網路輸出output和計算損失函式loss 5.3 反向傳播更新引數 以下三句話一句也不能...

R seau Donn e 搭建網路

reseu donnnee這門基本處於學一回忘一回的階段,這次,趁還沒忘利索之前,趕緊寫下來,為以後用著的時候存著。網路的組成 client1 communateur routeur1 routeaur2 communateur client2 配置ip 1.sudo ifconfig eth0 1...

搭建網路源

搭建本地源 1.mount o loop home centos 7 x86 64 everything 1708.iso mnt sr0 掛載檔案到mnt下的sr0,如果沒有sr0可以自己建乙個 2.lsblk可以檢視到掛載的資訊 3.vi etc yum.repos.d centos base....