天天看点

focal loss dice loss源码_机器阅读理解run_squad源码研读(pytorch)(下)

本文大部分引用自transformers,另有部分改写参考CAIL2019比赛的部分代码,针对中文。

由于代码量较大,会分为几部分分别阐述:

上:数据处理部分

中:训练、预测部分

下:整体架构及其他

1. 整体架构

focal loss dice loss源码_机器阅读理解run_squad源码研读(pytorch)(下)

其中,evaluate部分可能每个任务都不同,上述evaluate是源码代码(我自己又加了gold_answers部分)、针对英文,一些比赛中会给出自己的evaluate.py。

一些代码中没提到的部分再说一下吧:

1.1 Transformers

Transformers: State-of-the-art Natural Language Processing for Pytorch and TensorFlow 2.0.

用pytorch玩nlp必须要知道的一个github吧,hugging face升级后,已经包含了几乎所有的pytorch版本的transformers模型。

focal loss dice loss源码_机器阅读理解run_squad源码研读(pytorch)(下)

中文nlp预训练模型的话,transformers官方里面支持的很少,但是hugging face是支持其他开发者上传模型的,可以在huggingface官网模型库 搜索中文预训练模型(搜索词:chinese)。

1.2 BertForQuestionAnswering

加载模型的时候会用到下面的代码(也可以用AutoModelForQuestionAnswering,AutoModel会根据模型路径中的关键词来选择,如bert、roberta):

model 
           

BertForQuestionAnswering类的源代码如下:

class 
           

可以看到,BertForQuestionAnswering在BertModel后加了一层全连接层,全连接层的输入是BertModel的第0个输出(sequence_output)。

loss是start_logits的CrossEntropy与end_logits的CrossEntropy的均值。

2. 其他

2.1 yes/no问题(是否问题)

看了一些CAIL2019的解决方案,对于需要回答判断题类型的数据,有以下两种方法:

  • 端到端模型:修改bert模型,在原文末尾加上两个标识符,分别表示'yes' 和 'no'
  • 后接一个分类模型:在bert模型后再加一个分类模型,如textcnn(据说效果比第一种方法略差)
2.2 tensorboard
try
           
至此全部更新完毕。如果还有其他问题可以私我~