在简书、csdn中的很多帖子中,都没有根据gather函数的官方解释进行理解的,很多同学们根据二维矩阵或三维矩阵的单例去理解后不具备通用性,这是因为不结合三维角度去解释是错误的。
下边给出结合官方解释去理解gather函数的处理过程的思路。
这里是官方文档的解释
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
以上是官方给的解释,官方给的样例中,主要是结合三维的数据进行定位用的。
三维矩阵
例如在dim=0时,输出的值out为输入input的[index_value][j][k],也就是索引输入矩阵的指定位置的值,构成输出的值。具体的位置是:index指定了第一维度,第二、第三维度的索引和输出矩阵的索引相同。
例如:
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 28., 22., 27., 0.]],
[[ 26., 10., 20., 29., 18.],
[ 5., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18., 26., 22., 1., 0.],
[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.]],
[[ 5., 29., 10., 0., 22.],
[ 26., 10., 20., 29., 18.],
[ 10., 29., 10., 0., 22.]]])
'''
在例子中,dim=1,a的三维索引分别为矩阵数、行数、列数(a[0][0][0]表示第一个矩阵,第一行、第一列,看表中的值可知为18;如果是a[1][0][0]表示第二个矩阵,第一行、第一列,其值为26)。因此在index进行重建后,即
out[0][0][0]=input[0][index(0,0,0)][0]=a[0][0][0]=18
out[0][0][1]=input[0][index(0,0,1)][1]=a[0][1][1]=26
.
.
.
以此,我们就能得到最后的输出。
二维矩阵
对于二维矩阵来说,由于只有两层的索引,其索引分别为行、列,因此,官方的解释可以直接忽略第三层索引。
out[i][j] = input[index[i][j]][j] # dim=0
out[i][j] = input[i][index[i][j]] # dim=1
很容易看到,此时input第一层索引是行,因此dim=0时index索引是修改了行的值,以此类推。
注意
其中最关键的是python中索引的顺序为深度、行、列,按传统的行、列、深度的三维顺序去理解python三维矩阵,会发现根本解释不同,而且要根据官方给的三维解释理解其二维的不同之处。