天天看点

pytorch中gather函数的官方样例理解

在简书、csdn中的很多帖子中,都没有根据gather函数的官方解释进行理解的,很多同学们根据二维矩阵或三维矩阵的单例去理解后不具备通用性,这是因为不结合三维角度去解释是错误的。

下边给出结合官方解释去理解gather函数的处理过程的思路。

这里是官方文档的解释

torch.gather(input, dim, index, out=None) → Tensor
   Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters:	

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]
           

以上是官方给的解释,官方给的样例中,主要是结合三维的数据进行定位用的。

三维矩阵

例如在dim=0时,输出的值out为输入input的[index_value][j][k],也就是索引输入矩阵的指定位置的值,构成输出的值。具体的位置是:index指定了第一维度,第二、第三维度的索引和输出矩阵的索引相同。

例如:

import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18.,  26.,  22.,   1.,   0.],
         [ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.]],

        [[  5.,  29.,  10.,   0.,  22.],
         [ 26.,  10.,  20.,  29.,  18.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
           

在例子中,dim=1,a的三维索引分别为矩阵数、行数、列数(a[0][0][0]表示第一个矩阵,第一行、第一列,看表中的值可知为18;如果是a[1][0][0]表示第二个矩阵,第一行、第一列,其值为26)。因此在index进行重建后,即

out[0][0][0]=input[0][index(0,0,0)][0]=a[0][0][0]=18
out[0][0][1]=input[0][index(0,0,1)][1]=a[0][1][1]=26
.
.
.
           

以此,我们就能得到最后的输出。

二维矩阵

对于二维矩阵来说,由于只有两层的索引,其索引分别为行、列,因此,官方的解释可以直接忽略第三层索引。

out[i][j] = input[index[i][j]][j]  # dim=0
    out[i][j] = input[i][index[i][j]]  # dim=1
           

很容易看到,此时input第一层索引是行,因此dim=0时index索引是修改了行的值,以此类推。

注意

其中最关键的是python中索引的顺序为深度、行、列,按传统的行、列、深度的三维顺序去理解python三维矩阵,会发现根本解释不同,而且要根据官方给的三维解释理解其二维的不同之处。