莫煩 pytorch RNN 回歸

2021-10-03 15:38:34 字數 2461 閱讀 8494

import torch

from torch import nn

import numpy as np

import torchvision.datasets as dsets

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# hyper parameters

batch_size =

64epoch =

1time_step =

28# 考慮多少個時間點的資料

input_size =

1# 每個時間點給rnn多少個資料點

lr =

0.01

class

rnn(nn.module)

:def

__init__

(self)

:super

(rnn, self)

.__init__(

) self.rnn = nn.rnn(

input_size=input_size,

hidden_size=32,

num_layers=1,

batch_first=

true,)

self.out = nn.linear(32,

1)defforward

(self, x, h_state)

:# x(batch, time_step, input_size)

# h_state(n_layers, batch, hidden_size)

# r_out(batch, time_step, output_size = hidden_size)

r_out, h_state = self.rnn(x, h_state)

outs =

for time_step in

range

(r_out.size(1)

):# size是tensor的形狀是乙個陣列,size(1)就是裡面的第二個值域,

# 就是time_step的值的個數 即第二個維度的大小

:, time_step,:]

))return torch.stack(outs, dim=1)

, h_state

rnn = rnn(

)print

(rnn)

optimizer = torch.optim.adam(rnn.parameters(

), lr=lr)

# optimize all cnn parameters

loss_func = nn.mseloss(

)h_state =

none

plt.figure(

1, figsize=(12

,5))

plt.ion(

)for step in

range(50

):start, end = step * np.pi,

(step +1)

* np.pi

# use sin pre cos

steps = np.linspace(start, end, time_step, dtype=np.float32)

x_np = np.sin(steps)

y_np = np.cos(steps)

x = torch.from_numpy(x_np[np.newaxis,

:, np.newaxis]

)# shape(batch, time_step, input_size)

y = torch.from_numpy(y_np[np.newaxis,

:, np.newaxis]

) prediction, h_state = rnn(x, h_state)

h_state = h_state.data # !!! this step is important

loss = loss_func(prediction, y)

optimizer.zero_grad(

)# clear gradient for next train

loss.backward(

)# back propagation, compute gradient

optimizer.step(

)# plot

plt.plot(steps, y_np.flatten(),

'r-'

) plt.plot(steps, prediction.data.numpy(

).flatten(),

'b-'

) plt.draw(

) plt.pause(

0.5)

plt.ioff(

)plt.show(

)

莫煩 Tensorflow 變數

理解了tensorflow必須通過session去run才能輸出值,如果不通過session,那麼只能得到變數名字和型別,加深理解了tensorflow的機制。import tensorflow as tf state tf.variable 0,name counter print state 輸...

莫煩Python matplotlib基本使用篇

以下 可直接在pycharm下執行,前提是已安裝numpy和matplotlib。中的每個功能都進行了注釋,讀者可自行注釋某一部分 檢視結果,以便檢驗其中某個函式的功能。import matplotlib.pyplot as plt 匯入matplotlib import numpy as np p...

莫煩Tensorflow 入門

tensorflow 初步嘗試 建立資料 搭建模型 計算誤差 傳播誤差 初始會話 不斷訓練 import tensorflow as tf import numpy as np 建立資料 x data np.random.rand 100 astype np.float32 y data x dat...