pyg-nightly 2.6.0.dev20240511__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +30 -31
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +205 -181
- {pyg_nightly-2.6.0.dev20240511.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +26 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +16 -14
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/data.py +13 -8
- torch_geometric/data/database.py +15 -7
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +13 -22
- torch_geometric/data/graph_store.py +0 -4
- torch_geometric/data/hetero_data.py +4 -4
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/storage.py +15 -5
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +11 -1
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +6 -5
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +7 -1
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +4 -3
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +2 -2
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +17 -8
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +20 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +2 -3
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +9 -3
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/base.py +0 -1
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +100 -82
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +5 -4
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +3 -4
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +1 -2
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +7 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +203 -77
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +24 -15
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/testing/decorators.py +17 -22
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +4 -4
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +2 -2
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +31 -5
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +37 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +5 -5
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +1 -1
- torch_geometric/utils/smiles.py +66 -28
- torch_geometric/utils/sparse.py +25 -10
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/__init__.py
CHANGED
@@ -1,8 +1,15 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch_geometric.typing
|
5
|
+
|
1
6
|
from ._compile import compile, is_compiling
|
7
|
+
from ._onnx import is_in_onnx_export
|
2
8
|
from .index import Index
|
3
9
|
from .edge_index import EdgeIndex
|
4
10
|
from .seed import seed_everything
|
5
11
|
from .home import get_home_dir, set_home_dir
|
12
|
+
from .device import is_mps_available, is_xpu_available, device
|
6
13
|
from .isinstance import is_torch_instance
|
7
14
|
from .debug import is_debug_enabled, debug, set_debug
|
8
15
|
|
@@ -23,7 +30,7 @@ from .lazy_loader import LazyLoader
|
|
23
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
24
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
25
32
|
|
26
|
-
__version__ = '2.
|
33
|
+
__version__ = '2.7.0.dev20250114'
|
27
34
|
|
28
35
|
__all__ = [
|
29
36
|
'Index',
|
@@ -33,6 +40,10 @@ __all__ = [
|
|
33
40
|
'set_home_dir',
|
34
41
|
'compile',
|
35
42
|
'is_compiling',
|
43
|
+
'is_in_onnx_export',
|
44
|
+
'is_mps_available',
|
45
|
+
'is_xpu_available',
|
46
|
+
'device',
|
36
47
|
'is_torch_instance',
|
37
48
|
'is_debug_enabled',
|
38
49
|
'debug',
|
@@ -43,3 +54,17 @@ __all__ = [
|
|
43
54
|
'torch_geometric',
|
44
55
|
'__version__',
|
45
56
|
]
|
57
|
+
|
58
|
+
# Serialization ###############################################################
|
59
|
+
|
60
|
+
if torch_geometric.typing.WITH_PT24:
|
61
|
+
torch.serialization.add_safe_globals([
|
62
|
+
dict,
|
63
|
+
list,
|
64
|
+
defaultdict,
|
65
|
+
Index,
|
66
|
+
torch_geometric.index.CatMetadata,
|
67
|
+
EdgeIndex,
|
68
|
+
torch_geometric.edge_index.SortOrder,
|
69
|
+
torch_geometric.edge_index.CatMetadata,
|
70
|
+
])
|
torch_geometric/_compile.py
CHANGED
@@ -10,6 +10,8 @@ def is_compiling() -> bool:
|
|
10
10
|
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via
|
11
11
|
:meth:`torch.compile`.
|
12
12
|
"""
|
13
|
+
if torch_geometric.typing.WITH_PT23:
|
14
|
+
return torch.compiler.is_compiling()
|
13
15
|
if torch_geometric.typing.WITH_PT21:
|
14
16
|
return torch._dynamo.is_compiling()
|
15
17
|
return False # pragma: no cover
|
@@ -25,10 +27,15 @@ def compile(
|
|
25
27
|
This function has the same signature as :meth:`torch.compile` (see
|
26
28
|
`here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).
|
27
29
|
|
30
|
+
Args:
|
31
|
+
model: The model to compile.
|
32
|
+
*args: Additional arguments of :meth:`torch.compile`.
|
33
|
+
**kwargs: Additional keyword arguments of :meth:`torch.compile`.
|
34
|
+
|
28
35
|
.. note::
|
29
36
|
:meth:`torch_geometric.compile` is deprecated in favor of
|
30
37
|
:meth:`torch.compile`.
|
31
38
|
"""
|
32
39
|
warnings.warn("'torch_geometric.compile' is deprecated in favor of "
|
33
40
|
"'torch.compile'")
|
34
|
-
return torch.compile(model, *args, **kwargs)
|
41
|
+
return torch.compile(model, *args, **kwargs) # type: ignore
|
torch_geometric/_onnx.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from torch_geometric import is_compiling
|
4
|
+
|
5
|
+
|
6
|
+
def is_in_onnx_export() -> bool:
|
7
|
+
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via
|
8
|
+
:meth:`torch.onnx.export`.
|
9
|
+
"""
|
10
|
+
if is_compiling():
|
11
|
+
return False
|
12
|
+
if torch.jit.is_scripting():
|
13
|
+
return False
|
14
|
+
return torch.onnx.is_in_onnx_export()
|
@@ -0,0 +1,113 @@
|
|
1
|
+
import inspect
|
2
|
+
from dataclasses import fields, is_dataclass
|
3
|
+
from importlib import import_module
|
4
|
+
from typing import Any, Dict
|
5
|
+
|
6
|
+
from torch_geometric.config_store import (
|
7
|
+
class_from_dataclass,
|
8
|
+
dataclass_from_class,
|
9
|
+
)
|
10
|
+
from torch_geometric.isinstance import is_torch_instance
|
11
|
+
|
12
|
+
|
13
|
+
class ConfigMixin:
|
14
|
+
r"""Enables a class to serialize/deserialize itself to a dataclass."""
|
15
|
+
def config(self) -> Any:
|
16
|
+
r"""Creates a serializable configuration of the class."""
|
17
|
+
data_cls = dataclass_from_class(self.__class__)
|
18
|
+
if data_cls is None:
|
19
|
+
raise ValueError(f"Could not find the configuration class that "
|
20
|
+
f"belongs to '{self.__class__.__name__}'. Please "
|
21
|
+
f"register it in the configuration store.")
|
22
|
+
|
23
|
+
kwargs: Dict[str, Any] = {}
|
24
|
+
for field in fields(data_cls):
|
25
|
+
if not hasattr(self, field.name):
|
26
|
+
continue
|
27
|
+
kwargs[field.name] = _recursive_config(getattr(self, field.name))
|
28
|
+
return data_cls(**kwargs)
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def from_config(cls, cfg: Any, *args: Any, **kwargs: Any) -> Any:
|
32
|
+
r"""Instantiates the class from a serializable configuration."""
|
33
|
+
if getattr(cfg, '_target_', None):
|
34
|
+
cls = _locate_cls(cfg._target_)
|
35
|
+
elif isinstance(cfg, dict) and '_target_' in cfg:
|
36
|
+
cls = _locate_cls(cfg['_target_'])
|
37
|
+
|
38
|
+
data_cls = cfg.__class__
|
39
|
+
if not is_dataclass(data_cls):
|
40
|
+
data_cls = dataclass_from_class(cls)
|
41
|
+
if data_cls is None:
|
42
|
+
raise ValueError(f"Could not find the configuration class "
|
43
|
+
f"that belongs to '{cls.__name__}'. Please "
|
44
|
+
f"register it in the configuration store.")
|
45
|
+
|
46
|
+
field_names = {field.name for field in fields(data_cls)}
|
47
|
+
if isinstance(cfg, dict):
|
48
|
+
_kwargs = {k: v for k, v in cfg.items() if k in field_names}
|
49
|
+
cfg = data_cls(**_kwargs)
|
50
|
+
assert is_dataclass(cfg)
|
51
|
+
|
52
|
+
if len(args) > 0: # Convert `*args` to `**kwargs`:
|
53
|
+
param_names = list(inspect.signature(cls).parameters.keys())
|
54
|
+
if 'args' in param_names:
|
55
|
+
param_names.remove('args')
|
56
|
+
if 'kwargs' in param_names:
|
57
|
+
param_names.remove('kwargs')
|
58
|
+
|
59
|
+
for name, arg in zip(param_names, args):
|
60
|
+
kwargs[name] = arg
|
61
|
+
|
62
|
+
for key in field_names:
|
63
|
+
if key not in kwargs and key != '_target_':
|
64
|
+
kwargs[key] = _recursive_from_config(getattr(cfg, key))
|
65
|
+
|
66
|
+
return cls(**kwargs)
|
67
|
+
|
68
|
+
|
69
|
+
def _recursive_config(value: Any) -> Any:
|
70
|
+
if isinstance(value, ConfigMixin):
|
71
|
+
return value.config()
|
72
|
+
if is_torch_instance(value, ConfigMixin):
|
73
|
+
return value.config()
|
74
|
+
if isinstance(value, (tuple, list)):
|
75
|
+
return [_recursive_config(v) for v in value]
|
76
|
+
if isinstance(value, dict):
|
77
|
+
return {k: _recursive_config(v) for k, v in value.items()}
|
78
|
+
return value
|
79
|
+
|
80
|
+
|
81
|
+
def _recursive_from_config(value: Any) -> Any:
|
82
|
+
cls: Any = None
|
83
|
+
if is_dataclass(value):
|
84
|
+
if getattr(value, '_target_', None):
|
85
|
+
try:
|
86
|
+
cls = _locate_cls(value._target_) # type: ignore
|
87
|
+
except ImportError:
|
88
|
+
pass # Keep the dataclass as it is.
|
89
|
+
else:
|
90
|
+
cls = class_from_dataclass(value.__class__)
|
91
|
+
elif isinstance(value, dict) and '_target_' in value:
|
92
|
+
cls = _locate_cls(value['_target_'])
|
93
|
+
|
94
|
+
if cls is not None and issubclass(cls, ConfigMixin):
|
95
|
+
return cls.from_config(value)
|
96
|
+
if isinstance(value, (tuple, list)):
|
97
|
+
return [_recursive_from_config(v) for v in value]
|
98
|
+
if isinstance(value, dict):
|
99
|
+
return {k: _recursive_from_config(v) for k, v in value.items()}
|
100
|
+
return value
|
101
|
+
|
102
|
+
|
103
|
+
def _locate_cls(qualname: str) -> Any:
|
104
|
+
parts = qualname.split('.')
|
105
|
+
|
106
|
+
if len(parts) <= 1:
|
107
|
+
raise ValueError(f"Qualified name is missing a dot (got '{qualname}')")
|
108
|
+
|
109
|
+
if any([len(part) == 0 for part in parts]):
|
110
|
+
raise ValueError(f"Relative imports not supported (got '{qualname}')")
|
111
|
+
|
112
|
+
module_name, cls_name = '.'.join(parts[:-1]), parts[-1]
|
113
|
+
return getattr(import_module(module_name), cls_name)
|
torch_geometric/config_store.py
CHANGED
@@ -76,7 +76,7 @@ else:
|
|
76
76
|
|
77
77
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
78
78
|
if cls not in cls._instances:
|
79
|
-
instance = super(
|
79
|
+
instance = super().__call__(*args, **kwargs)
|
80
80
|
cls._instances[cls] = instance
|
81
81
|
return instance
|
82
82
|
return cls._instances[cls]
|
@@ -238,7 +238,7 @@ def to_dataclass(
|
|
238
238
|
if strict: # Check that keys in map_args or exclude_args are present.
|
239
239
|
keys = set() if map_args is None else set(map_args.keys())
|
240
240
|
if exclude_args is not None:
|
241
|
-
keys |=
|
241
|
+
keys |= {arg for arg in exclude_args if isinstance(arg, str)}
|
242
242
|
diff = keys - set(params.keys())
|
243
243
|
if len(diff) > 0:
|
244
244
|
raise ValueError(f"Expected input argument(s) {diff} in "
|
@@ -413,13 +413,13 @@ def fill_config_store() -> None:
|
|
413
413
|
|
414
414
|
# Register `torch_geometric.transforms` ###################################
|
415
415
|
transforms = torch_geometric.transforms
|
416
|
-
for cls_name in set(transforms.__all__) -
|
416
|
+
for cls_name in set(transforms.__all__) - {
|
417
417
|
'BaseTransform',
|
418
418
|
'Compose',
|
419
419
|
'ComposeFilters',
|
420
420
|
'LinearTransformation',
|
421
421
|
'AddMetaPaths', # TODO
|
422
|
-
|
422
|
+
}:
|
423
423
|
cls = to_dataclass(getattr(transforms, cls_name), base_cls=Transform)
|
424
424
|
# We use an explicit additional nesting level inside each config to
|
425
425
|
# allow for applying multiple transformations.
|
@@ -433,7 +433,7 @@ def fill_config_store() -> None:
|
|
433
433
|
'pre_transform': (Dict[str, Transform], field(default_factory=dict)),
|
434
434
|
}
|
435
435
|
|
436
|
-
for cls_name in set(datasets.__all__) - set(
|
436
|
+
for cls_name in set(datasets.__all__) - set():
|
437
437
|
cls = to_dataclass(getattr(datasets, cls_name), base_cls=Dataset,
|
438
438
|
map_args=map_dataset_args,
|
439
439
|
exclude_args=['pre_filter'])
|
@@ -441,32 +441,34 @@ def fill_config_store() -> None:
|
|
441
441
|
|
442
442
|
# Register `torch_geometric.models` #######################################
|
443
443
|
models = torch_geometric.nn.models.basic_gnn
|
444
|
-
for cls_name in set(models.__all__) - set(
|
444
|
+
for cls_name in set(models.__all__) - set():
|
445
445
|
cls = to_dataclass(getattr(models, cls_name), base_cls=Model)
|
446
446
|
config_store.store(cls_name, group='model', node=cls)
|
447
447
|
|
448
448
|
# Register `torch.optim.Optimizer` ########################################
|
449
|
-
for cls_name in
|
450
|
-
key
|
449
|
+
for cls_name in {
|
450
|
+
key
|
451
|
+
for key, cls in torch.optim.__dict__.items()
|
451
452
|
if inspect.isclass(cls) and issubclass(cls, torch.optim.Optimizer)
|
452
|
-
|
453
|
+
} - {
|
453
454
|
'Optimizer',
|
454
|
-
|
455
|
+
}:
|
455
456
|
cls = to_dataclass(getattr(torch.optim, cls_name), base_cls=Optimizer,
|
456
457
|
exclude_args=['params'])
|
457
458
|
config_store.store(cls_name, group='optimizer', node=cls)
|
458
459
|
|
459
460
|
# Register `torch.optim.lr_scheduler` #####################################
|
460
|
-
for cls_name in
|
461
|
-
key
|
461
|
+
for cls_name in {
|
462
|
+
key
|
463
|
+
for key, cls in torch.optim.lr_scheduler.__dict__.items()
|
462
464
|
if inspect.isclass(cls)
|
463
|
-
|
465
|
+
} - {
|
464
466
|
'Optimizer',
|
465
467
|
'_LRScheduler',
|
466
468
|
'Counter',
|
467
469
|
'SequentialLR',
|
468
470
|
'ChainedScheduler',
|
469
|
-
|
471
|
+
}:
|
470
472
|
cls = to_dataclass(getattr(torch.optim.lr_scheduler, cls_name),
|
471
473
|
base_cls=LRScheduler, exclude_args=['optimizer'])
|
472
474
|
config_store.store(cls_name, group='lr_scheduler', node=cls)
|
torch_geometric/data/__init__.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
# flake8: noqa
|
2
2
|
|
3
|
+
import torch
|
4
|
+
import torch_geometric.typing
|
5
|
+
|
3
6
|
from .feature_store import FeatureStore, TensorAttr
|
4
|
-
from .graph_store import GraphStore, EdgeAttr
|
7
|
+
from .graph_store import GraphStore, EdgeAttr, EdgeLayout
|
5
8
|
from .data import Data
|
6
9
|
from .hetero_data import HeteroData
|
7
10
|
from .batch import Batch
|
@@ -13,6 +16,7 @@ from .on_disk_dataset import OnDiskDataset
|
|
13
16
|
from .makedirs import makedirs
|
14
17
|
from .download import download_url, download_google_url
|
15
18
|
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
|
19
|
+
from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
|
16
20
|
|
17
21
|
from torch_geometric.lazy_loader import LazyLoader
|
18
22
|
|
@@ -24,6 +28,8 @@ data_classes = [
|
|
24
28
|
'Dataset',
|
25
29
|
'InMemoryDataset',
|
26
30
|
'OnDiskDataset',
|
31
|
+
'LargeGraphIndexer',
|
32
|
+
'TripletLike',
|
27
33
|
]
|
28
34
|
|
29
35
|
remote_backend_classes = [
|
@@ -47,6 +53,8 @@ helper_functions = [
|
|
47
53
|
'extract_zip',
|
48
54
|
'extract_bz2',
|
49
55
|
'extract_gz',
|
56
|
+
'get_features_for_triplets',
|
57
|
+
"get_features_for_triplets_groups",
|
50
58
|
]
|
51
59
|
|
52
60
|
__all__ = data_classes + remote_backend_classes + helper_functions
|
@@ -68,6 +76,21 @@ from torch_geometric.loader import DataLoader
|
|
68
76
|
from torch_geometric.loader import DataListLoader
|
69
77
|
from torch_geometric.loader import DenseDataLoader
|
70
78
|
|
79
|
+
# Serialization ###############################################################
|
80
|
+
|
81
|
+
if torch_geometric.typing.WITH_PT24:
|
82
|
+
torch.serialization.add_safe_globals([
|
83
|
+
Data,
|
84
|
+
HeteroData,
|
85
|
+
TemporalData,
|
86
|
+
ClusterData,
|
87
|
+
TensorAttr,
|
88
|
+
EdgeAttr,
|
89
|
+
EdgeLayout,
|
90
|
+
])
|
91
|
+
|
92
|
+
# Deprecations ################################################################
|
93
|
+
|
71
94
|
NeighborSampler = deprecated( # type: ignore
|
72
95
|
details="use 'loader.NeighborSampler' instead",
|
73
96
|
func_name='data.NeighborSampler',
|
torch_geometric/data/batch.py
CHANGED
@@ -118,8 +118,8 @@ class Batch(metaclass=DynamicInheritance):
|
|
118
118
|
"""
|
119
119
|
if not hasattr(self, '_slice_dict'):
|
120
120
|
raise RuntimeError(
|
121
|
-
|
122
|
-
|
121
|
+
"Cannot reconstruct 'Data' object from 'Batch' because "
|
122
|
+
"'Batch' was not created via 'Batch.from_data_list()'")
|
123
123
|
|
124
124
|
data = separate(
|
125
125
|
cls=self.__class__.__bases__[-1],
|
torch_geometric/data/data.py
CHANGED
@@ -291,13 +291,14 @@ class BaseData:
|
|
291
291
|
self,
|
292
292
|
start_time: Union[float, int],
|
293
293
|
end_time: Union[float, int],
|
294
|
+
attr: str = 'time',
|
294
295
|
) -> Self:
|
295
296
|
r"""Returns a snapshot of :obj:`data` to only hold events that occurred
|
296
297
|
in period :obj:`[start_time, end_time]`.
|
297
298
|
"""
|
298
299
|
out = copy.copy(self)
|
299
300
|
for store in out.stores:
|
300
|
-
store.snapshot(start_time, end_time)
|
301
|
+
store.snapshot(start_time, end_time, attr)
|
301
302
|
return out
|
302
303
|
|
303
304
|
def up_to(self, end_time: Union[float, int]) -> Self:
|
@@ -645,7 +646,7 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
645
646
|
return self
|
646
647
|
|
647
648
|
def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
|
648
|
-
if is_sparse(value) and 'adj' in key:
|
649
|
+
if is_sparse(value) and ('adj' in key or 'edge_index' in key):
|
649
650
|
return (0, 1)
|
650
651
|
elif 'index' in key or key == 'face':
|
651
652
|
return -1
|
@@ -658,7 +659,13 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
658
659
|
return value.get_dim_size()
|
659
660
|
return int(value.max()) + 1
|
660
661
|
elif 'index' in key or key == 'face':
|
661
|
-
|
662
|
+
num_nodes = self.num_nodes
|
663
|
+
if num_nodes is None:
|
664
|
+
raise RuntimeError(f"Unable to infer 'num_nodes' from the "
|
665
|
+
f"attribute '{key}'. Please explicitly set "
|
666
|
+
f"'num_nodes' as an attribute of 'data' to "
|
667
|
+
f"prevent this error")
|
668
|
+
return num_nodes
|
662
669
|
else:
|
663
670
|
return 0
|
664
671
|
|
@@ -937,16 +944,14 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
937
944
|
r"""Iterates over all attributes in the data, yielding their attribute
|
938
945
|
names and values.
|
939
946
|
"""
|
940
|
-
|
941
|
-
yield key, value
|
947
|
+
yield from self._store.items()
|
942
948
|
|
943
949
|
def __call__(self, *args: str) -> Iterable:
|
944
950
|
r"""Iterates over all attributes :obj:`*args` in the data, yielding
|
945
951
|
their attribute names and values.
|
946
952
|
If :obj:`*args` is not given, will iterate over all attributes.
|
947
953
|
"""
|
948
|
-
|
949
|
-
yield key, value
|
954
|
+
yield from self._store.items(*args)
|
950
955
|
|
951
956
|
@property
|
952
957
|
def x(self) -> Optional[Tensor]:
|
@@ -1166,7 +1171,7 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
|
|
1166
1171
|
f'[{value.num_rows}, {value.num_cols}])')
|
1167
1172
|
elif isinstance(value, str):
|
1168
1173
|
out = f"'{value}'"
|
1169
|
-
elif isinstance(value, Sequence):
|
1174
|
+
elif isinstance(value, (Sequence, set)):
|
1170
1175
|
out = str([len(value)])
|
1171
1176
|
elif isinstance(value, Mapping) and len(value) == 0:
|
1172
1177
|
out = '{}'
|
torch_geometric/data/database.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import
|
1
|
+
import io
|
2
2
|
import warnings
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from dataclasses import dataclass
|
@@ -115,11 +115,9 @@ class Database(ABC):
|
|
115
115
|
r"""Connects to the database.
|
116
116
|
Databases will automatically connect on instantiation.
|
117
117
|
"""
|
118
|
-
pass
|
119
118
|
|
120
119
|
def close(self) -> None:
|
121
120
|
r"""Closes the connection to the database."""
|
122
|
-
pass
|
123
121
|
|
124
122
|
@abstractmethod
|
125
123
|
def insert(self, index: int, data: Any) -> None:
|
@@ -498,7 +496,9 @@ class SQLiteDatabase(Database):
|
|
498
496
|
out.append(col)
|
499
497
|
|
500
498
|
else:
|
501
|
-
|
499
|
+
buffer = io.BytesIO()
|
500
|
+
torch.save(col, buffer)
|
501
|
+
out.append(buffer.getvalue())
|
502
502
|
|
503
503
|
return out
|
504
504
|
|
@@ -561,7 +561,10 @@ class SQLiteDatabase(Database):
|
|
561
561
|
out_dict[key] = value
|
562
562
|
|
563
563
|
else:
|
564
|
-
out_dict[key] =
|
564
|
+
out_dict[key] = torch.load(
|
565
|
+
io.BytesIO(value),
|
566
|
+
weights_only=False,
|
567
|
+
)
|
565
568
|
|
566
569
|
# In case `0` exists as integer in the schema, this means that the
|
567
570
|
# schema was passed as either a single entry or a tuple:
|
@@ -646,7 +649,12 @@ class RocksDatabase(Database):
|
|
646
649
|
# Ensure that data is not a view of a larger tensor:
|
647
650
|
if isinstance(row, Tensor):
|
648
651
|
row = row.clone()
|
649
|
-
|
652
|
+
buffer = io.BytesIO()
|
653
|
+
torch.save(row, buffer)
|
654
|
+
return buffer.getvalue()
|
650
655
|
|
651
656
|
def _deserialize(self, row: bytes) -> Any:
|
652
|
-
return
|
657
|
+
return torch.load(
|
658
|
+
io.BytesIO(row),
|
659
|
+
weights_only=False,
|
660
|
+
)
|
torch_geometric/data/dataset.py
CHANGED
@@ -235,7 +235,8 @@ class Dataset(torch.utils.data.Dataset):
|
|
235
235
|
|
236
236
|
def _process(self):
|
237
237
|
f = osp.join(self.processed_dir, 'pre_transform.pt')
|
238
|
-
if osp.exists(f) and torch.load(f) != _repr(
|
238
|
+
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
|
239
|
+
self.pre_transform):
|
239
240
|
warnings.warn(
|
240
241
|
"The `pre_transform` argument differs from the one used in "
|
241
242
|
"the pre-processed version of this dataset. If you want to "
|
@@ -243,7 +244,8 @@ class Dataset(torch.utils.data.Dataset):
|
|
243
244
|
"`force_reload=True` explicitly to reload the dataset.")
|
244
245
|
|
245
246
|
f = osp.join(self.processed_dir, 'pre_filter.pt')
|
246
|
-
if osp.exists(f) and torch.load(f) != _repr(
|
247
|
+
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
|
248
|
+
self.pre_filter):
|
247
249
|
warnings.warn(
|
248
250
|
"The `pre_filter` argument differs from the one used in "
|
249
251
|
"the pre-processed version of this dataset. If you want to "
|
@@ -367,15 +369,21 @@ class Dataset(torch.utils.data.Dataset):
|
|
367
369
|
from torch_geometric.data.summary import Summary
|
368
370
|
return Summary.from_dataset(self)
|
369
371
|
|
370
|
-
def print_summary(self) -> None:
|
371
|
-
r"""Prints summary statistics of the dataset to the console.
|
372
|
-
|
372
|
+
def print_summary(self, fmt: str = "psql") -> None:
|
373
|
+
r"""Prints summary statistics of the dataset to the console.
|
374
|
+
|
375
|
+
Args:
|
376
|
+
fmt (str, optional): Summary tables format. Available table formats
|
377
|
+
can be found `here <https://github.com/astanin/python-tabulate?
|
378
|
+
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
|
379
|
+
"""
|
380
|
+
print(self.get_summary().format(fmt=fmt))
|
373
381
|
|
374
382
|
def to_datapipe(self) -> Any:
|
375
383
|
r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
|
376
384
|
|
377
385
|
The returned instance can then be used with :pyg:`PyG's` built-in
|
378
|
-
:class:`DataPipes` for
|
386
|
+
:class:`DataPipes` for batching graphs as follows:
|
379
387
|
|
380
388
|
.. code-block:: python
|
381
389
|
|
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
|
|
74
74
|
r"""Whether the :obj:`TensorAttr` has no unset fields."""
|
75
75
|
return all([self.is_set(key) for key in self.__dataclass_fields__])
|
76
76
|
|
77
|
-
def fully_specify(self) -> 'TensorAttr':
|
78
|
-
r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
|
79
|
-
for key in self.__dataclass_fields__:
|
80
|
-
if not self.is_set(key):
|
81
|
-
setattr(self, key, None)
|
82
|
-
return self
|
83
|
-
|
84
77
|
def update(self, attr: 'TensorAttr') -> 'TensorAttr':
|
85
78
|
r"""Updates an :class:`TensorAttr` with set attributes from another
|
86
79
|
:class:`TensorAttr`.
|
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
|
|
230
223
|
|
231
224
|
store[group_name, attr_name]()
|
232
225
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
226
|
+
attr = copy.copy(self._attr)
|
227
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
228
|
+
if not attr.is_set(key):
|
229
|
+
setattr(attr, key, None)
|
230
|
+
return self._store.get_tensor(attr)
|
237
231
|
|
238
232
|
def __copy__(self) -> 'AttrView':
|
239
233
|
out = self.__class__.__new__(self.__class__)
|
@@ -283,7 +277,6 @@ class FeatureStore(ABC):
|
|
283
277
|
@abstractmethod
|
284
278
|
def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
|
285
279
|
r"""To be implemented by :class:`FeatureStore` subclasses."""
|
286
|
-
pass
|
287
280
|
|
288
281
|
def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
|
289
282
|
r"""Synchronously adds a :obj:`tensor` to the :class:`FeatureStore`.
|
@@ -309,7 +302,6 @@ class FeatureStore(ABC):
|
|
309
302
|
@abstractmethod
|
310
303
|
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
|
311
304
|
r"""To be implemented by :class:`FeatureStore` subclasses."""
|
312
|
-
pass
|
313
305
|
|
314
306
|
def get_tensor(
|
315
307
|
self,
|
@@ -394,7 +386,6 @@ class FeatureStore(ABC):
|
|
394
386
|
@abstractmethod
|
395
387
|
def _remove_tensor(self, attr: TensorAttr) -> bool:
|
396
388
|
r"""To be implemented by :obj:`FeatureStore` subclasses."""
|
397
|
-
pass
|
398
389
|
|
399
390
|
def remove_tensor(self, *args, **kwargs) -> bool:
|
400
391
|
r"""Removes a tensor from the :class:`FeatureStore`.
|
@@ -452,7 +443,6 @@ class FeatureStore(ABC):
|
|
452
443
|
@abstractmethod
|
453
444
|
def get_all_tensor_attrs(self) -> List[TensorAttr]:
|
454
445
|
r"""Returns all registered tensor attributes."""
|
455
|
-
pass
|
456
446
|
|
457
447
|
# `AttrView` methods ######################################################
|
458
448
|
|
@@ -483,9 +473,7 @@ class FeatureStore(ABC):
|
|
483
473
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
484
474
|
# object:
|
485
475
|
key = self._tensor_attr_cls.cast(key)
|
486
|
-
|
487
|
-
# sense to work with a view here:
|
488
|
-
key.fully_specify()
|
476
|
+
assert key.is_fully_specified()
|
489
477
|
self.put_tensor(value, key)
|
490
478
|
|
491
479
|
def __getitem__(self, key: TensorAttr) -> Any:
|
@@ -507,13 +495,16 @@ class FeatureStore(ABC):
|
|
507
495
|
# If the view is not fully-specified, return a :class:`AttrView`:
|
508
496
|
return self.view(attr)
|
509
497
|
|
510
|
-
def __delitem__(self,
|
498
|
+
def __delitem__(self, attr: TensorAttr):
|
511
499
|
r"""Supports :obj:`del store[tensor_attr]`."""
|
512
500
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
513
501
|
# object:
|
514
|
-
|
515
|
-
|
516
|
-
|
502
|
+
attr = self._tensor_attr_cls.cast(attr)
|
503
|
+
attr = copy.copy(attr)
|
504
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
505
|
+
if not attr.is_set(key):
|
506
|
+
setattr(attr, key, None)
|
507
|
+
self.remove_tensor(attr)
|
517
508
|
|
518
509
|
def __iter__(self):
|
519
510
|
raise NotImplementedError
|
@@ -116,7 +116,6 @@ class GraphStore(ABC):
|
|
116
116
|
def _put_edge_index(self, edge_index: EdgeTensorType,
|
117
117
|
edge_attr: EdgeAttr) -> bool:
|
118
118
|
r"""To be implemented by :class:`GraphStore` subclasses."""
|
119
|
-
pass
|
120
119
|
|
121
120
|
def put_edge_index(self, edge_index: EdgeTensorType, *args,
|
122
121
|
**kwargs) -> bool:
|
@@ -137,7 +136,6 @@ class GraphStore(ABC):
|
|
137
136
|
@abstractmethod
|
138
137
|
def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
|
139
138
|
r"""To be implemented by :class:`GraphStore` subclasses."""
|
140
|
-
pass
|
141
139
|
|
142
140
|
def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
|
143
141
|
r"""Synchronously obtains an :obj:`edge_index` tuple from the
|
@@ -160,7 +158,6 @@ class GraphStore(ABC):
|
|
160
158
|
@abstractmethod
|
161
159
|
def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:
|
162
160
|
r"""To be implemented by :class:`GraphStore` subclasses."""
|
163
|
-
pass
|
164
161
|
|
165
162
|
def remove_edge_index(self, *args, **kwargs) -> bool:
|
166
163
|
r"""Synchronously deletes an :obj:`edge_index` tuple from the
|
@@ -177,7 +174,6 @@ class GraphStore(ABC):
|
|
177
174
|
@abstractmethod
|
178
175
|
def get_all_edge_attrs(self) -> List[EdgeAttr]:
|
179
176
|
r"""Returns all registered edge attributes."""
|
180
|
-
pass
|
181
177
|
|
182
178
|
# Layout Conversion #######################################################
|
183
179
|
|