pytorch學習筆記3 線性回歸

2021-10-11 18:30:34 字數 1778 閱讀 1454

線性回歸線性回歸 是分析乙個變數與另外乙個(多)個變數之間關係的方法

因變數:y

自變數:x

關係:線性

y=wx+b -> 求解w,b

求解步驟:

1.確定模型 module: y=wx+b

2.選擇損失函式 mse: 均方差等

3.求解梯度並更新w,b w=w-lr*w.grad b=b-lr*w.grad

lr:步長,即學習率 -> 迭代更新使損失函式值較小即可

import torch

import matplotlib.pyplot as plt

torch.manual_seed(10) # 為cpu設定種子用於生成隨機數,以使得結果是確定的

lr = 0.05 # 學習率

# 建立訓練資料

x = torch.rand(20, 1) * 10

y = 2*x + (5 + torch.randn(20, 1))

# 構建線性回歸引數

w = torch.randn((1), requires_grad=true)

b = torch.zeros((1), requires_grad=true)

for iteration in range(1000):

# 前向傳播

wx = torch.mul(w, x)

y_pred = torch.add(wx, b)

# 計算mes loss

loss = (0.5 * (y-y_pred) ** 2).mean()

# 反向傳播

loss.backward()

# 更新引數

b.data.sub_(lr * b.grad)

w.data.sub_(lr * w.grad)

# 清零張量的梯度

w.grad.zero_()

b.grad.zero_()

# 繪圖

if iteration % 20 == 0:

plt.scatter(x.data.numpy(), y.data.numpy())

plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)

plt.text(2, 20, 'loss=%.4f' % loss.data.numpy(), fontdict=)

plt.xlim(1.5, 10)

plt.ylim(8, 28)

plt.title("iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))

plt.pause(0.5)

if loss.data.numpy() < 1:

break

當loss函式值低於1是結束迭代,

YOLOv3 Pytorch學習筆記

五月一直埋頭鑽研faster r cnn,但苦於電腦不支援gpu,一直連個簡單的結果都沒跑出來 期間還掙扎著安裝cuda,結果就是ubuntu系統一崩再崩 yolo官網 翻譯 按官網上的要求一步一步走,就可得到如下檢測結果 如果想訓練自己的資料,可參考部落格 yolov3 訓練自己的資料,講的非常詳...

Pytorch 學習筆記

本渣的pytorch 逐步學習鞏固經歷,希望各位大佬指正,寫這個部落格也為了鞏固下記憶。a a b c 表示從 a 取到 b 步長為 c c 2 則表示每2個數取1個 若為a 1 1 1 即表示從倒數最後乙個到正數第 1 個 最後乙個 1 表示倒著取 如下 陣列的第乙個為 0 啊,第 0 個!彆扭 ...

Pytorch學習筆記

資料集 penn fudan資料集 在學習pytorch官網教程時,作者對penn fudan資料集進行了定義,並且在自定義的資料集上實現了對r cnn模型的微調。此篇筆記簡單總結一下pytorch如何實現定義自己的資料集 資料集必須繼承torch.utils.data.dataset類,並且實現 ...