官方文檔對該函數的解釋:
按自己的了解翻譯的,如有錯誤望指出
作用:沿 dim 指定的軸收集值
參數:
input (Tensor) – 要操作的張量
dim (int) – 要索引的軸
index (LongTensor) – 要收集的元素的索引
out (Tensor, optional) – 目标張量-要收集資料得到的張量
sparse_grad (bool,optional) – 如果為真,梯度 w.r.t 輸入為稀疏張量
例子:
t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])