天天看點

pytorch張量次元的擴增、壓縮、廣播

1,張量的擴增和壓縮

對于pytorch張量,有一個比較常用的操作就是沿着某個方向對張量做擴增和對張量進行壓縮,這兩種情況與張量的大小等于1的次元有關,對于一個張量來說,可以任意添加一個次元,該次元的大小為,而不改變張量的資料,因為張量的大小等于所有次元大小的乘積,那些為1的次元不改變張量的大小。具體見下面代碼:

import torch

a = torch.rand(3, 4) # 随機生成一個3*4的張量
print(a)
print(a.shape)

b = a.unsqueeze(-1) # 擴增最後一個次元,生成3*4*1的張量
print(b)
print(b.shape)

c = a.unsqueeze(-1).unsqueeze(-1) # 繼續擴增最後一個次元,生成3*4*1*1的張量
print(c)
print(c.shape)

d = torch.rand(1,3,4,1) # 随機生成一個1*3*4*1的張量,其中有兩個次元大小為1
print(d)
e = d.squeeze()  # 壓縮所有大小為1的次元,成為了3*4的張量
print(e)
           

2,張量的廣播

在張量的運算中會碰到一種情況,即兩個不同次元張量之間做四則運算,且兩個張量某些次元相等。顯然,如果按照張量的四則運算的定義,兩個不同次元的張量不能進行四則運算。為了能夠讓它們進行計算,首先需要把次元數目比較小的張量擴增到和次元數目比較大的張量一緻。具體見下面代碼:

d1 = torch.randn(3, 4, 5)  # 定義3*4*5的張量1
print(d1)
d2 = torch.randn(3, 5)  # 定義3*5的張量2
print(d2)
d2 = d2.unsqueeze(1)  # 擴增第一個次元,将張量2的形狀變為3*1*5
print(d2)
d3 = d1 + d2  # 廣播求和,最後結果為3*4*5的張量,相當于是将d2沿着第二個次元複制4次,使之成為3*4*5的張量,這樣與d1就能進行元素一一對應的計算
print(d3)
           

Done!!!

繼續閱讀