文章作者:Tyan
部落格:noahsnail.com | CSDN | 簡書
0. 測試環境
Python 3.6.9, Pytorch 1.5.0
1. 基本概念
Tensor
是一個多元矩陣,其中包含所有的元素為同一資料類型。預設資料類型為
torch.float32
。
- 示例一
>>> a = torch.tensor([1.0])
>>> a.data
tensor([1.])
>>> a.grad
>>> a.requires_grad
False
>>> a.dtype
torch.float32
>>> a.item()
1.0
>>> type(a.item())
<class 'float'>
複制
Tensor
中隻有一個數字時,使用
torch.Tensor.item()
可以得到一個Python數字。
requires_grad
為
True
時,表示需要計算
Tensor
的梯度。
requires_grad=False
可以用來當機部分網絡,隻更新另一部分網絡的參數。
- 示例二
>>> a = torch.tensor([1.0, 2.0])
>>> b = a.data
>>> id(b)
139808984381768
>>> id(a)
139811772112328
>>> b.grad
>>> a.grad
>>> b[0] = 5.0
>>> b
tensor([5., 2.])
>>> a
tensor([5., 2.])
複制
a.data
傳回的是一個新的
Tensor
對象
b
,
a, b
的
id
不同,說明二者不是同一個
Tensor
,但
b
與
a
共享資料的存儲空間,即二者的資料部分指向同一塊記憶體,是以修改
b
的元素時,
a
的元素也對應修改。
2. requires_grad_()與detach()
>>> a = torch.tensor([1.0, 2.0])
>>> a.data
tensor([1., 2.])
>>> a.grad
>>> a.requires_grad
False
>>> a.requires_grad_()
tensor([1., 2.], requires_grad=True)
>>> c = a.pow(2).sum()
>>> c.backward()
>>> a.grad
tensor([2., 4.])
>>> b = a.detach()
>>> b.grad
>>> b.requires_grad
False
>>> b
tensor([1., 2.])
>>> b[0] = 6
>>> b
tensor([6., 2.])
>>> a
tensor([6., 2.], requires_grad=True)
複制
-
requires_grad_()
requires_grad_()
函數會改變
Tensor
的
requires_grad
屬性并傳回
Tensor
,修改
requires_grad
的操作是原位操作(in place)。其預設參數為
requires_grad=True
。
requires_grad=True
時,自動求導會記錄對
Tensor
的操作,
requires_grad_()
的主要用途是告訴自動求導開始記錄對
Tensor
的操作。
-
detach()
detach()
函數會傳回一個新的
Tensor
對象
b
,并且新
Tensor
是與目前的計算圖分離的,其
requires_grad
屬性為
False
,反向傳播時不會計算其梯度。
b
與
a
共享資料的存儲空間,二者指向同一塊記憶體。
注:共享記憶體空間隻是共享的資料部分,
a.grad
與
b.grad
是不同的。
3. torch.no_grad()
torch.no_grad()
是一個上下文管理器,用來禁止梯度的計算,通常用來網絡推斷中,它可以減少計算記憶體的使用量。
>>> a = torch.tensor([1.0, 2.0], requires_grad=True)
>>> with torch.no_grad():
... b = n.pow(2).sum()
...
>>> b
tensor(5.)
>>> b.requires_grad
False
>>> c = a.pow(2).sum()
>>> c.requires_grad
True
複制
上面的例子中,當
a
的
requires_grad=True
時,不使用
torch.no_grad()
,
c.requires_grad
為
True
,使用
torch.no_grad()
時,
b.requires_grad
為
False
,當不需要進行反向傳播時(推斷)或不需要計算梯度(網絡輸入)時,
requires_grad=True
會占用更多的計算資源及存儲資源。
4. 總結
requires_grad_()
會修改
Tensor
的
requires_grad
屬性。
detach()
會傳回一個與計算圖分離的新
Tensor
,新
Tensor
不會在反向傳播中計算梯度,會在特定場合使用。
torch.no_grad()
更節省計算資源和存儲資源,其作用域範圍内的操作不會建構計算圖,常用在網絡推斷中。
References
- https://pytorch.org/docs/stable/tensors.html
- https://pytorch.org/docs/stable/tensors.html#torch.Tensor.requires_grad_
- https://pytorch.org/docs/stable/autograd.html#torch.Tensor.detach
- https://pytorch.org/docs/master/generated/torch.no_grad.html