gather
torch.gather(input, dim, index, out=None) → Tensor
沿給定軸dim, 将輸入索引張量
index
指定位置的值進行聚合,即,指定次元的
大小
會發生變化。
對于一個3維張量,輸出可以定義為:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
舉個例子:
import torch
t = torch.Tensor([[,],[,]])
#t_gather在dim=1的次元為2,不變
t_gather = torch.gather(t, , torch.LongTensor([[,],[,]]))
#t_gather1在dim=1的次元為1
t_gather1 = torch.gather(t, , torch.LongTensor([[],[]]))
#列印結果
print(t_gather,t_gather1)
cat
torch.cat(inputs, dimension=0) → Tensor
在給定次元上對輸入的張量序列
seq
進行連接配接操作,合并操作,類似于
np.stack
參數:
inputs (sequence of Tensors) – 可以是任意相同Tensor 類型的python 序列
dimension (int, optional) – 沿着此維連接配接張量序列。
舉個例子:
import torch
x = torch.randn(,)
#類似于vstack
x3 = torch.cat((x,x,x),)
#類似于hstack
y3 = torch.cat((x,x,x),)
#列印結果
print(x3,y3)
unsqueeze
torch.unsqueeze(input, dim, out=None)
傳回一個新的張量,對輸入的制定位置插入次元 1
注意: 傳回張量與輸入張量共享記憶體,是以改變其中一個的内容會改變另一個。
如果
dim
為負,則将會被轉化
dim+input.dim()+1
參數:
tensor (Tensor) – 輸入張量
dim (int) – 插入次元的索引
out (Tensor, optional) – 結果張量
squeeze
torch.squeeze(input, dim, out=None)
将輸入張量形狀中的1 去除并傳回。 如果輸入是形如(A×1×B×1×C×1×D),那麼輸出形狀就為: (A×B×C×D)
當給定dim時,那麼擠壓操作隻在給定次元上。例如,輸入形狀為: (A×1×B), squeeze(input, 0) 将會保持張量不變,隻有用 squeeze(input, 1),形狀會變成 (A×B)。
注意: 傳回張量與輸入張量共享記憶體,是以改變其中一個的内容會改變另一個。
detach
var.detach()
傳回一個新的Tensor,從目前圖中脫離出來,該tensor不會要求更新梯度,也就是梯度在這中斷了。
注意:該新的Tensor與原Tensor共享記憶體。
舉個例子:
y=A(x), y1=B(y)
, y, Z代表兩個網絡,我們希望更新Z網絡,而不希望更新y網絡,是以需要将梯度在y時當機,即不會自動回傳梯度,而y時Z的輸入端,不影響其梯度回傳。
detach_
var.detach_()
從建立它的圖中分離張量,使其成為一片葉子。 視圖不能在原位分離。
與detach()差別在于,分離出的張量與原張量不共享記憶體,創造了一個Tensor。
view
view(*args) → Tensor
傳回一個新的tensor,與原tensor共享記憶體,但size不一樣,類似于reshape
Reference
pytorch tutorial
pytorch 中文版
pytorch新手需要注意的隐晦操作Tensor,max,gather
pytorch: Variable detach 與 detach_