天天看點

tensorflow中crf子產品函數解析

        這篇部落客要想解釋一下tensorflow中crf子產品的幾個函數的輸入輸出是什麼意思。作為預備知識,建議英文好的同學先看下這篇部落格,這篇部落格有8個小節,前5個小節比較通俗易懂,後3個小節感覺不太了解。當然我也會先講一下bilstm+crf的基本原理,主要講一下模型的損失函數。

一、預備知識

        首先說一下crf的輸入是什麼,crf的輸入就是bilstm的輸出,是一個三維矩陣,[batch_size,max_seq_len,num_tags],其中num_tags是命名實體識别中的标簽個數。也就是每個sequence中的每個單詞都會輸出一個未經歸一化的機率向量,向量長度是num_tags。

        再說一下crf層的參數是什麼,crf層的參數是一個狀态轉移矩陣(num_tags*num_tags),矩陣中的每個元素代表标簽之間轉移的機率。

        最後說一下模型的損失函數,模型的損失函數就是由bilstm輸出的三維矩陣P和狀态轉移矩陣A計算得到。公式如下。

公式1:

tensorflow中crf子產品函數解析

公式2:

tensorflow中crf子產品函數解析

s代表由sequence輸入經過模型計算,得到某個标簽序列的機率,最小化損失函數就是為了最大化真實标簽序列的機率。

二、tensorflow中crf子產品函數解析

<函數輸入的解釋>

tag_indices:[batch_size, max_seq_len],真實的标簽序列;

sequence_lengths:[batch_size],每個序列的長度;

transition_params:[num_tags, num_tags] ,狀态轉移矩陣;

inputs/potentials:[batch_size, max_seq_len, num_tags],bilstm的輸出,同時也是crf層的輸入;

<函數的輸入輸出解釋>

crf_binary_score(二進制機率):

輸入:tag_indices, sequence_lengths, transition_params

輸出:[batch_size],向量中的每個元素是一個sequence中所有的轉移機率加和,公式1中紅色方框中的公式;

crf_unary_score(一進制機率):

輸入:tag_indices, sequence_lengths, inputs

輸出:[batch_size],向量中的每個元素是一個sequence中所有的真實标簽機率加和,公式1中綠色方框中的公式;

crf_sequence_score:

輸入:inputs,tag_indices,sequence_lengths,transition_params

輸出:[batch_size],向量中的每個元素是一個sequence的真實标簽機率加和+所有的轉移機率加和,即:crf_unary_score+crf_binary_score,公式2中紅色方框中的公式;

crf_log_norm: 這個函數的實作比較複雜,tensorflow中的實作原理可參考這裡,有興趣的同學可以看一下。

輸入:inputs,sequence_lengths,transition_params

輸出:[batch_size],公式2中綠色方框中的公式;

crf_log_likelihood:

輸入:inputs,tag_indices,sequence_lengths,

輸出:

log_likelihood:[batch_size],向量中的每個元素為圖二中紅色方框中的公式減去綠色方框中的公式,也就是-logloss;

transition_params:[num_tags, num_tags] ,狀态轉移矩陣;

crf_decode: 這個函數用到了維特比算法。

輸入:potentials,transition_params,sequence_length

輸出:

decode_tags:[batch_size, max_seq_len],所有可能的輸出序列中機率最大的那個序列;

best_score:[batch_size] ,機率最大的序列對應的序列機率;

繼續閱讀