天天看點

從循環神經網絡(RNN)到LSTM網絡

通常,資料的存在形式有語音、文本、圖像、視訊等。因為我的研究方向主要是圖像識别,是以很少用有“記憶性”的深度網絡。懷着對循環神經網絡的興趣,在看懂了有關它的理論後,我又看了Github上提供的tensorflow實作,覺得收獲很大,故在這裡把我的了解記錄下來,也希望對大家能有所幫助。本文将主要介紹RNN相關的理論,并引出LSTM網絡結構(關于對tensorflow實作細節的了解,有時間的話,在下一篇博文中做介紹)。

循環神經網絡

RNN,也稱作循環神經網絡(還有一種深度網絡,稱作遞歸神經網絡,讀者要差別對待)。因為這種網絡有“記憶性”,是以主要是應用在自然語言處理(NLP)和語音領域。與傳統的Neural network不同,RNN能利用上"序列資訊"。從理論上講,它可以利用任意長序列的資訊,但由于該網絡結構存在“消失梯度”問題,是以在實際應用中,它隻能回溯利用與它接近的time steps上的資訊。

1. 網絡結構

常見的神經網絡結構有卷積網絡、循環網絡和遞歸網絡,棧式自編碼器和玻爾茲曼機也可以看做是特殊的卷積網絡,差別是它們的損失函數定義成均方誤差函數。遞歸網絡類似于資料結構中的樹形結構,且其每層之間會有共享參數。而最為常用的循環神經網絡,它的每層的結構相同,且每層之間參數完全共享。RNN的縮略圖和展開圖如下,

從循環神經網絡(RNN)到LSTM網絡

盡管RNN的網絡結構看上去與常見的前饋網絡不同,但是它的展開圖中資訊流向也是确定的,沒有環流,是以也屬于forward network,故也可以使用反向傳播(back propagation)算法來求解參數的梯度。另外,在RNN網絡中,可以有單輸入、多輸入、單輸出、多輸出,視具體任務而定。

2. 損失函數

在輸出層為二分類或者softmax多分類的深度網絡中,代價函數通常選擇交叉熵(cross entropy)損失函數,前面的博文中證明過,在分類問題中,交叉熵函數的本質就是似然損失函數。盡管RNN的網絡結構與分類網絡不同,但是損失函數也是有相似之處的。

  假設我們采用RNN網絡建構“語言模型”,“語言模型”其實就是看“一句話說出來是不是順口”,可以應用在機器翻譯、語音識别領域,從若幹候選結果中挑一個更加靠譜的結果。通常每個sentence長度不一樣,每一個word作為一個訓練樣例,一個sentence作為一個Minibatch,記sentence的長度為T。為了更好地了解語言模型中損失函數的定義形式,這裡做一些推導,根據全機率公式,則一句話是“自然化的語句”的機率為 p ( w 1 , w 2 , . . . , w T ) = p ( w 1 ) × p ( w 2 ∣ w 1 ) × . . . × p ( w T ∣ w 1 , w 2 , . . . , w T − 1 ) p(w_{1}, w_{2}, ..., w_{T})=p(w_{1})\times p(w_{2}|w_{1})\times ...\times p(w_{T}|w_{1},w_{2},...,w_{T-1}) p(w1​,w2​,...,wT​)=p(w1​)×p(w2​∣w1​)×...×p(wT​∣w1​,w2​,...,wT−1​)  是以語言模型的目标就是最大化 P ( w 1 , w 2 , . . . , w T ) P(w_{1}, w_{2}, ..., w_{T}) P(w1​,w2​,...,wT​)。而損失函數通常為最小化問題,是以可以定義 L o s s ( w 1 , w 2 , . . . , w T ∣ θ ) = − l o g P ( w 1 , w 2 , . . . , w T ∣ θ ) Loss(w_{1}, w_{2},...,w_{T}|\theta )=-logP(w_{1}, w_{2},...,w_{T}|\theta) Loss(w1​,w2​,...,wT​∣θ)=−logP(w1​,w2​,...,wT​∣θ)  那麼公式展開可得 L o s s ( w 1 , w 2 , . . . , w T ∣ θ ) = − ( l o g p ( w 1 ) + l o g p ( w 2 ∣ w 1 ) + . . . + l o g p ( w T ∣ w 1 , w 2 , . . . , w T − 1 ) ) Loss(w_{1}, w_{2},...,w_{T}|\theta )=-(logp(w_{1})+logp(w_{2}|w_{1})+ ...+logp(w_{T}|w_{1},w_{2},...,w_{T-1})) Loss(w1​,w2​,...,wT​∣θ)=−(logp(w1​)+logp(w2​∣w1​)+...+logp(wT​∣w1​,w2​,...,wT−1​))  展開式中的每一項為一個softmax分類模型,類别數為所采用的詞庫大小(vocabulary size),相信大家此刻應該就明白了,為什麼使用RNN網絡解決語言模型時,輸入序列和輸出序列錯了一個位置了。

3. 梯度求解

在訓練任何深度網絡模型時,求解損失函數關于模型參數的梯度,應該算是最為核心的一步了。在RNN模型訓練時,采用的是BPTT(back propagation through time)算法,這個算法其實實質上就是樸素的BP算法,也是采用的“鍊式法則”求解參數梯度,唯一的不同在于每一個time step上參數共享。從數學的角度來講,BP算法就是一個單變量求導過程,而BPTT算法就是一個複合函數求導過程。接下來以損失函數展開式中的第3項為例,推導其關于網絡參數U、W、V的梯度表達式(總損失的梯度則是各項相加的過程而已)。

  為了簡化符号表示,記 E 3 = − l o g p ( w 3 ∣ w 1 , w 2 ) E_{3}=-logp(w_{3}|w_{1},w_{2}) E3​=−logp(w3​∣w1​,w2​),則根據RNN的展開圖可得,

(1) s 3 = t a n h ( U × x 3 + W × s 2 ) ;     s 2 = t a n h ( U × x 2 + W × s 1 ) s 1 = t a n h ( U × x 1 + W × s 0 ) ;     s 0 = t a n h ( U × x 0 + W × s − 1 ) s_{3}=tanh(U\times x_{3}+W\times s_{2});  s_{2}=tanh(U\times x_{2}+W\times s_{1})\\ s_{1}=tanh(U\times x_{1}+W\times s_{0});  s_{0}=tanh(U\times x_{0}+W\times s_{-1}) \\\tag{1} s3​=tanh(U×x3​+W×s2​);  s2​=tanh(U×x2​+W×s1​)s1​=tanh(U×x1​+W×s0​);  s0​=tanh(U×x0​+W×s−1​)(1)

是以,

(2) ∂ s 3 W = ∂ s 3 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 W ∂ s 2 W = ∂ s 2 W 1 + ∂ s 2 ∂ s 1 × ∂ s 1 W ∂ s 1 W = ∂ s 1 W 0 + ∂ s 1 ∂ s 0 × ∂ s 0 W ∂ s 0 W = ∂ s 0 W 1 \frac{\partial s_{3}}{W}=\frac{\partial s_{3}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{W}\\ \frac{\partial s_{2}}{W}=\frac{\partial s_{2}}{W_{1}}+\frac{\partial s_{2}}{\partial s_{1}}\times \frac{\partial s_{1}}{W}\\ \frac{\partial s_{1}}{W}=\frac{\partial s_{1}}{W_{0}}+\frac{\partial s_{1}}{\partial s_{0}}\times \frac{\partial s_{0}}{W}\\ \frac{\partial s_{0}}{W}=\frac{\partial s_{0}}{W_{1}}\tag{2} W∂s3​​=W1​∂s3​​+∂s2​∂s3​​×W∂s2​​W∂s2​​=W1​∂s2​​+∂s1​∂s2​​×W∂s1​​W∂s1​​=W0​∂s1​​+∂s0​∂s1​​×W∂s0​​W∂s0​​=W1​∂s0​​(2)

說明一下,為了更好地展現複合函數求導的思想,公式(2)中引入了變量 W 1 W_{1} W1​,可以把 W 1 W_{1} W1​看作關于W的函數,即 W 1 = W W_{1}=W W1​=W。另外,因為 s − 1 s_{-1} s−1​表示RNN網絡的初始狀态,為一個常數向量,是以公式(2)中第4個表達式展開後隻有一項。是以由公式(2)可得,

(3) ∂ s 3 W = ∂ s 3 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 ∂ s 1 × ∂ s 1 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 ∂ s 1 × ∂ s 1 ∂ s 0 × ∂ s 0 W 1 \frac{\partial s_{3}}{W}=\frac{\partial s_{3}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{\partial s_{1}}\times \frac{\partial s_{1}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{\partial s_{1}}\times \frac{\partial s_{1}}{\partial s_{0}}\times \frac{\partial s_{0}}{W_{1}}\tag{3} W∂s3​​=W1​∂s3​​+∂s2​∂s3​​×W1​∂s2​​+∂s2​∂s3​​×∂s1​∂s2​​×W1​∂s1​​+∂s2​∂s3​​×∂s1​∂s2​​×∂s0​∂s1​​×W1​∂s0​​(3)

簡化得下式,

(4) ∂ s 3 W = ∂ s 3 W 1 + ∂ s 3 ∂ s 2 × ∂ s 2 W 1 + ∂ s 3 ∂ S 1 × ∂ s 1 W 1 + ∂ s 3 ∂ s 0 × ∂ s 0 W 1 \frac{\partial s_{3}}{W}=\frac{\partial s_{3}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{W_{1}}+\frac{\partial s_{3}}{\partial S_{1}}\times \frac{\partial s_{1}}{W_{1}}+\frac{\partial s_{3}}{\partial s_{0}}\times \frac{\partial s_{0}}{W_{1}}\tag{4} W∂s3​​=W1​∂s3​​+∂s2​∂s3​​×W1​∂s2​​+∂S1​∂s3​​×W1​∂s1​​+∂s0​∂s3​​×W1​∂s0​​(4)

繼續簡化得下式,

(5) ∂ s 3 W = ∑ i = 0 3 ∂ s 3 ∂ s i × ∂ s i W 1 \frac{\partial s_{3}}{W}=\sum_{i=0}^{3}\frac{\partial s_{3}}{\partial s_{i}}\times \frac{\partial s_{i}}{W_{1}}\tag{5} W∂s3​​=i=0∑3​∂si​∂s3​​×W1​∂si​​(5)

3.1 E 3 E_{3} E3​關于參數V的偏導數

記t=3時刻的softmax神經元的輸入為 a 3 a_{3} a3​,輸出為 y 3 y_{3} y3​,網絡的真實标簽為 y 3 ( 1 ) y_{3}^{(1)} y3(1)​。根據函數求導的“鍊式法則”,是以有下式成立,

(6) ∂ E 3 V = ∂ E 3 ∂ a 3 × ∂ a 3 ∂ V = ( y 3 ( 1 ) − y 3 ) ⨂ s 3 \frac{\partial E_{3}}{V}=\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial V}=(y_{3}^{(1)}-y_{3})\bigotimes s_{3}\tag{6} V∂E3​​=∂a3​∂E3​​×∂V∂a3​​=(y3(1)​−y3​)⨂s3​(6)

3.2 E 3 E_{3} E3​關于參數W的偏導數

關于參數W的偏導數,就要使用到上面關于複合函數的推導過程了,記 z i z_{i} zi​為t=i時刻隐藏層神經元的輸入,則具體的表達式簡化過程如下,

(7) ∂ E 3 W = ∂ E 3 ∂ s 3 × ∂ s 3 ∂ W = ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ W = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ s k × ∂ s k ∂ W 1 = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ s k × ∂ s k ∂ z k × ∂ z k ∂ W 1 = ∑ k = 0 3 ∂ E 3 ∂ z k × ∂ z k ∂ w 1 \frac{\partial E_{3}}{W}=\frac{\partial E_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial W}=\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial W}\\ = \sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial s_{k}}\times \frac{\partial s_{k}}{\partial W_{1}}\\ = \sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial s_{k}}\times \frac{\partial s_{k}}{\partial z_{k}}\times \frac{\partial z_{k}}{\partial W_{1}}\\ =\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial z_{k}}\times \frac{\partial z_{k}}{\partial w_{1}}\tag{7} W∂E3​​=∂s3​∂E3​​×∂W∂s3​​=∂a3​∂E3​​×∂s3​∂a3​​×∂W∂s3​​=k=0∑3​∂a3​∂E3​​×∂s3​∂a3​​×∂sk​∂s3​​×∂W1​∂sk​​=k=0∑3​∂a3​∂E3​​×∂s3​∂a3​​×∂sk​∂s3​​×∂zk​∂sk​​×∂W1​∂zk​​=k=0∑3​∂zk​∂E3​​×∂w1​∂zk​​(7)

類似于标準的BP算法中的表示,定義 δ n m = ∂ E m ∂ z n \delta _{n}^{m}=\frac{\partial E_{m}}{\partial z_{n}} δnm​=∂zn​∂Em​​,那麼可以得到如下遞推公式,

(8) δ 2 3 = ∂ E 3 ∂ z 3 × ∂ z 3 ∂ z 2 = ∂ E 3 ∂ z 3 × ∂ z 3 ∂ s 2 × ∂ s 2 ∂ z 2 = ( δ 3 3 ⨂ W ) ⨀ ( 1 − s 2 2 ) \delta _{2}^{3}=\frac{\partial E_{3}}{\partial z_{3}}\times \frac{\partial z_{3}}{\partial z_{2}}=\frac{\partial E_{3}}{\partial z_{3}}\times \frac{\partial z_{3}}{\partial s_{2}}\times \frac{\partial s_{2}}{\partial z_{2}}=(\delta _{3}^{3}\bigotimes W)\bigodot (1-s_{2}^{2})\tag{8} δ23​=∂z3​∂E3​​×∂z2​∂z3​​=∂z3​∂E3​​×∂s2​∂z3​​×∂z2​∂s2​​=(δ33​⨂W)⨀(1−s22​)(8)

那麼,公式(7)可以轉化為下式,

(9) ∂ E 3 W = ∑ k = 0 3 δ k 3 × ∂ z k ∂ w 1 \frac{\partial E_{3}}{W}=\sum_{k=0}^{3}\delta _{k}^{3}\times \frac{\partial z_{k}}{\partial w_{1}}\tag{9} W∂E3​​=k=0∑3​δk3​×∂w1​∂zk​​(9)

顯然,結合公式(8)中的遞推公式,可以遞推求解出公式(9)中的每一項,那麼 E 3 E_{3} E3​關于參數W的偏導數便迎刃而解了。

3.3 E 3 E_{3} E3​關于參數U的偏導數

關于參數U的偏導數求解過程,跟W的偏導數求解過程非常類似,在這裡就不介紹了,感興趣的讀者可以結合3.2的思路嘗試着自己推導一下。

4. 梯度消失問題

當網絡層數增多時,在使用BP算法求解梯度時,自然而然地就會出現“vanishing gradient“問題(還有一種稱作“exploding gradient”,但這種情況在訓練模型過程中易于被發現,是以可以通過人為控制來解決),下面我們從數學的角度來證明RNN網絡确實存在“vanishing gradient“問題,推導公式如下,

(10) ∂ E 3 W = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ∂ s 3 ∂ s k × ∂ s k ∂ W 1 = ∑ k = 0 3 ∂ E 3 ∂ a 3 × ∂ a 3 ∂ s 3 × ( ∏ i = k + 1 3 ∂ s i ∂ s i − 1 ) × ∂ s k ∂ W 1 \frac{\partial E_{3}}{W}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times \frac{\partial s_{3}}{\partial s_{k}}\times \frac{\partial s_{k}}{\partial W_{1}}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial a_{3}}\times \frac{\partial a_{3}}{\partial s_{3}}\times (\prod_{i=k+1}^{3}\frac{\partial s_{i}}{\partial s_{i-1}})\times \frac{\partial s_{k}}{\partial W_{1}}\tag{10} W∂E3​​=k=0∑3​∂a3​∂E3​​×∂s3​∂a3​​×∂sk​∂s3​​×∂W1​∂sk​​=k=0∑3​∂a3​∂E3​​×∂s3​∂a3​​×(i=k+1∏3​∂si−1​∂si​​)×∂W1​∂sk​​(10)

大家應該注意到了,上面的式子中有一個連乘式,對于其中的每一項,滿足 s i = a c t i v a t i o n ( U × x i + W × s i − 1 ) s_{i}=activation(U\times x_{i}+W\times s_{i-1}) si​=activation(U×xi​+W×si−1​),當激活函數為tanh時, ∂ s i ∂ s i − 1 \frac{\partial s_{i}}{\partial s_{i-1}} ∂si−1​∂si​​的取值範圍為[0, 1]。當激活函數為sigmoid時, ∂ s i ∂ s i − 1 \frac{\partial s_{i}}{\partial s_{i-1}} ∂si−1​∂si​​的取值範圍為[0, 1/4](簡單的一進制函數求導,這裡就不展開了)。因為這裡我們選擇t=3時刻的輸出損失,是以連乘的式子的個數并不多。但是我們可以設想一下,對于深度的網絡結構而言,若選擇tanh或者sigmoid激活函數,對于公式(10)中k取值較小的那一項,一定滿足 ∏ i = k + 1 3 ∂ s i ∂ s i − 1 \prod_{i=k+1}^{3}\frac{\partial s_{i}}{\partial s_{i-1}} ∏i=k+13​∂si−1​∂si​​趨近于0,進而導緻了消失梯度問題。

我們再從直覺的角度來了解一下消失梯度問題,對于RNN時刻T的輸出,其必定是時刻t=1,…,T-1的輸入綜合作用的結果,也就是說更新模型參數時,要充分利用目前時刻以及之前所有時刻的輸入資訊。但是如果發生了”消失梯度”問題,就會意味着,距離目前時刻非常遠的輸入資料,不能為目前模型參數的更新做貢獻,是以在RNN的程式設計實作中,才會有“truncated gradient”這一概念,“截斷梯度”就是在更新參數時,隻利用較近的時刻的序列資訊,把那些“曆史悠久的資訊”忽略掉了。

解決“消失梯度問題”,我們可以更換激活函數,比如采用Relu(rectified linear units)激活函數,但是更好的辦法是使用LSTM或者GRU架構的網絡。

LSTM網絡

為了解決原始RNN網絡結構存在的“vanishing gradient”問題,前輩們設計了LSTM這種新的網絡結構。但從本質上來講,LSTM是一種特殊的循環神經網絡,其和RNN的差別在于,對于特定時刻t,隐藏層輸出 s t s_{t} st​的計算方式不同。故對LSTM網絡的訓練的思路與RNN類似,僅前向傳播關系式不同而已。值得一提的是,在對LSTM網絡進行訓練時,cell state c[0]和hidden state s[0]都是随機初始化得到的。

  GRU(Gated Recurrent Unit)是2014年提出來的新的RNN架構,它是簡化版的LSTM,在超參數(hyper-parameters)均調優的前提下,這兩種RNN架構的性能相當,但是GRU架構的參數少,是以需要的訓練樣本更少,易于訓練。LSTM和GRU架構的網絡圖如下,

從循環神經網絡(RNN)到LSTM網絡

關于LSTM網絡結構相關的理論,請參見http://colah.github.io/posts/2015-08-Understanding-LSTMs/,相信也隻有這樣的大牛能把LSTM解析的如此淺顯易懂。這裡還需要補充說明一下,關于LSTM網絡的參數求偏微分,如果我們手動求解的話,也是跟RNN類似的思路,但由于LSTM網絡結構比較複雜,手動算的話,式子會變得非常複雜,我們便可以借助深度學習架構的自動微分功能了,現在的架構也都支援自動微分的,比如theano、tensorflow等。

參考資料:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

     “A tutorial on training recurrent neural networks”. H. Jaeger, 2002.

繼續閱讀