天天看点

简单QA:TF-IDF句子相似度计算

简单介绍一下基于TF-IDF计算句子相似度,并得到问题对应的答案过程:

  1. 准备好问题文件,答案文件,问题与答案一一对应,例如:
    简单QA:TF-IDF句子相似度计算
    简单QA:TF-IDF句子相似度计算
  2. 对问题文件进行分词、去停用词预处理操作
    简单QA:TF-IDF句子相似度计算
  3. 建立TF-IDF模型,计算所提问题与模板问题中相似度,将满足相似度问题对应的答案返回。关键代码如下:
from gensim import corpora, models, similarities
from preprocess_data import cut_stop_words
import numpy as np
import linecache


def similarity(query_path, query):
    """
    :func: 计算问题与知识库中问题的相似度
    :param query_path: 问题文件所在路径
    :param query: 所提问题
    :return: 返回满足阈值要求的问题所在行索引——对应答案所在的行索引

    """
    class MyCorpus():
        def __iter__(self):
            for line in open(query_path, 'r', encoding='utf-8'):
                 yield line.split()

    Corp = MyCorpus()
    # 建立词典
    dictionary = corpora.Dictionary(Corp)

    # 基于词典,将分词列表集转换成稀疏向量集,即语料库
    corpus = [dictionary.doc2bow(text) for text in Corp]

    # 训练TF-IDF模型,传入语料库进行训练
    tfidf = models.TfidfModel(corpus)

    # 用训练好的TF-IDF模型处理被检索文本,即语料库
    corpus_tfidf = tfidf[corpus]

    # # 得到TF-IDF值
    # for temp in corpus_tfidf:
    #     print(temp)

    vec_bow = dictionary.doc2bow(query.split())
    vec_tfidf = tfidf[vec_bow]

    index = similarities.MatrixSimilarity(corpus_tfidf)
    sims = index[vec_tfidf]
    max_loc = np.argmax(sims)
    max_sim = sims[max_loc]
    # 句子相似度阈值
    sup = 0.7
    # row_index默认为-1,即未匹配到满足相似度阈值的问题
    row_index = -1
    if max_sim > sup:
        # 相似度最大值对应文件中问题所在的行索引
        row_index = max_loc + 1

    return row_index

def get_answer(answer_path, row_index):
    """
    :func: 得到问题对应的答案
    :param answer_path: 答案存储所在文件路径
    :param row_index: 答案的行索引
    :return:
    """
    answer = linecache.getline(answer_path, row_index)
    return answer


if __name__ == '__main__':

    answer_path = '../corpus/sourceCorpus/answers.txt'
    query_path = '../corpus/useCorpus/questions.txt'

    question = '姚明的女儿是谁'
    # 对查询的问题进行处理
    query = cut_stop_words(question)
    query = ' '.join(line for line in query)

    # 得到问题(答案)所对应的行索引
    row_index = similarity(query_path, query)
    answer = get_answer(answer_path, row_index)

    print(answer)