Pytorch入門之softmax回歸

2021-10-24 13:04:53 字數 3503 閱讀 4917

主要用於分類問題,通過將前幾層網路的輸出得分轉為相應分類的概率,然後選取概率最高的類別作為此次分類的的結果。比如經過四分類問題經過softmax處理後的結果為[0.1,0.3,0.1,0.5],則最後屬於第四類,因為0.5最大了。

其實softmax處理分類問題和用線性回歸處理回歸問題大致相同

只不過線性回歸在全連線層後用的恒等函式relu函式輸出,然後利用均方誤差計算loss,然後反向傳播利用sgd訓練引數。

而softmax處理分類問題呢,是在全連線層後的輸出利用全連線的形式計算每個類別的概率值,然後利用交叉熵(cross entropy)計算loss,然後反向傳播利用sgd訓練引數。

關於基於pytorch處理線性回歸,可以參考我的另一篇部落格

那就開始了,下面這個是我整個網路的結構圖(連線線太麻煩了,請忽略)

資料集:採用fashion-mnist資料,這是乙個每個類別有6,000張訓練集和1,000張,一共10個類別,分別是t-shirt(

t恤)、

trouser

(褲⼦)、

pullover

(套衫)、 dress(連⾐裙)、

coat

(外套)、

sandal

(涼鞋)、

shirt

(襯衫)、

sneaker

(運動鞋)、 bag(包)和

ankle boot(短靴)

讀入fashion-mnist資料集,需要匯入torchvision包,用於構建計算機模型

# 或者 root='./fashionmnist',即儲存在當前目錄下

transform = transforms.totensor() # 將所有資料轉為tensor

mnist_train = torchvision.datasets.fashionmnist(root=root, train=true, download=true, transform=transform)

mnist_test = torchvision.datasets.fashionmnist(root=root, train=false, download=true, transform=transform)

torchvision.transforms常用於做一些變換,例如裁剪、旋轉等。transforms.totensor()

將尺⼨為 (h x w x c)

且資料位於

[0, 255]

的 pil 圖⽚或者資料型別為

np.uint8

的 numpy

陣列轉換為尺寸為

(c x h x w)

且資料型別為 torch.float32

且位於[0.0, 1.0]

的 tensor 。

這裡讀取的資料mnist_train、mnist_test格式為:[(x1,y1),(x2,y2)...] ,

其中x1:(1, 28, 28),表示1通道(灰度圖)解析度為28*28的影象

y:tensor標量,表示標籤,如tensor(9),他其實對應了查詢表中的 'ankle boot'。

note:

1、torchvison.models包包含常用的模型結構,常用於預訓練,如vgg、alexnet等

匯入資料集:

train_iter = torch.utils.data.dataloader(mnist_train, batch_size=batch_size, shuffle=true, num_workers=num_workers)

test_iter = torch.utils.data.dataloader(mnist_test, batch_size=batch_size, shuffle=false, num_workers=num_workers)

mnist_train、mnist_test為torch.utils.data.dataset的子類,故可以直接用torch.utils.data.dataloader直接匯入

note:

train_iter格式: [(x1,y1),(x2,y2)...] x1:(batch, 1, 28, 28) y:(batch,),標籤不是one-hot形式

到了正式開始的時候了

根據圖中所畫的,網路類中只需要乙個全連線層

自行建立模型類,繼承nn.module (也可以用nn.sequential(nn.module的子類)

class net(nn.module):

def __init__(self, feature_in, feature_out):

super().__init__()

self.linear = nn.linear(feature_in, feature_out)

def forward(self, datain):

out = self.linear(datain.view(datain.shape[0], -1)) # (batchsize, 10)

return out

法二:利用

# 搭建網路

network = net(784, 10)

note:需要注意的是,全連線層的輸入必須是個二維tensor,故需要進行轉換,即(batchsize,1,28,28)轉為(batchsize,784)

這樣全連線的輸出y_hat就是(batchsize, 10)

# 初始化已經在類的例化中完成了

loss = nn.crossentropyloss()

optimizer = optim.sgd(network.parameters(), lr=0.5)

note:分類問題常呼叫的nn.crossentropyloss()是乙個包含了softmax函式和cross_entropy函式的乙個函式。

for epoch in range(100):

for x, y in data_iter:

output = net(x) # batch

ls = loss(output, y).sum

optimizer.zero_grad()

ls.backward()

optimizer.step() # 更新以及對梯度求平均

在測試集上對每乙個epoch計算損失、正確率。由於沒有其他的優化措施,所以最終識別率在83%左右

選取了前十張圖的識別效果(僅有乙個判斷錯誤)

Pytorch入門之線性回歸

這裡定義乙個簡單的神經網路來做乙個線性回歸問題 神經元之間的線就不連了,大家知道是個全連線層就好 搭建這樣乙個網路,首先就是需要定義乙個class,class必須得繼承nn.module類,常用來被繼承,然後使用者去編寫自己的網路 層。類中的初始化部分需要去例化自己的層。這裡需要定義2個全連線層,因...

Pytorch入門 安裝

pytorch目前支援的平台有linux和osx,在pytorch官網上每種平台提供了conda pip source三種安裝方式,同時也可以根據有無gpu進行cuda安裝,在這裡以ubuntu14.04進行安裝學習。1.anaconda安裝配置 安裝過程參考我之前的anaconda tensorf...

PyTorch快速入門

詳細的pytorch教程可以去pytorch官網的學習指南進一步學習,下面主要對pytorch做簡單的介紹,能夠快速入門。首先pytorch是基於python的科學計算類庫,主要有以下兩個方面的應用 作為numpy的替代者,充分利用gpu的計算能力。提供乙個靈活 快速的深度學習平台。tensor 與...