天天看點

對于pytorch,gather()函數的了解

官方文檔對該函數的解釋:

按自己的了解翻譯的,如有錯誤望指出

作用:沿 dim 指定的軸收集值

參數:

input (Tensor) – 要操作的張量

dim (int) – 要索引的軸

index (LongTensor) – 要收集的元素的索引

out (Tensor, optional) – 目标張量-要收集資料得到的張量

sparse_grad (bool,optional) – 如果為真,梯度 w.r.t 輸入為稀疏張量

例子:

t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
           
tensor([[ 1,  1],
        [ 4,  3]])
           
對于pytorch,gather()函數的了解

繼續閱讀