天天看點

【GNN】硬核!一文梳理經典圖網絡模型

作者 | Chilia       

哥倫比亞大學 nlp搜尋推薦   

整理 | NewBeeNLP

圖神經網絡已經在NLP、CV、搜尋推薦廣告等領域廣泛應用,今天我們就來整體梳理一些經典常用的圖網絡模型:DeepWalk、GCN、Graphsage、GAT!

1. DeepWalk [2014]

DeepWalk是來解決圖裡面節點embedding問題的。Graph Embedding技術将圖中的節點以低維稠密向量的形式進行表達,要求在原始圖中相似(不同的方法對相似的定義不同)的節點其在低維表達空間也接近。得到的表達向量可以用來進行下遊任務,如節點分類(node classification),連結預測(link prediction)等。

1.1 DeepWalk 算法原理

雖然DeepWalk是KDD 2014的工作,但卻是我們了解Graph Embedding無法繞過的一個方法。

我們都知道在NLP任務中,word2vec是一種常用的word embedding方法,word2vec通過語料庫中的句子序列來描述詞與詞的共現關系,進而學習到詞語的向量表示。

DeepWalk的思想類似word2vec,使用圖中節點與節點的共現關系來學習節點的向量表示。那麼關鍵的問題就是如何來描述節點與節點的共現關系,DeepWalk給出的方法是使用**随機遊走(RandomWalk)**的方式在圖中進行節點采樣。

RandomWalk是一種可重複通路visited節點的深度優先周遊算法。給定目前通路起始節點,從其鄰居中随機采樣節點作為下一個通路節點,重複此過程,直到通路序列長度 = K。擷取足夠數量的節點通路序列後,使用skip-gram進行向量學習,這樣能夠把握節點的共現資訊。這樣就得到了每個節點的embedding。

2. GCN [2016]

GCN的概念首次提出于ICLR 2017:SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS。

為什麼要用GCN呢?這是因為對于圖結構的資料,CNN、RNN都無法解決。

對于圖檔來說,我們用卷積核來提取特征,這是因為圖檔有平移不變性:一個小視窗無論移動到圖檔的哪一個位置,其内部的結構都是一模一樣的,是以CNN可以實作參數共享。RNN主要用在NLP這種序列資訊上。圖檔,或者語言,都屬于歐式空間的資料,是以才有次元的概念,歐式空間的資料的特點就是結構很規則。

但是圖結構(拓撲結構)如社交網絡、知識圖譜、分子結構等等是十分不規則的,可以認為是無限維的一種資料,是以它沒有平移不變性。每一個節點的周圍結構可能都是獨一無二的,這種結構的資料,就讓傳統的CNN、RNN瞬間失效。

GCN,圖卷積神經網絡,實際上跟CNN的作用一樣,就是一個特征提取器,隻不過它的對象是圖。GCN精妙地設計了一種從圖資料中提取特征的方法,進而讓我們可以使用這些特征去對圖資料進行:

  • 節點分類(node classification)
  • 圖分類(graph classification)
  • 連結預測(link prediction)
2.1 GCN的核心公式

假設我們手頭有一個圖,其中有N個節點,每個節點都有自己的特征embedding,我們設這些節點的特征組成一個N×D維的矩陣 ,然後各個節點之間的關系也會形成一個N×N維的矩陣A(就是鄰接矩陣)

GCN也是一個神經網絡層,它的層與層之間的傳播方式是:

這個公式中:

  • , 是機關矩陣。
  • 是度矩陣(degree matrix),D[i][i]就是節點i的度。
  • H是每一層的特征,對于第一層(輸入層)的話,就是矩陣 。
  • σ是非線性激活函數

用這個公式就可以很好地提取圖的特征。假設我們構造一個兩層的GCN,激活函數分别采用ReLU和Softmax,則整體的正向傳播的公式為:

402 Payment Required

其中, .

那麼, 為什麼這個公式能提取圖的特征呢?

  • A+I 其實是保證對于每個節點,都能夠關注到其所有鄰居節點和自己的embedding。
  • 左右乘上度矩陣D是為了對A做一個标準化處理,讓A的每一行加起來都是1.

當然,原論文中用非常複雜的數學公式做了很多證明,由于筆者數學不好,隻能如此不求甚解的來粗略了解,感興趣的同學可以自行閱讀原論文。

3. GraphSAGE

3.1. GCN的局限

GCN本身有一個局限,即沒法快速表示新節點。GCN需要把所有節點都參與訓練(整個圖都丢進去訓練)才能得到node embedding,如果新node來了,沒法得到新node的embedding。是以說,GCN是transductive的。(Transductive任務是指:訓練階段與測試階段都基于同樣的圖結構)

而GraphSAGE是inductive的。inductive任務是指:訓練階段與測試階段需要處理的graph不同。通常是訓練階段隻是在子圖(subgraph)上進行,測試階段需要處理未知的頂點。

要想得到新節點的表示,需要讓新的node或者subgraph去和已經優化好的node embedding去“對齊”。然而每個節點的表示都是受到其他節點的影響(牽一發而動全身),是以添加一個節點,意味着許許多多與之相關的節點的表示都應該調整。

3.2 GraphSAGE

針對這種問題,GraphSAGE模型提出了一種算法架構,可以很友善地得到新node的表示。

3.2.1 Embedding generation(前向傳播算法)

Embedding generation算法共聚合K次,總共有K個聚合函數(aggregator),可以認為是K層,來聚合鄰居節點的資訊。假如 用來表示第k層每個節點的embedding,那麼如何 從 得到呢?

  • 就是初始的每個節點embedding。
  • 對于每個節點v,都把它随機采樣的若幹鄰居的k-1層的所有向量表示 、以及節點v自己的k-1層表示聚合成一個向量,這樣就得到了第層的表示 。這個聚合方法具體是怎麼做的後面再詳細介紹。

文中描述如下:

【GNN】硬核!一文梳理經典圖網絡模型

随着層數K的增加,可以聚合越來越遠距離的資訊。這是因為,雖然每次選擇鄰居的時候就是從周圍的一階鄰居中均勻地采樣固定個數個鄰居,但是由于節點的鄰居也聚合了其鄰居的資訊,這樣,在下一次聚合時,該節點就會接收到其鄰居的鄰居的資訊,也就是聚合到了二階鄰居的資訊了。這就像社交圖譜中“朋友的朋友”的概念。

3.2.2 聚合函數選擇

  • Mean Pooling:

這個比較好了解,就是目前節點v本身和它所有的鄰居在k-1層的embedding的mean,然後經過MLP+sigmoid

  • LSTM Aggregator:把目前節點v的鄰居随機打亂,輸入到LSTM中。作者的想法是說LSTM的模型capacity更強。但是節點周圍的鄰居明明是沒有順序的,這樣做似乎有不妥。
  • Pooling Aggregator:

把節點v的所有鄰居節點都單獨經過一個MLP+sigmoid得到一個向量,最後把所有鄰居的向量做一個element-wise的max-pooling。

3.2.3 GraphSAGE的參數學習

GraphSAGE的參數就是聚合函數的參數。為了學習這些參數,需要設計合适的損失函數。

對于無監督學習,設計的損失函數應該讓臨近的節點的擁有相似的表示,反之應該表示大不相同。是以損失函數是這樣的:

其中,節點v是和節點u在一定長度的random walk上共現的節點,是以它們的點積要盡可能大;後面這項是采了Q個負樣本,它們的點積要盡可能小。這個loss和skip-gram中的negative sampling如出一轍。

對于有監督學習,可以直接使用cross-entropy loss等正常損失函數。當然,上面的這個loss也可以作為一個輔助loss。

3.3 和GCN的關系

原始GCN的方法,其實和GraphSAGE的Mean Pooling聚合方法是類似的,即每一層都聚合自己和自己鄰居的歸一化embedding表示。而GraphSAGE使用了其他capacity更大的聚合函數而已。

此外,GCN是一口氣把整個圖都丢進去訓練,但是來了一個新的節點就不免又要把整個圖重新訓一次。而GraphSAGE則是在增加了新的節點之後,來增量更新舊的節點,調整整張圖的embedding表示。隻是生成新節點embedding的過程,實施起來相比于GCN更加靈活友善了。

4. GAT (Graph Attention Network)

4.1 GAT的具體做法

對于每個節點,注意力其在鄰居頂點上的注意力。對于頂點 ,逐個計算它的鄰居們和它自己之間的相似系數:

首先一個共享參數 的線性映射對于頂點的特征進行了增維,當然這是一種常見的特征增強(feature augment)方法;之後,對變換後的特征進行了拼接(concatenate);最後 a(·)把拼接後的高維特征映射到一個實數上,作者是通過單層的MLP來實作的。

然後,再對此相關系數用softmax做歸一化:

402 Payment Required

最後,根據計算好的注意力系數,把特征權重求和一下。這也是一種aggregation,隻是和GCN不同,這個aggregation是帶注意力權重的。

402 Payment Required

就是輸出的節點的embedding,融合了其鄰居和自身帶注意力的權重(這裡的注意力是self-attention)。

為了增強特征提取能力,用multi-head attention來進化增強一下:

4.2 與GCN的聯系

GCN與GAT都是将鄰居頂點的特征聚合到中心頂點上(一種aggregate運算)。不同的是GCN利用了拉普拉斯矩陣,GAT利用attention系數。一定程度上而言,GAT會更強,因為 頂點特征之間的相關性被更好地融入到模型中。

GAT适用于有向圖。這是因為GAT的運算方式是逐頂點的運算(node-wise),每一次運算都需要循環周遊圖上的所有頂點來完成。逐頂點運算意味着,擺脫了拉普利矩陣的束縛,使得有向圖問題迎刃而解。也正因如此,GAT适用于inductive任務。與此相反的是,GCN是一種全圖的計算方式,一次計算就更新全圖的節點特征。

- END -

AI基礎下載下傳機器學習交流qq群955171419