pyg-nightly 2.7.0.dev20250404__py3-none-any.whl → 2.7.0.dev20250405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyg-nightly
3
- Version: 2.7.0.dev20250404
3
+ Version: 2.7.0.dev20250405
4
4
  Summary: Graph Neural Network Library for PyTorch
5
5
  Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
6
6
  Author-email: Matthias Fey <matthias@pyg.org>
@@ -1,4 +1,4 @@
1
- torch_geometric/__init__.py,sha256=FXQ6WZdtxVAyouslhtsPt0xXJ33VOV4_aeSKmDUWYrU,1978
1
+ torch_geometric/__init__.py,sha256=Gv9oWwU7Y-5CKcoO3SA059-jYYjsTMcsubgypKMaUcQ,1978
2
2
  torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
3
3
  torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
4
4
  torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
@@ -426,7 +426,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
426
426
  torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
427
427
  torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
428
428
  torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
429
- torch_geometric/nn/models/__init__.py,sha256=dwKH_fLrAMV9nda0nQHLyCzIOPezHCZa01pmTCdaJxU,2263
429
+ torch_geometric/nn/models/__init__.py,sha256=4mZ5dyiZ9aa1NaBth1qYV-hZdnG_Np1XWvRLB4Qv6RM,2338
430
430
  torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
431
431
  torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
432
432
  torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
@@ -441,6 +441,7 @@ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0x
441
441
  torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
442
442
  torch_geometric/nn/models/glem.py,sha256=sT0XM4klVlci9wduvUoXupATUw9p25uXtaJBrmv3yvs,16431
443
443
  torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
444
+ torch_geometric/nn/models/gpse.py,sha256=my-KIw_Ov8o0pXSCyh43NZRBAW95TFfmBgxzSimx8-A,42680
444
445
  torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
445
446
  torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
446
447
  torch_geometric/nn/models/jumping_knowledge.py,sha256=9JR2EoViXKjcDSLb8tjJm-UHfv1mQCJvZAAEsYa0Ocw,5496
@@ -521,7 +522,8 @@ torch_geometric/testing/decorators.py,sha256=j45wlxMB1-Pn3wPKBgDziqg6KkWJUb_fcwf
521
522
  torch_geometric/testing/distributed.py,sha256=ZZCCXqiQC4-m1ExSjDZhS_a1qPXnHEwhJGTmACxNnVI,2227
522
523
  torch_geometric/testing/feature_store.py,sha256=J6JBIt2XK-t8yG8B4JzXp-aJcVl5jaCS1m2H7d6OUxs,2158
523
524
  torch_geometric/testing/graph_store.py,sha256=00B7QToCIspYmgN7svQKp1iU-qAzEtrt3VQRFxkHfuk,1044
524
- torch_geometric/transforms/__init__.py,sha256=9HElLNLbIRgcOVRVbFcVfMwfRsemPAaRFeJdgz2qWmQ,4251
525
+ torch_geometric/transforms/__init__.py,sha256=P0R2CFg9pXxjTX4NnYfNPrifRPAw5lVXEOxO80q-1Ek,4296
526
+ torch_geometric/transforms/add_gpse.py,sha256=4o0UrSmTu3CKsL3UAREiul8O4lC02PUx_ajxP4sPsxU,1570
525
527
  torch_geometric/transforms/add_metapaths.py,sha256=GabaPRvUnpFrZJsxLMUBY2Egzx94GTgsMxegL_qTtbk,14239
526
528
  torch_geometric/transforms/add_positional_encoding.py,sha256=tuilyubAn3yeyz8mvFc5zxXTlNzh8okKzG9AE2lPG1Q,6049
527
529
  torch_geometric/transforms/add_remaining_self_loops.py,sha256=ItU5FAcE-mkbp_wqTLkRhv0RShR5JVr8vr9d5xv3_Ak,2085
@@ -634,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
634
636
  torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
635
637
  torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
636
638
  torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
637
- pyg_nightly-2.7.0.dev20250404.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
638
- pyg_nightly-2.7.0.dev20250404.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
639
- pyg_nightly-2.7.0.dev20250404.dist-info/METADATA,sha256=AO-euV7W-eeQ5YKe9jRqjlvrpn4qZGiLmiYV_VNxKyI,63021
640
- pyg_nightly-2.7.0.dev20250404.dist-info/RECORD,,
639
+ pyg_nightly-2.7.0.dev20250405.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
640
+ pyg_nightly-2.7.0.dev20250405.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
641
+ pyg_nightly-2.7.0.dev20250405.dist-info/METADATA,sha256=9NrBoIq1lq6tX1hOXmse3Dvp7PZNQjcsoyGzucHSgp8,63021
642
+ pyg_nightly-2.7.0.dev20250405.dist-info/RECORD,,
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
31
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
32
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
33
33
 
34
- __version__ = '2.7.0.dev20250404'
34
+ __version__ = '2.7.0.dev20250405'
35
35
 
36
36
  __all__ = [
37
37
  'Index',
@@ -12,6 +12,7 @@ from .re_net import RENet
12
12
  from .graph_unet import GraphUNet
13
13
  from .schnet import SchNet
14
14
  from .dimenet import DimeNet, DimeNetPlusPlus
15
+ from .gpse import GPSE, GPSENodeEncoder
15
16
  from .captum import to_captum_model
16
17
  from .metapath2vec import MetaPath2Vec
17
18
  from .deepgcn import DeepGCNLayer
@@ -62,6 +63,8 @@ __all__ = classes = [
62
63
  'SchNet',
63
64
  'DimeNet',
64
65
  'DimeNetPlusPlus',
66
+ 'GPSE',
67
+ 'GPSENodeEncoder',
65
68
  'to_captum_model',
66
69
  'to_captum_input',
67
70
  'captum_output_to_dicts',
@@ -0,0 +1,1079 @@
1
+ import logging
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from collections import OrderedDict
6
+ from typing import List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from tqdm import trange
13
+
14
+ import torch_geometric.transforms as T
15
+ from torch_geometric.data import Data, Dataset, download_url
16
+ from torch_geometric.loader import DataLoader, NeighborLoader
17
+ from torch_geometric.nn import (
18
+ ResGatedGraphConv,
19
+ global_add_pool,
20
+ global_max_pool,
21
+ global_mean_pool,
22
+ )
23
+ from torch_geometric.nn.resolver import activation_resolver
24
+ from torch_geometric.utils import to_dense_batch
25
+
26
+
27
+ class Linear(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_channels: int,
31
+ out_channels: int,
32
+ bias: bool,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.model = torch.nn.Linear(in_channels, out_channels, bias=bias)
36
+
37
+ def forward(self, batch):
38
+ if isinstance(batch, torch.Tensor):
39
+ batch = self.model(batch)
40
+ else:
41
+ batch.x = self.model(batch.x)
42
+ return batch
43
+
44
+
45
+ class ResGatedGCNConv(torch.nn.Module):
46
+ def __init__(
47
+ self,
48
+ in_channels: int,
49
+ out_channels: int,
50
+ bias: bool,
51
+ **kwargs,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.model = ResGatedGraphConv(
55
+ in_channels,
56
+ out_channels,
57
+ bias=bias,
58
+ **kwargs,
59
+ )
60
+
61
+ def forward(self, batch):
62
+ batch.x = self.model(batch.x, batch.edge_index)
63
+ return batch
64
+
65
+
66
+ class GeneralLayer(torch.nn.Module):
67
+ def __init__(
68
+ self,
69
+ name: str,
70
+ in_channels: int,
71
+ out_channels: int,
72
+ has_batch_norm: bool,
73
+ has_l2_norm: bool,
74
+ dropout: float,
75
+ act: Optional[str],
76
+ **kwargs,
77
+ ):
78
+ super().__init__()
79
+ self.has_l2_norm = has_l2_norm
80
+
81
+ layer_dict = {
82
+ 'linear': Linear,
83
+ 'resgatedgcnconv': ResGatedGCNConv,
84
+ }
85
+ self.layer = layer_dict[name](
86
+ in_channels,
87
+ out_channels,
88
+ bias=not has_batch_norm,
89
+ **kwargs,
90
+ )
91
+ post_layers = []
92
+ if has_batch_norm:
93
+ post_layers.append(
94
+ torch.nn.BatchNorm1d(out_channels, eps=1e-5, momentum=0.1))
95
+ if dropout > 0:
96
+ post_layers.append(torch.nn.Dropout(p=dropout, inplace=False))
97
+ if act is not None:
98
+ post_layers.append(activation_resolver(act))
99
+ self.post_layer = nn.Sequential(*post_layers)
100
+
101
+ def forward(self, batch):
102
+ batch = self.layer(batch)
103
+ if isinstance(batch, torch.Tensor):
104
+ batch = self.post_layer(batch)
105
+ if self.has_l2_norm:
106
+ batch = F.normalize(batch, p=2, dim=1)
107
+ else:
108
+ batch.x = self.post_layer(batch.x)
109
+ if self.has_l2_norm:
110
+ batch.x = F.normalize(batch.x, p=2, dim=1)
111
+ return batch
112
+
113
+
114
+ class GeneralMultiLayer(torch.nn.Module):
115
+ def __init__(
116
+ self,
117
+ name: str,
118
+ in_channels: int,
119
+ out_channels: int,
120
+ hidden_channels: Optional[int],
121
+ num_layers: int,
122
+ has_batch_norm: bool,
123
+ has_l2_norm: bool,
124
+ dropout: float,
125
+ act: str,
126
+ final_act: bool,
127
+ **kwargs,
128
+ ) -> None:
129
+ super().__init__()
130
+ hidden_channels = hidden_channels or out_channels
131
+
132
+ for i in range(num_layers):
133
+ d_in = in_channels if i == 0 else hidden_channels
134
+ d_out = out_channels if i == num_layers - 1 else hidden_channels
135
+ layer = GeneralLayer(
136
+ name=name,
137
+ in_channels=d_in,
138
+ out_channels=d_out,
139
+ has_batch_norm=has_batch_norm,
140
+ has_l2_norm=has_l2_norm,
141
+ dropout=dropout,
142
+ act=None if i == num_layers - 1 and not final_act else act,
143
+ **kwargs,
144
+ )
145
+ self.add_module(f'Layer_{i}', layer)
146
+
147
+ def forward(self, batch):
148
+ for layer in self.children():
149
+ batch = layer(batch)
150
+ return batch
151
+
152
+
153
+ class BatchNorm1dNode(torch.nn.Module):
154
+ def __init__(self, channels: int) -> None:
155
+ super().__init__()
156
+ self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)
157
+
158
+ def forward(self, batch):
159
+ batch.x = self.bn(batch.x)
160
+ return batch
161
+
162
+
163
+ class BatchNorm1dEdge(torch.nn.Module):
164
+ def __init__(self, channels: int) -> None:
165
+ super().__init__()
166
+ self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)
167
+
168
+ def forward(self, batch):
169
+ batch.edge_attr = self.bn(batch.edge_attr)
170
+ return batch
171
+
172
+
173
+ class MLP(torch.nn.Module):
174
+ def __init__(
175
+ self,
176
+ in_channels: int,
177
+ out_channels: int,
178
+ hidden_channels: Optional[int],
179
+ num_layers: int,
180
+ has_batch_norm: bool = True,
181
+ has_l2_norm: bool = True,
182
+ dropout: float = 0.2,
183
+ act: str = 'relu',
184
+ **kwargs,
185
+ ):
186
+ super().__init__()
187
+ hidden_channels = hidden_channels or in_channels
188
+
189
+ layers = []
190
+ if num_layers > 1:
191
+ layer = GeneralMultiLayer(
192
+ 'linear',
193
+ in_channels,
194
+ hidden_channels,
195
+ hidden_channels,
196
+ num_layers - 1,
197
+ has_batch_norm,
198
+ has_l2_norm,
199
+ dropout,
200
+ act,
201
+ final_act=True,
202
+ **kwargs,
203
+ )
204
+ layers.append(layer)
205
+ layers.append(Linear(hidden_channels, out_channels, bias=True))
206
+ self.model = nn.Sequential(*layers)
207
+
208
+ def forward(self, batch):
209
+ if isinstance(batch, torch.Tensor):
210
+ batch = self.model(batch)
211
+ else:
212
+ batch.x = self.model(batch.x)
213
+ return batch
214
+
215
+
216
+ class GNNStackStage(torch.nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels: int,
220
+ out_channels: int,
221
+ num_layers: int,
222
+ layer_type: str,
223
+ stage_type: str = 'skipsum',
224
+ final_l2_norm: bool = True,
225
+ has_batch_norm: bool = True,
226
+ has_l2_norm: bool = True,
227
+ dropout: float = 0.2,
228
+ act: Optional[str] = 'relu',
229
+ ):
230
+ super().__init__()
231
+ self.num_layers = num_layers
232
+ self.stage_type = stage_type
233
+ self.final_l2_norm = final_l2_norm
234
+
235
+ for i in range(num_layers):
236
+ if stage_type == 'skipconcat':
237
+ if i == 0:
238
+ d_in = in_channels
239
+ else:
240
+ d_in = in_channels + i * out_channels
241
+ else:
242
+ d_in = in_channels if i == 0 else out_channels
243
+ layer = GeneralLayer(layer_type, d_in, out_channels,
244
+ has_batch_norm, has_l2_norm, dropout, act)
245
+ self.add_module(f'layer{i}', layer)
246
+
247
+ def forward(self, batch):
248
+ for i, layer in enumerate(self.children()):
249
+ x = batch.x
250
+ batch = layer(batch)
251
+ if self.stage_type == 'skipsum':
252
+ batch.x = x + batch.x
253
+ elif self.stage_type == 'skipconcat' and i < self.num_layers - 1:
254
+ batch.x = torch.cat([x, batch.x], dim=1)
255
+
256
+ if self.final_l2_norm:
257
+ batch.x = F.normalize(batch.x, p=2, dim=-1)
258
+
259
+ return batch
260
+
261
+
262
+ class GNNInductiveHybridMultiHead(torch.nn.Module):
263
+ r"""GNN prediction head for inductive node and graph prediction tasks using
264
+ individual MLP for each task.
265
+
266
+ Args:
267
+ dim_in (int): Input dimension.
268
+ dim_out (int): Output dimension. Not used, as the dimension is
269
+ determined by :obj:`num_node_targets` and :obj:`num_graph_targets`
270
+ instead.
271
+ num_node_targets (int): Number of individual PSEs used as node-level
272
+ targets in pretraining :class:`GPSE`.
273
+ num_graph_targets (int): Number of graph-level targets used in
274
+ pretraining :class:`GPSE`.
275
+ layers_post_mp (int): Number of MLP layers after GNN message-passing.
276
+ virtual_node (bool, optional): Whether a virtual node is added to
277
+ graphs in :class:`GPSE` computation. (default: :obj:`True`)
278
+ multi_head_dim_inner (int, optional): Width of MLPs for PSE target
279
+ prediction heads. (default: :obj:`32`)
280
+ graph_pooling (str, optional): Type of graph pooling applied before
281
+ post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
282
+ (default: :obj:`add`)
283
+ has_bn (bool, optional): Whether to apply batch normalization to layer
284
+ outputs. (default: :obj:`True`)
285
+ has_l2norm (bool, optional): Wheter to apply L2 normalization to the
286
+ layer outputs. (default: :obj:`True`)
287
+ dropout (float, optional): Dropout ratio at layer output.
288
+ (default: :obj:`0.2`)
289
+ act (str, optional): Activation to apply to layer outputs if
290
+ :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
291
+ """
292
+ def __init__(
293
+ self,
294
+ dim_in: int,
295
+ dim_out: int,
296
+ num_node_targets: int,
297
+ num_graph_targets: int,
298
+ layers_post_mp: int,
299
+ virtual_node: bool = True,
300
+ multi_head_dim_inner: int = 32,
301
+ graph_pooling: str = 'add',
302
+ has_bn: bool = True,
303
+ has_l2norm: bool = True,
304
+ dropout: float = 0.2,
305
+ act: str = 'relu',
306
+ ):
307
+ super().__init__()
308
+ pool_dict = {
309
+ 'add': global_add_pool,
310
+ 'max': global_max_pool,
311
+ 'mean': global_mean_pool
312
+ }
313
+ self.node_target_dim = num_node_targets
314
+ self.graph_target_dim = num_graph_targets
315
+ self.virtual_node = virtual_node
316
+ num_layers = layers_post_mp
317
+
318
+ self.node_post_mps = nn.ModuleList([
319
+ MLP(dim_in, 1, multi_head_dim_inner, num_layers, has_bn,
320
+ has_l2norm, dropout, act) for _ in range(self.node_target_dim)
321
+ ])
322
+
323
+ self.graph_pooling = pool_dict[graph_pooling]
324
+
325
+ self.graph_post_mp = MLP(dim_in, self.graph_target_dim, dim_in,
326
+ num_layers, has_bn, has_l2norm, dropout, act)
327
+
328
+ def _pad_and_stack(self, x1: torch.Tensor, x2: torch.Tensor, pad1: int,
329
+ pad2: int):
330
+ padded_x1 = nn.functional.pad(x1, (0, pad2))
331
+ padded_x2 = nn.functional.pad(x2, (pad1, 0))
332
+ return torch.vstack([padded_x1, padded_x2])
333
+
334
+ def _apply_index(self, batch, virtual_node: bool, pad_node: int,
335
+ pad_graph: int):
336
+ graph_pred, graph_true = batch.graph_feature, batch.y_graph
337
+ node_pred, node_true = batch.node_feature, batch.y
338
+ if virtual_node:
339
+ # Remove virtual node
340
+ idx = torch.concat([
341
+ torch.where(batch.batch == i)[0][:-1]
342
+ for i in range(batch.batch.max().item() + 1)
343
+ ])
344
+ node_pred, node_true = node_pred[idx], node_true[idx]
345
+
346
+ # Stack node predictions on top of graph predictions and pad with zeros
347
+ pred = self._pad_and_stack(node_pred, graph_pred, pad_node, pad_graph)
348
+ true = self._pad_and_stack(node_true, graph_true, pad_node, pad_graph)
349
+
350
+ return pred, true
351
+
352
+ def forward(self, batch):
353
+ batch.node_feature = torch.hstack(
354
+ [m(batch.x) for m in self.node_post_mps])
355
+ graph_emb = self.graph_pooling(batch.x, batch.batch)
356
+ batch.graph_feature = self.graph_post_mp(graph_emb)
357
+ return self._apply_index(batch, self.virtual_node,
358
+ self.node_target_dim, self.graph_target_dim)
359
+
360
+
361
+ class IdentityHead(torch.nn.Module):
362
+ def forward(self, batch):
363
+ return batch.x, batch.y
364
+
365
+
366
+ class GPSE(torch.nn.Module):
367
+ r"""The Graph Positional and Structural Encoder (GPSE) model from the
368
+ `"Graph Positional and Structural Encoder"
369
+ <https://arxiv.org/abs/2307.07107>`_ paper.
370
+
371
+ The GPSE model consists of a (1) deep GNN that consists of stacked
372
+ message passing layers, and a (2) prediction head to predict pre-computed
373
+ positional and structural encodings (PSE).
374
+ When used on downstream datasets, these prediction heads are removed and
375
+ the final fully-connected layer outputs are used as learned PSE embeddings.
376
+
377
+ GPSE also provides a static method :meth:`from_pretrained` to load
378
+ pre-trained GPSE models trained on a variety of molecular datasets.
379
+
380
+ .. code-block:: python
381
+
382
+ from torch_geometric.nn import GPSE, GPSENodeEncoder,
383
+ from torch_geometric.transforms import AddGPSE
384
+ from torch_geometric.nn.models.gpse import precompute_GPSE
385
+
386
+ gpse_model = GPSE.from_pretrained('molpcba')
387
+
388
+ # Option 1: Precompute GPSE encodings in-place for a given dataset
389
+ dataset = ZINC(path, subset=True, split='train')
390
+ precompute_gpse(gpse_model, dataset)
391
+
392
+ # Option 2: Use the GPSE model with AddGPSE as a pre_transform to save
393
+ # the encodings
394
+ dataset = ZINC(path, subset=True, split='train',
395
+ pre_transform=AddGPSE(gpse_model, vn=True,
396
+ rand_type='NormalSE'))
397
+
398
+ Both approaches append the generated encodings to the :obj:`pestat_GPSE`
399
+ attribute of :class:`~torch_geometric.data.Data` objects. To use the GPSE
400
+ encodings for a downstream task, one may need to add these encodings to the
401
+ :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects. To
402
+ do so, one can use the :class:`GPSENodeEncoder` provided to map these
403
+ encodings to a desired dimension before appending them to :obj:`x`.
404
+
405
+ Let's say we have a graph dataset with 64 original node features, and we
406
+ have generated GPSE encodings of dimension 32, i.e.
407
+ :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an
408
+ inner dimension of 128. To do so, we can map the 32-dimensional GPSE
409
+ encodings to a higher dimension of 64, and then append them to the :obj:`x`
410
+ attribute of the :class:`~torch_geometric.data.Data` objects to obtain a
411
+ 128-dimensional node feature representation.
412
+ :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and
413
+ concatenation to :obj:`x`, the outputs of which can be used as input to a
414
+ GNN:
415
+
416
+ .. code-block:: python
417
+
418
+ encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
419
+ expand_x=False)
420
+ gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
421
+
422
+ for batch in loader:
423
+ batch = encoder(batch)
424
+ batch = gnn(batch)
425
+ # Do something with the batch, which now includes 128-dimensional
426
+ # node representations
427
+
428
+
429
+ Args:
430
+ dim_in (int, optional): Input dimension. (default: :obj:`20`)
431
+ dim_out (int, optional): Output dimension. (default: :obj:`51`)
432
+ dim_inner (int, optional): Width of the encoder layers.
433
+ (default: :obj:`512`)
434
+ layer_type (str, optional): Type of graph convolutional layer for
435
+ message-passing. (default: :obj:`resgatedgcnconv`)
436
+ layers_pre_mp (int, optional): Number of MLP layers before
437
+ message-passing. (default: :obj:`1`)
438
+ layers_mp (int, optional): Number of layers for message-passing.
439
+ (default: :obj:`20`)
440
+ layers_post_mp (int, optional): Number of MLP layers after
441
+ message-passing. (default: :obj:`2`)
442
+ num_node_targets (int, optional): Number of individual PSEs used as
443
+ node-level targets in pretraining :class:`GPSE`.
444
+ (default: :obj:`51`)
445
+ num_graph_targets (int, optional): Number of graph-level targets used
446
+ in pretraining :class:`GPSE`. (default: :obj:`11`)
447
+ stage_type (str, optional): The type of staging to apply. Possible
448
+ values are: :obj:`skipsum`, :obj:`skipconcat`. Any other value will
449
+ default to no skip connections. (default: :obj:`skipsum`)
450
+ has_bn (bool, optional): Whether to apply batch normalization in the
451
+ layer. (default: :obj:`True`)
452
+ final_l2norm (bool, optional): Whether to apply L2 normalization to the
453
+ outputs. (default: :obj:`True`)
454
+ has_l2norm (bool, optional): Whether to apply L2 normalization after
455
+ the layer. (default: :obj:`True`)
456
+ dropout (float, optional): Dropout ratio at layer output.
457
+ (default: :obj:`0.2`)
458
+ has_act (bool, optional): Whether has activation after the layer.
459
+ (default: :obj:`True`)
460
+ final_act (bool, optional): Whether to apply activation after the layer
461
+ stack. (default: :obj:`True`)
462
+ act (str, optional): Activation to apply to layer output if
463
+ :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
464
+ virtual_node (bool, optional): Whether a virtual node is added to
465
+ graphs in :class:`GPSE` computation. (default: :obj:`True`)
466
+ multi_head_dim_inner (int, optional): Width of MLPs for PSE target
467
+ prediction heads. (default: :obj:`32`)
468
+ graph_pooling (str, optional): Type of graph pooling applied before
469
+ post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
470
+ (default: :obj:`add`)
471
+ use_repr (bool, optional): Whether to use the hidden representation of
472
+ the final layer as :class:`GPSE` encodings. (default: :obj:`True`)
473
+ repr_type (str, optional): Type of representation to use. Options are
474
+ :obj:`no_post_mp`, :obj:`one_layer_before`.
475
+ (default: :obj:`no_post_mp`)
476
+ bernoulli_threshold (float, optional): Threshold for Bernoulli sampling
477
+ of virtual nodes. (default: :obj:`0.5`)
478
+ """
479
+
480
+ url_dict = {
481
+ 'molpcba':
482
+ 'https://zenodo.org/record/8145095/files/'
483
+ 'gpse_model_molpcba_1.0.pt',
484
+ 'zinc':
485
+ 'https://zenodo.org/record/8145095/files/gpse_model_zinc_1.0.pt',
486
+ 'pcqm4mv2':
487
+ 'https://zenodo.org/record/8145095/files/'
488
+ 'gpse_model_pcqm4mv2_1.0.pt',
489
+ 'geom':
490
+ 'https://zenodo.org/record/8145095/files/gpse_model_geom_1.0.pt',
491
+ 'chembl':
492
+ 'https://zenodo.org/record/8145095/files/gpse_model_chembl_1.0.pt'
493
+ }
494
+
495
+ def __init__(
496
+ self,
497
+ dim_in: int = 20,
498
+ dim_out: int = 51,
499
+ dim_inner: int = 512,
500
+ layer_type: str = 'resgatedgcnconv',
501
+ layers_pre_mp: int = 1,
502
+ layers_mp: int = 20,
503
+ layers_post_mp: int = 2,
504
+ num_node_targets: int = 51,
505
+ num_graph_targets: int = 11,
506
+ stage_type: str = 'skipsum',
507
+ has_bn: bool = True,
508
+ head_bn: bool = False,
509
+ final_l2norm: bool = True,
510
+ has_l2norm: bool = True,
511
+ dropout: float = 0.2,
512
+ has_act: bool = True,
513
+ final_act: bool = True,
514
+ act: str = 'relu',
515
+ virtual_node: bool = True,
516
+ multi_head_dim_inner: int = 32,
517
+ graph_pooling: str = 'add',
518
+ use_repr: bool = True,
519
+ repr_type: str = 'no_post_mp',
520
+ bernoulli_threshold: float = 0.5,
521
+ ):
522
+ super().__init__()
523
+
524
+ self.use_repr = use_repr
525
+ self.repr_type = repr_type
526
+ self.bernoulli_threshold = bernoulli_threshold
527
+
528
+ if layers_pre_mp > 0:
529
+ self.pre_mp = GeneralMultiLayer(
530
+ name='linear',
531
+ in_channels=dim_in,
532
+ out_channels=dim_inner,
533
+ hidden_channels=dim_inner,
534
+ num_layers=layers_pre_mp,
535
+ has_batch_norm=has_bn,
536
+ has_l2_norm=has_l2norm,
537
+ dropout=dropout,
538
+ act=act,
539
+ final_act=final_act,
540
+ )
541
+ dim_in = dim_inner
542
+ if layers_mp > 0:
543
+ self.mp = GNNStackStage(
544
+ in_channels=dim_in,
545
+ out_channels=dim_inner,
546
+ num_layers=layers_mp,
547
+ layer_type=layer_type,
548
+ stage_type=stage_type,
549
+ final_l2_norm=final_l2norm,
550
+ has_batch_norm=has_bn,
551
+ has_l2_norm=has_l2norm,
552
+ dropout=dropout,
553
+ act=act if has_act else None,
554
+ )
555
+
556
+ self.post_mp = GNNInductiveHybridMultiHead(
557
+ dim_inner,
558
+ dim_out,
559
+ num_node_targets,
560
+ num_graph_targets,
561
+ layers_post_mp,
562
+ virtual_node,
563
+ multi_head_dim_inner,
564
+ graph_pooling,
565
+ head_bn,
566
+ has_l2norm,
567
+ dropout,
568
+ act,
569
+ )
570
+
571
+ self.reset_parameters()
572
+
573
+ def reset_parameters(self):
574
+ from torch_geometric.graphgym.init import init_weights
575
+ self.apply(init_weights)
576
+
577
+ @classmethod
578
+ def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
579
+ r"""Returns a pretrained :class:`GPSE` model on a dataset.
580
+
581
+ Args:
582
+ name (str): The name of the dataset (:obj:`"molpcba"`,
583
+ :obj:`"zinc"`, :obj:`"pcqm4mv2"`, :obj:`"geom"`,
584
+ :obj:`"chembl"`).
585
+ root (str, optional): The root directory to save the pre-trained
586
+ model. (default: :obj:`"GPSE_pretrained"`)
587
+ """
588
+ root = osp.expanduser(osp.normpath(root))
589
+ os.makedirs(root, exist_ok=True)
590
+ path = download_url(cls.url_dict[name], root)
591
+
592
+ model = GPSE() # All pretrained models use the default arguments
593
+ model_state = torch.load(path, map_location='cpu')['model_state']
594
+ model_state_new = OrderedDict([(k.split('.', 1)[1], v)
595
+ for k, v in model_state.items()])
596
+ model.load_state_dict(model_state_new)
597
+
598
+ # Set the final linear layer to identity if we use hidden reprs
599
+ if model.use_repr:
600
+ if model.repr_type == 'one_layer_before':
601
+ model.post_mp.layer_post_mp.model[-1] = torch.nn.Identity()
602
+ elif model.repr_type == 'no_post_mp':
603
+ model.post_mp = IdentityHead()
604
+ else:
605
+ raise ValueError(f"Unknown type '{model.repr_type}'")
606
+
607
+ model.eval()
608
+ return model
609
+
610
+ def forward(self, batch):
611
+ for module in self.children():
612
+ batch = module(batch)
613
+ return batch
614
+
615
+
616
+ class GPSENodeEncoder(torch.nn.Module):
617
+ r"""A helper linear/MLP encoder that takes the :class:`GPSE` encodings
618
+ (based on the `"Graph Positional and Structural Encoder"
619
+ <https://arxiv.org/abs/2307.07107>`_ paper) precomputed as
620
+ :obj:`batch.pestat_GPSE` in the input graphs, maps them to a desired
621
+ dimension defined by :obj:`dim_pe_out` and appends them to node features.
622
+
623
+ Let's say we have a graph dataset with 64 original node features, and we
624
+ have generated GPSE encodings of dimension 32, i.e.
625
+ :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an
626
+ inner dimension of 128. To do so, we can map the 32-dimensional GPSE
627
+ encodings to a higher dimension of 64, and then append them to the
628
+ :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to
629
+ obtain a 128-dimensional node feature representation.
630
+ :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and
631
+ concatenation to :obj:`x`, the outputs of which can be used as input to a
632
+ GNN:
633
+
634
+ .. code-block:: python
635
+
636
+ encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
637
+ expand_x=False)
638
+ gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
639
+
640
+ for batch in loader:
641
+ batch = encoder(batch)
642
+ batch = gnn(batch)
643
+ # Do something with the batch, which now includes 128-dimensional
644
+ # node representations
645
+
646
+ Args:
647
+ dim_emb (int): Size of final node embedding.
648
+ dim_pe_in (int): Original dimension of :obj:`batch.pestat_GPSE`.
649
+ dim_pe_out (int): Desired dimension of :class:`GPSE` after the encoder.
650
+ dim_in (int, optional): Original dimension of input node features,
651
+ required only if :obj:`expand_x` is set to :obj:`True`.
652
+ (default: :obj:`None`)
653
+ expand_x (bool, optional): Expand node features :obj:`x` from
654
+ :obj:`dim_in` to (:obj:`dim_emb` - :obj:`dim_pe_out`)
655
+ norm_type (str, optional): Type of normalization to apply.
656
+ (default: :obj:`batchnorm`)
657
+ model_type (str, optional): Type of encoder, either :obj:`mlp` or
658
+ :obj:`linear`. (default: :obj:`mlp`)
659
+ n_layers (int, optional): Number of MLP layers if :obj:`model_type` is
660
+ :obj:`mlp`. (default: :obj:`2`)
661
+ dropout_be (float, optional): Dropout ratio of inputs to encoder, i.e.
662
+ before encoding. (default: :obj:`0.5`)
663
+ dropout_ae (float, optional): Dropout ratio of outputs, i.e. after
664
+ encoding. (default: :obj:`0.2`)
665
+ """
666
+ def __init__(self, dim_emb: int, dim_pe_in: int, dim_pe_out: int,
667
+ dim_in: int = None, expand_x=False, norm_type='batchnorm',
668
+ model_type='mlp', n_layers=2, dropout_be=0.5, dropout_ae=0.2):
669
+ super().__init__()
670
+
671
+ assert dim_emb > dim_pe_out, ('Desired GPSE dimension (dim_pe_out) '
672
+ 'must be smaller than the final node '
673
+ 'embedding dimension (dim_emb).')
674
+
675
+ if expand_x:
676
+ self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe_out)
677
+ self.expand_x = expand_x
678
+
679
+ self.raw_norm = None
680
+ if norm_type == 'batchnorm':
681
+ self.raw_norm = nn.BatchNorm1d(dim_pe_in)
682
+
683
+ self.dropout_be = nn.Dropout(p=dropout_be)
684
+ self.dropout_ae = nn.Dropout(p=dropout_ae)
685
+
686
+ activation = nn.ReLU # register.act_dict[cfg.gnn.act]
687
+ if model_type == 'mlp':
688
+ layers = []
689
+ if n_layers == 1:
690
+ layers.append(torch.nn.Linear(dim_pe_in, dim_pe_out))
691
+ layers.append(activation())
692
+ else:
693
+ layers.append(torch.nn.Linear(dim_pe_in, 2 * dim_pe_out))
694
+ layers.append(activation())
695
+ for _ in range(n_layers - 2):
696
+ layers.append(
697
+ torch.nn.Linear(2 * dim_pe_out, 2 * dim_pe_out))
698
+ layers.append(activation())
699
+ layers.append(torch.nn.Linear(2 * dim_pe_out, dim_pe_out))
700
+ layers.append(activation())
701
+ self.pe_encoder = nn.Sequential(*layers)
702
+ elif model_type == 'linear':
703
+ self.pe_encoder = nn.Linear(dim_pe_in, dim_pe_out)
704
+ else:
705
+ raise ValueError(f"{self.__class__.__name__}: Does not support "
706
+ f"'{model_type}' encoder model.")
707
+
708
+ def forward(self, batch):
709
+ if not hasattr(batch, 'pestat_GPSE'):
710
+ raise ValueError('Precomputed "pestat_GPSE" variable is required '
711
+ 'for GNNNodeEncoder; either run '
712
+ '`precompute_GPSE(gpse_model, dataset)` on your '
713
+ 'dataset or add `AddGPSE(gpse_model)` as a (pre) '
714
+ 'transform.')
715
+
716
+ pos_enc = batch.pestat_GPSE
717
+
718
+ pos_enc = self.dropout_be(pos_enc)
719
+ pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
720
+ pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
721
+ pos_enc = self.dropout_ae(pos_enc)
722
+
723
+ # Expand node features if needed
724
+ h = self.linear_x(batch.x) if self.expand_x else batch.x
725
+
726
+ # Concatenate final PEs to input embedding
727
+ batch.x = torch.cat((h, pos_enc), 1)
728
+
729
+ return batch
730
+
731
+
732
+ @torch.no_grad()
733
+ def gpse_process(model: GPSE, data: Data, rand_type: str, use_vn: bool = True,
734
+ bernoulli_thresh: float = 0.5, neighbor_loader: bool = False,
735
+ num_neighbors: List[int] = [30, 20, 10], fillval: int = 5,
736
+ layers_mp: int = None, **kwargs) -> torch.Tensor:
737
+ r"""Processes the data using the :class:`GPSE` model to generate and append
738
+ GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a
739
+ single :class:`~torch_geometric.data.Dataset` object.
740
+
741
+ Unlike transform-based GPSE processing (i.e.
742
+ :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument
743
+ does not append virtual nodes if set to :obj:`True`, and instead assumes
744
+ the input graphs to :obj:`gpse_process` already have virtual nodes. Under
745
+ normal circumstances, one does not need to call this function; running
746
+ :obj:`precompute_GPSE` on your whole dataset is advised instead.
747
+
748
+ Args:
749
+ model (GPSE): The :class:`GPSE` model.
750
+ data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data`
751
+ object.
752
+ rand_type (str, optional): Type of random features to use. Options are
753
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
754
+ (default: :obj:`NormalSE`)
755
+ use_vn (bool, optional): Whether the input graphs have virtual nodes.
756
+ (default: :obj:`True`)
757
+ bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of
758
+ virtual nodes. (default: :obj:`0.5`)
759
+ neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.
760
+ (default: :obj:`False`)
761
+ num_neighbors (List[int], optional): Number of neighbors to consider
762
+ for each message-passing layer. (default: :obj:`[30, 20, 10]`)
763
+ fillval (int, optional): Value to fill for missing
764
+ :obj:`num_neighbors`. (default: :obj:`5`)
765
+ layers_mp (int, optional): Number of message-passing layers.
766
+ (default: :obj:`None`)
767
+ **kwargs (optional): Additional arguments for :obj:`NeighborLoader`.
768
+
769
+ Returns:
770
+ torch.Tensor: A tensor corresponding to the original
771
+ :class:`~torch_geometric.data.Data` object, with :class:`GPSE`
772
+ encodings appended as :obj:`out.pestat_GPSE` attribute.
773
+ """
774
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
775
+ # Generate random features for the encoder
776
+ n = data.num_nodes
777
+ dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]
778
+
779
+ # Prepare input distributions for GPSE
780
+ if rand_type == 'NormalSE':
781
+ rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))
782
+ elif rand_type == 'UniformSE':
783
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
784
+ elif rand_type == 'BernoulliSE':
785
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
786
+ rand = (rand < bernoulli_thresh)
787
+ else:
788
+ raise ValueError(f'Unknown {rand_type=!r}')
789
+ data.x = torch.from_numpy(rand.astype('float32'))
790
+
791
+ if use_vn:
792
+ data.x[-1] = 0
793
+
794
+ model, data = model.to(device), data.to(device)
795
+ # Generate encodings using the pretrained encoder
796
+ if neighbor_loader:
797
+ if layers_mp is None:
798
+ raise ValueError('Please provide the number of message-passing '
799
+ 'layers as "layers_mp".')
800
+ diff = layers_mp - len(num_neighbors)
801
+ if fillval > 0 and diff > 0:
802
+ num_neighbors += [fillval] * diff
803
+
804
+ loader = NeighborLoader(data, num_neighbors=num_neighbors,
805
+ shuffle=False, pin_memory=True, **kwargs)
806
+ out_list = []
807
+ pbar = trange(data.num_nodes, position=2)
808
+ for i, batch in enumerate(loader):
809
+ out, _ = model(batch.to(device))
810
+ out = out[:batch.batch_size].to("cpu", non_blocking=True)
811
+ out_list.append(out)
812
+ pbar.update(batch.batch_size)
813
+ out = torch.vstack(out_list)
814
+ else:
815
+ out, _ = model(data)
816
+ out = out.to("cpu")
817
+
818
+ return out
819
+
820
+
821
+ @torch.no_grad()
822
+ def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
823
+ bernoulli_thresh: float = 0.5,
824
+ neighbor_loader: bool = False,
825
+ num_neighbors: List[int] = [30, 20, 10],
826
+ fillval: int = 5, layers_mp: int = None,
827
+ **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
828
+ r"""Process a batch of data using the :class:`GPSE` model to generate and
829
+ append :class:`GPSE` encodings. Identical to `gpse_process`, but operates
830
+ on a batch of :class:`~torch_geometric.data.Data` objects.
831
+
832
+ Unlike transform-based GPSE processing (i.e.
833
+ :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument
834
+ does not append virtual nodes if set to :obj:`True`, and instead assumes
835
+ the input graphs to :obj:`gpse_process` already have virtual nodes. This is
836
+ because the virtual nodes are already added to graphs before the call to
837
+ :obj:`gpse_process_batch` in :obj:`precompute_GPSE` for better efficiency.
838
+ Under normal circumstances, one does not need to call this function;
839
+ running :obj:`precompute_GPSE` on your whole dataset is advised instead.
840
+
841
+ Args:
842
+ model (GPSE): The :class:`GPSE` model.
843
+ batch: A batch of PyG Data objects.
844
+ rand_type (str, optional): Type of random features to use. Options are
845
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
846
+ (default: :obj:`NormalSE`)
847
+ use_vn (bool, optional): Whether the input graphs have virtual nodes.
848
+ (default: :obj:`True`)
849
+ bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of
850
+ virtual nodes. (default: :obj:`0.5`)
851
+ neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.
852
+ (default: :obj:`False`)
853
+ num_neighbors (List[int], optional): Number of neighbors to consider
854
+ for each message-passing layer. (default: :obj:`[30, 20, 10]`)
855
+ fillval (int, optional): Value to fill for missing
856
+ :obj:`num_neighbors`. (default: :obj:`5`)
857
+ layers_mp (int, optional): Number of message-passing layers.
858
+ (default: :obj:`None`)
859
+ **kwargs: Additional keyword arguments for :obj:`NeighborLoader`.
860
+
861
+ Returns:
862
+ Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding
863
+ to the stacked :class:`GPSE` encodings and the pointers indicating
864
+ individual graphs.
865
+ """
866
+ n = batch.num_nodes
867
+ dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]
868
+
869
+ # Prepare input distributions for GPSE
870
+ if rand_type == 'NormalSE':
871
+ rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))
872
+ elif rand_type == 'UniformSE':
873
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
874
+ elif rand_type == 'BernoulliSE':
875
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
876
+ rand = (rand < bernoulli_thresh)
877
+ else:
878
+ raise ValueError(f'Unknown {rand_type=!r}')
879
+ batch.x = torch.from_numpy(rand.astype('float32'))
880
+
881
+ if use_vn:
882
+ # HACK: We need to reset virtual node features to zeros to match the
883
+ # pretraining setting (virtual node applied after random node features
884
+ # are set, and the default node features for the virtual node are all
885
+ # zeros). Can potentially test if initializing virtual node features to
886
+ # random features is better than setting them to zeros.
887
+ for i in batch.ptr[1:]:
888
+ batch.x[i - 1] = 0
889
+
890
+ # Generate encodings using the pretrained encoder
891
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
892
+ model = model.to(device)
893
+ if neighbor_loader:
894
+ if layers_mp is None:
895
+ raise ValueError('Please provide the number of message-passing '
896
+ 'layers as "layers_mp".')
897
+ diff = layers_mp - len(num_neighbors)
898
+ if fillval > 0 and diff > 0:
899
+ num_neighbors += [fillval] * diff
900
+
901
+ loader = NeighborLoader(batch, num_neighbors=num_neighbors,
902
+ shuffle=False, pin_memory=True, **kwargs)
903
+ out_list = []
904
+ pbar = trange(batch.num_nodes, position=2)
905
+ for i, batch in enumerate(loader):
906
+ out, _ = model(batch.to(device))
907
+ out = out[:batch.batch_size].to('cpu', non_blocking=True)
908
+ out_list.append(out)
909
+ pbar.update(batch.batch_size)
910
+ out = torch.vstack(out_list)
911
+ else:
912
+ out, _ = model(batch.to(device))
913
+ out = out.to('cpu')
914
+
915
+ return out, batch.ptr
916
+
917
+
918
+ @torch.no_grad()
919
+ def precompute_GPSE(model: GPSE, dataset: Dataset, use_vn: bool = True,
920
+ rand_type: str = 'NormalSE', **kwargs):
921
+ r"""Precomputes :class:`GPSE` encodings in-place for a given dataset using
922
+ a :class:`GPSE` model.
923
+
924
+ Args:
925
+ model (GPSE): The :class:`GPSE` model.
926
+ dataset (Dataset): A PyG Dataset.
927
+ use_vn (bool, optional): Whether to append virtual nodes to graphs in
928
+ :class:`GPSE` computation. Should match the setting used when
929
+ pre-training the :class:`GPSE` model. (default :obj:`True`)
930
+ rand_type (str, optional): The type of randomization to use.
931
+ (default :obj:`NormalSE`)
932
+ **kwargs (optional): Additional arguments for
933
+ :class:`~torch_geometric.data.DataLoader`.
934
+ """
935
+ # Temporarily replace the transformation
936
+ orig_dataset_transform = dataset.transform
937
+ dataset.transform = None
938
+ if use_vn:
939
+ dataset.transform = T.VirtualNode()
940
+
941
+ # Remove split indices, to be recovered at the end of the precomputation
942
+ tmp_store = {}
943
+ for name in [
944
+ 'train_mask', 'val_mask', 'test_mask', 'train_graph_index',
945
+ 'val_graph_index', 'test_graph_index', 'train_edge_index',
946
+ 'val_edge_index', 'test_edge_index'
947
+ ]:
948
+ if (name in dataset.data) and (dataset.slices is None
949
+ or name in dataset.slices):
950
+ tmp_store_data = dataset.data.pop(name)
951
+ tmp_store_slices = dataset.slices.pop(name) \
952
+ if dataset.slices else None
953
+ tmp_store[name] = (tmp_store_data, tmp_store_slices)
954
+
955
+ loader = DataLoader(dataset, shuffle=False, pin_memory=True, **kwargs)
956
+
957
+ # Batched GPSE precomputation loop
958
+ data_list = []
959
+ curr_idx = 0
960
+ pbar = trange(len(dataset), desc='Pre-computing GPSE')
961
+ tic = time.perf_counter()
962
+ for batch in loader:
963
+ batch_out, batch_ptr = gpse_process_batch(model, batch, rand_type,
964
+ **kwargs)
965
+
966
+ batch_out = batch_out.to('cpu', non_blocking=True)
967
+ # Need to wait for batch_ptr to finish transfering so that start and
968
+ # end indices are ready to use
969
+ batch_ptr = batch_ptr.to('cpu', non_blocking=False)
970
+
971
+ for start, end in zip(batch_ptr[:-1], batch_ptr[1:]):
972
+ data = dataset.get(curr_idx)
973
+ if use_vn:
974
+ end = end - 1
975
+ data.pestat_GPSE = batch_out[start:end]
976
+ data_list.append(data)
977
+ curr_idx += 1
978
+
979
+ pbar.update(len(batch_ptr) - 1)
980
+ pbar.close()
981
+
982
+ # Collate dataset and reset indicies and data list
983
+ dataset.transform = orig_dataset_transform
984
+ dataset._indices = None
985
+ dataset._data_list = data_list
986
+ dataset.data, dataset.slices = dataset.collate(data_list)
987
+
988
+ # Recover split indices
989
+ for name, (tmp_store_data, tmp_store_slices) in tmp_store.items():
990
+ dataset.data[name] = tmp_store_data
991
+ if tmp_store_slices is not None:
992
+ dataset.slices[name] = tmp_store_slices
993
+ dataset._data_list = None
994
+
995
+ timestr = time.strftime('%H:%M:%S', time.gmtime(time.perf_counter() - tic))
996
+ logging.info(f'Finished GPSE pre-computation, took {timestr}')
997
+
998
+ # Release resource and recover original configs
999
+ del model
1000
+ torch.cuda.empty_cache()
1001
+
1002
+
1003
+ def cosim_col_sep(pred: torch.Tensor, true: torch.Tensor,
1004
+ batch_idx: torch.Tensor) -> torch.Tensor:
1005
+ r"""Calculates the average cosine similarity between predicted and true
1006
+ features on a batch of graphs.
1007
+
1008
+ Args:
1009
+ pred (torch.Tensor): Predicted outputs.
1010
+ true (torch.Tensor): Value of ground truths.
1011
+ batch_idx (torch.Tensor): Batch indices to separate the graphs.
1012
+
1013
+ Returns:
1014
+ torch.Tensor: Average cosine similarity per graph in batch.
1015
+
1016
+ Raises:
1017
+ ValueError: If batch_index is not specified.
1018
+ """
1019
+ if batch_idx is None:
1020
+ raise ValueError("mae_cosim_col_sep requires batch index as "
1021
+ "input to distinguish different graphs.")
1022
+ batch_idx = batch_idx + 1 if batch_idx.min() == -1 else batch_idx
1023
+ pred_dense = to_dense_batch(pred, batch_idx)[0]
1024
+ true_dense = to_dense_batch(true, batch_idx)[0]
1025
+ mask = (true_dense == 0).all(1) # exclude trivial features from loss
1026
+ loss = 1 - F.cosine_similarity(pred_dense, true_dense, dim=1)[~mask].mean()
1027
+ return loss
1028
+
1029
+
1030
+ def gpse_loss(pred: torch.Tensor, true: torch.Tensor,
1031
+ batch_idx: torch.Tensor = None) \
1032
+ -> Tuple[torch.Tensor, torch.Tensor]:
1033
+ r"""Calculates :class:`GPSE` loss as the sum of MAE loss and cosine
1034
+ similarity loss over a batch of graphs.
1035
+
1036
+ Args:
1037
+ pred (torch.Tensor): Predicted outputs.
1038
+ true (torch.Tensor): Value of ground truths.
1039
+ batch_idx (torch.Tensor): Batch indices to separate the graphs.
1040
+
1041
+ Returns:
1042
+ Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding
1043
+ to the :class:`GPSE` loss and the predicted node-and-graph level
1044
+ outputs.
1045
+
1046
+ """
1047
+ if batch_idx is None:
1048
+ raise ValueError("mae_cosim_col_sep requires batch index as "
1049
+ "input to distinguish different graphs.")
1050
+ mae_loss = F.l1_loss(pred, true)
1051
+ cosim_loss = cosim_col_sep(pred, true, batch_idx)
1052
+ loss = mae_loss + cosim_loss
1053
+ return loss, pred
1054
+
1055
+
1056
+ def process_batch_idx(batch_idx, true, use_vn=True):
1057
+ r"""Processes batch indices to adjust for the removal of virtual nodes, and
1058
+ pads batch index for hybrid tasks.
1059
+
1060
+ Args:
1061
+ batch_idx: Batch indices to separate the graphs.
1062
+ true: Value of ground truths.
1063
+ use_vn: If input graphs have virtual nodes that need to be removed.
1064
+
1065
+ Returns:
1066
+ torch.Tensor: Batch indices that separate the graphs.
1067
+ """
1068
+ if batch_idx is None:
1069
+ return
1070
+ if use_vn: # remove virtual node
1071
+ batch_idx = torch.concat([
1072
+ batch_idx[batch_idx == i][:-1]
1073
+ for i in range(batch_idx.max().item() + 1)
1074
+ ])
1075
+ # Pad batch index for hybrid tasks (set batch index for graph heads to -1)
1076
+ if (pad := true.shape[0] - batch_idx.shape[0]) > 0:
1077
+ pad_idx = -torch.ones(pad, dtype=torch.long, device=batch_idx.device)
1078
+ batch_idx = torch.hstack([batch_idx, pad_idx])
1079
+ return batch_idx
@@ -37,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
37
37
  from .largest_connected_components import LargestConnectedComponents
38
38
  from .virtual_node import VirtualNode
39
39
  from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
40
+ from .add_gpse import AddGPSE
40
41
  from .feature_propagation import FeaturePropagation
41
42
  from .half_hop import HalfHop
42
43
 
@@ -108,6 +109,7 @@ graph_transforms = [
108
109
  'VirtualNode',
109
110
  'AddLaplacianEigenvectorPE',
110
111
  'AddRandomWalkPE',
112
+ 'AddGPSE',
111
113
  'FeaturePropagation',
112
114
  'HalfHop',
113
115
  ]
@@ -0,0 +1,39 @@
1
+ from torch_geometric.data import Data
2
+ from torch_geometric.data.datapipes import functional_transform
3
+ from torch_geometric.nn.models.gpse import GPSE
4
+ from torch_geometric.transforms import BaseTransform, VirtualNode
5
+
6
+
7
+ @functional_transform('add_gpse')
8
+ class AddGPSE(BaseTransform):
9
+ r"""Adds the GPSE encoding from the `"Graph Positional and Structural
10
+ Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
11
+ (functional name: :obj:`add_gpse`).
12
+ To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
13
+ the actual encodings.
14
+
15
+ Args:
16
+ model (GPSE): The pre-trained GPSE model.
17
+ use_vn (bool, optional): Whether to use virtual nodes.
18
+ (default: :obj:`True`)
19
+ rand_type (str, optional): Type of random features to use. Options are
20
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
21
+ (default: :obj:`NormalSE`)
22
+
23
+ """
24
+ def __init__(self, model: GPSE, use_vn: bool = True,
25
+ rand_type: str = 'NormalSE'):
26
+ self.model = model
27
+ self.use_vn = use_vn
28
+ self.vn = VirtualNode()
29
+ self.rand_type = rand_type
30
+
31
+ def __call__(self, data: Data) -> Data:
32
+ from torch_geometric.nn.models.gpse import gpse_process
33
+
34
+ data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
35
+ batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
36
+ batch_out = batch_out.to('cpu', non_blocking=True)
37
+ data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
38
+
39
+ return data