转摘PyG使用Heterogenous Operators搭建异构网络模型

风尘阅读量 71

专栏前两篇文章已经写过如何使用PyG处理异构图数据,分别为 [PyG将同构模型转化为异构网络模型](https://weibaohang.blog.csdn.net/article/details/128776979)、[PyG将异构图(heterogeneous)转化为同构图(homogeneous)数据](https://weibaohang.blog.csdn.net/article/details/128777035)、[PyG搭建异构网络模型](https://weibaohang.blog.csdn.net/article/details/128790564),如果不了解的可以点开链接看下。

但是对于PyG还提供了另外一种实现方式就是:Heterogenous Operators,在PyG中已经实现好的异构卷积算子有 HGTConv,这个卷积算子与 GCNConv 这类不同,它可以处理异构图数据,针对不同边缘类型数据执行消息传递。

PyG中HGTConv模块介绍:

![在这里插入图片描述](https://img-blog.csdnimg.cn/fd6889c11b474d018b856636d51fac28.png)
这个卷积算子采用了 Transformer 机制来实现,对于参数来说与传统的图卷积算子没有什么不同,都是输入维度、输出维度,但是对于输入维度可以以元组方式进行传递,表明每种节点类型的输入维度,因为可能不同的节点类型有着不同的特征维度,如果为了简便,也可以传入-1,进行动态获取,还有一个不同的参数就是 heads ,这个就是 Transformer 中的注意力机制的头数,对于 group 就是针对不同边缘类型进行消息传递后的特征的聚合方式,默认是采用求和 sum 来聚合不同边缘类型的特征信息。

示例代码

prism language-python 复制代码
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HGTConv, Linear
import torch

dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()
        
        # 添加异构卷积算子
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(-1, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        # 线性输出层
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['author'])

model = HGT(hidden_channels=64, out_channels=dataset.num_classes, num_heads=5, num_layers=3)
output = model(data.x_dict, data.edge_index_dict)

上述代码来自PyG官方文档,看了文章的小伙伴可能发现文章中还有一段这个代码:

prism language-python 复制代码
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
    self.lin_dict[node_type] = Linear(-1, hidden_channels)

这个就是做特征映射,在使用异构卷积算子进行消息传递之前,为不同的节点类型执行一个线性映射函数,对于这个可做可不做,做的好处可能就是能够使所有节点的特征维度统一变为 hidden_channels,对于效果上有没有帮助我没有做测试,感兴趣的小伙伴可以尝试以下。

复制代码
    ===========================
    【来源: CSDN】
    【作者: 海洋.之心】
    【原文链接】 https://weibaohang.blog.csdn.net/article/details/128792910
    声明:转载此文是出于传递更多信息之目的。若有来源标注错误或侵犯了您的合法权益,请作者持权属证明与本网联系,我们将及时更正、删除,谢谢。
标签:
0/300
全部评论0
0/300