pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/index.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
Union,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
import numpy as np
|
|
15
16
|
import torch
|
|
16
17
|
import torch.utils._pytree as pytree
|
|
17
18
|
from torch import Tensor
|
|
@@ -103,10 +104,10 @@ class Index(Tensor):
|
|
|
103
104
|
conversion in case its representation is sorted.
|
|
104
105
|
Caches are filled based on demand (*e.g.*, when calling
|
|
105
106
|
:meth:`Index.get_indptr`), or when explicitly requested via
|
|
106
|
-
:meth:`Index.fill_cache_`, and are
|
|
107
|
+
:meth:`Index.fill_cache_`, and are maintained and adjusted over its
|
|
107
108
|
lifespan.
|
|
108
109
|
|
|
109
|
-
This representation ensures
|
|
110
|
+
This representation ensures optimal computation in GNN message passing
|
|
110
111
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
|
111
112
|
workflows.
|
|
112
113
|
|
|
@@ -120,7 +121,7 @@ class Index(Tensor):
|
|
|
120
121
|
assert index.is_sorted
|
|
121
122
|
|
|
122
123
|
# Flipping order:
|
|
123
|
-
|
|
124
|
+
index.flip(0)
|
|
124
125
|
>>> Index([[2, 1, 1, 0], dim_size=3)
|
|
125
126
|
assert not index.is_sorted
|
|
126
127
|
|
|
@@ -181,7 +182,7 @@ class Index(Tensor):
|
|
|
181
182
|
assert_one_dimensional(data)
|
|
182
183
|
assert_contiguous(data)
|
|
183
184
|
|
|
184
|
-
out = Tensor._make_wrapper_subclass(
|
|
185
|
+
out = Tensor._make_wrapper_subclass(
|
|
185
186
|
cls,
|
|
186
187
|
size=data.size(),
|
|
187
188
|
strides=data.stride(),
|
|
@@ -360,10 +361,10 @@ class Index(Tensor):
|
|
|
360
361
|
return index
|
|
361
362
|
|
|
362
363
|
# Prevent auto-wrapping outputs back into the proper subclass type:
|
|
363
|
-
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
364
|
+
__torch_function__ = torch._C._disabled_torch_function_impl # type: ignore
|
|
364
365
|
|
|
365
366
|
@classmethod
|
|
366
|
-
def __torch_dispatch__(
|
|
367
|
+
def __torch_dispatch__( # type: ignore
|
|
367
368
|
cls: Type,
|
|
368
369
|
func: Callable[..., Any],
|
|
369
370
|
types: Iterable[Type[Any]],
|
|
@@ -410,6 +411,14 @@ class Index(Tensor):
|
|
|
410
411
|
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
|
|
411
412
|
indent, force_newline=False)
|
|
412
413
|
|
|
414
|
+
def tolist(self) -> List[Any]:
|
|
415
|
+
"""""" # noqa: D419
|
|
416
|
+
return self._data.tolist()
|
|
417
|
+
|
|
418
|
+
def numpy(self, *, force: bool = False) -> np.ndarray:
|
|
419
|
+
"""""" # noqa: D419
|
|
420
|
+
return self._data.numpy(force=force)
|
|
421
|
+
|
|
413
422
|
# Helpers #################################################################
|
|
414
423
|
|
|
415
424
|
def _shallow_copy(self) -> 'Index':
|
|
@@ -632,7 +641,7 @@ def _slice(
|
|
|
632
641
|
step: int = 1,
|
|
633
642
|
) -> Index:
|
|
634
643
|
|
|
635
|
-
if ((start is None or start <= 0)
|
|
644
|
+
if ((start is None or start <= 0 or start <= -input.size(dim))
|
|
636
645
|
and (end is None or end > input.size(dim)) and step == 1):
|
|
637
646
|
return input._shallow_copy() # No-op.
|
|
638
647
|
|
torch_geometric/inspector.py
CHANGED
|
@@ -305,7 +305,7 @@ class Inspector:
|
|
|
305
305
|
according to its function signature from a data blob.
|
|
306
306
|
|
|
307
307
|
Args:
|
|
308
|
-
func (
|
|
308
|
+
func (callable or str): The function.
|
|
309
309
|
kwargs (dict[str, Any]): The data blob which may serve as inputs.
|
|
310
310
|
"""
|
|
311
311
|
out_dict: Dict[str, Any] = {}
|
|
@@ -346,7 +346,7 @@ class Inspector:
|
|
|
346
346
|
type annotations are not found.
|
|
347
347
|
|
|
348
348
|
Args:
|
|
349
|
-
func (
|
|
349
|
+
func (callable or str): The function.
|
|
350
350
|
exclude (list[int or str]): A list of parameters to exclude, either
|
|
351
351
|
given by their name or index. (default: :obj:`None`)
|
|
352
352
|
"""
|
|
@@ -448,6 +448,10 @@ def type_repr(obj: Any, _globals: Dict[str, Any]) -> str:
|
|
|
448
448
|
return '...'
|
|
449
449
|
|
|
450
450
|
if obj.__module__ == 'typing': # Special logic for `typing.*` types:
|
|
451
|
+
|
|
452
|
+
if not hasattr(obj, '_name'):
|
|
453
|
+
return repr(obj)
|
|
454
|
+
|
|
451
455
|
name = obj._name
|
|
452
456
|
if name is None: # In some cases, `_name` is not populated.
|
|
453
457
|
name = str(obj.__origin__).split('.')[-1]
|
torch_geometric/io/fs.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import os.path as osp
|
|
3
|
+
import pickle
|
|
4
|
+
import re
|
|
3
5
|
import sys
|
|
6
|
+
import warnings
|
|
4
7
|
from typing import Any, Dict, List, Literal, Optional, Union, overload
|
|
5
8
|
from uuid import uuid4
|
|
6
9
|
|
|
@@ -211,5 +214,29 @@ def torch_save(data: Any, path: str) -> None:
|
|
|
211
214
|
|
|
212
215
|
|
|
213
216
|
def torch_load(path: str, map_location: Any = None) -> Any:
|
|
217
|
+
if torch_geometric.typing.WITH_PT24:
|
|
218
|
+
try:
|
|
219
|
+
with fsspec.open(path, 'rb') as f:
|
|
220
|
+
return torch.load(f, map_location, weights_only=True)
|
|
221
|
+
except pickle.UnpicklingError as e:
|
|
222
|
+
error_msg = str(e)
|
|
223
|
+
if "add_safe_globals" in error_msg:
|
|
224
|
+
warn_msg = ("Weights only load failed. Please file an issue "
|
|
225
|
+
"to make `torch.load(weights_only=True)` "
|
|
226
|
+
"compatible in your case.")
|
|
227
|
+
match = re.search(r'add_safe_globals\(.*?\)', error_msg)
|
|
228
|
+
if match is not None:
|
|
229
|
+
warnings.warn(
|
|
230
|
+
f"{warn_msg} Please use "
|
|
231
|
+
f"`torch.serialization.{match.group()}` to "
|
|
232
|
+
f"allowlist this global.", stacklevel=2)
|
|
233
|
+
else:
|
|
234
|
+
warnings.warn(warn_msg, stacklevel=2)
|
|
235
|
+
|
|
236
|
+
with fsspec.open(path, 'rb') as f:
|
|
237
|
+
return torch.load(f, map_location, weights_only=False)
|
|
238
|
+
else:
|
|
239
|
+
raise e
|
|
240
|
+
|
|
214
241
|
with fsspec.open(path, 'rb') as f:
|
|
215
242
|
return torch.load(f, map_location)
|
torch_geometric/io/tu.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import os.path as osp
|
|
2
2
|
from typing import Dict, List, Optional, Tuple
|
|
3
3
|
|
|
4
|
-
import numpy as np
|
|
5
4
|
import torch
|
|
6
5
|
from torch import Tensor
|
|
7
6
|
|
|
@@ -108,11 +107,11 @@ def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
|
|
|
108
107
|
|
|
109
108
|
|
|
110
109
|
def split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]:
|
|
111
|
-
node_slice = cumsum(torch.
|
|
110
|
+
node_slice = cumsum(torch.bincount(batch))
|
|
112
111
|
|
|
113
112
|
assert data.edge_index is not None
|
|
114
113
|
row, _ = data.edge_index
|
|
115
|
-
edge_slice = cumsum(torch.
|
|
114
|
+
edge_slice = cumsum(torch.bincount(batch[row]))
|
|
116
115
|
|
|
117
116
|
# Edge indices should start at zero for every graph.
|
|
118
117
|
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
|