pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: flit 3.9.0
2
+ Generator: flit 3.12.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2023 PyG Team <team@pyg.org>
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
@@ -4,9 +4,10 @@ import torch
4
4
  import torch_geometric.typing
5
5
 
6
6
  from ._compile import compile, is_compiling
7
- from ._onnx import is_in_onnx_export
7
+ from ._onnx import is_in_onnx_export, safe_onnx_export
8
8
  from .index import Index
9
9
  from .edge_index import EdgeIndex
10
+ from .hash_tensor import HashTensor
10
11
  from .seed import seed_everything
11
12
  from .home import get_home_dir, set_home_dir
12
13
  from .device import is_mps_available, is_xpu_available, device
@@ -30,17 +31,19 @@ from .lazy_loader import LazyLoader
30
31
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
31
32
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
32
33
 
33
- __version__ = '2.7.0.dev20241009'
34
+ __version__ = '2.8.0.dev20251207'
34
35
 
35
36
  __all__ = [
36
37
  'Index',
37
38
  'EdgeIndex',
39
+ 'HashTensor',
38
40
  'seed_everything',
39
41
  'get_home_dir',
40
42
  'set_home_dir',
41
43
  'compile',
42
44
  'is_compiling',
43
45
  'is_in_onnx_export',
46
+ 'safe_onnx_export',
44
47
  'is_mps_available',
45
48
  'is_xpu_available',
46
49
  'device',
@@ -55,6 +58,14 @@ __all__ = [
55
58
  '__version__',
56
59
  ]
57
60
 
61
+ if not torch_geometric.typing.WITH_PT113:
62
+ import warnings as std_warnings
63
+
64
+ std_warnings.warn(
65
+ "PyG 2.7 removed support for PyTorch < 1.13. Consider "
66
+ "Consider upgrading to PyTorch >= 1.13 or downgrading "
67
+ "to PyG <= 2.6. ", stacklevel=2)
68
+
58
69
  # Serialization ###############################################################
59
70
 
60
71
  if torch_geometric.typing.WITH_PT24:
@@ -67,4 +78,5 @@ if torch_geometric.typing.WITH_PT24:
67
78
  EdgeIndex,
68
79
  torch_geometric.edge_index.SortOrder,
69
80
  torch_geometric.edge_index.CatMetadata,
81
+ HashTensor,
70
82
  ])
@@ -27,10 +27,16 @@ def compile(
27
27
  This function has the same signature as :meth:`torch.compile` (see
28
28
  `here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).
29
29
 
30
+ Args:
31
+ model: The model to compile.
32
+ *args: Additional arguments of :meth:`torch.compile`.
33
+ **kwargs: Additional keyword arguments of :meth:`torch.compile`.
34
+
30
35
  .. note::
31
36
  :meth:`torch_geometric.compile` is deprecated in favor of
32
37
  :meth:`torch.compile`.
33
38
  """
34
- warnings.warn("'torch_geometric.compile' is deprecated in favor of "
35
- "'torch.compile'")
36
- return torch.compile(model, *args, **kwargs)
39
+ warnings.warn(
40
+ "'torch_geometric.compile' is deprecated in favor of "
41
+ "'torch.compile'", stacklevel=2)
42
+ return torch.compile(model, *args, **kwargs) # type: ignore
torch_geometric/_onnx.py CHANGED
@@ -1,3 +1,7 @@
1
+ import warnings
2
+ from os import PathLike
3
+ from typing import Any, Union
4
+
1
5
  import torch
2
6
 
3
7
  from torch_geometric import is_compiling
@@ -12,3 +16,213 @@ def is_in_onnx_export() -> bool:
12
16
  if torch.jit.is_scripting():
13
17
  return False
14
18
  return torch.onnx.is_in_onnx_export()
19
+
20
+
21
+ def safe_onnx_export(
22
+ model: torch.nn.Module,
23
+ args: Union[torch.Tensor, tuple[Any, ...]],
24
+ f: Union[str, PathLike[Any], None],
25
+ skip_on_error: bool = False,
26
+ **kwargs: Any,
27
+ ) -> bool:
28
+ r"""A safe wrapper around :meth:`torch.onnx.export` that handles known
29
+ ONNX serialization issues in PyTorch Geometric.
30
+
31
+ This function provides workarounds for the ``onnx_ir.serde.SerdeError``
32
+ with boolean ``allowzero`` attributes that occurs in certain environments.
33
+
34
+ Args:
35
+ model (torch.nn.Module): The model to export.
36
+ args (torch.Tensor or tuple): The input arguments for the model.
37
+ f (str or PathLike): The file path to save the model.
38
+ skip_on_error (bool): If True, return False instead of raising when
39
+ workarounds fail. Useful for CI environments.
40
+ **kwargs: Additional arguments passed to :meth:`torch.onnx.export`.
41
+
42
+ Returns:
43
+ bool: True if export succeeded, False if skipped due to known issues
44
+ (only when skip_on_error=True).
45
+
46
+ Example:
47
+ >>> from torch_geometric.nn import SAGEConv
48
+ >>> from torch_geometric import safe_onnx_export
49
+ >>>
50
+ >>> class MyModel(torch.nn.Module):
51
+ ... def __init__(self):
52
+ ... super().__init__()
53
+ ... self.conv = SAGEConv(8, 16)
54
+ ... def forward(self, x, edge_index):
55
+ ... return self.conv(x, edge_index)
56
+ >>>
57
+ >>> model = MyModel()
58
+ >>> x = torch.randn(3, 8)
59
+ >>> edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
60
+ >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx')
61
+ >>>
62
+ >>> # For CI environments:
63
+ >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx',
64
+ ... skip_on_error=True)
65
+ >>> if not success:
66
+ ... print("ONNX export skipped due to known upstream issue")
67
+ """
68
+ # Convert single tensor to tuple for torch.onnx.export compatibility
69
+ if isinstance(args, torch.Tensor):
70
+ args = (args, )
71
+
72
+ try:
73
+ # First attempt: standard ONNX export
74
+ torch.onnx.export(model, args, f, **kwargs)
75
+ return True
76
+
77
+ except Exception as e:
78
+ error_str = str(e)
79
+ error_type = type(e).__name__
80
+
81
+ # Check for the specific onnx_ir.serde.SerdeError patterns
82
+ is_allowzero_error = (('onnx_ir.serde.SerdeError' in error_str
83
+ and 'allowzero' in error_str) or
84
+ 'ValueError: Value out of range: 1' in error_str
85
+ or 'serialize_model_into' in error_str
86
+ or 'serialize_attribute_into' in error_str)
87
+
88
+ if is_allowzero_error:
89
+ warnings.warn(
90
+ f"Encountered known ONNX serialization issue ({error_type}). "
91
+ "This is likely the allowzero boolean attribute bug. "
92
+ "Attempting workaround...", UserWarning, stacklevel=2)
93
+
94
+ # Apply workaround strategies
95
+ return _apply_onnx_allowzero_workaround(model, args, f,
96
+ skip_on_error, **kwargs)
97
+
98
+ else:
99
+ # Re-raise other errors
100
+ raise
101
+
102
+
103
+ def _apply_onnx_allowzero_workaround(
104
+ model: torch.nn.Module,
105
+ args: tuple[Any, ...],
106
+ f: Union[str, PathLike[Any], None],
107
+ skip_on_error: bool = False,
108
+ **kwargs: Any,
109
+ ) -> bool:
110
+ r"""Apply workaround strategies for onnx_ir.serde.SerdeError with allowzero
111
+ attributes.
112
+
113
+ Returns:
114
+ bool: True if export succeeded, False if skipped (when
115
+ skip_on_error=True).
116
+ """
117
+ # Strategy 1: Try without dynamo if it was enabled
118
+ if kwargs.get('dynamo', False):
119
+ try:
120
+ kwargs_no_dynamo = kwargs.copy()
121
+ kwargs_no_dynamo['dynamo'] = False
122
+
123
+ warnings.warn(
124
+ "Retrying ONNX export with dynamo=False as workaround",
125
+ UserWarning, stacklevel=3)
126
+
127
+ torch.onnx.export(model, args, f, **kwargs_no_dynamo)
128
+ return True
129
+
130
+ except Exception:
131
+ pass
132
+
133
+ # Strategy 2: Try with different opset versions
134
+ original_opset = kwargs.get('opset_version', 18)
135
+ for opset_version in [17, 16, 15, 14, 13, 11]:
136
+ if opset_version != original_opset:
137
+ try:
138
+ kwargs_opset = kwargs.copy()
139
+ kwargs_opset['opset_version'] = opset_version
140
+
141
+ warnings.warn(
142
+ f"Retrying ONNX export with opset_version={opset_version}",
143
+ UserWarning, stacklevel=3)
144
+
145
+ torch.onnx.export(model, args, f, **kwargs_opset)
146
+ return True
147
+
148
+ except Exception:
149
+ continue
150
+
151
+ # Strategy 3: Try legacy export (non-dynamo with older opset)
152
+ try:
153
+ kwargs_legacy = kwargs.copy()
154
+ kwargs_legacy['dynamo'] = False
155
+ kwargs_legacy['opset_version'] = 11
156
+
157
+ warnings.warn(
158
+ "Retrying ONNX export with legacy settings "
159
+ "(dynamo=False, opset_version=11)", UserWarning, stacklevel=3)
160
+
161
+ torch.onnx.export(model, args, f, **kwargs_legacy)
162
+ return True
163
+
164
+ except Exception:
165
+ pass
166
+
167
+ # Strategy 4: Try with minimal settings
168
+ try:
169
+ minimal_kwargs: dict[str, Any] = {
170
+ 'opset_version': 11,
171
+ 'dynamo': False,
172
+ }
173
+ # Add optional parameters if they exist
174
+ if kwargs.get('input_names') is not None:
175
+ minimal_kwargs['input_names'] = kwargs.get('input_names')
176
+ if kwargs.get('output_names') is not None:
177
+ minimal_kwargs['output_names'] = kwargs.get('output_names')
178
+
179
+ warnings.warn(
180
+ "Retrying ONNX export with minimal settings as last resort",
181
+ UserWarning, stacklevel=3)
182
+
183
+ torch.onnx.export(model, args, f, **minimal_kwargs)
184
+ return True
185
+
186
+ except Exception:
187
+ pass
188
+
189
+ # If all strategies fail, handle based on skip_on_error flag
190
+ import os
191
+ pytest_detected = 'PYTEST_CURRENT_TEST' in os.environ or 'pytest' in str(f)
192
+
193
+ if skip_on_error:
194
+ # For CI environments: skip gracefully instead of failing
195
+ warnings.warn(
196
+ "ONNX export skipped due to known upstream issue "
197
+ "(onnx_ir.serde.SerdeError). "
198
+ "This is caused by a bug in the onnx_ir package where boolean "
199
+ "allowzero attributes cannot be serialized. All workarounds "
200
+ "failed. Consider updating packages: pip install --upgrade onnx "
201
+ "onnxscript "
202
+ "onnx_ir", UserWarning, stacklevel=3)
203
+ return False
204
+
205
+ # For regular usage: provide detailed error message
206
+ error_msg = (
207
+ "Failed to export model to ONNX due to known serialization issue. "
208
+ "This is caused by a bug in the onnx_ir package where boolean "
209
+ "allowzero attributes cannot be serialized. "
210
+ "Workarounds attempted: dynamo=False, multiple opset versions, "
211
+ "and legacy export. ")
212
+
213
+ if pytest_detected:
214
+ error_msg += (
215
+ "\n\nThis error commonly occurs in pytest environments. "
216
+ "Try one of these solutions:\n"
217
+ "1. Run the export outside of pytest (in a regular Python "
218
+ "script)\n"
219
+ "2. Update packages: pip install --upgrade onnx onnxscript "
220
+ "onnx_ir\n"
221
+ "3. Use torch.jit.script() instead of ONNX export for testing\n"
222
+ "4. Use safe_onnx_export(..., skip_on_error=True) to skip "
223
+ "gracefully in CI")
224
+ else:
225
+ error_msg += ("\n\nTry updating packages: pip install --upgrade onnx "
226
+ "onnxscript onnx_ir")
227
+
228
+ raise RuntimeError(error_msg)
@@ -3,6 +3,8 @@ from dataclasses import fields, is_dataclass
3
3
  from importlib import import_module
4
4
  from typing import Any, Dict
5
5
 
6
+ from torch.nn import ModuleDict, ModuleList
7
+
6
8
  from torch_geometric.config_store import (
7
9
  class_from_dataclass,
8
10
  dataclass_from_class,
@@ -71,9 +73,9 @@ def _recursive_config(value: Any) -> Any:
71
73
  return value.config()
72
74
  if is_torch_instance(value, ConfigMixin):
73
75
  return value.config()
74
- if isinstance(value, (tuple, list)):
76
+ if isinstance(value, (tuple, list, ModuleList)):
75
77
  return [_recursive_config(v) for v in value]
76
- if isinstance(value, dict):
78
+ if isinstance(value, (dict, ModuleDict)):
77
79
  return {k: _recursive_config(v) for k, v in value.items()}
78
80
  return value
79
81
 
@@ -83,7 +85,7 @@ def _recursive_from_config(value: Any) -> Any:
83
85
  if is_dataclass(value):
84
86
  if getattr(value, '_target_', None):
85
87
  try:
86
- cls = _locate_cls(value._target_) # type: ignore[attr-defined]
88
+ cls = _locate_cls(value._target_) # type: ignore
87
89
  except ImportError:
88
90
  pass # Keep the dataclass as it is.
89
91
  else:
@@ -168,7 +168,7 @@ def map_annotation(
168
168
  assert origin is not None
169
169
  args = tuple(map_annotation(a, mapping) for a in args)
170
170
  if type(annotation).__name__ == 'GenericAlias':
171
- # If annotated with `list[...]` or `dict[...]` (>= Python 3.10):
171
+ # If annotated with `list[...]` or `dict[...]`:
172
172
  annotation = origin[args]
173
173
  else:
174
174
  # If annotated with `typing.List[...]` or `typing.Dict[...]`:
@@ -7,6 +7,6 @@ import torch_geometric.contrib.explain # noqa
7
7
 
8
8
  warnings.warn(
9
9
  "'torch_geometric.contrib' contains experimental code and is subject to "
10
- "change. Please use with caution.")
10
+ "change. Please use with caution.", stacklevel=2)
11
11
 
12
12
  __all__ = []
@@ -151,7 +151,7 @@ class PGMExplainer(ExplainerAlgorithm):
151
151
 
152
152
  pred_change = torch.max(soft_pred) - soft_pred_perturb[pred_label]
153
153
 
154
- sample[num_nodes] = pred_change
154
+ sample[num_nodes] = pred_change.detach()
155
155
  samples.append(sample)
156
156
 
157
157
  samples = torch.tensor(np.array(samples))
@@ -125,8 +125,8 @@ class Batch(metaclass=DynamicInheritance):
125
125
  cls=self.__class__.__bases__[-1],
126
126
  batch=self,
127
127
  idx=idx,
128
- slice_dict=getattr(self, '_slice_dict'),
129
- inc_dict=getattr(self, '_inc_dict'),
128
+ slice_dict=self._slice_dict,
129
+ inc_dict=self._inc_dict,
130
130
  decrement=True,
131
131
  )
132
132
 
@@ -191,10 +191,8 @@ def _collate(
191
191
  if torch_geometric.typing.WITH_PT20:
192
192
  storage = elem.untyped_storage()._new_shared(
193
193
  numel * elem.element_size(), device=elem.device)
194
- elif torch_geometric.typing.WITH_PT112:
195
- storage = elem.storage()._new_shared(numel, device=elem.device)
196
194
  else:
197
- storage = elem.storage()._new_shared(numel)
195
+ storage = elem.storage()._new_shared(numel, device=elem.device)
198
196
  shape = list(elem.size())
199
197
  if cat_dim is None or elem.dim() == 0:
200
198
  shape = [len(values)] + shape
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import warnings
3
+ from collections import defaultdict
3
4
  from collections.abc import Mapping, Sequence
4
5
  from dataclasses import dataclass
5
6
  from itertools import chain
@@ -354,7 +355,7 @@ class BaseData:
354
355
  """
355
356
  return self.apply(lambda x: x.contiguous(), *args)
356
357
 
357
- def to(self, device: Union[int, str], *args: str,
358
+ def to(self, device: Union[int, str, torch.device], *args: str,
358
359
  non_blocking: bool = False):
359
360
  r"""Performs tensor device conversion, either for all attributes or
360
361
  only the ones given in :obj:`*args`.
@@ -659,7 +660,13 @@ class Data(BaseData, FeatureStore, GraphStore):
659
660
  return value.get_dim_size()
660
661
  return int(value.max()) + 1
661
662
  elif 'index' in key or key == 'face':
662
- return self.num_nodes
663
+ num_nodes = self.num_nodes
664
+ if num_nodes is None:
665
+ raise RuntimeError(f"Unable to infer 'num_nodes' from the "
666
+ f"attribute '{key}'. Please explicitly set "
667
+ f"'num_nodes' as an attribute of 'data' to "
668
+ f"prevent this error")
669
+ return num_nodes
663
670
  else:
664
671
  return 0
665
672
 
@@ -844,14 +851,14 @@ class Data(BaseData, FeatureStore, GraphStore):
844
851
  # that maps global node indices to local ones in the final
845
852
  # heterogeneous graph:
846
853
  node_ids, index_map = {}, torch.empty_like(node_type)
847
- for i, key in enumerate(node_type_names):
854
+ for i in range(len(node_type_names)):
848
855
  node_ids[i] = (node_type == i).nonzero(as_tuple=False).view(-1)
849
856
  index_map[node_ids[i]] = torch.arange(len(node_ids[i]),
850
857
  device=index_map.device)
851
858
 
852
859
  # We iterate over edge types to find the local edge indices:
853
860
  edge_ids = {}
854
- for i, key in enumerate(edge_type_names):
861
+ for i in range(len(edge_type_names)):
855
862
  edge_ids[i] = (edge_type == i).nonzero(as_tuple=False).view(-1)
856
863
 
857
864
  data = HeteroData()
@@ -898,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
898
905
 
899
906
  return data
900
907
 
908
+ def connected_components(self) -> List[Self]:
909
+ r"""Extracts connected components of the graph using a union-find
910
+ algorithm. The components are returned as a list of
911
+ :class:`~torch_geometric.data.Data` objects, where each object
912
+ represents a connected component of the graph.
913
+
914
+ .. code-block::
915
+
916
+ data = Data()
917
+ data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
918
+ data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
919
+ data.edge_index = torch.tensor(
920
+ [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
921
+ )
922
+
923
+ components = data.connected_components()
924
+ print(len(components))
925
+ >>> 2
926
+
927
+ print(components[0].x)
928
+ >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
929
+
930
+ Returns:
931
+ List[Data]: A list of disconnected components.
932
+ """
933
+ # Union-Find algorithm to find connected components
934
+ self._parents: Dict[int, int] = {}
935
+ self._ranks: Dict[int, int] = {}
936
+ for edge in self.edge_index.t().tolist():
937
+ self._union(edge[0], edge[1])
938
+
939
+ # Rerun _find_parent to ensure all nodes are covered correctly
940
+ for node in range(self.num_nodes):
941
+ self._find_parent(node)
942
+
943
+ # Group parents
944
+ grouped_parents = defaultdict(list)
945
+ for node, parent in self._parents.items():
946
+ grouped_parents[parent].append(node)
947
+ del self._ranks
948
+ del self._parents
949
+
950
+ # Create components based on the found parents (roots)
951
+ components: List[Self] = []
952
+ for nodes in grouped_parents.values():
953
+ # Convert the list of node IDs to a tensor
954
+ subset = torch.tensor(nodes, dtype=torch.long)
955
+
956
+ # Use the existing subgraph function
957
+ component_data = self.subgraph(subset)
958
+ components.append(component_data)
959
+
960
+ return components
961
+
901
962
  ###########################################################################
902
963
 
903
964
  @classmethod
@@ -1144,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
1144
1205
 
1145
1206
  return list(edge_attrs.values())
1146
1207
 
1208
+ # Connected Components Helper Functions ###################################
1209
+
1210
+ def _find_parent(self, node: int) -> int:
1211
+ r"""Finds and returns the representative parent of the given node in a
1212
+ disjoint-set (union-find) data structure. Implements path compression
1213
+ to optimize future queries.
1214
+
1215
+ Args:
1216
+ node (int): The node for which to find the representative parent.
1217
+
1218
+ Returns:
1219
+ int: The representative parent of the node.
1220
+ """
1221
+ if node not in self._parents:
1222
+ self._parents[node] = node
1223
+ self._ranks[node] = 0
1224
+ if self._parents[node] != node:
1225
+ self._parents[node] = self._find_parent(self._parents[node])
1226
+ return self._parents[node]
1227
+
1228
+ def _union(self, node1: int, node2: int):
1229
+ r"""Merges the sets containing node1 and node2 in the disjoint-set
1230
+ data structure.
1231
+
1232
+ Finds the root parents of node1 and node2 using the _find_parent
1233
+ method. If they belong to different sets, updates the parent of
1234
+ root2 to be root1, effectively merging the two sets.
1235
+
1236
+ Args:
1237
+ node1 (int): The index of the first node to union.
1238
+ node2 (int): The index of the second node to union.
1239
+ """
1240
+ root1 = self._find_parent(node1)
1241
+ root2 = self._find_parent(node2)
1242
+ if root1 != root2:
1243
+ if self._ranks[root1] < self._ranks[root2]:
1244
+ self._parents[root1] = root2
1245
+ elif self._ranks[root1] > self._ranks[root2]:
1246
+ self._parents[root2] = root1
1247
+ else:
1248
+ self._parents[root2] = root1
1249
+ self._ranks[root1] += 1
1250
+
1147
1251
 
1148
1252
  ###############################################################################
1149
1253
 
@@ -1187,4 +1291,4 @@ def warn_or_raise(msg: str, raise_on_error: bool = True):
1187
1291
  if raise_on_error:
1188
1292
  raise ValueError(msg)
1189
1293
  else:
1190
- warnings.warn(msg)
1294
+ warnings.warn(msg, stacklevel=2)
@@ -111,13 +111,17 @@ class Database(ABC):
111
111
  for key, value in schema_dict.items()
112
112
  }
113
113
 
114
+ @abstractmethod
114
115
  def connect(self) -> None:
115
116
  r"""Connects to the database.
116
117
  Databases will automatically connect on instantiation.
117
118
  """
119
+ raise NotImplementedError
118
120
 
121
+ @abstractmethod
119
122
  def close(self) -> None:
120
123
  r"""Closes the connection to the database."""
124
+ raise NotImplementedError
121
125
 
122
126
  @abstractmethod
123
127
  def insert(self, index: int, data: Any) -> None:
@@ -166,10 +166,11 @@ class Dataset(torch.utils.data.Dataset):
166
166
  elif y.numel() == y.size(0) and torch.is_floating_point(y):
167
167
  num_classes = torch.unique(y).numel()
168
168
  if num_classes > 2:
169
- warnings.warn("Found floating-point labels while calling "
170
- "`dataset.num_classes`. Returning the number of "
171
- "unique elements. Please make sure that this "
172
- "is expected before proceeding.")
169
+ warnings.warn(
170
+ "Found floating-point labels while calling "
171
+ "`dataset.num_classes`. Returning the number of "
172
+ "unique elements. Please make sure that this "
173
+ "is expected before proceeding.", stacklevel=2)
173
174
  return num_classes
174
175
  else:
175
176
  return y.size(-1)
@@ -235,22 +236,24 @@ class Dataset(torch.utils.data.Dataset):
235
236
 
236
237
  def _process(self):
237
238
  f = osp.join(self.processed_dir, 'pre_transform.pt')
238
- if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
239
- self.pre_transform):
239
+ if not self.force_reload and osp.exists(f) and torch.load(
240
+ f, weights_only=False) != _repr(self.pre_transform):
240
241
  warnings.warn(
241
242
  "The `pre_transform` argument differs from the one used in "
242
243
  "the pre-processed version of this dataset. If you want to "
243
244
  "make use of another pre-processing technique, pass "
244
- "`force_reload=True` explicitly to reload the dataset.")
245
+ "`force_reload=True` explicitly to reload the dataset.",
246
+ stacklevel=2)
245
247
 
246
248
  f = osp.join(self.processed_dir, 'pre_filter.pt')
247
- if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
248
- self.pre_filter):
249
+ if not self.force_reload and osp.exists(f) and torch.load(
250
+ f, weights_only=False) != _repr(self.pre_filter):
249
251
  warnings.warn(
250
252
  "The `pre_filter` argument differs from the one used in "
251
253
  "the pre-processed version of this dataset. If you want to "
252
254
  "make use of another pre-fitering technique, pass "
253
- "`force_reload=True` explicitly to reload the dataset.")
255
+ "`force_reload=True` explicitly to reload the dataset.",
256
+ stacklevel=2)
254
257
 
255
258
  if not self.force_reload and files_exist(self.processed_paths):
256
259
  return
@@ -383,7 +386,7 @@ class Dataset(torch.utils.data.Dataset):
383
386
  r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
384
387
 
385
388
  The returned instance can then be used with :pyg:`PyG's` built-in
386
- :class:`DataPipes` for baching graphs as follows:
389
+ :class:`DataPipes` for batching graphs as follows:
387
390
 
388
391
  .. code-block:: python
389
392
 
@@ -28,7 +28,7 @@ def extract_tar(
28
28
  """
29
29
  maybe_log(path, log)
30
30
  with tarfile.open(path, mode) as f:
31
- f.extractall(folder)
31
+ f.extractall(folder, filter='data')
32
32
 
33
33
 
34
34
  def extract_zip(path: str, folder: str, log: bool = True) -> None: