天天看點

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

特征圖注意力_DGL部落格 | 深入了解圖注意力機制
作者:

張昊、李牧非、王靈活、張峥

圖卷積網絡Graph Convolutional Network (GCN)告訴我們将局部的圖結構和節點特征結合可以在節點分類任務中獲得不錯的表現。美中不足的是GCN結合鄰近節點特征的方式和圖的結構依依相關,這局限了訓練所得模型在其他圖結構上的泛化能力。

Graph Attention Network (GAT)提出了用注意力機制對鄰近節點特征權重求和。鄰近節點特征的權重完全取決于節點特征,獨立于圖結構。

在這個教程裡我們将:

  • 解釋什麼是Graph Attention Network
  • 示範用DGL實作這一模型
  • 深入了解學習所得的注意力權重
  • 初探歸納學習(inductive learning)

難度:★★★★✩ (需要對圖神經網絡訓練和Pytorch有基本了解)

在GCN裡引入注意力機制

GAT和GCN的核心差別在于如何收集并累和距離為1的鄰居節點的特征表示。在GCN裡,一次圖卷積操作包含對鄰節點特征的标準化求和:

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

其中

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

是對節點距離為1鄰節點的集合。我們通常會加一條連接配接節點

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

和它自身的邊使得

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

本身也被包括在

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

裡。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

是一個基于圖結構的标準化常數;

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

是一個激活函數 (GCN使用了ReLU);

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

是節點特征轉換的權重矩陣,被所有節點共享。由于

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

和圖的機構相關,使得在一張圖上學習到的GCN模型比較難直接應用到另一張圖上。解決這一問題的方法有很多,比如GraphSAGE提出了一種采用相同節點特征更新規則的模型,唯一的差別是他們将

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

設為了

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

圖注意力模型GAT用注意力機制替代了圖卷積中固定的标準化操作。以下圖和公式定義了如何對第

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

層節點特征做更新得到第

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

層節點特征:

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

圖注意力網絡示意圖和更新公式

對于上述公式的一些解釋:

  • 公式(1)對
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制
    層節點嵌入
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制
    做了線性變換,
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制
    是該變換可訓練的參數。
  • 公式(2)計算了成對節點間的原始注意力分數。它首先拼接了兩個節點的
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制
    嵌入,注意
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制
    在這裡表示拼接;随後對拼接好的嵌入以及一個可學習的權重向量
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制

    做點積;最後應用了一個LeakyReLU激活函數。這一形式的注意力機制通常被稱為

    加性注意力,差別于Transformer裡的點積注意力。

  • 公式(3)對于一個節點所有入邊得到的原始注意力分數應用了一個softmax操作,得到了注意力權重。
  • 公式(4)形似GCN的節點特征更新規則,對所有鄰節點的特征做了基于注意力的權重求和。

出于簡潔的考量,在本教程中,我們選擇省略了一些論文中的細節,如dropout, skip connection等等。感興趣的讀者們歡迎參閱文末連結的模型完整實作。本質上,GAT隻是将原本的标準化常數替換為使用注意力權重的鄰居節點特征聚合函數。

GAT的DGL實作

以下代碼給讀者提供了在DGL裡實作一個GAT層的總體印象。别擔心,我們會将以下代碼拆分成三塊,并逐塊講解每塊代碼是如何實作上面的一條公式。

import 
           

實作公式(1)

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

第一個公式相對比較簡單。線性變換非常常見。在PyTorch裡,我們可以通過torch.nn.Linear很友善地實作。

實作公式(2)

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

原始注意力權重

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

是基于一對鄰近節點

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

的表示計算得到。我們可以把注意力權重

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

看成在 i->j 這條邊的資料。是以,在DGL裡,我們可以使用 g.apply_edges 這一API來調用邊上的操作,用一個邊上的使用者定義函數來指定具體操作的内容。我們在使用者定義函數裡實作了公式(2)的操作:

def 
           

公式中的點積同樣借由PyTorch的一個線性變換 attn_fc 實作。注意 apply_edges 會把所有邊上的資料打包為一個張量,這使得拼接和點積可以并行完成。

實作公式(3)和(4)

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

類似GCN,在DGL裡我們使用update_all API來觸發所有節點上的消息傳遞函數。update_all接收兩個使用者自定義函數作為參數。message_function發送了兩種張量作為消息:消息原節點的表示以及每條邊上的原始注意力權重。reduce_function随後進行了兩項操作:

  1. 使用softmax歸一化注意力權重 (公式(3))。
  2. 使用注意力權重聚合鄰節點特征 (公式(4))。

這兩項操作都先從節點的 mailbox 擷取了資料,随後在資料的第二維( dim = 1 ) 上進行了運算。注意資料的第一維代表了節點的數量,第二維代表了每個節點收到消息的數量。

def 
           

多頭注意力 (Multi-head attention)

神似卷積神經網絡裡的多通道,GAT引入了多頭注意力來豐富模型的能力和穩定訓練的過程。每一個注意力的頭都有它自己的參數。如何整合多個注意力機制的輸出結果一般有兩種方式:

  • 拼接:
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制
  • 平均:
    特征圖注意力_DGL部落格 | 深入了解圖注意力機制

以上式子中是注意力頭的數量。作者們建議對中間層使用拼接對最後一層使用求平均。

我們之前有定義單頭注意力的GAT層,它可作為多頭注意力GAT層的組建單元:

class 
           

在Cora資料集上訓練一個GAT模型

Cora是經典的文章引用網絡資料集。Cora圖上的每個節點是一篇文章,邊代表文章和文章間的引用關系。每個節點的初始特征是文章的詞袋(Bag of words)表示。其目标是根據引用關系預測文章的類别(比如機器學習還是遺傳算法)。在這裡,我們定義一個兩層的GAT模型:

class 
           

我們使用DGL自帶的資料子產品加載Cora資料集。

from 
           

模型訓練的流程和GCN教程裡的一樣。

import 
           

可視化并了解學到的注意力

Cora資料集

以下表格總結了GAT論文以及dgl實作的模型在Cora資料集上的表現:

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

可以看到DGL能完全複現原論文中的實驗結果。對比圖卷積網絡GCN,GAT在Cora上有2~3個百分點的提升。

不過,

我們的模型究竟學到了怎樣的注意力機制呢?

由于注意力權重與圖上的邊密切相關,我們可以通過給邊着色來可視化注意力權重。以下圖檔中我們選取了Cora的一個子圖并且在圖上畫出了GAT模型最後一層的注意力權重。我們根據圖上節點的标簽對節點進行了着色,根據注意力權重的大小對邊進行了着色(可參考圖右側的色條)。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

Cora資料集上學習到的注意力權重

乍看之下模型似乎學到了不同的注意力權重。為了對注意力機制有一個全局觀念,我們衡量了注意力分布的熵。對于節點,

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

構成了一個在鄰節點上的離散機率分布。它的熵被定義為:

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

直覺的說,熵低代表了機率高度集中,反之亦然。熵為則所有的注意力都被放在一個點上。均勻分布具有最高的熵(

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

)。在理想情況下,我們想要模型習得一個熵較低的分布(即某一、兩個節點比其它節點重要的多)。注意由于節點的入度不同,它們注意力權重的分布所能達到的最大熵也會不同。

基于圖中所有節點的熵,我們畫了所有頭注意力的直方圖。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

Cora資料集上學到的注意力權重直方圖

作為參考,下圖是在所有節點的注意力權重都是均勻分布的情況下得到的直方圖。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

出人意料的,

模型學到的節點注意力權重非常接近均勻分布

(換言之,所有的鄰節點都獲得了同等重視)。這在一定程度上解釋了為什麼在Cora上GAT的表現和GCN非常接近(在上面表格裡我們可以看到兩者的差距平均下來不到)。由于沒有顯著區分節點,注意力并沒有那麼重要。

這是否說明了注意力機制沒什麼用?

不!

在接下來的資料集上我們觀察到了完全不同的現象。

蛋白質互動網絡 (PPI)

PPI(蛋白質間互相作用)資料集包含了24張圖,對應了不同的人體組織。節點最多可以有121種标簽(比如蛋白質的一些性質、所處位置等)。是以節點标簽被表示為有個121元素的二進制張量。資料集的任務是預測節點标簽。

我們使用了20張圖進行訓練,2張圖進行驗證,2張圖進行測試。平均下來每張圖有2372個節點。每個節點有50個特征,包含定位基因集合、特征基因集合以及免疫特征。至關重要的是,測試用圖在訓練過程中對模型完全不可見。這一設定被稱為歸納學習。

我們比較了dgl實作的GAT和GCN在10次随機訓練中的表現。模型的超參數在驗證集上進行了優化。在實驗中我們使用了micro f1 score來衡量模型的表現。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

在訓練過程中,我們使用了 BCEWithLogitsLoss 作為損失函數。下圖繪制了GAT和GCN的學習曲線;顯然GAT的表現遠優于GCN。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

PPI資料集上GCN和GAT學習曲線比較

像之前一樣,我們可以通過繪制節點注意力分布之熵的直方圖來有一個統計意義上的直覺了解。以下我們基于一個3層GAT模型中不同模型層不同注意力頭繪制了直方圖。

第一層學到的注意力

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

第二層學到的注意力

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

最後一層學到的注意力

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

作為參考,下圖是在所有節點的注意力權重都是均勻分布的情況下得到的直方圖。

特征圖注意力_DGL部落格 | 深入了解圖注意力機制

可以很明顯地看到,

GAT在PPI上确實學到了一個尖銳的注意力權重分布

。與此同時,GAT層與層之間的注意力也呈現出一個清晰的模式:在中間層

随着層數的增加注意力權重變得愈發集中

;最後的輸出層由于我們對不同頭結果做了平均,注意力分布再次趨近均勻分布。

不同于在Cora資料集上非常有限的收益,GAT在PPI資料集上較GCN和其它圖模型的變種取得了明顯的優勢(根據原論文的結果在測試集上的表現提升了至少20%)。我們的實驗揭示了GAT學到的注意力顯著差別于均勻分布。雖然這值得進一步的深入研究,一個由此而生的假設是GAT的優勢在于處理更複雜領域結構的能力。

拓展閱讀

到目前為止我們示範了如何用DGL實作GAT。簡介起見,我們忽略了dropout, skip connection等一些細節。這些細節很常見且獨立于DGL相關的概念。有興趣的讀者歡迎參閱完整的代碼實作。

  • 經過優化的完整代碼實作
  • 在下一個教程中我們将介紹如何通過并行多頭注意力和稀疏矩陣向量乘法來加速GAT模型,敬請期待!