Pytorch 索引與切片

2022-05-01 01:45:10 字數 3163 閱讀 7379

引言

本篇介紹pytorch 的索引與切片

123

4567

in[3]: a = torch.rand(4,3,28,28)

in[4]: a[0].shape # 理解上相當於取第一張

out[4]: torch.size([3, 28, 28])

in[5]: a[0,0].shape # 第0張的第0個通道

out[5]: torch.size([28, 28])

in[6]: a[0,0,2,4] # 第0張第0個通道的第2行第4列的畫素點 標量

out[6]: tensor(0.4133) # 沒有用 包起來就是乙個標量 dim為0

123

4567

8910

in[7]: a.shape

out[7]: torch.size([4, 3, 28, 28])

in[8]: a[:2].shape # 前面兩張的所有資料

out[8]: torch.size([2, 3, 28, 28])

in[9]: a[:2,:1,:,:].shape # 前面兩張的第0通道的資料

out[9]: torch.size([2, 1, 28, 28])

in[11]: a[:2,1:,:,:].shape # 前面兩張,第1,2通道的資料

out[11]: torch.size([2, 2, 28, 28])

in[10]: a[:2,-1:,:,:].shape # 前面兩張,最後乙個通道的資料 從-1到最末尾,就是它本身。

out[10]: torch.size([2, 1, 28, 28])

123

4

a[:,:,0:28,0:28:2].shape    # 隔點取樣

out[12]: torch.size([4, 3, 28, 14])

a[:,:,::2,::2].shape

out[14]: torch.size([4, 3, 14, 14])

123

4567

8910

in[17]: a.shape

out[17]: torch.size([4, 3, 28, 28])

in[19]: a.index_select(0, torch.tensor([0,2])).shape # 當前維度為0,取第0,2張

out[19]: torch.size([2, 3, 28, 28])

in[20]: a.index_select(1, torch.tensor([1,2])).shape # 當前維度為1,取第1,2個通道

out[20]: torch.size([4, 2, 28, 28])

in[21]: a.index_select(2,torch.arange(28)).shape # 第二個引數,只是告訴你取28行

out[21]: torch.size([4, 3, 28, 28])

in[22]: a.index_select(2, torch.arange(8)).shape # 取8行 [0,8)

out[22]: torch.size([4, 3, 8, 28])

123

4567

8910

in[23]: a.shape

out[23]: torch.size([4, 3, 28, 28])

in[24]: a[...].shape # 所有維度

out[24]: torch.size([4, 3, 28, 28])

in[25]: a[0,...].shape # 後面都有,取第0個 = a[0]

out[25]: torch.size([3, 28, 28])

in[26]: a[:,1,...].shape

out[26]: torch.size([4, 28, 28])

in[27]: a[...,:2].shape # 當有...出現時,右邊的索引理解為最右邊,只取兩列

out[27]: torch.size([4, 3, 28, 2])

123

4567

891011

1213

1415

16

in[31]: x = torch.randn(3,4)

in[32]: x

out[32]:

tensor([[ 2.0373, 0.1586, 0.1093, -0.6493],

[ 0.0466, 0.0562, -0.7088, -0.9499],

[-1.2606, 0.6300, -1.6374, -1.6495]])

in[33]: mask = x.ge(0.5) # >= 0.5 的元素的位置上為1,其餘地方為0

in[34]: mask

out[34]:

tensor([[1, 0, 0, 0],

[0, 0, 0, 0],

[0, 1, 0, 0]], dtype=torch.uint8)

in[35]: torch.masked_select(x,mask)

out[35]: tensor([2.0373, 0.6300]) # 之所以打平是因為大於0.5的元素個數是根據內容才能確定的

in[36]: torch.masked_select(x,mask).shape

out[36]: torch.size([2])

123

4567

in[39]: src = torch.tensor([[4,3,5],[6,7,8]])		# 先打平成1維的,共6列

in[40]: src

out[40]:

tensor([[4, 3, 5],

[6, 7, 8]])

in[41]: torch.take(src, torch.tensor([0, 2, 5])) # 取打平後編碼,位置為0 2 5

out[41]: tensor([4, 5, 8])

pytorch索引與切片

torch會自動從左向右索引 例子 a torch.randn 4,3,28,28 表示類似乙個cnn 的的輸入資料,4表示這個batch一共有4張 而3表示的通道數為3 rgb 28,28 表示的大小 基本索引print a 0 shape torch.size 3,28,28 print a 0...

pytorch索引與切片

目錄 torch會自動從左向右索引 例子 a torch.randn 4,3,28,28 表示類似乙個cnn 的的輸入資料,4表示這個batch一共有4張 而3表示的通道數為3 rgb 28,28 表示的大小 基本索引print a 0 shape torch.size 3,28,28 print ...

Pytorch學習 張量的索引切片

張量的索引切片方式和numpy幾乎是一樣的。切片時支援預設引數和省略號。可以通過索引和切片對部分元素進行修改。此外,對於不規則的切片提取,可以使用torch.index select,torch.masked select,torch.take 如果要通過修改張量的某些元素得到新的張量,可以使用to...