通常,資料的存在形式有語音、文本、圖像、視訊等。因為我的研究方向主要是圖像識别,是以很少用有“記憶性”的深度網絡。懷着對循環神經網絡的興趣,在看懂了有關它的理論後,我又看了Github上提供的tensorflow實作,覺得收獲很大,故在這裡把我的了解記錄下來,也希望對大家能有所幫助。本文将主要介紹RNN相關的理論,并引出LSTM網絡結構(關于對tensorflow實作細節的了解,有時間的話,在下一篇博文中做介紹)。
循環神經網絡
RNN,也稱作循環神經網絡(還有一種深度網絡,稱作遞歸神經網絡,讀者要差別對待)。因為這種網絡有“記憶性”,是以主要是應用在自然語言處理(NLP)和語音領域。與傳統的Neural network不同,RNN能利用上"序列資訊"。從理論上講,它可以利用任意長序列的資訊,但由于該網絡結構存在“消失梯度”問題,是以在實際應用中,它隻能回溯利用與它接近的time steps上的資訊。
1. 網絡結構
常見的神經網絡結構有卷積網絡、循環網絡和遞歸網絡,棧式自編碼器和玻爾茲曼機也可以看做是特殊的卷積網絡,差別是它們的損失函數定義成均方誤差函數。遞歸網絡類似于資料結構中的樹形結構,且其每層之間會有共享參數。而最為常用的循環神經網絡,它的每層的結構相同,且每層之間參數完全共享。RNN的縮略圖和展開圖如下,
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZwpmLxkTN28lN4ATMxIDM4QTMvw1Ny8CXxEjNxAjMvw1ckF2bsBXdvwFdl5mLuR2cj5Set1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
盡管RNN的網絡結構看上去與常見的前饋網絡不同,但是它的展開圖中資訊流向也是确定的,沒有環流,是以也屬于forward network,故也可以使用反向傳播(back propagation)算法來求解參數的梯度。另外,在RNN網絡中,可以有單輸入、多輸入、單輸出、多輸出,視具體任務而定。
2. 損失函數
在輸出層為二分類或者softmax多分類的深度網絡中,代價函數通常選擇交叉熵(cross entropy)損失函數,前面的博文中證明過,在分類問題中,交叉熵函數的本質就是似然損失函數。盡管RNN的網絡結構與分類網絡不同,但是損失函數也是有相似之處的。
假設我們采用RNN網絡建構“語言模型”,“語言模型”其實就是看“一句話說出來是不是順口”,可以應用在機器翻譯、語音識别領域,從若幹候選結果中挑一個更加靠譜的結果。通常每個sentence長度不一樣,每一個word作為一個訓練樣例,一個sentence作為一個Minibatch,記sentence的長度為T。為了更好地了解語言模型中損失函數的定義形式,這裡做一些推導,根據全機率公式,則一句話是“自然化的語句”的機率為 p ( w 1 , w 2 , . . . , w T ) = p ( w 1 ) × p ( w 2 ∣ w 1 ) × . . . × p ( w T ∣ w 1 , w 2 , . . . , w T − 1 ) p(w_{1}, w_{2}, ..., w_{T})=p(w_{1})\times p(w_{2}|w_{1})\times ...\times p(w_{T}|w_{1},w_{2},...,w_{T-1}) p(w1,w2,...,wT)=p(w1)×p(w2∣w1)×...×p(wT∣w1,w2,...,wT−1) 是以語言模型的目标就是最大化 P ( w 1 , w 2 , . . . , w T ) P(w_{1}, w_{2}, ..., w_{T}) P(w1,w2,...,wT)。而損失函數通常為最小化問題,是以可以定義 L o s s ( w 1 , w 2 , . . . , w T ∣ θ ) = − l o g P ( w 1 , w 2 , . . . , w T ∣ θ ) Loss(w_{1}, w_{2},...,w_{T}|\theta )=-logP(w_{1}, w_{2},...,w_{T}|\theta) Loss(w1,w2,...,wT∣θ)=−logP(w1,w2,...,wT∣θ) 那麼公式展開可得 L o s s ( w 1 , w 2 , . . . , w T ∣ θ ) = − ( l o g p ( w 1 ) + l o g p ( w 2 ∣ w 1 ) + . . . + l o g p ( w T ∣ w 1 , w 2 , . . . , w T − 1 ) ) Loss(w_{1}, w_{2},...,w_{T}|\theta )=-(logp(w_{1})+logp(w_{2}|w_{1})+ ...+logp(w_{T}|w_{1},w_{2},...,w_{T-1})) Loss(w1,w2,...,wT∣θ)=−(logp(w1)+logp(w2∣w1)+...+logp(wT∣w1,w2,...,wT−1)) 展開式中的每一項為一個softmax分類模型,類别數為所采用的詞庫大小(vocabulary size),相信大家此刻應該就明白了,為什麼使用RNN網絡解決語言模型時,輸入序列和輸出序列錯了一個位置了。
3. 梯度求解
在訓練任何深度網絡模型時,求解損失函數關于模型參數的梯度,應該算是最為核心的一步了。在RNN模型訓練時,采用的是BPTT(back propagation through time)算法,這個算法其實實質上就是樸素的BP算法,也是采用的“鍊式法則”求解參數梯度,唯一的不同在于每一個time step上參數共享。從數學的角度來講,BP算法就是一個單變量求導過程,而BPTT算法就是一個複合函數求導過程。接下來以損失函數展開式中的第3項為例,推導其關于網絡參數U、W、V的梯度表達式(總損失的梯度則是各項相加的過程而已)。
為了簡化符号表示,記 E 3 = − l o g p ( w 3 ∣ w 1 , w 2 ) E_{3}=-logp(w_{3}|w_{1},w_{2}) E3=−logp(w3∣w1,w2),則根據RNN的展開圖可得,
(1) s 3 = t a n h ( U × x 3 + W × s 2 ) ; s 2 = t a n h ( U × x 2 + W × s 1 ) s 1 = t a n h ( U × x 1 + W × s 0 ) ; s 0 = t a n h ( U × x 0 + W × s − 1 ) s_{3}=tanh(U\times x_{3}+W\times s_{2}); s_{2}=tanh(U\times x_{2}+W\times s_{1})\\ s_{1}=tanh(U\times x_{1}+W\times s_{0}); s_{0}=tanh(U\times x_{0}+W\times s_{-1}) \\\tag{1} s3=tanh(U×x3+W×s2); s2=tanh(U×x2+W×s1)s1=tanh(U×x1+W×s0); s0=tanh(U×x0+W×s−1)(1)
是以,
(2) ∂ s 3 W = ∂ s 3 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 W ∂ s 2 W = ∂ s 2 W 1 + ∂ s 2 ∂ s 1 × ∂ s 1 W ∂ s 1 W = ∂ s 1 W 0 + ∂ s 1 ∂ s 0 × ∂ s 0 W ∂ s 0 W = ∂ s 0 W 1 \frac{\partial s_{3}}{W}=\frac{\partial s_{3}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{W}\\ \frac{\partial s_{2}}{W}=\frac{\partial s_{2}}{W_{1}}+\frac{\partial s_{2}}{\partial s_{1}}\times \frac{\partial s_{1}}{W}\\ \frac{\partial s_{1}}{W}=\frac{\partial s_{1}}{W_{0}}+\frac{\partial s_{1}}{\partial s_{0}}\times \frac{\partial s_{0}}{W}\\ \frac{\partial s_{0}}{W}=\frac{\partial s_{0}}{W_{1}}\tag{2} W∂s3=W1∂s3+∂s2∂s3×W∂s2W∂s2=W1∂s2+∂s1∂s2×W∂s1W∂s1=W0∂s1+∂s0∂s1×W∂s0W∂s0=W1∂s0(2)
說明一下,為了更好地展現複合函數求導的思想,公式(2)中引入了變量 W 1 W_{1} W1,可以把 W 1 W_{1} W1看作關于W的函數,即 W 1 = W W_{1}=W W1=W。另外,因為 s − 1 s_{-1} s−1表示RNN網絡的初始狀态,為一個常數向量,是以公式(2)中第4個表達式展開後隻有一項。是以由公式(2)可得,
(3) ∂ s 3 W = ∂ s 3 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 ∂ s 1 × ∂ s 1 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 ∂ s 1 × ∂ s 1 ∂ s 0 × ∂ s 0 W 1 \frac{\partial s_{3}}{W}=\frac{\partial s_{3}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{\partial s_{1}}\times \frac{\partial s_{1}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{\partial s_{1}}\times \frac{\partial s_{1}}{\partial s_{0}}\times \frac{\partial s_{0}}{W_{1}}\tag{3} W∂s3=W1∂s3+∂s2∂s3×W1∂s2+∂s2∂s3×∂s1∂s2×W1∂s1+∂s2∂s3×∂s1∂s2×∂s0∂s1×W1∂s0(3)
簡化得下式,
(4) ∂ s 3 W = ∂ s 3 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 W 1 + ∂ s 3 ∂ S 1 × ∂ s 1 W 1 + ∂ s 3 ∂ s 0 × ∂ s 0 W 1 \frac{\partial s_{3}}{W}=\frac{\partial s_{3}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{W_{1}}+\frac{\partial s_{3}}{\partial S_{1}}\times \frac{\partial s_{1}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{0}}\times \frac{\partial s_{0}}{W_{1}}\tag{4} W∂s3=W1∂s3+∂s2∂s3×W1∂s2+∂S1∂s3×W1∂s1+∂s0∂s3×W1∂s0(4)
繼續簡化得下式,
(5) ∂ s 3 W = ∑ i = 0 3 ∂ s 3 ∂ s i × ∂ s i W 1 \frac{\partial s_{3}}{W}=\sum_{i=0}^{3}\frac{\partial s_{3}}{\partial s_{i}}\times \frac{\partial s_{i}}{W_{1}}\tag{5} W∂s3=i=0∑3∂si∂s3×W1∂si(5)
3.1 E 3 E_{3} E3關于參數V的偏導數
記t=3時刻的softmax神經元的輸入為 a 3 a_{3} a3,輸出為 y 3 y_{3} y3,網絡的真實标簽為 y 3 ( 1 ) y_{3}^{(1)} y3(1)。根據函數求導的“鍊式法則”,是以有下式成立,
(6) ∂ E 3 V = ∂ E 3 ∂ a 3 × ∂ a 3 ∂ V = ( y 3 ( 1 ) − y 3 ) ⨂ s 3 \frac{\partial E_{3}}{V}=\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial V}=(y_{3}^{(1)}-y_{3})\bigotimes s_{3}\tag{6} V∂E3=∂a3∂E3×∂V∂a3=(y3(1)−y3)⨂s3(6)
3.2 E 3 E_{3} E3關于參數W的偏導數
關于參數W的偏導數,就要使用到上面關于複合函數的推導過程了,記 z i z_{i} zi為t=i時刻隐藏層神經元的輸入,則具體的表達式簡化過程如下,
(7) ∂ E 3 W = ∂ E 3 ∂ s 3 × ∂ s 3 ∂ W = ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ W = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ s k × ∂ s k ∂ W 1 = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ s k × ∂ s k ∂ z k × ∂ z k ∂ W 1 = ∑ k = 0 3 ∂ E 3 ∂ z k × ∂ z k ∂ w 1 \frac{\partial E_{3}}{W}=\frac{\partial E_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial W}=\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial W}\\ = \sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial s_{k}}\times \frac{\partial s_{k}}{\partial W_{1}}\\ = \sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial s_{k}}\times \frac{\partial s_{k}}{\partial z_{k}}\times \frac{\partial z_{k}}{\partial W_{1}}\\ =\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial z_{k}}\times \frac{\partial z_{k}}{\partial w_{1}}\tag{7} W∂E3=∂s3∂E3×∂W∂s3=∂a3∂E3×∂s3∂a3×∂W∂s3=k=0∑3∂a3∂E3×∂s3∂a3×∂sk∂s3×∂W1∂sk=k=0∑3∂a3∂E3×∂s3∂a3×∂sk∂s3×∂zk∂sk×∂W1∂zk=k=0∑3∂zk∂E3×∂w1∂zk(7)
類似于标準的BP算法中的表示,定義 δ n m = ∂ E m ∂ z n \delta _{n}^{m}=\frac{\partial E_{m}}{\partial z_{n}} δnm=∂zn∂Em,那麼可以得到如下遞推公式,
(8) δ 2 3 = ∂ E 3 ∂ z 3 × ∂ z 3 ∂ z 2 = ∂ E 3 ∂ z 3 × ∂ z 3 ∂ s 2 × ∂ s 2 ∂ z 2 = ( δ 3 3 ⨂ W ) ⨀ ( 1 − s 2 2 ) \delta _{2}^{3}=\frac{\partial E_{3}}{\partial z_{3}}\times \frac{\partial z_{3}}{\partial z_{2}}=\frac{\partial E_{3}}{\partial z_{3}}\times \frac{\partial z_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{\partial z_{2}}=(\delta _{3}^{3}\bigotimes W)\bigodot (1-s_{2}^{2})\tag{8} δ23=∂z3∂E3×∂z2∂z3=∂z3∂E3×∂s2∂z3×∂z2∂s2=(δ33⨂W)⨀(1−s22)(8)
那麼,公式(7)可以轉化為下式,
(9) ∂ E 3 W = ∑ k = 0 3 δ k 3 × ∂ z k ∂ w 1 \frac{\partial E_{3}}{W}=\sum_{k=0}^{3}\delta _{k}^{3}\times \frac{\partial z_{k}}{\partial w_{1}}\tag{9} W∂E3=k=0∑3δk3×∂w1∂zk(9)
顯然,結合公式(8)中的遞推公式,可以遞推求解出公式(9)中的每一項,那麼 E 3 E_{3} E3關于參數W的偏導數便迎刃而解了。
3.3 E 3 E_{3} E3關于參數U的偏導數
關于參數U的偏導數求解過程,跟W的偏導數求解過程非常類似,在這裡就不介紹了,感興趣的讀者可以結合3.2的思路嘗試着自己推導一下。
4. 梯度消失問題
當網絡層數增多時,在使用BP算法求解梯度時,自然而然地就會出現“vanishing gradient“問題(還有一種稱作“exploding gradient”,但這種情況在訓練模型過程中易于被發現,是以可以通過人為控制來解決),下面我們從數學的角度來證明RNN網絡确實存在“vanishing gradient“問題,推導公式如下,
(10) ∂ E 3 W = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ s k × ∂ s k ∂ W 1 = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ( ∏ i = k + 1 3 ∂ s i ∂ s i − 1 ) × ∂ s k ∂ W 1 \frac{\partial E_{3}}{W}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial s_{k}}\times \frac{\partial s_{k}}{\partial W_{1}}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times (\prod_{i=k+1}^{3}\frac{\partial s_{i}}{\partial s_{i-1}})\times \frac{\partial s_{k}}{\partial W_{1}}\tag{10} W∂E3=k=0∑3∂a3∂E3×∂s3∂a3×∂sk∂s3×∂W1∂sk=k=0∑3∂a3∂E3×∂s3∂a3×(i=k+1∏3∂si−1∂si)×∂W1∂sk(10)
大家應該注意到了,上面的式子中有一個連乘式,對于其中的每一項,滿足 s i = a c t i v a t i o n ( U × x i + W × s i − 1 ) s_{i}=activation(U\times x_{i}+W\times s_{i-1}) si=activation(U×xi+W×si−1),當激活函數為tanh時, ∂ s i ∂ s i − 1 \frac{\partial s_{i}}{\partial s_{i-1}} ∂si−1∂si的取值範圍為[0, 1]。當激活函數為sigmoid時, ∂ s i ∂ s i − 1 \frac{\partial s_{i}}{\partial s_{i-1}} ∂si−1∂si的取值範圍為[0, 1/4](簡單的一進制函數求導,這裡就不展開了)。因為這裡我們選擇t=3時刻的輸出損失,是以連乘的式子的個數并不多。但是我們可以設想一下,對于深度的網絡結構而言,若選擇tanh或者sigmoid激活函數,對于公式(10)中k取值較小的那一項,一定滿足 ∏ i = k + 1 3 ∂ s i ∂ s i − 1 \prod_{i=k+1}^{3}\frac{\partial s_{i}}{\partial s_{i-1}} ∏i=k+13∂si−1∂si趨近于0,進而導緻了消失梯度問題。
我們再從直覺的角度來了解一下消失梯度問題,對于RNN時刻T的輸出,其必定是時刻t=1,…,T-1的輸入綜合作用的結果,也就是說更新模型參數時,要充分利用目前時刻以及之前所有時刻的輸入資訊。但是如果發生了”消失梯度”問題,就會意味着,距離目前時刻非常遠的輸入資料,不能為目前模型參數的更新做貢獻,是以在RNN的程式設計實作中,才會有“truncated gradient”這一概念,“截斷梯度”就是在更新參數時,隻利用較近的時刻的序列資訊,把那些“曆史悠久的資訊”忽略掉了。
解決“消失梯度問題”,我們可以更換激活函數,比如采用Relu(rectified linear units)激活函數,但是更好的辦法是使用LSTM或者GRU架構的網絡。
LSTM網絡
為了解決原始RNN網絡結構存在的“vanishing gradient”問題,前輩們設計了LSTM這種新的網絡結構。但從本質上來講,LSTM是一種特殊的循環神經網絡,其和RNN的差別在于,對于特定時刻t,隐藏層輸出 s t s_{t} st的計算方式不同。故對LSTM網絡的訓練的思路與RNN類似,僅前向傳播關系式不同而已。值得一提的是,在對LSTM網絡進行訓練時,cell state c[0]和hidden state s[0]都是随機初始化得到的。
GRU(Gated Recurrent Unit)是2014年提出來的新的RNN架構,它是簡化版的LSTM,在超參數(hyper-parameters)均調優的前提下,這兩種RNN架構的性能相當,但是GRU架構的參數少,是以需要的訓練樣本更少,易于訓練。LSTM和GRU架構的網絡圖如下,
關于LSTM網絡結構相關的理論,請參見http://colah.github.io/posts/2015-08-Understanding-LSTMs/,相信也隻有這樣的大牛能把LSTM解析的如此淺顯易懂。這裡還需要補充說明一下,關于LSTM網絡的參數求偏微分,如果我們手動求解的話,也是跟RNN類似的思路,但由于LSTM網絡結構比較複雜,手動算的話,式子會變得非常複雜,我們便可以借助深度學習架構的自動微分功能了,現在的架構也都支援自動微分的,比如theano、tensorflow等。
參考資料:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/
“A tutorial on training recurrent neural networks”. H. Jaeger, 2002.