天天看點

keras 自定義層input_keras定義ScaledDotProductAttention層

Transformer火到不行的今天,做nlp的應該沒人不知道《Attention Is All You Need》這篇論文。文中提出了一種特殊的attention計算機制:scaled dot-product attention。

keras 自定義層input_keras定義ScaledDotProductAttention層

今天借此來梳理如何用keras自定義這個layer。

自定義層一定要重寫的基本函數

參考官方文檔自定義層

  • __init__: 初始化函數肯定是要有的,并且注意 最後要調用父類的初始化函數
def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
           

這裡的參數是自己給的,但最後要有**kwargs

  • build: 必須 傳入input_shape (因為沒見過不傳的)
def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                                  shape=(input_shape[1], self.output_dim),
                                  initializer='uniform',
                                  trainable=True)
    super(MyLayer, self).build(input_shape)  # Be sure to call this at the end
           

如果這一層沒有自身的trainable parameters,這裡就不需要進行add_weight,隻有需要定義訓練參數的時候,才在build内定義。

另外,在這裡可以做一些checking的檢測,比如我看到ScaledDotProductAttention中就有檢查input_shape是否合法的

def build(self, input_shape):
    self._validate_input_shape(input_shape)
    super(ScaledDotProductAttention, self).build(input_shape)
           

并且

最後一定要調用父類的build方法
  • call: 這個函數就類似于pytorch中的forward函數,把這一次的輸入得到輸出的邏輯都寫在這個函數中
  • compute_output_shape: 這個函數 必須傳入input_shape
  • get_config: 這個函數在官方教程中沒有重寫,但看到很多人自定義時都會實作這個函數

定義ScaledDotProductAttention層

這裡有幾點需要提前想清楚的:

  1. 根據paper中的公式(1)可知,該層沒有自身的weights,是以build()中不需要add_weights操作
  2. 通常來講,attention的層都會得到(1)softmax後attention矩陣本身(2)根據attention權重權重後的output,在這裡output肯定是需要傳回的,attention看自己意願,可以傳回可以不需要傳回
  3. 根據paper中的一句話: The input consists of queries and keys of dimension dk, and values of dimension dv. 我們知道(1)Q和K的最後一個次元必須相等,(2)這個attention函數必須是Q、K、V三個輸入,(3)這三個輸入的batch size必須相同,(4)這三個輸入的第二次元length必須也相同。以上四點就是驗證input_shape是否合法的依據。
keras 自定義層input_keras定義ScaledDotProductAttention層

已知Q=[B, L, dk],K=[B, L, dk],V=[B, L, dv]

首先需要計算

keras 自定義層input_keras定義ScaledDotProductAttention層

但是我們線上代裡學的矩陣相乘都是二維矩陣啊,現在QK都是三維矩陣,怎麼乘呢?keras中有一個叫

batch_dot

的東西就是來算這個的。

現在

keras 自定義層input_keras定義ScaledDotProductAttention層

=[B, L, L]為了防止dot(共有dk個數字相乘并相加)後的數字太大,除以

keras 自定義層input_keras定義ScaledDotProductAttention層

。為什麼很大?原文中有這樣解釋到:

keras 自定義層input_keras定義ScaledDotProductAttention層

再經過softmax,得到attention權重。這樣得到的還是[B, L, L],和V=[B, L, dv]點乘之後得到[B, L, dv],這就是權重後的輸出。

def call(self, x, mask=None):
    q, k, v = x
    d_k = q.shape.as_list()[2]

    weights = K.batch_dot(q, k, axes=[2, 2])

    if mask is not None:
        # add mask weights
        if isinstance(mask, (list, tuple)):
            if len(mask) > 0:
                raise ValueError("mask can only be a Tensor or a list of length 1 containing a tensor.")

            mask = mask[0]

        weights -= 1e10 * (1 - mask)

    normalized_weights = K.softmax(weights / np.sqrt(d_k))
    output = K.batch_dot(normalized_weights, v)

    if self._return_attention:
        return [output, normalized_weights]
    else:
        return output
           

另外,注意到實際實作中會傳入mask。注意,雖然這個layer的輸入有QKV三個,但mask隻傳一個。

mask用在softmax之前,為啥呢?

因為在softmax之前将相對應的位置設定為-1e10,這麼小的數在softmax之後就會變為0,起到mask的作用。

最後說明一下,此處的Q、K、V三者相同時,這個scaled dot-product attention就是

self-attention

In a self-attention layer all of the keys, values and queries come from the same place