天天看点

xlnet+bilstm实现菜品正负评价分类

摘要

cmu和google brain联手推出了bert的改进版xlnet。在这之前也有很多公司对bert进行了优化,包括百度、清华的知识图谱融合,微软在预训练阶段的多任务学习等等,但是这些优化并没有把bert致命缺点进行改进。xlnet作为bert的升级模型,主要在以下三个方面进行了优化

采用ar模型替代ae模型,解决mask带来的负面影响

双流注意力机制

引入transformer-xl

今天我们使用xlnet+bilstm实现一个二分类模型。

数据集

数据集如下图:

xlnet+bilstm实现菜品正负评价分类

是顾客对餐厅的正负评价。正面的评论是1,负面的是0。这类的数据集很多,比如电影的正负评论,商品的正负评论。

模型

模型结构如下:

xlnet+bilstm实现菜品正负评价分类

思路:将xlnet做为嵌入层提取特征,然后传入bilstm,最后使用全连接层输出分类。创建xlnet_lstm模型,代码如下:

xlnet_lstm需要的参数功6个,参数说明如下:

​ --xlnetpath:xlnet预训练模型的路径

​ --hidden_dim:隐藏层的数量。

​ --output_size:分类的个数。

​ --n_layers:lstm的层数

​ --bidirectional:是否是双向lstm

​ --drop_prob:dropout的参数

定义xlnet的参数,如下:

batch_size:batchsize的大小,根据显存设置。

output_size:输出的类别个数,本例是2.

hidden_dim:隐藏层的数量。

n_layers:lstm的层数。

bidirectional:是否双向

print_every:输出的间隔。

use_cuda:是否使用cuda,默认使用,不用cuda太慢了。

xlnet_path:预训练模型存放的文件夹。

save_path:模型保存的路径。

下载预训练模型

本例使用的预训练模型是xlnet-base-cased,下载地址:https://huggingface.co/hfl/chinese-xlnet-base/tree/main

xlnet+bilstm实现菜品正负评价分类

将上图画框的文件下载下来,如果下载后的名字和上面显示的名字不一样,则要修改回来。

将下载好的文件放入xlnet-base-chinese文件夹中。

配置环境

需要下载transformers和sentencepiece,执行命令:

训练、验证和预测

训练详见train_model函数,验证详见test_model,单次预测详见predict函数。

代码和模型链接: