天天看點

python求導_PyTorch的自動求導功能(Autograd)原了解析

python求導_PyTorch的自動求導功能(Autograd)原了解析

It's automatic

我們知道,深度學習最核心的其中一個步驟,就是求導:根據函數(linear + activation function)求weights相對于loss的導數(還是loss相對于weights的導數?)。然後根據得出的導數,相應的修改weights,讓loss最小化。

各大深度學習架構Tensorflow,Keras,PyTorch都自帶有自動求導功能,不需要我們手動算。

在初步學習PyTorch的時候,看到PyTorch的自動求導過程時,感覺非常的别扭和不直覺。我下面舉個例子,大家自己感受一下。

>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward() # 執行求導
>>> a.grad  # a.grad 即導數 d(e)/d(a) 的值
tensor(4.)
           

這裡讓人感覺别扭的是,調用 

e.backward()

執行求導,為什麼會更新 

a

 對象的狀态

grad

?對于習慣了OOP的人來說,這是非常不直覺的。因為,在OOP裡面,你要改變一個對象的狀态,一般的做法是,引用這個對象本身,給它的property顯示的指派(比如 

user.age = 18

),或者是調用這個對象的方法(

user.setAge(18)

),讓它狀态得以改變。

而這裡的做法是,調用了一個跟它(

a

)本身看起來沒什麼關系的對象(

e

)的方法,結果改變了它的狀态。

每次寫代碼寫到這個地方的時候,我都覺得心裡一驚。是以,就一直想一探究竟,看看這内部的關聯究竟是怎麼樣的。

根據上面的代碼,我們知道的是,

e

的結果,是由

c

d

運算得到的,而

c

,又是根據

a

b

相加得到的。現在,執行

e

的方法,最終改變了

a

的狀态。是以,我們可以猜測

e

内部可能有某個東西,引用着

c

,然後呢,

c

内部又有些東西,引用着

a

。是以,在運作

e

backward()

方法時,通過這些引用,先是改變

c

,在根據

c

内部的引用,最終改變了

a

。如果我們的猜測沒錯的話,那麼這些引用關系到底是什麼呢?在代碼裡是怎麼提現的呢?

想要知道其中原理,最先想到的辦法,自然是去看源代碼。

遺憾的是,

backward()

的實作主要是在C/Cpp層間做的,在Python層面做的事情很少,基本上就是對參數做了一下處理,然後調用native層面的實作。如下:

def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
    r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
    ...more comment
    """
if grad_variables is not None:
        warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
if grad_tensors is None:
            grad_tensors = grad_variables
else:
raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
"arguments both passed to backward(). Please only "
"use 'grad_tensors'.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)

if grad_tensors is None:
        grad_tensors = [None] * len(tensors)
elif isinstance(grad_tensors, torch.Tensor):
        grad_tensors = [grad_tensors]
else:
        grad_tensors = list(grad_tensors)

    grad_tensors = _make_grads(tensors, grad_tensors)
if retain_graph is None:
        retain_graph = create_graph

    Variable._execution_engine.run_backward(
        tensors, grad_tensors, retain_graph, create_graph,
        allow_unreachable=True)  # allow_unreachable flag
           

說到Cpp。。。

python求導_PyTorch的自動求導功能(Autograd)原了解析

看來隻能通過一頓自行的探索操作,來了解這個執行過程了。

我們先看看

e

裡面有什麼。

由于

e

是一個

Tensor

變量,我們自然想到去看

Tensor

這個類的代碼,看看裡面有哪些成員變量。不幸的是,Python語言聲明成員變量的方式跟Java這些靜态語言不一樣,他們是用到的時候直接用

self.xxx

随時聲明的。不像Java這樣,在某一個地方統一聲明并做初始化。

當然,我們可以用正規表達式 

self\.\w+\s+=

 搜尋所有類似于 

self.xxx = 

的地方,于是你會找到一些

data

requires_grad

_backward_hooks

retain_grad

等等。根據已有的知識,這些看起來都不像。看來相關的成員變量應該在其父類

TensorBase

裡面。不幸的是,

TensorBase

是用C/Cpp 實作的。這。。。這就又涉及到我的知識盲區了。。。

不過,Python其實還提供了其他的一些方式,來友善我們檢視這個對象的屬性和狀态。那就是

vars()

 方法和 

dir()

方法。然而。。。

>>> vars(a)
{}
>>>
>>>
>>>
>>> dir(a)
['__abs__', '__add__', '__and__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__div__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__idiv__', '__ilshift__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__or__', '__pow__', '__radd__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__rfloordiv__', '__rmul__', '__rpow__', '__rshift__', '__rsub__', '__rtruediv__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__weakref__', '__xor__', '_backward_hooks', '_base', '_cdata', '_coalesced_', '_dimI', '_dimV', '_grad', '_grad_fn', '_indices', '_make_subclass', '_nnz', '_values', '_version', 'abs', 'abs_', 'acos', 'acos_', 'add', 'add_', 'addbmm', 'addbmm_', 'addcdiv', 'addcdiv_', 'addcmul', 'addcmul_', 'addmm', 'addmm_', 'addmv', 'addmv_', 'addr', 'addr_', 'all', 'allclose', 'any', 'apply_', 'argmax', 'argmin', 'argsort', 'as_strided', 'as_strided_', 'asin', 'asin_', 'atan', 'atan2', 'atan2_', 'atan_', 'backward', 'baddbmm', 'baddbmm_', 'bernoulli', 'bernoulli_', 'bincount', 'bmm', 'btrifact', 'btrifact_with_info', 'btrisolve', 'byte', 'cauchy_', 'ceil', 'ceil_', 'char', 'cholesky', 'chunk', 'clamp', 'clamp_', 'clamp_max', 'clamp_max_', 'clamp_min', 'clamp_min_', 'clone', 'coalesce', 'contiguous', 'copy_', 'cos', 'cos_', 'cosh', 'cosh_', 'cpu', 'cross', 'cuda', 'cumprod', 'cumsum', 'data', 'data_ptr', 'dense_dim', 'det', 'detach', 'detach_', 'device', 'diag', 'diag_embed', 'diagflat', 'diagonal', 'digamma', 'digamma_', 'dim', 'dist', 'div', 'div_', 'dot', 'double', 'dtype', 'eig', 'element_size', 'eq', 'eq_', 'equal', 'erf', 'erf_', 'erfc', 'erfc_', 'erfinv', 'erfinv_', 'exp', 'exp_', 'expand', 'expand_as', 'expm1', 'expm1_', 'exponential_', 'fft', 'fill_', 'flatten', 'flip', 'float', 'floor', 'floor_', 'fmod', 'fmod_', 'frac', 'frac_', 'gather', 'ge', 'ge_', 'gels', 'geometric_', 'geqrf', 'ger', 'gesv', 'get_device', 'grad', 'grad_fn', 'gt', 'gt_', 'half', 'hardshrink', 'histc', 'ifft', 'index_add', 'index_add_', 'index_copy', 'index_copy_', 'index_fill', 'index_fill_', 'index_put', 'index_put_', 'index_select', 'indices', 'int', 'inverse', 'irfft', 'is_coalesced', 'is_complex', 'is_contiguous', 'is_cuda', 'is_distributed', 'is_floating_point', 'is_leaf', 'is_nonzero', 'is_pinned', 'is_same_size', 'is_set_to', 'is_shared', 'is_signed', 'is_sparse', 'isclose', 'item', 'kthvalue', 'layout', 'le', 'le_', 'lerp', 'lerp_', 'lgamma', 'lgamma_', 'log', 'log10', 'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'log_normal_', 'log_softmax', 'logdet', 'logsumexp', 'long', 'lt', 'lt_', 'map2_', 'map_', 'masked_fill', 'masked_fill_', 'masked_scatter', 'masked_scatter_', 'masked_select', 'matmul', 'matrix_power', 'max', 'mean', 'median', 'min', 'mm', 'mode', 'mul', 'mul_', 'multinomial', 'mv', 'mvlgamma', 'mvlgamma_', 'name', 'narrow', 'narrow_copy', 'ndimension', 'ne', 'ne_', 'neg', 'neg_', 'nelement', 'new', 'new_empty', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'nonzero', 'norm', 'normal_', 'numel', 'numpy', 'orgqr', 'ormqr', 'output_nr', 'permute', 'pin_memory', 'pinverse', 'polygamma', 'polygamma_', 'potrf', 'potri', 'potrs', 'pow', 'pow_', 'prelu', 'prod', 'pstrf', 'put_', 'qr', 'random_', 'reciprocal', 'reciprocal_', 'record_stream', 'register_hook', 'reinforce', 'relu', 'relu_', 'remainder', 'remainder_', 'renorm', 'renorm_', 'repeat', 'requires_grad', 'requires_grad_', 'reshape', 'reshape_as', 'resize', 'resize_', 'resize_as', 'resize_as_', 'retain_grad', 'rfft', 'roll', 'rot90', 'round', 'round_', 'rsqrt', 'rsqrt_', 'scatter', 'scatter_', 'scatter_add', 'scatter_add_', 'select', 'set_', 'shape', 'share_memory_', 'short', 'sigmoid', 'sigmoid_', 'sign', 'sign_', 'sin', 'sin_', 'sinh', 'sinh_', 'size', 'slogdet', 'smm', 'softmax', 'sort', 'sparse_dim', 'sparse_mask', 'sparse_resize_', 'sparse_resize_and_clear_', 'split', 'split_with_sizes', 'sqrt', 'sqrt_', 'squeeze', 'squeeze_', 'sspaddmm', 'std', 'stft', 'storage', 'storage_offset', 'storage_type', 'stride', 'sub', 'sub_', 'sum', 'svd', 'symeig', 't', 't_', 'take', 'tan', 'tan_', 'tanh', 'tanh_', 'to', 'to_dense', 'to_sparse', 'tolist', 'topk', 'trace', 'transpose', 'transpose_', 'tril', 'tril_', 'triu', 'triu_', 'trtrs', 'trunc', 'trunc_', 'type', 'type_as', 'unbind', 'unfold', 'uniform_', 'unique', 'unsqueeze', 'unsqueeze_', 'values', 'var', 'view', 'view_as', 'where', 'zero_']
>>>
           

可以看到,使用

vars()

方法,傳回的集合是空的。而使用

dir()

,傳回的卻又太多了,你都不知道哪些是有用的哪些是沒用的,哪些又是我們真正關心的。

怎麼辦呢?

看來隻能Google了。經過一頓調查和連猜帶蒙,我得出了一些結論。也不知道是否正确(準确),如果有錯誤或不準确的地方,還希望有大神不吝指出。

為了解釋他們之間的關系,我們先從一個最簡單的例子開始。

>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>>
>>> c.backward()
>>> a.grad
tensor(1.)
>>> b.grad
tensor(1.)
>>>
           

我們的問題是,

c

a

是怎麼串聯起來的?為什麼執行

c.backward()

,會更新

a

的狀态(

a.grad

的值)?

其實,我們要找的東西,遠在天邊,近在眼前。

>>> c
tensor(5., grad_fn=<AddBackward0>)
>>>
           

可以看到,c裡面有一個

gran_fn

變量。這個東西是什麼呢?

>>> c.grad_fn
<AddBackward0 object at 0x10e56d160>
>>> type(c.grad_fn)
<class 'AddBackward0'>
>>>
           

可見,這是一個

AddBackward0

這個類的對象。遺憾的是,這個類也是用Cpp來寫的。不過,這不代表我們不能在Python層做一些簡單的探索,看看裡面有些什麼東西。

>>> dir(c.grad_fn)
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad']
           

除去那些特殊方法(以

__

開頭和結束的)和私有方法(以

_

開頭的),範圍縮小到

['metadata', 'name', 'next_functions', 'register_hook', 'requires_grad’]

這其中,看名字,最可疑的是這個

next_functions

。我們看看是什麼:

>>> c.grad_fn.next_functions
((<AccumulateGrad object at 0x10e56d160>, 0), (<AccumulateGrad object at 0x1205b29b0>, 0))
>>>
           

看起來,這個

next_functions

 是一個tuple of tuple of 

AccumulateGrad

 and 

int

繼續探索這個 

AccumulateGrad

>>> ag = c.grad_fn.next_functions[0][0]
>>> dir(ag)
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad', 'variable']
           

同樣的,去掉那些特殊函數。我們感興趣的範圍縮小到 

['metadata', 'name', 'next_functions', 'register_hook', 'requires_grad', 'variable']

 這其中,除了前面提高過的 

'next_functions'

 之外,我們驚訝的發現,還有一個 叫 

variable

 的屬性。我們分别都看一下:

>>> ag.next_functions
()
>>> ag.variable
tensor(2., requires_grad=True)
>>>
           

可見,

ag1

variable

這個屬性是一個

tensor(2., requires_grad=True)
           

這個看起來似乎跟我們前面定義的a是同一個啊。是嗎?我們确認一下:

>>> id(a)
4842774104
>>> id(ag.variable)
4842774104
>>>
           

果然是!

到這裡,謎底基本上就呼之欲出了。

當我們執行

c.backward()

的時候。這個操作将調用c裡面的

grad_fn

這個屬性,執行求導的操作。這個操作将周遊

grad_fn

next_functions

,然後分别取出裡面的function(

AccumulateGrad

),執行求導操作。計算出結果以後,将結果儲存到他們對應的

variable

 這個變量所引用的對象(

a

b

)的 

grad

這個屬性裡面。

于是,當我們執行完

c.backward()

之後,

a

b

裡面的

grad

值就得到了更新。

再回到我們開篇提到的稍微複雜點的例子:

>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward()
>>> a.grad
tensor(4.)
>>> b.grad
tensor(4.)
>>> c.grad
>>> d.grad
tensor(5.)
           

以此類推,

e

到各個節點

a

b

c

d

的關聯也就很容易了解了。

>>> e
tensor(20., grad_fn=<MulBackward0>)
>>> e.grad_fn
<MulBackward0 object at 0x111cb5470>
>>> e.grad_fn.next_functions
((<AddBackward0 object at 0x110501438>, 0), (<AccumulateGrad object at 0x111cb5fd0>, 0))
           

分别把

next_functions

中的function取出來看看

>>> ((f1, _), (f2, _)) = e.grad_fn.next_functions

>>> f1
<AddBackward0 object at 0x111cb5fd0>
>>> f1.variable
Traceback (most recent call last):
  File "", line 1, in <module>
AttributeError: 'AddBackward0' object has no attribute 'variable'

>>> c
tensor(5., grad_fn=<AddBackward0>)
>>> c.grad_fn
<AddBackward0 object at 0x111cb5fd0>

>>> f2
<AccumulateGrad object at 0x1103ee4e0>
>>> f2.variable
tensor(4., requires_grad=True)
           

可見,

e.grad_fn.next_functions

中的第一個function 

f1

,就是

c.grad_fn

不過,如果跟着剛剛的思路,你會覺得意外的是,

f1

是沒有

variable

 變量的。這是因為,

c

的結果,是由

a

b

相加的出來的,這樣的變量是非“葉變量”。

如果我們把

a

b

c

d

e

和他們之間的運算過程了解為一棵樹。那麼,

a

b

d

都是我們自己“new”出來的,這樣的節點叫葉節點。這些葉節點分别有一個

AccumulateGrad

 類型的function跟它們對應起來。則像c、e這些,不是我們自己直接建立的,而是通過一些運算得出的,就是非葉節點。對于非葉節點來說,預設情況下他們不需要存儲導數值(當然,如果需要,也是有辦法做到的)。是以,他們的

grad_fn

,不需要有一個變量

variable

 引用着他們。

e.backward()

執行求導時,系統周遊

e.grad_fn.next_functions

,分别執行求導。如果

e.grad_fn.next_functions

中有哪個是

AccumulateGrad

,則把結果儲存到

AccumulateGrad

的variable引用的變量中。否則,遞歸周遊這個function的

next_functions

,執行求導過程。最終到達所有的葉節點,求導結束。同時,所有的葉節點的

grad

變量都得到了相應的更新。

他們之間的關系如下圖所示:

python求導_PyTorch的自動求導功能(Autograd)原了解析

那麼,還有兩個問題沒有解決:

1. 這些各種function,像

AccumulateGrad

AddBackward0

MulBackward0

,是怎麼産生的?

2. 這些function,比如上面出現過的

AddBackward0

 、

MulBackward0

,具體是怎麼求導的呢?

對于第一個問題,很自然的猜測,是PyTorch重寫了一些操作符,像

+

*

等。在這個過程中,建立了這些function,并建立起了引用關系。

對于第二個問題,簡單的說,就是在每個函數定義的時候,都需要自己定義好

forward()

backward()

函數。在

forward()

裡面實作這個運算的執行過程。比如,相加、相乘,在

backward()

則實作這個運算的求導過程。

以上就是我對PyTorch的自動求導原理的了解。隻是一個大概的,比較淺顯的了解。對于一些更加細節的,包括一些特殊情況的處理,推薦大家看這個視訊。講得非常清楚。

https://www.youtube.com/watch?v=MswxJw-8PvE

參考:

https://pytorch.org/docs/stable/autograd.html#in-place-operations-on-tensors

https://pytorch.org/docs/stable/notes/extending.html

https://www.youtube.com/watch?v=MswxJw-8PvE

python求導_PyTorch的自動求導功能(Autograd)原了解析

關注小創作