pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/data/dataset.py
CHANGED
@@ -235,7 +235,8 @@ class Dataset(torch.utils.data.Dataset):
|
|
235
235
|
|
236
236
|
def _process(self):
|
237
237
|
f = osp.join(self.processed_dir, 'pre_transform.pt')
|
238
|
-
if osp.exists(f) and torch.load(f) != _repr(
|
238
|
+
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
|
239
|
+
self.pre_transform):
|
239
240
|
warnings.warn(
|
240
241
|
"The `pre_transform` argument differs from the one used in "
|
241
242
|
"the pre-processed version of this dataset. If you want to "
|
@@ -243,7 +244,8 @@ class Dataset(torch.utils.data.Dataset):
|
|
243
244
|
"`force_reload=True` explicitly to reload the dataset.")
|
244
245
|
|
245
246
|
f = osp.join(self.processed_dir, 'pre_filter.pt')
|
246
|
-
if osp.exists(f) and torch.load(f) != _repr(
|
247
|
+
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
|
248
|
+
self.pre_filter):
|
247
249
|
warnings.warn(
|
248
250
|
"The `pre_filter` argument differs from the one used in "
|
249
251
|
"the pre-processed version of this dataset. If you want to "
|
@@ -367,15 +369,21 @@ class Dataset(torch.utils.data.Dataset):
|
|
367
369
|
from torch_geometric.data.summary import Summary
|
368
370
|
return Summary.from_dataset(self)
|
369
371
|
|
370
|
-
def print_summary(self) -> None:
|
371
|
-
r"""Prints summary statistics of the dataset to the console.
|
372
|
-
|
372
|
+
def print_summary(self, fmt: str = "psql") -> None:
|
373
|
+
r"""Prints summary statistics of the dataset to the console.
|
374
|
+
|
375
|
+
Args:
|
376
|
+
fmt (str, optional): Summary tables format. Available table formats
|
377
|
+
can be found `here <https://github.com/astanin/python-tabulate?
|
378
|
+
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
|
379
|
+
"""
|
380
|
+
print(self.get_summary().format(fmt=fmt))
|
373
381
|
|
374
382
|
def to_datapipe(self) -> Any:
|
375
383
|
r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
|
376
384
|
|
377
385
|
The returned instance can then be used with :pyg:`PyG's` built-in
|
378
|
-
:class:`DataPipes` for
|
386
|
+
:class:`DataPipes` for batching graphs as follows:
|
379
387
|
|
380
388
|
.. code-block:: python
|
381
389
|
|
@@ -28,6 +28,7 @@ from typing import Any, List, Optional, Tuple, Union
|
|
28
28
|
|
29
29
|
import numpy as np
|
30
30
|
import torch
|
31
|
+
from torch import Tensor
|
31
32
|
|
32
33
|
from torch_geometric.typing import FeatureTensorType, NodeType
|
33
34
|
from torch_geometric.utils.mixin import CastMixin
|
@@ -73,13 +74,6 @@ class TensorAttr(CastMixin):
|
|
73
74
|
r"""Whether the :obj:`TensorAttr` has no unset fields."""
|
74
75
|
return all([self.is_set(key) for key in self.__dataclass_fields__])
|
75
76
|
|
76
|
-
def fully_specify(self) -> 'TensorAttr':
|
77
|
-
r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
|
78
|
-
for key in self.__dataclass_fields__:
|
79
|
-
if not self.is_set(key):
|
80
|
-
setattr(self, key, None)
|
81
|
-
return self
|
82
|
-
|
83
77
|
def update(self, attr: 'TensorAttr') -> 'TensorAttr':
|
84
78
|
r"""Updates an :class:`TensorAttr` with set attributes from another
|
85
79
|
:class:`TensorAttr`.
|
@@ -229,10 +223,11 @@ class AttrView(CastMixin):
|
|
229
223
|
|
230
224
|
store[group_name, attr_name]()
|
231
225
|
"""
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
226
|
+
attr = copy.copy(self._attr)
|
227
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
228
|
+
if not attr.is_set(key):
|
229
|
+
setattr(attr, key, None)
|
230
|
+
return self._store.get_tensor(attr)
|
236
231
|
|
237
232
|
def __copy__(self) -> 'AttrView':
|
238
233
|
out = self.__class__.__new__(self.__class__)
|
@@ -282,7 +277,6 @@ class FeatureStore(ABC):
|
|
282
277
|
@abstractmethod
|
283
278
|
def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
|
284
279
|
r"""To be implemented by :class:`FeatureStore` subclasses."""
|
285
|
-
pass
|
286
280
|
|
287
281
|
def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
|
288
282
|
r"""Synchronously adds a :obj:`tensor` to the :class:`FeatureStore`.
|
@@ -308,7 +302,6 @@ class FeatureStore(ABC):
|
|
308
302
|
@abstractmethod
|
309
303
|
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
|
310
304
|
r"""To be implemented by :class:`FeatureStore` subclasses."""
|
311
|
-
pass
|
312
305
|
|
313
306
|
def get_tensor(
|
314
307
|
self,
|
@@ -329,8 +322,6 @@ class FeatureStore(ABC):
|
|
329
322
|
Raises:
|
330
323
|
ValueError: If the input :class:`TensorAttr` is not fully
|
331
324
|
specified.
|
332
|
-
KeyError: If the tensor corresponding to the input
|
333
|
-
:class:`TensorAttr` was not found.
|
334
325
|
"""
|
335
326
|
attr = self._tensor_attr_cls.cast(*args, **kwargs)
|
336
327
|
if not attr.is_fully_specified():
|
@@ -339,9 +330,9 @@ class FeatureStore(ABC):
|
|
339
330
|
f"specifying all 'UNSET' fields.")
|
340
331
|
|
341
332
|
tensor = self._get_tensor(attr)
|
342
|
-
if
|
343
|
-
|
344
|
-
return
|
333
|
+
if convert_type:
|
334
|
+
tensor = self._to_type(attr, tensor)
|
335
|
+
return tensor
|
345
336
|
|
346
337
|
def _multi_get_tensor(
|
347
338
|
self,
|
@@ -375,8 +366,6 @@ class FeatureStore(ABC):
|
|
375
366
|
Raises:
|
376
367
|
ValueError: If any input :class:`TensorAttr` is not fully
|
377
368
|
specified.
|
378
|
-
KeyError: If any of the tensors corresponding to the input
|
379
|
-
:class:`TensorAttr` was not found.
|
380
369
|
"""
|
381
370
|
attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs]
|
382
371
|
bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()]
|
@@ -387,20 +376,16 @@ class FeatureStore(ABC):
|
|
387
376
|
f"'UNSET' fields")
|
388
377
|
|
389
378
|
tensors = self._multi_get_tensor(attrs)
|
390
|
-
if
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
return
|
396
|
-
self._to_type(attr, tensor) if convert_type else tensor
|
397
|
-
for attr, tensor in zip(attrs, tensors)
|
398
|
-
]
|
379
|
+
if convert_type:
|
380
|
+
tensors = [
|
381
|
+
self._to_type(attr, tensor)
|
382
|
+
for attr, tensor in zip(attrs, tensors)
|
383
|
+
]
|
384
|
+
return tensors
|
399
385
|
|
400
386
|
@abstractmethod
|
401
387
|
def _remove_tensor(self, attr: TensorAttr) -> bool:
|
402
388
|
r"""To be implemented by :obj:`FeatureStore` subclasses."""
|
403
|
-
pass
|
404
389
|
|
405
390
|
def remove_tensor(self, *args, **kwargs) -> bool:
|
406
391
|
r"""Removes a tensor from the :class:`FeatureStore`.
|
@@ -458,7 +443,6 @@ class FeatureStore(ABC):
|
|
458
443
|
@abstractmethod
|
459
444
|
def get_all_tensor_attrs(self) -> List[TensorAttr]:
|
460
445
|
r"""Returns all registered tensor attributes."""
|
461
|
-
pass
|
462
446
|
|
463
447
|
# `AttrView` methods ######################################################
|
464
448
|
|
@@ -476,11 +460,9 @@ class FeatureStore(ABC):
|
|
476
460
|
attr: TensorAttr,
|
477
461
|
tensor: FeatureTensorType,
|
478
462
|
) -> FeatureTensorType:
|
479
|
-
if
|
480
|
-
and isinstance(tensor, np.ndarray)):
|
463
|
+
if isinstance(attr.index, Tensor) and isinstance(tensor, np.ndarray):
|
481
464
|
return torch.from_numpy(tensor)
|
482
|
-
if
|
483
|
-
and isinstance(tensor, torch.Tensor)):
|
465
|
+
if isinstance(attr.index, np.ndarray) and isinstance(tensor, Tensor):
|
484
466
|
return tensor.detach().cpu().numpy()
|
485
467
|
return tensor
|
486
468
|
|
@@ -491,9 +473,7 @@ class FeatureStore(ABC):
|
|
491
473
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
492
474
|
# object:
|
493
475
|
key = self._tensor_attr_cls.cast(key)
|
494
|
-
|
495
|
-
# sense to work with a view here:
|
496
|
-
key.fully_specify()
|
476
|
+
assert key.is_fully_specified()
|
497
477
|
self.put_tensor(value, key)
|
498
478
|
|
499
479
|
def __getitem__(self, key: TensorAttr) -> Any:
|
@@ -515,13 +495,16 @@ class FeatureStore(ABC):
|
|
515
495
|
# If the view is not fully-specified, return a :class:`AttrView`:
|
516
496
|
return self.view(attr)
|
517
497
|
|
518
|
-
def __delitem__(self,
|
498
|
+
def __delitem__(self, attr: TensorAttr):
|
519
499
|
r"""Supports :obj:`del store[tensor_attr]`."""
|
520
500
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
521
501
|
# object:
|
522
|
-
|
523
|
-
|
524
|
-
|
502
|
+
attr = self._tensor_attr_cls.cast(attr)
|
503
|
+
attr = copy.copy(attr)
|
504
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
505
|
+
if not attr.is_set(key):
|
506
|
+
setattr(attr, key, None)
|
507
|
+
self.remove_tensor(attr)
|
525
508
|
|
526
509
|
def __iter__(self):
|
527
510
|
raise NotImplementedError
|
@@ -25,10 +25,10 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
25
25
|
|
26
26
|
from torch import Tensor
|
27
27
|
|
28
|
+
from torch_geometric.index import index2ptr, ptr2index
|
28
29
|
from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor
|
29
30
|
from torch_geometric.utils import index_sort
|
30
31
|
from torch_geometric.utils.mixin import CastMixin
|
31
|
-
from torch_geometric.utils.sparse import index2ptr, ptr2index
|
32
32
|
|
33
33
|
# The output of converting between two types in the GraphStore is a Tuple of
|
34
34
|
# dictionaries: row, col, and perm. The dictionaries are keyed by the edge
|
@@ -116,7 +116,6 @@ class GraphStore(ABC):
|
|
116
116
|
def _put_edge_index(self, edge_index: EdgeTensorType,
|
117
117
|
edge_attr: EdgeAttr) -> bool:
|
118
118
|
r"""To be implemented by :class:`GraphStore` subclasses."""
|
119
|
-
pass
|
120
119
|
|
121
120
|
def put_edge_index(self, edge_index: EdgeTensorType, *args,
|
122
121
|
**kwargs) -> bool:
|
@@ -137,7 +136,6 @@ class GraphStore(ABC):
|
|
137
136
|
@abstractmethod
|
138
137
|
def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
|
139
138
|
r"""To be implemented by :class:`GraphStore` subclasses."""
|
140
|
-
pass
|
141
139
|
|
142
140
|
def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
|
143
141
|
r"""Synchronously obtains an :obj:`edge_index` tuple from the
|
@@ -160,7 +158,6 @@ class GraphStore(ABC):
|
|
160
158
|
@abstractmethod
|
161
159
|
def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:
|
162
160
|
r"""To be implemented by :class:`GraphStore` subclasses."""
|
163
|
-
pass
|
164
161
|
|
165
162
|
def remove_edge_index(self, *args, **kwargs) -> bool:
|
166
163
|
r"""Synchronously deletes an :obj:`edge_index` tuple from the
|
@@ -177,7 +174,6 @@ class GraphStore(ABC):
|
|
177
174
|
@abstractmethod
|
178
175
|
def get_all_edge_attrs(self) -> List[EdgeAttr]:
|
179
176
|
r"""Returns all registered edge attributes."""
|
180
|
-
pass
|
181
177
|
|
182
178
|
# Layout Conversion #######################################################
|
183
179
|
|
@@ -10,6 +10,7 @@ import torch
|
|
10
10
|
from torch import Tensor
|
11
11
|
from typing_extensions import Self
|
12
12
|
|
13
|
+
from torch_geometric import Index
|
13
14
|
from torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr
|
14
15
|
from torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise
|
15
16
|
from torch_geometric.data.graph_store import EdgeLayout
|
@@ -36,6 +37,8 @@ from torch_geometric.utils import (
|
|
36
37
|
|
37
38
|
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]
|
38
39
|
|
40
|
+
_DISPLAYED_TYPE_NAME_WARNING: bool = False
|
41
|
+
|
39
42
|
|
40
43
|
class HeteroData(BaseData, FeatureStore, GraphStore):
|
41
44
|
r"""A data object describing a heterogeneous graph, holding multiple node
|
@@ -334,7 +337,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
334
337
|
def __cat_dim__(self, key: str, value: Any,
|
335
338
|
store: Optional[NodeOrEdgeStorage] = None, *args,
|
336
339
|
**kwargs) -> Any:
|
337
|
-
if is_sparse(value) and 'adj' in key:
|
340
|
+
if is_sparse(value) and ('adj' in key or 'edge_index' in key):
|
338
341
|
return (0, 1)
|
339
342
|
elif isinstance(store, EdgeStorage) and 'index' in key:
|
340
343
|
return -1
|
@@ -344,6 +347,8 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
344
347
|
store: Optional[NodeOrEdgeStorage] = None, *args,
|
345
348
|
**kwargs) -> Any:
|
346
349
|
if 'batch' in key and isinstance(value, Tensor):
|
350
|
+
if isinstance(value, Index):
|
351
|
+
return value.get_dim_size()
|
347
352
|
return int(value.max()) + 1
|
348
353
|
elif isinstance(store, EdgeStorage) and 'index' in key:
|
349
354
|
return torch.tensor(store.size()).view(2, 1)
|
@@ -562,11 +567,15 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
562
567
|
return mapping
|
563
568
|
|
564
569
|
def _check_type_name(self, name: str):
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
f"
|
570
|
+
global _DISPLAYED_TYPE_NAME_WARNING
|
571
|
+
if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:
|
572
|
+
_DISPLAYED_TYPE_NAME_WARNING = True
|
573
|
+
warnings.warn(f"There exist type names in the "
|
574
|
+
f"'{self.__class__.__name__}' object that contain "
|
575
|
+
f"double underscores '__' (e.g., '{name}'). This "
|
576
|
+
f"may lead to unexpected behavior. To avoid any "
|
577
|
+
f"issues, ensure that your type names only contain "
|
578
|
+
f"single underscores.")
|
570
579
|
|
571
580
|
def get_node_store(self, key: NodeType) -> NodeStorage:
|
572
581
|
r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object
|
@@ -771,8 +780,8 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
771
780
|
for edge_type in self.edge_types:
|
772
781
|
if edge_type not in edge_types:
|
773
782
|
del data[edge_type]
|
774
|
-
node_types =
|
775
|
-
node_types |=
|
783
|
+
node_types = {e[0] for e in edge_types}
|
784
|
+
node_types |= {e[-1] for e in edge_types}
|
776
785
|
for node_type in self.node_types:
|
777
786
|
if node_type not in node_types:
|
778
787
|
del data[node_type]
|
@@ -878,7 +887,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
878
887
|
if len(sizes) != len(stores):
|
879
888
|
continue
|
880
889
|
# The attributes needs to have the same number of dimensions:
|
881
|
-
lengths =
|
890
|
+
lengths = {len(size) for size in sizes}
|
882
891
|
if len(lengths) != 1:
|
883
892
|
continue
|
884
893
|
# The attributes needs to have the same size in all dimensions:
|
@@ -347,10 +347,8 @@ class InMemoryDataset(Dataset):
|
|
347
347
|
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
|
348
348
|
if isinstance(node, Mapping):
|
349
349
|
for key, value in node.items():
|
350
|
-
|
351
|
-
yield inner_key, inner_value
|
350
|
+
yield from nested_iter(value)
|
352
351
|
elif isinstance(node, Sequence):
|
353
|
-
|
354
|
-
yield i, inner_value
|
352
|
+
yield from enumerate(node)
|
355
353
|
else:
|
356
354
|
yield None, node
|