天天看點

torch之gather,detach,cat,squeeze,unqueeze, view

gather

torch.gather(input, dim, index, out=None) → Tensor

沿給定軸dim, 将輸入索引張量

index

指定位置的值進行聚合,即,指定次元的

大小

會發生變化。

對于一個3維張量,輸出可以定義為:

out[i][j][k] = tensor[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]]  # dim=3
           

舉個例子:

import torch 
t = torch.Tensor([[,],[,]])
#t_gather在dim=1的次元為2,不變
t_gather = torch.gather(t, , torch.LongTensor([[,],[,]])) 
#t_gather1在dim=1的次元為1
t_gather1 = torch.gather(t, , torch.LongTensor([[],[]])) 
#列印結果
print(t_gather,t_gather1) 
           

cat

torch.cat(inputs, dimension=0) → Tensor

在給定次元上對輸入的張量序列

seq

進行連接配接操作,合并操作,類似于

np.stack

參數:

inputs (sequence of Tensors) – 可以是任意相同Tensor 類型的python 序列

dimension (int, optional) – 沿着此維連接配接張量序列。

舉個例子:

import torch 

x = torch.randn(,)
#類似于vstack
x3 = torch.cat((x,x,x),)
#類似于hstack
y3 = torch.cat((x,x,x),)
#列印結果
print(x3,y3)
           

unsqueeze

torch.unsqueeze(input, dim, out=None)

傳回一個新的張量,對輸入的制定位置插入次元 1

注意: 傳回張量與輸入張量共享記憶體,是以改變其中一個的内容會改變另一個。

如果

dim

為負,則将會被轉化

dim+input.dim()+1

參數:

tensor (Tensor) – 輸入張量

dim (int) – 插入次元的索引

out (Tensor, optional) – 結果張量

squeeze

torch.squeeze(input, dim, out=None)

将輸入張量形狀中的1 去除并傳回。 如果輸入是形如(A×1×B×1×C×1×D),那麼輸出形狀就為: (A×B×C×D)

當給定dim時,那麼擠壓操作隻在給定次元上。例如,輸入形狀為: (A×1×B), squeeze(input, 0) 将會保持張量不變,隻有用 squeeze(input, 1),形狀會變成 (A×B)。

注意: 傳回張量與輸入張量共享記憶體,是以改變其中一個的内容會改變另一個。

detach

var.detach()

傳回一個新的Tensor,從目前圖中脫離出來,該tensor不會要求更新梯度,也就是梯度在這中斷了。

注意:該新的Tensor與原Tensor共享記憶體。

舉個例子:

y=A(x), y1=B(y)

, y, Z代表兩個網絡,我們希望更新Z網絡,而不希望更新y網絡,是以需要将梯度在y時當機,即不會自動回傳梯度,而y時Z的輸入端,不影響其梯度回傳。

detach_

var.detach_()

從建立它的圖中分離張量,使其成為一片葉子。 視圖不能在原位分離。

與detach()差別在于,分離出的張量與原張量不共享記憶體,創造了一個Tensor。

view

view(*args) → Tensor

傳回一個新的tensor,與原tensor共享記憶體,但size不一樣,類似于reshape

Reference

pytorch tutorial

pytorch 中文版

pytorch新手需要注意的隐晦操作Tensor,max,gather

pytorch: Variable detach 與 detach_

繼續閱讀