pyg-nightly 2.7.0.dev20250403__py3-none-any.whl → 2.7.0.dev20250405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
37
37
  from .largest_connected_components import LargestConnectedComponents
38
38
  from .virtual_node import VirtualNode
39
39
  from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
40
+ from .add_gpse import AddGPSE
40
41
  from .feature_propagation import FeaturePropagation
41
42
  from .half_hop import HalfHop
42
43
 
@@ -108,6 +109,7 @@ graph_transforms = [
108
109
  'VirtualNode',
109
110
  'AddLaplacianEigenvectorPE',
110
111
  'AddRandomWalkPE',
112
+ 'AddGPSE',
111
113
  'FeaturePropagation',
112
114
  'HalfHop',
113
115
  ]
@@ -0,0 +1,39 @@
1
+ from torch_geometric.data import Data
2
+ from torch_geometric.data.datapipes import functional_transform
3
+ from torch_geometric.nn.models.gpse import GPSE
4
+ from torch_geometric.transforms import BaseTransform, VirtualNode
5
+
6
+
7
+ @functional_transform('add_gpse')
8
+ class AddGPSE(BaseTransform):
9
+ r"""Adds the GPSE encoding from the `"Graph Positional and Structural
10
+ Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
11
+ (functional name: :obj:`add_gpse`).
12
+ To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
13
+ the actual encodings.
14
+
15
+ Args:
16
+ model (GPSE): The pre-trained GPSE model.
17
+ use_vn (bool, optional): Whether to use virtual nodes.
18
+ (default: :obj:`True`)
19
+ rand_type (str, optional): Type of random features to use. Options are
20
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
21
+ (default: :obj:`NormalSE`)
22
+
23
+ """
24
+ def __init__(self, model: GPSE, use_vn: bool = True,
25
+ rand_type: str = 'NormalSE'):
26
+ self.model = model
27
+ self.use_vn = use_vn
28
+ self.vn = VirtualNode()
29
+ self.rand_type = rand_type
30
+
31
+ def __call__(self, data: Data) -> Data:
32
+ from torch_geometric.nn.models.gpse import gpse_process
33
+
34
+ data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
35
+ batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
36
+ batch_out = batch_out.to('cpu', non_blocking=True)
37
+ data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
38
+
39
+ return data