pyg-nightly 2.7.0.dev20250404__py3-none-any.whl → 2.7.0.dev20250405__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250404.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250404.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/RECORD +9 -7
- torch_geometric/__init__.py +1 -1
- torch_geometric/nn/models/__init__.py +3 -0
- torch_geometric/nn/models/gpse.py +1079 -0
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +39 -0
- {pyg_nightly-2.7.0.dev20250404.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250404.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250404.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250405
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=Gv9oWwU7Y-5CKcoO3SA059-jYYjsTMcsubgypKMaUcQ,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -426,7 +426,7 @@ torch_geometric/nn/kge/distmult.py,sha256=dGQ0bVzjreZgFN1lXE23_IIidsiOq7ehPrMb-N
|
|
426
426
|
torch_geometric/nn/kge/loader.py,sha256=5Uc1j3OUMQnBYSHDqL7pLCty1siFLzoPkztigYO2zP8,771
|
427
427
|
torch_geometric/nn/kge/rotate.py,sha256=XLuO1AbyTt5cJxr97ZzoyAyIEsHKesgW5TvDmnGJAao,3208
|
428
428
|
torch_geometric/nn/kge/transe.py,sha256=jlejq5BLMm-sb1wWcLDp7pZqCdelWBgjDIC8ctbjSdU,3088
|
429
|
-
torch_geometric/nn/models/__init__.py,sha256=
|
429
|
+
torch_geometric/nn/models/__init__.py,sha256=4mZ5dyiZ9aa1NaBth1qYV-hZdnG_Np1XWvRLB4Qv6RM,2338
|
430
430
|
torch_geometric/nn/models/attentive_fp.py,sha256=tkgvw28wg9-JqHIfBllfCwTHrZIUiv85yZJcDqjz3z0,6634
|
431
431
|
torch_geometric/nn/models/attract_repel.py,sha256=h9OyogT0NY0xiT0DkpJHMxH6ZUmo8R-CmwZdKEwq8Ek,5277
|
432
432
|
torch_geometric/nn/models/autoencoder.py,sha256=nGje-zty78Y3hxOJ9o0_6QziJjOvBlknk6z0_fDQwQU,10770
|
@@ -441,6 +441,7 @@ torch_geometric/nn/models/g_retriever.py,sha256=CdSOasnPiMvq5AjduNTpz-LIZiNp3X0x
|
|
441
441
|
torch_geometric/nn/models/git_mol.py,sha256=Wc6Hx6RDDR7sDWRWHfA5eK9e9gFsrTZ9OLmpMfoj3pE,12676
|
442
442
|
torch_geometric/nn/models/glem.py,sha256=sT0XM4klVlci9wduvUoXupATUw9p25uXtaJBrmv3yvs,16431
|
443
443
|
torch_geometric/nn/models/gnnff.py,sha256=15dkiLgy0LmH1hnUrpeoHioIp4BPTfjpVATpnGRt9E0,7860
|
444
|
+
torch_geometric/nn/models/gpse.py,sha256=my-KIw_Ov8o0pXSCyh43NZRBAW95TFfmBgxzSimx8-A,42680
|
444
445
|
torch_geometric/nn/models/graph_mixer.py,sha256=mthMeCOikR8gseEsu4oJ3Cd9C35zHSv1p32ROwnG-6s,9246
|
445
446
|
torch_geometric/nn/models/graph_unet.py,sha256=N8TSmJo8AlbZjjcame0xW_jZvMOirL5ahw6qv5Yjpbs,5586
|
446
447
|
torch_geometric/nn/models/jumping_knowledge.py,sha256=9JR2EoViXKjcDSLb8tjJm-UHfv1mQCJvZAAEsYa0Ocw,5496
|
@@ -521,7 +522,8 @@ torch_geometric/testing/decorators.py,sha256=j45wlxMB1-Pn3wPKBgDziqg6KkWJUb_fcwf
|
|
521
522
|
torch_geometric/testing/distributed.py,sha256=ZZCCXqiQC4-m1ExSjDZhS_a1qPXnHEwhJGTmACxNnVI,2227
|
522
523
|
torch_geometric/testing/feature_store.py,sha256=J6JBIt2XK-t8yG8B4JzXp-aJcVl5jaCS1m2H7d6OUxs,2158
|
523
524
|
torch_geometric/testing/graph_store.py,sha256=00B7QToCIspYmgN7svQKp1iU-qAzEtrt3VQRFxkHfuk,1044
|
524
|
-
torch_geometric/transforms/__init__.py,sha256=
|
525
|
+
torch_geometric/transforms/__init__.py,sha256=P0R2CFg9pXxjTX4NnYfNPrifRPAw5lVXEOxO80q-1Ek,4296
|
526
|
+
torch_geometric/transforms/add_gpse.py,sha256=4o0UrSmTu3CKsL3UAREiul8O4lC02PUx_ajxP4sPsxU,1570
|
525
527
|
torch_geometric/transforms/add_metapaths.py,sha256=GabaPRvUnpFrZJsxLMUBY2Egzx94GTgsMxegL_qTtbk,14239
|
526
528
|
torch_geometric/transforms/add_positional_encoding.py,sha256=tuilyubAn3yeyz8mvFc5zxXTlNzh8okKzG9AE2lPG1Q,6049
|
527
529
|
torch_geometric/transforms/add_remaining_self_loops.py,sha256=ItU5FAcE-mkbp_wqTLkRhv0RShR5JVr8vr9d5xv3_Ak,2085
|
@@ -634,7 +636,7 @@ torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5
|
|
634
636
|
torch_geometric/visualization/__init__.py,sha256=PyR_4K5SafsJrBr6qWrkjKr6GBL1b7FtZybyXCDEVwY,154
|
635
637
|
torch_geometric/visualization/graph.py,sha256=ZuLPL92yGRi7lxlqsUPwL_EVVXF7P2kMcveTtW79vpA,4784
|
636
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
637
|
-
pyg_nightly-2.7.0.
|
638
|
-
pyg_nightly-2.7.0.
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250405.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250405.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250405.dist-info/METADATA,sha256=9NrBoIq1lq6tX1hOXmse3Dvp7PZNQjcsoyGzucHSgp8,63021
|
642
|
+
pyg_nightly-2.7.0.dev20250405.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250405'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -12,6 +12,7 @@ from .re_net import RENet
|
|
12
12
|
from .graph_unet import GraphUNet
|
13
13
|
from .schnet import SchNet
|
14
14
|
from .dimenet import DimeNet, DimeNetPlusPlus
|
15
|
+
from .gpse import GPSE, GPSENodeEncoder
|
15
16
|
from .captum import to_captum_model
|
16
17
|
from .metapath2vec import MetaPath2Vec
|
17
18
|
from .deepgcn import DeepGCNLayer
|
@@ -62,6 +63,8 @@ __all__ = classes = [
|
|
62
63
|
'SchNet',
|
63
64
|
'DimeNet',
|
64
65
|
'DimeNetPlusPlus',
|
66
|
+
'GPSE',
|
67
|
+
'GPSENodeEncoder',
|
65
68
|
'to_captum_model',
|
66
69
|
'to_captum_input',
|
67
70
|
'captum_output_to_dicts',
|
@@ -0,0 +1,1079 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import os.path as osp
|
4
|
+
import time
|
5
|
+
from collections import OrderedDict
|
6
|
+
from typing import List, Optional, Tuple
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.nn.functional as F
|
12
|
+
from tqdm import trange
|
13
|
+
|
14
|
+
import torch_geometric.transforms as T
|
15
|
+
from torch_geometric.data import Data, Dataset, download_url
|
16
|
+
from torch_geometric.loader import DataLoader, NeighborLoader
|
17
|
+
from torch_geometric.nn import (
|
18
|
+
ResGatedGraphConv,
|
19
|
+
global_add_pool,
|
20
|
+
global_max_pool,
|
21
|
+
global_mean_pool,
|
22
|
+
)
|
23
|
+
from torch_geometric.nn.resolver import activation_resolver
|
24
|
+
from torch_geometric.utils import to_dense_batch
|
25
|
+
|
26
|
+
|
27
|
+
class Linear(torch.nn.Module):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
in_channels: int,
|
31
|
+
out_channels: int,
|
32
|
+
bias: bool,
|
33
|
+
) -> None:
|
34
|
+
super().__init__()
|
35
|
+
self.model = torch.nn.Linear(in_channels, out_channels, bias=bias)
|
36
|
+
|
37
|
+
def forward(self, batch):
|
38
|
+
if isinstance(batch, torch.Tensor):
|
39
|
+
batch = self.model(batch)
|
40
|
+
else:
|
41
|
+
batch.x = self.model(batch.x)
|
42
|
+
return batch
|
43
|
+
|
44
|
+
|
45
|
+
class ResGatedGCNConv(torch.nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
in_channels: int,
|
49
|
+
out_channels: int,
|
50
|
+
bias: bool,
|
51
|
+
**kwargs,
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
self.model = ResGatedGraphConv(
|
55
|
+
in_channels,
|
56
|
+
out_channels,
|
57
|
+
bias=bias,
|
58
|
+
**kwargs,
|
59
|
+
)
|
60
|
+
|
61
|
+
def forward(self, batch):
|
62
|
+
batch.x = self.model(batch.x, batch.edge_index)
|
63
|
+
return batch
|
64
|
+
|
65
|
+
|
66
|
+
class GeneralLayer(torch.nn.Module):
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
name: str,
|
70
|
+
in_channels: int,
|
71
|
+
out_channels: int,
|
72
|
+
has_batch_norm: bool,
|
73
|
+
has_l2_norm: bool,
|
74
|
+
dropout: float,
|
75
|
+
act: Optional[str],
|
76
|
+
**kwargs,
|
77
|
+
):
|
78
|
+
super().__init__()
|
79
|
+
self.has_l2_norm = has_l2_norm
|
80
|
+
|
81
|
+
layer_dict = {
|
82
|
+
'linear': Linear,
|
83
|
+
'resgatedgcnconv': ResGatedGCNConv,
|
84
|
+
}
|
85
|
+
self.layer = layer_dict[name](
|
86
|
+
in_channels,
|
87
|
+
out_channels,
|
88
|
+
bias=not has_batch_norm,
|
89
|
+
**kwargs,
|
90
|
+
)
|
91
|
+
post_layers = []
|
92
|
+
if has_batch_norm:
|
93
|
+
post_layers.append(
|
94
|
+
torch.nn.BatchNorm1d(out_channels, eps=1e-5, momentum=0.1))
|
95
|
+
if dropout > 0:
|
96
|
+
post_layers.append(torch.nn.Dropout(p=dropout, inplace=False))
|
97
|
+
if act is not None:
|
98
|
+
post_layers.append(activation_resolver(act))
|
99
|
+
self.post_layer = nn.Sequential(*post_layers)
|
100
|
+
|
101
|
+
def forward(self, batch):
|
102
|
+
batch = self.layer(batch)
|
103
|
+
if isinstance(batch, torch.Tensor):
|
104
|
+
batch = self.post_layer(batch)
|
105
|
+
if self.has_l2_norm:
|
106
|
+
batch = F.normalize(batch, p=2, dim=1)
|
107
|
+
else:
|
108
|
+
batch.x = self.post_layer(batch.x)
|
109
|
+
if self.has_l2_norm:
|
110
|
+
batch.x = F.normalize(batch.x, p=2, dim=1)
|
111
|
+
return batch
|
112
|
+
|
113
|
+
|
114
|
+
class GeneralMultiLayer(torch.nn.Module):
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
name: str,
|
118
|
+
in_channels: int,
|
119
|
+
out_channels: int,
|
120
|
+
hidden_channels: Optional[int],
|
121
|
+
num_layers: int,
|
122
|
+
has_batch_norm: bool,
|
123
|
+
has_l2_norm: bool,
|
124
|
+
dropout: float,
|
125
|
+
act: str,
|
126
|
+
final_act: bool,
|
127
|
+
**kwargs,
|
128
|
+
) -> None:
|
129
|
+
super().__init__()
|
130
|
+
hidden_channels = hidden_channels or out_channels
|
131
|
+
|
132
|
+
for i in range(num_layers):
|
133
|
+
d_in = in_channels if i == 0 else hidden_channels
|
134
|
+
d_out = out_channels if i == num_layers - 1 else hidden_channels
|
135
|
+
layer = GeneralLayer(
|
136
|
+
name=name,
|
137
|
+
in_channels=d_in,
|
138
|
+
out_channels=d_out,
|
139
|
+
has_batch_norm=has_batch_norm,
|
140
|
+
has_l2_norm=has_l2_norm,
|
141
|
+
dropout=dropout,
|
142
|
+
act=None if i == num_layers - 1 and not final_act else act,
|
143
|
+
**kwargs,
|
144
|
+
)
|
145
|
+
self.add_module(f'Layer_{i}', layer)
|
146
|
+
|
147
|
+
def forward(self, batch):
|
148
|
+
for layer in self.children():
|
149
|
+
batch = layer(batch)
|
150
|
+
return batch
|
151
|
+
|
152
|
+
|
153
|
+
class BatchNorm1dNode(torch.nn.Module):
|
154
|
+
def __init__(self, channels: int) -> None:
|
155
|
+
super().__init__()
|
156
|
+
self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)
|
157
|
+
|
158
|
+
def forward(self, batch):
|
159
|
+
batch.x = self.bn(batch.x)
|
160
|
+
return batch
|
161
|
+
|
162
|
+
|
163
|
+
class BatchNorm1dEdge(torch.nn.Module):
|
164
|
+
def __init__(self, channels: int) -> None:
|
165
|
+
super().__init__()
|
166
|
+
self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)
|
167
|
+
|
168
|
+
def forward(self, batch):
|
169
|
+
batch.edge_attr = self.bn(batch.edge_attr)
|
170
|
+
return batch
|
171
|
+
|
172
|
+
|
173
|
+
class MLP(torch.nn.Module):
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
in_channels: int,
|
177
|
+
out_channels: int,
|
178
|
+
hidden_channels: Optional[int],
|
179
|
+
num_layers: int,
|
180
|
+
has_batch_norm: bool = True,
|
181
|
+
has_l2_norm: bool = True,
|
182
|
+
dropout: float = 0.2,
|
183
|
+
act: str = 'relu',
|
184
|
+
**kwargs,
|
185
|
+
):
|
186
|
+
super().__init__()
|
187
|
+
hidden_channels = hidden_channels or in_channels
|
188
|
+
|
189
|
+
layers = []
|
190
|
+
if num_layers > 1:
|
191
|
+
layer = GeneralMultiLayer(
|
192
|
+
'linear',
|
193
|
+
in_channels,
|
194
|
+
hidden_channels,
|
195
|
+
hidden_channels,
|
196
|
+
num_layers - 1,
|
197
|
+
has_batch_norm,
|
198
|
+
has_l2_norm,
|
199
|
+
dropout,
|
200
|
+
act,
|
201
|
+
final_act=True,
|
202
|
+
**kwargs,
|
203
|
+
)
|
204
|
+
layers.append(layer)
|
205
|
+
layers.append(Linear(hidden_channels, out_channels, bias=True))
|
206
|
+
self.model = nn.Sequential(*layers)
|
207
|
+
|
208
|
+
def forward(self, batch):
|
209
|
+
if isinstance(batch, torch.Tensor):
|
210
|
+
batch = self.model(batch)
|
211
|
+
else:
|
212
|
+
batch.x = self.model(batch.x)
|
213
|
+
return batch
|
214
|
+
|
215
|
+
|
216
|
+
class GNNStackStage(torch.nn.Module):
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
in_channels: int,
|
220
|
+
out_channels: int,
|
221
|
+
num_layers: int,
|
222
|
+
layer_type: str,
|
223
|
+
stage_type: str = 'skipsum',
|
224
|
+
final_l2_norm: bool = True,
|
225
|
+
has_batch_norm: bool = True,
|
226
|
+
has_l2_norm: bool = True,
|
227
|
+
dropout: float = 0.2,
|
228
|
+
act: Optional[str] = 'relu',
|
229
|
+
):
|
230
|
+
super().__init__()
|
231
|
+
self.num_layers = num_layers
|
232
|
+
self.stage_type = stage_type
|
233
|
+
self.final_l2_norm = final_l2_norm
|
234
|
+
|
235
|
+
for i in range(num_layers):
|
236
|
+
if stage_type == 'skipconcat':
|
237
|
+
if i == 0:
|
238
|
+
d_in = in_channels
|
239
|
+
else:
|
240
|
+
d_in = in_channels + i * out_channels
|
241
|
+
else:
|
242
|
+
d_in = in_channels if i == 0 else out_channels
|
243
|
+
layer = GeneralLayer(layer_type, d_in, out_channels,
|
244
|
+
has_batch_norm, has_l2_norm, dropout, act)
|
245
|
+
self.add_module(f'layer{i}', layer)
|
246
|
+
|
247
|
+
def forward(self, batch):
|
248
|
+
for i, layer in enumerate(self.children()):
|
249
|
+
x = batch.x
|
250
|
+
batch = layer(batch)
|
251
|
+
if self.stage_type == 'skipsum':
|
252
|
+
batch.x = x + batch.x
|
253
|
+
elif self.stage_type == 'skipconcat' and i < self.num_layers - 1:
|
254
|
+
batch.x = torch.cat([x, batch.x], dim=1)
|
255
|
+
|
256
|
+
if self.final_l2_norm:
|
257
|
+
batch.x = F.normalize(batch.x, p=2, dim=-1)
|
258
|
+
|
259
|
+
return batch
|
260
|
+
|
261
|
+
|
262
|
+
class GNNInductiveHybridMultiHead(torch.nn.Module):
|
263
|
+
r"""GNN prediction head for inductive node and graph prediction tasks using
|
264
|
+
individual MLP for each task.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
dim_in (int): Input dimension.
|
268
|
+
dim_out (int): Output dimension. Not used, as the dimension is
|
269
|
+
determined by :obj:`num_node_targets` and :obj:`num_graph_targets`
|
270
|
+
instead.
|
271
|
+
num_node_targets (int): Number of individual PSEs used as node-level
|
272
|
+
targets in pretraining :class:`GPSE`.
|
273
|
+
num_graph_targets (int): Number of graph-level targets used in
|
274
|
+
pretraining :class:`GPSE`.
|
275
|
+
layers_post_mp (int): Number of MLP layers after GNN message-passing.
|
276
|
+
virtual_node (bool, optional): Whether a virtual node is added to
|
277
|
+
graphs in :class:`GPSE` computation. (default: :obj:`True`)
|
278
|
+
multi_head_dim_inner (int, optional): Width of MLPs for PSE target
|
279
|
+
prediction heads. (default: :obj:`32`)
|
280
|
+
graph_pooling (str, optional): Type of graph pooling applied before
|
281
|
+
post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
|
282
|
+
(default: :obj:`add`)
|
283
|
+
has_bn (bool, optional): Whether to apply batch normalization to layer
|
284
|
+
outputs. (default: :obj:`True`)
|
285
|
+
has_l2norm (bool, optional): Wheter to apply L2 normalization to the
|
286
|
+
layer outputs. (default: :obj:`True`)
|
287
|
+
dropout (float, optional): Dropout ratio at layer output.
|
288
|
+
(default: :obj:`0.2`)
|
289
|
+
act (str, optional): Activation to apply to layer outputs if
|
290
|
+
:obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
|
291
|
+
"""
|
292
|
+
def __init__(
|
293
|
+
self,
|
294
|
+
dim_in: int,
|
295
|
+
dim_out: int,
|
296
|
+
num_node_targets: int,
|
297
|
+
num_graph_targets: int,
|
298
|
+
layers_post_mp: int,
|
299
|
+
virtual_node: bool = True,
|
300
|
+
multi_head_dim_inner: int = 32,
|
301
|
+
graph_pooling: str = 'add',
|
302
|
+
has_bn: bool = True,
|
303
|
+
has_l2norm: bool = True,
|
304
|
+
dropout: float = 0.2,
|
305
|
+
act: str = 'relu',
|
306
|
+
):
|
307
|
+
super().__init__()
|
308
|
+
pool_dict = {
|
309
|
+
'add': global_add_pool,
|
310
|
+
'max': global_max_pool,
|
311
|
+
'mean': global_mean_pool
|
312
|
+
}
|
313
|
+
self.node_target_dim = num_node_targets
|
314
|
+
self.graph_target_dim = num_graph_targets
|
315
|
+
self.virtual_node = virtual_node
|
316
|
+
num_layers = layers_post_mp
|
317
|
+
|
318
|
+
self.node_post_mps = nn.ModuleList([
|
319
|
+
MLP(dim_in, 1, multi_head_dim_inner, num_layers, has_bn,
|
320
|
+
has_l2norm, dropout, act) for _ in range(self.node_target_dim)
|
321
|
+
])
|
322
|
+
|
323
|
+
self.graph_pooling = pool_dict[graph_pooling]
|
324
|
+
|
325
|
+
self.graph_post_mp = MLP(dim_in, self.graph_target_dim, dim_in,
|
326
|
+
num_layers, has_bn, has_l2norm, dropout, act)
|
327
|
+
|
328
|
+
def _pad_and_stack(self, x1: torch.Tensor, x2: torch.Tensor, pad1: int,
|
329
|
+
pad2: int):
|
330
|
+
padded_x1 = nn.functional.pad(x1, (0, pad2))
|
331
|
+
padded_x2 = nn.functional.pad(x2, (pad1, 0))
|
332
|
+
return torch.vstack([padded_x1, padded_x2])
|
333
|
+
|
334
|
+
def _apply_index(self, batch, virtual_node: bool, pad_node: int,
|
335
|
+
pad_graph: int):
|
336
|
+
graph_pred, graph_true = batch.graph_feature, batch.y_graph
|
337
|
+
node_pred, node_true = batch.node_feature, batch.y
|
338
|
+
if virtual_node:
|
339
|
+
# Remove virtual node
|
340
|
+
idx = torch.concat([
|
341
|
+
torch.where(batch.batch == i)[0][:-1]
|
342
|
+
for i in range(batch.batch.max().item() + 1)
|
343
|
+
])
|
344
|
+
node_pred, node_true = node_pred[idx], node_true[idx]
|
345
|
+
|
346
|
+
# Stack node predictions on top of graph predictions and pad with zeros
|
347
|
+
pred = self._pad_and_stack(node_pred, graph_pred, pad_node, pad_graph)
|
348
|
+
true = self._pad_and_stack(node_true, graph_true, pad_node, pad_graph)
|
349
|
+
|
350
|
+
return pred, true
|
351
|
+
|
352
|
+
def forward(self, batch):
|
353
|
+
batch.node_feature = torch.hstack(
|
354
|
+
[m(batch.x) for m in self.node_post_mps])
|
355
|
+
graph_emb = self.graph_pooling(batch.x, batch.batch)
|
356
|
+
batch.graph_feature = self.graph_post_mp(graph_emb)
|
357
|
+
return self._apply_index(batch, self.virtual_node,
|
358
|
+
self.node_target_dim, self.graph_target_dim)
|
359
|
+
|
360
|
+
|
361
|
+
class IdentityHead(torch.nn.Module):
|
362
|
+
def forward(self, batch):
|
363
|
+
return batch.x, batch.y
|
364
|
+
|
365
|
+
|
366
|
+
class GPSE(torch.nn.Module):
|
367
|
+
r"""The Graph Positional and Structural Encoder (GPSE) model from the
|
368
|
+
`"Graph Positional and Structural Encoder"
|
369
|
+
<https://arxiv.org/abs/2307.07107>`_ paper.
|
370
|
+
|
371
|
+
The GPSE model consists of a (1) deep GNN that consists of stacked
|
372
|
+
message passing layers, and a (2) prediction head to predict pre-computed
|
373
|
+
positional and structural encodings (PSE).
|
374
|
+
When used on downstream datasets, these prediction heads are removed and
|
375
|
+
the final fully-connected layer outputs are used as learned PSE embeddings.
|
376
|
+
|
377
|
+
GPSE also provides a static method :meth:`from_pretrained` to load
|
378
|
+
pre-trained GPSE models trained on a variety of molecular datasets.
|
379
|
+
|
380
|
+
.. code-block:: python
|
381
|
+
|
382
|
+
from torch_geometric.nn import GPSE, GPSENodeEncoder,
|
383
|
+
from torch_geometric.transforms import AddGPSE
|
384
|
+
from torch_geometric.nn.models.gpse import precompute_GPSE
|
385
|
+
|
386
|
+
gpse_model = GPSE.from_pretrained('molpcba')
|
387
|
+
|
388
|
+
# Option 1: Precompute GPSE encodings in-place for a given dataset
|
389
|
+
dataset = ZINC(path, subset=True, split='train')
|
390
|
+
precompute_gpse(gpse_model, dataset)
|
391
|
+
|
392
|
+
# Option 2: Use the GPSE model with AddGPSE as a pre_transform to save
|
393
|
+
# the encodings
|
394
|
+
dataset = ZINC(path, subset=True, split='train',
|
395
|
+
pre_transform=AddGPSE(gpse_model, vn=True,
|
396
|
+
rand_type='NormalSE'))
|
397
|
+
|
398
|
+
Both approaches append the generated encodings to the :obj:`pestat_GPSE`
|
399
|
+
attribute of :class:`~torch_geometric.data.Data` objects. To use the GPSE
|
400
|
+
encodings for a downstream task, one may need to add these encodings to the
|
401
|
+
:obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects. To
|
402
|
+
do so, one can use the :class:`GPSENodeEncoder` provided to map these
|
403
|
+
encodings to a desired dimension before appending them to :obj:`x`.
|
404
|
+
|
405
|
+
Let's say we have a graph dataset with 64 original node features, and we
|
406
|
+
have generated GPSE encodings of dimension 32, i.e.
|
407
|
+
:obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an
|
408
|
+
inner dimension of 128. To do so, we can map the 32-dimensional GPSE
|
409
|
+
encodings to a higher dimension of 64, and then append them to the :obj:`x`
|
410
|
+
attribute of the :class:`~torch_geometric.data.Data` objects to obtain a
|
411
|
+
128-dimensional node feature representation.
|
412
|
+
:class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and
|
413
|
+
concatenation to :obj:`x`, the outputs of which can be used as input to a
|
414
|
+
GNN:
|
415
|
+
|
416
|
+
.. code-block:: python
|
417
|
+
|
418
|
+
encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
|
419
|
+
expand_x=False)
|
420
|
+
gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
|
421
|
+
|
422
|
+
for batch in loader:
|
423
|
+
batch = encoder(batch)
|
424
|
+
batch = gnn(batch)
|
425
|
+
# Do something with the batch, which now includes 128-dimensional
|
426
|
+
# node representations
|
427
|
+
|
428
|
+
|
429
|
+
Args:
|
430
|
+
dim_in (int, optional): Input dimension. (default: :obj:`20`)
|
431
|
+
dim_out (int, optional): Output dimension. (default: :obj:`51`)
|
432
|
+
dim_inner (int, optional): Width of the encoder layers.
|
433
|
+
(default: :obj:`512`)
|
434
|
+
layer_type (str, optional): Type of graph convolutional layer for
|
435
|
+
message-passing. (default: :obj:`resgatedgcnconv`)
|
436
|
+
layers_pre_mp (int, optional): Number of MLP layers before
|
437
|
+
message-passing. (default: :obj:`1`)
|
438
|
+
layers_mp (int, optional): Number of layers for message-passing.
|
439
|
+
(default: :obj:`20`)
|
440
|
+
layers_post_mp (int, optional): Number of MLP layers after
|
441
|
+
message-passing. (default: :obj:`2`)
|
442
|
+
num_node_targets (int, optional): Number of individual PSEs used as
|
443
|
+
node-level targets in pretraining :class:`GPSE`.
|
444
|
+
(default: :obj:`51`)
|
445
|
+
num_graph_targets (int, optional): Number of graph-level targets used
|
446
|
+
in pretraining :class:`GPSE`. (default: :obj:`11`)
|
447
|
+
stage_type (str, optional): The type of staging to apply. Possible
|
448
|
+
values are: :obj:`skipsum`, :obj:`skipconcat`. Any other value will
|
449
|
+
default to no skip connections. (default: :obj:`skipsum`)
|
450
|
+
has_bn (bool, optional): Whether to apply batch normalization in the
|
451
|
+
layer. (default: :obj:`True`)
|
452
|
+
final_l2norm (bool, optional): Whether to apply L2 normalization to the
|
453
|
+
outputs. (default: :obj:`True`)
|
454
|
+
has_l2norm (bool, optional): Whether to apply L2 normalization after
|
455
|
+
the layer. (default: :obj:`True`)
|
456
|
+
dropout (float, optional): Dropout ratio at layer output.
|
457
|
+
(default: :obj:`0.2`)
|
458
|
+
has_act (bool, optional): Whether has activation after the layer.
|
459
|
+
(default: :obj:`True`)
|
460
|
+
final_act (bool, optional): Whether to apply activation after the layer
|
461
|
+
stack. (default: :obj:`True`)
|
462
|
+
act (str, optional): Activation to apply to layer output if
|
463
|
+
:obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
|
464
|
+
virtual_node (bool, optional): Whether a virtual node is added to
|
465
|
+
graphs in :class:`GPSE` computation. (default: :obj:`True`)
|
466
|
+
multi_head_dim_inner (int, optional): Width of MLPs for PSE target
|
467
|
+
prediction heads. (default: :obj:`32`)
|
468
|
+
graph_pooling (str, optional): Type of graph pooling applied before
|
469
|
+
post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
|
470
|
+
(default: :obj:`add`)
|
471
|
+
use_repr (bool, optional): Whether to use the hidden representation of
|
472
|
+
the final layer as :class:`GPSE` encodings. (default: :obj:`True`)
|
473
|
+
repr_type (str, optional): Type of representation to use. Options are
|
474
|
+
:obj:`no_post_mp`, :obj:`one_layer_before`.
|
475
|
+
(default: :obj:`no_post_mp`)
|
476
|
+
bernoulli_threshold (float, optional): Threshold for Bernoulli sampling
|
477
|
+
of virtual nodes. (default: :obj:`0.5`)
|
478
|
+
"""
|
479
|
+
|
480
|
+
url_dict = {
|
481
|
+
'molpcba':
|
482
|
+
'https://zenodo.org/record/8145095/files/'
|
483
|
+
'gpse_model_molpcba_1.0.pt',
|
484
|
+
'zinc':
|
485
|
+
'https://zenodo.org/record/8145095/files/gpse_model_zinc_1.0.pt',
|
486
|
+
'pcqm4mv2':
|
487
|
+
'https://zenodo.org/record/8145095/files/'
|
488
|
+
'gpse_model_pcqm4mv2_1.0.pt',
|
489
|
+
'geom':
|
490
|
+
'https://zenodo.org/record/8145095/files/gpse_model_geom_1.0.pt',
|
491
|
+
'chembl':
|
492
|
+
'https://zenodo.org/record/8145095/files/gpse_model_chembl_1.0.pt'
|
493
|
+
}
|
494
|
+
|
495
|
+
def __init__(
|
496
|
+
self,
|
497
|
+
dim_in: int = 20,
|
498
|
+
dim_out: int = 51,
|
499
|
+
dim_inner: int = 512,
|
500
|
+
layer_type: str = 'resgatedgcnconv',
|
501
|
+
layers_pre_mp: int = 1,
|
502
|
+
layers_mp: int = 20,
|
503
|
+
layers_post_mp: int = 2,
|
504
|
+
num_node_targets: int = 51,
|
505
|
+
num_graph_targets: int = 11,
|
506
|
+
stage_type: str = 'skipsum',
|
507
|
+
has_bn: bool = True,
|
508
|
+
head_bn: bool = False,
|
509
|
+
final_l2norm: bool = True,
|
510
|
+
has_l2norm: bool = True,
|
511
|
+
dropout: float = 0.2,
|
512
|
+
has_act: bool = True,
|
513
|
+
final_act: bool = True,
|
514
|
+
act: str = 'relu',
|
515
|
+
virtual_node: bool = True,
|
516
|
+
multi_head_dim_inner: int = 32,
|
517
|
+
graph_pooling: str = 'add',
|
518
|
+
use_repr: bool = True,
|
519
|
+
repr_type: str = 'no_post_mp',
|
520
|
+
bernoulli_threshold: float = 0.5,
|
521
|
+
):
|
522
|
+
super().__init__()
|
523
|
+
|
524
|
+
self.use_repr = use_repr
|
525
|
+
self.repr_type = repr_type
|
526
|
+
self.bernoulli_threshold = bernoulli_threshold
|
527
|
+
|
528
|
+
if layers_pre_mp > 0:
|
529
|
+
self.pre_mp = GeneralMultiLayer(
|
530
|
+
name='linear',
|
531
|
+
in_channels=dim_in,
|
532
|
+
out_channels=dim_inner,
|
533
|
+
hidden_channels=dim_inner,
|
534
|
+
num_layers=layers_pre_mp,
|
535
|
+
has_batch_norm=has_bn,
|
536
|
+
has_l2_norm=has_l2norm,
|
537
|
+
dropout=dropout,
|
538
|
+
act=act,
|
539
|
+
final_act=final_act,
|
540
|
+
)
|
541
|
+
dim_in = dim_inner
|
542
|
+
if layers_mp > 0:
|
543
|
+
self.mp = GNNStackStage(
|
544
|
+
in_channels=dim_in,
|
545
|
+
out_channels=dim_inner,
|
546
|
+
num_layers=layers_mp,
|
547
|
+
layer_type=layer_type,
|
548
|
+
stage_type=stage_type,
|
549
|
+
final_l2_norm=final_l2norm,
|
550
|
+
has_batch_norm=has_bn,
|
551
|
+
has_l2_norm=has_l2norm,
|
552
|
+
dropout=dropout,
|
553
|
+
act=act if has_act else None,
|
554
|
+
)
|
555
|
+
|
556
|
+
self.post_mp = GNNInductiveHybridMultiHead(
|
557
|
+
dim_inner,
|
558
|
+
dim_out,
|
559
|
+
num_node_targets,
|
560
|
+
num_graph_targets,
|
561
|
+
layers_post_mp,
|
562
|
+
virtual_node,
|
563
|
+
multi_head_dim_inner,
|
564
|
+
graph_pooling,
|
565
|
+
head_bn,
|
566
|
+
has_l2norm,
|
567
|
+
dropout,
|
568
|
+
act,
|
569
|
+
)
|
570
|
+
|
571
|
+
self.reset_parameters()
|
572
|
+
|
573
|
+
def reset_parameters(self):
|
574
|
+
from torch_geometric.graphgym.init import init_weights
|
575
|
+
self.apply(init_weights)
|
576
|
+
|
577
|
+
@classmethod
|
578
|
+
def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
|
579
|
+
r"""Returns a pretrained :class:`GPSE` model on a dataset.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
name (str): The name of the dataset (:obj:`"molpcba"`,
|
583
|
+
:obj:`"zinc"`, :obj:`"pcqm4mv2"`, :obj:`"geom"`,
|
584
|
+
:obj:`"chembl"`).
|
585
|
+
root (str, optional): The root directory to save the pre-trained
|
586
|
+
model. (default: :obj:`"GPSE_pretrained"`)
|
587
|
+
"""
|
588
|
+
root = osp.expanduser(osp.normpath(root))
|
589
|
+
os.makedirs(root, exist_ok=True)
|
590
|
+
path = download_url(cls.url_dict[name], root)
|
591
|
+
|
592
|
+
model = GPSE() # All pretrained models use the default arguments
|
593
|
+
model_state = torch.load(path, map_location='cpu')['model_state']
|
594
|
+
model_state_new = OrderedDict([(k.split('.', 1)[1], v)
|
595
|
+
for k, v in model_state.items()])
|
596
|
+
model.load_state_dict(model_state_new)
|
597
|
+
|
598
|
+
# Set the final linear layer to identity if we use hidden reprs
|
599
|
+
if model.use_repr:
|
600
|
+
if model.repr_type == 'one_layer_before':
|
601
|
+
model.post_mp.layer_post_mp.model[-1] = torch.nn.Identity()
|
602
|
+
elif model.repr_type == 'no_post_mp':
|
603
|
+
model.post_mp = IdentityHead()
|
604
|
+
else:
|
605
|
+
raise ValueError(f"Unknown type '{model.repr_type}'")
|
606
|
+
|
607
|
+
model.eval()
|
608
|
+
return model
|
609
|
+
|
610
|
+
def forward(self, batch):
|
611
|
+
for module in self.children():
|
612
|
+
batch = module(batch)
|
613
|
+
return batch
|
614
|
+
|
615
|
+
|
616
|
+
class GPSENodeEncoder(torch.nn.Module):
|
617
|
+
r"""A helper linear/MLP encoder that takes the :class:`GPSE` encodings
|
618
|
+
(based on the `"Graph Positional and Structural Encoder"
|
619
|
+
<https://arxiv.org/abs/2307.07107>`_ paper) precomputed as
|
620
|
+
:obj:`batch.pestat_GPSE` in the input graphs, maps them to a desired
|
621
|
+
dimension defined by :obj:`dim_pe_out` and appends them to node features.
|
622
|
+
|
623
|
+
Let's say we have a graph dataset with 64 original node features, and we
|
624
|
+
have generated GPSE encodings of dimension 32, i.e.
|
625
|
+
:obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an
|
626
|
+
inner dimension of 128. To do so, we can map the 32-dimensional GPSE
|
627
|
+
encodings to a higher dimension of 64, and then append them to the
|
628
|
+
:obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to
|
629
|
+
obtain a 128-dimensional node feature representation.
|
630
|
+
:class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and
|
631
|
+
concatenation to :obj:`x`, the outputs of which can be used as input to a
|
632
|
+
GNN:
|
633
|
+
|
634
|
+
.. code-block:: python
|
635
|
+
|
636
|
+
encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
|
637
|
+
expand_x=False)
|
638
|
+
gnn = GNN(dim_in=128, dim_out=128, num_layers=4)
|
639
|
+
|
640
|
+
for batch in loader:
|
641
|
+
batch = encoder(batch)
|
642
|
+
batch = gnn(batch)
|
643
|
+
# Do something with the batch, which now includes 128-dimensional
|
644
|
+
# node representations
|
645
|
+
|
646
|
+
Args:
|
647
|
+
dim_emb (int): Size of final node embedding.
|
648
|
+
dim_pe_in (int): Original dimension of :obj:`batch.pestat_GPSE`.
|
649
|
+
dim_pe_out (int): Desired dimension of :class:`GPSE` after the encoder.
|
650
|
+
dim_in (int, optional): Original dimension of input node features,
|
651
|
+
required only if :obj:`expand_x` is set to :obj:`True`.
|
652
|
+
(default: :obj:`None`)
|
653
|
+
expand_x (bool, optional): Expand node features :obj:`x` from
|
654
|
+
:obj:`dim_in` to (:obj:`dim_emb` - :obj:`dim_pe_out`)
|
655
|
+
norm_type (str, optional): Type of normalization to apply.
|
656
|
+
(default: :obj:`batchnorm`)
|
657
|
+
model_type (str, optional): Type of encoder, either :obj:`mlp` or
|
658
|
+
:obj:`linear`. (default: :obj:`mlp`)
|
659
|
+
n_layers (int, optional): Number of MLP layers if :obj:`model_type` is
|
660
|
+
:obj:`mlp`. (default: :obj:`2`)
|
661
|
+
dropout_be (float, optional): Dropout ratio of inputs to encoder, i.e.
|
662
|
+
before encoding. (default: :obj:`0.5`)
|
663
|
+
dropout_ae (float, optional): Dropout ratio of outputs, i.e. after
|
664
|
+
encoding. (default: :obj:`0.2`)
|
665
|
+
"""
|
666
|
+
def __init__(self, dim_emb: int, dim_pe_in: int, dim_pe_out: int,
|
667
|
+
dim_in: int = None, expand_x=False, norm_type='batchnorm',
|
668
|
+
model_type='mlp', n_layers=2, dropout_be=0.5, dropout_ae=0.2):
|
669
|
+
super().__init__()
|
670
|
+
|
671
|
+
assert dim_emb > dim_pe_out, ('Desired GPSE dimension (dim_pe_out) '
|
672
|
+
'must be smaller than the final node '
|
673
|
+
'embedding dimension (dim_emb).')
|
674
|
+
|
675
|
+
if expand_x:
|
676
|
+
self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe_out)
|
677
|
+
self.expand_x = expand_x
|
678
|
+
|
679
|
+
self.raw_norm = None
|
680
|
+
if norm_type == 'batchnorm':
|
681
|
+
self.raw_norm = nn.BatchNorm1d(dim_pe_in)
|
682
|
+
|
683
|
+
self.dropout_be = nn.Dropout(p=dropout_be)
|
684
|
+
self.dropout_ae = nn.Dropout(p=dropout_ae)
|
685
|
+
|
686
|
+
activation = nn.ReLU # register.act_dict[cfg.gnn.act]
|
687
|
+
if model_type == 'mlp':
|
688
|
+
layers = []
|
689
|
+
if n_layers == 1:
|
690
|
+
layers.append(torch.nn.Linear(dim_pe_in, dim_pe_out))
|
691
|
+
layers.append(activation())
|
692
|
+
else:
|
693
|
+
layers.append(torch.nn.Linear(dim_pe_in, 2 * dim_pe_out))
|
694
|
+
layers.append(activation())
|
695
|
+
for _ in range(n_layers - 2):
|
696
|
+
layers.append(
|
697
|
+
torch.nn.Linear(2 * dim_pe_out, 2 * dim_pe_out))
|
698
|
+
layers.append(activation())
|
699
|
+
layers.append(torch.nn.Linear(2 * dim_pe_out, dim_pe_out))
|
700
|
+
layers.append(activation())
|
701
|
+
self.pe_encoder = nn.Sequential(*layers)
|
702
|
+
elif model_type == 'linear':
|
703
|
+
self.pe_encoder = nn.Linear(dim_pe_in, dim_pe_out)
|
704
|
+
else:
|
705
|
+
raise ValueError(f"{self.__class__.__name__}: Does not support "
|
706
|
+
f"'{model_type}' encoder model.")
|
707
|
+
|
708
|
+
def forward(self, batch):
|
709
|
+
if not hasattr(batch, 'pestat_GPSE'):
|
710
|
+
raise ValueError('Precomputed "pestat_GPSE" variable is required '
|
711
|
+
'for GNNNodeEncoder; either run '
|
712
|
+
'`precompute_GPSE(gpse_model, dataset)` on your '
|
713
|
+
'dataset or add `AddGPSE(gpse_model)` as a (pre) '
|
714
|
+
'transform.')
|
715
|
+
|
716
|
+
pos_enc = batch.pestat_GPSE
|
717
|
+
|
718
|
+
pos_enc = self.dropout_be(pos_enc)
|
719
|
+
pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
|
720
|
+
pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
|
721
|
+
pos_enc = self.dropout_ae(pos_enc)
|
722
|
+
|
723
|
+
# Expand node features if needed
|
724
|
+
h = self.linear_x(batch.x) if self.expand_x else batch.x
|
725
|
+
|
726
|
+
# Concatenate final PEs to input embedding
|
727
|
+
batch.x = torch.cat((h, pos_enc), 1)
|
728
|
+
|
729
|
+
return batch
|
730
|
+
|
731
|
+
|
732
|
+
@torch.no_grad()
|
733
|
+
def gpse_process(model: GPSE, data: Data, rand_type: str, use_vn: bool = True,
|
734
|
+
bernoulli_thresh: float = 0.5, neighbor_loader: bool = False,
|
735
|
+
num_neighbors: List[int] = [30, 20, 10], fillval: int = 5,
|
736
|
+
layers_mp: int = None, **kwargs) -> torch.Tensor:
|
737
|
+
r"""Processes the data using the :class:`GPSE` model to generate and append
|
738
|
+
GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a
|
739
|
+
single :class:`~torch_geometric.data.Dataset` object.
|
740
|
+
|
741
|
+
Unlike transform-based GPSE processing (i.e.
|
742
|
+
:class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument
|
743
|
+
does not append virtual nodes if set to :obj:`True`, and instead assumes
|
744
|
+
the input graphs to :obj:`gpse_process` already have virtual nodes. Under
|
745
|
+
normal circumstances, one does not need to call this function; running
|
746
|
+
:obj:`precompute_GPSE` on your whole dataset is advised instead.
|
747
|
+
|
748
|
+
Args:
|
749
|
+
model (GPSE): The :class:`GPSE` model.
|
750
|
+
data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data`
|
751
|
+
object.
|
752
|
+
rand_type (str, optional): Type of random features to use. Options are
|
753
|
+
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
|
754
|
+
(default: :obj:`NormalSE`)
|
755
|
+
use_vn (bool, optional): Whether the input graphs have virtual nodes.
|
756
|
+
(default: :obj:`True`)
|
757
|
+
bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of
|
758
|
+
virtual nodes. (default: :obj:`0.5`)
|
759
|
+
neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.
|
760
|
+
(default: :obj:`False`)
|
761
|
+
num_neighbors (List[int], optional): Number of neighbors to consider
|
762
|
+
for each message-passing layer. (default: :obj:`[30, 20, 10]`)
|
763
|
+
fillval (int, optional): Value to fill for missing
|
764
|
+
:obj:`num_neighbors`. (default: :obj:`5`)
|
765
|
+
layers_mp (int, optional): Number of message-passing layers.
|
766
|
+
(default: :obj:`None`)
|
767
|
+
**kwargs (optional): Additional arguments for :obj:`NeighborLoader`.
|
768
|
+
|
769
|
+
Returns:
|
770
|
+
torch.Tensor: A tensor corresponding to the original
|
771
|
+
:class:`~torch_geometric.data.Data` object, with :class:`GPSE`
|
772
|
+
encodings appended as :obj:`out.pestat_GPSE` attribute.
|
773
|
+
"""
|
774
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
775
|
+
# Generate random features for the encoder
|
776
|
+
n = data.num_nodes
|
777
|
+
dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]
|
778
|
+
|
779
|
+
# Prepare input distributions for GPSE
|
780
|
+
if rand_type == 'NormalSE':
|
781
|
+
rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))
|
782
|
+
elif rand_type == 'UniformSE':
|
783
|
+
rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
|
784
|
+
elif rand_type == 'BernoulliSE':
|
785
|
+
rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
|
786
|
+
rand = (rand < bernoulli_thresh)
|
787
|
+
else:
|
788
|
+
raise ValueError(f'Unknown {rand_type=!r}')
|
789
|
+
data.x = torch.from_numpy(rand.astype('float32'))
|
790
|
+
|
791
|
+
if use_vn:
|
792
|
+
data.x[-1] = 0
|
793
|
+
|
794
|
+
model, data = model.to(device), data.to(device)
|
795
|
+
# Generate encodings using the pretrained encoder
|
796
|
+
if neighbor_loader:
|
797
|
+
if layers_mp is None:
|
798
|
+
raise ValueError('Please provide the number of message-passing '
|
799
|
+
'layers as "layers_mp".')
|
800
|
+
diff = layers_mp - len(num_neighbors)
|
801
|
+
if fillval > 0 and diff > 0:
|
802
|
+
num_neighbors += [fillval] * diff
|
803
|
+
|
804
|
+
loader = NeighborLoader(data, num_neighbors=num_neighbors,
|
805
|
+
shuffle=False, pin_memory=True, **kwargs)
|
806
|
+
out_list = []
|
807
|
+
pbar = trange(data.num_nodes, position=2)
|
808
|
+
for i, batch in enumerate(loader):
|
809
|
+
out, _ = model(batch.to(device))
|
810
|
+
out = out[:batch.batch_size].to("cpu", non_blocking=True)
|
811
|
+
out_list.append(out)
|
812
|
+
pbar.update(batch.batch_size)
|
813
|
+
out = torch.vstack(out_list)
|
814
|
+
else:
|
815
|
+
out, _ = model(data)
|
816
|
+
out = out.to("cpu")
|
817
|
+
|
818
|
+
return out
|
819
|
+
|
820
|
+
|
821
|
+
@torch.no_grad()
|
822
|
+
def gpse_process_batch(model: GPSE, batch, rand_type: str, use_vn: bool = True,
|
823
|
+
bernoulli_thresh: float = 0.5,
|
824
|
+
neighbor_loader: bool = False,
|
825
|
+
num_neighbors: List[int] = [30, 20, 10],
|
826
|
+
fillval: int = 5, layers_mp: int = None,
|
827
|
+
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
828
|
+
r"""Process a batch of data using the :class:`GPSE` model to generate and
|
829
|
+
append :class:`GPSE` encodings. Identical to `gpse_process`, but operates
|
830
|
+
on a batch of :class:`~torch_geometric.data.Data` objects.
|
831
|
+
|
832
|
+
Unlike transform-based GPSE processing (i.e.
|
833
|
+
:class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument
|
834
|
+
does not append virtual nodes if set to :obj:`True`, and instead assumes
|
835
|
+
the input graphs to :obj:`gpse_process` already have virtual nodes. This is
|
836
|
+
because the virtual nodes are already added to graphs before the call to
|
837
|
+
:obj:`gpse_process_batch` in :obj:`precompute_GPSE` for better efficiency.
|
838
|
+
Under normal circumstances, one does not need to call this function;
|
839
|
+
running :obj:`precompute_GPSE` on your whole dataset is advised instead.
|
840
|
+
|
841
|
+
Args:
|
842
|
+
model (GPSE): The :class:`GPSE` model.
|
843
|
+
batch: A batch of PyG Data objects.
|
844
|
+
rand_type (str, optional): Type of random features to use. Options are
|
845
|
+
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
|
846
|
+
(default: :obj:`NormalSE`)
|
847
|
+
use_vn (bool, optional): Whether the input graphs have virtual nodes.
|
848
|
+
(default: :obj:`True`)
|
849
|
+
bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of
|
850
|
+
virtual nodes. (default: :obj:`0.5`)
|
851
|
+
neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.
|
852
|
+
(default: :obj:`False`)
|
853
|
+
num_neighbors (List[int], optional): Number of neighbors to consider
|
854
|
+
for each message-passing layer. (default: :obj:`[30, 20, 10]`)
|
855
|
+
fillval (int, optional): Value to fill for missing
|
856
|
+
:obj:`num_neighbors`. (default: :obj:`5`)
|
857
|
+
layers_mp (int, optional): Number of message-passing layers.
|
858
|
+
(default: :obj:`None`)
|
859
|
+
**kwargs: Additional keyword arguments for :obj:`NeighborLoader`.
|
860
|
+
|
861
|
+
Returns:
|
862
|
+
Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding
|
863
|
+
to the stacked :class:`GPSE` encodings and the pointers indicating
|
864
|
+
individual graphs.
|
865
|
+
"""
|
866
|
+
n = batch.num_nodes
|
867
|
+
dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]
|
868
|
+
|
869
|
+
# Prepare input distributions for GPSE
|
870
|
+
if rand_type == 'NormalSE':
|
871
|
+
rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))
|
872
|
+
elif rand_type == 'UniformSE':
|
873
|
+
rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
|
874
|
+
elif rand_type == 'BernoulliSE':
|
875
|
+
rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
|
876
|
+
rand = (rand < bernoulli_thresh)
|
877
|
+
else:
|
878
|
+
raise ValueError(f'Unknown {rand_type=!r}')
|
879
|
+
batch.x = torch.from_numpy(rand.astype('float32'))
|
880
|
+
|
881
|
+
if use_vn:
|
882
|
+
# HACK: We need to reset virtual node features to zeros to match the
|
883
|
+
# pretraining setting (virtual node applied after random node features
|
884
|
+
# are set, and the default node features for the virtual node are all
|
885
|
+
# zeros). Can potentially test if initializing virtual node features to
|
886
|
+
# random features is better than setting them to zeros.
|
887
|
+
for i in batch.ptr[1:]:
|
888
|
+
batch.x[i - 1] = 0
|
889
|
+
|
890
|
+
# Generate encodings using the pretrained encoder
|
891
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
892
|
+
model = model.to(device)
|
893
|
+
if neighbor_loader:
|
894
|
+
if layers_mp is None:
|
895
|
+
raise ValueError('Please provide the number of message-passing '
|
896
|
+
'layers as "layers_mp".')
|
897
|
+
diff = layers_mp - len(num_neighbors)
|
898
|
+
if fillval > 0 and diff > 0:
|
899
|
+
num_neighbors += [fillval] * diff
|
900
|
+
|
901
|
+
loader = NeighborLoader(batch, num_neighbors=num_neighbors,
|
902
|
+
shuffle=False, pin_memory=True, **kwargs)
|
903
|
+
out_list = []
|
904
|
+
pbar = trange(batch.num_nodes, position=2)
|
905
|
+
for i, batch in enumerate(loader):
|
906
|
+
out, _ = model(batch.to(device))
|
907
|
+
out = out[:batch.batch_size].to('cpu', non_blocking=True)
|
908
|
+
out_list.append(out)
|
909
|
+
pbar.update(batch.batch_size)
|
910
|
+
out = torch.vstack(out_list)
|
911
|
+
else:
|
912
|
+
out, _ = model(batch.to(device))
|
913
|
+
out = out.to('cpu')
|
914
|
+
|
915
|
+
return out, batch.ptr
|
916
|
+
|
917
|
+
|
918
|
+
@torch.no_grad()
|
919
|
+
def precompute_GPSE(model: GPSE, dataset: Dataset, use_vn: bool = True,
|
920
|
+
rand_type: str = 'NormalSE', **kwargs):
|
921
|
+
r"""Precomputes :class:`GPSE` encodings in-place for a given dataset using
|
922
|
+
a :class:`GPSE` model.
|
923
|
+
|
924
|
+
Args:
|
925
|
+
model (GPSE): The :class:`GPSE` model.
|
926
|
+
dataset (Dataset): A PyG Dataset.
|
927
|
+
use_vn (bool, optional): Whether to append virtual nodes to graphs in
|
928
|
+
:class:`GPSE` computation. Should match the setting used when
|
929
|
+
pre-training the :class:`GPSE` model. (default :obj:`True`)
|
930
|
+
rand_type (str, optional): The type of randomization to use.
|
931
|
+
(default :obj:`NormalSE`)
|
932
|
+
**kwargs (optional): Additional arguments for
|
933
|
+
:class:`~torch_geometric.data.DataLoader`.
|
934
|
+
"""
|
935
|
+
# Temporarily replace the transformation
|
936
|
+
orig_dataset_transform = dataset.transform
|
937
|
+
dataset.transform = None
|
938
|
+
if use_vn:
|
939
|
+
dataset.transform = T.VirtualNode()
|
940
|
+
|
941
|
+
# Remove split indices, to be recovered at the end of the precomputation
|
942
|
+
tmp_store = {}
|
943
|
+
for name in [
|
944
|
+
'train_mask', 'val_mask', 'test_mask', 'train_graph_index',
|
945
|
+
'val_graph_index', 'test_graph_index', 'train_edge_index',
|
946
|
+
'val_edge_index', 'test_edge_index'
|
947
|
+
]:
|
948
|
+
if (name in dataset.data) and (dataset.slices is None
|
949
|
+
or name in dataset.slices):
|
950
|
+
tmp_store_data = dataset.data.pop(name)
|
951
|
+
tmp_store_slices = dataset.slices.pop(name) \
|
952
|
+
if dataset.slices else None
|
953
|
+
tmp_store[name] = (tmp_store_data, tmp_store_slices)
|
954
|
+
|
955
|
+
loader = DataLoader(dataset, shuffle=False, pin_memory=True, **kwargs)
|
956
|
+
|
957
|
+
# Batched GPSE precomputation loop
|
958
|
+
data_list = []
|
959
|
+
curr_idx = 0
|
960
|
+
pbar = trange(len(dataset), desc='Pre-computing GPSE')
|
961
|
+
tic = time.perf_counter()
|
962
|
+
for batch in loader:
|
963
|
+
batch_out, batch_ptr = gpse_process_batch(model, batch, rand_type,
|
964
|
+
**kwargs)
|
965
|
+
|
966
|
+
batch_out = batch_out.to('cpu', non_blocking=True)
|
967
|
+
# Need to wait for batch_ptr to finish transfering so that start and
|
968
|
+
# end indices are ready to use
|
969
|
+
batch_ptr = batch_ptr.to('cpu', non_blocking=False)
|
970
|
+
|
971
|
+
for start, end in zip(batch_ptr[:-1], batch_ptr[1:]):
|
972
|
+
data = dataset.get(curr_idx)
|
973
|
+
if use_vn:
|
974
|
+
end = end - 1
|
975
|
+
data.pestat_GPSE = batch_out[start:end]
|
976
|
+
data_list.append(data)
|
977
|
+
curr_idx += 1
|
978
|
+
|
979
|
+
pbar.update(len(batch_ptr) - 1)
|
980
|
+
pbar.close()
|
981
|
+
|
982
|
+
# Collate dataset and reset indicies and data list
|
983
|
+
dataset.transform = orig_dataset_transform
|
984
|
+
dataset._indices = None
|
985
|
+
dataset._data_list = data_list
|
986
|
+
dataset.data, dataset.slices = dataset.collate(data_list)
|
987
|
+
|
988
|
+
# Recover split indices
|
989
|
+
for name, (tmp_store_data, tmp_store_slices) in tmp_store.items():
|
990
|
+
dataset.data[name] = tmp_store_data
|
991
|
+
if tmp_store_slices is not None:
|
992
|
+
dataset.slices[name] = tmp_store_slices
|
993
|
+
dataset._data_list = None
|
994
|
+
|
995
|
+
timestr = time.strftime('%H:%M:%S', time.gmtime(time.perf_counter() - tic))
|
996
|
+
logging.info(f'Finished GPSE pre-computation, took {timestr}')
|
997
|
+
|
998
|
+
# Release resource and recover original configs
|
999
|
+
del model
|
1000
|
+
torch.cuda.empty_cache()
|
1001
|
+
|
1002
|
+
|
1003
|
+
def cosim_col_sep(pred: torch.Tensor, true: torch.Tensor,
|
1004
|
+
batch_idx: torch.Tensor) -> torch.Tensor:
|
1005
|
+
r"""Calculates the average cosine similarity between predicted and true
|
1006
|
+
features on a batch of graphs.
|
1007
|
+
|
1008
|
+
Args:
|
1009
|
+
pred (torch.Tensor): Predicted outputs.
|
1010
|
+
true (torch.Tensor): Value of ground truths.
|
1011
|
+
batch_idx (torch.Tensor): Batch indices to separate the graphs.
|
1012
|
+
|
1013
|
+
Returns:
|
1014
|
+
torch.Tensor: Average cosine similarity per graph in batch.
|
1015
|
+
|
1016
|
+
Raises:
|
1017
|
+
ValueError: If batch_index is not specified.
|
1018
|
+
"""
|
1019
|
+
if batch_idx is None:
|
1020
|
+
raise ValueError("mae_cosim_col_sep requires batch index as "
|
1021
|
+
"input to distinguish different graphs.")
|
1022
|
+
batch_idx = batch_idx + 1 if batch_idx.min() == -1 else batch_idx
|
1023
|
+
pred_dense = to_dense_batch(pred, batch_idx)[0]
|
1024
|
+
true_dense = to_dense_batch(true, batch_idx)[0]
|
1025
|
+
mask = (true_dense == 0).all(1) # exclude trivial features from loss
|
1026
|
+
loss = 1 - F.cosine_similarity(pred_dense, true_dense, dim=1)[~mask].mean()
|
1027
|
+
return loss
|
1028
|
+
|
1029
|
+
|
1030
|
+
def gpse_loss(pred: torch.Tensor, true: torch.Tensor,
|
1031
|
+
batch_idx: torch.Tensor = None) \
|
1032
|
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
1033
|
+
r"""Calculates :class:`GPSE` loss as the sum of MAE loss and cosine
|
1034
|
+
similarity loss over a batch of graphs.
|
1035
|
+
|
1036
|
+
Args:
|
1037
|
+
pred (torch.Tensor): Predicted outputs.
|
1038
|
+
true (torch.Tensor): Value of ground truths.
|
1039
|
+
batch_idx (torch.Tensor): Batch indices to separate the graphs.
|
1040
|
+
|
1041
|
+
Returns:
|
1042
|
+
Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding
|
1043
|
+
to the :class:`GPSE` loss and the predicted node-and-graph level
|
1044
|
+
outputs.
|
1045
|
+
|
1046
|
+
"""
|
1047
|
+
if batch_idx is None:
|
1048
|
+
raise ValueError("mae_cosim_col_sep requires batch index as "
|
1049
|
+
"input to distinguish different graphs.")
|
1050
|
+
mae_loss = F.l1_loss(pred, true)
|
1051
|
+
cosim_loss = cosim_col_sep(pred, true, batch_idx)
|
1052
|
+
loss = mae_loss + cosim_loss
|
1053
|
+
return loss, pred
|
1054
|
+
|
1055
|
+
|
1056
|
+
def process_batch_idx(batch_idx, true, use_vn=True):
|
1057
|
+
r"""Processes batch indices to adjust for the removal of virtual nodes, and
|
1058
|
+
pads batch index for hybrid tasks.
|
1059
|
+
|
1060
|
+
Args:
|
1061
|
+
batch_idx: Batch indices to separate the graphs.
|
1062
|
+
true: Value of ground truths.
|
1063
|
+
use_vn: If input graphs have virtual nodes that need to be removed.
|
1064
|
+
|
1065
|
+
Returns:
|
1066
|
+
torch.Tensor: Batch indices that separate the graphs.
|
1067
|
+
"""
|
1068
|
+
if batch_idx is None:
|
1069
|
+
return
|
1070
|
+
if use_vn: # remove virtual node
|
1071
|
+
batch_idx = torch.concat([
|
1072
|
+
batch_idx[batch_idx == i][:-1]
|
1073
|
+
for i in range(batch_idx.max().item() + 1)
|
1074
|
+
])
|
1075
|
+
# Pad batch index for hybrid tasks (set batch index for graph heads to -1)
|
1076
|
+
if (pad := true.shape[0] - batch_idx.shape[0]) > 0:
|
1077
|
+
pad_idx = -torch.ones(pad, dtype=torch.long, device=batch_idx.device)
|
1078
|
+
batch_idx = torch.hstack([batch_idx, pad_idx])
|
1079
|
+
return batch_idx
|
@@ -37,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
|
|
37
37
|
from .largest_connected_components import LargestConnectedComponents
|
38
38
|
from .virtual_node import VirtualNode
|
39
39
|
from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
|
40
|
+
from .add_gpse import AddGPSE
|
40
41
|
from .feature_propagation import FeaturePropagation
|
41
42
|
from .half_hop import HalfHop
|
42
43
|
|
@@ -108,6 +109,7 @@ graph_transforms = [
|
|
108
109
|
'VirtualNode',
|
109
110
|
'AddLaplacianEigenvectorPE',
|
110
111
|
'AddRandomWalkPE',
|
112
|
+
'AddGPSE',
|
111
113
|
'FeaturePropagation',
|
112
114
|
'HalfHop',
|
113
115
|
]
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from torch_geometric.data import Data
|
2
|
+
from torch_geometric.data.datapipes import functional_transform
|
3
|
+
from torch_geometric.nn.models.gpse import GPSE
|
4
|
+
from torch_geometric.transforms import BaseTransform, VirtualNode
|
5
|
+
|
6
|
+
|
7
|
+
@functional_transform('add_gpse')
|
8
|
+
class AddGPSE(BaseTransform):
|
9
|
+
r"""Adds the GPSE encoding from the `"Graph Positional and Structural
|
10
|
+
Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
|
11
|
+
(functional name: :obj:`add_gpse`).
|
12
|
+
To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
|
13
|
+
the actual encodings.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
model (GPSE): The pre-trained GPSE model.
|
17
|
+
use_vn (bool, optional): Whether to use virtual nodes.
|
18
|
+
(default: :obj:`True`)
|
19
|
+
rand_type (str, optional): Type of random features to use. Options are
|
20
|
+
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
|
21
|
+
(default: :obj:`NormalSE`)
|
22
|
+
|
23
|
+
"""
|
24
|
+
def __init__(self, model: GPSE, use_vn: bool = True,
|
25
|
+
rand_type: str = 'NormalSE'):
|
26
|
+
self.model = model
|
27
|
+
self.use_vn = use_vn
|
28
|
+
self.vn = VirtualNode()
|
29
|
+
self.rand_type = rand_type
|
30
|
+
|
31
|
+
def __call__(self, data: Data) -> Data:
|
32
|
+
from torch_geometric.nn.models.gpse import gpse_process
|
33
|
+
|
34
|
+
data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
|
35
|
+
batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
|
36
|
+
batch_out = batch_out.to('cpu', non_blocking=True)
|
37
|
+
data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
|
38
|
+
|
39
|
+
return data
|
File without changes
|
{pyg_nightly-2.7.0.dev20250404.dist-info → pyg_nightly-2.7.0.dev20250405.dist-info}/licenses/LICENSE
RENAMED
File without changes
|