Tensor索引操作

2022-07-20 02:51:09 字數 3091 閱讀 1876

#tensor索引操作  

''''' 

tensor支援與numpy.ndarray類似的索引操作,語法上也類似 

如無特殊說明,索引出來的結果與原tensor共享記憶體,即修改乙個,另乙個會跟著修改 

'''  

import torch as t  

a = t.randn(3,4)  

'''''tensor([[ 0.1986,  0.1809,  1.4662,  0.6693], 

[-0.8837, -0.0196, -1.0380,  0.2927], 

[-1.1032, -0.2637, -1.4972,  1.8135]])'''  

print(a[0])         #第0行  

'''''tensor([0.1986, 0.1809, 1.4662, 0.6693])'''  

print(a[:,0])       #第0列  

'''''tensor([ 0.1986, -0.8837, -1.1032])'''  

print(a[0][2])      #第0行第2個元素,等價於a[0,2]  

'''''tensor(1.4662)'''  

print(a[0][-1])     #第0行最後乙個元素  

'''''tensor(0.6693)'''  

print(a[:2,0:2])    #前兩行,第0,1列  

'''''tensor([[ 0.1986,  0.1809], 

[-0.8837, -0.0196]])'''  

print(a[0:1,:2])    #第0行,前兩列  

'''''tensor([[0.1986, 0.1809]])'''  

print(a[0,:2])      #注意兩者的區別,形狀不同  

'''''tensor([0.1986, 0.1809])'''  

print(a>1)  

'''''tensor([[0, 0, 1, 0], 

[0, 0, 0, 0], 

[0, 0, 0, 1]], dtype=torch.uint8)'''  

print(a[a>1])        #等價於a.masked_select(a>1),選擇結果與原tensor不共享記憶體空間  

print(a.masked_select(a>1))  

'''''tensor([1.4662, 1.8135]) 

tensor([1.4662, 1.8135])'''  

print(a[t.longtensor([0,1])])  

'''''tensor([[ 0.1986,  0.1809,  1.4662,  0.6693], 

[-0.8837, -0.0196, -1.0380,  0.2927]])'''  

''''' 

常用的選擇函式 

index_select(input,dim,index)   在指定維度dim上選取,列如選擇某些列、某些行 

masked_select(input,mask)       例子如上,a[a>0],使用bytetensor進行選取 

non_zero(input)                 非0元素的下標 

gather(input,dim,index)         根據index,在dim維度上選取資料,輸出size與index一樣 

gather是乙個比較複雜的操作,對乙個二維tensor,輸出的每個元素如下: 

out[i][j] = input[index[i][j]][j]   #dim = 0 

out[i][j] = input[i][index[i][j]]   #dim = 1 

'''  

b = t.arange(0,16).view(4,4)  

'''''tensor([[ 0,  1,  2,  3], 

[ 4,  5,  6,  7], 

[ 8,  9, 10, 11], 

[12, 13, 14, 15]])'''  

index = t.longtensor([[0,1,2,3]])  

print(b.gather(0,index))            #取對角線元素  

'''''tensor([[ 0,  5, 10, 15]])'''  

index = t.longtensor([[3,2,1,0]]).t()       #取反對角線上的元素  

print(b.gather(1,index))  

'''''tensor([[ 3], 

[ 6], 

[ 9], 

[12]])'''  

index = t.longtensor([[3,2,1,0]])           #取反對角線的元素,與上面不同  

print(b.gather(0,index))  

'''''tensor([[12,  9,  6,  3]])'''  

index = t.longtensor([[0,1,2,3],[3,2,1,0]]).t()  

print(b.gather(1,index))  

'''''tensor([[ 0,  3], 

[ 5,  6], 

[10,  9], 

[15, 12]])'''  

''''' 

與gather相對應的逆操作是scatter_,gather把資料從input中按index取出,而 

scatter_是把取出的資料再放回去,scatter_函式時inplace操作 

out = input.gather(dim,index) 

out = tensor() 

out.scatter_(dim,index) 

'''  

x = t.rand(2, 5)  

print(x)  

c = t.zeros(3, 5).scatter_(0, t.longtensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)  

print(c)  

2018-10-23 20:30:30      

tensorflow中tensor的索引

tensorflow中tensor的索引 1.print sess.run outputs 0 2,0 2,2.print sess.run tf.slice outputs,0,0,0 2,2,2 3.print sess.run tf.gather outputs,0,2 1和2是等效的,不難看...

Tensor 逐元素操作

逐元素操作 這部分操作會對tensor的每乙個元素 point wise,又名element wise 進行操作,此類操作的輸入與輸出形狀一致。常用的操作如表3 4所示。表3 4 常見的逐元素操作 函式功能 abs sqrt div exp fmod log pow.絕對值 平方根 除法 指數 求餘...

常用的Tensor操作

1 通過tensor.view方法可以調整tensor的形狀,但必須保證調整去前後元素總數一致。view不會修改自身的資料,返回新的tensor與原tensor共享記憶體,即更改其中的乙個,另乙個也會跟這改變。2 實際中經常需要新增或減少某一維度,可用squeeze和unsqueeze這兩函式。im...