pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: flit 3.9.0
2
+ Generator: flit 3.12.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2023 PyG Team <team@pyg.org>
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
@@ -1,6 +1,13 @@
1
+ from collections import defaultdict
2
+
3
+ import torch
4
+ import torch_geometric.typing
5
+
1
6
  from ._compile import compile, is_compiling
7
+ from ._onnx import is_in_onnx_export, safe_onnx_export
2
8
  from .index import Index
3
9
  from .edge_index import EdgeIndex
10
+ from .hash_tensor import HashTensor
4
11
  from .seed import seed_everything
5
12
  from .home import get_home_dir, set_home_dir
6
13
  from .device import is_mps_available, is_xpu_available, device
@@ -24,16 +31,19 @@ from .lazy_loader import LazyLoader
24
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
25
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
26
33
 
27
- __version__ = '2.6.0.dev20240704'
34
+ __version__ = '2.8.0.dev20251207'
28
35
 
29
36
  __all__ = [
30
37
  'Index',
31
38
  'EdgeIndex',
39
+ 'HashTensor',
32
40
  'seed_everything',
33
41
  'get_home_dir',
34
42
  'set_home_dir',
35
43
  'compile',
36
44
  'is_compiling',
45
+ 'is_in_onnx_export',
46
+ 'safe_onnx_export',
37
47
  'is_mps_available',
38
48
  'is_xpu_available',
39
49
  'device',
@@ -47,3 +57,26 @@ __all__ = [
47
57
  'torch_geometric',
48
58
  '__version__',
49
59
  ]
60
+
61
+ if not torch_geometric.typing.WITH_PT113:
62
+ import warnings as std_warnings
63
+
64
+ std_warnings.warn(
65
+ "PyG 2.7 removed support for PyTorch < 1.13. Consider "
66
+ "Consider upgrading to PyTorch >= 1.13 or downgrading "
67
+ "to PyG <= 2.6. ", stacklevel=2)
68
+
69
+ # Serialization ###############################################################
70
+
71
+ if torch_geometric.typing.WITH_PT24:
72
+ torch.serialization.add_safe_globals([
73
+ dict,
74
+ list,
75
+ defaultdict,
76
+ Index,
77
+ torch_geometric.index.CatMetadata,
78
+ EdgeIndex,
79
+ torch_geometric.edge_index.SortOrder,
80
+ torch_geometric.edge_index.CatMetadata,
81
+ HashTensor,
82
+ ])
@@ -10,6 +10,8 @@ def is_compiling() -> bool:
10
10
  r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via
11
11
  :meth:`torch.compile`.
12
12
  """
13
+ if torch_geometric.typing.WITH_PT23:
14
+ return torch.compiler.is_compiling()
13
15
  if torch_geometric.typing.WITH_PT21:
14
16
  return torch._dynamo.is_compiling()
15
17
  return False # pragma: no cover
@@ -25,10 +27,16 @@ def compile(
25
27
  This function has the same signature as :meth:`torch.compile` (see
26
28
  `here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).
27
29
 
30
+ Args:
31
+ model: The model to compile.
32
+ *args: Additional arguments of :meth:`torch.compile`.
33
+ **kwargs: Additional keyword arguments of :meth:`torch.compile`.
34
+
28
35
  .. note::
29
36
  :meth:`torch_geometric.compile` is deprecated in favor of
30
37
  :meth:`torch.compile`.
31
38
  """
32
- warnings.warn("'torch_geometric.compile' is deprecated in favor of "
33
- "'torch.compile'")
34
- return torch.compile(model, *args, **kwargs)
39
+ warnings.warn(
40
+ "'torch_geometric.compile' is deprecated in favor of "
41
+ "'torch.compile'", stacklevel=2)
42
+ return torch.compile(model, *args, **kwargs) # type: ignore
@@ -0,0 +1,228 @@
1
+ import warnings
2
+ from os import PathLike
3
+ from typing import Any, Union
4
+
5
+ import torch
6
+
7
+ from torch_geometric import is_compiling
8
+
9
+
10
+ def is_in_onnx_export() -> bool:
11
+ r"""Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via
12
+ :meth:`torch.onnx.export`.
13
+ """
14
+ if is_compiling():
15
+ return False
16
+ if torch.jit.is_scripting():
17
+ return False
18
+ return torch.onnx.is_in_onnx_export()
19
+
20
+
21
+ def safe_onnx_export(
22
+ model: torch.nn.Module,
23
+ args: Union[torch.Tensor, tuple[Any, ...]],
24
+ f: Union[str, PathLike[Any], None],
25
+ skip_on_error: bool = False,
26
+ **kwargs: Any,
27
+ ) -> bool:
28
+ r"""A safe wrapper around :meth:`torch.onnx.export` that handles known
29
+ ONNX serialization issues in PyTorch Geometric.
30
+
31
+ This function provides workarounds for the ``onnx_ir.serde.SerdeError``
32
+ with boolean ``allowzero`` attributes that occurs in certain environments.
33
+
34
+ Args:
35
+ model (torch.nn.Module): The model to export.
36
+ args (torch.Tensor or tuple): The input arguments for the model.
37
+ f (str or PathLike): The file path to save the model.
38
+ skip_on_error (bool): If True, return False instead of raising when
39
+ workarounds fail. Useful for CI environments.
40
+ **kwargs: Additional arguments passed to :meth:`torch.onnx.export`.
41
+
42
+ Returns:
43
+ bool: True if export succeeded, False if skipped due to known issues
44
+ (only when skip_on_error=True).
45
+
46
+ Example:
47
+ >>> from torch_geometric.nn import SAGEConv
48
+ >>> from torch_geometric import safe_onnx_export
49
+ >>>
50
+ >>> class MyModel(torch.nn.Module):
51
+ ... def __init__(self):
52
+ ... super().__init__()
53
+ ... self.conv = SAGEConv(8, 16)
54
+ ... def forward(self, x, edge_index):
55
+ ... return self.conv(x, edge_index)
56
+ >>>
57
+ >>> model = MyModel()
58
+ >>> x = torch.randn(3, 8)
59
+ >>> edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
60
+ >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx')
61
+ >>>
62
+ >>> # For CI environments:
63
+ >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx',
64
+ ... skip_on_error=True)
65
+ >>> if not success:
66
+ ... print("ONNX export skipped due to known upstream issue")
67
+ """
68
+ # Convert single tensor to tuple for torch.onnx.export compatibility
69
+ if isinstance(args, torch.Tensor):
70
+ args = (args, )
71
+
72
+ try:
73
+ # First attempt: standard ONNX export
74
+ torch.onnx.export(model, args, f, **kwargs)
75
+ return True
76
+
77
+ except Exception as e:
78
+ error_str = str(e)
79
+ error_type = type(e).__name__
80
+
81
+ # Check for the specific onnx_ir.serde.SerdeError patterns
82
+ is_allowzero_error = (('onnx_ir.serde.SerdeError' in error_str
83
+ and 'allowzero' in error_str) or
84
+ 'ValueError: Value out of range: 1' in error_str
85
+ or 'serialize_model_into' in error_str
86
+ or 'serialize_attribute_into' in error_str)
87
+
88
+ if is_allowzero_error:
89
+ warnings.warn(
90
+ f"Encountered known ONNX serialization issue ({error_type}). "
91
+ "This is likely the allowzero boolean attribute bug. "
92
+ "Attempting workaround...", UserWarning, stacklevel=2)
93
+
94
+ # Apply workaround strategies
95
+ return _apply_onnx_allowzero_workaround(model, args, f,
96
+ skip_on_error, **kwargs)
97
+
98
+ else:
99
+ # Re-raise other errors
100
+ raise
101
+
102
+
103
+ def _apply_onnx_allowzero_workaround(
104
+ model: torch.nn.Module,
105
+ args: tuple[Any, ...],
106
+ f: Union[str, PathLike[Any], None],
107
+ skip_on_error: bool = False,
108
+ **kwargs: Any,
109
+ ) -> bool:
110
+ r"""Apply workaround strategies for onnx_ir.serde.SerdeError with allowzero
111
+ attributes.
112
+
113
+ Returns:
114
+ bool: True if export succeeded, False if skipped (when
115
+ skip_on_error=True).
116
+ """
117
+ # Strategy 1: Try without dynamo if it was enabled
118
+ if kwargs.get('dynamo', False):
119
+ try:
120
+ kwargs_no_dynamo = kwargs.copy()
121
+ kwargs_no_dynamo['dynamo'] = False
122
+
123
+ warnings.warn(
124
+ "Retrying ONNX export with dynamo=False as workaround",
125
+ UserWarning, stacklevel=3)
126
+
127
+ torch.onnx.export(model, args, f, **kwargs_no_dynamo)
128
+ return True
129
+
130
+ except Exception:
131
+ pass
132
+
133
+ # Strategy 2: Try with different opset versions
134
+ original_opset = kwargs.get('opset_version', 18)
135
+ for opset_version in [17, 16, 15, 14, 13, 11]:
136
+ if opset_version != original_opset:
137
+ try:
138
+ kwargs_opset = kwargs.copy()
139
+ kwargs_opset['opset_version'] = opset_version
140
+
141
+ warnings.warn(
142
+ f"Retrying ONNX export with opset_version={opset_version}",
143
+ UserWarning, stacklevel=3)
144
+
145
+ torch.onnx.export(model, args, f, **kwargs_opset)
146
+ return True
147
+
148
+ except Exception:
149
+ continue
150
+
151
+ # Strategy 3: Try legacy export (non-dynamo with older opset)
152
+ try:
153
+ kwargs_legacy = kwargs.copy()
154
+ kwargs_legacy['dynamo'] = False
155
+ kwargs_legacy['opset_version'] = 11
156
+
157
+ warnings.warn(
158
+ "Retrying ONNX export with legacy settings "
159
+ "(dynamo=False, opset_version=11)", UserWarning, stacklevel=3)
160
+
161
+ torch.onnx.export(model, args, f, **kwargs_legacy)
162
+ return True
163
+
164
+ except Exception:
165
+ pass
166
+
167
+ # Strategy 4: Try with minimal settings
168
+ try:
169
+ minimal_kwargs: dict[str, Any] = {
170
+ 'opset_version': 11,
171
+ 'dynamo': False,
172
+ }
173
+ # Add optional parameters if they exist
174
+ if kwargs.get('input_names') is not None:
175
+ minimal_kwargs['input_names'] = kwargs.get('input_names')
176
+ if kwargs.get('output_names') is not None:
177
+ minimal_kwargs['output_names'] = kwargs.get('output_names')
178
+
179
+ warnings.warn(
180
+ "Retrying ONNX export with minimal settings as last resort",
181
+ UserWarning, stacklevel=3)
182
+
183
+ torch.onnx.export(model, args, f, **minimal_kwargs)
184
+ return True
185
+
186
+ except Exception:
187
+ pass
188
+
189
+ # If all strategies fail, handle based on skip_on_error flag
190
+ import os
191
+ pytest_detected = 'PYTEST_CURRENT_TEST' in os.environ or 'pytest' in str(f)
192
+
193
+ if skip_on_error:
194
+ # For CI environments: skip gracefully instead of failing
195
+ warnings.warn(
196
+ "ONNX export skipped due to known upstream issue "
197
+ "(onnx_ir.serde.SerdeError). "
198
+ "This is caused by a bug in the onnx_ir package where boolean "
199
+ "allowzero attributes cannot be serialized. All workarounds "
200
+ "failed. Consider updating packages: pip install --upgrade onnx "
201
+ "onnxscript "
202
+ "onnx_ir", UserWarning, stacklevel=3)
203
+ return False
204
+
205
+ # For regular usage: provide detailed error message
206
+ error_msg = (
207
+ "Failed to export model to ONNX due to known serialization issue. "
208
+ "This is caused by a bug in the onnx_ir package where boolean "
209
+ "allowzero attributes cannot be serialized. "
210
+ "Workarounds attempted: dynamo=False, multiple opset versions, "
211
+ "and legacy export. ")
212
+
213
+ if pytest_detected:
214
+ error_msg += (
215
+ "\n\nThis error commonly occurs in pytest environments. "
216
+ "Try one of these solutions:\n"
217
+ "1. Run the export outside of pytest (in a regular Python "
218
+ "script)\n"
219
+ "2. Update packages: pip install --upgrade onnx onnxscript "
220
+ "onnx_ir\n"
221
+ "3. Use torch.jit.script() instead of ONNX export for testing\n"
222
+ "4. Use safe_onnx_export(..., skip_on_error=True) to skip "
223
+ "gracefully in CI")
224
+ else:
225
+ error_msg += ("\n\nTry updating packages: pip install --upgrade onnx "
226
+ "onnxscript onnx_ir")
227
+
228
+ raise RuntimeError(error_msg)
@@ -3,6 +3,8 @@ from dataclasses import fields, is_dataclass
3
3
  from importlib import import_module
4
4
  from typing import Any, Dict
5
5
 
6
+ from torch.nn import ModuleDict, ModuleList
7
+
6
8
  from torch_geometric.config_store import (
7
9
  class_from_dataclass,
8
10
  dataclass_from_class,
@@ -71,9 +73,9 @@ def _recursive_config(value: Any) -> Any:
71
73
  return value.config()
72
74
  if is_torch_instance(value, ConfigMixin):
73
75
  return value.config()
74
- if isinstance(value, (tuple, list)):
76
+ if isinstance(value, (tuple, list, ModuleList)):
75
77
  return [_recursive_config(v) for v in value]
76
- if isinstance(value, dict):
78
+ if isinstance(value, (dict, ModuleDict)):
77
79
  return {k: _recursive_config(v) for k, v in value.items()}
78
80
  return value
79
81
 
@@ -82,7 +84,10 @@ def _recursive_from_config(value: Any) -> Any:
82
84
  cls: Any = None
83
85
  if is_dataclass(value):
84
86
  if getattr(value, '_target_', None):
85
- cls = _locate_cls(value._target_)
87
+ try:
88
+ cls = _locate_cls(value._target_) # type: ignore
89
+ except ImportError:
90
+ pass # Keep the dataclass as it is.
86
91
  else:
87
92
  cls = class_from_dataclass(value.__class__)
88
93
  elif isinstance(value, dict) and '_target_' in value:
@@ -168,7 +168,7 @@ def map_annotation(
168
168
  assert origin is not None
169
169
  args = tuple(map_annotation(a, mapping) for a in args)
170
170
  if type(annotation).__name__ == 'GenericAlias':
171
- # If annotated with `list[...]` or `dict[...]` (>= Python 3.10):
171
+ # If annotated with `list[...]` or `dict[...]`:
172
172
  annotation = origin[args]
173
173
  else:
174
174
  # If annotated with `typing.List[...]` or `typing.Dict[...]`:
@@ -7,6 +7,6 @@ import torch_geometric.contrib.explain # noqa
7
7
 
8
8
  warnings.warn(
9
9
  "'torch_geometric.contrib' contains experimental code and is subject to "
10
- "change. Please use with caution.")
10
+ "change. Please use with caution.", stacklevel=2)
11
11
 
12
12
  __all__ = []
@@ -151,7 +151,7 @@ class PGMExplainer(ExplainerAlgorithm):
151
151
 
152
152
  pred_change = torch.max(soft_pred) - soft_pred_perturb[pred_label]
153
153
 
154
- sample[num_nodes] = pred_change
154
+ sample[num_nodes] = pred_change.detach()
155
155
  samples.append(sample)
156
156
 
157
157
  samples = torch.tensor(np.array(samples))
@@ -1,7 +1,10 @@
1
1
  # flake8: noqa
2
2
 
3
+ import torch
4
+ import torch_geometric.typing
5
+
3
6
  from .feature_store import FeatureStore, TensorAttr
4
- from .graph_store import GraphStore, EdgeAttr
7
+ from .graph_store import GraphStore, EdgeAttr, EdgeLayout
5
8
  from .data import Data
6
9
  from .hetero_data import HeteroData
7
10
  from .batch import Batch
@@ -68,6 +71,21 @@ from torch_geometric.loader import DataLoader
68
71
  from torch_geometric.loader import DataListLoader
69
72
  from torch_geometric.loader import DenseDataLoader
70
73
 
74
+ # Serialization ###############################################################
75
+
76
+ if torch_geometric.typing.WITH_PT24:
77
+ torch.serialization.add_safe_globals([
78
+ Data,
79
+ HeteroData,
80
+ TemporalData,
81
+ ClusterData,
82
+ TensorAttr,
83
+ EdgeAttr,
84
+ EdgeLayout,
85
+ ])
86
+
87
+ # Deprecations ################################################################
88
+
71
89
  NeighborSampler = deprecated( # type: ignore
72
90
  details="use 'loader.NeighborSampler' instead",
73
91
  func_name='data.NeighborSampler',
@@ -125,8 +125,8 @@ class Batch(metaclass=DynamicInheritance):
125
125
  cls=self.__class__.__bases__[-1],
126
126
  batch=self,
127
127
  idx=idx,
128
- slice_dict=getattr(self, '_slice_dict'),
129
- inc_dict=getattr(self, '_inc_dict'),
128
+ slice_dict=self._slice_dict,
129
+ inc_dict=self._inc_dict,
130
130
  decrement=True,
131
131
  )
132
132
 
@@ -191,10 +191,8 @@ def _collate(
191
191
  if torch_geometric.typing.WITH_PT20:
192
192
  storage = elem.untyped_storage()._new_shared(
193
193
  numel * elem.element_size(), device=elem.device)
194
- elif torch_geometric.typing.WITH_PT112:
195
- storage = elem.storage()._new_shared(numel, device=elem.device)
196
194
  else:
197
- storage = elem.storage()._new_shared(numel)
195
+ storage = elem.storage()._new_shared(numel, device=elem.device)
198
196
  shape = list(elem.size())
199
197
  if cat_dim is None or elem.dim() == 0:
200
198
  shape = [len(values)] + shape
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import warnings
3
+ from collections import defaultdict
3
4
  from collections.abc import Mapping, Sequence
4
5
  from dataclasses import dataclass
5
6
  from itertools import chain
@@ -354,7 +355,7 @@ class BaseData:
354
355
  """
355
356
  return self.apply(lambda x: x.contiguous(), *args)
356
357
 
357
- def to(self, device: Union[int, str], *args: str,
358
+ def to(self, device: Union[int, str, torch.device], *args: str,
358
359
  non_blocking: bool = False):
359
360
  r"""Performs tensor device conversion, either for all attributes or
360
361
  only the ones given in :obj:`*args`.
@@ -659,7 +660,13 @@ class Data(BaseData, FeatureStore, GraphStore):
659
660
  return value.get_dim_size()
660
661
  return int(value.max()) + 1
661
662
  elif 'index' in key or key == 'face':
662
- return self.num_nodes
663
+ num_nodes = self.num_nodes
664
+ if num_nodes is None:
665
+ raise RuntimeError(f"Unable to infer 'num_nodes' from the "
666
+ f"attribute '{key}'. Please explicitly set "
667
+ f"'num_nodes' as an attribute of 'data' to "
668
+ f"prevent this error")
669
+ return num_nodes
663
670
  else:
664
671
  return 0
665
672
 
@@ -844,14 +851,14 @@ class Data(BaseData, FeatureStore, GraphStore):
844
851
  # that maps global node indices to local ones in the final
845
852
  # heterogeneous graph:
846
853
  node_ids, index_map = {}, torch.empty_like(node_type)
847
- for i, key in enumerate(node_type_names):
854
+ for i in range(len(node_type_names)):
848
855
  node_ids[i] = (node_type == i).nonzero(as_tuple=False).view(-1)
849
856
  index_map[node_ids[i]] = torch.arange(len(node_ids[i]),
850
857
  device=index_map.device)
851
858
 
852
859
  # We iterate over edge types to find the local edge indices:
853
860
  edge_ids = {}
854
- for i, key in enumerate(edge_type_names):
861
+ for i in range(len(edge_type_names)):
855
862
  edge_ids[i] = (edge_type == i).nonzero(as_tuple=False).view(-1)
856
863
 
857
864
  data = HeteroData()
@@ -898,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
898
905
 
899
906
  return data
900
907
 
908
+ def connected_components(self) -> List[Self]:
909
+ r"""Extracts connected components of the graph using a union-find
910
+ algorithm. The components are returned as a list of
911
+ :class:`~torch_geometric.data.Data` objects, where each object
912
+ represents a connected component of the graph.
913
+
914
+ .. code-block::
915
+
916
+ data = Data()
917
+ data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
918
+ data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
919
+ data.edge_index = torch.tensor(
920
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
921
+ )
922
+
923
+ components = data.connected_components()
924
+ print(len(components))
925
+ >>> 2
926
+
927
+ print(components[0].x)
928
+ >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
929
+
930
+ Returns:
931
+ List[Data]: A list of disconnected components.
932
+ """
933
+ # Union-Find algorithm to find connected components
934
+ self._parents: Dict[int, int] = {}
935
+ self._ranks: Dict[int, int] = {}
936
+ for edge in self.edge_index.t().tolist():
937
+ self._union(edge[0], edge[1])
938
+
939
+ # Rerun _find_parent to ensure all nodes are covered correctly
940
+ for node in range(self.num_nodes):
941
+ self._find_parent(node)
942
+
943
+ # Group parents
944
+ grouped_parents = defaultdict(list)
945
+ for node, parent in self._parents.items():
946
+ grouped_parents[parent].append(node)
947
+ del self._ranks
948
+ del self._parents
949
+
950
+ # Create components based on the found parents (roots)
951
+ components: List[Self] = []
952
+ for nodes in grouped_parents.values():
953
+ # Convert the list of node IDs to a tensor
954
+ subset = torch.tensor(nodes, dtype=torch.long)
955
+
956
+ # Use the existing subgraph function
957
+ component_data = self.subgraph(subset)
958
+ components.append(component_data)
959
+
960
+ return components
961
+
901
962
  ###########################################################################
902
963
 
903
964
  @classmethod
@@ -1144,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
1144
1205
 
1145
1206
  return list(edge_attrs.values())
1146
1207
 
1208
+ # Connected Components Helper Functions ###################################
1209
+
1210
+ def _find_parent(self, node: int) -> int:
1211
+ r"""Finds and returns the representative parent of the given node in a
1212
+ disjoint-set (union-find) data structure. Implements path compression
1213
+ to optimize future queries.
1214
+
1215
+ Args:
1216
+ node (int): The node for which to find the representative parent.
1217
+
1218
+ Returns:
1219
+ int: The representative parent of the node.
1220
+ """
1221
+ if node not in self._parents:
1222
+ self._parents[node] = node
1223
+ self._ranks[node] = 0
1224
+ if self._parents[node] != node:
1225
+ self._parents[node] = self._find_parent(self._parents[node])
1226
+ return self._parents[node]
1227
+
1228
+ def _union(self, node1: int, node2: int):
1229
+ r"""Merges the sets containing node1 and node2 in the disjoint-set
1230
+ data structure.
1231
+
1232
+ Finds the root parents of node1 and node2 using the _find_parent
1233
+ method. If they belong to different sets, updates the parent of
1234
+ root2 to be root1, effectively merging the two sets.
1235
+
1236
+ Args:
1237
+ node1 (int): The index of the first node to union.
1238
+ node2 (int): The index of the second node to union.
1239
+ """
1240
+ root1 = self._find_parent(node1)
1241
+ root2 = self._find_parent(node2)
1242
+ if root1 != root2:
1243
+ if self._ranks[root1] < self._ranks[root2]:
1244
+ self._parents[root1] = root2
1245
+ elif self._ranks[root1] > self._ranks[root2]:
1246
+ self._parents[root2] = root1
1247
+ else:
1248
+ self._parents[root2] = root1
1249
+ self._ranks[root1] += 1
1250
+
1147
1251
 
1148
1252
  ###############################################################################
1149
1253
 
@@ -1165,7 +1269,7 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
1165
1269
  f'[{value.num_rows}, {value.num_cols}])')
1166
1270
  elif isinstance(value, str):
1167
1271
  out = f"'{value}'"
1168
- elif isinstance(value, Sequence):
1272
+ elif isinstance(value, (Sequence, set)):
1169
1273
  out = str([len(value)])
1170
1274
  elif isinstance(value, Mapping) and len(value) == 0:
1171
1275
  out = '{}'
@@ -1187,4 +1291,4 @@ def warn_or_raise(msg: str, raise_on_error: bool = True):
1187
1291
  if raise_on_error:
1188
1292
  raise ValueError(msg)
1189
1293
  else:
1190
- warnings.warn(msg)
1294
+ warnings.warn(msg, stacklevel=2)