GNN學(xué)習(xí)第9天

首先感謝datawhale 的GNN課程,非常精彩。
GNN/Markdown版本/6-1-數(shù)據(jù)完整存于內(nèi)存的數(shù)據(jù)集類.md?

Task04 數(shù)據(jù)完整存儲(chǔ)與內(nèi)存的數(shù)據(jù)集類+節(jié)點(diǎn)預(yù)測(cè)與邊預(yù)測(cè)任務(wù)實(shí)踐

1 知識(shí)梳理

1.1 使用數(shù)據(jù)集的一般過程

從網(wǎng)絡(luò)上下載數(shù)據(jù)原始文件;

對(duì)數(shù)據(jù)原始文件做處理,為每一個(gè)圖樣本生成一個(gè)**Data對(duì)象**;

對(duì)每一個(gè)Data對(duì)象執(zhí)行數(shù)據(jù)處理,使其轉(zhuǎn)換成新的Data對(duì)象;

過濾Data對(duì)象;

保存Data對(duì)象到文件

獲取Data對(duì)象,在每一次獲取Data對(duì)象時(shí),都先對(duì)Data對(duì)象做數(shù)據(jù)變換(于是獲取到的是數(shù)據(jù)變換后的Data對(duì)象)。

1.2 邊預(yù)測(cè)任務(wù)

思路:生成負(fù)樣本,使得正負(fù)樣本數(shù)量平衡

使用train_test_split_edges函數(shù),采樣得到負(fù)樣本,并將正負(fù)樣本分成訓(xùn)練集、驗(yàn)證集和測(cè)試集

2 實(shí)戰(zhàn)練習(xí)

2.1 PlanetoidPubMed數(shù)據(jù)集類的構(gòu)造?(CORA數(shù)據(jù)集訓(xùn)練滴)dataset = Planetoid(root='./tmp/cora', name='Cora')

dataset = Planetoid(root='./tmp/cora', name='Cora')

print('數(shù)據(jù)類別個(gè)數(shù):', dataset.num_classes)

print('節(jié)點(diǎn)數(shù):', dataset[0].num_nodes)

print('邊數(shù):', dataset[0].num_edges)

print('節(jié)點(diǎn)特征維度:', dataset[0].num_features)

importos.pathasospimporttorchfromtorch_geometric.dataimport(InMemoryDataset,download_url)fromtorch_geometric.ioimportread_planetoid_dataclassPlanetoidPubMed(InMemoryDataset):r""" 節(jié)點(diǎn)代表文章,邊代表引文關(guān)系。

? ? ? ? ? ? ? ? 訓(xùn)練、驗(yàn)證和測(cè)試的劃分通過二進(jìn)制掩碼給出。

? ? 參數(shù):

? ? ? ? root (string): 存儲(chǔ)數(shù)據(jù)集的文件夾的路徑

? ? ? ? transform (callable, optional): 數(shù)據(jù)轉(zhuǎn)換函數(shù),每一次獲取數(shù)據(jù)時(shí)被調(diào)用。

? ? ? ? pre_transform (callable, optional): 數(shù)據(jù)轉(zhuǎn)換函數(shù),數(shù)據(jù)保存到文件前被調(diào)用。

? ? """#? ? url = 'https://github.com/kimiyoung/planetoid/raw/master/data'url='https://gitee.com/jiajiewu/planetoid/raw/master/data'def__init__(self,root,transform=None,pre_transform=None):super(PlanetoidPubMed,self).__init__(root,transform,pre_transform)self.data,self.slices=torch.load(self.processed_paths[0])@propertydefraw_dir(self):returnosp.join(self.root,'raw')@propertydefprocessed_dir(self):returnosp.join(self.root,'processed')@propertydefraw_file_names(self):names=['x','tx','allx','y','ty','ally','graph','test.index']return['ind.pubmed.{}'.format(name)fornameinnames]@propertydefprocessed_file_names(self):return'data.pt'defdownload(self):fornameinself.raw_file_names:download_url('{}/{}'.format(self.url,name),self.raw_dir)defprocess(self):data=read_planetoid_data(self.raw_dir,'pubmed')data=dataifself.pre_transformisNoneelseself.pre_transform(data)torch.save(self.collate([data]),self.processed_paths[0])def__repr__(self):return'{}()'.format(self.name)Copy to clipboardErrorCopied

程序運(yùn)行流程:

檢查數(shù)據(jù)原始文件是否已經(jīng)下載

檢查數(shù)據(jù)是否經(jīng)過處理:檢查數(shù)據(jù)變換的方法、檢查樣本過濾的方法、檢查是否處理好數(shù)據(jù)

dataset=PlanetoidPubMed('dataset/PlanetoidPubMed')print('數(shù)據(jù)類別個(gè)數(shù):',dataset.num_classes)print('節(jié)點(diǎn)數(shù):',dataset[0].num_nodes)print('邊數(shù):',dataset[0].num_edges)print('節(jié)點(diǎn)特征維度:',dataset[0].num_features)Copy to clipboardErrorCopied

數(shù)據(jù)類別個(gè)數(shù): 3

節(jié)點(diǎn)數(shù): 19717

邊數(shù): 88648

節(jié)點(diǎn)特征維度: 500Copy to clipboardErrorCopied

2.2 使用GAT圖神經(jīng)網(wǎng)絡(luò)進(jìn)行節(jié)點(diǎn)預(yù)測(cè)

fromtorch_geometric.nnimportGATConv,Sequentialfromtorch.nnimportLinear,ReLUimporttorch.nn.functionalasFclassGAT(torch.nn.Module):def__init__(self,num_features,hidden_channels_list,num_classes):super(GAT,self).__init__()torch.manual_seed(12345)hns=[num_features]+hidden_channels_list? ? ? ? conv_list=[]foridxinrange(len(hidden_channels_list)):conv_list.append((GATConv(hns[idx],hns[idx+1]),'x, edge_index -> x'))conv_list.append(ReLU(inplace=True),)self.convseq=Sequential('x, edge_index',conv_list)self.linear=Linear(hidden_channels_list[-1],num_classes)defforward(self,x,edge_index):x=self.convseq(x,edge_index)x=F.dropout(x,p=0.5,training=self.training)x=self.linear(x)returnxCopy to clipboardErrorCopied

deftrain():model.train()optimizer.zero_grad()# Clear gradients.out=model(data.x,data.edge_index)# Perform a single forward pass.# Compute the loss solely based on the training nodes.loss=criterion(out[data.train_mask],data.y[data.train_mask])loss.backward()# Derive gradients.optimizer.step()# Update parameters based on gradients.returnlossdeftest():model.eval()out=model(data.x,data.edge_index)pred=out.argmax(dim=1)# Use the class with highest probability.test_correct=pred[data.test_mask]==data.y[data.test_mask]# Check against ground-truth labels.test_acc=int(test_correct.sum())/int(data.test_mask.sum())# Derive ratio of correct predictions.returntest_accCopy to clipboardErrorCopied

importmatplotlib.pyplotaspltfromsklearn.manifoldimportTSNE%matplotlib inlinedefvisualize(h,color):z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())plt.figure(figsize=(10,10))plt.xticks([])plt.yticks([])plt.scatter(z[:,0],z[:,1],s=70,c=color.cpu(),cmap="Set2")plt.show()Copy to clipboardErrorCopied

fromtorch_geometric.transformsimportNormalizeFeaturesdataset=PlanetoidPubMed(root='dataset/PlanetoidPubMed/',transform=NormalizeFeatures())print('dataset.num_features:',dataset.num_features)device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')data=dataset[0].to(device)model=GAT(num_features=dataset.num_features,hidden_channels_list=[200,100],num_classes=dataset.num_classes).to(device)print(model)optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)criterion=torch.nn.CrossEntropyLoss()forepochinrange(1,201):loss=train()ifepoch%10==0:print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')test_acc=test()print(f'Test Accuracy: {test_acc:.4f}')model.eval()out=model(data.x,data.edge_index)visualize(out,color=data.y)Copy to clipboardErrorCopied

dataset.num_features: 500

GAT(

? (convseq): Sequential(

? ? (0): GATConv(500, 200, heads=1)

? ? (1): ReLU(inplace=True)

? ? (2): GATConv(200, 100, heads=1)

? ? (3): ReLU(inplace=True)

? )

? (linear): Linear(in_features=100, out_features=3, bias=True)

)

dataset.num_features: 1433

GAT(

? (convseq): Sequential(

? ? (0): GATConv(1433, 200, heads=1)

? ? (1): ReLU(inplace=True)

? ? (2): GATConv(200, 100, heads=1)

? ? (3): ReLU(inplace=True)

? )

? (linear): Linear(in_features=100, out_features=7, bias=True)

)

Epoch: 010, Loss: 1.7378

Epoch: 020, Loss: 0.7310

Epoch: 030, Loss: 0.2087

Epoch: 040, Loss: 0.0610

Epoch: 050, Loss: 0.0477

Epoch: 060, Loss: 0.0368

Epoch: 070, Loss: 0.0360

Epoch: 080, Loss: 0.0354

Epoch: 090, Loss: 0.0310

Epoch: 100, Loss: 0.0279

Epoch: 110, Loss: 0.0263

Epoch: 120, Loss: 0.0281

Epoch: 130, Loss: 0.0349

Epoch: 140, Loss: 0.0246

Epoch: 150, Loss: 0.0298

Epoch: 160, Loss: 0.0218

Epoch: 170, Loss: 0.0328

Epoch: 180, Loss: 0.0199

Epoch: 190, Loss: 0.0223

Epoch: 200, Loss: 0.0330

Test Accuracy: 0.7510



2.3 使用兩層GCNConv神經(jīng)網(wǎng)絡(luò)進(jìn)行邊預(yù)測(cè)

fromtorch_geometric.datasetsimportPlanetoidfromtorch_geometric.utilsimporttrain_test_split_edgesimporttorch_geometric.transformsasTdevice=torch.device('cuda'iftorch.cuda.is_available()else'cpu')dataset='Cora'path=osp.join('dataset',dataset)# 讀取Cora數(shù)據(jù)集dataset=Planetoid(path,dataset,transform=T.NormalizeFeatures())data=dataset[0]ground_truth_edge_index=data.edge_index.to(device)data.train_mask=data.val_mask=data.test_mask=data.y=None# 劃分?jǐn)?shù)據(jù)集data=train_test_split_edges(data)data=data.to(device)Copy to clipboardErrorCopied

fromtorch_geometric.nnimportGCNConv# 構(gòu)建神經(jīng)網(wǎng)絡(luò)classNet(torch.nn.Module):def__init__(self,in_channels,out_channels):super(Net,self).__init__()self.conv1=GCNConv(in_channels,128)self.conv2=GCNConv(128,out_channels)defencode(self,x,edge_index):x=self.conv1(x,edge_index)x=x.relu()returnself.conv2(x,edge_index)defdecode(self,z,pos_edge_index,neg_edge_index):edge_index=torch.cat([pos_edge_index,neg_edge_index],dim=-1)return(z[edge_index[0]]*z[edge_index[1]]).sum(dim=-1)defdecode_all(self,z):prob_adj=z @ z.t()return(prob_adj>0).nonzero(as_tuple=False).t()Copy to clipboardErrorCopied

fromtorch_geometric.utilsimportnegative_samplingimporttorch.nn.functionalasF# 得到邊的類別{0,1}defget_link_labels(pos_edge_index,neg_edge_index):num_links=pos_edge_index.size(1)+neg_edge_index.size(1)link_labels=torch.zeros(num_links,dtype=torch.float)link_labels[:pos_edge_index.size(1)]=1.returnlink_labelsdeftrain(data,model,optimizer):model.train()# 進(jìn)行負(fù)采樣,使得樣本數(shù)一致neg_edge_index=negative_sampling(edge_index=data.train_pos_edge_index,num_nodes=data.num_nodes,num_neg_samples=data.train_pos_edge_index.size(1))optimizer.zero_grad()z=model.encode(data.x,data.train_pos_edge_index)link_logits=model.decode(z,data.train_pos_edge_index,neg_edge_index)link_labels=get_link_labels(data.train_pos_edge_index,neg_edge_index).to(data.x.device)loss=F.binary_cross_entropy_with_logits(link_logits,link_labels)loss.backward()optimizer.step()returnlossCopy to clipboardErrorCopied

fromsklearn.metricsimportroc_auc_score@torch.no_grad()deftest(data,model):model.eval()z=model.encode(data.x,data.train_pos_edge_index)results=[]forprefixin['val','test']:pos_edge_index=data[f'{prefix}_pos_edge_index']neg_edge_index=data[f'{prefix}_neg_edge_index']link_logits=model.decode(z,pos_edge_index,neg_edge_index)# 得到正負(fù)類別概率link_probs=link_logits.sigmoid()link_labels=get_link_labels(pos_edge_index,neg_edge_index)results.append(roc_auc_score(link_labels.cpu(),link_probs.cpu()))returnresultsCopy to clipboardErrorCopied

model=Net(dataset.num_features,64).to(device)optimizer=torch.optim.Adam(params=model.parameters(),lr=0.01)best_val_auc=test_auc=0forepochinrange(1,101):loss=train(data,model,optimizer)val_auc,tmp_test_auc=test(data,model)ifval_auc>best_val_auc:best_val_auc=val_auc? ? ? ? test_auc=tmp_test_aucifepoch%10==0:print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, 'f'Test: {test_auc:.4f}')z=model.encode(data.x,data.train_pos_edge_index)final_edge_index=model.decode_all(z)print('ground truth edge shape:',ground_truth_edge_index.shape)print('final edge shape:',final_edge_index.shape)Copy to clipboardErrorCopied

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容