pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,8 +1,14 @@
1
1
  import math
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  from torch import Tensor
5
6
 
7
+ __all__ = classes = [
8
+ 'PositionalEncoding',
9
+ 'TemporalEncoding',
10
+ ]
11
+
6
12
 
7
13
  class PositionalEncoding(torch.nn.Module):
8
14
  r"""The positional encoding scheme from the `"Attention Is All You Need"
@@ -23,12 +29,15 @@ class PositionalEncoding(torch.nn.Module):
23
29
  granularity (float, optional): The granularity of the positions. If
24
30
  set to smaller value, the encoder will capture more fine-grained
25
31
  changes in positions. (default: :obj:`1.0`)
32
+ device (torch.device, optional): The device of the module.
33
+ (default: :obj:`None`)
26
34
  """
27
35
  def __init__(
28
36
  self,
29
37
  out_channels: int,
30
38
  base_freq: float = 1e-4,
31
39
  granularity: float = 1.0,
40
+ device: Optional[torch.device] = None,
32
41
  ):
33
42
  super().__init__()
34
43
 
@@ -40,7 +49,8 @@ class PositionalEncoding(torch.nn.Module):
40
49
  self.base_freq = base_freq
41
50
  self.granularity = granularity
42
51
 
43
- frequency = torch.logspace(0, 1, out_channels // 2, base_freq)
52
+ frequency = torch.logspace(0, 1, out_channels // 2, base_freq,
53
+ device=device)
44
54
  self.register_buffer('frequency', frequency)
45
55
 
46
56
  self.reset_parameters()
@@ -75,13 +85,17 @@ class TemporalEncoding(torch.nn.Module):
75
85
 
76
86
  Args:
77
87
  out_channels (int): Size :math:`d` of each output sample.
88
+ device (torch.device, optional): The device of the module.
89
+ (default: :obj:`None`)
78
90
  """
79
- def __init__(self, out_channels: int):
91
+ def __init__(self, out_channels: int,
92
+ device: Optional[torch.device] = None):
80
93
  super().__init__()
81
94
  self.out_channels = out_channels
82
95
 
83
96
  sqrt = math.sqrt(out_channels)
84
- weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels).view(1, -1)
97
+ weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels,
98
+ device=device).view(1, -1)
85
99
  self.register_buffer('weight', weight)
86
100
 
87
101
  self.reset_parameters()
torch_geometric/nn/fx.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import copy
2
2
  import warnings
3
- from typing import Any, Dict, Optional
3
+ from typing import Any, Callable, Dict, List, Optional, Type, Union
4
4
 
5
5
  import torch
6
+ from torch import Tensor
6
7
  from torch.nn import Module, ModuleDict, ModuleList, Sequential
7
8
 
8
9
  try:
@@ -129,11 +130,13 @@ class Transformer:
129
130
  # (node-level, edge-level) by filling `self._state`:
130
131
  for node in list(self.graph.nodes):
131
132
  if node.op == 'call_function' and 'training' in node.kwargs:
132
- warnings.warn(f"Found function '{node.name}' with keyword "
133
- f"argument 'training'. During FX tracing, this "
134
- f"will likely be baked in as a constant value. "
135
- f"Consider replacing this function by a module "
136
- f"to properly encapsulate its training flag.")
133
+ warnings.warn(
134
+ f"Found function '{node.name}' with keyword "
135
+ f"argument 'training'. During FX tracing, this "
136
+ f"will likely be baked in as a constant value. "
137
+ f"Consider replacing this function by a module "
138
+ f"to properly encapsulate its training flag.",
139
+ stacklevel=2)
137
140
 
138
141
  if node.op == 'placeholder':
139
142
  if node.name not in self._state:
@@ -289,7 +292,7 @@ def symbolic_trace(
289
292
  # details on the rationale
290
293
  # TODO: Revisit https://github.com/pyg-team/pytorch_geometric/pull/5021
291
294
  @st.compatibility(is_backward_compatible=True)
292
- def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]],
295
+ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]],
293
296
  concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
294
297
 
295
298
  if isinstance(root, torch.nn.Module):
@@ -303,17 +306,16 @@ def symbolic_trace(
303
306
  self.root = torch.nn.Module()
304
307
  fn = root
305
308
 
306
- tracer_cls: Optional[st.Type['Tracer']] = getattr(
309
+ tracer_cls: Optional[Type['Tracer']] = getattr(
307
310
  self, '__class__', None)
308
311
  self.graph = Graph(tracer_cls=tracer_cls)
309
312
 
310
- self.tensor_attrs: Dict[st.Union[torch.Tensor, st.ScriptObject],
311
- str] = {}
313
+ self.tensor_attrs: Dict[Union[Tensor, st.ScriptObject], str] = {}
312
314
 
313
315
  def collect_tensor_attrs(m: torch.nn.Module,
314
- prefix_atoms: st.List[str]):
316
+ prefix_atoms: List[str]):
315
317
  for k, v in m.__dict__.items():
316
- if isinstance(v, (torch.Tensor, st.ScriptObject)):
318
+ if isinstance(v, (Tensor, st.ScriptObject)):
317
319
  self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
318
320
  for k, v in m.named_children():
319
321
  collect_tensor_attrs(v, prefix_atoms + [k])
@@ -144,10 +144,10 @@ class PyGModelHubMixin(ModelHubMixin):
144
144
  revision,
145
145
  cache_dir,
146
146
  force_download,
147
- proxies,
148
- resume_download,
149
147
  local_files_only,
150
148
  token,
149
+ proxies=None,
150
+ resume_download=False,
151
151
  dataset_name='',
152
152
  model_name='',
153
153
  map_location='cpu',
@@ -165,8 +165,6 @@ class PyGModelHubMixin(ModelHubMixin):
165
165
  revision=revision,
166
166
  cache_dir=cache_dir,
167
167
  force_download=force_download,
168
- proxies=proxies,
169
- resume_download=resume_download,
170
168
  token=token,
171
169
  local_files_only=local_files_only,
172
170
  )
@@ -188,8 +186,6 @@ class PyGModelHubMixin(ModelHubMixin):
188
186
  cls,
189
187
  pretrained_model_name_or_path: str,
190
188
  force_download: bool = False,
191
- resume_download: bool = False,
192
- proxies: Optional[Dict] = None,
193
189
  token: Optional[Union[str, bool]] = None,
194
190
  cache_dir: Optional[str] = None,
195
191
  local_files_only: bool = False,
@@ -215,13 +211,6 @@ class PyGModelHubMixin(ModelHubMixin):
215
211
  (re-)download of the model weights and configuration files,
216
212
  overriding the cached versions if they exist.
217
213
  (default: :obj:`False`)
218
- resume_download (bool, optional): Whether to delete incompletely
219
- received files. Will attempt to resume the download if such a
220
- file exists. (default: :obj:`False`)
221
- proxies (Dict[str, str], optional): A dictionary of proxy servers
222
- to use by protocol or endpoint, *e.g.*,
223
- :obj:`{'http': 'foo.bar:3128', 'http://host': 'foo.bar:4012'}`.
224
- The proxies are used on each request. (default: :obj:`None`)
225
214
  token (str or bool, optional): The token to use as HTTP bearer
226
215
  authorization for remote files. If set to :obj:`True`, will use
227
216
  the token generated when running :obj:`transformers-cli login`
@@ -239,8 +228,6 @@ class PyGModelHubMixin(ModelHubMixin):
239
228
  return super().from_pretrained(
240
229
  pretrained_model_name_or_path,
241
230
  force_download=force_download,
242
- resume_download=resume_download,
243
- proxies=proxies,
244
231
  use_auth_token=token,
245
232
  cache_dir=cache_dir,
246
233
  local_files_only=local_files_only,
@@ -12,6 +12,7 @@ from .re_net import RENet
12
12
  from .graph_unet import GraphUNet
13
13
  from .schnet import SchNet
14
14
  from .dimenet import DimeNet, DimeNetPlusPlus
15
+ from .gpse import GPSE, GPSENodeEncoder
15
16
  from .captum import to_captum_model
16
17
  from .metapath2vec import MetaPath2Vec
17
18
  from .deepgcn import DeepGCNLayer
@@ -28,11 +29,14 @@ from .gnnff import GNNFF
28
29
  from .pmlp import PMLP
29
30
  from .neural_fingerprint import NeuralFingerprint
30
31
  from .visnet import ViSNet
31
- from .g_retriever import GRetriever
32
+ from .lpformer import LPFormer
33
+ from .sgformer import SGFormer
32
34
 
35
+ from .polynormer import Polynormer
33
36
  # Deprecated:
34
37
  from torch_geometric.explain.algorithm.captum import (to_captum_input,
35
38
  captum_output_to_dicts)
39
+ from .attract_repel import ARLinkPredictor
36
40
 
37
41
  __all__ = classes = [
38
42
  'MLP',
@@ -58,6 +62,8 @@ __all__ = classes = [
58
62
  'SchNet',
59
63
  'DimeNet',
60
64
  'DimeNetPlusPlus',
65
+ 'GPSE',
66
+ 'GPSENodeEncoder',
61
67
  'to_captum_model',
62
68
  'to_captum_input',
63
69
  'captum_output_to_dicts',
@@ -76,5 +82,8 @@ __all__ = classes = [
76
82
  'PMLP',
77
83
  'NeuralFingerprint',
78
84
  'ViSNet',
79
- 'GRetriever',
85
+ 'LPFormer',
86
+ 'SGFormer',
87
+ 'Polynormer',
88
+ 'ARLinkPredictor',
80
89
  ]
@@ -160,7 +160,7 @@ class AttentiveFP(torch.nn.Module):
160
160
  edge_index = torch.stack([row, batch], dim=0)
161
161
 
162
162
  out = global_add_pool(x, batch).relu_()
163
- for t in range(self.num_timesteps):
163
+ for _ in range(self.num_timesteps):
164
164
  h = F.elu_(self.mol_conv((x, out), edge_index))
165
165
  h = F.dropout(h, p=self.dropout, training=self.training)
166
166
  out = self.mol_gru(h, out).relu_()
@@ -0,0 +1,148 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class ARLinkPredictor(torch.nn.Module):
6
+ r"""Link predictor using Attract-Repel embeddings from the paper
7
+ `"Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs"
8
+ <https://arxiv.org/abs/2106.09671>`_.
9
+
10
+ This model splits node embeddings into: attract and
11
+ repel.
12
+ The edge prediction score is computed as the dot product of attract
13
+ components minus the dot product of repel components.
14
+
15
+ Args:
16
+ in_channels (int): Size of each input sample.
17
+ hidden_channels (int): Size of hidden embeddings.
18
+ out_channels (int, optional): Size of output embeddings.
19
+ If set to :obj:`None`, will default to :obj:`hidden_channels`.
20
+ (default: :obj:`None`)
21
+ num_layers (int): Number of message passing layers.
22
+ (default: :obj:`2`)
23
+ dropout (float): Dropout probability. (default: :obj:`0.0`)
24
+ attract_ratio (float): Ratio to use for attract component.
25
+ Must be between 0 and 1. (default: :obj:`0.5`)
26
+ """
27
+ def __init__(self, in_channels, hidden_channels, out_channels=None,
28
+ num_layers=2, dropout=0.0, attract_ratio=0.5):
29
+ super().__init__()
30
+
31
+ if out_channels is None:
32
+ out_channels = hidden_channels
33
+
34
+ self.in_channels = in_channels
35
+ self.hidden_channels = hidden_channels
36
+ self.out_channels = out_channels
37
+ self.num_layers = num_layers
38
+ self.dropout = dropout
39
+
40
+ if not 0 <= attract_ratio <= 1:
41
+ raise ValueError(
42
+ f"attract_ratio must be between 0 and 1, got {attract_ratio}")
43
+
44
+ self.attract_ratio = attract_ratio
45
+ self.attract_dim = int(out_channels * attract_ratio)
46
+ self.repel_dim = out_channels - self.attract_dim
47
+
48
+ # Create model layers
49
+ self.lins = torch.nn.ModuleList()
50
+ self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
51
+
52
+ for _ in range(num_layers - 2):
53
+ self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
54
+
55
+ # Final layer splits into attract and repel components
56
+ self.lin_attract = torch.nn.Linear(hidden_channels, self.attract_dim)
57
+ self.lin_repel = torch.nn.Linear(hidden_channels, self.repel_dim)
58
+
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ """Reset all learnable parameters."""
63
+ for lin in self.lins:
64
+ lin.reset_parameters()
65
+ self.lin_attract.reset_parameters()
66
+ self.lin_repel.reset_parameters()
67
+
68
+ def encode(self, x, *args, **kwargs):
69
+ """Encode node features into attract-repel embeddings.
70
+
71
+ Args:
72
+ x (torch.Tensor): Node feature matrix of shape
73
+ :obj:`[num_nodes, in_channels]`.
74
+ *args: Variable length argument list
75
+ **kwargs: Arbitrary keyword arguments
76
+
77
+ """
78
+ for lin in self.lins:
79
+ x = lin(x)
80
+ x = F.relu(x)
81
+ x = F.dropout(x, p=self.dropout, training=self.training)
82
+
83
+ # Split into attract and repel components
84
+ attract_x = self.lin_attract(x)
85
+ repel_x = self.lin_repel(x)
86
+
87
+ return attract_x, repel_x
88
+
89
+ def decode(self, attract_z, repel_z, edge_index):
90
+ """Decode edge scores from attract-repel embeddings.
91
+
92
+ Args:
93
+ attract_z (torch.Tensor): Attract embeddings of shape
94
+ :obj:`[num_nodes, attract_dim]`.
95
+ repel_z (torch.Tensor): Repel embeddings of shape
96
+ :obj:`[num_nodes, repel_dim]`.
97
+ edge_index (torch.Tensor): Edge indices of shape
98
+ :obj:`[2, num_edges]`.
99
+
100
+ Returns:
101
+ torch.Tensor: Edge prediction scores.
102
+ """
103
+ # Get node embeddings for edges
104
+ row, col = edge_index
105
+ attract_z_row = attract_z[row]
106
+ attract_z_col = attract_z[col]
107
+ repel_z_row = repel_z[row]
108
+ repel_z_col = repel_z[col]
109
+
110
+ # Compute attract-repel scores
111
+ attract_score = torch.sum(attract_z_row * attract_z_col, dim=1)
112
+ repel_score = torch.sum(repel_z_row * repel_z_col, dim=1)
113
+
114
+ return attract_score - repel_score
115
+
116
+ def forward(self, x, edge_index):
117
+ """Forward pass for link prediction.
118
+
119
+ Args:
120
+ x (torch.Tensor): Node feature matrix.
121
+ edge_index (torch.Tensor): Edge indices to predict.
122
+
123
+ Returns:
124
+ torch.Tensor: Predicted edge scores.
125
+ """
126
+ # Encode nodes into attract-repel embeddings
127
+ attract_z, repel_z = self.encode(x)
128
+
129
+ # Decode target edges
130
+ return torch.sigmoid(self.decode(attract_z, repel_z, edge_index))
131
+
132
+ def calculate_r_fraction(self, attract_z, repel_z):
133
+ """Calculate the R-fraction (proportion of energy in repel space).
134
+
135
+ Args:
136
+ attract_z (torch.Tensor): Attract embeddings.
137
+ repel_z (torch.Tensor): Repel embeddings.
138
+
139
+ Returns:
140
+ float: R-fraction value.
141
+ """
142
+ attract_norm_squared = torch.sum(attract_z**2)
143
+ repel_norm_squared = torch.sum(repel_z**2)
144
+
145
+ r_fraction = repel_norm_squared / (attract_norm_squared +
146
+ repel_norm_squared + 1e-10)
147
+
148
+ return r_fraction.item()
@@ -415,7 +415,8 @@ class GCN(BasicGNN):
415
415
  (default: :obj:`None`)
416
416
  jk (str, optional): The Jumping Knowledge mode. If specified, the model
417
417
  will additionally apply a final linear transformation to transform
418
- node embeddings to the expected output feature dimensionality.
418
+ node embeddings to the expected output feature dimensionality,
419
+ while default will not.
419
420
  (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
420
421
  :obj:`"lstm"`). (default: :obj:`None`)
421
422
  **kwargs (optional): Additional arguments of
@@ -94,7 +94,7 @@ def to_captum_model(
94
94
  function will return the output of the model for the element at
95
95
  the index specified. (default: :obj:`None`)
96
96
  metadata (Metadata, optional): The metadata of the heterogeneous graph.
97
- Only required if explaning a
97
+ Only required if explaining a
98
98
  :class:`~torch_geometric.data.HeteroData` object.
99
99
  (default: :obj:`None`)
100
100
  """
@@ -106,7 +106,7 @@ class DeepGraphInfomax(torch.nn.Module):
106
106
  """
107
107
  from sklearn.linear_model import LogisticRegression
108
108
 
109
- clf = LogisticRegression(solver=solver, *args,
109
+ clf = LogisticRegression(*args, solver=solver,
110
110
  **kwargs).fit(train_z.detach().cpu().numpy(),
111
111
  train_y.detach().cpu().numpy())
112
112
  return clf.score(test_z.detach().cpu().numpy(),
@@ -755,7 +755,7 @@ class DimeNetPlusPlus(DimeNet):
755
755
  interaction blocks after the skip connection. (default: :obj:`2`)
756
756
  num_output_layers: (int, optional): Number of linear layers for the
757
757
  output blocks. (default: :obj:`3`)
758
- act: (str or Callable, optional): The activation funtion.
758
+ act: (str or Callable, optional): The activation function.
759
759
  (default: :obj:`"swish"`)
760
760
  output_initializer (str, optional): The initialization method for the
761
761
  output layer (:obj:`"zeros"`, :obj:`"glorot_orthogonal"`).
@@ -805,7 +805,7 @@ class DimeNetPlusPlus(DimeNet):
805
805
 
806
806
  # We are re-using the RBF, SBF and embedding layers of `DimeNet` and
807
807
  # redefine output_block and interaction_block in DimeNet++.
808
- # Hence, it is to be noted that in the above initalization, the
808
+ # Hence, it is to be noted that in the above initialization, the
809
809
  # variable `num_bilinear` does not have any purpose as it is used
810
810
  # solely in the `OutputBlock` of DimeNet:
811
811
  self.output_blocks = torch.nn.ModuleList([
@@ -1,5 +1,7 @@
1
1
  # Shameless steal from: https://github.com/klicperajo/dimenet
2
2
 
3
+ import math
4
+
3
5
  import numpy as np
4
6
  import sympy as sym
5
7
  from scipy import special as sp
@@ -62,8 +64,8 @@ def bessel_basis(n, k):
62
64
 
63
65
 
64
66
  def sph_harm_prefactor(k, m):
65
- return ((2 * k + 1) * np.math.factorial(k - abs(m)) /
66
- (4 * np.pi * np.math.factorial(k + abs(m))))**0.5
67
+ return ((2 * k + 1) * math.factorial(k - abs(m)) /
68
+ (4 * np.pi * math.factorial(k + abs(m))))**0.5
67
69
 
68
70
 
69
71
  def associated_legendre_polynomials(k, zero_m_only=True):