一、引言
結(jié)點(diǎn)表征的生成是圖結(jié)點(diǎn)預(yù)測和邊預(yù)測任務(wù)成功的關(guān)鍵。
基于圖神經(jīng)網(wǎng)絡(luò)的結(jié)點(diǎn)表征學(xué)習(xí)可以理解為對(duì)圖神經(jīng)網(wǎng)絡(luò)進(jìn)行基于監(jiān)督學(xué)習(xí)的訓(xùn)練,使得圖神經(jīng)網(wǎng)絡(luò)學(xué)會(huì)產(chǎn)生高質(zhì)量的結(jié)點(diǎn)表征。
在結(jié)點(diǎn)預(yù)測任務(wù)中,一個(gè)圖,圖上有很多節(jié)點(diǎn),部分節(jié)點(diǎn)的標(biāo)簽已知,剩余節(jié)點(diǎn)的標(biāo)簽未知。將節(jié)點(diǎn)的屬性(x)、邊的端點(diǎn)信息(edge_index)、邊的屬性(edge_attr)輸入到多層圖神經(jīng)網(wǎng)絡(luò),經(jīng)過圖神經(jīng)網(wǎng)絡(luò)每一層的一次結(jié)點(diǎn)間信息傳遞,圖神經(jīng)網(wǎng)絡(luò)為結(jié)點(diǎn)生成結(jié)點(diǎn)表征。
任務(wù)為:根據(jù)結(jié)點(diǎn)的屬性(可以是類別型、也可以是數(shù)值型)、邊的信息、邊的屬性、已知的結(jié)點(diǎn)預(yù)測標(biāo)簽,對(duì)未知標(biāo)簽的結(jié)點(diǎn)做預(yù)測。
具體舉例:以Cora數(shù)據(jù)集為例進(jìn)行說明,Cora是一個(gè)論文引用網(wǎng)絡(luò),結(jié)點(diǎn)代表論文,如果兩篇論文存在引用關(guān)系,那么認(rèn)為對(duì)應(yīng)的兩個(gè)結(jié)點(diǎn)之間存在邊,每個(gè)結(jié)點(diǎn)由一個(gè)1433維的詞包特征向量描述。任務(wù)是推斷每個(gè)文檔的類別(共7類)。
通過結(jié)點(diǎn)分類任務(wù)來比較MLP和GCN, GAT三者的節(jié)點(diǎn)表征學(xué)習(xí)能力。
二、準(zhǔn)備工作
獲取并分析數(shù)據(jù)集
實(shí)現(xiàn)代碼如下:
from torch_geometric.datasetsimport Planetoid
from torch_geometric.transformsimport NormalizeFeatures
dataset=Planetoid(root='data/Planetoid',name='Cora',transform=NormalizeFeatures())
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0]# Get the first graph object.
print()
print(data)
print('======================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
運(yùn)行結(jié)果如下:

根據(jù)運(yùn)行結(jié)果可以得出:Cora圖擁有2,708個(gè)結(jié)點(diǎn)和10,556條邊,平均結(jié)點(diǎn)度為3.9,共140個(gè)有真實(shí)標(biāo)簽的節(jié)點(diǎn)(每類20個(gè))用于訓(xùn)練,有標(biāo)簽的結(jié)點(diǎn)的比例占5%。進(jìn)一步可以看到,這個(gè)圖是無向圖,不存在孤立的節(jié)、結(jié)點(diǎn)(即每個(gè)文檔至少有一個(gè)引文)。
數(shù)據(jù)轉(zhuǎn)換在將數(shù)據(jù)輸入到神經(jīng)網(wǎng)絡(luò)之前修改數(shù)據(jù),這一功能可用于實(shí)現(xiàn)數(shù)據(jù)規(guī)范化或數(shù)據(jù)增強(qiáng),使用NormalizeFeatures()方法。
可視化結(jié)點(diǎn)表征分布的方法
為了實(shí)現(xiàn)結(jié)點(diǎn)表征分布的可視化,我們先利用TANE將高維結(jié)點(diǎn)表征嵌入到二維平面空間,然后在二維平面空間畫出節(jié)點(diǎn)。
import matplotlib.pyplotas plt
from sklearn.manifoldimport TSNE
def visualize(h, color):
? ? ? ? ? ? z = TSNE(n_components=2).fit_transform(h.output.detach().cpu().numpy())
? ? ? ? ? ? plt.figure(figsize=(10,10))
? ? ? ? ? ? plt.xticks([])
? ? ? ? ? ? plt.yticks([])
? ? ? ? ? ? plt.scatter(z[:,0], z[:,1],s=70,c=color,cmap="Set2")
? ? ? ? ? ? plt.show()
三、MLP在圖結(jié)點(diǎn)分類中的應(yīng)用
構(gòu)建一個(gè)簡單的MLP,該網(wǎng)絡(luò)只對(duì)輸入結(jié)點(diǎn)的特征進(jìn)行操作,在所有節(jié)點(diǎn)之間共享權(quán)重。
MLP圖結(jié)點(diǎn)分類器:
import torch
from torch.nnimport Linear
import torch.nn.functionalas F
from torch_geometric.datasetsimport Planetoid
from torch_geometric.transformsimport NormalizeFeatures
dataset=Planetoid(root='data/Planetoid',name='Cora',transform=NormalizeFeatures())
data = dataset[0]
class MLP(torch.nn.Module):
? ? ? ? ? ? ?def __init__(self, hidden_channels):
? ? ? ? ? ? ? ? ? ? ?super(MLP,self).__init__()
? ? ? ? ? ? ? ? ? ? ?torch.manual_seed(12345)
? ? ? ? ? ? ? ? ? ? ?self.lin1 = Linear(dataset.num_features, hidden_channels)
? ? ? ? ? ? ? ? ? ? ?self.lin2 = Linear(hidden_channels, dataset.num_classes)
? ? ? ? ? ? def forward(self, x):
? ? ? ? ? ? ? ? ? ? ? x =self.lin1(x)
? ? ? ? ? ? ? ? ? ? ? x = x.relu()
? ? ? ? ? ? ? ? ? ? ? ?x = F.dropout(x,p=0.5,training=self.training)
? ? ? ? ? ? ? ? ? ? ? ?x =self.lin2(x)
? ? ? ? ? ? ? ? ? ? ? ?return x
model = MLP(hidden_channels=16)
print(model)
運(yùn)行結(jié)果如下:

根據(jù)運(yùn)行結(jié)果可以得出:MLP由兩個(gè)線程層、一個(gè)ReLU非線性層和一個(gè)dropout操作組成。第一線程層將1433維的特征向量嵌入(embedding)到低維空間中(hidden_channels=16),第二個(gè)線性層將節(jié)點(diǎn)表征嵌入到類別空間中(num_classes=7)。
利用交叉熵?fù)p失和Adam優(yōu)化器來訓(xùn)練MLP網(wǎng)絡(luò)
model= MLP(hidden_channels=16)
criterion = torch.nn.CrossEntropyLoss()# Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
def train():
? ? ? ? ? ? model.train()
? ? ? ? ? ? optimizer.zero_grad()# Clear gradients.
? ? ? ? ? ? ?out =model(data.x)# Perform a single forward pass.
? ? ? ? ? ? ?loss = criterion(out[data.train_mask], data.y[data.train_mask])
? ? ? ? ? ? ?# Compute the loss solely based on the training nodes.
? ? ? ? ? ? ?loss.backward()# Derive gradients.
? ? ? ? ? ? ?optimizer.step()# Update parameters based on gradients.
? ? ? ? ? ? ? return loss
for epochin range(1,201):
? ? ? ? ? ? ?loss = train()
? ? ? ? ? ? ?print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
運(yùn)行結(jié)果如下:

MLP測試
測試這個(gè)MLP神經(jīng)網(wǎng)絡(luò)在測試集上的表現(xiàn)

MLP只有大約59%的測試準(zhǔn)確性。不準(zhǔn)確的一個(gè)重要原因是,用于訓(xùn)練此神經(jīng)網(wǎng)絡(luò)的有標(biāo)簽結(jié)點(diǎn)數(shù)量過少,此神經(jīng)網(wǎng)絡(luò)被過擬合,對(duì)未見過的節(jié)點(diǎn)泛化性很差。
四、GCN及其在圖節(jié)點(diǎn)分類任務(wù)中的應(yīng)用
GCN的定義
GCN神經(jīng)網(wǎng)絡(luò)的數(shù)學(xué)的定義為:
其中表示插入自環(huán)的鄰接矩陣,
表示其對(duì)角線度矩陣。鄰接矩陣可以包括不為1的值,當(dāng)鄰接矩陣不為{0,1}值時(shí),表示鄰接矩陣存儲(chǔ)的是邊的權(quán)重。
為對(duì)稱歸一化矩陣。
PyG中GCNConv模塊說明
GCNConv構(gòu)造函數(shù)接口:
GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)
in_channels:輸入數(shù)據(jù)維度;
out_channels:輸出數(shù)據(jù)維度;
improved:如果為true,其目的在于增強(qiáng)中心節(jié)點(diǎn)自身信息;
cached:是否存儲(chǔ)的計(jì)算結(jié)果以便后續(xù)使用,這個(gè)參數(shù)只應(yīng)在歸納學(xué)習(xí)的場景中設(shè)置為true;
add_self_loops:是否在鄰接矩陣中增加自環(huán)邊;
normalize:是否添加自環(huán)邊并在運(yùn)行中計(jì)算對(duì)稱歸一化系數(shù);
bias:是否包含偏置項(xiàng)。
基于GCN圖神經(jīng)網(wǎng)絡(luò)的圖節(jié)點(diǎn)分類
通過將torch.nn.Linear layers 替換為PyG的GNN Conv Layers,可以將MLP模型轉(zhuǎn)化為GNN模型。
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
? ? ? ? ?def __init__(self, hidden_channels):
? ? ? ? ? ? ? ? ? super(GCN, self).__init__()
? ? ? ? ? ? ? ? ? torch.manual_seed(12345)
? ? ? ? ? ? ? ? ? ?self.conv1 = GCNConv(dataset.num_features, hidden_channels)
? ? ? ? ? ? ? ? ? self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
? ? ? ?def forward(self, x, edge_index):
? ? ? ? ? ? ? ? ? ? x = self.conv1(x, edge_index)
? ? ? ? ? ? ? ? ? ? x = x.relu()
? ? ? ? ? ? ? ? ? ? x = F.dropout(x, p=0.5, training=self.training)
? ? ? ? ? ? ? ? ? ? ?x = self.conv2(x, edge_index)
? ? ? ? ? ? ? ? ? ? ? return x
model = GCN(hidden_channels=16)
print(model)
運(yùn)行結(jié)果如下:

可視化未訓(xùn)練的GCN網(wǎng)絡(luò)的結(jié)點(diǎn)表征
model = GCN(hidden_channels=16)
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
運(yùn)行結(jié)果如下:

根據(jù)運(yùn)行結(jié)果可以得出:7維特征的結(jié)點(diǎn)被嵌入到2維的平面上,存在同類結(jié)點(diǎn)聚集的情況。
訓(xùn)練GCN結(jié)點(diǎn)分類器
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
? ? ? ? ? ?model.train()
? ? ? ? ? ?optimizer.zero_grad()? # Clear gradients.
? ? ? ? ? ?out = model(data.x, data.edge_index)? # Perform a single forward pass.
? ? ? ? ? ?loss = criterion(out[data.train_mask], data.y[data.train_mask])?
? ? ? ? ? ? # Compute the loss solely based on the training nodes.
? ? ? ? ? ? loss.backward()? # Derive gradients.
? ? ? ? ? ? optimizer.step()? # Update parameters based on gradients.
? ? ? ? ? ? return loss
for epoch in range(1, 201):
? ? ? ? ? ? loss = train()
? ? ? ? ? ? ?print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
運(yùn)行結(jié)果如下:

訓(xùn)練過程結(jié)束后,檢測GCN結(jié)點(diǎn)分類器在測試集上的準(zhǔn)確性
def test():
? ? ? ? ? ? model.eval()
? ? ? ? ? ? out = model(data.x, data.edge_index)
? ? ? ? ? ? pred = out.argmax(dim=1)? # Use the class with highest probability.
? ? ? ? ? ? test_correct = pred[data.test_mask] == data.y[data.test_mask]?
? ? ? ? ? ? # Check against ground-truth labels.
? ? ? ? ? ? ?test_acc = int(test_correct.sum()) / int(data.test_mask.sum())?
? ? ? ? ? ? ?# Derive ratio of correct predictions.
? ? ? ? ? ? ?return test_acc
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
運(yùn)行結(jié)果如下:

通過將線性層替換成GCN層,測試準(zhǔn)確率可以達(dá)到81.4%,高于MLP分類器。表明結(jié)點(diǎn)的鄰接信息在取得更好的準(zhǔn)確率方面起著關(guān)鍵作用。
可視化訓(xùn)練過的GCN模型
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
運(yùn)行結(jié)果如下:

五、GAT及其在圖節(jié)點(diǎn)分類任務(wù)中的應(yīng)用
GAT的定義
圖注意網(wǎng)絡(luò)的數(shù)學(xué)定義為:
注意力系數(shù)的計(jì)算方法為:
PyG中GATConv 模塊說明
GATConv構(gòu)造函數(shù)接口:
GATConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)
in_channels:輸入數(shù)據(jù)維度;
out_channels:輸出數(shù)據(jù)維度;
heads:在GATConv使用多少個(gè)注意力模型;
concat:如為true,不同注意力模型得到的結(jié)點(diǎn)表征被拼接到一起(表征維度翻倍),否則對(duì)不同注意力模型得到的結(jié)點(diǎn)表征求均值。
基于GAT圖神經(jīng)網(wǎng)絡(luò)的圖結(jié)點(diǎn)分類
將MLP例子中的linear層替換為GATConv層,來實(shí)現(xiàn)基于GAT的圖結(jié)點(diǎn)分類神經(jīng)網(wǎng)絡(luò)。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
? ? ? ? def __init__(self, hidden_channels):
? ? ? ? ? ? ? ? ?super(GAT, self).__init__()
? ? ? ? ? ? ? ? ?torch.manual_seed(12345)
? ? ? ? ? ? ? ? ?self.conv1 = GATConv(dataset.num_features, hidden_channels)
? ? ? ? ? ? ? ? ?self.conv2 = GATConv(hidden_channels, dataset.num_classes)
? ? ? ?def forward(self, x, edge_index):
? ? ? ? ? ? ? ? ?x = self.conv1(x, edge_index)
? ? ? ? ? ? ? ? x = x.relu()
? ? ? ? ? ? ? ? x = F.dropout(x, p=0.5, training=self.training)
? ? ? ? ? ? ? ? x = self.conv2(x, edge_index)
? ? ? ? ? ? ? ? return x
model = GAT(hidden_channels=16)
print(model)
運(yùn)行結(jié)果如下:

可視化未訓(xùn)練的GAT網(wǎng)絡(luò)的結(jié)點(diǎn)表征
model = GAT(hidden_channels=16)
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
運(yùn)行結(jié)果如下:

訓(xùn)練GAT結(jié)點(diǎn)分類器
model = GAT(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
? ? ? ? ? ?model.train()
? ? ? ? ? ?optimizer.zero_grad()? # Clear gradients.
? ? ? ? ? ?out = model(data.x, data.edge_index)? # Perform a single forward pass.
? ? ? ? ? ?loss = criterion(out[data.train_mask], data.y[data.train_mask])?
? ? ? ? ? ? # Compute the loss solely based on the training nodes.
? ? ? ? ? ? loss.backward()? # Derive gradients.
? ? ? ? ? ? optimizer.step()? # Update parameters based on gradients.
? ? ? ? ? ? return loss
for epoch in range(1, 201):
? ? ? ? ? ? loss = train()
? ? ? ? ? ? ?print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
運(yùn)行結(jié)果如下:

訓(xùn)練過程結(jié)束后,檢測GAT結(jié)點(diǎn)分類器在測試集上的準(zhǔn)確性
def test():
? ? ? ? ? ? model.eval()
? ? ? ? ? ? out = model(data.x, data.edge_index)
? ? ? ? ? ? pred = out.argmax(dim=1)? # Use the class with highest probability.
? ? ? ? ? ? test_correct = pred[data.test_mask] == data.y[data.test_mask]?
? ? ? ? ? ? # Check against ground-truth labels.
? ? ? ? ? ? ?test_acc = int(test_correct.sum()) / int(data.test_mask.sum())?
? ? ? ? ? ? ?# Derive ratio of correct predictions.
? ? ? ? ? ? ?return test_acc
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
運(yùn)行結(jié)果如下:

通過將線性層替換成GATConv層,測試準(zhǔn)確率可以達(dá)到73.8%,高于MLP分類器。
可視化訓(xùn)練過的GAT模型
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
運(yùn)行結(jié)果如下:

六、MLP、GCN、GAT結(jié)點(diǎn)分類器的對(duì)比
在結(jié)點(diǎn)表征的學(xué)習(xí)中,MLP節(jié)點(diǎn)分類器只考慮了結(jié)點(diǎn)自身屬性,忽略了結(jié)點(diǎn)之間的連接關(guān)系,它的結(jié)果是最差的;而GCN與GAT節(jié)點(diǎn)分類器,同時(shí)考慮了結(jié)點(diǎn)自身屬性與周圍鄰居結(jié)點(diǎn)的屬性,它們的結(jié)果優(yōu)于MLP節(jié)點(diǎn)分類器。
從中可以看出鄰居結(jié)點(diǎn)的信息對(duì)于結(jié)點(diǎn)分類任務(wù)的重要性。
基于圖神經(jīng)網(wǎng)絡(luò)的結(jié)點(diǎn)表征的學(xué)習(xí)遵循消息傳遞范式:
在鄰居結(jié)點(diǎn)信息變換階段,GCN與GAT都對(duì)鄰居結(jié)點(diǎn)做歸一化和線性變換;
在鄰居結(jié)點(diǎn)信息聚合階段都將變換后的鄰居結(jié)點(diǎn)信息做求和聚合;
在中心結(jié)點(diǎn)信息變換階段只是簡單返回鄰居結(jié)點(diǎn)信息聚合階段的聚合結(jié)果。
GCN與GAT的區(qū)別在于鄰居結(jié)點(diǎn)信息聚合過程中的歸一化方法不同:
1. GCN根據(jù)中心結(jié)點(diǎn)與鄰居結(jié)點(diǎn)的度計(jì)算歸一化系數(shù),后者根據(jù)中心結(jié)點(diǎn)與鄰居結(jié)點(diǎn)的相似度計(jì)算歸一化系數(shù)。
2. GCN的歸一化方式依賴于圖的拓?fù)浣Y(jié)構(gòu),不同結(jié)點(diǎn)其自身的度不同、其鄰居的度也不同,在一些應(yīng)用中可能會(huì)影響泛化能力。GAT的歸一化方式依賴于中心結(jié)點(diǎn)與鄰居結(jié)點(diǎn)的相似度,相似度是訓(xùn)練得到的,不受圖的拓?fù)浣Y(jié)構(gòu)的影響,在不同的任務(wù)中都會(huì)有較好的泛化表現(xiàn)。
DataWhale開源學(xué)習(xí)資料:
https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN