天天看点

torch.squeeze()与torch.unsqueeze()

1.torch.squeeze()与torch.unsqueeze()在一维数组上的横向对比

结论:

torch.squeeze()在指定的维度前面面减少一个维度

torch.unsqueeze()在指定的维度后面增加一个维度

import torch
a = torch.randn(5)
print(a)
b = torch.unsqueeze(a, dim=-1)
print(b)
c = torch.squeeze(a, dim=-1)
print(c)
           
tensor([ 0.1454, -1.4046, -0.8206, -0.7080,  0.3535])
tensor([[ 0.1454],
        [-1.4046],
        [-0.8206],
        [-0.7080],
        [ 0.3535]])
tensor([ 0.1454, -1.4046, -0.8206, -0.7080,  0.3535])
           

torch.squeeze与torch.unsqueeze()在二维数组上的横向对比

import torch
a = torch.randn(2,5)
print(a)
b = torch.unsqueeze(a, dim=-1)
print(b)
c = torch.squeeze(a, dim=-1)
print(c)
           
tensor([[-1.6989,  1.3214, -0.4190,  1.4261, -0.4857],
        [ 3.3618, -0.1716, -0.1987, -2.3104,  2.1282]])
tensor([[[-1.6989],
         [ 1.3214],
         [-0.4190],
         [ 1.4261],
         [-0.4857]],

        [[ 3.3618],
         [-0.1716],
         [-0.1987],
         [-2.3104],
         [ 2.1282]]])
tensor([[-1.6989,  1.3214, -0.4190,  1.4261, -0.4857],
        [ 3.3618, -0.1716, -0.1987, -2.3104,  2.1282]])
           

2.torch.unsqueeze()在一维数组上的关于维度的纵向对比

结论:

对于一维数组要注意一点:dim=-1不等与dim=0

即便是一维数组,系统也是认为有两个维度

dim = -1 等价于dim=1

dim = -2 等价于dim=0

当然dim为负号表示的含义始终是不变的,-n就是到数第n个,n即是正数第n-1个

import torch
a = torch.randn(5)
print(a.shape)
print(a)
b = torch.unsqueeze(a, dim=0)
print(b)
bb = torch.unsqueeze(a, dim=-2)
print(bb)
bbb = torch.unsqueeze(a, dim=1)
print(bbb)
bbb = torch.unsqueeze(a, dim=-1)
print(bbb)
           
torch.Size([5])
tensor([-0.5416,  0.2842, -0.0026,  0.8659, -1.1321])
tensor([[-0.5416,  0.2842, -0.0026,  0.8659, -1.1321]])
tensor([[-0.5416,  0.2842, -0.0026,  0.8659, -1.1321]])
tensor([[-0.5416],
        [ 0.2842],
        [-0.0026],
        [ 0.8659],
        [-1.1321]])
tensor([[-0.5416],
        [ 0.2842],
        [-0.0026],
        [ 0.8659],
        [-1.1321]])
           

继续阅读