pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -10,10 +10,16 @@ from packaging.requirements import Requirement
|
|
|
10
10
|
from packaging.version import Version
|
|
11
11
|
|
|
12
12
|
import torch_geometric
|
|
13
|
+
import torch_geometric.typing
|
|
13
14
|
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
|
|
14
15
|
from torch_geometric.visualization.graph import has_graphviz
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def is_rag_test() -> bool:
|
|
19
|
+
r"""Whether to run the RAG test suite."""
|
|
20
|
+
return os.getenv('RAG_TEST', '0') == '1'
|
|
21
|
+
|
|
22
|
+
|
|
17
23
|
def is_full_test() -> bool:
|
|
18
24
|
r"""Whether to run the full but time-consuming test suite."""
|
|
19
25
|
return os.getenv('FULL_TEST', '0') == '1'
|
|
@@ -32,8 +38,8 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
32
38
|
|
|
33
39
|
def is_distributed_test() -> bool:
|
|
34
40
|
r"""Whether to run the distributed test suite."""
|
|
35
|
-
return (
|
|
36
|
-
and
|
|
41
|
+
return (os.getenv('DIST_TEST', '0') == '1' and sys.platform == 'linux'
|
|
42
|
+
and has_package('pyg_lib'))
|
|
37
43
|
|
|
38
44
|
|
|
39
45
|
def onlyDistributedTest(func: Callable) -> Callable:
|
|
@@ -203,6 +209,18 @@ def withPackage(*args: str) -> Callable:
|
|
|
203
209
|
return decorator
|
|
204
210
|
|
|
205
211
|
|
|
212
|
+
def onlyRAG(func: Callable) -> Callable:
|
|
213
|
+
r"""A decorator to specify that this function belongs to the RAG test
|
|
214
|
+
suite.
|
|
215
|
+
"""
|
|
216
|
+
import pytest
|
|
217
|
+
func = pytest.mark.rag(func)
|
|
218
|
+
return pytest.mark.skipif(
|
|
219
|
+
not is_rag_test(),
|
|
220
|
+
reason="RAG tests are disabled",
|
|
221
|
+
)(func)
|
|
222
|
+
|
|
223
|
+
|
|
206
224
|
def withCUDA(func: Callable) -> Callable:
|
|
207
225
|
r"""A decorator to test both on CPU and CUDA (if available)."""
|
|
208
226
|
import pytest
|
|
@@ -234,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
|
|
|
234
252
|
if device:
|
|
235
253
|
backend = os.getenv('TORCH_BACKEND')
|
|
236
254
|
if backend is None:
|
|
237
|
-
warnings.warn(
|
|
238
|
-
|
|
255
|
+
warnings.warn(
|
|
256
|
+
f"Please specify the backend via 'TORCH_BACKEND' in"
|
|
257
|
+
f"order to test against '{device}'", stacklevel=2)
|
|
239
258
|
else:
|
|
240
259
|
import_module(backend)
|
|
241
260
|
devices.append(pytest.param(torch.device(device), id=device))
|
|
@@ -250,7 +269,7 @@ def withMETIS(func: Callable) -> Callable:
|
|
|
250
269
|
with_metis = WITH_METIS
|
|
251
270
|
|
|
252
271
|
if with_metis:
|
|
253
|
-
try: # Test that METIS can
|
|
272
|
+
try: # Test that METIS can successfully execute:
|
|
254
273
|
# TODO Using `pyg-lib` metis partitioning leads to some weird bugs
|
|
255
274
|
# in the # CI. As such, we require `torch-sparse` for now.
|
|
256
275
|
rowptr = torch.tensor([0, 2, 4, 6])
|
|
@@ -265,6 +284,17 @@ def withMETIS(func: Callable) -> Callable:
|
|
|
265
284
|
)(func)
|
|
266
285
|
|
|
267
286
|
|
|
287
|
+
def withHashTensor(func: Callable) -> Callable:
|
|
288
|
+
r"""A decorator to only test in case :class:`HashTensor` is available."""
|
|
289
|
+
import pytest
|
|
290
|
+
|
|
291
|
+
return pytest.mark.skipif(
|
|
292
|
+
not torch_geometric.typing.WITH_CPU_HASH_MAP
|
|
293
|
+
and not has_package('pandas'),
|
|
294
|
+
reason="HashTensor dependencies not available",
|
|
295
|
+
)(func)
|
|
296
|
+
|
|
297
|
+
|
|
268
298
|
def disableExtensions(func: Callable) -> Callable:
|
|
269
299
|
r"""A decorator to temporarily disable the usage of the
|
|
270
300
|
:obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
|
|
@@ -37,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
|
|
|
37
37
|
from .largest_connected_components import LargestConnectedComponents
|
|
38
38
|
from .virtual_node import VirtualNode
|
|
39
39
|
from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
|
|
40
|
+
from .add_gpse import AddGPSE
|
|
40
41
|
from .feature_propagation import FeaturePropagation
|
|
41
42
|
from .half_hop import HalfHop
|
|
42
43
|
|
|
@@ -108,6 +109,7 @@ graph_transforms = [
|
|
|
108
109
|
'VirtualNode',
|
|
109
110
|
'AddLaplacianEigenvectorPE',
|
|
110
111
|
'AddRandomWalkPE',
|
|
112
|
+
'AddGPSE',
|
|
111
113
|
'FeaturePropagation',
|
|
112
114
|
'HalfHop',
|
|
113
115
|
]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch.nn import Module
|
|
4
|
+
|
|
5
|
+
from torch_geometric.data import Data
|
|
6
|
+
from torch_geometric.data.datapipes import functional_transform
|
|
7
|
+
from torch_geometric.transforms import BaseTransform, VirtualNode
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@functional_transform('add_gpse')
|
|
11
|
+
class AddGPSE(BaseTransform):
|
|
12
|
+
r"""Adds the GPSE encoding from the `"Graph Positional and Structural
|
|
13
|
+
Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
|
|
14
|
+
(functional name: :obj:`add_gpse`).
|
|
15
|
+
To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
|
|
16
|
+
the actual encodings.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model (Module): The pre-trained GPSE model.
|
|
20
|
+
use_vn (bool, optional): Whether to use virtual nodes.
|
|
21
|
+
(default: :obj:`True`)
|
|
22
|
+
rand_type (str, optional): Type of random features to use. Options are
|
|
23
|
+
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
|
|
24
|
+
(default: :obj:`NormalSE`)
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model: Module,
|
|
30
|
+
use_vn: bool = True,
|
|
31
|
+
rand_type: str = 'NormalSE',
|
|
32
|
+
):
|
|
33
|
+
self.model = model
|
|
34
|
+
self.use_vn = use_vn
|
|
35
|
+
self.vn = VirtualNode()
|
|
36
|
+
self.rand_type = rand_type
|
|
37
|
+
|
|
38
|
+
def forward(self, data: Data) -> Any:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def __call__(self, data: Data) -> Data:
|
|
42
|
+
from torch_geometric.nn.models.gpse import gpse_process
|
|
43
|
+
|
|
44
|
+
data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
|
|
45
|
+
batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
|
|
46
|
+
batch_out = batch_out.to('cpu', non_blocking=True)
|
|
47
|
+
data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
|
|
48
|
+
|
|
49
|
+
return data
|
|
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
|
|
|
108
108
|
**kwargs: bool,
|
|
109
109
|
) -> None:
|
|
110
110
|
if 'drop_orig_edges' in kwargs:
|
|
111
|
-
warnings.warn(
|
|
112
|
-
|
|
111
|
+
warnings.warn(
|
|
112
|
+
"'drop_orig_edges' is deprecated. Use "
|
|
113
|
+
"'drop_orig_edge_types' instead", stacklevel=2)
|
|
113
114
|
drop_orig_edge_types = kwargs['drop_orig_edges']
|
|
114
115
|
|
|
115
116
|
if 'drop_unconnected_nodes' in kwargs:
|
|
116
|
-
warnings.warn(
|
|
117
|
-
|
|
117
|
+
warnings.warn(
|
|
118
|
+
"'drop_unconnected_nodes' is deprecated. Use "
|
|
119
|
+
"'drop_unconnected_node_types' instead", stacklevel=2)
|
|
118
120
|
drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
|
|
119
121
|
|
|
120
122
|
for path in metapaths:
|
|
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
|
|
|
144
146
|
if self.max_sample is not None:
|
|
145
147
|
edge_index, edge_weight = self._sample(edge_index, edge_weight)
|
|
146
148
|
|
|
147
|
-
for
|
|
149
|
+
for edge_type in metapath[1:]:
|
|
148
150
|
edge_index2, edge_weight2 = self._edge_index(data, edge_type)
|
|
149
151
|
|
|
150
152
|
edge_index, edge_weight = edge_index.matmul(
|
|
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
|
|
|
276
278
|
row = start = torch.randperm(num_nodes)[:num_starts].repeat(
|
|
277
279
|
self.walks_per_node[j])
|
|
278
280
|
|
|
279
|
-
for
|
|
281
|
+
for edge_type in metapath:
|
|
280
282
|
edge_index = EdgeIndex(
|
|
281
283
|
data[edge_type].edge_index,
|
|
282
284
|
sparse_size=data[edge_type].size(),
|
|
@@ -92,12 +92,12 @@ class AddLaplacianEigenvectorPE(BaseTransform):
|
|
|
92
92
|
from numpy.linalg import eig, eigh
|
|
93
93
|
eig_fn = eig if not self.is_undirected else eigh
|
|
94
94
|
|
|
95
|
-
eig_vals, eig_vecs = eig_fn(L.todense())
|
|
95
|
+
eig_vals, eig_vecs = eig_fn(L.todense())
|
|
96
96
|
else:
|
|
97
97
|
from scipy.sparse.linalg import eigs, eigsh
|
|
98
98
|
eig_fn = eigs if not self.is_undirected else eigsh
|
|
99
99
|
|
|
100
|
-
eig_vals, eig_vecs = eig_fn(
|
|
100
|
+
eig_vals, eig_vecs = eig_fn(
|
|
101
101
|
L,
|
|
102
102
|
k=self.k + 1,
|
|
103
103
|
which='SR' if not self.is_undirected else 'SA',
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
from abc import ABC
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
|
|
@@ -31,6 +31,7 @@ class BaseTransform(ABC):
|
|
|
31
31
|
# Shallow-copy the data so that we prevent in-place data modification.
|
|
32
32
|
return self.forward(copy.copy(data))
|
|
33
33
|
|
|
34
|
+
@abstractmethod
|
|
34
35
|
def forward(self, data: Any) -> Any:
|
|
35
36
|
pass
|
|
36
37
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
|
|
3
5
|
from torch_geometric.data import Data
|
|
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
|
|
|
5
7
|
from torch_geometric.transforms import BaseTransform
|
|
6
8
|
|
|
7
9
|
|
|
10
|
+
class _QhullTransform(BaseTransform):
|
|
11
|
+
r"""Q-hull implementation of delaunay triangulation."""
|
|
12
|
+
def forward(self, data: Data) -> Data:
|
|
13
|
+
assert data.pos is not None
|
|
14
|
+
import scipy.spatial
|
|
15
|
+
|
|
16
|
+
pos = data.pos.cpu().numpy()
|
|
17
|
+
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
|
18
|
+
face = torch.from_numpy(tri.simplices)
|
|
19
|
+
|
|
20
|
+
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
|
21
|
+
return data
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _ShullTransform(BaseTransform):
|
|
25
|
+
r"""Sweep-hull implementation of delaunay triangulation."""
|
|
26
|
+
def forward(self, data: Data) -> Data:
|
|
27
|
+
assert data.pos is not None
|
|
28
|
+
from torch_delaunay.functional import shull2d
|
|
29
|
+
|
|
30
|
+
face = shull2d(data.pos.cpu())
|
|
31
|
+
data.face = face.t().contiguous().to(data.pos.device)
|
|
32
|
+
return data
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _SequentialTransform(BaseTransform):
|
|
36
|
+
r"""Runs the first successful transformation.
|
|
37
|
+
|
|
38
|
+
All intermediate exceptions are suppressed except the last.
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, transforms: List[BaseTransform]) -> None:
|
|
41
|
+
assert len(transforms) > 0
|
|
42
|
+
self.transforms = transforms
|
|
43
|
+
|
|
44
|
+
def forward(self, data: Data) -> Data:
|
|
45
|
+
for i, transform in enumerate(self.transforms):
|
|
46
|
+
try:
|
|
47
|
+
return transform.forward(data)
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
if i == len(self.transforms) - 1:
|
|
50
|
+
raise e
|
|
51
|
+
return data
|
|
52
|
+
|
|
53
|
+
|
|
8
54
|
@functional_transform('delaunay')
|
|
9
55
|
class Delaunay(BaseTransform):
|
|
10
56
|
r"""Computes the delaunay triangulation of a set of points
|
|
11
57
|
(functional name: :obj:`delaunay`).
|
|
58
|
+
|
|
59
|
+
.. hint::
|
|
60
|
+
Consider installing the
|
|
61
|
+
`torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
|
|
62
|
+
to speed up computation.
|
|
12
63
|
"""
|
|
13
|
-
def
|
|
14
|
-
|
|
64
|
+
def __init__(self) -> None:
|
|
65
|
+
self._transform = _SequentialTransform([
|
|
66
|
+
_ShullTransform(),
|
|
67
|
+
_QhullTransform(),
|
|
68
|
+
])
|
|
15
69
|
|
|
70
|
+
def forward(self, data: Data) -> Data:
|
|
16
71
|
assert data.pos is not None
|
|
72
|
+
device = data.pos.device
|
|
17
73
|
|
|
18
74
|
if data.pos.size(0) < 2:
|
|
19
|
-
data.edge_index = torch.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
data.edge_index = torch.tensor([[0, 1], [1, 0]],
|
|
23
|
-
device=data.pos.device)
|
|
75
|
+
data.edge_index = torch.empty(2, 0, dtype=torch.long,
|
|
76
|
+
device=device)
|
|
77
|
+
elif data.pos.size(0) == 2:
|
|
78
|
+
data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
|
|
24
79
|
elif data.pos.size(0) == 3:
|
|
25
|
-
data.face = torch.tensor([[0], [1], [2]],
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
pos = data.pos.cpu().numpy()
|
|
29
|
-
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
|
30
|
-
face = torch.from_numpy(tri.simplices)
|
|
31
|
-
|
|
32
|
-
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
|
80
|
+
data.face = torch.tensor([[0], [1], [2]], device=device)
|
|
81
|
+
else:
|
|
82
|
+
data = self._transform.forward(data)
|
|
33
83
|
|
|
34
84
|
return data
|
|
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
|
|
|
8
8
|
|
|
9
9
|
@functional_transform('face_to_edge')
|
|
10
10
|
class FaceToEdge(BaseTransform):
|
|
11
|
-
r"""Converts mesh faces :obj:`[3, num_faces]`
|
|
12
|
-
:obj:`[
|
|
11
|
+
r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
|
|
12
|
+
:obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
|
|
13
|
+
(functional name: :obj:`face_to_edge`).
|
|
14
|
+
|
|
15
|
+
This transform supports both 2D triangular faces, represented by a
|
|
16
|
+
tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
|
|
17
|
+
represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
|
|
18
|
+
these faces into edge indices, where each edge is defined by the indices
|
|
19
|
+
of its two endpoints.
|
|
13
20
|
|
|
14
21
|
Args:
|
|
15
22
|
remove_faces (bool, optional): If set to :obj:`False`, the face tensor
|
|
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
|
|
|
22
29
|
if hasattr(data, 'face'):
|
|
23
30
|
assert data.face is not None
|
|
24
31
|
face = data.face
|
|
25
|
-
|
|
32
|
+
|
|
33
|
+
if face.size(0) not in [3, 4]:
|
|
34
|
+
raise RuntimeError(f"Expected 'face' tensor with shape "
|
|
35
|
+
f"[3, num_faces] or [4, num_faces] "
|
|
36
|
+
f"(got {list(face.size())})")
|
|
37
|
+
|
|
38
|
+
if face.size()[0] == 3:
|
|
39
|
+
edge_index = torch.cat([
|
|
40
|
+
face[:2],
|
|
41
|
+
face[1:],
|
|
42
|
+
face[::2],
|
|
43
|
+
], dim=1)
|
|
44
|
+
else:
|
|
45
|
+
assert face.size()[0] == 4
|
|
46
|
+
edge_index = torch.cat([
|
|
47
|
+
face[:2],
|
|
48
|
+
face[1:3],
|
|
49
|
+
face[2:4],
|
|
50
|
+
face[::2],
|
|
51
|
+
face[1::2],
|
|
52
|
+
face[::3],
|
|
53
|
+
], dim=1)
|
|
54
|
+
|
|
26
55
|
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
|
|
27
56
|
|
|
28
57
|
data.edge_index = edge_index
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, Tuple
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -78,18 +78,17 @@ class GDC(BaseTransform):
|
|
|
78
78
|
self_loop_weight: float = 1.,
|
|
79
79
|
normalization_in: str = 'sym',
|
|
80
80
|
normalization_out: str = 'col',
|
|
81
|
-
diffusion_kwargs: Dict[str, Any] =
|
|
82
|
-
sparsification_kwargs: Dict[str, Any] =
|
|
83
|
-
method='threshold',
|
|
84
|
-
avg_degree=64,
|
|
85
|
-
),
|
|
81
|
+
diffusion_kwargs: Optional[Dict[str, Any]] = None,
|
|
82
|
+
sparsification_kwargs: Optional[Dict[str, Any]] = None,
|
|
86
83
|
exact: bool = True,
|
|
87
84
|
) -> None:
|
|
88
85
|
self.self_loop_weight = self_loop_weight
|
|
89
86
|
self.normalization_in = normalization_in
|
|
90
87
|
self.normalization_out = normalization_out
|
|
91
|
-
self.diffusion_kwargs = diffusion_kwargs
|
|
92
|
-
|
|
88
|
+
self.diffusion_kwargs = diffusion_kwargs or dict(
|
|
89
|
+
method='ppr', alpha=0.15)
|
|
90
|
+
self.sparsification_kwargs = sparsification_kwargs or dict(
|
|
91
|
+
method='threshold', avg_degree=64)
|
|
93
92
|
self.exact = exact
|
|
94
93
|
|
|
95
94
|
if self_loop_weight:
|
|
@@ -47,7 +47,7 @@ class LargestConnectedComponents(BaseTransform):
|
|
|
47
47
|
return data
|
|
48
48
|
|
|
49
49
|
_, count = np.unique(component, return_counts=True)
|
|
50
|
-
subset_np = np.
|
|
50
|
+
subset_np = np.isin(component, count.argsort()[-self.num_components:])
|
|
51
51
|
subset = torch.from_numpy(subset_np)
|
|
52
52
|
subset = subset.to(data.edge_index.device, torch.bool)
|
|
53
53
|
|
|
@@ -19,7 +19,11 @@ def get_attrs_with_suffix(
|
|
|
19
19
|
return [key for key in store.keys() if key.endswith(suffix)]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def get_mask_size(
|
|
22
|
+
def get_mask_size(
|
|
23
|
+
attr: str,
|
|
24
|
+
store: BaseStorage,
|
|
25
|
+
size: Optional[int],
|
|
26
|
+
) -> Optional[int]:
|
|
23
27
|
if size is not None:
|
|
24
28
|
return size
|
|
25
29
|
return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Union
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
2
|
|
|
3
3
|
from torch_geometric.data import Data, HeteroData
|
|
4
4
|
from torch_geometric.data.datapipes import functional_transform
|
|
@@ -14,8 +14,8 @@ class NormalizeFeatures(BaseTransform):
|
|
|
14
14
|
attrs (List[str]): The names of attributes to normalize.
|
|
15
15
|
(default: :obj:`["x"]`)
|
|
16
16
|
"""
|
|
17
|
-
def __init__(self, attrs: List[str] =
|
|
18
|
-
self.attrs = attrs
|
|
17
|
+
def __init__(self, attrs: Optional[List[str]] = None) -> None:
|
|
18
|
+
self.attrs = attrs or ["x"]
|
|
19
19
|
|
|
20
20
|
def forward(
|
|
21
21
|
self,
|
|
@@ -245,7 +245,7 @@ class RandomLinkSplit(BaseTransform):
|
|
|
245
245
|
warnings.warn(
|
|
246
246
|
f"There are not enough negative edges to satisfy "
|
|
247
247
|
"the provided sampling ratio. The ratio will be "
|
|
248
|
-
f"adjusted to {ratio:.2f}.")
|
|
248
|
+
f"adjusted to {ratio:.2f}.", stacklevel=2)
|
|
249
249
|
num_neg_train = int((num_neg_train / num_neg) * num_neg_found)
|
|
250
250
|
num_neg_val = int((num_neg_val / num_neg) * num_neg_found)
|
|
251
251
|
num_neg_test = num_neg_found - num_neg_train - num_neg_val
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Union
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
2
|
|
|
3
3
|
from torch_geometric.data import Data, HeteroData
|
|
4
4
|
from torch_geometric.data.datapipes import functional_transform
|
|
@@ -22,9 +22,11 @@ class RemoveDuplicatedEdges(BaseTransform):
|
|
|
22
22
|
"""
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
|
-
key: Union[str, List[str]] =
|
|
25
|
+
key: Optional[Union[str, List[str]]] = None,
|
|
26
26
|
reduce: str = "add",
|
|
27
27
|
) -> None:
|
|
28
|
+
key = key or ['edge_attr', 'edge_weight']
|
|
29
|
+
|
|
28
30
|
if isinstance(key, str):
|
|
29
31
|
key = [key]
|
|
30
32
|
|
|
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
|
|
|
94
94
|
arange = torch.arange(n_id.size(0), device=data.edge_index.device)
|
|
95
95
|
node_map = data.edge_index.new_ones(num_nodes, num_nodes)
|
|
96
96
|
node_map[n_sub_batch, n_id] = arange
|
|
97
|
-
sub_edge_index += (arange *
|
|
97
|
+
sub_edge_index += (arange * num_nodes)[e_sub_batch]
|
|
98
98
|
sub_edge_index = node_map.view(-1)[sub_edge_index]
|
|
99
99
|
|
|
100
100
|
return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
|
torch_geometric/typing.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import sys
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -15,8 +15,9 @@ WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
|
|
|
15
15
|
WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
|
|
16
16
|
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
|
|
17
17
|
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
|
|
19
|
+
WITH_PT27 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 7
|
|
20
|
+
WITH_PT28 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 8
|
|
20
21
|
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
|
|
21
22
|
|
|
22
23
|
WITH_WINDOWS = os.name == 'nt'
|
|
@@ -63,10 +64,21 @@ try:
|
|
|
63
64
|
pyg_lib.sampler.neighbor_sample).parameters)
|
|
64
65
|
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
|
|
65
66
|
pyg_lib.sampler.neighbor_sample).parameters)
|
|
67
|
+
try:
|
|
68
|
+
torch.classes.pyg.CPUHashMap # noqa: B018
|
|
69
|
+
WITH_CPU_HASH_MAP = True
|
|
70
|
+
except Exception:
|
|
71
|
+
WITH_CPU_HASH_MAP = False
|
|
72
|
+
try:
|
|
73
|
+
torch.classes.pyg.CUDAHashMap # noqa: B018
|
|
74
|
+
WITH_CUDA_HASH_MAP = True
|
|
75
|
+
except Exception:
|
|
76
|
+
WITH_CUDA_HASH_MAP = False
|
|
66
77
|
except Exception as e:
|
|
67
78
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
68
|
-
warnings.warn(
|
|
69
|
-
|
|
79
|
+
warnings.warn(
|
|
80
|
+
f"An issue occurred while importing 'pyg-lib'. "
|
|
81
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
70
82
|
pyg_lib = object
|
|
71
83
|
WITH_PYG_LIB = False
|
|
72
84
|
WITH_GMM = False
|
|
@@ -77,14 +89,41 @@ except Exception as e:
|
|
|
77
89
|
WITH_METIS = False
|
|
78
90
|
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
|
|
79
91
|
WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
|
|
92
|
+
WITH_CPU_HASH_MAP = False
|
|
93
|
+
WITH_CUDA_HASH_MAP = False
|
|
94
|
+
|
|
95
|
+
if WITH_CPU_HASH_MAP:
|
|
96
|
+
CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap # type: ignore[name-defined] # noqa: E501
|
|
97
|
+
else:
|
|
98
|
+
|
|
99
|
+
class CPUHashMap: # type: ignore
|
|
100
|
+
def __init__(self, key: Tensor) -> None:
|
|
101
|
+
raise ImportError("'CPUHashMap' requires 'pyg-lib'")
|
|
102
|
+
|
|
103
|
+
def get(self, query: Tensor) -> Tensor:
|
|
104
|
+
raise ImportError("'CPUHashMap' requires 'pyg-lib'")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
if WITH_CUDA_HASH_MAP:
|
|
108
|
+
CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap # type: ignore[name-defined] # noqa: E501
|
|
109
|
+
else:
|
|
110
|
+
|
|
111
|
+
class CUDAHashMap: # type: ignore
|
|
112
|
+
def __init__(self, key: Tensor) -> None:
|
|
113
|
+
raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
|
|
114
|
+
|
|
115
|
+
def get(self, query: Tensor) -> Tensor:
|
|
116
|
+
raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
|
|
117
|
+
|
|
80
118
|
|
|
81
119
|
try:
|
|
82
120
|
import torch_scatter # noqa
|
|
83
121
|
WITH_TORCH_SCATTER = True
|
|
84
122
|
except Exception as e:
|
|
85
123
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
86
|
-
warnings.warn(
|
|
87
|
-
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"An issue occurred while importing 'torch-scatter'. "
|
|
126
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
88
127
|
torch_scatter = object
|
|
89
128
|
WITH_TORCH_SCATTER = False
|
|
90
129
|
|
|
@@ -94,8 +133,9 @@ try:
|
|
|
94
133
|
WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
|
|
95
134
|
except Exception as e:
|
|
96
135
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
97
|
-
warnings.warn(
|
|
98
|
-
|
|
136
|
+
warnings.warn(
|
|
137
|
+
f"An issue occurred while importing 'torch-cluster'. "
|
|
138
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
99
139
|
WITH_TORCH_CLUSTER = False
|
|
100
140
|
WITH_TORCH_CLUSTER_BATCH_SIZE = False
|
|
101
141
|
|
|
@@ -112,7 +152,7 @@ except Exception as e:
|
|
|
112
152
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
113
153
|
warnings.warn(
|
|
114
154
|
f"An issue occurred while importing 'torch-spline-conv'. "
|
|
115
|
-
f"Disabling its usage. Stacktrace: {e}")
|
|
155
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
116
156
|
WITH_TORCH_SPLINE_CONV = False
|
|
117
157
|
|
|
118
158
|
try:
|
|
@@ -121,8 +161,9 @@ try:
|
|
|
121
161
|
WITH_TORCH_SPARSE = True
|
|
122
162
|
except Exception as e:
|
|
123
163
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
124
|
-
warnings.warn(
|
|
125
|
-
|
|
164
|
+
warnings.warn(
|
|
165
|
+
f"An issue occurred while importing 'torch-sparse'. "
|
|
166
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
126
167
|
WITH_TORCH_SPARSE = False
|
|
127
168
|
|
|
128
169
|
class SparseStorage: # type: ignore
|
|
@@ -306,6 +347,8 @@ class EdgeTypeStr(str):
|
|
|
306
347
|
r"""A helper class to construct serializable edge types by merging an edge
|
|
307
348
|
type tuple into a single string.
|
|
308
349
|
"""
|
|
350
|
+
edge_type: tuple[str, str, str]
|
|
351
|
+
|
|
309
352
|
def __new__(cls, *args: Any) -> 'EdgeTypeStr':
|
|
310
353
|
if isinstance(args[0], (list, tuple)):
|
|
311
354
|
# Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
|
|
@@ -313,27 +356,37 @@ class EdgeTypeStr(str):
|
|
|
313
356
|
|
|
314
357
|
if len(args) == 1 and isinstance(args[0], str):
|
|
315
358
|
arg = args[0] # An edge type string was passed.
|
|
359
|
+
edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
|
|
360
|
+
if len(edge_type) != 3:
|
|
361
|
+
raise ValueError(f"Cannot convert the edge type '{arg}' to a "
|
|
362
|
+
f"tuple since it holds invalid characters")
|
|
316
363
|
|
|
317
364
|
elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
|
|
318
365
|
# A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
|
|
319
|
-
|
|
366
|
+
edge_type = (args[0], DEFAULT_REL, args[1])
|
|
367
|
+
arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
|
|
320
368
|
|
|
321
369
|
elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
|
|
322
370
|
# A `(src, rel, dst)` edge type was passed:
|
|
371
|
+
edge_type = tuple(args)
|
|
323
372
|
arg = EDGE_TYPE_STR_SPLIT.join(args)
|
|
324
373
|
|
|
325
374
|
else:
|
|
326
375
|
raise ValueError(f"Encountered invalid edge type '{args}'")
|
|
327
376
|
|
|
328
|
-
|
|
377
|
+
out = str.__new__(cls, arg)
|
|
378
|
+
out.edge_type = edge_type # type: ignore
|
|
379
|
+
return out
|
|
329
380
|
|
|
330
381
|
def to_tuple(self) -> EdgeType:
|
|
331
382
|
r"""Returns the original edge type."""
|
|
332
|
-
|
|
333
|
-
if len(out) != 3:
|
|
383
|
+
if len(self.edge_type) != 3:
|
|
334
384
|
raise ValueError(f"Cannot convert the edge type '{self}' to a "
|
|
335
385
|
f"tuple since it holds invalid characters")
|
|
336
|
-
return
|
|
386
|
+
return self.edge_type
|
|
387
|
+
|
|
388
|
+
def __reduce__(self) -> tuple[Any, Any]:
|
|
389
|
+
return (self.__class__, (self.edge_type, ))
|
|
337
390
|
|
|
338
391
|
|
|
339
392
|
# There exist some short-cuts to query edge-types (given that the full triplet
|