天天看點

xlnet+bilstm實作菜品正負評價分類

摘要

cmu和google brain聯手推出了bert的改進版xlnet。在這之前也有很多公司對bert進行了優化,包括百度、清華的知識圖譜融合,微軟在預訓練階段的多任務學習等等,但是這些優化并沒有把bert緻命缺點進行改進。xlnet作為bert的更新模型,主要在以下三個方面進行了優化

采用ar模型替代ae模型,解決mask帶來的負面影響

雙流注意力機制

引入transformer-xl

今天我們使用xlnet+bilstm實作一個二分類模型。

資料集

資料集如下圖:

xlnet+bilstm實作菜品正負評價分類

是顧客對餐廳的正負評價。正面的評論是1,負面的是0。這類的資料集很多,比如電影的正負評論,商品的正負評論。

模型

模型結構如下:

xlnet+bilstm實作菜品正負評價分類

思路:将xlnet做為嵌入層提取特征,然後傳入bilstm,最後使用全連接配接層輸出分類。建立xlnet_lstm模型,代碼如下:

xlnet_lstm需要的參數功6個,參數說明如下:

​ --xlnetpath:xlnet預訓練模型的路徑

​ --hidden_dim:隐藏層的數量。

​ --output_size:分類的個數。

​ --n_layers:lstm的層數

​ --bidirectional:是否是雙向lstm

​ --drop_prob:dropout的參數

定義xlnet的參數,如下:

batch_size:batchsize的大小,根據顯存設定。

output_size:輸出的類别個數,本例是2.

hidden_dim:隐藏層的數量。

n_layers:lstm的層數。

bidirectional:是否雙向

print_every:輸出的間隔。

use_cuda:是否使用cuda,預設使用,不用cuda太慢了。

xlnet_path:預訓練模型存放的檔案夾。

save_path:模型儲存的路徑。

下載下傳預訓練模型

本例使用的預訓練模型是xlnet-base-cased,下載下傳位址:https://huggingface.co/hfl/chinese-xlnet-base/tree/main

xlnet+bilstm實作菜品正負評價分類

将上圖畫框的檔案下載下傳下來,如果下載下傳後的名字和上面顯示的名字不一樣,則要修改回來。

将下載下傳好的檔案放入xlnet-base-chinese檔案夾中。

配置環境

需要下載下傳transformers和sentencepiece,執行指令:

訓練、驗證和預測

訓練詳見train_model函數,驗證詳見test_model,單次預測詳見predict函數。

代碼和模型連結: