pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -11,7 +11,7 @@ This particular feature store abstraction makes a few key assumptions:
|
|
|
11
11
|
* A feature can be uniquely identified from any associated attributes specified
|
|
12
12
|
in `TensorAttr`.
|
|
13
13
|
|
|
14
|
-
It is the job of a feature store
|
|
14
|
+
It is the job of a feature store implementer class to handle these assumptions
|
|
15
15
|
properly. For example, a simple in-memory feature store implementation may
|
|
16
16
|
concatenate all metadata values with a feature index and use this as a unique
|
|
17
17
|
index in a KV store. More complicated implementations may choose to partition
|
|
@@ -74,13 +74,6 @@ class TensorAttr(CastMixin):
|
|
|
74
74
|
r"""Whether the :obj:`TensorAttr` has no unset fields."""
|
|
75
75
|
return all([self.is_set(key) for key in self.__dataclass_fields__])
|
|
76
76
|
|
|
77
|
-
def fully_specify(self) -> 'TensorAttr':
|
|
78
|
-
r"""Sets all :obj:`UNSET` fields to :obj:`None`."""
|
|
79
|
-
for key in self.__dataclass_fields__:
|
|
80
|
-
if not self.is_set(key):
|
|
81
|
-
setattr(self, key, None)
|
|
82
|
-
return self
|
|
83
|
-
|
|
84
77
|
def update(self, attr: 'TensorAttr') -> 'TensorAttr':
|
|
85
78
|
r"""Updates an :class:`TensorAttr` with set attributes from another
|
|
86
79
|
:class:`TensorAttr`.
|
|
@@ -230,10 +223,11 @@ class AttrView(CastMixin):
|
|
|
230
223
|
|
|
231
224
|
store[group_name, attr_name]()
|
|
232
225
|
"""
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
226
|
+
attr = copy.copy(self._attr)
|
|
227
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
|
228
|
+
if not attr.is_set(key):
|
|
229
|
+
setattr(attr, key, None)
|
|
230
|
+
return self._store.get_tensor(attr)
|
|
237
231
|
|
|
238
232
|
def __copy__(self) -> 'AttrView':
|
|
239
233
|
out = self.__class__.__new__(self.__class__)
|
|
@@ -358,7 +352,7 @@ class FeatureStore(ABC):
|
|
|
358
352
|
|
|
359
353
|
.. note::
|
|
360
354
|
The default implementation simply iterates over all calls to
|
|
361
|
-
:meth:`get_tensor`.
|
|
355
|
+
:meth:`get_tensor`. Implementer classes that can provide
|
|
362
356
|
additional, more performant functionality are recommended to
|
|
363
357
|
to override this method.
|
|
364
358
|
|
|
@@ -415,10 +409,10 @@ class FeatureStore(ABC):
|
|
|
415
409
|
def update_tensor(self, tensor: FeatureTensorType, *args,
|
|
416
410
|
**kwargs) -> bool:
|
|
417
411
|
r"""Updates a :obj:`tensor` in the :class:`FeatureStore` with a new
|
|
418
|
-
value. Returns whether the update was
|
|
412
|
+
value. Returns whether the update was successful.
|
|
419
413
|
|
|
420
414
|
.. note::
|
|
421
|
-
|
|
415
|
+
Implementer classes can choose to define more efficient update
|
|
422
416
|
methods; the default performs a removal and insertion.
|
|
423
417
|
|
|
424
418
|
Args:
|
|
@@ -479,9 +473,7 @@ class FeatureStore(ABC):
|
|
|
479
473
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
|
480
474
|
# object:
|
|
481
475
|
key = self._tensor_attr_cls.cast(key)
|
|
482
|
-
|
|
483
|
-
# sense to work with a view here:
|
|
484
|
-
key.fully_specify()
|
|
476
|
+
assert key.is_fully_specified()
|
|
485
477
|
self.put_tensor(value, key)
|
|
486
478
|
|
|
487
479
|
def __getitem__(self, key: TensorAttr) -> Any:
|
|
@@ -503,13 +495,16 @@ class FeatureStore(ABC):
|
|
|
503
495
|
# If the view is not fully-specified, return a :class:`AttrView`:
|
|
504
496
|
return self.view(attr)
|
|
505
497
|
|
|
506
|
-
def __delitem__(self,
|
|
498
|
+
def __delitem__(self, attr: TensorAttr):
|
|
507
499
|
r"""Supports :obj:`del store[tensor_attr]`."""
|
|
508
500
|
# CastMixin will handle the case of key being a tuple or TensorAttr
|
|
509
501
|
# object:
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
502
|
+
attr = self._tensor_attr_cls.cast(attr)
|
|
503
|
+
attr = copy.copy(attr)
|
|
504
|
+
for key in attr.__dataclass_fields__: # Set all UNSET values to None.
|
|
505
|
+
if not attr.is_set(key):
|
|
506
|
+
setattr(attr, key, None)
|
|
507
|
+
self.remove_tensor(attr)
|
|
513
508
|
|
|
514
509
|
def __iter__(self):
|
|
515
510
|
raise NotImplementedError
|
|
@@ -10,7 +10,7 @@ This particular graph store abstraction makes a few key assumptions:
|
|
|
10
10
|
support dynamic modification of edge indices once they have been inserted
|
|
11
11
|
into the graph store.
|
|
12
12
|
|
|
13
|
-
It is the job of a graph store
|
|
13
|
+
It is the job of a graph store implementer class to handle these assumptions
|
|
14
14
|
properly. For example, a simple in-memory graph store implementation may
|
|
15
15
|
concatenate all metadata values with an edge index and use this as a unique
|
|
16
16
|
index in a KV store. More complicated implementations may choose to partition
|
|
@@ -261,7 +261,8 @@ class GraphStore(ABC):
|
|
|
261
261
|
col = ptr2index(col)
|
|
262
262
|
|
|
263
263
|
if attr.layout != EdgeLayout.CSR: # COO->CSR
|
|
264
|
-
num_rows = attr.size[0] if attr.size else int(
|
|
264
|
+
num_rows = attr.size[0] if attr.size is not None else int(
|
|
265
|
+
row.max()) + 1
|
|
265
266
|
row, perm = index_sort(row, max_value=num_rows)
|
|
266
267
|
col = col[perm]
|
|
267
268
|
row = index2ptr(row, num_rows)
|
|
@@ -282,6 +282,21 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
282
282
|
r"""Returns a list of edge type and edge storage pairs."""
|
|
283
283
|
return list(self._edge_store_dict.items())
|
|
284
284
|
|
|
285
|
+
@property
|
|
286
|
+
def input_type(self) -> Optional[Union[NodeType, EdgeType]]:
|
|
287
|
+
r"""Returns the seed/input node/edge type of the graph in case it
|
|
288
|
+
refers to a sampled subgraph, *e.g.*, obtained via
|
|
289
|
+
:class:`~torch_geometric.loader.NeighborLoader` or
|
|
290
|
+
:class:`~torch_geometric.loader.LinkNeighborLoader`.
|
|
291
|
+
"""
|
|
292
|
+
for node_type, store in self.node_items():
|
|
293
|
+
if hasattr(store, 'input_id'):
|
|
294
|
+
return node_type
|
|
295
|
+
for edge_type, store in self.edge_items():
|
|
296
|
+
if hasattr(store, 'input_id'):
|
|
297
|
+
return edge_type
|
|
298
|
+
return None
|
|
299
|
+
|
|
285
300
|
def to_dict(self) -> Dict[str, Any]:
|
|
286
301
|
out_dict: Dict[str, Any] = {}
|
|
287
302
|
out_dict['_global_store'] = self._global_store.to_dict()
|
|
@@ -472,6 +487,77 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
472
487
|
|
|
473
488
|
return status
|
|
474
489
|
|
|
490
|
+
def connected_components(self) -> List[Self]:
|
|
491
|
+
r"""Extracts connected components of the heterogeneous graph using
|
|
492
|
+
a union-find algorithm. The components are returned as a list of
|
|
493
|
+
:class:`~torch_geometric.data.HeteroData` objects.
|
|
494
|
+
|
|
495
|
+
.. code-block::
|
|
496
|
+
|
|
497
|
+
data = HeteroData()
|
|
498
|
+
data["red"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
|
|
499
|
+
data["blue"].x = torch.tensor([[5.0], [6.0]])
|
|
500
|
+
data["red", "to", "red"].edge_index = torch.tensor(
|
|
501
|
+
[[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
components = data.connected_components()
|
|
505
|
+
print(len(components))
|
|
506
|
+
>>> 4
|
|
507
|
+
|
|
508
|
+
print(components[0])
|
|
509
|
+
>>> HeteroData(
|
|
510
|
+
red={x: tensor([[1.], [2.]])},
|
|
511
|
+
blue={x: tensor([[]])},
|
|
512
|
+
red, to, red={edge_index: tensor([[0, 1], [1, 0]])}
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
List[HeteroData]: A list of connected components.
|
|
517
|
+
"""
|
|
518
|
+
# Initialize union-find structures
|
|
519
|
+
self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}
|
|
520
|
+
self._ranks: Dict[Tuple[str, int], int] = {}
|
|
521
|
+
|
|
522
|
+
# Union-Find algorithm to find connected components
|
|
523
|
+
for edge_type in self.edge_types:
|
|
524
|
+
src, _, dst = edge_type
|
|
525
|
+
edge_index = self[edge_type].edge_index
|
|
526
|
+
for src_node, dst_node in edge_index.t().tolist():
|
|
527
|
+
self._union((src, src_node), (dst, dst_node))
|
|
528
|
+
|
|
529
|
+
# Rerun _find_parent to ensure all nodes are covered correctly
|
|
530
|
+
for node_type in self.node_types:
|
|
531
|
+
for node_index in range(self[node_type].num_nodes):
|
|
532
|
+
self._find_parent((node_type, node_index))
|
|
533
|
+
|
|
534
|
+
# Group nodes by their representative parent
|
|
535
|
+
components_map = defaultdict(list)
|
|
536
|
+
for node, parent in self._parents.items():
|
|
537
|
+
components_map[parent].append(node)
|
|
538
|
+
del self._parents
|
|
539
|
+
del self._ranks
|
|
540
|
+
|
|
541
|
+
components: List[Self] = []
|
|
542
|
+
for nodes in components_map.values():
|
|
543
|
+
# Prefill subset_dict with all node types to ensure all are present
|
|
544
|
+
subset_dict = {node_type: [] for node_type in self.node_types}
|
|
545
|
+
|
|
546
|
+
# Convert the list of (node_type, node_id) tuples to a subset_dict
|
|
547
|
+
for node_type, node_id in nodes:
|
|
548
|
+
subset_dict[node_type].append(node_id)
|
|
549
|
+
|
|
550
|
+
# Convert lists to tensors
|
|
551
|
+
for node_type, node_ids in subset_dict.items():
|
|
552
|
+
subset_dict[node_type] = torch.tensor(node_ids,
|
|
553
|
+
dtype=torch.long)
|
|
554
|
+
|
|
555
|
+
# Use the existing subgraph function to do all the heavy lifting
|
|
556
|
+
component_data = self.subgraph(subset_dict)
|
|
557
|
+
components.append(component_data)
|
|
558
|
+
|
|
559
|
+
return components
|
|
560
|
+
|
|
475
561
|
def debug(self):
|
|
476
562
|
pass # TODO
|
|
477
563
|
|
|
@@ -551,7 +637,7 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
551
637
|
This is equivalent to writing :obj:`data.x_dict`.
|
|
552
638
|
|
|
553
639
|
Args:
|
|
554
|
-
key (str): The attribute to collect from all node and
|
|
640
|
+
key (str): The attribute to collect from all node and edge types.
|
|
555
641
|
allow_empty (bool, optional): If set to :obj:`True`, will not raise
|
|
556
642
|
an error in case the attribute does not exit in any node or
|
|
557
643
|
edge type. (default: :obj:`False`)
|
|
@@ -570,12 +656,13 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
570
656
|
global _DISPLAYED_TYPE_NAME_WARNING
|
|
571
657
|
if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:
|
|
572
658
|
_DISPLAYED_TYPE_NAME_WARNING = True
|
|
573
|
-
warnings.warn(
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
659
|
+
warnings.warn(
|
|
660
|
+
f"There exist type names in the "
|
|
661
|
+
f"'{self.__class__.__name__}' object that contain "
|
|
662
|
+
f"double underscores '__' (e.g., '{name}'). This "
|
|
663
|
+
f"may lead to unexpected behavior. To avoid any "
|
|
664
|
+
f"issues, ensure that your type names only contain "
|
|
665
|
+
f"single underscores.", stacklevel=2)
|
|
579
666
|
|
|
580
667
|
def get_node_store(self, key: NodeType) -> NodeStorage:
|
|
581
668
|
r"""Gets the :class:`~torch_geometric.data.storage.NodeStorage` object
|
|
@@ -1132,6 +1219,51 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
|
|
|
1132
1219
|
|
|
1133
1220
|
return list(edge_attrs.values())
|
|
1134
1221
|
|
|
1222
|
+
# Connected Components Helper Functions ###################################
|
|
1223
|
+
|
|
1224
|
+
def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:
|
|
1225
|
+
r"""Finds and returns the representative parent of the given node in a
|
|
1226
|
+
disjoint-set (union-find) data structure. Implements path compression
|
|
1227
|
+
to optimize future queries.
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
node (tuple[str, int]): The node for which to find the parent.
|
|
1231
|
+
First element is the node type, second is the node index.
|
|
1232
|
+
|
|
1233
|
+
Returns:
|
|
1234
|
+
tuple[str, int]: The representative parent of the node.
|
|
1235
|
+
"""
|
|
1236
|
+
if node not in self._parents:
|
|
1237
|
+
self._parents[node] = node
|
|
1238
|
+
self._ranks[node] = 0
|
|
1239
|
+
if self._parents[node] != node:
|
|
1240
|
+
self._parents[node] = self._find_parent(self._parents[node])
|
|
1241
|
+
return self._parents[node]
|
|
1242
|
+
|
|
1243
|
+
def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):
|
|
1244
|
+
r"""Merges the node1 and node2 in the disjoint-set data structure.
|
|
1245
|
+
|
|
1246
|
+
Finds the root parents of node1 and node2 using the _find_parent
|
|
1247
|
+
method. If they belong to different sets, updates the parent of
|
|
1248
|
+
root2 to be root1, effectively merging the two sets.
|
|
1249
|
+
|
|
1250
|
+
Args:
|
|
1251
|
+
node1 (Tuple[str, int]): The first node to union. First element is
|
|
1252
|
+
the node type, second is the node index.
|
|
1253
|
+
node2 (Tuple[str, int]): The second node to union. First element is
|
|
1254
|
+
the node type, second is the node index.
|
|
1255
|
+
"""
|
|
1256
|
+
root1 = self._find_parent(node1)
|
|
1257
|
+
root2 = self._find_parent(node2)
|
|
1258
|
+
if root1 != root2:
|
|
1259
|
+
if self._ranks[root1] < self._ranks[root2]:
|
|
1260
|
+
self._parents[root1] = root2
|
|
1261
|
+
elif self._ranks[root1] > self._ranks[root2]:
|
|
1262
|
+
self._parents[root2] = root1
|
|
1263
|
+
else:
|
|
1264
|
+
self._parents[root2] = root1
|
|
1265
|
+
self._ranks[root1] += 1
|
|
1266
|
+
|
|
1135
1267
|
|
|
1136
1268
|
# Helper functions ############################################################
|
|
1137
1269
|
|
|
@@ -39,7 +39,7 @@ class HyperGraphData(Data):
|
|
|
39
39
|
edge_index (LongTensor, optional): Hyperedge tensor
|
|
40
40
|
with shape :obj:`[2, num_edges*num_nodes_per_edge]`.
|
|
41
41
|
Where `edge_index[1]` denotes the hyperedge index and
|
|
42
|
-
`edge_index[0]` denotes the node
|
|
42
|
+
`edge_index[0]` denotes the node indices that are connected
|
|
43
43
|
by the hyperedge. (default: :obj:`None`)
|
|
44
44
|
(default: :obj:`None`)
|
|
45
45
|
edge_attr (torch.Tensor, optional): Edge feature matrix with shape
|
|
@@ -223,4 +223,4 @@ def warn_or_raise(msg: str, raise_on_error: bool = True) -> None:
|
|
|
223
223
|
if raise_on_error:
|
|
224
224
|
raise ValueError(msg)
|
|
225
225
|
else:
|
|
226
|
-
warnings.warn(msg)
|
|
226
|
+
warnings.warn(msg, stacklevel=2)
|
|
@@ -297,7 +297,7 @@ class InMemoryDataset(Dataset):
|
|
|
297
297
|
self._data_list = None
|
|
298
298
|
msg += f' {msg4}'
|
|
299
299
|
|
|
300
|
-
warnings.warn(msg)
|
|
300
|
+
warnings.warn(msg, stacklevel=2)
|
|
301
301
|
|
|
302
302
|
return self._data
|
|
303
303
|
|
|
@@ -346,7 +346,7 @@ class InMemoryDataset(Dataset):
|
|
|
346
346
|
|
|
347
347
|
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
|
|
348
348
|
if isinstance(node, Mapping):
|
|
349
|
-
for
|
|
349
|
+
for value in node.values():
|
|
350
350
|
yield from nested_iter(value)
|
|
351
351
|
elif isinstance(node, Sequence):
|
|
352
352
|
yield from enumerate(node)
|
|
@@ -11,21 +11,27 @@ from torch_geometric.sampler import BaseSampler, NeighborSampler
|
|
|
11
11
|
from torch_geometric.typing import InputEdges, InputNodes, OptTensor
|
|
12
12
|
|
|
13
13
|
try:
|
|
14
|
-
from
|
|
15
|
-
|
|
14
|
+
from lightning.pytorch import LightningDataModule as _LightningDataModule
|
|
15
|
+
_pl_is_available = True
|
|
16
16
|
except ImportError:
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
try:
|
|
18
|
+
from pytorch_lightning import \
|
|
19
|
+
LightningDataModule as _LightningDataModule
|
|
20
|
+
_pl_is_available = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
_pl_is_available = False
|
|
23
|
+
_LightningDataModule = object
|
|
19
24
|
|
|
20
25
|
|
|
21
|
-
class LightningDataModule(
|
|
26
|
+
class LightningDataModule(_LightningDataModule):
|
|
22
27
|
def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None:
|
|
23
28
|
super().__init__()
|
|
24
29
|
|
|
25
|
-
if
|
|
30
|
+
if not _pl_is_available:
|
|
26
31
|
raise ModuleNotFoundError(
|
|
27
|
-
"No module named 'pytorch_lightning'
|
|
28
|
-
"Run 'pip install
|
|
32
|
+
"No module named 'pytorch_lightning' (or 'lightning') found "
|
|
33
|
+
"in your Python environment. Run 'pip install "
|
|
34
|
+
"pytorch_lightning' or 'pip install lightning'")
|
|
29
35
|
|
|
30
36
|
if not has_val:
|
|
31
37
|
self.val_dataloader = None # type: ignore
|
|
@@ -40,9 +46,11 @@ class LightningDataModule(PLLightningDataModule):
|
|
|
40
46
|
kwargs.get('num_workers', 0) > 0)
|
|
41
47
|
|
|
42
48
|
if 'shuffle' in kwargs:
|
|
43
|
-
warnings.warn(
|
|
44
|
-
|
|
45
|
-
|
|
49
|
+
warnings.warn(
|
|
50
|
+
f"The 'shuffle={kwargs['shuffle']}' option is "
|
|
51
|
+
f"ignored in '{self.__class__.__name__}'. Remove it "
|
|
52
|
+
f"from the argument list to disable this warning",
|
|
53
|
+
stacklevel=2)
|
|
46
54
|
del kwargs['shuffle']
|
|
47
55
|
|
|
48
56
|
self.kwargs = kwargs
|
|
@@ -74,34 +82,39 @@ class LightningData(LightningDataModule):
|
|
|
74
82
|
raise ValueError(f"Undefined 'loader' option (got '{loader}')")
|
|
75
83
|
|
|
76
84
|
if loader == 'full' and kwargs['batch_size'] != 1:
|
|
77
|
-
warnings.warn(
|
|
78
|
-
|
|
79
|
-
|
|
85
|
+
warnings.warn(
|
|
86
|
+
f"Re-setting 'batch_size' to 1 in "
|
|
87
|
+
f"'{self.__class__.__name__}' for loader='full' "
|
|
88
|
+
f"(got '{kwargs['batch_size']}')", stacklevel=2)
|
|
80
89
|
kwargs['batch_size'] = 1
|
|
81
90
|
|
|
82
91
|
if loader == 'full' and kwargs['num_workers'] != 0:
|
|
83
|
-
warnings.warn(
|
|
84
|
-
|
|
85
|
-
|
|
92
|
+
warnings.warn(
|
|
93
|
+
f"Re-setting 'num_workers' to 0 in "
|
|
94
|
+
f"'{self.__class__.__name__}' for loader='full' "
|
|
95
|
+
f"(got '{kwargs['num_workers']}')", stacklevel=2)
|
|
86
96
|
kwargs['num_workers'] = 0
|
|
87
97
|
|
|
88
98
|
if loader == 'full' and kwargs.get('sampler') is not None:
|
|
89
|
-
warnings.warn(
|
|
90
|
-
|
|
99
|
+
warnings.warn(
|
|
100
|
+
"'sampler' option is not supported for "
|
|
101
|
+
"loader='full'", stacklevel=2)
|
|
91
102
|
kwargs.pop('sampler', None)
|
|
92
103
|
|
|
93
104
|
if loader == 'full' and kwargs.get('batch_sampler') is not None:
|
|
94
|
-
warnings.warn(
|
|
95
|
-
|
|
105
|
+
warnings.warn(
|
|
106
|
+
"'batch_sampler' option is not supported for "
|
|
107
|
+
"loader='full'", stacklevel=2)
|
|
96
108
|
kwargs.pop('batch_sampler', None)
|
|
97
109
|
|
|
98
110
|
super().__init__(has_val, has_test, **kwargs)
|
|
99
111
|
|
|
100
112
|
if loader == 'full':
|
|
101
113
|
if kwargs.get('pin_memory', False):
|
|
102
|
-
warnings.warn(
|
|
103
|
-
|
|
104
|
-
|
|
114
|
+
warnings.warn(
|
|
115
|
+
f"Re-setting 'pin_memory' to 'False' in "
|
|
116
|
+
f"'{self.__class__.__name__}' for loader='full' "
|
|
117
|
+
f"(got 'True')", stacklevel=2)
|
|
105
118
|
self.kwargs['pin_memory'] = False
|
|
106
119
|
|
|
107
120
|
self.data = data
|
|
@@ -127,10 +140,11 @@ class LightningData(LightningDataModule):
|
|
|
127
140
|
graph_sampler.__class__,
|
|
128
141
|
)
|
|
129
142
|
if len(sampler_kwargs) > 0:
|
|
130
|
-
warnings.warn(
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
143
|
+
warnings.warn(
|
|
144
|
+
f"Ignoring the arguments "
|
|
145
|
+
f"{list(sampler_kwargs.keys())} in "
|
|
146
|
+
f"'{self.__class__.__name__}' since a custom "
|
|
147
|
+
f"'graph_sampler' was passed", stacklevel=2)
|
|
134
148
|
self.graph_sampler = graph_sampler
|
|
135
149
|
|
|
136
150
|
else:
|
torch_geometric/data/storage.py
CHANGED
|
@@ -454,7 +454,7 @@ class NodeStorage(BaseStorage):
|
|
|
454
454
|
f"'{set(self.keys())}'. Please explicitly set 'num_nodes' as an "
|
|
455
455
|
f"attribute of " +
|
|
456
456
|
("'data'" if self._key is None else f"'data[{self._key}]'") +
|
|
457
|
-
" to suppress this warning")
|
|
457
|
+
" to suppress this warning", stacklevel=2)
|
|
458
458
|
if 'edge_index' in self and isinstance(self.edge_index, Tensor):
|
|
459
459
|
if self.edge_index.numel() > 0:
|
|
460
460
|
return int(self.edge_index.max()) + 1
|
|
@@ -806,6 +806,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
|
|
|
806
806
|
return False
|
|
807
807
|
|
|
808
808
|
cat_dim = self._parent().__cat_dim__(key, value, self)
|
|
809
|
+
|
|
810
|
+
if not isinstance(cat_dim, int):
|
|
811
|
+
return False
|
|
812
|
+
|
|
809
813
|
num_nodes, num_edges = self.num_nodes, self.num_edges
|
|
810
814
|
|
|
811
815
|
if value.shape[cat_dim] != num_nodes:
|
|
@@ -852,6 +856,10 @@ class GlobalStorage(NodeStorage, EdgeStorage):
|
|
|
852
856
|
return False
|
|
853
857
|
|
|
854
858
|
cat_dim = self._parent().__cat_dim__(key, value, self)
|
|
859
|
+
|
|
860
|
+
if not isinstance(cat_dim, int):
|
|
861
|
+
return False
|
|
862
|
+
|
|
855
863
|
num_nodes, num_edges = self.num_nodes, self.num_edges
|
|
856
864
|
|
|
857
865
|
if value.shape[cat_dim] != num_edges:
|
|
@@ -30,6 +30,7 @@ from .faust import FAUST
|
|
|
30
30
|
from .dynamic_faust import DynamicFAUST
|
|
31
31
|
from .shapenet import ShapeNet
|
|
32
32
|
from .modelnet import ModelNet
|
|
33
|
+
from .medshapenet import MedShapeNet
|
|
33
34
|
from .coma import CoMA
|
|
34
35
|
from .shrec2016 import SHREC2016
|
|
35
36
|
from .tosca import TOSCA
|
|
@@ -76,7 +77,14 @@ from .wikidata import Wikidata5M
|
|
|
76
77
|
from .myket import MyketDataset
|
|
77
78
|
from .brca_tgca import BrcaTcga
|
|
78
79
|
from .neurograph import NeuroGraphDataset
|
|
79
|
-
from .web_qsp_dataset import WebQSPDataset
|
|
80
|
+
from .web_qsp_dataset import WebQSPDataset, CWQDataset
|
|
81
|
+
from .git_mol_dataset import GitMolDataset
|
|
82
|
+
from .molecule_gpt_dataset import MoleculeGPTDataset
|
|
83
|
+
from .instruct_mol_dataset import InstructMolDataset
|
|
84
|
+
from .protein_mpnn_dataset import ProteinMPNNDataset
|
|
85
|
+
from .tag_dataset import TAGDataset
|
|
86
|
+
from .city import CityNetwork
|
|
87
|
+
from .teeth3ds import Teeth3DS
|
|
80
88
|
|
|
81
89
|
from .dbp15k import DBP15K
|
|
82
90
|
from .aminer import AMiner
|
|
@@ -142,6 +150,7 @@ homo_datasets = [
|
|
|
142
150
|
'DynamicFAUST',
|
|
143
151
|
'ShapeNet',
|
|
144
152
|
'ModelNet',
|
|
153
|
+
'MedShapeNet',
|
|
145
154
|
'CoMA',
|
|
146
155
|
'SHREC2016',
|
|
147
156
|
'TOSCA',
|
|
@@ -190,6 +199,14 @@ homo_datasets = [
|
|
|
190
199
|
'BrcaTcga',
|
|
191
200
|
'NeuroGraphDataset',
|
|
192
201
|
'WebQSPDataset',
|
|
202
|
+
'CWQDataset',
|
|
203
|
+
'GitMolDataset',
|
|
204
|
+
'MoleculeGPTDataset',
|
|
205
|
+
'InstructMolDataset',
|
|
206
|
+
'ProteinMPNNDataset',
|
|
207
|
+
'TAGDataset',
|
|
208
|
+
'CityNetwork',
|
|
209
|
+
'Teeth3DS',
|
|
193
210
|
]
|
|
194
211
|
|
|
195
212
|
hetero_datasets = [
|
|
@@ -19,17 +19,15 @@ class Actor(InMemoryDataset):
|
|
|
19
19
|
actor's Wikipedia.
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
root
|
|
23
|
-
transform
|
|
22
|
+
root: Root directory where the dataset should be saved.
|
|
23
|
+
transform: A function/transform that takes in an
|
|
24
24
|
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
25
25
|
version. The data object will be transformed before every access.
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
force_reload (bool, optional): Whether to re-process the dataset.
|
|
32
|
-
(default: :obj:`False`)
|
|
26
|
+
pre_transform: A function/transform that takes in an
|
|
27
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
28
|
+
version. The data object will be transformed before being saved to
|
|
29
|
+
disk.
|
|
30
|
+
force_reload: Whether to re-process the dataset.
|
|
33
31
|
|
|
34
32
|
**STATS:**
|
|
35
33
|
|
|
@@ -25,13 +25,13 @@ class AirfRANS(InMemoryDataset):
|
|
|
25
25
|
features: the inlet velocity (two components in meter per second), the
|
|
26
26
|
distance to the airfoil (one component in meter), and the normals (two
|
|
27
27
|
components in meter, set to :obj:`0` if the point is not on the airfoil).
|
|
28
|
-
Each point is given a target of 4 components for the
|
|
28
|
+
Each point is given a target of 4 components for the underlying regression
|
|
29
29
|
task: the velocity (two components in meter per second), the pressure
|
|
30
30
|
divided by the specific mass (one component in meter squared per second
|
|
31
31
|
squared), the turbulent kinematic viscosity (one component in meter squared
|
|
32
32
|
per second).
|
|
33
|
-
|
|
34
|
-
the airfoil or not.
|
|
33
|
+
Finally, a boolean is attached to each point to inform if this point lies
|
|
34
|
+
on the airfoil or not.
|
|
35
35
|
|
|
36
36
|
A library for manipulating simulations of the dataset is available `here
|
|
37
37
|
<https://airfrans.readthedocs.io/en/latest/index.html>`_.
|
|
@@ -46,26 +46,24 @@ class AirfRANS(InMemoryDataset):
|
|
|
46
46
|
:obj:`torch_geometric.transforms.RadiusGraph` transform.
|
|
47
47
|
|
|
48
48
|
Args:
|
|
49
|
-
root
|
|
50
|
-
task
|
|
49
|
+
root: Root directory where the dataset should be saved.
|
|
50
|
+
task: The task to study (:obj:`"full"`, :obj:`"scarce"`,
|
|
51
51
|
:obj:`"reynolds"`, :obj:`"aoa"`) that defines the utilized training
|
|
52
52
|
and test splits.
|
|
53
|
-
train
|
|
54
|
-
|
|
55
|
-
transform
|
|
56
|
-
:
|
|
53
|
+
train: If :obj:`True`, loads the training dataset, otherwise the test
|
|
54
|
+
dataset.
|
|
55
|
+
transform: A function/transform that takes in an
|
|
56
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
57
57
|
version. The data object will be transformed before every access.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
an :obj:`torch_geometric.data.Data` object and returns a
|
|
58
|
+
pre_transform: A function/transform that takes in an
|
|
59
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
61
60
|
transformed version. The data object will be transformed before
|
|
62
|
-
being saved to disk.
|
|
63
|
-
pre_filter
|
|
61
|
+
being saved to disk.
|
|
62
|
+
pre_filter: A function that takes in an
|
|
64
63
|
:obj:`torch_geometric.data.Data` object and returns a boolean
|
|
65
64
|
value, indicating whether the data object should be included in the
|
|
66
|
-
final dataset.
|
|
67
|
-
force_reload
|
|
68
|
-
(default: :obj:`False`)
|
|
65
|
+
final dataset.
|
|
66
|
+
force_reload: Whether to re-process the dataset.
|
|
69
67
|
|
|
70
68
|
**STATS:**
|
|
71
69
|
|
|
@@ -14,22 +14,20 @@ class Airports(InMemoryDataset):
|
|
|
14
14
|
and labels correspond to activity levels.
|
|
15
15
|
Features are given by one-hot encoded node identifiers, as described in the
|
|
16
16
|
`"GraLSP: Graph Neural Networks with Local Structural Patterns"
|
|
17
|
-
|
|
17
|
+
<https://arxiv.org/abs/1911.07675>`_ paper.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
|
-
root
|
|
21
|
-
name
|
|
20
|
+
root: Root directory where the dataset should be saved.
|
|
21
|
+
name: The name of the dataset (:obj:`"USA"`, :obj:`"Brazil"`,
|
|
22
22
|
:obj:`"Europe"`).
|
|
23
|
-
transform
|
|
24
|
-
:
|
|
23
|
+
transform: A function/transform that takes in an
|
|
24
|
+
:class:`torch_geometric.data.Data` object and returns a transformed
|
|
25
25
|
version. The data object will be transformed before every access.
|
|
26
|
-
(default: :obj:`None`)
|
|
27
26
|
pre_transform (callable, optional): A function/transform that takes in
|
|
28
|
-
|
|
27
|
+
:class:`torch_geometric.data.Data` object and returns a
|
|
29
28
|
transformed version. The data object will be transformed before
|
|
30
|
-
being saved to disk.
|
|
31
|
-
force_reload
|
|
32
|
-
(default: :obj:`False`)
|
|
29
|
+
being saved to disk.
|
|
30
|
+
force_reload: Whether to re-process the dataset.
|
|
33
31
|
"""
|
|
34
32
|
edge_url = ('https://github.com/leoribeiro/struc2vec/'
|
|
35
33
|
'raw/master/graph/{}-airports.edgelist')
|