转摘PyG中global_max_pool()函数介绍
最近在研究PyG这个框架,所以借这个机会讲解一些常见的函数,本文将说明 global_max_pool()
这个池化函数,并且以图解及示例的方式说明。
PyG中global_max_pool()函数定义如下:

参数列表:
- x:节点特征向量,维度为【num_nodes,feature_size】
- batch:维度为【num_nodes,】,标志着每个节点属于哪个图或者哪个簇
- size:节点个数,如果不提供可以根据输入的x进行计算
我们在图分类任务时,需要获取每张图的特征表示,但是每张图会存在多个节点,我们需要根据这些节点形成一个新的特征向量作为图的表示,这时就利用到了全局池化。
对于 global_max_pool()
这个函数会对多个节点按照 batch
进行池化,这里举个例子:
prism language-python
x = torch.randn(5, 10)
batch = torch.tensor([0, 0, 0, 1, 1])
output = pyn_nn.global_add_pool(x, batch)
print(output.shape)
>>>torch.Size([2, 10])
首先我们定义了一个【5,10】代表节点特征矩阵,每个节点的特征维度为10,然后我们又定义了 batch=[0, 0, 0, 1, 1]
表明第一个、第二个和第三个节点进行池化,第四个和第五个节点进行池化,最终我们形成了2个特征向量,表示每个簇进行池化后的结果。
这个函数一般用于在使用 Loader
进行批次计算时使用,因为 Loader
会将多个子图拼接成为一个大图,这时如果对这张图进行池化,不能够分清哪个节点属于哪张图,所以可以使用 global_max_pool()
这个函数,将大图的节点特征向量传入,然后送入 batch
这个参数指明每个节点属于该批次中的哪张图。
在PyG中还有一个池化函数,跟他很类似,就是
max_pool()
,他们两个作用是类似的,都是对于同一簇内的节点进行池化,对于全局池化相当于对一整张图进行池化,如果我们设置max_pool()
的参数cluster
为一定值,那么此时跟global_max_pool()
这个函数作用是一致的。
对于 max_pool()
的介绍可以参考这篇文章 [PyG中max_pool()函数介绍](https://blog.csdn.net/m0_47256162/article/details/128853900?spm=1001.2014.3001.5501)
。
===========================
【来源: CSDN】
【作者: 海洋.之心】
【原文链接】 https://weibaohang.blog.csdn.net/article/details/128854628
声明:转载此文是出于传递更多信息之目的。若有来源标注错误或侵犯了您的合法权益,请作者持权属证明与本网联系,我们将及时更正、删除,谢谢。