摘要
cmu和google brain聯手推出了bert的改進版xlnet。在這之前也有很多公司對bert進行了優化,包括百度、清華的知識圖譜融合,微軟在預訓練階段的多任務學習等等,但是這些優化并沒有把bert緻命缺點進行改進。xlnet作為bert的更新模型,主要在以下三個方面進行了優化
采用ar模型替代ae模型,解決mask帶來的負面影響
雙流注意力機制
引入transformer-xl
今天我們使用xlnet+bilstm實作一個二分類模型。
資料集
資料集如下圖:
是顧客對餐廳的正負評價。正面的評論是1,負面的是0。這類的資料集很多,比如電影的正負評論,商品的正負評論。
模型
模型結構如下:
思路:将xlnet做為嵌入層提取特征,然後傳入bilstm,最後使用全連接配接層輸出分類。建立xlnet_lstm模型,代碼如下:
xlnet_lstm需要的參數功6個,參數說明如下:
--xlnetpath:xlnet預訓練模型的路徑
--hidden_dim:隐藏層的數量。
--output_size:分類的個數。
--n_layers:lstm的層數
--bidirectional:是否是雙向lstm
--drop_prob:dropout的參數
定義xlnet的參數,如下:
batch_size:batchsize的大小,根據顯存設定。
output_size:輸出的類别個數,本例是2.
hidden_dim:隐藏層的數量。
n_layers:lstm的層數。
bidirectional:是否雙向
print_every:輸出的間隔。
use_cuda:是否使用cuda,預設使用,不用cuda太慢了。
xlnet_path:預訓練模型存放的檔案夾。
save_path:模型儲存的路徑。
下載下傳預訓練模型
本例使用的預訓練模型是xlnet-base-cased,下載下傳位址:https://huggingface.co/hfl/chinese-xlnet-base/tree/main
将上圖畫框的檔案下載下傳下來,如果下載下傳後的名字和上面顯示的名字不一樣,則要修改回來。
将下載下傳好的檔案放入xlnet-base-chinese檔案夾中。
配置環境
需要下載下傳transformers和sentencepiece,執行指令:
訓練、驗證和預測
訓練詳見train_model函數,驗證詳見test_model,單次預測詳見predict函數。
代碼和模型連結: