pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Copyright (c) 2023 PyG Team <team@pyg.org>
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
4
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
5
|
+
in the Software without restriction, including without limitation the rights
|
|
6
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
7
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
8
|
+
furnished to do so, subject to the following conditions:
|
|
9
|
+
|
|
10
|
+
The above copyright notice and this permission notice shall be included in
|
|
11
|
+
all copies or substantial portions of the Software.
|
|
12
|
+
|
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
16
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
17
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
18
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
19
|
+
THE SOFTWARE.
|
torch_geometric/__init__.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch_geometric.typing
|
|
5
|
+
|
|
1
6
|
from ._compile import compile, is_compiling
|
|
7
|
+
from ._onnx import is_in_onnx_export, safe_onnx_export
|
|
2
8
|
from .index import Index
|
|
3
9
|
from .edge_index import EdgeIndex
|
|
10
|
+
from .hash_tensor import HashTensor
|
|
4
11
|
from .seed import seed_everything
|
|
5
12
|
from .home import get_home_dir, set_home_dir
|
|
6
13
|
from .device import is_mps_available, is_xpu_available, device
|
|
@@ -24,16 +31,19 @@ from .lazy_loader import LazyLoader
|
|
|
24
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
|
25
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
|
26
33
|
|
|
27
|
-
__version__ = '2.
|
|
34
|
+
__version__ = '2.8.0.dev20251207'
|
|
28
35
|
|
|
29
36
|
__all__ = [
|
|
30
37
|
'Index',
|
|
31
38
|
'EdgeIndex',
|
|
39
|
+
'HashTensor',
|
|
32
40
|
'seed_everything',
|
|
33
41
|
'get_home_dir',
|
|
34
42
|
'set_home_dir',
|
|
35
43
|
'compile',
|
|
36
44
|
'is_compiling',
|
|
45
|
+
'is_in_onnx_export',
|
|
46
|
+
'safe_onnx_export',
|
|
37
47
|
'is_mps_available',
|
|
38
48
|
'is_xpu_available',
|
|
39
49
|
'device',
|
|
@@ -47,3 +57,26 @@ __all__ = [
|
|
|
47
57
|
'torch_geometric',
|
|
48
58
|
'__version__',
|
|
49
59
|
]
|
|
60
|
+
|
|
61
|
+
if not torch_geometric.typing.WITH_PT113:
|
|
62
|
+
import warnings as std_warnings
|
|
63
|
+
|
|
64
|
+
std_warnings.warn(
|
|
65
|
+
"PyG 2.7 removed support for PyTorch < 1.13. Consider "
|
|
66
|
+
"Consider upgrading to PyTorch >= 1.13 or downgrading "
|
|
67
|
+
"to PyG <= 2.6. ", stacklevel=2)
|
|
68
|
+
|
|
69
|
+
# Serialization ###############################################################
|
|
70
|
+
|
|
71
|
+
if torch_geometric.typing.WITH_PT24:
|
|
72
|
+
torch.serialization.add_safe_globals([
|
|
73
|
+
dict,
|
|
74
|
+
list,
|
|
75
|
+
defaultdict,
|
|
76
|
+
Index,
|
|
77
|
+
torch_geometric.index.CatMetadata,
|
|
78
|
+
EdgeIndex,
|
|
79
|
+
torch_geometric.edge_index.SortOrder,
|
|
80
|
+
torch_geometric.edge_index.CatMetadata,
|
|
81
|
+
HashTensor,
|
|
82
|
+
])
|
torch_geometric/_compile.py
CHANGED
|
@@ -10,6 +10,8 @@ def is_compiling() -> bool:
|
|
|
10
10
|
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via
|
|
11
11
|
:meth:`torch.compile`.
|
|
12
12
|
"""
|
|
13
|
+
if torch_geometric.typing.WITH_PT23:
|
|
14
|
+
return torch.compiler.is_compiling()
|
|
13
15
|
if torch_geometric.typing.WITH_PT21:
|
|
14
16
|
return torch._dynamo.is_compiling()
|
|
15
17
|
return False # pragma: no cover
|
|
@@ -25,10 +27,16 @@ def compile(
|
|
|
25
27
|
This function has the same signature as :meth:`torch.compile` (see
|
|
26
28
|
`here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).
|
|
27
29
|
|
|
30
|
+
Args:
|
|
31
|
+
model: The model to compile.
|
|
32
|
+
*args: Additional arguments of :meth:`torch.compile`.
|
|
33
|
+
**kwargs: Additional keyword arguments of :meth:`torch.compile`.
|
|
34
|
+
|
|
28
35
|
.. note::
|
|
29
36
|
:meth:`torch_geometric.compile` is deprecated in favor of
|
|
30
37
|
:meth:`torch.compile`.
|
|
31
38
|
"""
|
|
32
|
-
warnings.warn(
|
|
33
|
-
|
|
34
|
-
|
|
39
|
+
warnings.warn(
|
|
40
|
+
"'torch_geometric.compile' is deprecated in favor of "
|
|
41
|
+
"'torch.compile'", stacklevel=2)
|
|
42
|
+
return torch.compile(model, *args, **kwargs) # type: ignore
|
torch_geometric/_onnx.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from os import PathLike
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from torch_geometric import is_compiling
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def is_in_onnx_export() -> bool:
|
|
11
|
+
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via
|
|
12
|
+
:meth:`torch.onnx.export`.
|
|
13
|
+
"""
|
|
14
|
+
if is_compiling():
|
|
15
|
+
return False
|
|
16
|
+
if torch.jit.is_scripting():
|
|
17
|
+
return False
|
|
18
|
+
return torch.onnx.is_in_onnx_export()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def safe_onnx_export(
|
|
22
|
+
model: torch.nn.Module,
|
|
23
|
+
args: Union[torch.Tensor, tuple[Any, ...]],
|
|
24
|
+
f: Union[str, PathLike[Any], None],
|
|
25
|
+
skip_on_error: bool = False,
|
|
26
|
+
**kwargs: Any,
|
|
27
|
+
) -> bool:
|
|
28
|
+
r"""A safe wrapper around :meth:`torch.onnx.export` that handles known
|
|
29
|
+
ONNX serialization issues in PyTorch Geometric.
|
|
30
|
+
|
|
31
|
+
This function provides workarounds for the ``onnx_ir.serde.SerdeError``
|
|
32
|
+
with boolean ``allowzero`` attributes that occurs in certain environments.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model (torch.nn.Module): The model to export.
|
|
36
|
+
args (torch.Tensor or tuple): The input arguments for the model.
|
|
37
|
+
f (str or PathLike): The file path to save the model.
|
|
38
|
+
skip_on_error (bool): If True, return False instead of raising when
|
|
39
|
+
workarounds fail. Useful for CI environments.
|
|
40
|
+
**kwargs: Additional arguments passed to :meth:`torch.onnx.export`.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
bool: True if export succeeded, False if skipped due to known issues
|
|
44
|
+
(only when skip_on_error=True).
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> from torch_geometric.nn import SAGEConv
|
|
48
|
+
>>> from torch_geometric import safe_onnx_export
|
|
49
|
+
>>>
|
|
50
|
+
>>> class MyModel(torch.nn.Module):
|
|
51
|
+
... def __init__(self):
|
|
52
|
+
... super().__init__()
|
|
53
|
+
... self.conv = SAGEConv(8, 16)
|
|
54
|
+
... def forward(self, x, edge_index):
|
|
55
|
+
... return self.conv(x, edge_index)
|
|
56
|
+
>>>
|
|
57
|
+
>>> model = MyModel()
|
|
58
|
+
>>> x = torch.randn(3, 8)
|
|
59
|
+
>>> edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
|
|
60
|
+
>>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx')
|
|
61
|
+
>>>
|
|
62
|
+
>>> # For CI environments:
|
|
63
|
+
>>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx',
|
|
64
|
+
... skip_on_error=True)
|
|
65
|
+
>>> if not success:
|
|
66
|
+
... print("ONNX export skipped due to known upstream issue")
|
|
67
|
+
"""
|
|
68
|
+
# Convert single tensor to tuple for torch.onnx.export compatibility
|
|
69
|
+
if isinstance(args, torch.Tensor):
|
|
70
|
+
args = (args, )
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
# First attempt: standard ONNX export
|
|
74
|
+
torch.onnx.export(model, args, f, **kwargs)
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
except Exception as e:
|
|
78
|
+
error_str = str(e)
|
|
79
|
+
error_type = type(e).__name__
|
|
80
|
+
|
|
81
|
+
# Check for the specific onnx_ir.serde.SerdeError patterns
|
|
82
|
+
is_allowzero_error = (('onnx_ir.serde.SerdeError' in error_str
|
|
83
|
+
and 'allowzero' in error_str) or
|
|
84
|
+
'ValueError: Value out of range: 1' in error_str
|
|
85
|
+
or 'serialize_model_into' in error_str
|
|
86
|
+
or 'serialize_attribute_into' in error_str)
|
|
87
|
+
|
|
88
|
+
if is_allowzero_error:
|
|
89
|
+
warnings.warn(
|
|
90
|
+
f"Encountered known ONNX serialization issue ({error_type}). "
|
|
91
|
+
"This is likely the allowzero boolean attribute bug. "
|
|
92
|
+
"Attempting workaround...", UserWarning, stacklevel=2)
|
|
93
|
+
|
|
94
|
+
# Apply workaround strategies
|
|
95
|
+
return _apply_onnx_allowzero_workaround(model, args, f,
|
|
96
|
+
skip_on_error, **kwargs)
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
# Re-raise other errors
|
|
100
|
+
raise
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _apply_onnx_allowzero_workaround(
|
|
104
|
+
model: torch.nn.Module,
|
|
105
|
+
args: tuple[Any, ...],
|
|
106
|
+
f: Union[str, PathLike[Any], None],
|
|
107
|
+
skip_on_error: bool = False,
|
|
108
|
+
**kwargs: Any,
|
|
109
|
+
) -> bool:
|
|
110
|
+
r"""Apply workaround strategies for onnx_ir.serde.SerdeError with allowzero
|
|
111
|
+
attributes.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
bool: True if export succeeded, False if skipped (when
|
|
115
|
+
skip_on_error=True).
|
|
116
|
+
"""
|
|
117
|
+
# Strategy 1: Try without dynamo if it was enabled
|
|
118
|
+
if kwargs.get('dynamo', False):
|
|
119
|
+
try:
|
|
120
|
+
kwargs_no_dynamo = kwargs.copy()
|
|
121
|
+
kwargs_no_dynamo['dynamo'] = False
|
|
122
|
+
|
|
123
|
+
warnings.warn(
|
|
124
|
+
"Retrying ONNX export with dynamo=False as workaround",
|
|
125
|
+
UserWarning, stacklevel=3)
|
|
126
|
+
|
|
127
|
+
torch.onnx.export(model, args, f, **kwargs_no_dynamo)
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
except Exception:
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
# Strategy 2: Try with different opset versions
|
|
134
|
+
original_opset = kwargs.get('opset_version', 18)
|
|
135
|
+
for opset_version in [17, 16, 15, 14, 13, 11]:
|
|
136
|
+
if opset_version != original_opset:
|
|
137
|
+
try:
|
|
138
|
+
kwargs_opset = kwargs.copy()
|
|
139
|
+
kwargs_opset['opset_version'] = opset_version
|
|
140
|
+
|
|
141
|
+
warnings.warn(
|
|
142
|
+
f"Retrying ONNX export with opset_version={opset_version}",
|
|
143
|
+
UserWarning, stacklevel=3)
|
|
144
|
+
|
|
145
|
+
torch.onnx.export(model, args, f, **kwargs_opset)
|
|
146
|
+
return True
|
|
147
|
+
|
|
148
|
+
except Exception:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
# Strategy 3: Try legacy export (non-dynamo with older opset)
|
|
152
|
+
try:
|
|
153
|
+
kwargs_legacy = kwargs.copy()
|
|
154
|
+
kwargs_legacy['dynamo'] = False
|
|
155
|
+
kwargs_legacy['opset_version'] = 11
|
|
156
|
+
|
|
157
|
+
warnings.warn(
|
|
158
|
+
"Retrying ONNX export with legacy settings "
|
|
159
|
+
"(dynamo=False, opset_version=11)", UserWarning, stacklevel=3)
|
|
160
|
+
|
|
161
|
+
torch.onnx.export(model, args, f, **kwargs_legacy)
|
|
162
|
+
return True
|
|
163
|
+
|
|
164
|
+
except Exception:
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
# Strategy 4: Try with minimal settings
|
|
168
|
+
try:
|
|
169
|
+
minimal_kwargs: dict[str, Any] = {
|
|
170
|
+
'opset_version': 11,
|
|
171
|
+
'dynamo': False,
|
|
172
|
+
}
|
|
173
|
+
# Add optional parameters if they exist
|
|
174
|
+
if kwargs.get('input_names') is not None:
|
|
175
|
+
minimal_kwargs['input_names'] = kwargs.get('input_names')
|
|
176
|
+
if kwargs.get('output_names') is not None:
|
|
177
|
+
minimal_kwargs['output_names'] = kwargs.get('output_names')
|
|
178
|
+
|
|
179
|
+
warnings.warn(
|
|
180
|
+
"Retrying ONNX export with minimal settings as last resort",
|
|
181
|
+
UserWarning, stacklevel=3)
|
|
182
|
+
|
|
183
|
+
torch.onnx.export(model, args, f, **minimal_kwargs)
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
except Exception:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
# If all strategies fail, handle based on skip_on_error flag
|
|
190
|
+
import os
|
|
191
|
+
pytest_detected = 'PYTEST_CURRENT_TEST' in os.environ or 'pytest' in str(f)
|
|
192
|
+
|
|
193
|
+
if skip_on_error:
|
|
194
|
+
# For CI environments: skip gracefully instead of failing
|
|
195
|
+
warnings.warn(
|
|
196
|
+
"ONNX export skipped due to known upstream issue "
|
|
197
|
+
"(onnx_ir.serde.SerdeError). "
|
|
198
|
+
"This is caused by a bug in the onnx_ir package where boolean "
|
|
199
|
+
"allowzero attributes cannot be serialized. All workarounds "
|
|
200
|
+
"failed. Consider updating packages: pip install --upgrade onnx "
|
|
201
|
+
"onnxscript "
|
|
202
|
+
"onnx_ir", UserWarning, stacklevel=3)
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
# For regular usage: provide detailed error message
|
|
206
|
+
error_msg = (
|
|
207
|
+
"Failed to export model to ONNX due to known serialization issue. "
|
|
208
|
+
"This is caused by a bug in the onnx_ir package where boolean "
|
|
209
|
+
"allowzero attributes cannot be serialized. "
|
|
210
|
+
"Workarounds attempted: dynamo=False, multiple opset versions, "
|
|
211
|
+
"and legacy export. ")
|
|
212
|
+
|
|
213
|
+
if pytest_detected:
|
|
214
|
+
error_msg += (
|
|
215
|
+
"\n\nThis error commonly occurs in pytest environments. "
|
|
216
|
+
"Try one of these solutions:\n"
|
|
217
|
+
"1. Run the export outside of pytest (in a regular Python "
|
|
218
|
+
"script)\n"
|
|
219
|
+
"2. Update packages: pip install --upgrade onnx onnxscript "
|
|
220
|
+
"onnx_ir\n"
|
|
221
|
+
"3. Use torch.jit.script() instead of ONNX export for testing\n"
|
|
222
|
+
"4. Use safe_onnx_export(..., skip_on_error=True) to skip "
|
|
223
|
+
"gracefully in CI")
|
|
224
|
+
else:
|
|
225
|
+
error_msg += ("\n\nTry updating packages: pip install --upgrade onnx "
|
|
226
|
+
"onnxscript onnx_ir")
|
|
227
|
+
|
|
228
|
+
raise RuntimeError(error_msg)
|
torch_geometric/config_mixin.py
CHANGED
|
@@ -3,6 +3,8 @@ from dataclasses import fields, is_dataclass
|
|
|
3
3
|
from importlib import import_module
|
|
4
4
|
from typing import Any, Dict
|
|
5
5
|
|
|
6
|
+
from torch.nn import ModuleDict, ModuleList
|
|
7
|
+
|
|
6
8
|
from torch_geometric.config_store import (
|
|
7
9
|
class_from_dataclass,
|
|
8
10
|
dataclass_from_class,
|
|
@@ -71,9 +73,9 @@ def _recursive_config(value: Any) -> Any:
|
|
|
71
73
|
return value.config()
|
|
72
74
|
if is_torch_instance(value, ConfigMixin):
|
|
73
75
|
return value.config()
|
|
74
|
-
if isinstance(value, (tuple, list)):
|
|
76
|
+
if isinstance(value, (tuple, list, ModuleList)):
|
|
75
77
|
return [_recursive_config(v) for v in value]
|
|
76
|
-
if isinstance(value, dict):
|
|
78
|
+
if isinstance(value, (dict, ModuleDict)):
|
|
77
79
|
return {k: _recursive_config(v) for k, v in value.items()}
|
|
78
80
|
return value
|
|
79
81
|
|
|
@@ -82,7 +84,10 @@ def _recursive_from_config(value: Any) -> Any:
|
|
|
82
84
|
cls: Any = None
|
|
83
85
|
if is_dataclass(value):
|
|
84
86
|
if getattr(value, '_target_', None):
|
|
85
|
-
|
|
87
|
+
try:
|
|
88
|
+
cls = _locate_cls(value._target_) # type: ignore
|
|
89
|
+
except ImportError:
|
|
90
|
+
pass # Keep the dataclass as it is.
|
|
86
91
|
else:
|
|
87
92
|
cls = class_from_dataclass(value.__class__)
|
|
88
93
|
elif isinstance(value, dict) and '_target_' in value:
|
torch_geometric/config_store.py
CHANGED
|
@@ -168,7 +168,7 @@ def map_annotation(
|
|
|
168
168
|
assert origin is not None
|
|
169
169
|
args = tuple(map_annotation(a, mapping) for a in args)
|
|
170
170
|
if type(annotation).__name__ == 'GenericAlias':
|
|
171
|
-
# If annotated with `list[...]` or `dict[...]
|
|
171
|
+
# If annotated with `list[...]` or `dict[...]`:
|
|
172
172
|
annotation = origin[args]
|
|
173
173
|
else:
|
|
174
174
|
# If annotated with `typing.List[...]` or `typing.Dict[...]`:
|
|
@@ -151,7 +151,7 @@ class PGMExplainer(ExplainerAlgorithm):
|
|
|
151
151
|
|
|
152
152
|
pred_change = torch.max(soft_pred) - soft_pred_perturb[pred_label]
|
|
153
153
|
|
|
154
|
-
sample[num_nodes] = pred_change
|
|
154
|
+
sample[num_nodes] = pred_change.detach()
|
|
155
155
|
samples.append(sample)
|
|
156
156
|
|
|
157
157
|
samples = torch.tensor(np.array(samples))
|
torch_geometric/data/__init__.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
2
|
|
|
3
|
+
import torch
|
|
4
|
+
import torch_geometric.typing
|
|
5
|
+
|
|
3
6
|
from .feature_store import FeatureStore, TensorAttr
|
|
4
|
-
from .graph_store import GraphStore, EdgeAttr
|
|
7
|
+
from .graph_store import GraphStore, EdgeAttr, EdgeLayout
|
|
5
8
|
from .data import Data
|
|
6
9
|
from .hetero_data import HeteroData
|
|
7
10
|
from .batch import Batch
|
|
@@ -68,6 +71,21 @@ from torch_geometric.loader import DataLoader
|
|
|
68
71
|
from torch_geometric.loader import DataListLoader
|
|
69
72
|
from torch_geometric.loader import DenseDataLoader
|
|
70
73
|
|
|
74
|
+
# Serialization ###############################################################
|
|
75
|
+
|
|
76
|
+
if torch_geometric.typing.WITH_PT24:
|
|
77
|
+
torch.serialization.add_safe_globals([
|
|
78
|
+
Data,
|
|
79
|
+
HeteroData,
|
|
80
|
+
TemporalData,
|
|
81
|
+
ClusterData,
|
|
82
|
+
TensorAttr,
|
|
83
|
+
EdgeAttr,
|
|
84
|
+
EdgeLayout,
|
|
85
|
+
])
|
|
86
|
+
|
|
87
|
+
# Deprecations ################################################################
|
|
88
|
+
|
|
71
89
|
NeighborSampler = deprecated( # type: ignore
|
|
72
90
|
details="use 'loader.NeighborSampler' instead",
|
|
73
91
|
func_name='data.NeighborSampler',
|
torch_geometric/data/batch.py
CHANGED
|
@@ -125,8 +125,8 @@ class Batch(metaclass=DynamicInheritance):
|
|
|
125
125
|
cls=self.__class__.__bases__[-1],
|
|
126
126
|
batch=self,
|
|
127
127
|
idx=idx,
|
|
128
|
-
slice_dict=
|
|
129
|
-
inc_dict=
|
|
128
|
+
slice_dict=self._slice_dict,
|
|
129
|
+
inc_dict=self._inc_dict,
|
|
130
130
|
decrement=True,
|
|
131
131
|
)
|
|
132
132
|
|
torch_geometric/data/collate.py
CHANGED
|
@@ -191,10 +191,8 @@ def _collate(
|
|
|
191
191
|
if torch_geometric.typing.WITH_PT20:
|
|
192
192
|
storage = elem.untyped_storage()._new_shared(
|
|
193
193
|
numel * elem.element_size(), device=elem.device)
|
|
194
|
-
elif torch_geometric.typing.WITH_PT112:
|
|
195
|
-
storage = elem.storage()._new_shared(numel, device=elem.device)
|
|
196
194
|
else:
|
|
197
|
-
storage = elem.storage()._new_shared(numel)
|
|
195
|
+
storage = elem.storage()._new_shared(numel, device=elem.device)
|
|
198
196
|
shape = list(elem.size())
|
|
199
197
|
if cat_dim is None or elem.dim() == 0:
|
|
200
198
|
shape = [len(values)] + shape
|
torch_geometric/data/data.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import warnings
|
|
3
|
+
from collections import defaultdict
|
|
3
4
|
from collections.abc import Mapping, Sequence
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from itertools import chain
|
|
@@ -354,7 +355,7 @@ class BaseData:
|
|
|
354
355
|
"""
|
|
355
356
|
return self.apply(lambda x: x.contiguous(), *args)
|
|
356
357
|
|
|
357
|
-
def to(self, device: Union[int, str], *args: str,
|
|
358
|
+
def to(self, device: Union[int, str, torch.device], *args: str,
|
|
358
359
|
non_blocking: bool = False):
|
|
359
360
|
r"""Performs tensor device conversion, either for all attributes or
|
|
360
361
|
only the ones given in :obj:`*args`.
|
|
@@ -659,7 +660,13 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
|
659
660
|
return value.get_dim_size()
|
|
660
661
|
return int(value.max()) + 1
|
|
661
662
|
elif 'index' in key or key == 'face':
|
|
662
|
-
|
|
663
|
+
num_nodes = self.num_nodes
|
|
664
|
+
if num_nodes is None:
|
|
665
|
+
raise RuntimeError(f"Unable to infer 'num_nodes' from the "
|
|
666
|
+
f"attribute '{key}'. Please explicitly set "
|
|
667
|
+
f"'num_nodes' as an attribute of 'data' to "
|
|
668
|
+
f"prevent this error")
|
|
669
|
+
return num_nodes
|
|
663
670
|
else:
|
|
664
671
|
return 0
|
|
665
672
|
|
|
@@ -844,14 +851,14 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
|
844
851
|
# that maps global node indices to local ones in the final
|
|
845
852
|
# heterogeneous graph:
|
|
846
853
|
node_ids, index_map = {}, torch.empty_like(node_type)
|
|
847
|
-
for i
|
|
854
|
+
for i in range(len(node_type_names)):
|
|
848
855
|
node_ids[i] = (node_type == i).nonzero(as_tuple=False).view(-1)
|
|
849
856
|
index_map[node_ids[i]] = torch.arange(len(node_ids[i]),
|
|
850
857
|
device=index_map.device)
|
|
851
858
|
|
|
852
859
|
# We iterate over edge types to find the local edge indices:
|
|
853
860
|
edge_ids = {}
|
|
854
|
-
for i
|
|
861
|
+
for i in range(len(edge_type_names)):
|
|
855
862
|
edge_ids[i] = (edge_type == i).nonzero(as_tuple=False).view(-1)
|
|
856
863
|
|
|
857
864
|
data = HeteroData()
|
|
@@ -898,6 +905,60 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
|
898
905
|
|
|
899
906
|
return data
|
|
900
907
|
|
|
908
|
+
def connected_components(self) -> List[Self]:
|
|
909
|
+
r"""Extracts connected components of the graph using a union-find
|
|
910
|
+
algorithm. The components are returned as a list of
|
|
911
|
+
:class:`~torch_geometric.data.Data` objects, where each object
|
|
912
|
+
represents a connected component of the graph.
|
|
913
|
+
|
|
914
|
+
.. code-block::
|
|
915
|
+
|
|
916
|
+
data = Data()
|
|
917
|
+
data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
|
918
|
+
data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])
|
|
919
|
+
data.edge_index = torch.tensor(
|
|
920
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
components = data.connected_components()
|
|
924
|
+
print(len(components))
|
|
925
|
+
>>> 2
|
|
926
|
+
|
|
927
|
+
print(components[0].x)
|
|
928
|
+
>>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
List[Data]: A list of disconnected components.
|
|
932
|
+
"""
|
|
933
|
+
# Union-Find algorithm to find connected components
|
|
934
|
+
self._parents: Dict[int, int] = {}
|
|
935
|
+
self._ranks: Dict[int, int] = {}
|
|
936
|
+
for edge in self.edge_index.t().tolist():
|
|
937
|
+
self._union(edge[0], edge[1])
|
|
938
|
+
|
|
939
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
|
940
|
+
for node in range(self.num_nodes):
|
|
941
|
+
self._find_parent(node)
|
|
942
|
+
|
|
943
|
+
# Group parents
|
|
944
|
+
grouped_parents = defaultdict(list)
|
|
945
|
+
for node, parent in self._parents.items():
|
|
946
|
+
grouped_parents[parent].append(node)
|
|
947
|
+
del self._ranks
|
|
948
|
+
del self._parents
|
|
949
|
+
|
|
950
|
+
# Create components based on the found parents (roots)
|
|
951
|
+
components: List[Self] = []
|
|
952
|
+
for nodes in grouped_parents.values():
|
|
953
|
+
# Convert the list of node IDs to a tensor
|
|
954
|
+
subset = torch.tensor(nodes, dtype=torch.long)
|
|
955
|
+
|
|
956
|
+
# Use the existing subgraph function
|
|
957
|
+
component_data = self.subgraph(subset)
|
|
958
|
+
components.append(component_data)
|
|
959
|
+
|
|
960
|
+
return components
|
|
961
|
+
|
|
901
962
|
###########################################################################
|
|
902
963
|
|
|
903
964
|
@classmethod
|
|
@@ -1144,6 +1205,49 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
|
1144
1205
|
|
|
1145
1206
|
return list(edge_attrs.values())
|
|
1146
1207
|
|
|
1208
|
+
# Connected Components Helper Functions ###################################
|
|
1209
|
+
|
|
1210
|
+
def _find_parent(self, node: int) -> int:
|
|
1211
|
+
r"""Finds and returns the representative parent of the given node in a
|
|
1212
|
+
disjoint-set (union-find) data structure. Implements path compression
|
|
1213
|
+
to optimize future queries.
|
|
1214
|
+
|
|
1215
|
+
Args:
|
|
1216
|
+
node (int): The node for which to find the representative parent.
|
|
1217
|
+
|
|
1218
|
+
Returns:
|
|
1219
|
+
int: The representative parent of the node.
|
|
1220
|
+
"""
|
|
1221
|
+
if node not in self._parents:
|
|
1222
|
+
self._parents[node] = node
|
|
1223
|
+
self._ranks[node] = 0
|
|
1224
|
+
if self._parents[node] != node:
|
|
1225
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
|
1226
|
+
return self._parents[node]
|
|
1227
|
+
|
|
1228
|
+
def _union(self, node1: int, node2: int):
|
|
1229
|
+
r"""Merges the sets containing node1 and node2 in the disjoint-set
|
|
1230
|
+
data structure.
|
|
1231
|
+
|
|
1232
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
|
1233
|
+
method. If they belong to different sets, updates the parent of
|
|
1234
|
+
root2 to be root1, effectively merging the two sets.
|
|
1235
|
+
|
|
1236
|
+
Args:
|
|
1237
|
+
node1 (int): The index of the first node to union.
|
|
1238
|
+
node2 (int): The index of the second node to union.
|
|
1239
|
+
"""
|
|
1240
|
+
root1 = self._find_parent(node1)
|
|
1241
|
+
root2 = self._find_parent(node2)
|
|
1242
|
+
if root1 != root2:
|
|
1243
|
+
if self._ranks[root1] < self._ranks[root2]:
|
|
1244
|
+
self._parents[root1] = root2
|
|
1245
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
|
1246
|
+
self._parents[root2] = root1
|
|
1247
|
+
else:
|
|
1248
|
+
self._parents[root2] = root1
|
|
1249
|
+
self._ranks[root1] += 1
|
|
1250
|
+
|
|
1147
1251
|
|
|
1148
1252
|
###############################################################################
|
|
1149
1253
|
|
|
@@ -1165,7 +1269,7 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
|
|
|
1165
1269
|
f'[{value.num_rows}, {value.num_cols}])')
|
|
1166
1270
|
elif isinstance(value, str):
|
|
1167
1271
|
out = f"'{value}'"
|
|
1168
|
-
elif isinstance(value, Sequence):
|
|
1272
|
+
elif isinstance(value, (Sequence, set)):
|
|
1169
1273
|
out = str([len(value)])
|
|
1170
1274
|
elif isinstance(value, Mapping) and len(value) == 0:
|
|
1171
1275
|
out = '{}'
|
|
@@ -1187,4 +1291,4 @@ def warn_or_raise(msg: str, raise_on_error: bool = True):
|
|
|
1187
1291
|
if raise_on_error:
|
|
1188
1292
|
raise ValueError(msg)
|
|
1189
1293
|
else:
|
|
1190
|
-
warnings.warn(msg)
|
|
1294
|
+
warnings.warn(msg, stacklevel=2)
|