pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: flit 3.9.0
2
+ Generator: flit 3.10.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,7 +1,15 @@
1
+ from collections import defaultdict
2
+
3
+ import torch
4
+ import torch_geometric.typing
5
+
1
6
  from ._compile import compile, is_compiling
7
+ from ._onnx import is_in_onnx_export
8
+ from .index import Index
2
9
  from .edge_index import EdgeIndex
3
10
  from .seed import seed_everything
4
11
  from .home import get_home_dir, set_home_dir
12
+ from .device import is_mps_available, is_xpu_available, device
5
13
  from .isinstance import is_torch_instance
6
14
  from .debug import is_debug_enabled, debug, set_debug
7
15
 
@@ -22,15 +30,20 @@ from .lazy_loader import LazyLoader
22
30
  contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
23
31
  graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
24
32
 
25
- __version__ = '2.6.0.dev20240319'
33
+ __version__ = '2.7.0.dev20250114'
26
34
 
27
35
  __all__ = [
36
+ 'Index',
28
37
  'EdgeIndex',
29
38
  'seed_everything',
30
39
  'get_home_dir',
31
40
  'set_home_dir',
32
41
  'compile',
33
42
  'is_compiling',
43
+ 'is_in_onnx_export',
44
+ 'is_mps_available',
45
+ 'is_xpu_available',
46
+ 'device',
34
47
  'is_torch_instance',
35
48
  'is_debug_enabled',
36
49
  'debug',
@@ -41,3 +54,17 @@ __all__ = [
41
54
  'torch_geometric',
42
55
  '__version__',
43
56
  ]
57
+
58
+ # Serialization ###############################################################
59
+
60
+ if torch_geometric.typing.WITH_PT24:
61
+ torch.serialization.add_safe_globals([
62
+ dict,
63
+ list,
64
+ defaultdict,
65
+ Index,
66
+ torch_geometric.index.CatMetadata,
67
+ EdgeIndex,
68
+ torch_geometric.edge_index.SortOrder,
69
+ torch_geometric.edge_index.CatMetadata,
70
+ ])
@@ -10,6 +10,8 @@ def is_compiling() -> bool:
10
10
  r"""Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via
11
11
  :meth:`torch.compile`.
12
12
  """
13
+ if torch_geometric.typing.WITH_PT23:
14
+ return torch.compiler.is_compiling()
13
15
  if torch_geometric.typing.WITH_PT21:
14
16
  return torch._dynamo.is_compiling()
15
17
  return False # pragma: no cover
@@ -25,10 +27,15 @@ def compile(
25
27
  This function has the same signature as :meth:`torch.compile` (see
26
28
  `here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).
27
29
 
30
+ Args:
31
+ model: The model to compile.
32
+ *args: Additional arguments of :meth:`torch.compile`.
33
+ **kwargs: Additional keyword arguments of :meth:`torch.compile`.
34
+
28
35
  .. note::
29
36
  :meth:`torch_geometric.compile` is deprecated in favor of
30
37
  :meth:`torch.compile`.
31
38
  """
32
39
  warnings.warn("'torch_geometric.compile' is deprecated in favor of "
33
40
  "'torch.compile'")
34
- return torch.compile(model, *args, **kwargs)
41
+ return torch.compile(model, *args, **kwargs) # type: ignore
@@ -0,0 +1,14 @@
1
+ import torch
2
+
3
+ from torch_geometric import is_compiling
4
+
5
+
6
+ def is_in_onnx_export() -> bool:
7
+ r"""Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via
8
+ :meth:`torch.onnx.export`.
9
+ """
10
+ if is_compiling():
11
+ return False
12
+ if torch.jit.is_scripting():
13
+ return False
14
+ return torch.onnx.is_in_onnx_export()
@@ -0,0 +1,113 @@
1
+ import inspect
2
+ from dataclasses import fields, is_dataclass
3
+ from importlib import import_module
4
+ from typing import Any, Dict
5
+
6
+ from torch_geometric.config_store import (
7
+ class_from_dataclass,
8
+ dataclass_from_class,
9
+ )
10
+ from torch_geometric.isinstance import is_torch_instance
11
+
12
+
13
+ class ConfigMixin:
14
+ r"""Enables a class to serialize/deserialize itself to a dataclass."""
15
+ def config(self) -> Any:
16
+ r"""Creates a serializable configuration of the class."""
17
+ data_cls = dataclass_from_class(self.__class__)
18
+ if data_cls is None:
19
+ raise ValueError(f"Could not find the configuration class that "
20
+ f"belongs to '{self.__class__.__name__}'. Please "
21
+ f"register it in the configuration store.")
22
+
23
+ kwargs: Dict[str, Any] = {}
24
+ for field in fields(data_cls):
25
+ if not hasattr(self, field.name):
26
+ continue
27
+ kwargs[field.name] = _recursive_config(getattr(self, field.name))
28
+ return data_cls(**kwargs)
29
+
30
+ @classmethod
31
+ def from_config(cls, cfg: Any, *args: Any, **kwargs: Any) -> Any:
32
+ r"""Instantiates the class from a serializable configuration."""
33
+ if getattr(cfg, '_target_', None):
34
+ cls = _locate_cls(cfg._target_)
35
+ elif isinstance(cfg, dict) and '_target_' in cfg:
36
+ cls = _locate_cls(cfg['_target_'])
37
+
38
+ data_cls = cfg.__class__
39
+ if not is_dataclass(data_cls):
40
+ data_cls = dataclass_from_class(cls)
41
+ if data_cls is None:
42
+ raise ValueError(f"Could not find the configuration class "
43
+ f"that belongs to '{cls.__name__}'. Please "
44
+ f"register it in the configuration store.")
45
+
46
+ field_names = {field.name for field in fields(data_cls)}
47
+ if isinstance(cfg, dict):
48
+ _kwargs = {k: v for k, v in cfg.items() if k in field_names}
49
+ cfg = data_cls(**_kwargs)
50
+ assert is_dataclass(cfg)
51
+
52
+ if len(args) > 0: # Convert `*args` to `**kwargs`:
53
+ param_names = list(inspect.signature(cls).parameters.keys())
54
+ if 'args' in param_names:
55
+ param_names.remove('args')
56
+ if 'kwargs' in param_names:
57
+ param_names.remove('kwargs')
58
+
59
+ for name, arg in zip(param_names, args):
60
+ kwargs[name] = arg
61
+
62
+ for key in field_names:
63
+ if key not in kwargs and key != '_target_':
64
+ kwargs[key] = _recursive_from_config(getattr(cfg, key))
65
+
66
+ return cls(**kwargs)
67
+
68
+
69
+ def _recursive_config(value: Any) -> Any:
70
+ if isinstance(value, ConfigMixin):
71
+ return value.config()
72
+ if is_torch_instance(value, ConfigMixin):
73
+ return value.config()
74
+ if isinstance(value, (tuple, list)):
75
+ return [_recursive_config(v) for v in value]
76
+ if isinstance(value, dict):
77
+ return {k: _recursive_config(v) for k, v in value.items()}
78
+ return value
79
+
80
+
81
+ def _recursive_from_config(value: Any) -> Any:
82
+ cls: Any = None
83
+ if is_dataclass(value):
84
+ if getattr(value, '_target_', None):
85
+ try:
86
+ cls = _locate_cls(value._target_) # type: ignore
87
+ except ImportError:
88
+ pass # Keep the dataclass as it is.
89
+ else:
90
+ cls = class_from_dataclass(value.__class__)
91
+ elif isinstance(value, dict) and '_target_' in value:
92
+ cls = _locate_cls(value['_target_'])
93
+
94
+ if cls is not None and issubclass(cls, ConfigMixin):
95
+ return cls.from_config(value)
96
+ if isinstance(value, (tuple, list)):
97
+ return [_recursive_from_config(v) for v in value]
98
+ if isinstance(value, dict):
99
+ return {k: _recursive_from_config(v) for k, v in value.items()}
100
+ return value
101
+
102
+
103
+ def _locate_cls(qualname: str) -> Any:
104
+ parts = qualname.split('.')
105
+
106
+ if len(parts) <= 1:
107
+ raise ValueError(f"Qualified name is missing a dot (got '{qualname}')")
108
+
109
+ if any([len(part) == 0 for part in parts]):
110
+ raise ValueError(f"Relative imports not supported (got '{qualname}')")
111
+
112
+ module_name, cls_name = '.'.join(parts[:-1]), parts[-1]
113
+ return getattr(import_module(module_name), cls_name)
@@ -76,7 +76,7 @@ else:
76
76
 
77
77
  def __call__(cls, *args: Any, **kwargs: Any) -> Any:
78
78
  if cls not in cls._instances:
79
- instance = super(Singleton, cls).__call__(*args, **kwargs)
79
+ instance = super().__call__(*args, **kwargs)
80
80
  cls._instances[cls] = instance
81
81
  return instance
82
82
  return cls._instances[cls]
@@ -162,12 +162,19 @@ def map_annotation(
162
162
  annotation: Any,
163
163
  mapping: Optional[Dict[Any, Any]] = None,
164
164
  ) -> Any:
165
-
166
165
  origin = getattr(annotation, '__origin__', None)
167
- args = getattr(annotation, '__args__', [])
168
- if origin == Union or origin == list or origin == dict:
169
- annotation = copy.copy(annotation)
170
- annotation.__args__ = tuple(map_annotation(a, mapping) for a in args)
166
+ args: Tuple[Any, ...] = getattr(annotation, '__args__', tuple())
167
+ if origin in {Union, list, dict, tuple}:
168
+ assert origin is not None
169
+ args = tuple(map_annotation(a, mapping) for a in args)
170
+ if type(annotation).__name__ == 'GenericAlias':
171
+ # If annotated with `list[...]` or `dict[...]` (>= Python 3.10):
172
+ annotation = origin[args]
173
+ else:
174
+ # If annotated with `typing.List[...]` or `typing.Dict[...]`:
175
+ annotation = copy.copy(annotation)
176
+ annotation.__args__ = args
177
+
171
178
  return annotation
172
179
 
173
180
  if mapping is not None and annotation in mapping:
@@ -231,7 +238,7 @@ def to_dataclass(
231
238
  if strict: # Check that keys in map_args or exclude_args are present.
232
239
  keys = set() if map_args is None else set(map_args.keys())
233
240
  if exclude_args is not None:
234
- keys |= set([arg for arg in exclude_args if isinstance(arg, str)])
241
+ keys |= {arg for arg in exclude_args if isinstance(arg, str)}
235
242
  diff = keys - set(params.keys())
236
243
  if len(diff) > 0:
237
244
  raise ValueError(f"Expected input argument(s) {diff} in "
@@ -406,13 +413,13 @@ def fill_config_store() -> None:
406
413
 
407
414
  # Register `torch_geometric.transforms` ###################################
408
415
  transforms = torch_geometric.transforms
409
- for cls_name in set(transforms.__all__) - set([
416
+ for cls_name in set(transforms.__all__) - {
410
417
  'BaseTransform',
411
418
  'Compose',
412
419
  'ComposeFilters',
413
420
  'LinearTransformation',
414
421
  'AddMetaPaths', # TODO
415
- ]):
422
+ }:
416
423
  cls = to_dataclass(getattr(transforms, cls_name), base_cls=Transform)
417
424
  # We use an explicit additional nesting level inside each config to
418
425
  # allow for applying multiple transformations.
@@ -426,7 +433,7 @@ def fill_config_store() -> None:
426
433
  'pre_transform': (Dict[str, Transform], field(default_factory=dict)),
427
434
  }
428
435
 
429
- for cls_name in set(datasets.__all__) - set([]):
436
+ for cls_name in set(datasets.__all__) - set():
430
437
  cls = to_dataclass(getattr(datasets, cls_name), base_cls=Dataset,
431
438
  map_args=map_dataset_args,
432
439
  exclude_args=['pre_filter'])
@@ -434,32 +441,34 @@ def fill_config_store() -> None:
434
441
 
435
442
  # Register `torch_geometric.models` #######################################
436
443
  models = torch_geometric.nn.models.basic_gnn
437
- for cls_name in set(models.__all__) - set([]):
444
+ for cls_name in set(models.__all__) - set():
438
445
  cls = to_dataclass(getattr(models, cls_name), base_cls=Model)
439
446
  config_store.store(cls_name, group='model', node=cls)
440
447
 
441
448
  # Register `torch.optim.Optimizer` ########################################
442
- for cls_name in set([
443
- key for key, cls in torch.optim.__dict__.items()
449
+ for cls_name in {
450
+ key
451
+ for key, cls in torch.optim.__dict__.items()
444
452
  if inspect.isclass(cls) and issubclass(cls, torch.optim.Optimizer)
445
- ]) - set([
453
+ } - {
446
454
  'Optimizer',
447
- ]):
455
+ }:
448
456
  cls = to_dataclass(getattr(torch.optim, cls_name), base_cls=Optimizer,
449
457
  exclude_args=['params'])
450
458
  config_store.store(cls_name, group='optimizer', node=cls)
451
459
 
452
460
  # Register `torch.optim.lr_scheduler` #####################################
453
- for cls_name in set([
454
- key for key, cls in torch.optim.lr_scheduler.__dict__.items()
461
+ for cls_name in {
462
+ key
463
+ for key, cls in torch.optim.lr_scheduler.__dict__.items()
455
464
  if inspect.isclass(cls)
456
- ]) - set([
465
+ } - {
457
466
  'Optimizer',
458
467
  '_LRScheduler',
459
468
  'Counter',
460
469
  'SequentialLR',
461
470
  'ChainedScheduler',
462
- ]):
471
+ }:
463
472
  cls = to_dataclass(getattr(torch.optim.lr_scheduler, cls_name),
464
473
  base_cls=LRScheduler, exclude_args=['optimizer'])
465
474
  config_store.store(cls_name, group='lr_scheduler', node=cls)
@@ -1,7 +1,10 @@
1
1
  # flake8: noqa
2
2
 
3
+ import torch
4
+ import torch_geometric.typing
5
+
3
6
  from .feature_store import FeatureStore, TensorAttr
4
- from .graph_store import GraphStore, EdgeAttr
7
+ from .graph_store import GraphStore, EdgeAttr, EdgeLayout
5
8
  from .data import Data
6
9
  from .hetero_data import HeteroData
7
10
  from .batch import Batch
@@ -13,6 +16,7 @@ from .on_disk_dataset import OnDiskDataset
13
16
  from .makedirs import makedirs
14
17
  from .download import download_url, download_google_url
15
18
  from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
19
+ from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
16
20
 
17
21
  from torch_geometric.lazy_loader import LazyLoader
18
22
 
@@ -24,6 +28,8 @@ data_classes = [
24
28
  'Dataset',
25
29
  'InMemoryDataset',
26
30
  'OnDiskDataset',
31
+ 'LargeGraphIndexer',
32
+ 'TripletLike',
27
33
  ]
28
34
 
29
35
  remote_backend_classes = [
@@ -47,6 +53,8 @@ helper_functions = [
47
53
  'extract_zip',
48
54
  'extract_bz2',
49
55
  'extract_gz',
56
+ 'get_features_for_triplets',
57
+ "get_features_for_triplets_groups",
50
58
  ]
51
59
 
52
60
  __all__ = data_classes + remote_backend_classes + helper_functions
@@ -68,6 +76,21 @@ from torch_geometric.loader import DataLoader
68
76
  from torch_geometric.loader import DataListLoader
69
77
  from torch_geometric.loader import DenseDataLoader
70
78
 
79
+ # Serialization ###############################################################
80
+
81
+ if torch_geometric.typing.WITH_PT24:
82
+ torch.serialization.add_safe_globals([
83
+ Data,
84
+ HeteroData,
85
+ TemporalData,
86
+ ClusterData,
87
+ TensorAttr,
88
+ EdgeAttr,
89
+ EdgeLayout,
90
+ ])
91
+
92
+ # Deprecations ################################################################
93
+
71
94
  NeighborSampler = deprecated( # type: ignore
72
95
  details="use 'loader.NeighborSampler' instead",
73
96
  func_name='data.NeighborSampler',
@@ -118,8 +118,8 @@ class Batch(metaclass=DynamicInheritance):
118
118
  """
119
119
  if not hasattr(self, '_slice_dict'):
120
120
  raise RuntimeError(
121
- ("Cannot reconstruct 'Data' object from 'Batch' because "
122
- "'Batch' was not created via 'Batch.from_data_list()'"))
121
+ "Cannot reconstruct 'Data' object from 'Batch' because "
122
+ "'Batch' was not created via 'Batch.from_data_list()'")
123
123
 
124
124
  data = separate(
125
125
  cls=self.__class__.__bases__[-1],
@@ -16,7 +16,7 @@ import torch
16
16
  from torch import Tensor
17
17
 
18
18
  import torch_geometric.typing
19
- from torch_geometric import EdgeIndex
19
+ from torch_geometric import EdgeIndex, Index
20
20
  from torch_geometric.data.data import BaseData
21
21
  from torch_geometric.data.storage import BaseStorage, NodeStorage
22
22
  from torch_geometric.edge_index import SortOrder
@@ -184,7 +184,8 @@ def _collate(
184
184
  return value, slices, incs
185
185
 
186
186
  out = None
187
- if torch.utils.data.get_worker_info() is not None:
187
+ if (torch.utils.data.get_worker_info() is not None
188
+ and not isinstance(elem, (Index, EdgeIndex))):
188
189
  # Write directly into shared memory to avoid an extra copy:
189
190
  numel = sum(value.numel() for value in values)
190
191
  if torch_geometric.typing.WITH_PT20:
@@ -203,6 +204,11 @@ def _collate(
203
204
 
204
205
  value = torch.cat(values, dim=cat_dim or 0, out=out)
205
206
 
207
+ if increment and isinstance(value, Index) and values[0].is_sorted:
208
+ # Check whether the whole `Index` is sorted:
209
+ if (value.diff() >= 0).all():
210
+ value._is_sorted = True
211
+
206
212
  if increment and isinstance(value, EdgeIndex) and values[0].is_sorted:
207
213
  # Check whether the whole `EdgeIndex` is sorted by row:
208
214
  if values[0].is_sorted_by_row and (value[0].diff() >= 0).all():
@@ -31,6 +31,7 @@ from torch_geometric.data.storage import (
31
31
  NodeStorage,
32
32
  )
33
33
  from torch_geometric.deprecation import deprecated
34
+ from torch_geometric.index import Index
34
35
  from torch_geometric.typing import (
35
36
  EdgeTensorType,
36
37
  EdgeType,
@@ -290,13 +291,14 @@ class BaseData:
290
291
  self,
291
292
  start_time: Union[float, int],
292
293
  end_time: Union[float, int],
294
+ attr: str = 'time',
293
295
  ) -> Self:
294
296
  r"""Returns a snapshot of :obj:`data` to only hold events that occurred
295
297
  in period :obj:`[start_time, end_time]`.
296
298
  """
297
299
  out = copy.copy(self)
298
300
  for store in out.stores:
299
- store.snapshot(start_time, end_time)
301
+ store.snapshot(start_time, end_time, attr)
300
302
  return out
301
303
 
302
304
  def up_to(self, end_time: Union[float, int]) -> Self:
@@ -644,7 +646,7 @@ class Data(BaseData, FeatureStore, GraphStore):
644
646
  return self
645
647
 
646
648
  def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
647
- if is_sparse(value) and 'adj' in key:
649
+ if is_sparse(value) and ('adj' in key or 'edge_index' in key):
648
650
  return (0, 1)
649
651
  elif 'index' in key or key == 'face':
650
652
  return -1
@@ -653,9 +655,17 @@ class Data(BaseData, FeatureStore, GraphStore):
653
655
 
654
656
  def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
655
657
  if 'batch' in key and isinstance(value, Tensor):
658
+ if isinstance(value, Index):
659
+ return value.get_dim_size()
656
660
  return int(value.max()) + 1
657
661
  elif 'index' in key or key == 'face':
658
- return self.num_nodes
662
+ num_nodes = self.num_nodes
663
+ if num_nodes is None:
664
+ raise RuntimeError(f"Unable to infer 'num_nodes' from the "
665
+ f"attribute '{key}'. Please explicitly set "
666
+ f"'num_nodes' as an attribute of 'data' to "
667
+ f"prevent this error")
668
+ return num_nodes
659
669
  else:
660
670
  return 0
661
671
 
@@ -934,16 +944,14 @@ class Data(BaseData, FeatureStore, GraphStore):
934
944
  r"""Iterates over all attributes in the data, yielding their attribute
935
945
  names and values.
936
946
  """
937
- for key, value in self._store.items():
938
- yield key, value
947
+ yield from self._store.items()
939
948
 
940
949
  def __call__(self, *args: str) -> Iterable:
941
950
  r"""Iterates over all attributes :obj:`*args` in the data, yielding
942
951
  their attribute names and values.
943
952
  If :obj:`*args` is not given, will iterate over all attributes.
944
953
  """
945
- for key, value in self._store.items(*args):
946
- yield key, value
954
+ yield from self._store.items(*args)
947
955
 
948
956
  @property
949
957
  def x(self) -> Optional[Tensor]:
@@ -1163,7 +1171,7 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
1163
1171
  f'[{value.num_rows}, {value.num_cols}])')
1164
1172
  elif isinstance(value, str):
1165
1173
  out = f"'{value}'"
1166
- elif isinstance(value, Sequence):
1174
+ elif isinstance(value, (Sequence, set)):
1167
1175
  out = str([len(value)])
1168
1176
  elif isinstance(value, Mapping) and len(value) == 0:
1169
1177
  out = '{}'
@@ -1,16 +1,15 @@
1
- import pickle
1
+ import io
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass
5
5
  from functools import cached_property
6
6
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
7
- from uuid import uuid4
8
7
 
9
8
  import torch
10
9
  from torch import Tensor
11
10
  from tqdm import tqdm
12
11
 
13
- from torch_geometric import EdgeIndex
12
+ from torch_geometric import EdgeIndex, Index
14
13
  from torch_geometric.edge_index import SortOrder
15
14
  from torch_geometric.utils.mixin import CastMixin
16
15
 
@@ -19,9 +18,17 @@ from torch_geometric.utils.mixin import CastMixin
19
18
  class TensorInfo(CastMixin):
20
19
  dtype: torch.dtype
21
20
  size: Tuple[int, ...] = (-1, )
21
+ is_index: bool = False
22
22
  is_edge_index: bool = False
23
23
 
24
24
  def __post_init__(self) -> None:
25
+ if self.is_index and self.is_edge_index:
26
+ raise ValueError("Tensor cannot be a 'Index' and 'EdgeIndex' "
27
+ "tensor at the same time")
28
+
29
+ if self.is_index:
30
+ self.size = (-1, )
31
+
25
32
  if self.is_edge_index:
26
33
  self.size = (2, -1)
27
34
 
@@ -33,7 +40,8 @@ def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]:
33
40
  return value
34
41
  if 'dtype' not in value:
35
42
  return value
36
- if len(set(value.keys()) | {'dtype', 'size', 'is_edge_index'}) != 3:
43
+ valid_keys = {'dtype', 'size', 'is_index', 'is_edge_index'}
44
+ if len(set(value.keys()) | valid_keys) != len(valid_keys):
37
45
  return value
38
46
  return TensorInfo.cast(value)
39
47
 
@@ -107,11 +115,9 @@ class Database(ABC):
107
115
  r"""Connects to the database.
108
116
  Databases will automatically connect on instantiation.
109
117
  """
110
- pass
111
118
 
112
119
  def close(self) -> None:
113
120
  r"""Closes the connection to the database."""
114
- pass
115
121
 
116
122
  @abstractmethod
117
123
  def insert(self, index: int, data: Any) -> None:
@@ -373,8 +379,9 @@ class SQLiteDatabase(Database):
373
379
 
374
380
  # We create a temporary ID table to then perform an INNER JOIN.
375
381
  # This avoids having a long IN clause and guarantees sorted outputs:
376
- join_table_name = f'{self.name}__join__{uuid4().hex}'
377
- query = (f'CREATE TABLE {join_table_name} (\n'
382
+ join_table_name = f'{self.name}__join'
383
+ # Temporary tables do not lock the database.
384
+ query = (f'CREATE TEMP TABLE {join_table_name} (\n'
378
385
  f' id INTEGER,\n'
379
386
  f' row_id INTEGER\n'
380
387
  f')')
@@ -452,10 +459,22 @@ class SQLiteDatabase(Database):
452
459
  if isinstance(col, Tensor) and not isinstance(schema, TensorInfo):
453
460
  self.schema[key] = schema = TensorInfo(
454
461
  col.dtype,
462
+ is_index=isinstance(col, Index),
455
463
  is_edge_index=isinstance(col, EdgeIndex),
456
464
  )
457
465
 
458
- if isinstance(schema, TensorInfo) and schema.is_edge_index:
466
+ if isinstance(schema, TensorInfo) and schema.is_index:
467
+ assert isinstance(col, Index)
468
+
469
+ meta = torch.tensor([
470
+ col.dim_size if col.dim_size is not None else -1,
471
+ col.is_sorted,
472
+ ], dtype=torch.long)
473
+
474
+ out.append(meta.numpy().tobytes() +
475
+ col.as_tensor().numpy().tobytes())
476
+
477
+ elif isinstance(schema, TensorInfo) and schema.is_edge_index:
459
478
  assert isinstance(col, EdgeIndex)
460
479
 
461
480
  num_rows, num_cols = col.sparse_size()
@@ -466,7 +485,8 @@ class SQLiteDatabase(Database):
466
485
  col.is_undirected,
467
486
  ], dtype=torch.long)
468
487
 
469
- out.append(meta.numpy().tobytes() + col.numpy().tobytes())
488
+ out.append(meta.numpy().tobytes() +
489
+ col.as_tensor().numpy().tobytes())
470
490
 
471
491
  elif isinstance(schema, TensorInfo):
472
492
  assert isinstance(col, Tensor)
@@ -476,7 +496,9 @@ class SQLiteDatabase(Database):
476
496
  out.append(col)
477
497
 
478
498
  else:
479
- out.append(pickle.dumps(col))
499
+ buffer = io.BytesIO()
500
+ torch.save(col, buffer)
501
+ out.append(buffer.getvalue())
480
502
 
481
503
  return out
482
504
 
@@ -490,7 +512,23 @@ class SQLiteDatabase(Database):
490
512
  for i, (key, schema) in enumerate(self.schema.items()):
491
513
  value = row[i]
492
514
 
493
- if isinstance(schema, TensorInfo) and schema.is_edge_index:
515
+ if isinstance(schema, TensorInfo) and schema.is_index:
516
+ meta = torch.frombuffer(value[:16], dtype=torch.long).tolist()
517
+ dim_size = meta[0] if meta[0] >= 0 else None
518
+ is_sorted = meta[1] > 0
519
+
520
+ if len(value) > 16:
521
+ tensor = torch.frombuffer(value[16:], dtype=schema.dtype)
522
+ else:
523
+ tensor = torch.empty(0, dtype=schema.dtype)
524
+
525
+ out_dict[key] = Index(
526
+ tensor.view(*schema.size),
527
+ dim_size=dim_size,
528
+ is_sorted=is_sorted,
529
+ )
530
+
531
+ elif isinstance(schema, TensorInfo) and schema.is_edge_index:
494
532
  meta = torch.frombuffer(value[:32], dtype=torch.long).tolist()
495
533
  num_rows = meta[0] if meta[0] >= 0 else None
496
534
  num_cols = meta[1] if meta[1] >= 0 else None
@@ -523,7 +561,10 @@ class SQLiteDatabase(Database):
523
561
  out_dict[key] = value
524
562
 
525
563
  else:
526
- out_dict[key] = pickle.loads(value)
564
+ out_dict[key] = torch.load(
565
+ io.BytesIO(value),
566
+ weights_only=False,
567
+ )
527
568
 
528
569
  # In case `0` exists as integer in the schema, this means that the
529
570
  # schema was passed as either a single entry or a tuple:
@@ -608,7 +649,12 @@ class RocksDatabase(Database):
608
649
  # Ensure that data is not a view of a larger tensor:
609
650
  if isinstance(row, Tensor):
610
651
  row = row.clone()
611
- return pickle.dumps(row)
652
+ buffer = io.BytesIO()
653
+ torch.save(row, buffer)
654
+ return buffer.getvalue()
612
655
 
613
656
  def _deserialize(self, row: bytes) -> Any:
614
- return pickle.loads(row)
657
+ return torch.load(
658
+ io.BytesIO(row),
659
+ weights_only=False,
660
+ )