天天看點

花書BPTT公式推導

花書第10.2.2節的計算循環神經網絡的梯度看了好久,總算是把公式的推導給看懂了,記錄一下過程。

首先,對于一個普通的RNN來說,其前向傳播過程為:

$$\textbf{a}^{(t)}=\textbf{b}+\textbf{Wh}^{t-1}+\textbf{Ux}^{(t)}$$

$$\textbf{h}^t=tanh(\textbf{a}^{(t)})$$

$$\textbf{o}^{(t)} = \textbf{c} + \textbf{V}\textbf{h}^{(t)}$$

$$\hat{\textbf{y}}^{(t)} = softmax(\textbf{o}^{(t)})$$

先介紹一下等下計算過程中會用到的偏導數:

$$h = tanh(a) = \frac{e^a-e^{-a}}{e^a+e^{-a}}$$

$$\frac{\partial \textbf{h}}{\partial \textbf{a}} = diag(1-\textbf{h}^2)$$

另一個,當$\textbf{y}$采用one-hot并且損失函數$L$為交叉熵時:

$$\frac{\partial L}{\partial \textbf{o}^{(t)}} = \frac{\partial L}{\partial L^{(t)}}\frac{\partial L^{(t)}}{\partial \textbf{o}^{t}} = \hat{\textbf{y}}^{(t)}-\textbf{y}^{(t)}$$

【注】這裡涉及到softmax求導的規律,如果不懂的話可以看看:傳送門

接下來從RNN的尾部開始,逐漸計算隐藏狀态$\textbf{h}^t$的梯度。如果$\tau$是最後的時間步,$\textbf{h}^{(\tau)}$就是最後的隐藏輸出。

$$\frac{\partial L}{\partial \textbf{h}^{(\tau)}} = \frac{\partial L}{\partial \textbf{o}^{(\tau)}}\frac{\partial \textbf{o}^{(\tau)}}{\partial \textbf{h}^{(\tau)}}= \textbf{V}^T(\hat{\textbf{y}}^{(\tau)}-\textbf{y}^{(\tau)})$$

然後一步步往前計算$\textbf{h}^t$的梯度,注意$\textbf{h}^{(t)}(t<\tau)$同時有$\textbf{o}^{(t)}$和$\textbf{h}^{(t+1)}$兩個後續節點,是以:

$$\frac{\partial L}{\partial \textbf{h}^{(t)}}=(\frac{\partial \textbf{h}^{(t+1)}}{\partial \textbf{h}^{(t)}})^T\frac{\partial L}{\partial \textbf{h}^{(t+1)}}+(\frac{\partial \textbf{o}^{(t)}}{\partial \textbf{h}^{(t)}})^T\frac{\partial L}{\partial \textbf{o}^{(t)}}=(\frac{\partial \textbf{h}^{(t+1)}}{\partial \textbf{a}^{(t+1)}} \frac{\partial \textbf{a}^{(t+1)}}{\partial \textbf{h}^{(t)}})^T \frac{\partial L}{\partial \textbf{h}^{(t+1)}}+\textbf{V}^T(\hat{\textbf{y}}^{(t)}-\textbf{y}^{(t)})= \textbf{W}^T(diag(1-(\textbf{h}^{(t+1)})^2))\frac{\partial L}{\partial \textbf{h}^{(t+1)}}+\textbf{V}^T(\hat{\textbf{y}}^{(t)}-\textbf{y}^{(t)})$$

【注】這裡的結果和花書有點不一樣,不知道是花書有錯誤還是我這裡錯了?

剩下的參數計算起來就簡單多了:

$$\frac{\partial L}{\partial \textbf{W}} = \sum_{t=1}^{\tau}\frac{\partial L}{\partial \textbf{h}^{(t)}}\frac{\partial \textbf{h}^{(t)}}{\partial \textbf{W}} = \sum_{t=1}^{\tau}\frac{\partial L}{\partial \textbf{h}^{(t)}}\frac{\partial \textbf{h}^{(t)}}{\partial \textbf{a}^{(t)}}\frac{\partial \textbf{a}^{(t)}}{\partial \textbf{W}} = \sum_{t=1}^{\tau}diag(1-(\textbf{h}^{(t)})^2)\frac{\partial L}{\partial \textbf{h}^{(t)}}(\textbf{h}^{(t-1)})^T$$

$$\frac{\partial L}{\partial \textbf{b}}= \sum\limits_{t=1}^{\tau}diag(1-(\textbf{h}^{(t)})^2)\frac{\partial L}{\partial \textbf{h}^{(t)}}$$

$$\frac{\partial L}{\partial \textbf{U}} =\sum\limits_{t=1}^{\tau}diag(1-(\textbf{h}^{(t)})^2)\frac{\partial L}{\partial \textbf{h}^{(t)}}(\textbf{x}^{(t)})^T$$

$$\frac{\partial L}{\partial \textbf{c}} = \sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial \textbf{c}}  = \sum\limits_{t=1}^{\tau}\hat{\textbf{y}}^{(t)} - \textbf{y}^{(t)}$$

$$\frac{\partial L}{\partial \textbf{V}} =\sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial \textbf{V}}  = \sum\limits_{t=1}^{\tau}(\hat{\textbf{y}}^{(t)} - \textbf{y}^{(t)}) (\textbf{h}^{(t)})^T$$

參考

【1】RNN前向方向傳播(花書《深度學習》10.2循環神經網絡)

【2】循環神經網絡(RNN)模型與前向反向傳播算法

繼續閱讀