PyTorch實現多項式回歸,並實現Loss視覺化

2021-09-28 18:44:26 字數 2340 閱讀 8253

import torch

import numpy as np

from torch import nn

from torch.autograd import variable

import visdom

import matplotlib.pyplot as plt

import random

viz = visdom.visdom(env='train')

loss_win = viz.line(np.arange(10))

#torch.cat()實現tensor拼接

def make_features(x):

x = x.unsqueeze(1)

return torch.cat([x ** i for i in range(1,4)],1)

#定義實際的函式

#unsqueeze將原來的tensor大小由3變成(3,1)

w_target = torch.floattensor([0.5,3,2.4]).unsqueeze(1)

b_target = torch.floattensor([0.9])

def f(x):

return x.mm(w_target) + b_target[0]

#定義每次的訓練集

#每次取batch_size個資料點,然後轉換成矩陣的形式

def get_batch(batch_size=32,random = none):

if random is none:

random = torch.randn(batch_size)

batch_size = random.size()[0]

x = make_features(random)

y = f(x)

return variable(x),variable(y)

#定義模型

# model = nn.linear(3,1)也可以

n = 3

class poly_model(nn.module) :

def __init__(self, n) :

super().__init__()

self.poly = nn.linear(n, 1)

def forward(self, x) :

return self.poly(x)

model = poly_model(n)

criterion = nn.mseloss()

optimizer = torch.optim.sgd(model.parameters(),lr=1e-3)

print('****')

epoch=0

while true:

batch_x,batch_y = get_batch()

#forward

output = model(batch_x)

loss = criterion(output,batch_y)

print_loss = loss.item()

if(epoch + 1) % 20 == 0:

print('epoch[{}],loss:'.format(epoch+1,loss.item()))

#backward

optimizer.zero_grad()

loss.backward()

optimizer.step()

epoch+=1

if(print_loss<1e-3):

print('loss:{} after {} batchs'.format(print_loss,epoch))

break

x = [random.randint(-200,200)*0.01 for i in range(20)]

x = np.array(sorted(x))

featurn_x,y = get_batch(random = torch.from_numpy(x).float())

y = y.data.numpy()

plt.plot(x,y,'ro',label='original data')

model.eval()

x_sample = np.arange(-2,2,0.01)

x, y = get_batch(random = torch.from_numpy(x_sample).float())

y = model(x)

y_sample = y.data.numpy()

plt.plot(x_sample,y_sample,label='fitting line')

plt.show()

多項式回歸

import numpy as np import matplotlib.pyplot as plt x np.random.uniform 3,3,size 100 x x.reshape 1,1 y 0.5 x 2 x 2 np.random.normal 0,1,100 plt.scatter...

多項式回歸

多項式回歸 import torch import numpy defmake features x 獲取 x,x 2,x 3 的矩陣 x x.unsqueeze 1 將一維資料變為 n,1 二維矩陣形式 return torch.cat x i for i in range 1 4 1 按列拼接 ...

多項式回歸

線性回歸適用於資料成線性分布的回歸問題,如果樣本是非線性分布,線性回歸就不再使用,轉而可以採用非線性模型進行回歸,比如多項式回歸 多項式回歸模型定義 與線性模型,多項式模型引入了高次項 y w 0 w1 x w2 x2 w 3x3 wnxn y w 0 w 1x w 2x 2 w 3x 3 w nx...