Pytorch學習 張量的索引切片

2021-10-07 23:56:40 字數 2887 閱讀 5823

張量的索引切片方式和numpy幾乎是一樣的。切片時支援預設引數和省略號。

可以通過索引和切片對部分元素進行修改。

此外,對於不規則的切片提取,可以使用torch.index_select, torch.masked_select, torch.take

如果要通過修改張量的某些元素得到新的張量,可以使用torch.where,torch.masked_fill,torch.index_fill

#第一行

print

(t[-1]

)#倒數第一行

print

(t[1,3

])#第一行第三列

print

(t[1][

3])print

(t[1:4

,:])

#第一到第三行

print

(t[1:4

,:5:

2])#1-3行,0-4列step2

對於不規則的切片提取,可以使用torch.index_select, torch.take, torch.gather, torch.masked_select.

考慮班級成績冊的例子,有4個班級,每個班級10個學生,每個學生7門科目成績。可以用乙個4×10×7的張量來表示。

minval=

0maxval=

100scores = torch.floor(minval +

(maxval-minval)

*torch.rand([4

,10,7

])).

int(

)print

(scores)

#抽取每個班級第0個學生,第5個學生,第9個學生的全部成績

torch.index_select(scores,dim =

1,index = torch.tensor([0

,5,9]))

#抽取每個班級第0個學生,第5個學生,第9個學生的第1門課程,第3門課程,第6門課程成績

q = torch.index_select(torch.index_select(scores,dim =

1,index = torch.tensor([0

,5,9

])),dim=

2,index = torch.tensor([1

,3,6

]))print

(q)

#抽取第0個班級第0個學生的第0門課程,第2個班級的第4個學生的第1門課程,第3個班級的第9個學生第6門課程成績

#take將輸入看成一維陣列,輸出和index同形狀

s = torch.take(scores,torch.tensor([0

*10*7

+0,2

*10*7

+4*7

+1,3

*10*7

+9*7

+6])

)s

#抽取分數大於等於80分的分數(布林索引)

#結果是1維張量

g = torch.masked_select(scores,scores>=80)

print

(g)

以上這些方法僅能提取張量的部分元素值,但不能更改張量的部分元素值得到新的張量。

如果要通過修改張量的部分元素值得到新的張量,可以使用torch.where,torch.index_fill 和 torch.masked_fill

torch.where可以理解為if的張量版本。

torch.index_fill的選取元素邏輯和torch.index_select相同。

torch.masked_fill的選取元素邏輯和torch.masked_select相同。

#如果分數大於60分,賦值成1,否則賦值成0

ifpass = torch.where(scores>

60,torch.tensor(1)

,torch.tensor(0)

)print

(ifpass)

#將每個班級第0個學生,第5個學生,第9個學生的全部成績賦值成滿分

torch.index_fill(scores,dim =

1,index = torch.tensor([0

,5,9

]),value =

100)

#等價於 scores.index_fill(dim = 1,index = torch.tensor([0,5,9]),value = 100)

#將分數小於60分的分數賦值成60分

b = torch.masked_fill(scores,scores<60,

60)#等價於b = scores.masked_fill(scores<60,60)

b

pytorch 張量 張量的生成

張量的生成 import torch import numpy as np 使用tensor.tensor 函式構造張量 a torch.tensor 1.0,1.0 2.2 print a 獲取張量的維度 print 張量的維度 a.shape 獲取張量的形狀大小 print 張量的大小 a.si...

pytorch 張量 張量的資料型別

張量定義 import torch torch.tensor 1.2 3.4 dtype 獲取張量的資料型別,其中torch.tensor 函式生成乙個張量 torch.float32 torch.set default tensor type torch.doubletensor 設定張量的預設資...

PyTorch學習筆記1 張量

pytorch中資料集用tensor來表示,tensor與python中的list類似,但是其內部儲存時以連續記憶體單元儲存,可以通過下標計算出記憶體位址,然後直接讀出數值,因此訪問效率很高,同時由於與numpy的記憶體儲存基本相同,所以numpy的ndarray與tensor之間轉換,不論有多少元...