天天看点

huggingface使用BERT对自己的数据集进行命名实体识别方法

以此篇文章作为在学习过程中的记录和笔记,由于transformers库的使用非常广泛,而由于文档采用英文书写。简单的使用倒并不难,huggingface有许多已经集成好的数据集和模型,而如果进行中文命名实体识别,库中的数据资源十分稀少,所以如何将自己的数据集使用transformers库来进行BERT的微调成为了难点。在根据自己的需要来导入自己的数据的过程中,发现中文的解读非常稀少,故写此文来分享给大家,也欢迎大家一起探讨。作者首先使用了最经典的命名实体识别conll2003数据集进行了使用,后续看情况会推出中文的方法

其实我在知乎上已经分享过一篇中文命名实体直接使用的方法,这个方法是根据huggingface的官方文档,一个代码段一个代码段修改出来的,其实已经可以使用了,但是我看到了huggingface给出的example文件夹里有token classification的任务例子,而自己却不会根据自己需求来使用,非常的不爽,现在终于能知道怎么用了。

huggingface使用BERT对自己的数据集进行命名实体识别方法

首先是带着大家看一下这个文件夹的目录,非常直观,安装requirements文件里所需要的库,然后在linux环境下bash run.sh就可以跑一个默认的conll2003的数据集,这也是我挑选conll2003数据来测试的原因,因为可以和正确的对比自己在哪里出了问题。

其实run.sh就是加入了默认参数来运行了run_ner.py,其他的文件就都可以不用看了,其实在调试程序的过程中,只要你的数据输入形式没问题,后面的模型程序都会运行的很顺畅,而数据集如何让程序按照他需要的形式读入是我们主要要解决的。

数据集我下载了https://github.com/davidsbatista/NER-datasets里的conll2003数据集,三个txt文件,命名就简单命名为conlltrain.txt之类的,然后放在这个目录的data文件夹下。

经过我在运行run_ner一系列的报错和修改后,把一些需要的注释的地方和修改的地方先给大家标出来。

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.10.0.dev0")
           

上面这段是版本确认的,先注释

def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        self.task_name = self.task_name.lower()
           

这段大概就是你要把数据的形式这么输入进去,我就是被这个CSV和json给困扰了好久,一直想着怎么变成这种形式就可以用了,发现根本没用,就算是csv形式的数据,也会因为数据的本身属性改变导致他在后面提取特征时出问题。给我的直观感觉就是数据经过了有损压缩,所有的数据都成了字符形式的内容,而在后面程序仍然需要数据有一些类的属性。所以这段也需要注释掉,我们可以在后面把数据输进去。让这段程序不会因为我们一开始没把文件告诉他就报错。

if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
        extension = data_args.train_file.split(".")[-1]
        raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.
           

上面这段基本是最最重要的,如果是从他们自己集成的库里导入的数据,那么就可以很容易的在后面使用,而如果是自己的数据,就会在这里和原数据发生分歧,所以改成下面的形式。我们来将txt文件完美融入huggingface的数据格式。

if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
    else:
        #data_files = {}
        #if data_args.train_file is not None:
        #    data_files["train"] = data_args.train_file
        #if data_args.validation_file is not None:
        #    data_files["validation"] = data_args.validation_file
        #if data_args.test_file is not None:
        #    data_files["test"] = data_args.test_file
        #extension = data_args.train_file.split(".")[-1]
        
        raw_datasets = load_dataset('conll2003.py')
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.
    
           

在这里我们可以看到有一个conll2003.py。这个文件就是赋予txt文件各种类属性的神奇python文件。这个大家可以先看一下https://blog.csdn.net/qq_42388742/article/details/114293746对这个文件的解读和理解。然后参照着理解我的conll2003.py文件。

import csv
import json
import os

import datasets


# TODO: Add BibTeX citation
# Find for instance the citation on arxiv or on the dataset repo/website
_CITATION = """\
@InProceedings{huggingface:dataset,
title = {A great new dataset},
author={huggingface, Inc.
},
year={2020}
}
"""

# TODO: Add description of the dataset here
# You can copy an official description
_DESCRIPTION = """\
This new dataset is designed to solve this great NLP task and is crafted with a lot of care.
"""

# TODO: Add a link to an official homepage for the dataset here
_HOMEPAGE = ""

# TODO: Add the licence for the dataset here if you can find it
_LICENSE = ""

# TODO: Add link to the official dataset URLs here
# The HuggingFace dataset library don't host the datasets but only point to the original files
# This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
_URLs = {
    'train': "./data/conlltrain.txt",
    'test': "./data/conlltest.txt",
    'dev': "./data/conllvalid.txt",
}

class Conll2003Config(datasets.BuilderConfig):
    """BuilderConfig for Conll2003"""

    def __init__(self, **kwargs):
        """BuilderConfig forConll2003.
        Args:
          **kwargs: keyword arguments forwarded to super.
        """
        super(Conll2003Config, self).__init__(**kwargs)
        

# TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
class Conll2003(datasets.GeneratorBasedBuilder):
    """TODO: Short description of my dataset."""

    VERSION = datasets.Version("1.1.0")

    # This is an example of a dataset with multiple configurations.
    # If you don't want/need to define several sub-sets in your dataset,
    # just remove the BUILDER_CONFIG_CLASS and the BUILDER_CONFIGS attributes.

    # If you need to make complex sub-parts in the datasets with configurable options
    # You can create your own builder configuration class to store attribute, inheriting from datasets.BuilderConfig
    # BUILDER_CONFIG_CLASS = MyBuilderConfig

    # You will be able to load one or the other configurations in the following list with
    # data = datasets.load_dataset('my_dataset', 'first_domain')
    # data = datasets.load_dataset('my_dataset', 'second_domain')
    BUILDER_CONFIGS = [
        datasets.BuilderConfig(name="conll2003", version=VERSION, description="This part of my dataset covers a first domain"),
    ]


    def _info(self):
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=_DESCRIPTION,
            # This defines the different columns of the dataset and their types
            features=datasets.Features(
                {
                    #'unusedid':datasets.Value("string"),
                    "id": datasets.Value("string"),
                    "tokens": datasets.Sequence(datasets.Value("string")),
                    "pos_tags": datasets.Sequence(
                        datasets.features.ClassLabel(
                            names=[
                                '"',
                                "''",
                                "#",
                                "$",
                                "(",
                                ")",
                                ",",
                                ".",
                                ":",
                                "``",
                                "CC",
                                "CD",
                                "DT",
                                "EX",
                                "FW",
                                "IN",
                                "JJ",
                                "JJR",
                                "JJS",
                                "LS",
                                "MD",
                                "NN",
                                "NNP",
                                "NNPS",
                                "NNS",
                                "NN|SYM",
                                "PDT",
                                "POS",
                                "PRP",
                                "PRP$",
                                "RB",
                                "RBR",
                                "RBS",
                                "RP",
                                "SYM",
                                "TO",
                                "UH",
                                "VB",
                                "VBD",
                                "VBG",
                                "VBN",
                                "VBP",
                                "VBZ",
                                "WDT",
                                "WP",
                                "WP$",
                                "WRB",
                            ]
                        )
                    ),
                    "chunk_tags": datasets.Sequence(
                        datasets.features.ClassLabel(
                            names=[
                                "O",
                                "B-ADJP",
                                "I-ADJP",
                                "B-ADVP",
                                "I-ADVP",
                                "B-CONJP",
                                "I-CONJP",
                                "B-INTJ",
                                "I-INTJ",
                                "B-LST",
                                "I-LST",
                                "B-NP",
                                "I-NP",
                                "B-PP",
                                "I-PP",
                                "B-PRT",
                                "I-PRT",
                                "B-SBAR",
                                "I-SBAR",
                                "B-UCP",
                                "I-UCP",
                                "B-VP",
                                "I-VP",
                            ]
                        )
                    ),
                    "ner_tags": datasets.Sequence(
                        datasets.features.ClassLabel(
                            names=[
                                "O",
                                "B-PER",
                                "I-PER",
                                "B-ORG",
                                "I-ORG",
                                "B-LOC",
                                "I-LOC",
                                "B-MISC",
                                "I-MISC",
                            ]
                        )
                    ),
                }
            ),
            # Here we define them above because they are different between the two configurations
            # If there's a common (input, target) tuple from the features,
            # specify them here. They'll be used if as_supervised=True in
            # builder.as_dataset.
            supervised_keys=None,
            # Homepage of the dataset for documentation
            homepage=_HOMEPAGE,
            # License for the dataset if available
            license=_LICENSE,
            # Citation for the dataset
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        """Returns SplitGenerators."""
        # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
        # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name

        # dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLs
        # It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files.
        # By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive
        my_urls = _URLs
        data_dir = dl_manager.download_and_extract(my_urls)
        print(data_dir)
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "filepath": os.path.join(data_dir["train"]),
                    "split": "train",
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "filepath": os.path.join(data_dir["test"]),
                    "split": "test"
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.VALIDATION,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "filepath": os.path.join(data_dir["dev"]),
                    "split": "dev",
                },
            ),
        ]

    def _generate_examples(
        self, filepath, split  # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
    ):
        """ Yields examples as (key, example) tuples. """
        # This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
        # The `key` is here for legacy reason (tfds) and is not important in itself.

        with open(filepath, encoding="utf-8") as f:
            guid = 0
            tokens = []
            pos_tags = []
            chunk_tags = []
            ner_tags = []
            for line in f:
                if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                    if tokens:
                        yield guid, {
                            "id": str(guid),
                            "tokens": tokens,
                            "pos_tags": pos_tags,
                            "chunk_tags": chunk_tags,
                            "ner_tags": ner_tags,
                        }
                        guid += 1
                        tokens = []
                        pos_tags = []
                        chunk_tags = []
                        ner_tags = []
                else:
                    # conll2003 tokens are space separated
                    splits = line.split(" ")
                    tokens.append(splits[0])
                    pos_tags.append(splits[1])
                    chunk_tags.append(splits[2])
                    ner_tags.append(splits[3].rstrip())
            # last example
            yield guid, {
                "id": str(guid),
                "tokens": tokens,
                "pos_tags": pos_tags,
                "chunk_tags": chunk_tags,
                "ner_tags": ner_tags,
            }
           

大家如果对下面文件不太理解,感兴趣可以在下方评论,我可以之后再更新一下对下面文件的解读,其实看完原版文档,结合着这个程序看也大概能理解这个程序在做什么,我做的就是照猫画虎,修改了一下自己的文件位置之类的。

最后,欢迎大家和我一起交流学习!

继续阅读