對Pytorch 中的contiguous理解說明

2022-09-26 00:24:11 字數 4234 閱讀 2192

最近遇到這個函式,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答並加上自己的理解。

在pytorch中,只有很少幾個操作是不改變tensor的內容本身,而只是重新定義下標與元素的對應關係的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是元資料。

這些操作是:

舉個栗子,在使用transpose()進行轉置操作時,pytorch並不會建立新的、轉置後的tensor,而是修改了tensor中的一些屬性(也就是元資料),使得此時的offset和stride是與轉置tensor相對應的。

轉置的tensor和原tensor的記憶體是共享的!

為了證明這一點,我們來看下面的**:

x = torch.randn(3, 2)

y = x.transpose(x, 0, 1)

x[0, 0] = 233

print(y[0, 0])

# print 233

可以看到,改變了y的元素的值的同時,x的元素的值也發生了變化。

也就是說,經過上述操作後得到的tensor,它內部資料的布局方式和從頭開始建立乙個這樣的常規的tensor的布局方式是不一樣的!於是…這就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因為內部資料不是通常的布局方式)。

注意不要被contiguous的字面意思「連續的」誤解,tensor中資料還是在記憶體中一塊區域裡,只是布局的問題!

當呼叫contiguous()時,會強制拷貝乙份tensor,讓它的布局和從頭建立的一毛一樣。

一般來說這一點不用太擔心,如果你沒在需要呼叫contiguous()的地方呼叫contiguous(),執行時會提示你:

runtimeerror: input is not contiguous

只要看到這個錯誤提示,加上contiguous()就好啦~

補充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

torch.gather(input,dim,index,out=none)。對指定維進行索引。比如4*3的張量,對dim=1進行索引,那麼index的取值範圍就是0~2.

input是乙個張量,index是索引張量。input和index的size要麼全部維度都相同,要麼指定的dim那一維度值不同。輸出為和index大小相同的張量。

import torch

a=torch.tensor([[.1,.2,.3],

[1.1,1.2,1.3],

[2.1,2.2,2.3],

[3.1,3.2,3.3]])

b=torch.longtensor([[1,2,1],

[2,2,2],

[2,2,2],

[1,1,0]])

b=b.view(4,3)

print(a.gather(1,b))

print(a.gather(0,b))

c=torch.longtensor([1,2,0,1])

c=c.view(4,1)

print(a.gather(1,程式設計客棧c))

輸出:tensor([[ 0.2000, 0.3000, 0.2000],

[ 1.3000, 1.3000, 1.3000],

[ 2.3000, 2.3000, 2.3000],

[ 3.2000, 3.2000, 3.1000]])

tensor([[ 1.1000, 2.2000, 1.3000],

[ 2.1000, 2.2000, 2.3000],

[ 2.1000, 2.2000, 2.3000],

[ 1.1000, 1.2000, 0.3000]])

tensor([[ 0.2000],

[ 1.3000],

[ 2.1000],

[ 3.2000]])

將維度為1的壓縮掉。如size為(3,1,1,2),壓縮之後為(3,2)

import torch

a=torch.randn(2,1,1,3)

print(awww.cppcns.com)

print(a.squeeze())

輸出:tensor([[[[-0.2320, 0.9513, 1.1613]]],

[[[ 0.0901, 0.9613, -0.9344]]]])

tensor([[-0.2320, 0.9513, 1.1613],

[ 0.0901, 0.9613, -0.9344]])

擴充套件某個size為1的維度。如(2,2,1)擴充套件為(2,2,3)

import torch

x=torch.randn(2,2,1)

print(x)

y=x.expand(2,2,3)

print(y)

輸出:tensor([[[ 0.0608],

[ 2.2106]],

[[程式設計客棧-1.9287],

[ 0.8748]]])

tensor([[[ 0.0608, 0.0608, 0.0608],

[ 2.2106, 2.2106, 2.2106]],

[[-1.9287, -1.9287, -1.9287],

[ 0.8748, 0.8748, 0.8748]]])

size為(m,n,d)的張量,dim=1時,輸出為size為(m,d)程式設計客棧的張量

import torch

a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])

print(a.sum())

print(a.sum(dim=1))

輸出:tensor(60)

tensor([[ 5, 10, 15],

[ 5, 10, 15]])

返回乙個記憶體為連續的張量,如本身就是連續的,返回它自己。一般用在view()函式之前,因為view()要求呼叫張量是連續的。

可以通過is_contiguous檢視張量記憶體是否連續。

import torch

a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])

print(a.is_contiguous)

print(a.contiguous().view(4,3))

輸出:tensor([[ 1, 2, 3],

[ 4, 8, 12],

[ 1, 2,程式設計客棧 3],

[ 4, 8, 12]])

假設陣列v有c個元素。對其進行softmax等價於將v的每個元素的指數除以所有元素的指數之和。這會使值落在區間(0,1)上,並且和為1。

import torch

import torch.nn.functional as f

a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])

b=f.softmax(a,dim=1)

print(b)

輸出:tensor([[ 0.5000, 0.5000],

[ 0.7311, 0.2689],

[ 0.8808, 0.1192],

[ 0.2689, 0.7311],

[ 0.1192, 0.8808]])

返回最大值,或指定維度的最大值以及index

import torch

a=torch.tensor([[.1,.2,.3],

[1.1,1.2,1.3],

[2.1,2.2,2.3],

[3.1,3.2,3.3]])

print(a.max(dim=1))

print(a.max())

輸出:(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))

tensor(3.3000)

返回最大值的index

import torch

a=torch.tensor([[.1,.2,.3],

[1.1,1.2,1.3],

[2.1,2.2,2.3],

[3.1,3.2,3.3]])

print(a.argmax(dim=1))

print(a.argmax())

輸出:tensor([ 2, 2, 2, 2])

tensor(11)

本文標題: 對pytorch 中的contiguous理解說明

本文位址:

對Pytorch中backward()函式的理解

寫在第一句 這個部落格解釋的也很好,參考了很多 pytorch中的自動求導函式backward 所需引數含義 所以切入正題 backward 函式中的引數應該怎麼理解?官方 如果需要計算導數,可以在tensor上呼叫.backward 1.如果tensor是乙個標量 即它包含乙個元素的資料 則不需要...

oracle中實現break和continue

一 continue 在 oracle 11g 之前無法使用 continue 實現退出當前迴圈的 11g中據說實現了 但是可以用一下方法模擬實現 declare 定義變數 begin fori in 1.10loop 真正的迴圈 forj in 1.1loop 假迴圈,目的是模擬出 continu...

對PyTorch中inplace欄位的全面理解

torch.nn.relu inplace true inplace true 表示進行原地操作,對上一層傳遞下來的tensor直接進行修改,如x x 3 inplace false 表示新建乙個變數儲存操作結果,如y x 3,x y inplace true 可以節省運算記憶體,不用多儲存變數。補...