![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLicmbw5iM5YDO5MTN3YTYxY2NzEWZ1MmZlJGMwkDZ0QmN4UDZm9CX0JXZ252bj91Ztl2Lc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
It's automatic
我們知道,深度學習最核心的其中一個步驟,就是求導:根據函數(linear + activation function)求weights相對于loss的導數(還是loss相對于weights的導數?)。然後根據得出的導數,相應的修改weights,讓loss最小化。
各大深度學習架構Tensorflow,Keras,PyTorch都自帶有自動求導功能,不需要我們手動算。
在初步學習PyTorch的時候,看到PyTorch的自動求導過程時,感覺非常的别扭和不直覺。我下面舉個例子,大家自己感受一下。
>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward() # 執行求導
>>> a.grad # a.grad 即導數 d(e)/d(a) 的值
tensor(4.)
這裡讓人感覺别扭的是,調用
e.backward()
執行求導,為什麼會更新
a
對象的狀态
grad
?對于習慣了OOP的人來說,這是非常不直覺的。因為,在OOP裡面,你要改變一個對象的狀态,一般的做法是,引用這個對象本身,給它的property顯示的指派(比如
user.age = 18
),或者是調用這個對象的方法(
user.setAge(18)
),讓它狀态得以改變。
而這裡的做法是,調用了一個跟它(
a
)本身看起來沒什麼關系的對象(
e
)的方法,結果改變了它的狀态。
每次寫代碼寫到這個地方的時候,我都覺得心裡一驚。是以,就一直想一探究竟,看看這内部的關聯究竟是怎麼樣的。
根據上面的代碼,我們知道的是,
e
的結果,是由
c
和
d
運算得到的,而
c
,又是根據
a
和
b
相加得到的。現在,執行
e
的方法,最終改變了
a
的狀态。是以,我們可以猜測
e
内部可能有某個東西,引用着
c
,然後呢,
c
内部又有些東西,引用着
a
。是以,在運作
e
的
backward()
方法時,通過這些引用,先是改變
c
,在根據
c
内部的引用,最終改變了
a
。如果我們的猜測沒錯的話,那麼這些引用關系到底是什麼呢?在代碼裡是怎麼提現的呢?
想要知道其中原理,最先想到的辦法,自然是去看源代碼。
遺憾的是,
backward()
的實作主要是在C/Cpp層間做的,在Python層面做的事情很少,基本上就是對參數做了一下處理,然後調用native層面的實作。如下:
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
...more comment
"""
if grad_variables is not None:
warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
if grad_tensors is None:
grad_tensors = grad_variables
else:
raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
"arguments both passed to backward(). Please only "
"use 'grad_tensors'.")
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
if grad_tensors is None:
grad_tensors = [None] * len(tensors)
elif isinstance(grad_tensors, torch.Tensor):
grad_tensors = [grad_tensors]
else:
grad_tensors = list(grad_tensors)
grad_tensors = _make_grads(tensors, grad_tensors)
if retain_graph is None:
retain_graph = create_graph
Variable._execution_engine.run_backward(
tensors, grad_tensors, retain_graph, create_graph,
allow_unreachable=True) # allow_unreachable flag
說到Cpp。。。
看來隻能通過一頓自行的探索操作,來了解這個執行過程了。
我們先看看
e
裡面有什麼。
由于
e
是一個
Tensor
變量,我們自然想到去看
Tensor
這個類的代碼,看看裡面有哪些成員變量。不幸的是,Python語言聲明成員變量的方式跟Java這些靜态語言不一樣,他們是用到的時候直接用
self.xxx
随時聲明的。不像Java這樣,在某一個地方統一聲明并做初始化。
當然,我們可以用正規表達式
self\.\w+\s+=
搜尋所有類似于
self.xxx =
的地方,于是你會找到一些
data
,
requires_grad
,
_backward_hooks
,
retain_grad
等等。根據已有的知識,這些看起來都不像。看來相關的成員變量應該在其父類
TensorBase
裡面。不幸的是,
TensorBase
是用C/Cpp 實作的。這。。。這就又涉及到我的知識盲區了。。。
不過,Python其實還提供了其他的一些方式,來友善我們檢視這個對象的屬性和狀态。那就是
vars()
方法和
dir()
方法。然而。。。
>>> vars(a)
{}
>>>
>>>
>>>
>>> dir(a)
['__abs__', '__add__', '__and__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__div__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__idiv__', '__ilshift__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__or__', '__pow__', '__radd__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__rfloordiv__', '__rmul__', '__rpow__', '__rshift__', '__rsub__', '__rtruediv__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__weakref__', '__xor__', '_backward_hooks', '_base', '_cdata', '_coalesced_', '_dimI', '_dimV', '_grad', '_grad_fn', '_indices', '_make_subclass', '_nnz', '_values', '_version', 'abs', 'abs_', 'acos', 'acos_', 'add', 'add_', 'addbmm', 'addbmm_', 'addcdiv', 'addcdiv_', 'addcmul', 'addcmul_', 'addmm', 'addmm_', 'addmv', 'addmv_', 'addr', 'addr_', 'all', 'allclose', 'any', 'apply_', 'argmax', 'argmin', 'argsort', 'as_strided', 'as_strided_', 'asin', 'asin_', 'atan', 'atan2', 'atan2_', 'atan_', 'backward', 'baddbmm', 'baddbmm_', 'bernoulli', 'bernoulli_', 'bincount', 'bmm', 'btrifact', 'btrifact_with_info', 'btrisolve', 'byte', 'cauchy_', 'ceil', 'ceil_', 'char', 'cholesky', 'chunk', 'clamp', 'clamp_', 'clamp_max', 'clamp_max_', 'clamp_min', 'clamp_min_', 'clone', 'coalesce', 'contiguous', 'copy_', 'cos', 'cos_', 'cosh', 'cosh_', 'cpu', 'cross', 'cuda', 'cumprod', 'cumsum', 'data', 'data_ptr', 'dense_dim', 'det', 'detach', 'detach_', 'device', 'diag', 'diag_embed', 'diagflat', 'diagonal', 'digamma', 'digamma_', 'dim', 'dist', 'div', 'div_', 'dot', 'double', 'dtype', 'eig', 'element_size', 'eq', 'eq_', 'equal', 'erf', 'erf_', 'erfc', 'erfc_', 'erfinv', 'erfinv_', 'exp', 'exp_', 'expand', 'expand_as', 'expm1', 'expm1_', 'exponential_', 'fft', 'fill_', 'flatten', 'flip', 'float', 'floor', 'floor_', 'fmod', 'fmod_', 'frac', 'frac_', 'gather', 'ge', 'ge_', 'gels', 'geometric_', 'geqrf', 'ger', 'gesv', 'get_device', 'grad', 'grad_fn', 'gt', 'gt_', 'half', 'hardshrink', 'histc', 'ifft', 'index_add', 'index_add_', 'index_copy', 'index_copy_', 'index_fill', 'index_fill_', 'index_put', 'index_put_', 'index_select', 'indices', 'int', 'inverse', 'irfft', 'is_coalesced', 'is_complex', 'is_contiguous', 'is_cuda', 'is_distributed', 'is_floating_point', 'is_leaf', 'is_nonzero', 'is_pinned', 'is_same_size', 'is_set_to', 'is_shared', 'is_signed', 'is_sparse', 'isclose', 'item', 'kthvalue', 'layout', 'le', 'le_', 'lerp', 'lerp_', 'lgamma', 'lgamma_', 'log', 'log10', 'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'log_normal_', 'log_softmax', 'logdet', 'logsumexp', 'long', 'lt', 'lt_', 'map2_', 'map_', 'masked_fill', 'masked_fill_', 'masked_scatter', 'masked_scatter_', 'masked_select', 'matmul', 'matrix_power', 'max', 'mean', 'median', 'min', 'mm', 'mode', 'mul', 'mul_', 'multinomial', 'mv', 'mvlgamma', 'mvlgamma_', 'name', 'narrow', 'narrow_copy', 'ndimension', 'ne', 'ne_', 'neg', 'neg_', 'nelement', 'new', 'new_empty', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'nonzero', 'norm', 'normal_', 'numel', 'numpy', 'orgqr', 'ormqr', 'output_nr', 'permute', 'pin_memory', 'pinverse', 'polygamma', 'polygamma_', 'potrf', 'potri', 'potrs', 'pow', 'pow_', 'prelu', 'prod', 'pstrf', 'put_', 'qr', 'random_', 'reciprocal', 'reciprocal_', 'record_stream', 'register_hook', 'reinforce', 'relu', 'relu_', 'remainder', 'remainder_', 'renorm', 'renorm_', 'repeat', 'requires_grad', 'requires_grad_', 'reshape', 'reshape_as', 'resize', 'resize_', 'resize_as', 'resize_as_', 'retain_grad', 'rfft', 'roll', 'rot90', 'round', 'round_', 'rsqrt', 'rsqrt_', 'scatter', 'scatter_', 'scatter_add', 'scatter_add_', 'select', 'set_', 'shape', 'share_memory_', 'short', 'sigmoid', 'sigmoid_', 'sign', 'sign_', 'sin', 'sin_', 'sinh', 'sinh_', 'size', 'slogdet', 'smm', 'softmax', 'sort', 'sparse_dim', 'sparse_mask', 'sparse_resize_', 'sparse_resize_and_clear_', 'split', 'split_with_sizes', 'sqrt', 'sqrt_', 'squeeze', 'squeeze_', 'sspaddmm', 'std', 'stft', 'storage', 'storage_offset', 'storage_type', 'stride', 'sub', 'sub_', 'sum', 'svd', 'symeig', 't', 't_', 'take', 'tan', 'tan_', 'tanh', 'tanh_', 'to', 'to_dense', 'to_sparse', 'tolist', 'topk', 'trace', 'transpose', 'transpose_', 'tril', 'tril_', 'triu', 'triu_', 'trtrs', 'trunc', 'trunc_', 'type', 'type_as', 'unbind', 'unfold', 'uniform_', 'unique', 'unsqueeze', 'unsqueeze_', 'values', 'var', 'view', 'view_as', 'where', 'zero_']
>>>
可以看到,使用
vars()
方法,傳回的集合是空的。而使用
dir()
,傳回的卻又太多了,你都不知道哪些是有用的哪些是沒用的,哪些又是我們真正關心的。
怎麼辦呢?
看來隻能Google了。經過一頓調查和連猜帶蒙,我得出了一些結論。也不知道是否正确(準确),如果有錯誤或不準确的地方,還希望有大神不吝指出。
為了解釋他們之間的關系,我們先從一個最簡單的例子開始。
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>>
>>> c.backward()
>>> a.grad
tensor(1.)
>>> b.grad
tensor(1.)
>>>
我們的問題是,
c
和
a
是怎麼串聯起來的?為什麼執行
c.backward()
,會更新
a
的狀态(
a.grad
的值)?
其實,我們要找的東西,遠在天邊,近在眼前。
>>> c
tensor(5., grad_fn=<AddBackward0>)
>>>
可以看到,c裡面有一個
gran_fn
變量。這個東西是什麼呢?
>>> c.grad_fn
<AddBackward0 object at 0x10e56d160>
>>> type(c.grad_fn)
<class 'AddBackward0'>
>>>
可見,這是一個
AddBackward0
這個類的對象。遺憾的是,這個類也是用Cpp來寫的。不過,這不代表我們不能在Python層做一些簡單的探索,看看裡面有些什麼東西。
>>> dir(c.grad_fn)
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad']
除去那些特殊方法(以
__
開頭和結束的)和私有方法(以
_
開頭的),範圍縮小到
['metadata', 'name', 'next_functions', 'register_hook', 'requires_grad’]
這其中,看名字,最可疑的是這個
next_functions
。我們看看是什麼:
>>> c.grad_fn.next_functions
((<AccumulateGrad object at 0x10e56d160>, 0), (<AccumulateGrad object at 0x1205b29b0>, 0))
>>>
看起來,這個
next_functions
是一個tuple of tuple of
AccumulateGrad
and
int
。
繼續探索這個
AccumulateGrad
。
>>> ag = c.grad_fn.next_functions[0][0]
>>> dir(ag)
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad', 'variable']
同樣的,去掉那些特殊函數。我們感興趣的範圍縮小到
['metadata', 'name', 'next_functions', 'register_hook', 'requires_grad', 'variable']
這其中,除了前面提高過的
'next_functions'
之外,我們驚訝的發現,還有一個 叫
variable
的屬性。我們分别都看一下:
>>> ag.next_functions
()
>>> ag.variable
tensor(2., requires_grad=True)
>>>
可見,
ag1
的
variable
這個屬性是一個
tensor(2., requires_grad=True)
這個看起來似乎跟我們前面定義的a是同一個啊。是嗎?我們确認一下:
>>> id(a)
4842774104
>>> id(ag.variable)
4842774104
>>>
果然是!
到這裡,謎底基本上就呼之欲出了。
當我們執行
c.backward()
的時候。這個操作将調用c裡面的
grad_fn
這個屬性,執行求導的操作。這個操作将周遊
grad_fn
的
next_functions
,然後分别取出裡面的function(
AccumulateGrad
),執行求導操作。計算出結果以後,将結果儲存到他們對應的
variable
這個變量所引用的對象(
a
和
b
)的
grad
這個屬性裡面。
于是,當我們執行完
c.backward()
之後,
a
和
b
裡面的
grad
值就得到了更新。
再回到我們開篇提到的稍微複雜點的例子:
>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward()
>>> a.grad
tensor(4.)
>>> b.grad
tensor(4.)
>>> c.grad
>>> d.grad
tensor(5.)
以此類推,
e
到各個節點
a
、
b
、
c
、
d
的關聯也就很容易了解了。
>>> e
tensor(20., grad_fn=<MulBackward0>)
>>> e.grad_fn
<MulBackward0 object at 0x111cb5470>
>>> e.grad_fn.next_functions
((<AddBackward0 object at 0x110501438>, 0), (<AccumulateGrad object at 0x111cb5fd0>, 0))
分别把
next_functions
中的function取出來看看
>>> ((f1, _), (f2, _)) = e.grad_fn.next_functions
>>> f1
<AddBackward0 object at 0x111cb5fd0>
>>> f1.variable
Traceback (most recent call last):
File "", line 1, in <module>
AttributeError: 'AddBackward0' object has no attribute 'variable'
>>> c
tensor(5., grad_fn=<AddBackward0>)
>>> c.grad_fn
<AddBackward0 object at 0x111cb5fd0>
>>> f2
<AccumulateGrad object at 0x1103ee4e0>
>>> f2.variable
tensor(4., requires_grad=True)
可見,
e.grad_fn.next_functions
中的第一個function
f1
,就是
c.grad_fn
。
不過,如果跟着剛剛的思路,你會覺得意外的是,
f1
是沒有
variable
變量的。這是因為,
c
的結果,是由
a
和
b
相加的出來的,這樣的變量是非“葉變量”。
如果我們把
a
、
b
、
c
、
d
、
e
和他們之間的運算過程了解為一棵樹。那麼,
a
、
b
、
d
都是我們自己“new”出來的,這樣的節點叫葉節點。這些葉節點分别有一個
AccumulateGrad
類型的function跟它們對應起來。則像c、e這些,不是我們自己直接建立的,而是通過一些運算得出的,就是非葉節點。對于非葉節點來說,預設情況下他們不需要存儲導數值(當然,如果需要,也是有辦法做到的)。是以,他們的
grad_fn
,不需要有一個變量
variable
引用着他們。
在
e.backward()
執行求導時,系統周遊
e.grad_fn.next_functions
,分别執行求導。如果
e.grad_fn.next_functions
中有哪個是
AccumulateGrad
,則把結果儲存到
AccumulateGrad
的variable引用的變量中。否則,遞歸周遊這個function的
next_functions
,執行求導過程。最終到達所有的葉節點,求導結束。同時,所有的葉節點的
grad
變量都得到了相應的更新。
他們之間的關系如下圖所示:
那麼,還有兩個問題沒有解決:
1. 這些各種function,像
AccumulateGrad
、
AddBackward0
、
MulBackward0
,是怎麼産生的?
2. 這些function,比如上面出現過的
AddBackward0
、
MulBackward0
,具體是怎麼求導的呢?
對于第一個問題,很自然的猜測,是PyTorch重寫了一些操作符,像
+
,
*
等。在這個過程中,建立了這些function,并建立起了引用關系。
對于第二個問題,簡單的說,就是在每個函數定義的時候,都需要自己定義好
forward()
和
backward()
函數。在
forward()
裡面實作這個運算的執行過程。比如,相加、相乘,在
backward()
則實作這個運算的求導過程。
以上就是我對PyTorch的自動求導原理的了解。隻是一個大概的,比較淺顯的了解。對于一些更加細節的,包括一些特殊情況的處理,推薦大家看這個視訊。講得非常清楚。
https://www.youtube.com/watch?v=MswxJw-8PvE
參考:
https://pytorch.org/docs/stable/autograd.html#in-place-operations-on-tensors
https://pytorch.org/docs/stable/notes/extending.html
https://www.youtube.com/watch?v=MswxJw-8PvE
關注小創作