本篇文章内容基于Shusen Wang老師的《RNN模型與NLP應用》系列課程。
課程視訊連結:https://www.youtube.com/playlist?list=PLvOO0btloRnuTUGN4XqO85eKPeFSZsEqK
課件:https://github.com/wangshusen/DeepLearning
目錄
- 1. Tokenization & Build dictionary
- 2. One-hot Encoding
- 3. Seq2Seq
- 3.1 Encoder
- 3.2 Decoder
- 3.3 Inference
- 4. 總結
- 參考
token是“符号”的意思,那tokenization簡單了解就是分詞,比如 “我是中國人”可以分解成['我', '是', '中國人']。
假設我們需要把英語翻譯成德語,那麼我們首先要做的是對不同語種做tokenization(分詞)。常用的分詞做法是以“詞”為機關,這裡為友善介紹,就以字元為機關:
- 英語有26個字母,考慮大小寫的話就有52個字元。
- 德語也有26個字母,還有4個特殊字元。
分詞後就可以得到不同語種對應的字典,結果如下圖所示。
注意不同語種之前分詞的結果是不一樣的,比如字元 a 在英文中的編号是1,而在德文中是3。另外因為現在的任務是講英文翻譯成德文,是以德文裡額外添加了兩個特殊符号:
\t
和
\n
分别表示起始和結束符号,你也可以用其他特殊符号,隻要不與其他符号重複就可以了。
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZuBnLwQjMxQTN4UzN4EjNx8FOx8FNfFjMwIzLc9mYnd2cyFWbvwlclR3ch12LcRWZCNWaQ9CXvJ2ZnNnch12Lc12bj5CduVGdu92YyV2c1JWdoRXan5ydhJ3Lc9CX6MHc0RHaiojIsJye.png)
分詞和建構字典後,我們需要把每個字元轉化成one-hot格式,如下圖示。是以,每個字元有相同長度的列向量表示,其中隻有一個位置的值為1,其餘位置均為0。一個詞或者一句話由一個矩陣表示。
Seq2Seq由兩部分組成:Encoder和Decoder。這兩部分都是LSTM結構,下面分别介紹兩個部分。
如果對LSTM不熟悉,可以看看最後給出的參考文獻和下面兩個示意圖:
LSTM結構示意圖:
LSTM Cell結構示意圖:![]()
NLP系列筆記-機器翻譯之Sequence-to-Sequence模型 ![]()
NLP系列筆記-機器翻譯之Sequence-to-Sequence模型
Encoder主要用來對輸入的英語句子進行特征提取,它最後的輸出會作為輸入傳給Decoder。
如下圖示,傳入的是英文"go away"的one-hot編碼,encoder會生成很多hidden states,隻不過隻會保留最後的hidden state (h, c)。
拿到了Encoder的編碼結果後,此時Decoder就會開始做預測。
Decoder最開始的輸入是起始符号
\t
,初始狀态是Encoder傳入的
(h,c)
,基于這些會輸出預測機率向量
p
,這和圖像分類的softmax輸出類似。真實的label
y
就是下一個字元
m
的one-hot編碼,之後我們可以通過計算交叉熵來更新Encoder和Decoder的參數權重。
上面隻是預測了一個字元,我們還需要不斷預測。我們假設前一次預測的機率向量中機率最大的索引剛好就是
m
這個字元的索引, 即
argmax(p)
等于
index(m)
,那麼Decoder下一個輸入值就是
m
的one-hot編碼序列。Decoder按照前面介紹的方式不斷做預測,直到預測的字元為終止符号
\n
。
假設Seq2Seq模型訓練好了,那它的inference流程是什麼樣的呢?
- Step 1: 首先Decoder接收Encoder的輸入 \((h_0,c_0)\),輸入為
,其輸出為\((h_1,c_1)\)和\t
m
- Step 2:之後Decoder的初始狀态不再是Encoder的輸出,而是上一時刻的\((h_1,c_1)\),其輸入也變成了上一時刻預測的
m
- ... 重複上面的操作知道輸出為
就停止inference。\n
下面給出了利用Keras實作Seq2Seq的示意圖,每次給Decoder傳入新的輸入,計算loss并更新Decoder和Encoder。
下圖給出了Seq2Seq更加直覺的網絡結構示意圖。
-
NLP領域中的token和tokenization到底指的是什麼? - 周鳥的回答 - 知乎
https://www.zhihu.com/question/64984731/answer/292860859
-
了解Pytorch中LSTM的輸入輸出參數含義 - marsggbo的文章 - 知乎
https://zhuanlan.zhihu.com/p/100360301
- LSTM結構詳解:https://colah.github.io/posts/2015-08-Understanding-LSTMs/
微信公衆号:AutoML機器學習 MARSGGBO ♥原創
如有意合作或學術讨論歡迎私戳聯系~
2021-04-18 22:57:17