摘要
cmu和google brain联手推出了bert的改进版xlnet。在这之前也有很多公司对bert进行了优化,包括百度、清华的知识图谱融合,微软在预训练阶段的多任务学习等等,但是这些优化并没有把bert致命缺点进行改进。xlnet作为bert的升级模型,主要在以下三个方面进行了优化
采用ar模型替代ae模型,解决mask带来的负面影响
双流注意力机制
引入transformer-xl
今天我们使用xlnet+bilstm实现一个二分类模型。
数据集
数据集如下图:
是顾客对餐厅的正负评价。正面的评论是1,负面的是0。这类的数据集很多,比如电影的正负评论,商品的正负评论。
模型
模型结构如下:
思路:将xlnet做为嵌入层提取特征,然后传入bilstm,最后使用全连接层输出分类。创建xlnet_lstm模型,代码如下:
xlnet_lstm需要的参数功6个,参数说明如下:
--xlnetpath:xlnet预训练模型的路径
--hidden_dim:隐藏层的数量。
--output_size:分类的个数。
--n_layers:lstm的层数
--bidirectional:是否是双向lstm
--drop_prob:dropout的参数
定义xlnet的参数,如下:
batch_size:batchsize的大小,根据显存设置。
output_size:输出的类别个数,本例是2.
hidden_dim:隐藏层的数量。
n_layers:lstm的层数。
bidirectional:是否双向
print_every:输出的间隔。
use_cuda:是否使用cuda,默认使用,不用cuda太慢了。
xlnet_path:预训练模型存放的文件夹。
save_path:模型保存的路径。
下载预训练模型
本例使用的预训练模型是xlnet-base-cased,下载地址:https://huggingface.co/hfl/chinese-xlnet-base/tree/main
将上图画框的文件下载下来,如果下载后的名字和上面显示的名字不一样,则要修改回来。
将下载好的文件放入xlnet-base-chinese文件夹中。
配置环境
需要下载transformers和sentencepiece,执行命令:
训练、验证和预测
训练详见train_model函数,验证详见test_model,单次预测详见predict函数。
代码和模型链接: