Transformer火到不行的今天,做nlp的應該沒人不知道《Attention Is All You Need》這篇論文。文中提出了一種特殊的attention計算機制:scaled dot-product attention。
今天借此來梳理如何用keras自定義這個layer。
自定義層一定要重寫的基本函數
參考官方文檔自定義層
- __init__: 初始化函數肯定是要有的,并且注意 最後要調用父類的初始化函數
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
這裡的參數是自己給的,但最後要有**kwargs
- build: 必須 傳入input_shape (因為沒見過不傳的)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this at the end
如果這一層沒有自身的trainable parameters,這裡就不需要進行add_weight,隻有需要定義訓練參數的時候,才在build内定義。
另外,在這裡可以做一些checking的檢測,比如我看到ScaledDotProductAttention中就有檢查input_shape是否合法的
def build(self, input_shape):
self._validate_input_shape(input_shape)
super(ScaledDotProductAttention, self).build(input_shape)
并且
最後一定要調用父類的build方法- call: 這個函數就類似于pytorch中的forward函數,把這一次的輸入得到輸出的邏輯都寫在這個函數中
- compute_output_shape: 這個函數 必須傳入input_shape
- get_config: 這個函數在官方教程中沒有重寫,但看到很多人自定義時都會實作這個函數
定義ScaledDotProductAttention層
這裡有幾點需要提前想清楚的:
- 根據paper中的公式(1)可知,該層沒有自身的weights,是以build()中不需要add_weights操作
- 通常來講,attention的層都會得到(1)softmax後attention矩陣本身(2)根據attention權重權重後的output,在這裡output肯定是需要傳回的,attention看自己意願,可以傳回可以不需要傳回
- 根據paper中的一句話: The input consists of queries and keys of dimension dk, and values of dimension dv. 我們知道(1)Q和K的最後一個次元必須相等,(2)這個attention函數必須是Q、K、V三個輸入,(3)這三個輸入的batch size必須相同,(4)這三個輸入的第二次元length必須也相同。以上四點就是驗證input_shape是否合法的依據。
已知Q=[B, L, dk],K=[B, L, dk],V=[B, L, dv]
首先需要計算
但是我們線上代裡學的矩陣相乘都是二維矩陣啊,現在QK都是三維矩陣,怎麼乘呢?keras中有一個叫
batch_dot
的東西就是來算這個的。
現在
=[B, L, L]為了防止dot(共有dk個數字相乘并相加)後的數字太大,除以
。為什麼很大?原文中有這樣解釋到:
再經過softmax,得到attention權重。這樣得到的還是[B, L, L],和V=[B, L, dv]點乘之後得到[B, L, dv],這就是權重後的輸出。
def call(self, x, mask=None):
q, k, v = x
d_k = q.shape.as_list()[2]
weights = K.batch_dot(q, k, axes=[2, 2])
if mask is not None:
# add mask weights
if isinstance(mask, (list, tuple)):
if len(mask) > 0:
raise ValueError("mask can only be a Tensor or a list of length 1 containing a tensor.")
mask = mask[0]
weights -= 1e10 * (1 - mask)
normalized_weights = K.softmax(weights / np.sqrt(d_k))
output = K.batch_dot(normalized_weights, v)
if self._return_attention:
return [output, normalized_weights]
else:
return output
另外,注意到實際實作中會傳入mask。注意,雖然這個layer的輸入有QKV三個,但mask隻傳一個。
mask用在softmax之前,為啥呢?
因為在softmax之前将相對應的位置設定為-1e10,這麼小的數在softmax之後就會變為0,起到mask的作用。
最後說明一下,此處的Q、K、V三者相同時,這個scaled dot-product attention就是
self-attention:
In a self-attention layer all of the keys, values and queries come from the same place