pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -235,7 +235,8 @@ class Dataset(torch.utils.data.Dataset):
235
235
 
236
236
  def _process(self):
237
237
  f = osp.join(self.processed_dir, 'pre_transform.pt')
238
- if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
238
+ if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
239
+ self.pre_transform):
239
240
  warnings.warn(
240
241
  "The `pre_transform` argument differs from the one used in "
241
242
  "the pre-processed version of this dataset. If you want to "
@@ -243,7 +244,8 @@ class Dataset(torch.utils.data.Dataset):
243
244
  "`force_reload=True` explicitly to reload the dataset.")
244
245
 
245
246
  f = osp.join(self.processed_dir, 'pre_filter.pt')
246
- if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
247
+ if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
248
+ self.pre_filter):
247
249
  warnings.warn(
248
250
  "The `pre_filter` argument differs from the one used in "
249
251
  "the pre-processed version of this dataset. If you want to "
@@ -367,15 +369,21 @@ class Dataset(torch.utils.data.Dataset):
367
369
  from torch_geometric.data.summary import Summary
368
370
  return Summary.from_dataset(self)
369
371
 
370
- def print_summary(self) -> None:
371
- r"""Prints summary statistics of the dataset to the console."""
372
- print(str(self.get_summary()))
372
+ def print_summary(self, fmt: str = "psql") -> None:
373
+ r"""Prints summary statistics of the dataset to the console.
374
+
375
+ Args:
376
+ fmt (str, optional): Summary tables format. Available table formats
377
+ can be found `here <https://github.com/astanin/python-tabulate?
378
+ tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
379
+ """
380
+ print(self.get_summary().format(fmt=fmt))
373
381
 
374
382
  def to_datapipe(self) -> Any:
375
383
  r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
376
384
 
377
385
  The returned instance can then be used with :pyg:`PyG's` built-in
378
- :class:`DataPipes` for baching graphs as follows:
386
+ :class:`DataPipes` for batching graphs as follows:
379
387
 
380
388
  .. code-block:: python
381
389
 
@@ -28,6 +28,7 @@ from typing import Any, List, Optional, Tuple, Union
28
28
 
29
29
  import numpy as np
30
30
  import torch
31
+ from torch import Tensor
31
32
 
32
33
  from torch_geometric.typing import FeatureTensorType, NodeType
33
34
  from torch_geometric.utils.mixin import CastMixin
@@ -73,13 +74,6 @@ class TensorAttr(CastMixin):
73
74
  r"""Whether the :obj:`TensorAttr` has no unset fields."""
74
75
  return all([self.is_set(key) for key in self.__dataclass_fields__])
75
76
 
76
- def fully_specify(self) -> 'TensorAttr':
77
- r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
78
- for key in self.__dataclass_fields__:
79
- if not self.is_set(key):
80
- setattr(self, key, None)
81
- return self
82
-
83
77
  def update(self, attr: 'TensorAttr') -> 'TensorAttr':
84
78
  r"""Updates an :class:`TensorAttr` with set attributes from another
85
79
  :class:`TensorAttr`.
@@ -229,10 +223,11 @@ class AttrView(CastMixin):
229
223
 
230
224
  store[group_name, attr_name]()
231
225
  """
232
- # Set all UNSET values to None:
233
- out = copy.copy(self)
234
- out._attr.fully_specify()
235
- return out._store.get_tensor(out._attr)
226
+ attr = copy.copy(self._attr)
227
+ for key in attr.__dataclass_fields__: # Set all UNSET values to None.
228
+ if not attr.is_set(key):
229
+ setattr(attr, key, None)
230
+ return self._store.get_tensor(attr)
236
231
 
237
232
  def __copy__(self) -> 'AttrView':
238
233
  out = self.__class__.__new__(self.__class__)
@@ -282,7 +277,6 @@ class FeatureStore(ABC):
282
277
  @abstractmethod
283
278
  def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
284
279
  r"""To be implemented by :class:`FeatureStore` subclasses."""
285
- pass
286
280
 
287
281
  def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
288
282
  r"""Synchronously adds a :obj:`tensor` to the :class:`FeatureStore`.
@@ -308,7 +302,6 @@ class FeatureStore(ABC):
308
302
  @abstractmethod
309
303
  def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
310
304
  r"""To be implemented by :class:`FeatureStore` subclasses."""
311
- pass
312
305
 
313
306
  def get_tensor(
314
307
  self,
@@ -329,8 +322,6 @@ class FeatureStore(ABC):
329
322
  Raises:
330
323
  ValueError: If the input :class:`TensorAttr` is not fully
331
324
  specified.
332
- KeyError: If the tensor corresponding to the input
333
- :class:`TensorAttr` was not found.
334
325
  """
335
326
  attr = self._tensor_attr_cls.cast(*args, **kwargs)
336
327
  if not attr.is_fully_specified():
@@ -339,9 +330,9 @@ class FeatureStore(ABC):
339
330
  f"specifying all 'UNSET' fields.")
340
331
 
341
332
  tensor = self._get_tensor(attr)
342
- if tensor is None:
343
- raise KeyError(f"A tensor corresponding to '{attr}' was not found")
344
- return self._to_type(attr, tensor) if convert_type else tensor
333
+ if convert_type:
334
+ tensor = self._to_type(attr, tensor)
335
+ return tensor
345
336
 
346
337
  def _multi_get_tensor(
347
338
  self,
@@ -375,8 +366,6 @@ class FeatureStore(ABC):
375
366
  Raises:
376
367
  ValueError: If any input :class:`TensorAttr` is not fully
377
368
  specified.
378
- KeyError: If any of the tensors corresponding to the input
379
- :class:`TensorAttr` was not found.
380
369
  """
381
370
  attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs]
382
371
  bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()]
@@ -387,20 +376,16 @@ class FeatureStore(ABC):
387
376
  f"'UNSET' fields")
388
377
 
389
378
  tensors = self._multi_get_tensor(attrs)
390
- if any(v is None for v in tensors):
391
- bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None]
392
- raise KeyError(f"Tensors corresponding to attributes "
393
- f"'{bad_attrs}' were not found")
394
-
395
- return [
396
- self._to_type(attr, tensor) if convert_type else tensor
397
- for attr, tensor in zip(attrs, tensors)
398
- ]
379
+ if convert_type:
380
+ tensors = [
381
+ self._to_type(attr, tensor)
382
+ for attr, tensor in zip(attrs, tensors)
383
+ ]
384
+ return tensors
399
385
 
400
386
  @abstractmethod
401
387
  def _remove_tensor(self, attr: TensorAttr) -> bool:
402
388
  r"""To be implemented by :obj:`FeatureStore` subclasses."""
403
- pass
404
389
 
405
390
  def remove_tensor(self, *args, **kwargs) -> bool:
406
391
  r"""Removes a tensor from the :class:`FeatureStore`.
@@ -458,7 +443,6 @@ class FeatureStore(ABC):
458
443
  @abstractmethod
459
444
  def get_all_tensor_attrs(self) -> List[TensorAttr]:
460
445
  r"""Returns all registered tensor attributes."""
461
- pass
462
446
 
463
447
  # `AttrView` methods ######################################################
464
448
 
@@ -476,11 +460,9 @@ class FeatureStore(ABC):
476
460
  attr: TensorAttr,
477
461
  tensor: FeatureTensorType,
478
462
  ) -> FeatureTensorType:
479
- if (isinstance(attr.index, torch.Tensor)
480
- and isinstance(tensor, np.ndarray)):
463
+ if isinstance(attr.index, Tensor) and isinstance(tensor, np.ndarray):
481
464
  return torch.from_numpy(tensor)
482
- if (isinstance(attr.index, np.ndarray)
483
- and isinstance(tensor, torch.Tensor)):
465
+ if isinstance(attr.index, np.ndarray) and isinstance(tensor, Tensor):
484
466
  return tensor.detach().cpu().numpy()
485
467
  return tensor
486
468
 
@@ -491,9 +473,7 @@ class FeatureStore(ABC):
491
473
  # CastMixin will handle the case of key being a tuple or TensorAttr
492
474
  # object:
493
475
  key = self._tensor_attr_cls.cast(key)
494
- # We need to fully-specify the key for __setitem__ as it does not make
495
- # sense to work with a view here:
496
- key.fully_specify()
476
+ assert key.is_fully_specified()
497
477
  self.put_tensor(value, key)
498
478
 
499
479
  def __getitem__(self, key: TensorAttr) -> Any:
@@ -515,13 +495,16 @@ class FeatureStore(ABC):
515
495
  # If the view is not fully-specified, return a :class:`AttrView`:
516
496
  return self.view(attr)
517
497
 
518
- def __delitem__(self, key: TensorAttr):
498
+ def __delitem__(self, attr: TensorAttr):
519
499
  r"""Supports :obj:`del store[tensor_attr]`."""
520
500
  # CastMixin will handle the case of key being a tuple or TensorAttr
521
501
  # object:
522
- key = self._tensor_attr_cls.cast(key)
523
- key.fully_specify()
524
- self.remove_tensor(key)
502
+ attr = self._tensor_attr_cls.cast(attr)
503
+ attr = copy.copy(attr)
504
+ for key in attr.__dataclass_fields__: # Set all UNSET values to None.
505
+ if not attr.is_set(key):
506
+ setattr(attr, key, None)
507
+ self.remove_tensor(attr)
525
508
 
526
509
  def __iter__(self):
527
510
  raise NotImplementedError
@@ -25,10 +25,10 @@ from typing import Any, Dict, List, Optional, Tuple
25
25
 
26
26
  from torch import Tensor
27
27
 
28
+ from torch_geometric.index import index2ptr, ptr2index
28
29
  from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor
29
30
  from torch_geometric.utils import index_sort
30
31
  from torch_geometric.utils.mixin import CastMixin
31
- from torch_geometric.utils.sparse import index2ptr, ptr2index
32
32
 
33
33
  # The output of converting between two types in the GraphStore is a Tuple of
34
34
  # dictionaries: row, col, and perm. The dictionaries are keyed by the edge
@@ -116,7 +116,6 @@ class GraphStore(ABC):
116
116
  def _put_edge_index(self, edge_index: EdgeTensorType,
117
117
  edge_attr: EdgeAttr) -> bool:
118
118
  r"""To be implemented by :class:`GraphStore` subclasses."""
119
- pass
120
119
 
121
120
  def put_edge_index(self, edge_index: EdgeTensorType, *args,
122
121
  **kwargs) -> bool:
@@ -137,7 +136,6 @@ class GraphStore(ABC):
137
136
  @abstractmethod
138
137
  def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
139
138
  r"""To be implemented by :class:`GraphStore` subclasses."""
140
- pass
141
139
 
142
140
  def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
143
141
  r"""Synchronously obtains an :obj:`edge_index` tuple from the
@@ -160,7 +158,6 @@ class GraphStore(ABC):
160
158
  @abstractmethod
161
159
  def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:
162
160
  r"""To be implemented by :class:`GraphStore` subclasses."""
163
- pass
164
161
 
165
162
  def remove_edge_index(self, *args, **kwargs) -> bool:
166
163
  r"""Synchronously deletes an :obj:`edge_index` tuple from the
@@ -177,7 +174,6 @@ class GraphStore(ABC):
177
174
  @abstractmethod
178
175
  def get_all_edge_attrs(self) -> List[EdgeAttr]:
179
176
  r"""Returns all registered edge attributes."""
180
- pass
181
177
 
182
178
  # Layout Conversion #######################################################
183
179
 
@@ -10,6 +10,7 @@ import torch
10
10
  from torch import Tensor
11
11
  from typing_extensions import Self
12
12
 
13
+ from torch_geometric import Index
13
14
  from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr
14
15
  from torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise
15
16
  from torch_geometric.data.graph_store import EdgeLayout
@@ -36,6 +37,8 @@ from torch_geometric.utils import (
36
37
 
37
38
  NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]
38
39
 
40
+ _DISPLAYED_TYPE_NAME_WARNING: bool = False
41
+
39
42
 
40
43
  class HeteroData(BaseData, FeatureStore, GraphStore):
41
44
  r"""A data object describing a heterogeneous graph, holding multiple node
@@ -334,7 +337,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
334
337
  def __cat_dim__(self, key: str, value: Any,
335
338
  store: Optional[NodeOrEdgeStorage] = None, *args,
336
339
  **kwargs) -> Any:
337
- if is_sparse(value) and 'adj' in key:
340
+ if is_sparse(value) and ('adj' in key or 'edge_index' in key):
338
341
  return (0, 1)
339
342
  elif isinstance(store, EdgeStorage) and 'index' in key:
340
343
  return -1
@@ -344,6 +347,8 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
344
347
  store: Optional[NodeOrEdgeStorage] = None, *args,
345
348
  **kwargs) -> Any:
346
349
  if 'batch' in key and isinstance(value, Tensor):
350
+ if isinstance(value, Index):
351
+ return value.get_dim_size()
347
352
  return int(value.max()) + 1
348
353
  elif isinstance(store, EdgeStorage) and 'index' in key:
349
354
  return torch.tensor(store.size()).view(2, 1)
@@ -562,11 +567,15 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
562
567
  return mapping
563
568
 
564
569
  def _check_type_name(self, name: str):
565
- if '__' in name:
566
- warnings.warn(f"The type '{name}' contains double underscores "
567
- f"('__') which may lead to unexpected behavior. "
568
- f"To avoid any issues, ensure that your type names "
569
- f"only contain single underscores.")
570
+ global _DISPLAYED_TYPE_NAME_WARNING
571
+ if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:
572
+ _DISPLAYED_TYPE_NAME_WARNING = True
573
+ warnings.warn(f"There exist type names in the "
574
+ f"'{self.__class__.__name__}' object that contain "
575
+ f"double underscores '__' (e.g., '{name}'). This "
576
+ f"may lead to unexpected behavior. To avoid any "
577
+ f"issues, ensure that your type names only contain "
578
+ f"single underscores.")
570
579
 
571
580
  def get_node_store(self, key: NodeType) -> NodeStorage:
572
581
  r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object
@@ -771,8 +780,8 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
771
780
  for edge_type in self.edge_types:
772
781
  if edge_type not in edge_types:
773
782
  del data[edge_type]
774
- node_types = set(e[0] for e in edge_types)
775
- node_types |= set(e[-1] for e in edge_types)
783
+ node_types = {e[0] for e in edge_types}
784
+ node_types |= {e[-1] for e in edge_types}
776
785
  for node_type in self.node_types:
777
786
  if node_type not in node_types:
778
787
  del data[node_type]
@@ -878,7 +887,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
878
887
  if len(sizes) != len(stores):
879
888
  continue
880
889
  # The attributes needs to have the same number of dimensions:
881
- lengths = set([len(size) for size in sizes])
890
+ lengths = {len(size) for size in sizes}
882
891
  if len(lengths) != 1:
883
892
  continue
884
893
  # The attributes needs to have the same size in all dimensions:
@@ -347,10 +347,8 @@ class InMemoryDataset(Dataset):
347
347
  def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
348
348
  if isinstance(node, Mapping):
349
349
  for key, value in node.items():
350
- for inner_key, inner_value in nested_iter(value):
351
- yield inner_key, inner_value
350
+ yield from nested_iter(value)
352
351
  elif isinstance(node, Sequence):
353
- for i, inner_value in enumerate(node):
354
- yield i, inner_value
352
+ yield from enumerate(node)
355
353
  else:
356
354
  yield None, node