天天看点

pytorch中的torch.gather()函数的理解和使用在看pytorch 时,遇到了gather函数,主要对其中dim =0,dim=1有很多困惑不清楚如何进行取数的,后面明白了点作为自己总结也希望可以帮到大家

在看pytorch 时,遇到了gather函数,主要对其中dim =0,dim=1有很多困惑不清楚如何进行取数的,后面明白了点作为自己总结也希望可以帮到大家

# 选取对角线的元素
import torch as t
a = t.arange(0, 16).view(4, 4)
index = t.LongTensor([[0,1,2,3]])
b=a.gather(0, index)
print(b)
           

我们先了解一下这个函数以及其中的参数都代表的什么意思,

torch.gather(input, dim, index, out=None) → Tensor

input:要进行操作的tensor数据,也就是例子中的a
dim:维度
index:索引列表

函数操作就是按照给定轴dim的设置,根据index索引列表取出input数据中相对应位置的数,通过这些参数可能还是不太容易明白,其中最难理解的的就是维度dim这个参数了现在我们看几个例子,相信看完之后少许会明白一些

..............下面是dim = 0时测试........

import torch
a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,0],[1,1,1]])
b = torch.gather(a,0,index_1)
print('a=',a)
print('b=',b)
.............下面是输出结果............
a= tensor([[1., 2., 3.],
        [4., 5., 6.]])
b= tensor([[1., 5., 3.],
        [4., 5., 6.]])

..............下面是dim = 1时的测试.............
import torch
a = torch.Tensor([[1,2,3],[4,5,6]])
index_2 = torch.LongTensor([[0,1,0],[1,1,1]])
b = torch.gather(a,1,index_2)
print('a=',a)
print('b=',b)
..................输出............
a= tensor([[1., 2., 3.],
        [4., 5., 6.]])
b= tensor([[1., 2., 1.],
        [5., 5., 5.]])
           

dim = 0表示的是对纵向的操作也就是列的操作

index表示从input数据的所取的数据的位置,我们可以根据下面这个文字说明对应上面的程序来对照理解

pytorch中的torch.gather()函数的理解和使用在看pytorch 时,遇到了gather函数,主要对其中dim =0,dim=1有很多困惑不清楚如何进行取数的,后面明白了点作为自己总结也希望可以帮到大家

dim =1表示是对横向的操作也就是行的操作

pytorch中的torch.gather()函数的理解和使用在看pytorch 时,遇到了gather函数,主要对其中dim =0,dim=1有很多困惑不清楚如何进行取数的,后面明白了点作为自己总结也希望可以帮到大家

dim = 0时对列操作,要求index 的列数和input的列数要相同,否则会报错;同样dim = 1对行操作,要求两者的行数要一致

现在我给出一个dim =0时列数不一致的报错情况,对于dim =1的情况可以自行验证

import torch
a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
index_1 = torch.LongTensor([[0,1,0],[1,1,1]])
b = torch.gather(a,0,index_1)
print('a=',a)
print('b=',b)
...........dim =0会报错,如果在这里改成dim =1则会正确输出,..........

RuntimeError: Expected tensor [2, 3], src [2, 4] and index [2, 3] to have the same size apart from dimension 0

           

自己的一点小总结,希望上面内容可以帮助到大家~~,如果有哪里不对的地方也希望大家可以提出来~

继续阅读