pyg-nightly 2.7.0.dev20250403__py3-none-any.whl → 2.7.0.dev20250405__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250403.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250403.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/RECORD +10 -7
- torch_geometric/__init__.py +1 -1
- torch_geometric/nn/models/__init__.py +5 -0
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/gpse.py +1079 -0
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +39 -0
- {pyg_nightly-2.7.0.dev20250403.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250403.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/licenses/LICENSE +0 -0
@@ -37,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
|
|
37
37
|
from .largest_connected_components import LargestConnectedComponents
|
38
38
|
from .virtual_node import VirtualNode
|
39
39
|
from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
|
40
|
+
from .add_gpse import AddGPSE
|
40
41
|
from .feature_propagation import FeaturePropagation
|
41
42
|
from .half_hop import HalfHop
|
42
43
|
|
@@ -108,6 +109,7 @@ graph_transforms = [
|
|
108
109
|
'VirtualNode',
|
109
110
|
'AddLaplacianEigenvectorPE',
|
110
111
|
'AddRandomWalkPE',
|
112
|
+
'AddGPSE',
|
111
113
|
'FeaturePropagation',
|
112
114
|
'HalfHop',
|
113
115
|
]
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from torch_geometric.data import Data
|
2
|
+
from torch_geometric.data.datapipes import functional_transform
|
3
|
+
from torch_geometric.nn.models.gpse import GPSE
|
4
|
+
from torch_geometric.transforms import BaseTransform, VirtualNode
|
5
|
+
|
6
|
+
|
7
|
+
@functional_transform('add_gpse')
|
8
|
+
class AddGPSE(BaseTransform):
|
9
|
+
r"""Adds the GPSE encoding from the `"Graph Positional and Structural
|
10
|
+
Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
|
11
|
+
(functional name: :obj:`add_gpse`).
|
12
|
+
To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
|
13
|
+
the actual encodings.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
model (GPSE): The pre-trained GPSE model.
|
17
|
+
use_vn (bool, optional): Whether to use virtual nodes.
|
18
|
+
(default: :obj:`True`)
|
19
|
+
rand_type (str, optional): Type of random features to use. Options are
|
20
|
+
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
|
21
|
+
(default: :obj:`NormalSE`)
|
22
|
+
|
23
|
+
"""
|
24
|
+
def __init__(self, model: GPSE, use_vn: bool = True,
|
25
|
+
rand_type: str = 'NormalSE'):
|
26
|
+
self.model = model
|
27
|
+
self.use_vn = use_vn
|
28
|
+
self.vn = VirtualNode()
|
29
|
+
self.rand_type = rand_type
|
30
|
+
|
31
|
+
def __call__(self, data: Data) -> Data:
|
32
|
+
from torch_geometric.nn.models.gpse import gpse_process
|
33
|
+
|
34
|
+
data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
|
35
|
+
batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
|
36
|
+
batch_out = batch_out.to('cpu', non_blocking=True)
|
37
|
+
data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
|
38
|
+
|
39
|
+
return data
|
File without changes
|
{pyg_nightly-2.7.0.dev20250403.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/licenses/LICENSE
RENAMED
File without changes
|