pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,1083 @@
1
+ import logging
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from collections import OrderedDict
6
+ from typing import List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn import Module
13
+ from tqdm import trange
14
+
15
+ import torch_geometric.transforms as T
16
+ from torch_geometric.data import Data, Dataset, download_url
17
+ from torch_geometric.loader import DataLoader, NeighborLoader
18
+ from torch_geometric.nn import (
19
+ ResGatedGraphConv,
20
+ global_add_pool,
21
+ global_max_pool,
22
+ global_mean_pool,
23
+ )
24
+ from torch_geometric.nn.resolver import activation_resolver
25
+ from torch_geometric.utils import to_dense_batch
26
+
27
+
28
+ class Linear(torch.nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels: int,
32
+ out_channels: int,
33
+ bias: bool,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.model = torch.nn.Linear(in_channels, out_channels, bias=bias)
37
+
38
+ def forward(self, batch):
39
+ if isinstance(batch, torch.Tensor):
40
+ batch = self.model(batch)
41
+ else:
42
+ batch.x = self.model(batch.x)
43
+ return batch
44
+
45
+
46
+ class ResGatedGCNConv(torch.nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_channels: int,
50
+ out_channels: int,
51
+ bias: bool,
52
+ **kwargs,
53
+ ) -> None:
54
+ super().__init__()
55
+ self.model = ResGatedGraphConv(
56
+ in_channels,
57
+ out_channels,
58
+ bias=bias,
59
+ **kwargs,
60
+ )
61
+
62
+ def forward(self, batch):
63
+ batch.x = self.model(batch.x, batch.edge_index)
64
+ return batch
65
+
66
+
67
+ class GeneralLayer(torch.nn.Module):
68
+ def __init__(
69
+ self,
70
+ name: str,
71
+ in_channels: int,
72
+ out_channels: int,
73
+ has_batch_norm: bool,
74
+ has_l2_norm: bool,
75
+ dropout: float,
76
+ act: Optional[str],
77
+ **kwargs,
78
+ ):
79
+ super().__init__()
80
+ self.has_l2_norm = has_l2_norm
81
+
82
+ layer_dict = {
83
+ 'linear': Linear,
84
+ 'resgatedgcnconv': ResGatedGCNConv,
85
+ }
86
+ self.layer = layer_dict[name](
87
+ in_channels,
88
+ out_channels,
89
+ bias=not has_batch_norm,
90
+ **kwargs,
91
+ )
92
+ post_layers = []
93
+ if has_batch_norm:
94
+ post_layers.append(
95
+ torch.nn.BatchNorm1d(out_channels, eps=1e-5, momentum=0.1))
96
+ if dropout > 0:
97
+ post_layers.append(torch.nn.Dropout(p=dropout, inplace=False))
98
+ if act is not None:
99
+ post_layers.append(activation_resolver(act))
100
+ self.post_layer = nn.Sequential(*post_layers)
101
+
102
+ def forward(self, batch):
103
+ batch = self.layer(batch)
104
+ if isinstance(batch, torch.Tensor):
105
+ batch = self.post_layer(batch)
106
+ if self.has_l2_norm:
107
+ batch = F.normalize(batch, p=2, dim=1)
108
+ else:
109
+ batch.x = self.post_layer(batch.x)
110
+ if self.has_l2_norm:
111
+ batch.x = F.normalize(batch.x, p=2, dim=1)
112
+ return batch
113
+
114
+
115
+ class GeneralMultiLayer(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ name: str,
119
+ in_channels: int,
120
+ out_channels: int,
121
+ hidden_channels: Optional[int],
122
+ num_layers: int,
123
+ has_batch_norm: bool,
124
+ has_l2_norm: bool,
125
+ dropout: float,
126
+ act: str,
127
+ final_act: bool,
128
+ **kwargs,
129
+ ) -> None:
130
+ super().__init__()
131
+ hidden_channels = hidden_channels or out_channels
132
+
133
+ for i in range(num_layers):
134
+ d_in = in_channels if i == 0 else hidden_channels
135
+ d_out = out_channels if i == num_layers - 1 else hidden_channels
136
+ layer = GeneralLayer(
137
+ name=name,
138
+ in_channels=d_in,
139
+ out_channels=d_out,
140
+ has_batch_norm=has_batch_norm,
141
+ has_l2_norm=has_l2_norm,
142
+ dropout=dropout,
143
+ act=None if i == num_layers - 1 and not final_act else act,
144
+ **kwargs,
145
+ )
146
+ self.add_module(f'Layer_{i}', layer)
147
+
148
+ def forward(self, batch):
149
+ for layer in self.children():
150
+ batch = layer(batch)
151
+ return batch
152
+
153
+
154
+ class BatchNorm1dNode(torch.nn.Module):
155
+ def __init__(self, channels: int) -> None:
156
+ super().__init__()
157
+ self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)
158
+
159
+ def forward(self, batch):
160
+ batch.x = self.bn(batch.x)
161
+ return batch
162
+
163
+
164
+ class BatchNorm1dEdge(torch.nn.Module):
165
+ def __init__(self, channels: int) -> None:
166
+ super().__init__()
167
+ self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)
168
+
169
+ def forward(self, batch):
170
+ batch.edge_attr = self.bn(batch.edge_attr)
171
+ return batch
172
+
173
+
174
+ class MLP(torch.nn.Module):
175
+ def __init__(
176
+ self,
177
+ in_channels: int,
178
+ out_channels: int,
179
+ hidden_channels: Optional[int],
180
+ num_layers: int,
181
+ has_batch_norm: bool = True,
182
+ has_l2_norm: bool = True,
183
+ dropout: float = 0.2,
184
+ act: str = 'relu',
185
+ **kwargs,
186
+ ):
187
+ super().__init__()
188
+ hidden_channels = hidden_channels or in_channels
189
+
190
+ layers = []
191
+ if num_layers > 1:
192
+ layer = GeneralMultiLayer(
193
+ 'linear',
194
+ in_channels,
195
+ hidden_channels,
196
+ hidden_channels,
197
+ num_layers - 1,
198
+ has_batch_norm,
199
+ has_l2_norm,
200
+ dropout,
201
+ act,
202
+ final_act=True,
203
+ **kwargs,
204
+ )
205
+ layers.append(layer)
206
+ layers.append(Linear(hidden_channels, out_channels, bias=True))
207
+ self.model = nn.Sequential(*layers)
208
+
209
+ def forward(self, batch):
210
+ if isinstance(batch, torch.Tensor):
211
+ batch = self.model(batch)
212
+ else:
213
+ batch.x = self.model(batch.x)
214
+ return batch
215
+
216
+
217
+ class GNNStackStage(torch.nn.Module):
218
+ def __init__(
219
+ self,
220
+ in_channels: int,
221
+ out_channels: int,
222
+ num_layers: int,
223
+ layer_type: str,
224
+ stage_type: str = 'skipsum',
225
+ final_l2_norm: bool = True,
226
+ has_batch_norm: bool = True,
227
+ has_l2_norm: bool = True,
228
+ dropout: float = 0.2,
229
+ act: Optional[str] = 'relu',
230
+ ):
231
+ super().__init__()
232
+ self.num_layers = num_layers
233
+ self.stage_type = stage_type
234
+ self.final_l2_norm = final_l2_norm
235
+
236
+ for i in range(num_layers):
237
+ if stage_type == 'skipconcat':
238
+ if i == 0:
239
+ d_in = in_channels
240
+ else:
241
+ d_in = in_channels + i * out_channels
242
+ else:
243
+ d_in = in_channels if i == 0 else out_channels
244
+ layer = GeneralLayer(layer_type, d_in, out_channels,
245
+ has_batch_norm, has_l2_norm, dropout, act)
246
+ self.add_module(f'layer{i}', layer)
247
+
248
+ def forward(self, batch):
249
+ for i, layer in enumerate(self.children()):
250
+ x = batch.x
251
+ batch = layer(batch)
252
+ if self.stage_type == 'skipsum':
253
+ batch.x = x + batch.x
254
+ elif self.stage_type == 'skipconcat' and i < self.num_layers - 1:
255
+ batch.x = torch.cat([x, batch.x], dim=1)
256
+
257
+ if self.final_l2_norm:
258
+ batch.x = F.normalize(batch.x, p=2, dim=-1)
259
+
260
+ return batch
261
+
262
+
263
+ class GNNInductiveHybridMultiHead(torch.nn.Module):
264
+ r"""GNN prediction head for inductive node and graph prediction tasks using
265
+ individual MLP for each task.
266
+
267
+ Args:
268
+ dim_in (int): Input dimension.
269
+ dim_out (int): Output dimension. Not used, as the dimension is
270
+ determined by :obj:`num_node_targets` and :obj:`num_graph_targets`
271
+ instead.
272
+ num_node_targets (int): Number of individual PSEs used as node-level
273
+ targets in pretraining :class:`GPSE`.
274
+ num_graph_targets (int): Number of graph-level targets used in
275
+ pretraining :class:`GPSE`.
276
+ layers_post_mp (int): Number of MLP layers after GNN message-passing.
277
+ virtual_node (bool, optional): Whether a virtual node is added to
278
+ graphs in :class:`GPSE` computation. (default: :obj:`True`)
279
+ multi_head_dim_inner (int, optional): Width of MLPs for PSE target
280
+ prediction heads. (default: :obj:`32`)
281
+ graph_pooling (str, optional): Type of graph pooling applied before
282
+ post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
283
+ (default: :obj:`add`)
284
+ has_bn (bool, optional): Whether to apply batch normalization to layer
285
+ outputs. (default: :obj:`True`)
286
+ has_l2norm (bool, optional): Whether to apply L2 normalization to the
287
+ layer outputs. (default: :obj:`True`)
288
+ dropout (float, optional): Dropout ratio at layer output.
289
+ (default: :obj:`0.2`)
290
+ act (str, optional): Activation to apply to layer outputs if
291
+ :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
292
+ """
293
+ def __init__(
294
+ self,
295
+ dim_in: int,
296
+ dim_out: int,
297
+ num_node_targets: int,
298
+ num_graph_targets: int,
299
+ layers_post_mp: int,
300
+ virtual_node: bool = True,
301
+ multi_head_dim_inner: int = 32,
302
+ graph_pooling: str = 'add',
303
+ has_bn: bool = True,
304
+ has_l2norm: bool = True,
305
+ dropout: float = 0.2,
306
+ act: str = 'relu',
307
+ ):
308
+ super().__init__()
309
+ pool_dict = {
310
+ 'add': global_add_pool,
311
+ 'max': global_max_pool,
312
+ 'mean': global_mean_pool
313
+ }
314
+ self.node_target_dim = num_node_targets
315
+ self.graph_target_dim = num_graph_targets
316
+ self.virtual_node = virtual_node
317
+ num_layers = layers_post_mp
318
+
319
+ self.node_post_mps = nn.ModuleList([
320
+ MLP(dim_in, 1, multi_head_dim_inner, num_layers, has_bn,
321
+ has_l2norm, dropout, act) for _ in range(self.node_target_dim)
322
+ ])
323
+
324
+ self.graph_pooling = pool_dict[graph_pooling]
325
+
326
+ self.graph_post_mp = MLP(dim_in, self.graph_target_dim, dim_in,
327
+ num_layers, has_bn, has_l2norm, dropout, act)
328
+
329
+ def _pad_and_stack(self, x1: torch.Tensor, x2: torch.Tensor, pad1: int,
330
+ pad2: int):
331
+ padded_x1 = nn.functional.pad(x1, (0, pad2))
332
+ padded_x2 = nn.functional.pad(x2, (pad1, 0))
333
+ return torch.vstack([padded_x1, padded_x2])
334
+
335
+ def _apply_index(self, batch, virtual_node: bool, pad_node: int,
336
+ pad_graph: int):
337
+ graph_pred, graph_true = batch.graph_feature, batch.y_graph
338
+ node_pred, node_true = batch.node_feature, batch.y
339
+ if virtual_node:
340
+ # Remove virtual node
341
+ idx = torch.concat([
342
+ torch.where(batch.batch == i)[0][:-1]
343
+ for i in range(batch.batch.max().item() + 1)
344
+ ])
345
+ node_pred, node_true = node_pred[idx], node_true[idx]
346
+
347
+ # Stack node predictions on top of graph predictions and pad with zeros
348
+ pred = self._pad_and_stack(node_pred, graph_pred, pad_node, pad_graph)
349
+ true = self._pad_and_stack(node_true, graph_true, pad_node, pad_graph)
350
+
351
+ return pred, true
352
+
353
+ def forward(self, batch):
354
+ batch.node_feature = torch.hstack(
355
+ [m(batch.x) for m in self.node_post_mps])
356
+ graph_emb = self.graph_pooling(batch.x, batch.batch)
357
+ batch.graph_feature = self.graph_post_mp(graph_emb)
358
+ return self._apply_index(batch, self.virtual_node,
359
+ self.node_target_dim, self.graph_target_dim)
360
+
361
+
362
+ class IdentityHead(torch.nn.Module):
363
+ def forward(self, batch):
364
+ return batch.x, batch.y
365
+
366
+
367
+ class GPSE(torch.nn.Module):
368
+ r"""The Graph Positional and Structural Encoder (GPSE) model from the
369
+ `"Graph Positional and Structural Encoder"
370
+ <https://arxiv.org/abs/2307.07107>`_ paper.
371
+
372
+ The GPSE model consists of a (1) deep GNN that consists of stacked
373
+ message passing layers, and a (2) prediction head to predict pre-computed
374
+ positional and structural encodings (PSE).
375
+ When used on downstream datasets, these prediction heads are removed and
376
+ the final fully-connected layer outputs are used as learned PSE embeddings.
377
+
378
+ GPSE also provides a static method :meth:`from_pretrained` to load
379
+ pre-trained GPSE models trained on a variety of molecular datasets.
380
+
381
+ .. code-block:: python
382
+
383
+ from torch_geometric.nn import GPSE, GPSENodeEncoder
384
+ from torch_geometric.transforms import AddGPSE
385
+ from torch_geometric.nn.models.gpse import precompute_GPSE
386
+
387
+ gpse_model = GPSE.from_pretrained('molpcba')
388
+
389
+ # Option 1: Precompute GPSE encodings in-place for a given dataset
390
+ dataset = ZINC(path, subset=True, split='train')
391
+ precompute_gpse(gpse_model, dataset)
392
+
393
+ # Option 2: Use the GPSE model with AddGPSE as a pre_transform to save
394
+ # the encodings
395
+ dataset = ZINC(path, subset=True, split='train',
396
+ pre_transform=AddGPSE(gpse_model, vn=True,
397
+ rand_type='NormalSE'))
398
+
399
+ Both approaches append the generated encodings to the :obj:`pestat_GPSE`
400
+ attribute of :class:`~torch_geometric.data.Data` objects. To use the GPSE
401
+ encodings for a downstream task, one may need to add these encodings to the
402
+ :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects. To
403
+ do so, one can use the :class:`GPSENodeEncoder` provided to map these
404
+ encodings to a desired dimension before appending them to :obj:`x`.
405
+
406
+ Let's say we have a graph dataset with 64 original node features, and we
407
+ have generated GPSE encodings of dimension 32, i.e.
408
+ :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an
409
+ inner dimension of 128. To do so, we can map the 32-dimensional GPSE
410
+ encodings to a higher dimension of 64, and then append them to the :obj:`x`
411
+ attribute of the :class:`~torch_geometric.data.Data` objects to obtain a
412
+ 128-dimensional node feature representation.
413
+ :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and
414
+ concatenation to :obj:`x`, the outputs of which can be used as input to a
415
+ GNN:
416
+
417
+ .. code-block:: python
418
+
419
+ encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
420
+ expand_x=False)
421
+ gnn = GNN(...)
422
+
423
+ for batch in loader:
424
+ x = encoder(batch.x, batch.pestat_GPSE)
425
+ out = gnn(x, batch.edge_index)
426
+
427
+
428
+ Args:
429
+ dim_in (int, optional): Input dimension. (default: :obj:`20`)
430
+ dim_out (int, optional): Output dimension. (default: :obj:`51`)
431
+ dim_inner (int, optional): Width of the encoder layers.
432
+ (default: :obj:`512`)
433
+ layer_type (str, optional): Type of graph convolutional layer for
434
+ message-passing. (default: :obj:`resgatedgcnconv`)
435
+ layers_pre_mp (int, optional): Number of MLP layers before
436
+ message-passing. (default: :obj:`1`)
437
+ layers_mp (int, optional): Number of layers for message-passing.
438
+ (default: :obj:`20`)
439
+ layers_post_mp (int, optional): Number of MLP layers after
440
+ message-passing. (default: :obj:`2`)
441
+ num_node_targets (int, optional): Number of individual PSEs used as
442
+ node-level targets in pretraining :class:`GPSE`.
443
+ (default: :obj:`51`)
444
+ num_graph_targets (int, optional): Number of graph-level targets used
445
+ in pretraining :class:`GPSE`. (default: :obj:`11`)
446
+ stage_type (str, optional): The type of staging to apply. Possible
447
+ values are: :obj:`skipsum`, :obj:`skipconcat`. Any other value will
448
+ default to no skip connections. (default: :obj:`skipsum`)
449
+ has_bn (bool, optional): Whether to apply batch normalization in the
450
+ layer. (default: :obj:`True`)
451
+ final_l2norm (bool, optional): Whether to apply L2 normalization to the
452
+ outputs. (default: :obj:`True`)
453
+ has_l2norm (bool, optional): Whether to apply L2 normalization after
454
+ the layer. (default: :obj:`True`)
455
+ dropout (float, optional): Dropout ratio at layer output.
456
+ (default: :obj:`0.2`)
457
+ has_act (bool, optional): Whether has activation after the layer.
458
+ (default: :obj:`True`)
459
+ final_act (bool, optional): Whether to apply activation after the layer
460
+ stack. (default: :obj:`True`)
461
+ act (str, optional): Activation to apply to layer output if
462
+ :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
463
+ virtual_node (bool, optional): Whether a virtual node is added to
464
+ graphs in :class:`GPSE` computation. (default: :obj:`True`)
465
+ multi_head_dim_inner (int, optional): Width of MLPs for PSE target
466
+ prediction heads. (default: :obj:`32`)
467
+ graph_pooling (str, optional): Type of graph pooling applied before
468
+ post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
469
+ (default: :obj:`add`)
470
+ use_repr (bool, optional): Whether to use the hidden representation of
471
+ the final layer as :class:`GPSE` encodings. (default: :obj:`True`)
472
+ repr_type (str, optional): Type of representation to use. Options are
473
+ :obj:`no_post_mp`, :obj:`one_layer_before`.
474
+ (default: :obj:`no_post_mp`)
475
+ bernoulli_threshold (float, optional): Threshold for Bernoulli sampling
476
+ of virtual nodes. (default: :obj:`0.5`)
477
+ """
478
+
479
+ url_dict = {
480
+ 'molpcba':
481
+ 'https://zenodo.org/record/8145095/files/'
482
+ 'gpse_model_molpcba_1.0.pt',
483
+ 'zinc':
484
+ 'https://zenodo.org/record/8145095/files/gpse_model_zinc_1.0.pt',
485
+ 'pcqm4mv2':
486
+ 'https://zenodo.org/record/8145095/files/'
487
+ 'gpse_model_pcqm4mv2_1.0.pt',
488
+ 'geom':
489
+ 'https://zenodo.org/record/8145095/files/gpse_model_geom_1.0.pt',
490
+ 'chembl':
491
+ 'https://zenodo.org/record/8145095/files/gpse_model_chembl_1.0.pt'
492
+ }
493
+
494
+ def __init__(
495
+ self,
496
+ dim_in: int = 20,
497
+ dim_out: int = 51,
498
+ dim_inner: int = 512,
499
+ layer_type: str = 'resgatedgcnconv',
500
+ layers_pre_mp: int = 1,
501
+ layers_mp: int = 20,
502
+ layers_post_mp: int = 2,
503
+ num_node_targets: int = 51,
504
+ num_graph_targets: int = 11,
505
+ stage_type: str = 'skipsum',
506
+ has_bn: bool = True,
507
+ head_bn: bool = False,
508
+ final_l2norm: bool = True,
509
+ has_l2norm: bool = True,
510
+ dropout: float = 0.2,
511
+ has_act: bool = True,
512
+ final_act: bool = True,
513
+ act: str = 'relu',
514
+ virtual_node: bool = True,
515
+ multi_head_dim_inner: int = 32,
516
+ graph_pooling: str = 'add',
517
+ use_repr: bool = True,
518
+ repr_type: str = 'no_post_mp',
519
+ bernoulli_threshold: float = 0.5,
520
+ ):
521
+ super().__init__()
522
+
523
+ self.use_repr = use_repr
524
+ self.repr_type = repr_type
525
+ self.bernoulli_threshold = bernoulli_threshold
526
+
527
+ if layers_pre_mp > 0:
528
+ self.pre_mp = GeneralMultiLayer(
529
+ name='linear',
530
+ in_channels=dim_in,
531
+ out_channels=dim_inner,
532
+ hidden_channels=dim_inner,
533
+ num_layers=layers_pre_mp,
534
+ has_batch_norm=has_bn,
535
+ has_l2_norm=has_l2norm,
536
+ dropout=dropout,
537
+ act=act,
538
+ final_act=final_act,
539
+ )
540
+ dim_in = dim_inner
541
+ if layers_mp > 0:
542
+ self.mp = GNNStackStage(
543
+ in_channels=dim_in,
544
+ out_channels=dim_inner,
545
+ num_layers=layers_mp,
546
+ layer_type=layer_type,
547
+ stage_type=stage_type,
548
+ final_l2_norm=final_l2norm,
549
+ has_batch_norm=has_bn,
550
+ has_l2_norm=has_l2norm,
551
+ dropout=dropout,
552
+ act=act if has_act else None,
553
+ )
554
+
555
+ self.post_mp = GNNInductiveHybridMultiHead(
556
+ dim_inner,
557
+ dim_out,
558
+ num_node_targets,
559
+ num_graph_targets,
560
+ layers_post_mp,
561
+ virtual_node,
562
+ multi_head_dim_inner,
563
+ graph_pooling,
564
+ head_bn,
565
+ has_l2norm,
566
+ dropout,
567
+ act,
568
+ )
569
+
570
+ self.reset_parameters()
571
+
572
+ def reset_parameters(self):
573
+ pass
574
+
575
+ @classmethod
576
+ def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):
577
+ r"""Returns a pretrained :class:`GPSE` model on a dataset.
578
+
579
+ Args:
580
+ name (str): The name of the dataset (:obj:`"molpcba"`,
581
+ :obj:`"zinc"`, :obj:`"pcqm4mv2"`, :obj:`"geom"`,
582
+ :obj:`"chembl"`).
583
+ root (str, optional): The root directory to save the pre-trained
584
+ model. (default: :obj:`"GPSE_pretrained"`)
585
+ """
586
+ root = osp.expanduser(osp.normpath(root))
587
+ os.makedirs(root, exist_ok=True)
588
+ path = download_url(cls.url_dict[name], root)
589
+
590
+ model = GPSE() # All pretrained models use the default arguments
591
+ model_state = torch.load(path, map_location='cpu')['model_state']
592
+ model_state_new = OrderedDict([(k.split('.', 1)[1], v)
593
+ for k, v in model_state.items()])
594
+ model.load_state_dict(model_state_new)
595
+
596
+ # Set the final linear layer to identity if we use hidden reprs
597
+ if model.use_repr:
598
+ if model.repr_type == 'one_layer_before':
599
+ model.post_mp.layer_post_mp.model[-1] = torch.nn.Identity()
600
+ elif model.repr_type == 'no_post_mp':
601
+ model.post_mp = IdentityHead()
602
+ else:
603
+ raise ValueError(f"Unknown type '{model.repr_type}'")
604
+
605
+ model.eval()
606
+ return model
607
+
608
+ def forward(self, batch):
609
+ batch = batch.clone()
610
+ for module in self.children():
611
+ batch = module(batch)
612
+ return batch
613
+
614
+
615
+ class GPSENodeEncoder(torch.nn.Module):
616
+ r"""A helper linear/MLP encoder that takes the :class:`GPSE` encodings
617
+ (based on the `"Graph Positional and Structural Encoder"
618
+ <https://arxiv.org/abs/2307.07107>`_ paper) precomputed as
619
+ :obj:`batch.pestat_GPSE` in the input graphs, maps them to a desired
620
+ dimension defined by :obj:`dim_pe_out` and appends them to node features.
621
+
622
+ Let's say we have a graph dataset with 64 original node features, and we
623
+ have generated GPSE encodings of dimension 32, i.e.
624
+ :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an
625
+ inner dimension of 128. To do so, we can map the 32-dimensional GPSE
626
+ encodings to a higher dimension of 64, and then append them to the
627
+ :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to
628
+ obtain a 128-dimensional node feature representation.
629
+ :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and
630
+ concatenation to :obj:`x`, the outputs of which can be used as input to a
631
+ GNN:
632
+
633
+ .. code-block:: python
634
+
635
+ encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,
636
+ expand_x=False)
637
+ gnn = GNN(...)
638
+
639
+ for batch in loader:
640
+ x = encoder(batch.x, batch.pestat_GPSE)
641
+ batch = gnn(x, batch.edge_index)
642
+
643
+ Args:
644
+ dim_emb (int): Size of final node embedding.
645
+ dim_pe_in (int): Original dimension of :obj:`batch.pestat_GPSE`.
646
+ dim_pe_out (int): Desired dimension of :class:`GPSE` after the encoder.
647
+ dim_in (int, optional): Original dimension of input node features,
648
+ required only if :obj:`expand_x` is set to :obj:`True`.
649
+ (default: :obj:`None`)
650
+ expand_x (bool, optional): Expand node features :obj:`x` from
651
+ :obj:`dim_in` to (:obj:`dim_emb` - :obj:`dim_pe_out`)
652
+ norm_type (str, optional): Type of normalization to apply.
653
+ (default: :obj:`batchnorm`)
654
+ model_type (str, optional): Type of encoder, either :obj:`mlp` or
655
+ :obj:`linear`. (default: :obj:`mlp`)
656
+ n_layers (int, optional): Number of MLP layers if :obj:`model_type` is
657
+ :obj:`mlp`. (default: :obj:`2`)
658
+ dropout_be (float, optional): Dropout ratio of inputs to encoder, i.e.
659
+ before encoding. (default: :obj:`0.5`)
660
+ dropout_ae (float, optional): Dropout ratio of outputs, i.e. after
661
+ encoding. (default: :obj:`0.2`)
662
+ """
663
+ def __init__(self, dim_emb: int, dim_pe_in: int, dim_pe_out: int,
664
+ dim_in: int = None, expand_x=False, norm_type='batchnorm',
665
+ model_type='mlp', n_layers=2, dropout_be=0.5, dropout_ae=0.2):
666
+ super().__init__()
667
+
668
+ assert dim_emb > dim_pe_out, ('Desired GPSE dimension (dim_pe_out) '
669
+ 'must be smaller than the final node '
670
+ 'embedding dimension (dim_emb).')
671
+
672
+ if expand_x:
673
+ self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe_out)
674
+ self.expand_x = expand_x
675
+
676
+ self.raw_norm = None
677
+ if norm_type == 'batchnorm':
678
+ self.raw_norm = nn.BatchNorm1d(dim_pe_in)
679
+
680
+ self.dropout_be = nn.Dropout(p=dropout_be)
681
+ self.dropout_ae = nn.Dropout(p=dropout_ae)
682
+
683
+ activation = nn.ReLU # register.act_dict[cfg.gnn.act]
684
+ if model_type == 'mlp':
685
+ layers = []
686
+ if n_layers == 1:
687
+ layers.append(torch.nn.Linear(dim_pe_in, dim_pe_out))
688
+ layers.append(activation())
689
+ else:
690
+ layers.append(torch.nn.Linear(dim_pe_in, 2 * dim_pe_out))
691
+ layers.append(activation())
692
+ for _ in range(n_layers - 2):
693
+ layers.append(
694
+ torch.nn.Linear(2 * dim_pe_out, 2 * dim_pe_out))
695
+ layers.append(activation())
696
+ layers.append(torch.nn.Linear(2 * dim_pe_out, dim_pe_out))
697
+ layers.append(activation())
698
+ self.pe_encoder = nn.Sequential(*layers)
699
+ elif model_type == 'linear':
700
+ self.pe_encoder = nn.Linear(dim_pe_in, dim_pe_out)
701
+ else:
702
+ raise ValueError(f"{self.__class__.__name__}: Does not support "
703
+ f"'{model_type}' encoder model.")
704
+
705
+ def forward(self, x, pos_enc):
706
+ pos_enc = self.dropout_be(pos_enc)
707
+ pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc
708
+ pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe
709
+ pos_enc = self.dropout_ae(pos_enc)
710
+
711
+ # Expand node features if needed
712
+ h = self.linear_x(x) if self.expand_x else x
713
+
714
+ # Concatenate final PEs to input embedding
715
+ return torch.cat((h, pos_enc), 1)
716
+
717
+
718
+ @torch.no_grad()
719
+ def gpse_process(
720
+ model: Module,
721
+ data: Data,
722
+ rand_type: str,
723
+ use_vn: bool = True,
724
+ bernoulli_thresh: float = 0.5,
725
+ neighbor_loader: bool = False,
726
+ num_neighbors: Optional[List[int]] = None,
727
+ fillval: int = 5,
728
+ layers_mp: int = None,
729
+ **kwargs,
730
+ ) -> torch.Tensor:
731
+ r"""Processes the data using the :class:`GPSE` model to generate and append
732
+ GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a
733
+ single :class:`~torch_geometric.data.Dataset` object.
734
+
735
+ Unlike transform-based GPSE processing (i.e.
736
+ :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument
737
+ does not append virtual nodes if set to :obj:`True`, and instead assumes
738
+ the input graphs to :obj:`gpse_process` already have virtual nodes. Under
739
+ normal circumstances, one does not need to call this function; running
740
+ :obj:`precompute_GPSE` on your whole dataset is advised instead.
741
+
742
+ Args:
743
+ model (Module): The :class:`GPSE` model.
744
+ data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data`
745
+ object.
746
+ rand_type (str, optional): Type of random features to use. Options are
747
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
748
+ (default: :obj:`NormalSE`)
749
+ use_vn (bool, optional): Whether the input graphs have virtual nodes.
750
+ (default: :obj:`True`)
751
+ bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of
752
+ virtual nodes. (default: :obj:`0.5`)
753
+ neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.
754
+ (default: :obj:`False`)
755
+ num_neighbors (List[int], optional): Number of neighbors to consider
756
+ for each message-passing layer. (default: :obj:`[30, 20, 10]`)
757
+ fillval (int, optional): Value to fill for missing
758
+ :obj:`num_neighbors`. (default: :obj:`5`)
759
+ layers_mp (int, optional): Number of message-passing layers.
760
+ (default: :obj:`None`)
761
+ **kwargs (optional): Additional arguments for :obj:`NeighborLoader`.
762
+
763
+ Returns:
764
+ torch.Tensor: A tensor corresponding to the original
765
+ :class:`~torch_geometric.data.Data` object, with :class:`GPSE`
766
+ encodings appended as :obj:`out.pestat_GPSE` attribute.
767
+ """
768
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
769
+ # Generate random features for the encoder
770
+ n = data.num_nodes
771
+ dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]
772
+
773
+ # Prepare input distributions for GPSE
774
+ if rand_type == 'NormalSE':
775
+ rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))
776
+ elif rand_type == 'UniformSE':
777
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
778
+ elif rand_type == 'BernoulliSE':
779
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
780
+ rand = (rand < bernoulli_thresh)
781
+ else:
782
+ raise ValueError(f'Unknown {rand_type=!r}')
783
+ data.x = torch.from_numpy(rand.astype('float32'))
784
+
785
+ if use_vn:
786
+ data.x[-1] = 0
787
+
788
+ model, data = model.to(device), data.to(device)
789
+ # Generate encodings using the pretrained encoder
790
+ if neighbor_loader:
791
+ if layers_mp is None:
792
+ raise ValueError('Please provide the number of message-passing '
793
+ 'layers as "layers_mp".')
794
+
795
+ num_neighbors = num_neighbors or [30, 20, 10]
796
+ diff = layers_mp - len(num_neighbors)
797
+ if fillval > 0 and diff > 0:
798
+ num_neighbors += [fillval] * diff
799
+
800
+ loader = NeighborLoader(data, num_neighbors=num_neighbors,
801
+ shuffle=False, pin_memory=True, **kwargs)
802
+ out_list = []
803
+ pbar = trange(data.num_nodes, position=2)
804
+ for batch in loader:
805
+ out, _ = model(batch.to(device))
806
+ out = out[:batch.batch_size].to("cpu", non_blocking=True)
807
+ out_list.append(out)
808
+ pbar.update(batch.batch_size)
809
+ out = torch.vstack(out_list)
810
+ else:
811
+ out, _ = model(data)
812
+ out = out.to("cpu")
813
+
814
+ return out
815
+
816
+
817
+ @torch.no_grad()
818
+ def gpse_process_batch(
819
+ model: GPSE,
820
+ batch,
821
+ rand_type: str,
822
+ use_vn: bool = True,
823
+ bernoulli_thresh: float = 0.5,
824
+ neighbor_loader: bool = False,
825
+ num_neighbors: Optional[List[int]] = None,
826
+ fillval: int = 5,
827
+ layers_mp: int = None,
828
+ **kwargs,
829
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
830
+ r"""Process a batch of data using the :class:`GPSE` model to generate and
831
+ append :class:`GPSE` encodings. Identical to `gpse_process`, but operates
832
+ on a batch of :class:`~torch_geometric.data.Data` objects.
833
+
834
+ Unlike transform-based GPSE processing (i.e.
835
+ :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument
836
+ does not append virtual nodes if set to :obj:`True`, and instead assumes
837
+ the input graphs to :obj:`gpse_process` already have virtual nodes. This is
838
+ because the virtual nodes are already added to graphs before the call to
839
+ :obj:`gpse_process_batch` in :obj:`precompute_GPSE` for better efficiency.
840
+ Under normal circumstances, one does not need to call this function;
841
+ running :obj:`precompute_GPSE` on your whole dataset is advised instead.
842
+
843
+ Args:
844
+ model (GPSE): The :class:`GPSE` model.
845
+ batch: A batch of PyG Data objects.
846
+ rand_type (str, optional): Type of random features to use. Options are
847
+ :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
848
+ (default: :obj:`NormalSE`)
849
+ use_vn (bool, optional): Whether the input graphs have virtual nodes.
850
+ (default: :obj:`True`)
851
+ bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of
852
+ virtual nodes. (default: :obj:`0.5`)
853
+ neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.
854
+ (default: :obj:`False`)
855
+ num_neighbors (List[int], optional): Number of neighbors to consider
856
+ for each message-passing layer. (default: :obj:`[30, 20, 10]`)
857
+ fillval (int, optional): Value to fill for missing
858
+ :obj:`num_neighbors`. (default: :obj:`5`)
859
+ layers_mp (int, optional): Number of message-passing layers.
860
+ (default: :obj:`None`)
861
+ **kwargs: Additional keyword arguments for :obj:`NeighborLoader`.
862
+
863
+ Returns:
864
+ Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding
865
+ to the stacked :class:`GPSE` encodings and the pointers indicating
866
+ individual graphs.
867
+ """
868
+ n = batch.num_nodes
869
+ dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]
870
+
871
+ # Prepare input distributions for GPSE
872
+ if rand_type == 'NormalSE':
873
+ rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))
874
+ elif rand_type == 'UniformSE':
875
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
876
+ elif rand_type == 'BernoulliSE':
877
+ rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))
878
+ rand = (rand < bernoulli_thresh)
879
+ else:
880
+ raise ValueError(f'Unknown {rand_type=!r}')
881
+ batch.x = torch.from_numpy(rand.astype('float32'))
882
+
883
+ if use_vn:
884
+ # HACK: We need to reset virtual node features to zeros to match the
885
+ # pretraining setting (virtual node applied after random node features
886
+ # are set, and the default node features for the virtual node are all
887
+ # zeros). Can potentially test if initializing virtual node features to
888
+ # random features is better than setting them to zeros.
889
+ for i in batch.ptr[1:]:
890
+ batch.x[i - 1] = 0
891
+
892
+ # Generate encodings using the pretrained encoder
893
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
894
+ model = model.to(device)
895
+ if neighbor_loader:
896
+ if layers_mp is None:
897
+ raise ValueError('Please provide the number of message-passing '
898
+ 'layers as "layers_mp".')
899
+
900
+ num_neighbors = num_neighbors or [30, 20, 10]
901
+ diff = layers_mp - len(num_neighbors)
902
+ if fillval > 0 and diff > 0:
903
+ num_neighbors += [fillval] * diff
904
+
905
+ loader = NeighborLoader(batch, num_neighbors=num_neighbors,
906
+ shuffle=False, pin_memory=True, **kwargs)
907
+ out_list = []
908
+ pbar = trange(batch.num_nodes, position=2)
909
+ for batch in loader:
910
+ out, _ = model(batch.to(device))
911
+ out = out[:batch.batch_size].to('cpu', non_blocking=True)
912
+ out_list.append(out)
913
+ pbar.update(batch.batch_size)
914
+ out = torch.vstack(out_list)
915
+ else:
916
+ out, _ = model(batch.to(device))
917
+ out = out.to('cpu')
918
+
919
+ return out, batch.ptr
920
+
921
+
922
+ @torch.no_grad()
923
+ def precompute_GPSE(model: GPSE, dataset: Dataset, use_vn: bool = True,
924
+ rand_type: str = 'NormalSE', **kwargs):
925
+ r"""Precomputes :class:`GPSE` encodings in-place for a given dataset using
926
+ a :class:`GPSE` model.
927
+
928
+ Args:
929
+ model (GPSE): The :class:`GPSE` model.
930
+ dataset (Dataset): A PyG Dataset.
931
+ use_vn (bool, optional): Whether to append virtual nodes to graphs in
932
+ :class:`GPSE` computation. Should match the setting used when
933
+ pre-training the :class:`GPSE` model. (default :obj:`True`)
934
+ rand_type (str, optional): The type of randomization to use.
935
+ (default :obj:`NormalSE`)
936
+ **kwargs (optional): Additional arguments for
937
+ :class:`~torch_geometric.data.DataLoader`.
938
+ """
939
+ # Temporarily replace the transformation
940
+ orig_dataset_transform = dataset.transform
941
+ dataset.transform = None
942
+ if use_vn:
943
+ dataset.transform = T.VirtualNode()
944
+
945
+ # Remove split indices, to be recovered at the end of the precomputation
946
+ tmp_store = {}
947
+ for name in [
948
+ 'train_mask', 'val_mask', 'test_mask', 'train_graph_index',
949
+ 'val_graph_index', 'test_graph_index', 'train_edge_index',
950
+ 'val_edge_index', 'test_edge_index'
951
+ ]:
952
+ if (name in dataset.data) and (dataset.slices is None
953
+ or name in dataset.slices):
954
+ tmp_store_data = dataset.data.pop(name)
955
+ tmp_store_slices = dataset.slices.pop(name) \
956
+ if dataset.slices else None
957
+ tmp_store[name] = (tmp_store_data, tmp_store_slices)
958
+
959
+ loader = DataLoader(dataset, shuffle=False, pin_memory=True, **kwargs)
960
+
961
+ # Batched GPSE precomputation loop
962
+ data_list = []
963
+ curr_idx = 0
964
+ pbar = trange(len(dataset), desc='Pre-computing GPSE')
965
+ tic = time.perf_counter()
966
+ for batch in loader:
967
+ batch_out, batch_ptr = gpse_process_batch(model, batch, rand_type,
968
+ **kwargs)
969
+
970
+ batch_out = batch_out.to('cpu', non_blocking=True)
971
+ # Need to wait for batch_ptr to finish transferring so that start and
972
+ # end indices are ready to use
973
+ batch_ptr = batch_ptr.to('cpu', non_blocking=False)
974
+
975
+ for start, end in zip(batch_ptr[:-1], batch_ptr[1:]):
976
+ data = dataset.get(curr_idx)
977
+ if use_vn:
978
+ end = end - 1
979
+ data.pestat_GPSE = batch_out[start:end]
980
+ data_list.append(data)
981
+ curr_idx += 1
982
+
983
+ pbar.update(len(batch_ptr) - 1)
984
+ pbar.close()
985
+
986
+ # Collate dataset and reset indices and data list
987
+ dataset.transform = orig_dataset_transform
988
+ dataset._indices = None
989
+ dataset._data_list = data_list
990
+ dataset.data, dataset.slices = dataset.collate(data_list)
991
+
992
+ # Recover split indices
993
+ for name, (tmp_store_data, tmp_store_slices) in tmp_store.items():
994
+ dataset.data[name] = tmp_store_data
995
+ if tmp_store_slices is not None:
996
+ dataset.slices[name] = tmp_store_slices
997
+ dataset._data_list = None
998
+
999
+ timestr = time.strftime('%H:%M:%S', time.gmtime(time.perf_counter() - tic))
1000
+ logging.info(f'Finished GPSE pre-computation, took {timestr}')
1001
+
1002
+ # Release resource and recover original configs
1003
+ del model
1004
+ torch.cuda.empty_cache()
1005
+
1006
+
1007
+ def cosim_col_sep(pred: torch.Tensor, true: torch.Tensor,
1008
+ batch_idx: torch.Tensor) -> torch.Tensor:
1009
+ r"""Calculates the average cosine similarity between predicted and true
1010
+ features on a batch of graphs.
1011
+
1012
+ Args:
1013
+ pred (torch.Tensor): Predicted outputs.
1014
+ true (torch.Tensor): Value of ground truths.
1015
+ batch_idx (torch.Tensor): Batch indices to separate the graphs.
1016
+
1017
+ Returns:
1018
+ torch.Tensor: Average cosine similarity per graph in batch.
1019
+
1020
+ Raises:
1021
+ ValueError: If batch_index is not specified.
1022
+ """
1023
+ if batch_idx is None:
1024
+ raise ValueError("mae_cosim_col_sep requires batch index as "
1025
+ "input to distinguish different graphs.")
1026
+ batch_idx = batch_idx + 1 if batch_idx.min() == -1 else batch_idx
1027
+ pred_dense = to_dense_batch(pred, batch_idx)[0]
1028
+ true_dense = to_dense_batch(true, batch_idx)[0]
1029
+ mask = (true_dense == 0).all(1) # exclude trivial features from loss
1030
+ loss = 1 - F.cosine_similarity(pred_dense, true_dense, dim=1)[~mask].mean()
1031
+ return loss
1032
+
1033
+
1034
+ def gpse_loss(pred: torch.Tensor, true: torch.Tensor,
1035
+ batch_idx: torch.Tensor = None) \
1036
+ -> Tuple[torch.Tensor, torch.Tensor]:
1037
+ r"""Calculates :class:`GPSE` loss as the sum of MAE loss and cosine
1038
+ similarity loss over a batch of graphs.
1039
+
1040
+ Args:
1041
+ pred (torch.Tensor): Predicted outputs.
1042
+ true (torch.Tensor): Value of ground truths.
1043
+ batch_idx (torch.Tensor): Batch indices to separate the graphs.
1044
+
1045
+ Returns:
1046
+ Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding
1047
+ to the :class:`GPSE` loss and the predicted node-and-graph level
1048
+ outputs.
1049
+
1050
+ """
1051
+ if batch_idx is None:
1052
+ raise ValueError("mae_cosim_col_sep requires batch index as "
1053
+ "input to distinguish different graphs.")
1054
+ mae_loss = F.l1_loss(pred, true)
1055
+ cosim_loss = cosim_col_sep(pred, true, batch_idx)
1056
+ loss = mae_loss + cosim_loss
1057
+ return loss, pred
1058
+
1059
+
1060
+ def process_batch_idx(batch_idx, true, use_vn=True):
1061
+ r"""Processes batch indices to adjust for the removal of virtual nodes, and
1062
+ pads batch index for hybrid tasks.
1063
+
1064
+ Args:
1065
+ batch_idx: Batch indices to separate the graphs.
1066
+ true: Value of ground truths.
1067
+ use_vn: If input graphs have virtual nodes that need to be removed.
1068
+
1069
+ Returns:
1070
+ torch.Tensor: Batch indices that separate the graphs.
1071
+ """
1072
+ if batch_idx is None:
1073
+ return
1074
+ if use_vn: # remove virtual node
1075
+ batch_idx = torch.concat([
1076
+ batch_idx[batch_idx == i][:-1]
1077
+ for i in range(batch_idx.max().item() + 1)
1078
+ ])
1079
+ # Pad batch index for hybrid tasks (set batch index for graph heads to -1)
1080
+ if (pad := true.shape[0] - batch_idx.shape[0]) > 0:
1081
+ pad_idx = -torch.ones(pad, dtype=torch.long, device=batch_idx.device)
1082
+ batch_idx = torch.hstack([batch_idx, pad_idx])
1083
+ return batch_idx