天天看點

Pytorch中requires_grad_(), detach(), torch.no_grad()的差別

文章作者:Tyan

部落格:noahsnail.com | CSDN | 簡書

0. 測試環境

Python 3.6.9, Pytorch 1.5.0

1. 基本概念

Tensor

是一個多元矩陣,其中包含所有的元素為同一資料類型。預設資料類型為

torch.float32

  • 示例一
>>> a = torch.tensor([1.0])
>>> a.data
tensor([1.])
>>> a.grad
>>> a.requires_grad
False
>>> a.dtype
torch.float32
>>> a.item()
1.0
>>> type(a.item())
<class 'float'>           

複制

Tensor

中隻有一個數字時,使用

torch.Tensor.item()

可以得到一個Python數字。

requires_grad

True

時,表示需要計算

Tensor

的梯度。

requires_grad=False

可以用來當機部分網絡,隻更新另一部分網絡的參數。

  • 示例二
>>> a = torch.tensor([1.0, 2.0])
>>> b = a.data
>>> id(b)
139808984381768
>>> id(a)
139811772112328
>>> b.grad
>>> a.grad
>>> b[0] = 5.0
>>> b
tensor([5., 2.])
>>> a
tensor([5., 2.])           

複制

a.data

傳回的是一個新的

Tensor

對象

b

a, b

id

不同,說明二者不是同一個

Tensor

,但

b

a

共享資料的存儲空間,即二者的資料部分指向同一塊記憶體,是以修改

b

的元素時,

a

的元素也對應修改。

2. requires_grad_()與detach()

>>> a = torch.tensor([1.0, 2.0])
>>> a.data
tensor([1., 2.])
>>> a.grad
>>> a.requires_grad
False
>>> a.requires_grad_()
tensor([1., 2.], requires_grad=True)
>>> c = a.pow(2).sum()
>>> c.backward()
>>> a.grad
tensor([2., 4.])
>>> b = a.detach()
>>> b.grad
>>> b.requires_grad
False
>>> b
tensor([1., 2.])
>>> b[0] = 6
>>> b
tensor([6., 2.])
>>> a
tensor([6., 2.], requires_grad=True)           

複制

  • requires_grad_()

requires_grad_()

函數會改變

Tensor

requires_grad

屬性并傳回

Tensor

,修改

requires_grad

的操作是原位操作(in place)。其預設參數為

requires_grad=True

requires_grad=True

時,自動求導會記錄對

Tensor

的操作,

requires_grad_()

的主要用途是告訴自動求導開始記錄對

Tensor

的操作。

  • detach()

detach()

函數會傳回一個新的

Tensor

對象

b

,并且新

Tensor

是與目前的計算圖分離的,其

requires_grad

屬性為

False

,反向傳播時不會計算其梯度。

b

a

共享資料的存儲空間,二者指向同一塊記憶體。

注:共享記憶體空間隻是共享的資料部分,

a.grad

b.grad

是不同的。

3. torch.no_grad()

torch.no_grad()

是一個上下文管理器,用來禁止梯度的計算,通常用來網絡推斷中,它可以減少計算記憶體的使用量。

>>> a = torch.tensor([1.0, 2.0], requires_grad=True)
>>> with torch.no_grad():
...     b = n.pow(2).sum()
...
>>> b
tensor(5.)
>>> b.requires_grad
False
>>> c = a.pow(2).sum()
>>> c.requires_grad
True           

複制

上面的例子中,當

a

requires_grad=True

時,不使用

torch.no_grad()

c.requires_grad

True

,使用

torch.no_grad()

時,

b.requires_grad

False

,當不需要進行反向傳播時(推斷)或不需要計算梯度(網絡輸入)時,

requires_grad=True

會占用更多的計算資源及存儲資源。

4. 總結

requires_grad_()

會修改

Tensor

requires_grad

屬性。

detach()

會傳回一個與計算圖分離的新

Tensor

,新

Tensor

不會在反向傳播中計算梯度,會在特定場合使用。

torch.no_grad()

更節省計算資源和存儲資源,其作用域範圍内的操作不會建構計算圖,常用在網絡推斷中。

References

  1. https://pytorch.org/docs/stable/tensors.html
  2. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.requires_grad_
  3. https://pytorch.org/docs/stable/autograd.html#torch.Tensor.detach
  4. https://pytorch.org/docs/master/generated/torch.no_grad.html