天天看点

对于pytorch,gather()函数的理解

官方文档对该函数的解释:

按自己的理解翻译的,如有错误望指出

作用:沿 dim 指定的轴收集值

参数:

input (Tensor) – 要操作的张量

dim (int) – 要索引的轴

index (LongTensor) – 要收集的元素的索引

out (Tensor, optional) – 目标张量-要收集数据得到的张量

sparse_grad (bool,optional) – 如果为真,梯度 w.r.t 输入为稀疏张量

例子:

t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
           
tensor([[ 1,  1],
        [ 4,  3]])
           
对于pytorch,gather()函数的理解

继续阅读