pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/__init__.py
CHANGED
@@ -1,7 +1,15 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch_geometric.typing
|
5
|
+
|
1
6
|
from ._compile import compile, is_compiling
|
7
|
+
from ._onnx import is_in_onnx_export
|
8
|
+
from .index import Index
|
2
9
|
from .edge_index import EdgeIndex
|
3
10
|
from .seed import seed_everything
|
4
11
|
from .home import get_home_dir, set_home_dir
|
12
|
+
from .device import is_mps_available, is_xpu_available, device
|
5
13
|
from .isinstance import is_torch_instance
|
6
14
|
from .debug import is_debug_enabled, debug, set_debug
|
7
15
|
|
@@ -22,15 +30,20 @@ from .lazy_loader import LazyLoader
|
|
22
30
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
23
31
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
24
32
|
|
25
|
-
__version__ = '2.
|
33
|
+
__version__ = '2.7.0.dev20250115'
|
26
34
|
|
27
35
|
__all__ = [
|
36
|
+
'Index',
|
28
37
|
'EdgeIndex',
|
29
38
|
'seed_everything',
|
30
39
|
'get_home_dir',
|
31
40
|
'set_home_dir',
|
32
41
|
'compile',
|
33
42
|
'is_compiling',
|
43
|
+
'is_in_onnx_export',
|
44
|
+
'is_mps_available',
|
45
|
+
'is_xpu_available',
|
46
|
+
'device',
|
34
47
|
'is_torch_instance',
|
35
48
|
'is_debug_enabled',
|
36
49
|
'debug',
|
@@ -41,3 +54,17 @@ __all__ = [
|
|
41
54
|
'torch_geometric',
|
42
55
|
'__version__',
|
43
56
|
]
|
57
|
+
|
58
|
+
# Serialization ###############################################################
|
59
|
+
|
60
|
+
if torch_geometric.typing.WITH_PT24:
|
61
|
+
torch.serialization.add_safe_globals([
|
62
|
+
dict,
|
63
|
+
list,
|
64
|
+
defaultdict,
|
65
|
+
Index,
|
66
|
+
torch_geometric.index.CatMetadata,
|
67
|
+
EdgeIndex,
|
68
|
+
torch_geometric.edge_index.SortOrder,
|
69
|
+
torch_geometric.edge_index.CatMetadata,
|
70
|
+
])
|
torch_geometric/_compile.py
CHANGED
@@ -10,6 +10,8 @@ def is_compiling() -> bool:
|
|
10
10
|
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via
|
11
11
|
:meth:`torch.compile`.
|
12
12
|
"""
|
13
|
+
if torch_geometric.typing.WITH_PT23:
|
14
|
+
return torch.compiler.is_compiling()
|
13
15
|
if torch_geometric.typing.WITH_PT21:
|
14
16
|
return torch._dynamo.is_compiling()
|
15
17
|
return False # pragma: no cover
|
@@ -25,10 +27,15 @@ def compile(
|
|
25
27
|
This function has the same signature as :meth:`torch.compile` (see
|
26
28
|
`here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).
|
27
29
|
|
30
|
+
Args:
|
31
|
+
model: The model to compile.
|
32
|
+
*args: Additional arguments of :meth:`torch.compile`.
|
33
|
+
**kwargs: Additional keyword arguments of :meth:`torch.compile`.
|
34
|
+
|
28
35
|
.. note::
|
29
36
|
:meth:`torch_geometric.compile` is deprecated in favor of
|
30
37
|
:meth:`torch.compile`.
|
31
38
|
"""
|
32
39
|
warnings.warn("'torch_geometric.compile' is deprecated in favor of "
|
33
40
|
"'torch.compile'")
|
34
|
-
return torch.compile(model, *args, **kwargs)
|
41
|
+
return torch.compile(model, *args, **kwargs) # type: ignore
|
torch_geometric/_onnx.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from torch_geometric import is_compiling
|
4
|
+
|
5
|
+
|
6
|
+
def is_in_onnx_export() -> bool:
|
7
|
+
r"""Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via
|
8
|
+
:meth:`torch.onnx.export`.
|
9
|
+
"""
|
10
|
+
if is_compiling():
|
11
|
+
return False
|
12
|
+
if torch.jit.is_scripting():
|
13
|
+
return False
|
14
|
+
return torch.onnx.is_in_onnx_export()
|
@@ -0,0 +1,113 @@
|
|
1
|
+
import inspect
|
2
|
+
from dataclasses import fields, is_dataclass
|
3
|
+
from importlib import import_module
|
4
|
+
from typing import Any, Dict
|
5
|
+
|
6
|
+
from torch_geometric.config_store import (
|
7
|
+
class_from_dataclass,
|
8
|
+
dataclass_from_class,
|
9
|
+
)
|
10
|
+
from torch_geometric.isinstance import is_torch_instance
|
11
|
+
|
12
|
+
|
13
|
+
class ConfigMixin:
|
14
|
+
r"""Enables a class to serialize/deserialize itself to a dataclass."""
|
15
|
+
def config(self) -> Any:
|
16
|
+
r"""Creates a serializable configuration of the class."""
|
17
|
+
data_cls = dataclass_from_class(self.__class__)
|
18
|
+
if data_cls is None:
|
19
|
+
raise ValueError(f"Could not find the configuration class that "
|
20
|
+
f"belongs to '{self.__class__.__name__}'. Please "
|
21
|
+
f"register it in the configuration store.")
|
22
|
+
|
23
|
+
kwargs: Dict[str, Any] = {}
|
24
|
+
for field in fields(data_cls):
|
25
|
+
if not hasattr(self, field.name):
|
26
|
+
continue
|
27
|
+
kwargs[field.name] = _recursive_config(getattr(self, field.name))
|
28
|
+
return data_cls(**kwargs)
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def from_config(cls, cfg: Any, *args: Any, **kwargs: Any) -> Any:
|
32
|
+
r"""Instantiates the class from a serializable configuration."""
|
33
|
+
if getattr(cfg, '_target_', None):
|
34
|
+
cls = _locate_cls(cfg._target_)
|
35
|
+
elif isinstance(cfg, dict) and '_target_' in cfg:
|
36
|
+
cls = _locate_cls(cfg['_target_'])
|
37
|
+
|
38
|
+
data_cls = cfg.__class__
|
39
|
+
if not is_dataclass(data_cls):
|
40
|
+
data_cls = dataclass_from_class(cls)
|
41
|
+
if data_cls is None:
|
42
|
+
raise ValueError(f"Could not find the configuration class "
|
43
|
+
f"that belongs to '{cls.__name__}'. Please "
|
44
|
+
f"register it in the configuration store.")
|
45
|
+
|
46
|
+
field_names = {field.name for field in fields(data_cls)}
|
47
|
+
if isinstance(cfg, dict):
|
48
|
+
_kwargs = {k: v for k, v in cfg.items() if k in field_names}
|
49
|
+
cfg = data_cls(**_kwargs)
|
50
|
+
assert is_dataclass(cfg)
|
51
|
+
|
52
|
+
if len(args) > 0: # Convert `*args` to `**kwargs`:
|
53
|
+
param_names = list(inspect.signature(cls).parameters.keys())
|
54
|
+
if 'args' in param_names:
|
55
|
+
param_names.remove('args')
|
56
|
+
if 'kwargs' in param_names:
|
57
|
+
param_names.remove('kwargs')
|
58
|
+
|
59
|
+
for name, arg in zip(param_names, args):
|
60
|
+
kwargs[name] = arg
|
61
|
+
|
62
|
+
for key in field_names:
|
63
|
+
if key not in kwargs and key != '_target_':
|
64
|
+
kwargs[key] = _recursive_from_config(getattr(cfg, key))
|
65
|
+
|
66
|
+
return cls(**kwargs)
|
67
|
+
|
68
|
+
|
69
|
+
def _recursive_config(value: Any) -> Any:
|
70
|
+
if isinstance(value, ConfigMixin):
|
71
|
+
return value.config()
|
72
|
+
if is_torch_instance(value, ConfigMixin):
|
73
|
+
return value.config()
|
74
|
+
if isinstance(value, (tuple, list)):
|
75
|
+
return [_recursive_config(v) for v in value]
|
76
|
+
if isinstance(value, dict):
|
77
|
+
return {k: _recursive_config(v) for k, v in value.items()}
|
78
|
+
return value
|
79
|
+
|
80
|
+
|
81
|
+
def _recursive_from_config(value: Any) -> Any:
|
82
|
+
cls: Any = None
|
83
|
+
if is_dataclass(value):
|
84
|
+
if getattr(value, '_target_', None):
|
85
|
+
try:
|
86
|
+
cls = _locate_cls(value._target_) # type: ignore
|
87
|
+
except ImportError:
|
88
|
+
pass # Keep the dataclass as it is.
|
89
|
+
else:
|
90
|
+
cls = class_from_dataclass(value.__class__)
|
91
|
+
elif isinstance(value, dict) and '_target_' in value:
|
92
|
+
cls = _locate_cls(value['_target_'])
|
93
|
+
|
94
|
+
if cls is not None and issubclass(cls, ConfigMixin):
|
95
|
+
return cls.from_config(value)
|
96
|
+
if isinstance(value, (tuple, list)):
|
97
|
+
return [_recursive_from_config(v) for v in value]
|
98
|
+
if isinstance(value, dict):
|
99
|
+
return {k: _recursive_from_config(v) for k, v in value.items()}
|
100
|
+
return value
|
101
|
+
|
102
|
+
|
103
|
+
def _locate_cls(qualname: str) -> Any:
|
104
|
+
parts = qualname.split('.')
|
105
|
+
|
106
|
+
if len(parts) <= 1:
|
107
|
+
raise ValueError(f"Qualified name is missing a dot (got '{qualname}')")
|
108
|
+
|
109
|
+
if any([len(part) == 0 for part in parts]):
|
110
|
+
raise ValueError(f"Relative imports not supported (got '{qualname}')")
|
111
|
+
|
112
|
+
module_name, cls_name = '.'.join(parts[:-1]), parts[-1]
|
113
|
+
return getattr(import_module(module_name), cls_name)
|
torch_geometric/config_store.py
CHANGED
@@ -76,7 +76,7 @@ else:
|
|
76
76
|
|
77
77
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
78
78
|
if cls not in cls._instances:
|
79
|
-
instance = super(
|
79
|
+
instance = super().__call__(*args, **kwargs)
|
80
80
|
cls._instances[cls] = instance
|
81
81
|
return instance
|
82
82
|
return cls._instances[cls]
|
@@ -162,12 +162,19 @@ def map_annotation(
|
|
162
162
|
annotation: Any,
|
163
163
|
mapping: Optional[Dict[Any, Any]] = None,
|
164
164
|
) -> Any:
|
165
|
-
|
166
165
|
origin = getattr(annotation, '__origin__', None)
|
167
|
-
args = getattr(annotation, '__args__',
|
168
|
-
if origin
|
169
|
-
|
170
|
-
|
166
|
+
args: Tuple[Any, ...] = getattr(annotation, '__args__', tuple())
|
167
|
+
if origin in {Union, list, dict, tuple}:
|
168
|
+
assert origin is not None
|
169
|
+
args = tuple(map_annotation(a, mapping) for a in args)
|
170
|
+
if type(annotation).__name__ == 'GenericAlias':
|
171
|
+
# If annotated with `list[...]` or `dict[...]` (>= Python 3.10):
|
172
|
+
annotation = origin[args]
|
173
|
+
else:
|
174
|
+
# If annotated with `typing.List[...]` or `typing.Dict[...]`:
|
175
|
+
annotation = copy.copy(annotation)
|
176
|
+
annotation.__args__ = args
|
177
|
+
|
171
178
|
return annotation
|
172
179
|
|
173
180
|
if mapping is not None and annotation in mapping:
|
@@ -231,7 +238,7 @@ def to_dataclass(
|
|
231
238
|
if strict: # Check that keys in map_args or exclude_args are present.
|
232
239
|
keys = set() if map_args is None else set(map_args.keys())
|
233
240
|
if exclude_args is not None:
|
234
|
-
keys |=
|
241
|
+
keys |= {arg for arg in exclude_args if isinstance(arg, str)}
|
235
242
|
diff = keys - set(params.keys())
|
236
243
|
if len(diff) > 0:
|
237
244
|
raise ValueError(f"Expected input argument(s) {diff} in "
|
@@ -406,13 +413,13 @@ def fill_config_store() -> None:
|
|
406
413
|
|
407
414
|
# Register `torch_geometric.transforms` ###################################
|
408
415
|
transforms = torch_geometric.transforms
|
409
|
-
for cls_name in set(transforms.__all__) -
|
416
|
+
for cls_name in set(transforms.__all__) - {
|
410
417
|
'BaseTransform',
|
411
418
|
'Compose',
|
412
419
|
'ComposeFilters',
|
413
420
|
'LinearTransformation',
|
414
421
|
'AddMetaPaths', # TODO
|
415
|
-
|
422
|
+
}:
|
416
423
|
cls = to_dataclass(getattr(transforms, cls_name), base_cls=Transform)
|
417
424
|
# We use an explicit additional nesting level inside each config to
|
418
425
|
# allow for applying multiple transformations.
|
@@ -426,7 +433,7 @@ def fill_config_store() -> None:
|
|
426
433
|
'pre_transform': (Dict[str, Transform], field(default_factory=dict)),
|
427
434
|
}
|
428
435
|
|
429
|
-
for cls_name in set(datasets.__all__) - set(
|
436
|
+
for cls_name in set(datasets.__all__) - set():
|
430
437
|
cls = to_dataclass(getattr(datasets, cls_name), base_cls=Dataset,
|
431
438
|
map_args=map_dataset_args,
|
432
439
|
exclude_args=['pre_filter'])
|
@@ -434,32 +441,34 @@ def fill_config_store() -> None:
|
|
434
441
|
|
435
442
|
# Register `torch_geometric.models` #######################################
|
436
443
|
models = torch_geometric.nn.models.basic_gnn
|
437
|
-
for cls_name in set(models.__all__) - set(
|
444
|
+
for cls_name in set(models.__all__) - set():
|
438
445
|
cls = to_dataclass(getattr(models, cls_name), base_cls=Model)
|
439
446
|
config_store.store(cls_name, group='model', node=cls)
|
440
447
|
|
441
448
|
# Register `torch.optim.Optimizer` ########################################
|
442
|
-
for cls_name in
|
443
|
-
key
|
449
|
+
for cls_name in {
|
450
|
+
key
|
451
|
+
for key, cls in torch.optim.__dict__.items()
|
444
452
|
if inspect.isclass(cls) and issubclass(cls, torch.optim.Optimizer)
|
445
|
-
|
453
|
+
} - {
|
446
454
|
'Optimizer',
|
447
|
-
|
455
|
+
}:
|
448
456
|
cls = to_dataclass(getattr(torch.optim, cls_name), base_cls=Optimizer,
|
449
457
|
exclude_args=['params'])
|
450
458
|
config_store.store(cls_name, group='optimizer', node=cls)
|
451
459
|
|
452
460
|
# Register `torch.optim.lr_scheduler` #####################################
|
453
|
-
for cls_name in
|
454
|
-
key
|
461
|
+
for cls_name in {
|
462
|
+
key
|
463
|
+
for key, cls in torch.optim.lr_scheduler.__dict__.items()
|
455
464
|
if inspect.isclass(cls)
|
456
|
-
|
465
|
+
} - {
|
457
466
|
'Optimizer',
|
458
467
|
'_LRScheduler',
|
459
468
|
'Counter',
|
460
469
|
'SequentialLR',
|
461
470
|
'ChainedScheduler',
|
462
|
-
|
471
|
+
}:
|
463
472
|
cls = to_dataclass(getattr(torch.optim.lr_scheduler, cls_name),
|
464
473
|
base_cls=LRScheduler, exclude_args=['optimizer'])
|
465
474
|
config_store.store(cls_name, group='lr_scheduler', node=cls)
|
torch_geometric/data/__init__.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
# flake8: noqa
|
2
2
|
|
3
|
+
import torch
|
4
|
+
import torch_geometric.typing
|
5
|
+
|
3
6
|
from .feature_store import FeatureStore, TensorAttr
|
4
|
-
from .graph_store import GraphStore, EdgeAttr
|
7
|
+
from .graph_store import GraphStore, EdgeAttr, EdgeLayout
|
5
8
|
from .data import Data
|
6
9
|
from .hetero_data import HeteroData
|
7
10
|
from .batch import Batch
|
@@ -13,6 +16,7 @@ from .on_disk_dataset import OnDiskDataset
|
|
13
16
|
from .makedirs import makedirs
|
14
17
|
from .download import download_url, download_google_url
|
15
18
|
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
|
19
|
+
from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
|
16
20
|
|
17
21
|
from torch_geometric.lazy_loader import LazyLoader
|
18
22
|
|
@@ -24,6 +28,8 @@ data_classes = [
|
|
24
28
|
'Dataset',
|
25
29
|
'InMemoryDataset',
|
26
30
|
'OnDiskDataset',
|
31
|
+
'LargeGraphIndexer',
|
32
|
+
'TripletLike',
|
27
33
|
]
|
28
34
|
|
29
35
|
remote_backend_classes = [
|
@@ -47,6 +53,8 @@ helper_functions = [
|
|
47
53
|
'extract_zip',
|
48
54
|
'extract_bz2',
|
49
55
|
'extract_gz',
|
56
|
+
'get_features_for_triplets',
|
57
|
+
"get_features_for_triplets_groups",
|
50
58
|
]
|
51
59
|
|
52
60
|
__all__ = data_classes + remote_backend_classes + helper_functions
|
@@ -68,6 +76,21 @@ from torch_geometric.loader import DataLoader
|
|
68
76
|
from torch_geometric.loader import DataListLoader
|
69
77
|
from torch_geometric.loader import DenseDataLoader
|
70
78
|
|
79
|
+
# Serialization ###############################################################
|
80
|
+
|
81
|
+
if torch_geometric.typing.WITH_PT24:
|
82
|
+
torch.serialization.add_safe_globals([
|
83
|
+
Data,
|
84
|
+
HeteroData,
|
85
|
+
TemporalData,
|
86
|
+
ClusterData,
|
87
|
+
TensorAttr,
|
88
|
+
EdgeAttr,
|
89
|
+
EdgeLayout,
|
90
|
+
])
|
91
|
+
|
92
|
+
# Deprecations ################################################################
|
93
|
+
|
71
94
|
NeighborSampler = deprecated( # type: ignore
|
72
95
|
details="use 'loader.NeighborSampler' instead",
|
73
96
|
func_name='data.NeighborSampler',
|
torch_geometric/data/batch.py
CHANGED
@@ -118,8 +118,8 @@ class Batch(metaclass=DynamicInheritance):
|
|
118
118
|
"""
|
119
119
|
if not hasattr(self, '_slice_dict'):
|
120
120
|
raise RuntimeError(
|
121
|
-
|
122
|
-
|
121
|
+
"Cannot reconstruct 'Data' object from 'Batch' because "
|
122
|
+
"'Batch' was not created via 'Batch.from_data_list()'")
|
123
123
|
|
124
124
|
data = separate(
|
125
125
|
cls=self.__class__.__bases__[-1],
|
torch_geometric/data/collate.py
CHANGED
@@ -16,7 +16,7 @@ import torch
|
|
16
16
|
from torch import Tensor
|
17
17
|
|
18
18
|
import torch_geometric.typing
|
19
|
-
from torch_geometric import EdgeIndex
|
19
|
+
from torch_geometric import EdgeIndex, Index
|
20
20
|
from torch_geometric.data.data import BaseData
|
21
21
|
from torch_geometric.data.storage import BaseStorage, NodeStorage
|
22
22
|
from torch_geometric.edge_index import SortOrder
|
@@ -184,7 +184,8 @@ def _collate(
|
|
184
184
|
return value, slices, incs
|
185
185
|
|
186
186
|
out = None
|
187
|
-
if torch.utils.data.get_worker_info() is not None
|
187
|
+
if (torch.utils.data.get_worker_info() is not None
|
188
|
+
and not isinstance(elem, (Index, EdgeIndex))):
|
188
189
|
# Write directly into shared memory to avoid an extra copy:
|
189
190
|
numel = sum(value.numel() for value in values)
|
190
191
|
if torch_geometric.typing.WITH_PT20:
|
@@ -203,6 +204,11 @@ def _collate(
|
|
203
204
|
|
204
205
|
value = torch.cat(values, dim=cat_dim or 0, out=out)
|
205
206
|
|
207
|
+
if increment and isinstance(value, Index) and values[0].is_sorted:
|
208
|
+
# Check whether the whole `Index` is sorted:
|
209
|
+
if (value.diff() >= 0).all():
|
210
|
+
value._is_sorted = True
|
211
|
+
|
206
212
|
if increment and isinstance(value, EdgeIndex) and values[0].is_sorted:
|
207
213
|
# Check whether the whole `EdgeIndex` is sorted by row:
|
208
214
|
if values[0].is_sorted_by_row and (value[0].diff() >= 0).all():
|
torch_geometric/data/data.py
CHANGED
@@ -31,6 +31,7 @@ from torch_geometric.data.storage import (
|
|
31
31
|
NodeStorage,
|
32
32
|
)
|
33
33
|
from torch_geometric.deprecation import deprecated
|
34
|
+
from torch_geometric.index import Index
|
34
35
|
from torch_geometric.typing import (
|
35
36
|
EdgeTensorType,
|
36
37
|
EdgeType,
|
@@ -290,13 +291,14 @@ class BaseData:
|
|
290
291
|
self,
|
291
292
|
start_time: Union[float, int],
|
292
293
|
end_time: Union[float, int],
|
294
|
+
attr: str = 'time',
|
293
295
|
) -> Self:
|
294
296
|
r"""Returns a snapshot of :obj:`data` to only hold events that occurred
|
295
297
|
in period :obj:`[start_time, end_time]`.
|
296
298
|
"""
|
297
299
|
out = copy.copy(self)
|
298
300
|
for store in out.stores:
|
299
|
-
store.snapshot(start_time, end_time)
|
301
|
+
store.snapshot(start_time, end_time, attr)
|
300
302
|
return out
|
301
303
|
|
302
304
|
def up_to(self, end_time: Union[float, int]) -> Self:
|
@@ -644,7 +646,7 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
644
646
|
return self
|
645
647
|
|
646
648
|
def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
|
647
|
-
if is_sparse(value) and 'adj' in key:
|
649
|
+
if is_sparse(value) and ('adj' in key or 'edge_index' in key):
|
648
650
|
return (0, 1)
|
649
651
|
elif 'index' in key or key == 'face':
|
650
652
|
return -1
|
@@ -653,9 +655,17 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
653
655
|
|
654
656
|
def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
|
655
657
|
if 'batch' in key and isinstance(value, Tensor):
|
658
|
+
if isinstance(value, Index):
|
659
|
+
return value.get_dim_size()
|
656
660
|
return int(value.max()) + 1
|
657
661
|
elif 'index' in key or key == 'face':
|
658
|
-
|
662
|
+
num_nodes = self.num_nodes
|
663
|
+
if num_nodes is None:
|
664
|
+
raise RuntimeError(f"Unable to infer 'num_nodes' from the "
|
665
|
+
f"attribute '{key}'. Please explicitly set "
|
666
|
+
f"'num_nodes' as an attribute of 'data' to "
|
667
|
+
f"prevent this error")
|
668
|
+
return num_nodes
|
659
669
|
else:
|
660
670
|
return 0
|
661
671
|
|
@@ -934,16 +944,14 @@ class Data(BaseData, FeatureStore, GraphStore):
|
|
934
944
|
r"""Iterates over all attributes in the data, yielding their attribute
|
935
945
|
names and values.
|
936
946
|
"""
|
937
|
-
|
938
|
-
yield key, value
|
947
|
+
yield from self._store.items()
|
939
948
|
|
940
949
|
def __call__(self, *args: str) -> Iterable:
|
941
950
|
r"""Iterates over all attributes :obj:`*args` in the data, yielding
|
942
951
|
their attribute names and values.
|
943
952
|
If :obj:`*args` is not given, will iterate over all attributes.
|
944
953
|
"""
|
945
|
-
|
946
|
-
yield key, value
|
954
|
+
yield from self._store.items(*args)
|
947
955
|
|
948
956
|
@property
|
949
957
|
def x(self) -> Optional[Tensor]:
|
@@ -1163,7 +1171,7 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
|
|
1163
1171
|
f'[{value.num_rows}, {value.num_cols}])')
|
1164
1172
|
elif isinstance(value, str):
|
1165
1173
|
out = f"'{value}'"
|
1166
|
-
elif isinstance(value, Sequence):
|
1174
|
+
elif isinstance(value, (Sequence, set)):
|
1167
1175
|
out = str([len(value)])
|
1168
1176
|
elif isinstance(value, Mapping) and len(value) == 0:
|
1169
1177
|
out = '{}'
|
torch_geometric/data/database.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1
|
-
import
|
1
|
+
import io
|
2
2
|
import warnings
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from functools import cached_property
|
6
6
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
7
|
-
from uuid import uuid4
|
8
7
|
|
9
8
|
import torch
|
10
9
|
from torch import Tensor
|
11
10
|
from tqdm import tqdm
|
12
11
|
|
13
|
-
from torch_geometric import EdgeIndex
|
12
|
+
from torch_geometric import EdgeIndex, Index
|
14
13
|
from torch_geometric.edge_index import SortOrder
|
15
14
|
from torch_geometric.utils.mixin import CastMixin
|
16
15
|
|
@@ -19,9 +18,17 @@ from torch_geometric.utils.mixin import CastMixin
|
|
19
18
|
class TensorInfo(CastMixin):
|
20
19
|
dtype: torch.dtype
|
21
20
|
size: Tuple[int, ...] = (-1, )
|
21
|
+
is_index: bool = False
|
22
22
|
is_edge_index: bool = False
|
23
23
|
|
24
24
|
def __post_init__(self) -> None:
|
25
|
+
if self.is_index and self.is_edge_index:
|
26
|
+
raise ValueError("Tensor cannot be a 'Index' and 'EdgeIndex' "
|
27
|
+
"tensor at the same time")
|
28
|
+
|
29
|
+
if self.is_index:
|
30
|
+
self.size = (-1, )
|
31
|
+
|
25
32
|
if self.is_edge_index:
|
26
33
|
self.size = (2, -1)
|
27
34
|
|
@@ -33,7 +40,8 @@ def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]:
|
|
33
40
|
return value
|
34
41
|
if 'dtype' not in value:
|
35
42
|
return value
|
36
|
-
|
43
|
+
valid_keys = {'dtype', 'size', 'is_index', 'is_edge_index'}
|
44
|
+
if len(set(value.keys()) | valid_keys) != len(valid_keys):
|
37
45
|
return value
|
38
46
|
return TensorInfo.cast(value)
|
39
47
|
|
@@ -107,11 +115,9 @@ class Database(ABC):
|
|
107
115
|
r"""Connects to the database.
|
108
116
|
Databases will automatically connect on instantiation.
|
109
117
|
"""
|
110
|
-
pass
|
111
118
|
|
112
119
|
def close(self) -> None:
|
113
120
|
r"""Closes the connection to the database."""
|
114
|
-
pass
|
115
121
|
|
116
122
|
@abstractmethod
|
117
123
|
def insert(self, index: int, data: Any) -> None:
|
@@ -373,8 +379,9 @@ class SQLiteDatabase(Database):
|
|
373
379
|
|
374
380
|
# We create a temporary ID table to then perform an INNER JOIN.
|
375
381
|
# This avoids having a long IN clause and guarantees sorted outputs:
|
376
|
-
join_table_name = f'{self.name}
|
377
|
-
|
382
|
+
join_table_name = f'{self.name}__join'
|
383
|
+
# Temporary tables do not lock the database.
|
384
|
+
query = (f'CREATE TEMP TABLE {join_table_name} (\n'
|
378
385
|
f' id INTEGER,\n'
|
379
386
|
f' row_id INTEGER\n'
|
380
387
|
f')')
|
@@ -452,10 +459,22 @@ class SQLiteDatabase(Database):
|
|
452
459
|
if isinstance(col, Tensor) and not isinstance(schema, TensorInfo):
|
453
460
|
self.schema[key] = schema = TensorInfo(
|
454
461
|
col.dtype,
|
462
|
+
is_index=isinstance(col, Index),
|
455
463
|
is_edge_index=isinstance(col, EdgeIndex),
|
456
464
|
)
|
457
465
|
|
458
|
-
if isinstance(schema, TensorInfo) and schema.
|
466
|
+
if isinstance(schema, TensorInfo) and schema.is_index:
|
467
|
+
assert isinstance(col, Index)
|
468
|
+
|
469
|
+
meta = torch.tensor([
|
470
|
+
col.dim_size if col.dim_size is not None else -1,
|
471
|
+
col.is_sorted,
|
472
|
+
], dtype=torch.long)
|
473
|
+
|
474
|
+
out.append(meta.numpy().tobytes() +
|
475
|
+
col.as_tensor().numpy().tobytes())
|
476
|
+
|
477
|
+
elif isinstance(schema, TensorInfo) and schema.is_edge_index:
|
459
478
|
assert isinstance(col, EdgeIndex)
|
460
479
|
|
461
480
|
num_rows, num_cols = col.sparse_size()
|
@@ -466,7 +485,8 @@ class SQLiteDatabase(Database):
|
|
466
485
|
col.is_undirected,
|
467
486
|
], dtype=torch.long)
|
468
487
|
|
469
|
-
out.append(meta.numpy().tobytes() +
|
488
|
+
out.append(meta.numpy().tobytes() +
|
489
|
+
col.as_tensor().numpy().tobytes())
|
470
490
|
|
471
491
|
elif isinstance(schema, TensorInfo):
|
472
492
|
assert isinstance(col, Tensor)
|
@@ -476,7 +496,9 @@ class SQLiteDatabase(Database):
|
|
476
496
|
out.append(col)
|
477
497
|
|
478
498
|
else:
|
479
|
-
|
499
|
+
buffer = io.BytesIO()
|
500
|
+
torch.save(col, buffer)
|
501
|
+
out.append(buffer.getvalue())
|
480
502
|
|
481
503
|
return out
|
482
504
|
|
@@ -490,7 +512,23 @@ class SQLiteDatabase(Database):
|
|
490
512
|
for i, (key, schema) in enumerate(self.schema.items()):
|
491
513
|
value = row[i]
|
492
514
|
|
493
|
-
if isinstance(schema, TensorInfo) and schema.
|
515
|
+
if isinstance(schema, TensorInfo) and schema.is_index:
|
516
|
+
meta = torch.frombuffer(value[:16], dtype=torch.long).tolist()
|
517
|
+
dim_size = meta[0] if meta[0] >= 0 else None
|
518
|
+
is_sorted = meta[1] > 0
|
519
|
+
|
520
|
+
if len(value) > 16:
|
521
|
+
tensor = torch.frombuffer(value[16:], dtype=schema.dtype)
|
522
|
+
else:
|
523
|
+
tensor = torch.empty(0, dtype=schema.dtype)
|
524
|
+
|
525
|
+
out_dict[key] = Index(
|
526
|
+
tensor.view(*schema.size),
|
527
|
+
dim_size=dim_size,
|
528
|
+
is_sorted=is_sorted,
|
529
|
+
)
|
530
|
+
|
531
|
+
elif isinstance(schema, TensorInfo) and schema.is_edge_index:
|
494
532
|
meta = torch.frombuffer(value[:32], dtype=torch.long).tolist()
|
495
533
|
num_rows = meta[0] if meta[0] >= 0 else None
|
496
534
|
num_cols = meta[1] if meta[1] >= 0 else None
|
@@ -523,7 +561,10 @@ class SQLiteDatabase(Database):
|
|
523
561
|
out_dict[key] = value
|
524
562
|
|
525
563
|
else:
|
526
|
-
out_dict[key] =
|
564
|
+
out_dict[key] = torch.load(
|
565
|
+
io.BytesIO(value),
|
566
|
+
weights_only=False,
|
567
|
+
)
|
527
568
|
|
528
569
|
# In case `0` exists as integer in the schema, this means that the
|
529
570
|
# schema was passed as either a single entry or a tuple:
|
@@ -608,7 +649,12 @@ class RocksDatabase(Database):
|
|
608
649
|
# Ensure that data is not a view of a larger tensor:
|
609
650
|
if isinstance(row, Tensor):
|
610
651
|
row = row.clone()
|
611
|
-
|
652
|
+
buffer = io.BytesIO()
|
653
|
+
torch.save(row, buffer)
|
654
|
+
return buffer.getvalue()
|
612
655
|
|
613
656
|
def _deserialize(self, row: bytes) -> Any:
|
614
|
-
return
|
657
|
+
return torch.load(
|
658
|
+
io.BytesIO(row),
|
659
|
+
weights_only=False,
|
660
|
+
)
|