天天看點

圖神經網絡 消息傳遞圖神經網絡

消息傳遞範式是一種聚合鄰接節點資訊來更新中心節點資訊的範式,它将卷積算子推廣到了不規則資料領域,實作了圖與神經網絡的連接配接。

此範式包含三個步驟:(1)鄰接節點資訊變換、(2)鄰接節點資訊聚合到中心節點、(3)聚合資訊變換。

1.1 消息傳遞範式介紹

用 x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k−1)​∈RF表示 ( k − 1 ) (k-1) (k−1)層中節點 i i i的節點特征, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,i​∈RD 表示從節點 j j j到節點 i i i的邊的特征,消息傳遞圖神經網絡可以描述為

x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)​=γ(k)(xi(k−1)​,□j∈N(i)​ϕ(k)(xi(k−1)​,xj(k−1)​,ej,i​)),

其中 □ \square □表示可微分的、具有排列不變性(函數輸出結果與輸入參數的排列無關)的函數。具有排列不變性的函數有,和函數、均值函數和最大值函數。 γ \gamma γ和 ϕ \phi ϕ表示可微分的函數,如MLPs(多層感覺器)。此處内容來源于CREATING MESSAGE PASSING NETWORKS。

1.2 MessagePassing基類

-該基類封裝了消息傳遞的運作流程

  • aggr:聚合方案,flow:消息傳遞的流向,node_dim:傳播的具體次元
  • MessagePassing.propagate():開始傳遞消息的起始調用
  • MessagePassing.message():實作 ϕ \phi ϕ函數
  • MessagePassing.aggregate():從源節點傳遞過來的消息聚合在目标節點上的函數,使用sum,mean和max
  • MessagePassing.update():實作 γ \gamma γ函數

GCNConv的數學定義為

圖神經網絡 消息傳遞圖神經網絡

步驟1-3通常是在消息傳遞發生之前計算的。步驟4-5可以使用

MessagePassing

基類輕松處理。該層的全部實作如下所示。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # 使用add聚合
        # flow='source_to_target' 表示消息從源節點傳播到目标節點
        #線性變換
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x 次元 [N, in_channels]
        # 鄰接矩陣次元是 [2, E]

        # Step 1: 向鄰接矩陣添加自循環邊
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: 對節點表征進行線性變換
       	#這個線性變換是np.dot(X,w)即可進行線性降維,改變特征的次元,其中這個w權重矩陣是随機生成的
        x = self.lin(x)

        # Step 3:計算歸一化系數
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] #這個得到的便是每一條邊的标準化系數*這條邊target這一端的節點特征

        # Step 4-5: 調用propagate函數,開啟消息傳遞
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j 次元是[E, out_channels]
        # Step 4: 将x_j進行歸一化
        return norm.view(-1, 1) * x_j
           

通過以上學習便掌握了建立一個僅包含依次“消息傳遞過程”的圖神經網絡的方法。

#這個含義即是将1433維的dataset進行降維到64次元
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
#torch.Size([2708, 64])
           

前向傳播的demo

圖神經網絡 消息傳遞圖神經網絡
# 随機種子
torch.manual_seed(0)

# 定義邊
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

# 定義節點特征,每個節點特征次元是2
x = torch.tensor([[-1, 2], [0, 4], [1, 5]], dtype=torch.float)

# 建立一層GCN層,并把特征次元從2維降到1維
conv = GCNConv(2, 1)

# 前向傳播
x = conv(x, edge_index)
print(x)
print(conv.lin.weight)

           
tensor([[0.4728],
        [0.9206],
        [1.0365]], grad_fn=<ScatterAddBackward>)
Parameter containing:
tensor([[-0.0053,  0.3793]], requires_grad=True)
           

1.3 MessagePassing基類剖析

在__init__()方法中,我們看到程式會檢查子類是否實作了message_and_aggregate()方法,并将檢查結果指派給fuse屬性。

class MessagePassing(torch.nn.Module):
	def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2):
        super(MessagePassing, self).__init__()
		# 此處省略n行代碼
        # Support for "fused" message passing.
        self.fuse = self.inspector.implements('message_and_aggregate')
		# 此處省略n行代碼
           

消息傳遞過程是從propagate方法被調用開始的

class MessagePassing(torch.nn.Module):
    # 此處省略n行代碼
    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
    	# 此處省略n行代碼
        # Run "fused" message and aggregation (if applicable).
        if (isinstance(edge_index, SparseTensor) and self.fuse and not self.__explain__):
            coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs)

            msg_aggr_kwargs = self.inspector.distribute('message_and_aggregate', coll_dict)
            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)

            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)
        # Otherwise, run both functions in separation.
        elif isinstance(edge_index, Tensor) or not self.fuse:
            coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)

            msg_kwargs = self.inspector.distribute('message', coll_dict)
            out = self.message(**msg_kwargs)
    		# 此處省略n行代碼
            aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
            out = self.aggregate(out, **aggr_kwargs)

            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)
           

參數簡介:

  • edge_index: 邊端點索引,它可以是Tensor類型或SparseTensor類型。
    • 當flow="source_to_target"時,節點edge_index[0]的資訊将被傳遞到節點edge_index[1],
    • 當flow="target_to_source"時,節點edge_index[1]的資訊将被傳遞到節點edge_index[0]
  • size: 鄰接節點的數量與中心節點的數量。
    • 對于普通圖,鄰接節點的數量與中心節點的數量都是N,我們可以不給size傳參數,即讓size取值為預設值None。
    • 對于二部圖,鄰接節點的數量與中心節點的數量分别記為M, N,于是我們需要給size參數傳一個元組(M, N)。
  • kwargs: 圖其他屬性或額外的資料。

propagate()方法首先檢查edge_index是否為SparseTensor類型以及是否子類實作了message_and_aggregate()方法,如是就執行子類的message_and_aggregate方法;否則依次執行子類的message(),aggregate(),update()三個方法

1.4 message方法的覆寫

前面我們介紹了,傳遞給propagate()方法的參數,如果是節點的屬性的話,可以被拆分成屬于中心節點的部分和屬于鄰接節點的部分,隻需在變量名後面加上_i或_j。現在我們有一個額外的節點屬性,節點的度deg,我們希望meassge方法還能接收中心節點的度,我們對前面GCNConv的message方法進行改造得到新的GCNConv類。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息從源節點傳播到目标節點
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_i):
        # x_j has shape [E, out_channels]
        # deg_i has shape [E, 1]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j * deg_i


from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
           

1.5 aggregate方法的覆寫

在前面的例子的基礎上,我們增加如下的aggregate方法。通過觀察運作結果我們可以看到,我們覆寫的aggregate方法被調用,同時在super(GCNConv, self)._init_(aggr=‘add’)中傳遞給aggr參數的值被存儲到了self.aggr屬性中。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息從源節點傳播到目标節點
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_i):
        # x_j has shape [E, out_channels]
        # deg_i has shape [E, 1]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j * deg_i

    def aggregate(self, inputs, index, ptr, dim_size):
        print('self.aggr:', self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
        

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
           

1.6 message_and_aggregate方法的覆寫

在一些案例中,“消息傳遞”與“消息聚合”可以融合在一起。對于這種情況,我們可以覆寫message_and_aggregate方法,在message_and_aggregate方法中一塊實作“消息傳遞”與“消息聚合”,這樣能使程式的運作更加高效。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_sparse import SparseTensor

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息從源節點傳播到目标節點
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此處傳的不再是edge_idex,而是SparseTensor類型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_i):
        # x_j has shape [E, out_channels]
        # deg_i has shape [E, 1]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j * deg_i

    def aggregate(self, inputs, index, ptr, dim_size):
        print('self.aggr:', self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')
        # 沒有實作真實的消息傳遞與消息聚合的操作
 
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
# print(h_nodes.shape)
           

1.7 update方法的覆寫

from torch_geometric.datasets import Planetoid
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_sparse import SparseTensor


class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息從源節點傳播到目标節點
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此處傳的不再是edge_idex,而是SparseTensor類型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_i):
        # x_j has shape [E, out_channels]
        # deg_i has shape [E, 1]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j * deg_i

    def aggregate(self, inputs, index, ptr, dim_size):
        print('self.aggr:', self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')
        # 沒有實作真實的消息傳遞與消息聚合的操作

    def update(self, inputs, deg):
        print(deg)
        return inputs


dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
# print(h_nodes.shape)
           

update方法接收聚合的輸出作為第一個參數,此外還可以接收傳遞給propagate方法的任何參數。在上方的代碼中,我們覆寫的update方法接收了聚合的輸出作為第一個參數,此外接收了傳遞給propagate的deg參數。

作業

  1. 請總結

    MessagePassing

    類的運作流程以及繼承

    MessagePassing

    類的規範。
  2. 請繼承

    MessagePassing

    類來自定義以下的圖神經網絡類,并進行測試:
    1. 第一個類,覆寫

      message

      函數,要求該函數接收消息傳遞源節點屬性

      x

      、目标節點度

      d

    2. 第二個類,在第一個類的基礎上,再覆寫

      aggregate

      函數,要求不能調用

      super

      類的

      aggregate

      函數,并且不能直接複制

      super

      類的

      aggregate

      函數内容。
    3. 第三個類,在第二個類的基礎上,再覆寫

      update

      函數,要求對節點資訊做一層線性變換。
    4. 第四個類,在第三個類的基礎上,再覆寫

      message_and_aggregate

      函數,要求在這一個函數中實作前面

      message

      函數和

      aggregate

      函數的功能。
from torch_geometric.datasets import Planetoid
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_sparse import SparseTensor


class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此處傳的不再是edge_idex,而是SparseTensor類型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_i):
        # x_j has shape [E, out_channels]
        # deg_i has shape [E, 1]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j * deg_i

    def aggregate(self, inputs, index, ptr, dim_size):
        print('self.aggr:', self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')
        # 沒有實作真實的消息傳遞與消息聚合的操作

    def update(self, inputs, deg):
        print(deg)
        return inputs


dataset = Planetoid(root='./codes/data/Planetoid/', name='Cora')
data = dataset[0]

net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
           

參考連結 了解GCN的整個算法流程:

https://blog.csdn.net/qq_41987033/article/details/103377561

繼續閱讀