天天看點

簡單QA:TF-IDF句子相似度計算

簡單介紹一下基于TF-IDF計算句子相似度,并得到問題對應的答案過程:

  1. 準備好問題檔案,答案檔案,問題與答案一一對應,例如:
    簡單QA:TF-IDF句子相似度計算
    簡單QA:TF-IDF句子相似度計算
  2. 對問題檔案進行分詞、去停用詞預處理操作
    簡單QA:TF-IDF句子相似度計算
  3. 建立TF-IDF模型,計算所提問題與模闆問題中相似度,将滿足相似度問題對應的答案傳回。關鍵代碼如下:
from gensim import corpora, models, similarities
from preprocess_data import cut_stop_words
import numpy as np
import linecache


def similarity(query_path, query):
    """
    :func: 計算問題與知識庫中問題的相似度
    :param query_path: 問題檔案所在路徑
    :param query: 所提問題
    :return: 傳回滿足門檻值要求的問題所在行索引——對應答案所在的行索引

    """
    class MyCorpus():
        def __iter__(self):
            for line in open(query_path, 'r', encoding='utf-8'):
                 yield line.split()

    Corp = MyCorpus()
    # 建立詞典
    dictionary = corpora.Dictionary(Corp)

    # 基于詞典,将分詞清單集轉換成稀疏向量集,即語料庫
    corpus = [dictionary.doc2bow(text) for text in Corp]

    # 訓練TF-IDF模型,傳入語料庫進行訓練
    tfidf = models.TfidfModel(corpus)

    # 用訓練好的TF-IDF模型處理被檢索文本,即語料庫
    corpus_tfidf = tfidf[corpus]

    # # 得到TF-IDF值
    # for temp in corpus_tfidf:
    #     print(temp)

    vec_bow = dictionary.doc2bow(query.split())
    vec_tfidf = tfidf[vec_bow]

    index = similarities.MatrixSimilarity(corpus_tfidf)
    sims = index[vec_tfidf]
    max_loc = np.argmax(sims)
    max_sim = sims[max_loc]
    # 句子相似度門檻值
    sup = 0.7
    # row_index預設為-1,即未比對到滿足相似度門檻值的問題
    row_index = -1
    if max_sim > sup:
        # 相似度最大值對應檔案中問題所在的行索引
        row_index = max_loc + 1

    return row_index

def get_answer(answer_path, row_index):
    """
    :func: 得到問題對應的答案
    :param answer_path: 答案存儲所在檔案路徑
    :param row_index: 答案的行索引
    :return:
    """
    answer = linecache.getline(answer_path, row_index)
    return answer


if __name__ == '__main__':

    answer_path = '../corpus/sourceCorpus/answers.txt'
    query_path = '../corpus/useCorpus/questions.txt'

    question = '姚明的女兒是誰'
    # 對查詢的問題進行處理
    query = cut_stop_words(question)
    query = ' '.join(line for line in query)

    # 得到問題(答案)所對應的行索引
    row_index = similarity(query_path, query)
    answer = get_answer(answer_path, row_index)

    print(answer)