pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/edge_index.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1
1
|
import functools
|
2
|
-
import typing
|
3
2
|
from enum import Enum
|
4
3
|
from typing import (
|
5
4
|
Any,
|
6
5
|
Callable,
|
7
6
|
Dict,
|
7
|
+
Iterable,
|
8
8
|
List,
|
9
9
|
Literal,
|
10
10
|
NamedTuple,
|
11
11
|
Optional,
|
12
12
|
Sequence,
|
13
|
-
Set,
|
14
13
|
Tuple,
|
15
14
|
Type,
|
16
15
|
Union,
|
@@ -19,23 +18,17 @@ from typing import (
|
|
19
18
|
)
|
20
19
|
|
21
20
|
import torch
|
21
|
+
import torch.utils._pytree as pytree
|
22
22
|
from torch import Tensor
|
23
23
|
|
24
24
|
import torch_geometric.typing
|
25
|
-
from torch_geometric import is_compiling
|
26
|
-
from torch_geometric.
|
25
|
+
from torch_geometric import Index, is_compiling
|
26
|
+
from torch_geometric.index import index2ptr, ptr2index
|
27
|
+
from torch_geometric.typing import INDEX_DTYPES, SparseTensor
|
27
28
|
|
28
|
-
|
29
|
+
aten = torch.ops.aten
|
29
30
|
|
30
|
-
|
31
|
-
SUPPORTED_DTYPES: Set[torch.dtype] = {
|
32
|
-
torch.int32,
|
33
|
-
torch.int64,
|
34
|
-
}
|
35
|
-
elif not typing.TYPE_CHECKING: # pragma: no cover
|
36
|
-
SUPPORTED_DTYPES: Set[torch.dtype] = {
|
37
|
-
torch.int64,
|
38
|
-
}
|
31
|
+
HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
|
39
32
|
|
40
33
|
ReduceType = Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max']
|
41
34
|
PYG_REDUCE: Dict[ReduceType, ReduceType] = {
|
@@ -114,16 +107,11 @@ def maybe_sub(
|
|
114
107
|
for v, o in zip(value, other))
|
115
108
|
|
116
109
|
|
117
|
-
def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
|
118
|
-
index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
|
119
|
-
return index.repeat_interleave(ptr.diff(), output_size=output_size)
|
120
|
-
|
121
|
-
|
122
110
|
def assert_valid_dtype(tensor: Tensor) -> None:
|
123
|
-
if tensor.dtype not in
|
111
|
+
if tensor.dtype not in INDEX_DTYPES:
|
124
112
|
raise ValueError(f"'EdgeIndex' holds an unsupported data type "
|
125
113
|
f"(got '{tensor.dtype}', but expected one of "
|
126
|
-
f"{
|
114
|
+
f"{INDEX_DTYPES})")
|
127
115
|
|
128
116
|
|
129
117
|
def assert_two_dimensional(tensor: Tensor) -> None:
|
@@ -136,7 +124,7 @@ def assert_two_dimensional(tensor: Tensor) -> None:
|
|
136
124
|
|
137
125
|
|
138
126
|
def assert_contiguous(tensor: Tensor) -> None:
|
139
|
-
if not tensor.is_contiguous():
|
127
|
+
if not tensor[0].is_contiguous() or not tensor[1].is_contiguous():
|
140
128
|
raise ValueError("'EdgeIndex' needs to be contiguous. Please call "
|
141
129
|
"`edge_index.contiguous()` before proceeding.")
|
142
130
|
|
@@ -150,13 +138,13 @@ def assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None:
|
|
150
138
|
|
151
139
|
def assert_sorted(func: Callable) -> Callable:
|
152
140
|
@functools.wraps(func)
|
153
|
-
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
154
|
-
if not
|
155
|
-
cls_name =
|
141
|
+
def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any:
|
142
|
+
if not self.is_sorted:
|
143
|
+
cls_name = self.__class__.__name__
|
156
144
|
raise ValueError(
|
157
145
|
f"Cannot call '{func.__name__}' since '{cls_name}' is not "
|
158
146
|
f"sorted. Please call `{cls_name}.sort_by(...)` first.")
|
159
|
-
return func(*args, **kwargs)
|
147
|
+
return func(self, *args, **kwargs)
|
160
148
|
|
161
149
|
return wrapper
|
162
150
|
|
@@ -185,7 +173,7 @@ class EdgeIndex(Tensor):
|
|
185
173
|
:meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
|
186
174
|
lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
|
187
175
|
|
188
|
-
This representation ensures
|
176
|
+
This representation ensures optimal computation in GNN message passing
|
189
177
|
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
190
178
|
workflows.
|
191
179
|
|
@@ -229,12 +217,12 @@ class EdgeIndex(Tensor):
|
|
229
217
|
# for a basic tutorial on how to subclass `torch.Tensor`.
|
230
218
|
|
231
219
|
# The underlying tensor representation:
|
232
|
-
_data:
|
220
|
+
_data: Tensor
|
233
221
|
|
234
222
|
# The size of the underlying sparse matrix:
|
235
223
|
_sparse_size: Tuple[Optional[int], Optional[int]] = (None, None)
|
236
224
|
|
237
|
-
# Whether the `edge_index`
|
225
|
+
# Whether the `edge_index` representation is non-sorted (`None`), or sorted
|
238
226
|
# based on row or column values.
|
239
227
|
_sort_order: Optional[SortOrder] = None
|
240
228
|
|
@@ -260,6 +248,7 @@ class EdgeIndex(Tensor):
|
|
260
248
|
# original metadata to be able to reconstruct individual edge indices:
|
261
249
|
_cat_metadata: Optional[CatMetadata] = None
|
262
250
|
|
251
|
+
@staticmethod
|
263
252
|
def __new__(
|
264
253
|
cls: Type,
|
265
254
|
data: Any,
|
@@ -336,21 +325,26 @@ class EdgeIndex(Tensor):
|
|
336
325
|
elif sparse_size[0] is None and sparse_size[1] is not None:
|
337
326
|
sparse_size = (sparse_size[1], sparse_size[1])
|
338
327
|
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
328
|
+
out = Tensor._make_wrapper_subclass( # type: ignore
|
329
|
+
cls,
|
330
|
+
size=data.size(),
|
331
|
+
strides=data.stride(),
|
332
|
+
dtype=data.dtype,
|
333
|
+
device=data.device,
|
334
|
+
layout=data.layout,
|
335
|
+
requires_grad=False,
|
336
|
+
)
|
337
|
+
assert isinstance(out, EdgeIndex)
|
343
338
|
|
344
339
|
# Attach metadata:
|
345
|
-
|
346
|
-
if torch_geometric.typing.WITH_PT22:
|
347
|
-
out._data = data
|
340
|
+
out._data = data
|
348
341
|
out._sparse_size = sparse_size
|
349
342
|
out._sort_order = None if sort_order is None else SortOrder(sort_order)
|
350
343
|
out._is_undirected = is_undirected
|
351
344
|
out._indptr = indptr
|
352
345
|
|
353
346
|
if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata:
|
347
|
+
out._data = data._data
|
354
348
|
out._T_perm = data._T_perm
|
355
349
|
out._T_index = data._T_index
|
356
350
|
out._T_indptr = data._T_indptr
|
@@ -378,41 +372,43 @@ class EdgeIndex(Tensor):
|
|
378
372
|
* the sort order is correctly set.
|
379
373
|
* indices are bidirectional in case it is specified as undirected.
|
380
374
|
"""
|
381
|
-
assert_valid_dtype(self)
|
382
|
-
assert_two_dimensional(self)
|
383
|
-
assert_contiguous(self)
|
375
|
+
assert_valid_dtype(self._data)
|
376
|
+
assert_two_dimensional(self._data)
|
377
|
+
assert_contiguous(self._data)
|
384
378
|
if self.is_undirected:
|
385
379
|
assert_symmetric(self.sparse_size())
|
386
380
|
|
387
|
-
if self.numel() > 0 and self.min() < 0:
|
381
|
+
if self.numel() > 0 and self._data.min() < 0:
|
388
382
|
raise ValueError(f"'{self.__class__.__name__}' contains negative "
|
389
383
|
f"indices (got {int(self.min())})")
|
390
384
|
|
391
385
|
if (self.numel() > 0 and self.num_rows is not None
|
392
|
-
and self[0].max() >= self.num_rows):
|
386
|
+
and self._data[0].max() >= self.num_rows):
|
393
387
|
raise ValueError(f"'{self.__class__.__name__}' contains larger "
|
394
388
|
f"indices than its number of rows "
|
395
|
-
f"(got {int(self[0].max())}, but expected
|
396
|
-
f"smaller than {self.num_rows})")
|
389
|
+
f"(got {int(self._data[0].max())}, but expected "
|
390
|
+
f"values smaller than {self.num_rows})")
|
397
391
|
|
398
392
|
if (self.numel() > 0 and self.num_cols is not None
|
399
|
-
and self[1].max() >= self.num_cols):
|
393
|
+
and self._data[1].max() >= self.num_cols):
|
400
394
|
raise ValueError(f"'{self.__class__.__name__}' contains larger "
|
401
395
|
f"indices than its number of columns "
|
402
|
-
f"(got {int(self[1].max())}, but expected
|
403
|
-
f"smaller than {self.num_cols})")
|
396
|
+
f"(got {int(self._data[1].max())}, but expected "
|
397
|
+
f"values smaller than {self.num_cols})")
|
404
398
|
|
405
|
-
if self.is_sorted_by_row and (self[0].diff() < 0).any():
|
399
|
+
if self.is_sorted_by_row and (self._data[0].diff() < 0).any():
|
406
400
|
raise ValueError(f"'{self.__class__.__name__}' is not sorted by "
|
407
401
|
f"row indices")
|
408
402
|
|
409
|
-
if self.is_sorted_by_col and (self[1].diff() < 0).any():
|
403
|
+
if self.is_sorted_by_col and (self._data[1].diff() < 0).any():
|
410
404
|
raise ValueError(f"'{self.__class__.__name__}' is not sorted by "
|
411
405
|
f"column indices")
|
412
406
|
|
413
407
|
if self.is_undirected:
|
414
|
-
flat_index1 =
|
415
|
-
|
408
|
+
flat_index1 = self._data[0] * self.get_num_rows() + self._data[1]
|
409
|
+
flat_index1 = flat_index1.sort()[0]
|
410
|
+
flat_index2 = self._data[1] * self.get_num_cols() + self._data[0]
|
411
|
+
flat_index2 = flat_index2.sort()[0]
|
416
412
|
if not torch.equal(flat_index1, flat_index2):
|
417
413
|
raise ValueError(f"'{self.__class__.__name__}' is not "
|
418
414
|
f"undirected")
|
@@ -482,6 +478,11 @@ class EdgeIndex(Tensor):
|
|
482
478
|
r"""Returns whether indices are bidirectional."""
|
483
479
|
return self._is_undirected
|
484
480
|
|
481
|
+
@property
|
482
|
+
def dtype(self) -> torch.dtype: # type: ignore
|
483
|
+
# TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
|
484
|
+
return self._data.dtype
|
485
|
+
|
485
486
|
# Cache Interface #########################################################
|
486
487
|
|
487
488
|
@overload
|
@@ -511,11 +512,11 @@ class EdgeIndex(Tensor):
|
|
511
512
|
return size
|
512
513
|
|
513
514
|
if self.is_undirected:
|
514
|
-
size = int(self.max()) + 1 if self.numel() > 0 else 0
|
515
|
+
size = int(self._data.max()) + 1 if self.numel() > 0 else 0
|
515
516
|
self._sparse_size = (size, size)
|
516
517
|
return size
|
517
518
|
|
518
|
-
size = int(self[dim].max()) + 1 if self.numel() > 0 else 0
|
519
|
+
size = int(self._data[dim].max()) + 1 if self.numel() > 0 else 0
|
519
520
|
self._sparse_size = set_tuple_item(self._sparse_size, dim, size)
|
520
521
|
return size
|
521
522
|
|
@@ -551,11 +552,8 @@ class EdgeIndex(Tensor):
|
|
551
552
|
if ptr is None or size is None:
|
552
553
|
return None
|
553
554
|
|
554
|
-
if ptr.numel() - 1
|
555
|
-
return ptr
|
556
|
-
|
557
|
-
if ptr.numel() - 1 > size:
|
558
|
-
return None
|
555
|
+
if ptr.numel() - 1 >= size:
|
556
|
+
return ptr[:size + 1]
|
559
557
|
|
560
558
|
fill_value = ptr.new_full(
|
561
559
|
(size - ptr.numel() + 1, ),
|
@@ -599,11 +597,7 @@ class EdgeIndex(Tensor):
|
|
599
597
|
return self._T_indptr
|
600
598
|
|
601
599
|
dim = 0 if self.is_sorted_by_row else 1
|
602
|
-
self._indptr =
|
603
|
-
self[dim],
|
604
|
-
self.get_sparse_size(dim),
|
605
|
-
out_int32=self.dtype != torch.int64,
|
606
|
-
)
|
600
|
+
self._indptr = index2ptr(self._data[dim], self.get_sparse_size(dim))
|
607
601
|
|
608
602
|
return self._indptr
|
609
603
|
|
@@ -614,13 +608,14 @@ class EdgeIndex(Tensor):
|
|
614
608
|
dim = 1 if self.is_sorted_by_row else 0
|
615
609
|
|
616
610
|
if self._T_perm is None:
|
617
|
-
|
611
|
+
max_index = self.get_sparse_size(dim)
|
612
|
+
index, perm = index_sort(self._data[dim], max_index)
|
618
613
|
self._T_index = set_tuple_item(self._T_index, dim, index)
|
619
|
-
self._T_perm = perm
|
614
|
+
self._T_perm = perm.to(self.dtype)
|
620
615
|
|
621
616
|
if self._T_index[1 - dim] is None:
|
622
617
|
self._T_index = set_tuple_item( #
|
623
|
-
self._T_index, 1 - dim, self[1 - dim][self._T_perm])
|
618
|
+
self._T_index, 1 - dim, self._data[1 - dim][self._T_perm])
|
624
619
|
|
625
620
|
row, col = self._T_index
|
626
621
|
assert row is not None and col is not None
|
@@ -628,12 +623,12 @@ class EdgeIndex(Tensor):
|
|
628
623
|
return (row, col), self._T_perm
|
629
624
|
|
630
625
|
@assert_sorted
|
631
|
-
def get_csr(self) -> Tuple[Tuple[Tensor, Tensor],
|
626
|
+
def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]:
|
632
627
|
r"""Returns the compressed CSR representation
|
633
628
|
:obj:`(rowptr, col), perm` in case :class:`EdgeIndex` is sorted.
|
634
629
|
"""
|
635
630
|
if self.is_sorted_by_row:
|
636
|
-
return (self.get_indptr(), self[1]),
|
631
|
+
return (self.get_indptr(), self._data[1]), None
|
637
632
|
|
638
633
|
assert self.is_sorted_by_col
|
639
634
|
(row, col), perm = self._sort_by_transpose()
|
@@ -643,21 +638,17 @@ class EdgeIndex(Tensor):
|
|
643
638
|
elif self.is_undirected and self._indptr is not None:
|
644
639
|
rowptr = self._indptr
|
645
640
|
else:
|
646
|
-
rowptr = self._T_indptr =
|
647
|
-
row,
|
648
|
-
self.get_num_rows(),
|
649
|
-
out_int32=self.dtype != torch.int64,
|
650
|
-
)
|
641
|
+
rowptr = self._T_indptr = index2ptr(row, self.get_num_rows())
|
651
642
|
|
652
643
|
return (rowptr, col), perm
|
653
644
|
|
654
645
|
@assert_sorted
|
655
|
-
def get_csc(self) -> Tuple[Tuple[Tensor, Tensor],
|
646
|
+
def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]:
|
656
647
|
r"""Returns the compressed CSC representation
|
657
648
|
:obj:`(colptr, row), perm` in case :class:`EdgeIndex` is sorted.
|
658
649
|
"""
|
659
650
|
if self.is_sorted_by_col:
|
660
|
-
return (self.get_indptr(), self[0]),
|
651
|
+
return (self.get_indptr(), self._data[0]), None
|
661
652
|
|
662
653
|
assert self.is_sorted_by_row
|
663
654
|
(row, col), perm = self._sort_by_transpose()
|
@@ -667,11 +658,7 @@ class EdgeIndex(Tensor):
|
|
667
658
|
elif self.is_undirected and self._indptr is not None:
|
668
659
|
colptr = self._indptr
|
669
660
|
else:
|
670
|
-
colptr = self._T_indptr =
|
671
|
-
col,
|
672
|
-
self.get_num_cols(),
|
673
|
-
out_int32=self.dtype != torch.int64,
|
674
|
-
)
|
661
|
+
colptr = self._T_indptr = index2ptr(col, self.get_num_cols())
|
675
662
|
|
676
663
|
return (colptr, row), perm
|
677
664
|
|
@@ -710,11 +697,32 @@ class EdgeIndex(Tensor):
|
|
710
697
|
|
711
698
|
# Methods #################################################################
|
712
699
|
|
700
|
+
def share_memory_(self) -> 'EdgeIndex':
|
701
|
+
"""""" # noqa: D419
|
702
|
+
self._data.share_memory_()
|
703
|
+
if self._indptr is not None:
|
704
|
+
self._indptr.share_memory_()
|
705
|
+
if self._T_perm is not None:
|
706
|
+
self._T_perm.share_memory_()
|
707
|
+
if self._T_index[0] is not None:
|
708
|
+
self._T_index[0].share_memory_()
|
709
|
+
if self._T_index[1] is not None:
|
710
|
+
self._T_index[1].share_memory_()
|
711
|
+
if self._T_indptr is not None:
|
712
|
+
self._T_indptr.share_memory_()
|
713
|
+
if self._value is not None:
|
714
|
+
self._value.share_memory_()
|
715
|
+
return self
|
716
|
+
|
717
|
+
def is_shared(self) -> bool:
|
718
|
+
"""""" # noqa: D419
|
719
|
+
return self._data.is_shared()
|
720
|
+
|
713
721
|
def as_tensor(self) -> Tensor:
|
714
722
|
r"""Zero-copies the :class:`EdgeIndex` representation back to a
|
715
723
|
:class:`torch.Tensor` representation.
|
716
724
|
"""
|
717
|
-
return self.
|
725
|
+
return self._data
|
718
726
|
|
719
727
|
def sort_by(
|
720
728
|
self,
|
@@ -735,7 +743,7 @@ class EdgeIndex(Tensor):
|
|
735
743
|
sort_order = SortOrder(sort_order)
|
736
744
|
|
737
745
|
if self._sort_order == sort_order: # Nothing to do.
|
738
|
-
return SortReturnType(self,
|
746
|
+
return SortReturnType(self, None)
|
739
747
|
|
740
748
|
if self.is_sorted:
|
741
749
|
(row, col), perm = self._sort_by_transpose()
|
@@ -743,12 +751,12 @@ class EdgeIndex(Tensor):
|
|
743
751
|
|
744
752
|
# Otherwise, perform sorting:
|
745
753
|
elif sort_order == SortOrder.ROW:
|
746
|
-
row, perm = index_sort(self[0], self.get_num_rows(), stable)
|
747
|
-
edge_index = torch.stack([row, self[1][perm]], dim=0)
|
754
|
+
row, perm = index_sort(self._data[0], self.get_num_rows(), stable)
|
755
|
+
edge_index = torch.stack([row, self._data[1][perm]], dim=0)
|
748
756
|
|
749
757
|
else:
|
750
|
-
col, perm = index_sort(self[1], self.get_num_cols(), stable)
|
751
|
-
edge_index = torch.stack([self[0][perm], col], dim=0)
|
758
|
+
col, perm = index_sort(self._data[1], self.get_num_cols(), stable)
|
759
|
+
edge_index = torch.stack([self._data[0][perm], col], dim=0)
|
752
760
|
|
753
761
|
out = self.__class__(edge_index)
|
754
762
|
|
@@ -798,7 +806,7 @@ class EdgeIndex(Tensor):
|
|
798
806
|
size = size + value.size()[1:] # type: ignore
|
799
807
|
|
800
808
|
out = torch.full(size, fill_value, dtype=dtype, device=self.device)
|
801
|
-
out[self[0], self[1]] = value if value is not None else 1
|
809
|
+
out[self._data[0], self._data[1]] = value if value is not None else 1
|
802
810
|
|
803
811
|
return out
|
804
812
|
|
@@ -812,19 +820,28 @@ class EdgeIndex(Tensor):
|
|
812
820
|
:obj:`1.0`. (default: :obj:`None`)
|
813
821
|
"""
|
814
822
|
value = self._get_value() if value is None else value
|
815
|
-
|
816
|
-
|
823
|
+
|
824
|
+
if not torch_geometric.typing.WITH_PT21:
|
825
|
+
out = torch.sparse_coo_tensor(
|
826
|
+
indices=self._data,
|
827
|
+
values=value,
|
828
|
+
size=self.get_sparse_size(),
|
829
|
+
device=self.device,
|
830
|
+
requires_grad=value.requires_grad,
|
831
|
+
)
|
832
|
+
if self.is_sorted_by_row:
|
833
|
+
out = out._coalesced_(True)
|
834
|
+
return out
|
835
|
+
|
836
|
+
return torch.sparse_coo_tensor(
|
837
|
+
indices=self._data,
|
817
838
|
values=value,
|
818
839
|
size=self.get_sparse_size(),
|
819
840
|
device=self.device,
|
820
841
|
requires_grad=value.requires_grad,
|
842
|
+
is_coalesced=True if self.is_sorted_by_row else None,
|
821
843
|
)
|
822
844
|
|
823
|
-
if self.is_sorted_by_row:
|
824
|
-
out = out._coalesced_(True)
|
825
|
-
|
826
|
-
return out
|
827
|
-
|
828
845
|
def to_sparse_csr( # type: ignore
|
829
846
|
self,
|
830
847
|
value: Optional[Tensor] = None,
|
@@ -838,7 +855,10 @@ class EdgeIndex(Tensor):
|
|
838
855
|
:obj:`1.0`. (default: :obj:`None`)
|
839
856
|
"""
|
840
857
|
(rowptr, col), perm = self.get_csr()
|
841
|
-
|
858
|
+
if value is not None and perm is not None:
|
859
|
+
value = value[perm]
|
860
|
+
elif value is None:
|
861
|
+
value = self._get_value()
|
842
862
|
|
843
863
|
return torch.sparse_csr_tensor(
|
844
864
|
crow_indices=rowptr,
|
@@ -866,7 +886,10 @@ class EdgeIndex(Tensor):
|
|
866
886
|
"'to_sparse_csc' not supported for PyTorch < 1.12")
|
867
887
|
|
868
888
|
(colptr, row), perm = self.get_csc()
|
869
|
-
|
889
|
+
if value is not None and perm is not None:
|
890
|
+
value = value[perm]
|
891
|
+
elif value is None:
|
892
|
+
value = self._get_value()
|
870
893
|
|
871
894
|
return torch.sparse_csc_tensor(
|
872
895
|
ccol_indices=colptr,
|
@@ -916,8 +939,8 @@ class EdgeIndex(Tensor):
|
|
916
939
|
(default: :obj:`None`)
|
917
940
|
"""
|
918
941
|
return SparseTensor(
|
919
|
-
row=self[0],
|
920
|
-
col=self[1],
|
942
|
+
row=self._data[0],
|
943
|
+
col=self._data[1],
|
921
944
|
rowptr=self._indptr if self.is_sorted_by_row else None,
|
922
945
|
value=value,
|
923
946
|
sparse_sizes=self.get_sparse_size(),
|
@@ -925,7 +948,7 @@ class EdgeIndex(Tensor):
|
|
925
948
|
trust_data=True,
|
926
949
|
)
|
927
950
|
|
928
|
-
# TODO
|
951
|
+
# TODO Investigate how to avoid overlapping return types here.
|
929
952
|
@overload
|
930
953
|
def matmul( # type: ignore
|
931
954
|
self,
|
@@ -1034,93 +1057,148 @@ class EdgeIndex(Tensor):
|
|
1034
1057
|
f"(got {start})")
|
1035
1058
|
|
1036
1059
|
if dim == 0:
|
1037
|
-
|
1038
|
-
|
1060
|
+
if self.is_sorted_by_row:
|
1061
|
+
(rowptr, col), _ = self.get_csr()
|
1062
|
+
rowptr = rowptr.narrow(0, start, length + 1)
|
1063
|
+
|
1064
|
+
if rowptr.numel() < 2:
|
1065
|
+
row, col = self._data[0, :0], self._data[1, :0]
|
1066
|
+
rowptr = None
|
1067
|
+
num_rows = 0
|
1068
|
+
else:
|
1069
|
+
col = col[rowptr[0]:rowptr[-1]]
|
1070
|
+
rowptr = rowptr - rowptr[0]
|
1071
|
+
num_rows = rowptr.numel() - 1
|
1072
|
+
|
1073
|
+
row = torch.arange(
|
1074
|
+
num_rows,
|
1075
|
+
dtype=col.dtype,
|
1076
|
+
device=col.device,
|
1077
|
+
).repeat_interleave(
|
1078
|
+
rowptr.diff(),
|
1079
|
+
output_size=col.numel(),
|
1080
|
+
)
|
1039
1081
|
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
else:
|
1045
|
-
col = col[rowptr[0]:rowptr[-1]]
|
1046
|
-
rowptr = rowptr - rowptr[0]
|
1047
|
-
num_rows = rowptr.numel() - 1
|
1048
|
-
|
1049
|
-
row = torch.arange(
|
1050
|
-
num_rows,
|
1051
|
-
dtype=col.dtype,
|
1052
|
-
device=col.device,
|
1053
|
-
).repeat_interleave(
|
1054
|
-
rowptr.diff(),
|
1055
|
-
output_size=col.numel(),
|
1082
|
+
edge_index = EdgeIndex(
|
1083
|
+
torch.stack([row, col], dim=0),
|
1084
|
+
sparse_size=(num_rows, self.sparse_size(1)),
|
1085
|
+
sort_order='row',
|
1056
1086
|
)
|
1087
|
+
edge_index._indptr = rowptr
|
1088
|
+
return edge_index
|
1057
1089
|
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1090
|
+
else:
|
1091
|
+
mask = self._data[0] >= start
|
1092
|
+
mask &= self._data[0] < (start + length)
|
1093
|
+
offset = torch.tensor([[start], [0]], device=self.device)
|
1094
|
+
edge_index = self[:, mask].sub_(offset) # type: ignore
|
1095
|
+
edge_index._sparse_size = (length, edge_index._sparse_size[1])
|
1096
|
+
return edge_index
|
1097
|
+
|
1098
|
+
else:
|
1099
|
+
assert dim == 1
|
1065
1100
|
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1101
|
+
if self.is_sorted_by_col:
|
1102
|
+
(colptr, row), _ = self.get_csc()
|
1103
|
+
colptr = colptr.narrow(0, start, length + 1)
|
1069
1104
|
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1105
|
+
if colptr.numel() < 2:
|
1106
|
+
row, col = self._data[0, :0], self._data[1, :0]
|
1107
|
+
colptr = None
|
1108
|
+
num_cols = 0
|
1109
|
+
else:
|
1110
|
+
row = row[colptr[0]:colptr[-1]]
|
1111
|
+
colptr = colptr - colptr[0]
|
1112
|
+
num_cols = colptr.numel() - 1
|
1113
|
+
|
1114
|
+
col = torch.arange(
|
1115
|
+
num_cols,
|
1116
|
+
dtype=row.dtype,
|
1117
|
+
device=row.device,
|
1118
|
+
).repeat_interleave(
|
1119
|
+
colptr.diff(),
|
1120
|
+
output_size=row.numel(),
|
1121
|
+
)
|
1122
|
+
|
1123
|
+
edge_index = EdgeIndex(
|
1124
|
+
torch.stack([row, col], dim=0),
|
1125
|
+
sparse_size=(self.sparse_size(0), num_cols),
|
1126
|
+
sort_order='col',
|
1086
1127
|
)
|
1128
|
+
edge_index._indptr = colptr
|
1129
|
+
return edge_index
|
1087
1130
|
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1131
|
+
else:
|
1132
|
+
mask = self._data[1] >= start
|
1133
|
+
mask &= self._data[1] < (start + length)
|
1134
|
+
offset = torch.tensor([[0], [start]], device=self.device)
|
1135
|
+
edge_index = self[:, mask].sub_(offset) # type: ignore
|
1136
|
+
edge_index._sparse_size = (edge_index._sparse_size[0], length)
|
1137
|
+
return edge_index
|
1138
|
+
|
1139
|
+
def to_vector(self) -> Tensor:
|
1140
|
+
r"""Converts :class:`EdgeIndex` into a one-dimensional index
|
1141
|
+
vector representation.
|
1142
|
+
"""
|
1143
|
+
num_rows, num_cols = self.get_sparse_size()
|
1144
|
+
|
1145
|
+
if num_rows * num_cols > torch_geometric.typing.MAX_INT64:
|
1146
|
+
raise ValueError("'to_vector()' will result in an overflow")
|
1147
|
+
|
1148
|
+
return self._data[0] * num_rows + self._data[1]
|
1149
|
+
|
1150
|
+
# PyTorch/Python builtins #################################################
|
1095
1151
|
|
1096
1152
|
def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1153
|
+
attrs = ['_data']
|
1154
|
+
if self._indptr is not None:
|
1155
|
+
attrs.append('_indptr')
|
1156
|
+
if self._T_perm is not None:
|
1157
|
+
attrs.append('_T_perm')
|
1158
|
+
# TODO We cannot save `_T_index` for now since it is stored as tuple.
|
1159
|
+
if self._T_indptr is not None:
|
1160
|
+
attrs.append('_T_indptr')
|
1161
|
+
|
1162
|
+
ctx = (
|
1163
|
+
self._sparse_size,
|
1164
|
+
self._sort_order,
|
1165
|
+
self._is_undirected,
|
1166
|
+
self._cat_metadata,
|
1167
|
+
)
|
1168
|
+
|
1169
|
+
return attrs, ctx
|
1104
1170
|
|
1105
1171
|
@staticmethod
|
1106
1172
|
def __tensor_unflatten__(
|
1107
|
-
inner_tensors:
|
1173
|
+
inner_tensors: Dict[str, Any],
|
1108
1174
|
ctx: Tuple[Any, ...],
|
1109
|
-
|
1110
|
-
|
1175
|
+
outer_size: Tuple[int, ...],
|
1176
|
+
outer_stride: Tuple[int, ...],
|
1111
1177
|
) -> 'EdgeIndex':
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1178
|
+
edge_index = EdgeIndex(
|
1179
|
+
inner_tensors['_data'],
|
1180
|
+
sparse_size=ctx[0],
|
1181
|
+
sort_order=ctx[1],
|
1182
|
+
is_undirected=ctx[2],
|
1183
|
+
)
|
1184
|
+
|
1185
|
+
edge_index._indptr = inner_tensors.get('_indptr', None)
|
1186
|
+
edge_index._T_perm = inner_tensors.get('_T_perm', None)
|
1187
|
+
edge_index._T_indptr = inner_tensors.get('_T_indptr', None)
|
1188
|
+
edge_index._cat_metadata = ctx[3]
|
1189
|
+
|
1190
|
+
return edge_index
|
1191
|
+
|
1192
|
+
# Prevent auto-wrapping outputs back into the proper subclass type:
|
1193
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
1116
1194
|
|
1117
1195
|
@classmethod
|
1118
|
-
def
|
1196
|
+
def __torch_dispatch__(
|
1119
1197
|
cls: Type,
|
1120
|
-
func: Callable,
|
1121
|
-
types:
|
1122
|
-
args: Tuple[Any, ...] = (),
|
1123
|
-
kwargs: Optional[Dict[
|
1198
|
+
func: Callable[..., Any],
|
1199
|
+
types: Iterable[Type[Any]],
|
1200
|
+
args: Iterable[Tuple[Any, ...]] = (),
|
1201
|
+
kwargs: Optional[Dict[Any, Any]] = None,
|
1124
1202
|
) -> Any:
|
1125
1203
|
# `EdgeIndex` should be treated as a regular PyTorch tensor for all
|
1126
1204
|
# standard PyTorch functionalities. However,
|
@@ -1136,53 +1214,69 @@ class EdgeIndex(Tensor):
|
|
1136
1214
|
if func in HANDLED_FUNCTIONS:
|
1137
1215
|
return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
|
1138
1216
|
|
1139
|
-
# For all other PyTorch functions, we
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
class SortReturnType(NamedTuple):
|
1145
|
-
values: EdgeIndex
|
1146
|
-
indices: Union[Tensor, slice]
|
1217
|
+
# For all other PyTorch functions, we treat them as vanilla tensors.
|
1218
|
+
args = pytree.tree_map_only(EdgeIndex, lambda x: x._data, args)
|
1219
|
+
if kwargs is not None:
|
1220
|
+
kwargs = pytree.tree_map_only(EdgeIndex, lambda x: x._data, kwargs)
|
1221
|
+
return func(*args, **(kwargs or {}))
|
1147
1222
|
|
1223
|
+
def __repr__(self) -> str: # type: ignore
|
1224
|
+
prefix = f'{self.__class__.__name__}('
|
1225
|
+
indent = len(prefix)
|
1226
|
+
tensor_str = torch._tensor_str._tensor_str(self._data, indent)
|
1148
1227
|
|
1149
|
-
|
1150
|
-
|
1151
|
-
tensor: EdgeIndex,
|
1152
|
-
*,
|
1153
|
-
tensor_contents: Optional[str] = None,
|
1154
|
-
) -> str:
|
1155
|
-
# Monkey-patch `torch._tensor_str._add_suffixes`. There might exist better
|
1156
|
-
# solutions to attach additional metadata, but this seems to be the most
|
1157
|
-
# straightforward one to inherit most of the `torch.Tensor` print logic:
|
1158
|
-
orig_fn = torch._tensor_str._add_suffixes
|
1159
|
-
|
1160
|
-
def _add_suffixes(
|
1161
|
-
tensor_str: str,
|
1162
|
-
suffixes: List[str],
|
1163
|
-
indent: int,
|
1164
|
-
force_newline: bool,
|
1165
|
-
) -> str:
|
1166
|
-
|
1167
|
-
num_rows, num_cols = tensor.sparse_size()
|
1228
|
+
suffixes = []
|
1229
|
+
num_rows, num_cols = self.sparse_size()
|
1168
1230
|
if num_rows is not None or num_cols is not None:
|
1169
1231
|
size_repr = f"({num_rows or '?'}, {num_cols or '?'})"
|
1170
1232
|
suffixes.append(f'sparse_size={size_repr}')
|
1233
|
+
suffixes.append(f'nnz={self._data.size(1)}')
|
1234
|
+
if (self.device.type != torch._C._get_default_device()
|
1235
|
+
or (self.device.type == 'cuda'
|
1236
|
+
and torch.cuda.current_device() != self.device.index)
|
1237
|
+
or (self.device.type == 'mps')):
|
1238
|
+
suffixes.append(f"device='{self.device}'")
|
1239
|
+
if self.dtype != torch.int64:
|
1240
|
+
suffixes.append(f'dtype={self.dtype}')
|
1241
|
+
if self.is_sorted:
|
1242
|
+
suffixes.append(f'sort_order={self.sort_order}')
|
1243
|
+
if self.is_undirected:
|
1244
|
+
suffixes.append('is_undirected=True')
|
1171
1245
|
|
1172
|
-
|
1246
|
+
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
|
1247
|
+
indent, force_newline=False)
|
1173
1248
|
|
1174
|
-
|
1175
|
-
suffixes.append(f'sort_order={tensor.sort_order}')
|
1249
|
+
# Helpers #################################################################
|
1176
1250
|
|
1177
|
-
|
1178
|
-
|
1251
|
+
def _shallow_copy(self) -> 'EdgeIndex':
|
1252
|
+
out = EdgeIndex(self._data)
|
1253
|
+
out._sparse_size = self._sparse_size
|
1254
|
+
out._sort_order = self._sort_order
|
1255
|
+
out._is_undirected = self._is_undirected
|
1256
|
+
out._indptr = self._indptr
|
1257
|
+
out._T_perm = self._T_perm
|
1258
|
+
out._T_index = self._T_index
|
1259
|
+
out._T_indptr = self._T_indptr
|
1260
|
+
out._value = self._value
|
1261
|
+
out._cat_metadata = self._cat_metadata
|
1262
|
+
return out
|
1179
1263
|
|
1180
|
-
|
1264
|
+
def _clear_metadata(self) -> 'EdgeIndex':
|
1265
|
+
self._sparse_size = (None, None)
|
1266
|
+
self._sort_order = None
|
1267
|
+
self._is_undirected = False
|
1268
|
+
self._indptr = None
|
1269
|
+
self._T_perm = None
|
1270
|
+
self._T_index = (None, None)
|
1271
|
+
self._T_indptr = None
|
1272
|
+
self._value = None
|
1273
|
+
self._cat_metadata = None
|
1274
|
+
return self
|
1181
1275
|
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1276
|
+
|
1277
|
+
class SortReturnType(NamedTuple):
|
1278
|
+
values: EdgeIndex
|
1279
|
+
indices: Optional[Tensor]
|
1186
1280
|
|
1187
1281
|
|
1188
1282
|
def apply_(
|
@@ -1190,15 +1284,24 @@ def apply_(
|
|
1190
1284
|
fn: Callable,
|
1191
1285
|
*args: Any,
|
1192
1286
|
**kwargs: Any,
|
1193
|
-
) -> EdgeIndex:
|
1287
|
+
) -> Union[EdgeIndex, Tensor]:
|
1288
|
+
|
1289
|
+
data = fn(tensor._data, *args, **kwargs)
|
1290
|
+
|
1291
|
+
if data.dtype not in INDEX_DTYPES:
|
1292
|
+
return data
|
1194
1293
|
|
1195
|
-
|
1196
|
-
|
1294
|
+
if tensor._data.data_ptr() != data.data_ptr():
|
1295
|
+
out = EdgeIndex(data)
|
1296
|
+
else: # In-place:
|
1297
|
+
tensor._data = data
|
1298
|
+
out = tensor
|
1197
1299
|
|
1198
1300
|
# Copy metadata:
|
1199
|
-
out._sparse_size = tensor.
|
1301
|
+
out._sparse_size = tensor._sparse_size
|
1200
1302
|
out._sort_order = tensor._sort_order
|
1201
1303
|
out._is_undirected = tensor._is_undirected
|
1304
|
+
out._cat_metadata = tensor._cat_metadata
|
1202
1305
|
|
1203
1306
|
# Convert cache (but do not consider `_value`):
|
1204
1307
|
if tensor._indptr is not None:
|
@@ -1220,77 +1323,68 @@ def apply_(
|
|
1220
1323
|
return out
|
1221
1324
|
|
1222
1325
|
|
1223
|
-
@implements(
|
1224
|
-
|
1225
|
-
def clone(tensor: EdgeIndex) -> EdgeIndex:
|
1226
|
-
return apply_(tensor, Tensor.clone)
|
1227
|
-
|
1228
|
-
|
1229
|
-
@implements(Tensor.to)
|
1230
|
-
def to(
|
1326
|
+
@implements(aten.clone.default)
|
1327
|
+
def _clone(
|
1231
1328
|
tensor: EdgeIndex,
|
1232
|
-
|
1233
|
-
|
1234
|
-
) ->
|
1235
|
-
out = apply_(tensor,
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
@implements(Tensor.int)
|
1240
|
-
def _int(tensor: EdgeIndex) -> EdgeIndex:
|
1241
|
-
return to(tensor, torch.int32)
|
1242
|
-
|
1243
|
-
|
1244
|
-
@implements(Tensor.long)
|
1245
|
-
def long(tensor: EdgeIndex, *args: Any, **kwargs: Any) -> EdgeIndex:
|
1246
|
-
return to(tensor, torch.int64)
|
1247
|
-
|
1248
|
-
|
1249
|
-
@implements(Tensor.cpu)
|
1250
|
-
def cpu(tensor: EdgeIndex, *args: Any, **kwargs: Any) -> EdgeIndex:
|
1251
|
-
return apply_(tensor, Tensor.cpu, *args, **kwargs)
|
1329
|
+
*,
|
1330
|
+
memory_format: torch.memory_format = torch.preserve_format,
|
1331
|
+
) -> EdgeIndex:
|
1332
|
+
out = apply_(tensor, aten.clone.default, memory_format=memory_format)
|
1333
|
+
assert isinstance(out, EdgeIndex)
|
1334
|
+
return out
|
1252
1335
|
|
1253
1336
|
|
1254
|
-
@implements(
|
1255
|
-
def
|
1337
|
+
@implements(aten._to_copy.default)
|
1338
|
+
def _to_copy(
|
1256
1339
|
tensor: EdgeIndex,
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1340
|
+
*,
|
1341
|
+
dtype: Optional[torch.dtype] = None,
|
1342
|
+
layout: Optional[torch.layout] = None,
|
1343
|
+
device: Optional[torch.device] = None,
|
1344
|
+
pin_memory: bool = False,
|
1345
|
+
non_blocking: bool = False,
|
1346
|
+
memory_format: Optional[torch.memory_format] = None,
|
1347
|
+
) -> Union[EdgeIndex, Tensor]:
|
1348
|
+
return apply_(
|
1349
|
+
tensor,
|
1350
|
+
aten._to_copy.default,
|
1351
|
+
dtype=dtype,
|
1352
|
+
layout=layout,
|
1353
|
+
device=device,
|
1354
|
+
pin_memory=pin_memory,
|
1355
|
+
non_blocking=non_blocking,
|
1356
|
+
memory_format=memory_format,
|
1357
|
+
)
|
1261
1358
|
|
1262
1359
|
|
1263
|
-
@implements(
|
1264
|
-
def
|
1265
|
-
return
|
1360
|
+
@implements(aten.alias.default)
|
1361
|
+
def _alias(tensor: EdgeIndex) -> EdgeIndex:
|
1362
|
+
return tensor._shallow_copy()
|
1266
1363
|
|
1267
1364
|
|
1268
|
-
@implements(
|
1269
|
-
def
|
1270
|
-
|
1365
|
+
@implements(aten._pin_memory.default)
|
1366
|
+
def _pin_memory(tensor: EdgeIndex) -> EdgeIndex:
|
1367
|
+
out = apply_(tensor, aten._pin_memory.default)
|
1368
|
+
assert isinstance(out, EdgeIndex)
|
1369
|
+
return out
|
1271
1370
|
|
1272
1371
|
|
1273
|
-
@implements(
|
1274
|
-
def
|
1372
|
+
@implements(aten.cat.default)
|
1373
|
+
def _cat(
|
1275
1374
|
tensors: List[Union[EdgeIndex, Tensor]],
|
1276
1375
|
dim: int = 0,
|
1277
|
-
*,
|
1278
|
-
out: Optional[Tensor] = None,
|
1279
1376
|
) -> Union[EdgeIndex, Tensor]:
|
1280
1377
|
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
output = Tensor.__torch_function__(torch.cat, (Tensor, ), (tensors, dim),
|
1285
|
-
dict(out=out))
|
1378
|
+
data_list = pytree.tree_map_only(EdgeIndex, lambda x: x._data, tensors)
|
1379
|
+
data = aten.cat.default(data_list, dim=dim)
|
1286
1380
|
|
1287
1381
|
if dim != 1 and dim != -1: # No valid `EdgeIndex` anymore.
|
1288
|
-
return
|
1382
|
+
return data
|
1289
1383
|
|
1290
1384
|
if any([not isinstance(tensor, EdgeIndex) for tensor in tensors]):
|
1291
|
-
return
|
1385
|
+
return data
|
1292
1386
|
|
1293
|
-
|
1387
|
+
out = EdgeIndex(data)
|
1294
1388
|
|
1295
1389
|
nnz_list = [t.size(1) for t in tensors]
|
1296
1390
|
sparse_size_list = [t.sparse_size() for t in tensors] # type: ignore
|
@@ -1312,36 +1406,31 @@ def cat(
|
|
1312
1406
|
total_num_cols = None
|
1313
1407
|
break
|
1314
1408
|
assert isinstance(total_num_cols, int)
|
1315
|
-
|
1409
|
+
total_num_cols = max(num_cols, total_num_cols)
|
1316
1410
|
|
1317
|
-
|
1411
|
+
out._sparse_size = (total_num_rows, total_num_cols)
|
1318
1412
|
|
1319
1413
|
# Post-process `is_undirected`:
|
1320
|
-
|
1414
|
+
out._is_undirected = all(is_undirected_list)
|
1321
1415
|
|
1322
|
-
|
1416
|
+
out._cat_metadata = CatMetadata(
|
1323
1417
|
nnz=nnz_list,
|
1324
1418
|
sparse_size=sparse_size_list,
|
1325
1419
|
sort_order=sort_order_list,
|
1326
1420
|
is_undirected=is_undirected_list,
|
1327
1421
|
)
|
1328
1422
|
|
1329
|
-
return
|
1423
|
+
return out
|
1330
1424
|
|
1331
1425
|
|
1332
|
-
@implements(
|
1333
|
-
|
1334
|
-
def flip(
|
1426
|
+
@implements(aten.flip.default)
|
1427
|
+
def _flip(
|
1335
1428
|
input: EdgeIndex,
|
1336
|
-
dims: Union[
|
1337
|
-
) ->
|
1338
|
-
|
1339
|
-
if isinstance(dims, int):
|
1340
|
-
dims = [dims]
|
1341
|
-
assert isinstance(dims, (tuple, list))
|
1429
|
+
dims: Union[List[int], Tuple[int, ...]],
|
1430
|
+
) -> EdgeIndex:
|
1342
1431
|
|
1343
|
-
|
1344
|
-
out =
|
1432
|
+
data = aten.flip.default(input._data, dims)
|
1433
|
+
out = EdgeIndex(data)
|
1345
1434
|
|
1346
1435
|
out._value = input._value
|
1347
1436
|
out._is_undirected = input.is_undirected
|
@@ -1364,238 +1453,309 @@ def flip(
|
|
1364
1453
|
return out
|
1365
1454
|
|
1366
1455
|
|
1367
|
-
@implements(
|
1368
|
-
|
1369
|
-
def index_select(
|
1456
|
+
@implements(aten.index_select.default)
|
1457
|
+
def _index_select(
|
1370
1458
|
input: EdgeIndex,
|
1371
1459
|
dim: int,
|
1372
1460
|
index: Tensor,
|
1373
|
-
*,
|
1374
|
-
out: Optional[Tensor] = None,
|
1375
1461
|
) -> Union[EdgeIndex, Tensor]:
|
1376
1462
|
|
1377
|
-
|
1378
|
-
torch.index_select, (Tensor, ), (input, dim, index), dict(out=out))
|
1463
|
+
out = aten.index_select.default(input._data, dim, index)
|
1379
1464
|
|
1380
1465
|
if dim == 1 or dim == -1:
|
1381
|
-
|
1382
|
-
|
1466
|
+
out = EdgeIndex(out)
|
1467
|
+
out._sparse_size = input.sparse_size()
|
1383
1468
|
|
1384
|
-
return
|
1469
|
+
return out
|
1385
1470
|
|
1386
1471
|
|
1387
|
-
@implements(
|
1388
|
-
|
1389
|
-
def narrow(
|
1472
|
+
@implements(aten.slice.Tensor)
|
1473
|
+
def _slice(
|
1390
1474
|
input: EdgeIndex,
|
1391
1475
|
dim: int,
|
1392
|
-
start:
|
1393
|
-
|
1476
|
+
start: Optional[int] = None,
|
1477
|
+
end: Optional[int] = None,
|
1478
|
+
step: int = 1,
|
1394
1479
|
) -> Union[EdgeIndex, Tensor]:
|
1395
1480
|
|
1396
|
-
|
1397
|
-
|
1481
|
+
if ((start is None or start <= 0)
|
1482
|
+
and (end is None or end > input.size(dim)) and step == 1):
|
1483
|
+
return input._shallow_copy() # No-op.
|
1484
|
+
|
1485
|
+
out = aten.slice.Tensor(input._data, dim, start, end, step)
|
1398
1486
|
|
1399
1487
|
if dim == 1 or dim == -1:
|
1400
|
-
|
1488
|
+
if step != 1:
|
1489
|
+
out = out.contiguous()
|
1490
|
+
|
1491
|
+
out = EdgeIndex(out)
|
1401
1492
|
out._sparse_size = input.sparse_size()
|
1402
1493
|
# NOTE We could potentially maintain `rowptr`/`colptr` attributes here,
|
1403
1494
|
# but it is not really clear if this is worth it. The most important
|
1404
1495
|
# information, the sort order, needs to be maintained though:
|
1405
|
-
|
1496
|
+
if step >= 0:
|
1497
|
+
out._sort_order = input._sort_order
|
1498
|
+
else:
|
1499
|
+
if input._sort_order == SortOrder.ROW:
|
1500
|
+
out._sort_order = SortOrder.COL
|
1501
|
+
elif input._sort_order == SortOrder.COL:
|
1502
|
+
out._sort_order = SortOrder.ROW
|
1406
1503
|
|
1407
1504
|
return out
|
1408
1505
|
|
1409
1506
|
|
1410
|
-
@implements(Tensor
|
1411
|
-
def
|
1412
|
-
|
1413
|
-
|
1507
|
+
@implements(aten.index.Tensor)
|
1508
|
+
def _index(
|
1509
|
+
input: Union[EdgeIndex, Tensor],
|
1510
|
+
indices: List[Optional[Union[Tensor, EdgeIndex]]],
|
1511
|
+
) -> Union[EdgeIndex, Tensor]:
|
1512
|
+
|
1513
|
+
if not isinstance(input, EdgeIndex):
|
1514
|
+
indices = pytree.tree_map_only(EdgeIndex, lambda x: x._data, indices)
|
1515
|
+
return aten.index.Tensor(input, indices)
|
1516
|
+
|
1517
|
+
out = aten.index.Tensor(input._data, indices)
|
1518
|
+
|
1519
|
+
if len(indices) != 2 or indices[0] is not None:
|
1520
|
+
return out
|
1414
1521
|
|
1415
|
-
|
1416
|
-
|
1417
|
-
def is_last_dim_select(i: Any) -> bool:
|
1418
|
-
# Maps to true for `__getitem__` requests of the form
|
1419
|
-
# `tensor[..., index]` or `tensor[:, index]`.
|
1420
|
-
if not isinstance(i, tuple) or len(i) != 2:
|
1421
|
-
return False
|
1422
|
-
if i[0] == Ellipsis:
|
1423
|
-
return True
|
1424
|
-
if not isinstance(i[0], slice):
|
1425
|
-
return False
|
1426
|
-
return i[0].start is None and i[0].stop is None and i[0].step is None
|
1522
|
+
index = indices[1]
|
1523
|
+
assert isinstance(index, Tensor)
|
1427
1524
|
|
1428
|
-
|
1525
|
+
out = EdgeIndex(out)
|
1429
1526
|
|
1430
1527
|
# 1. `edge_index[:, mask]` or `edge_index[..., mask]`.
|
1431
|
-
if
|
1432
|
-
and index[1].dtype in (torch.bool, torch.uint8)):
|
1433
|
-
out = out.as_subclass(EdgeIndex)
|
1528
|
+
if index.dtype in (torch.bool, torch.uint8):
|
1434
1529
|
out._sparse_size = input.sparse_size()
|
1435
1530
|
out._sort_order = input._sort_order
|
1436
1531
|
|
1437
|
-
# 2. `edge_index[:, index]` or `edge_index[..., index]`.
|
1438
|
-
elif is_valid and isinstance(index[1], Tensor):
|
1439
|
-
out = out.as_subclass(EdgeIndex)
|
1532
|
+
else: # 2. `edge_index[:, index]` or `edge_index[..., index]`.
|
1440
1533
|
out._sparse_size = input.sparse_size()
|
1441
1534
|
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1535
|
+
return out
|
1536
|
+
|
1537
|
+
|
1538
|
+
@implements(aten.select.int)
|
1539
|
+
def _select(input: EdgeIndex, dim: int, index: int) -> Union[Tensor, Index]:
|
1540
|
+
out = aten.select.int(input._data, dim, index)
|
1541
|
+
|
1542
|
+
if dim == 0 or dim == -2:
|
1543
|
+
out = Index(out)
|
1544
|
+
|
1545
|
+
if index == 0 or index == -2: # Row-select:
|
1546
|
+
out._dim_size = input.sparse_size(0)
|
1547
|
+
out._is_sorted = input.is_sorted_by_row
|
1548
|
+
if input.is_sorted_by_row:
|
1549
|
+
out._indptr = input._indptr
|
1550
|
+
|
1551
|
+
else: # Col-select:
|
1552
|
+
assert index == 1 or index == -1
|
1553
|
+
out._dim_size = input.sparse_size(1)
|
1554
|
+
out._is_sorted = input.is_sorted_by_col
|
1555
|
+
if input.is_sorted_by_col:
|
1556
|
+
out._indptr = input._indptr
|
1448
1557
|
|
1449
1558
|
return out
|
1450
1559
|
|
1451
1560
|
|
1452
|
-
|
1561
|
+
@implements(aten.unbind.int)
|
1562
|
+
def _unbind(
|
1453
1563
|
input: EdgeIndex,
|
1454
|
-
|
1455
|
-
|
1564
|
+
dim: int = 0,
|
1565
|
+
) -> Union[List[Index], List[Tensor]]:
|
1566
|
+
|
1567
|
+
if dim == 0 or dim == -2:
|
1568
|
+
row = input[0]
|
1569
|
+
assert isinstance(row, Index)
|
1570
|
+
col = input[1]
|
1571
|
+
assert isinstance(col, Index)
|
1572
|
+
return [row, col]
|
1573
|
+
|
1574
|
+
return aten.unbind.int(input._data, dim)
|
1575
|
+
|
1576
|
+
|
1577
|
+
@implements(aten.add.Tensor)
|
1578
|
+
def _add(
|
1579
|
+
input: EdgeIndex,
|
1580
|
+
other: Union[int, Tensor, EdgeIndex],
|
1581
|
+
*,
|
1456
1582
|
alpha: int = 1,
|
1457
1583
|
) -> Union[EdgeIndex, Tensor]:
|
1458
1584
|
|
1459
|
-
|
1585
|
+
out = aten.add.Tensor(
|
1586
|
+
input._data,
|
1587
|
+
other._data if isinstance(other, EdgeIndex) else other,
|
1588
|
+
alpha=alpha,
|
1589
|
+
)
|
1590
|
+
|
1591
|
+
if out.dtype not in INDEX_DTYPES:
|
1460
1592
|
return out
|
1461
1593
|
if out.dim() != 2 or out.size(0) != 2:
|
1462
1594
|
return out
|
1463
1595
|
|
1464
|
-
|
1596
|
+
out = EdgeIndex(out)
|
1597
|
+
|
1598
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
1599
|
+
other = int(other)
|
1465
1600
|
|
1466
1601
|
if isinstance(other, int):
|
1467
1602
|
size = maybe_add(input._sparse_size, other, alpha)
|
1468
1603
|
assert len(size) == 2
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
elif isinstance(other, Tensor) and other.numel() <= 1:
|
1475
|
-
size = maybe_add(input._sparse_size, int(other), alpha)
|
1476
|
-
assert len(size) == 2
|
1477
|
-
output._sparse_size = size
|
1478
|
-
output._sort_order = input._sort_order
|
1479
|
-
output._is_undirected = input.is_undirected
|
1480
|
-
output._T_perm = input._T_perm
|
1604
|
+
out._sparse_size = size
|
1605
|
+
out._sort_order = input._sort_order
|
1606
|
+
out._is_undirected = input.is_undirected
|
1607
|
+
out._T_perm = input._T_perm
|
1481
1608
|
|
1482
1609
|
elif isinstance(other, Tensor) and other.size() == (2, 1):
|
1483
1610
|
size = maybe_add(input._sparse_size, other.view(-1).tolist(), alpha)
|
1484
1611
|
assert len(size) == 2
|
1485
|
-
|
1486
|
-
|
1487
|
-
output._T_perm = input._T_perm
|
1612
|
+
out._sparse_size = size
|
1613
|
+
out._sort_order = input._sort_order
|
1488
1614
|
if torch.equal(other[0], other[1]):
|
1489
|
-
|
1615
|
+
out._is_undirected = input.is_undirected
|
1616
|
+
out._T_perm = input._T_perm
|
1490
1617
|
|
1491
1618
|
elif isinstance(other, EdgeIndex):
|
1492
1619
|
size = maybe_add(input._sparse_size, other._sparse_size, alpha)
|
1493
1620
|
assert len(size) == 2
|
1494
|
-
|
1621
|
+
out._sparse_size = size
|
1495
1622
|
|
1496
|
-
return
|
1623
|
+
return out
|
1497
1624
|
|
1498
1625
|
|
1499
|
-
@implements(
|
1500
|
-
|
1501
|
-
def add(
|
1626
|
+
@implements(aten.add_.Tensor)
|
1627
|
+
def add_(
|
1502
1628
|
input: EdgeIndex,
|
1503
|
-
other: Union[int, Tensor],
|
1629
|
+
other: Union[int, Tensor, EdgeIndex],
|
1504
1630
|
*,
|
1505
1631
|
alpha: int = 1,
|
1506
|
-
|
1507
|
-
) -> Union[EdgeIndex, Tensor]:
|
1632
|
+
) -> EdgeIndex:
|
1508
1633
|
|
1509
|
-
|
1510
|
-
|
1634
|
+
sparse_size = input._sparse_size
|
1635
|
+
sort_order = input._sort_order
|
1636
|
+
is_undirected = input._is_undirected
|
1637
|
+
T_perm = input._T_perm
|
1638
|
+
input._clear_metadata()
|
1511
1639
|
|
1512
|
-
|
1640
|
+
aten.add_.Tensor(
|
1641
|
+
input._data,
|
1642
|
+
other._data if isinstance(other, EdgeIndex) else other,
|
1643
|
+
alpha=alpha,
|
1644
|
+
)
|
1513
1645
|
|
1646
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
1647
|
+
other = int(other)
|
1514
1648
|
|
1515
|
-
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1649
|
+
if isinstance(other, int):
|
1650
|
+
size = maybe_add(sparse_size, other, alpha)
|
1651
|
+
assert len(size) == 2
|
1652
|
+
input._sparse_size = size
|
1653
|
+
input._sort_order = sort_order
|
1654
|
+
input._is_undirected = is_undirected
|
1655
|
+
input._T_perm = T_perm
|
1522
1656
|
|
1523
|
-
|
1524
|
-
|
1657
|
+
elif isinstance(other, Tensor) and other.size() == (2, 1):
|
1658
|
+
size = maybe_add(sparse_size, other.view(-1).tolist(), alpha)
|
1659
|
+
assert len(size) == 2
|
1660
|
+
input._sparse_size = size
|
1661
|
+
input._sort_order = sort_order
|
1662
|
+
if torch.equal(other[0], other[1]):
|
1663
|
+
input._is_undirected = is_undirected
|
1664
|
+
input._T_perm = T_perm
|
1525
1665
|
|
1526
|
-
|
1666
|
+
elif isinstance(other, EdgeIndex):
|
1667
|
+
size = maybe_add(sparse_size, other._sparse_size, alpha)
|
1668
|
+
assert len(size) == 2
|
1669
|
+
input._sparse_size = size
|
1670
|
+
|
1671
|
+
return input
|
1527
1672
|
|
1528
1673
|
|
1529
|
-
|
1674
|
+
@implements(aten.sub.Tensor)
|
1675
|
+
def _sub(
|
1530
1676
|
input: EdgeIndex,
|
1531
|
-
other: Union[int, Tensor],
|
1532
|
-
|
1677
|
+
other: Union[int, Tensor, EdgeIndex],
|
1678
|
+
*,
|
1533
1679
|
alpha: int = 1,
|
1534
1680
|
) -> Union[EdgeIndex, Tensor]:
|
1535
1681
|
|
1536
|
-
|
1682
|
+
out = aten.sub.Tensor(
|
1683
|
+
input._data,
|
1684
|
+
other._data if isinstance(other, EdgeIndex) else other,
|
1685
|
+
alpha=alpha,
|
1686
|
+
)
|
1687
|
+
|
1688
|
+
if out.dtype not in INDEX_DTYPES:
|
1537
1689
|
return out
|
1538
1690
|
if out.dim() != 2 or out.size(0) != 2:
|
1539
1691
|
return out
|
1540
1692
|
|
1541
|
-
|
1693
|
+
out = EdgeIndex(out)
|
1694
|
+
|
1695
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
1696
|
+
other = int(other)
|
1542
1697
|
|
1543
1698
|
if isinstance(other, int):
|
1544
1699
|
size = maybe_sub(input._sparse_size, other, alpha)
|
1545
1700
|
assert len(size) == 2
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
elif isinstance(other, Tensor) and other.numel() <= 1:
|
1552
|
-
size = maybe_sub(input._sparse_size, int(other), alpha)
|
1553
|
-
assert len(size) == 2
|
1554
|
-
output._sparse_size = size
|
1555
|
-
output._sort_order = input._sort_order
|
1556
|
-
output._is_undirected = input.is_undirected
|
1557
|
-
output._T_perm = input._T_perm
|
1701
|
+
out._sparse_size = size
|
1702
|
+
out._sort_order = input._sort_order
|
1703
|
+
out._is_undirected = input.is_undirected
|
1704
|
+
out._T_perm = input._T_perm
|
1558
1705
|
|
1559
1706
|
elif isinstance(other, Tensor) and other.size() == (2, 1):
|
1560
1707
|
size = maybe_sub(input._sparse_size, other.view(-1).tolist(), alpha)
|
1561
1708
|
assert len(size) == 2
|
1562
|
-
|
1563
|
-
|
1564
|
-
output._T_perm = input._T_perm
|
1709
|
+
out._sparse_size = size
|
1710
|
+
out._sort_order = input._sort_order
|
1565
1711
|
if torch.equal(other[0], other[1]):
|
1566
|
-
|
1712
|
+
out._is_undirected = input.is_undirected
|
1713
|
+
out._T_perm = input._T_perm
|
1567
1714
|
|
1568
|
-
return
|
1715
|
+
return out
|
1569
1716
|
|
1570
1717
|
|
1571
|
-
@implements(
|
1572
|
-
|
1573
|
-
def sub(
|
1718
|
+
@implements(aten.sub_.Tensor)
|
1719
|
+
def sub_(
|
1574
1720
|
input: EdgeIndex,
|
1575
|
-
other: Union[int, Tensor],
|
1721
|
+
other: Union[int, Tensor, EdgeIndex],
|
1576
1722
|
*,
|
1577
1723
|
alpha: int = 1,
|
1578
|
-
|
1579
|
-
) -> Union[EdgeIndex, Tensor]:
|
1724
|
+
) -> EdgeIndex:
|
1580
1725
|
|
1581
|
-
|
1582
|
-
|
1726
|
+
sparse_size = input._sparse_size
|
1727
|
+
sort_order = input._sort_order
|
1728
|
+
is_undirected = input._is_undirected
|
1729
|
+
T_perm = input._T_perm
|
1730
|
+
input._clear_metadata()
|
1583
1731
|
|
1584
|
-
|
1732
|
+
aten.sub_.Tensor(
|
1733
|
+
input._data,
|
1734
|
+
other._data if isinstance(other, EdgeIndex) else other,
|
1735
|
+
alpha=alpha,
|
1736
|
+
)
|
1585
1737
|
|
1738
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
1739
|
+
other = int(other)
|
1586
1740
|
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1741
|
+
if isinstance(other, int):
|
1742
|
+
size = maybe_sub(sparse_size, other, alpha)
|
1743
|
+
assert len(size) == 2
|
1744
|
+
input._sparse_size = size
|
1745
|
+
input._sort_order = sort_order
|
1746
|
+
input._is_undirected = is_undirected
|
1747
|
+
input._T_perm = T_perm
|
1594
1748
|
|
1595
|
-
|
1596
|
-
|
1749
|
+
elif isinstance(other, Tensor) and other.size() == (2, 1):
|
1750
|
+
size = maybe_sub(sparse_size, other.view(-1).tolist(), alpha)
|
1751
|
+
assert len(size) == 2
|
1752
|
+
input._sparse_size = size
|
1753
|
+
input._sort_order = sort_order
|
1754
|
+
if torch.equal(other[0], other[1]):
|
1755
|
+
input._is_undirected = is_undirected
|
1756
|
+
input._T_perm = T_perm
|
1597
1757
|
|
1598
|
-
return
|
1758
|
+
return input
|
1599
1759
|
|
1600
1760
|
|
1601
1761
|
# Sparse-Dense Matrix Multiplication ##########################################
|
@@ -1620,13 +1780,13 @@ def _torch_sparse_spmm(
|
|
1620
1780
|
if not transpose:
|
1621
1781
|
assert input.is_sorted_by_row
|
1622
1782
|
(rowptr, col), _ = input.get_csr()
|
1623
|
-
row = input[0]
|
1783
|
+
row = input._data[0]
|
1624
1784
|
if other.requires_grad and reduce in ['sum', 'mean']:
|
1625
1785
|
(colptr, _), perm = input.get_csc()
|
1626
1786
|
else:
|
1627
1787
|
assert input.is_sorted_by_col
|
1628
1788
|
(rowptr, col), _ = input.get_csc()
|
1629
|
-
row = input[1]
|
1789
|
+
row = input._data[1]
|
1630
1790
|
if other.requires_grad and reduce in ['sum', 'mean']:
|
1631
1791
|
(colptr, _), perm = input.get_csr()
|
1632
1792
|
|
@@ -1699,7 +1859,7 @@ class _TorchSPMM(torch.autograd.Function):
|
|
1699
1859
|
adj = input.to_sparse_csr(value)
|
1700
1860
|
else:
|
1701
1861
|
(colptr, row), perm = input.get_csc()
|
1702
|
-
if value is not None:
|
1862
|
+
if value is not None and perm is not None:
|
1703
1863
|
value = value[perm]
|
1704
1864
|
else:
|
1705
1865
|
value = input._get_value()
|
@@ -1715,7 +1875,7 @@ class _TorchSPMM(torch.autograd.Function):
|
|
1715
1875
|
adj = input.to_sparse_csc(value).t()
|
1716
1876
|
else:
|
1717
1877
|
(rowptr, col), perm = input.get_csr()
|
1718
|
-
if value is not None:
|
1878
|
+
if value is not None and perm is not None:
|
1719
1879
|
value = value[perm]
|
1720
1880
|
else:
|
1721
1881
|
value = input._get_value()
|
@@ -1746,14 +1906,16 @@ def _scatter_spmm(
|
|
1746
1906
|
from torch_geometric.utils import scatter
|
1747
1907
|
|
1748
1908
|
if not transpose:
|
1749
|
-
other_j = other[input[1]]
|
1750
|
-
index = input[0]
|
1909
|
+
other_j = other[input._data[1]]
|
1910
|
+
index = input._data[0]
|
1911
|
+
dim_size = input.get_sparse_size(0)
|
1751
1912
|
else:
|
1752
|
-
other_j = other[input[0]]
|
1753
|
-
index = input[1]
|
1913
|
+
other_j = other[input._data[0]]
|
1914
|
+
index = input._data[1]
|
1915
|
+
dim_size = input.get_sparse_size(1)
|
1754
1916
|
|
1755
1917
|
other_j = other_j * value.view(-1, 1) if value is not None else other_j
|
1756
|
-
return scatter(other_j, index, 0, dim_size=
|
1918
|
+
return scatter(other_j, index, 0, dim_size=dim_size, reduce=reduce)
|
1757
1919
|
|
1758
1920
|
|
1759
1921
|
def _spmm(
|
@@ -1775,7 +1937,7 @@ def _spmm(
|
|
1775
1937
|
if transpose and not input.is_sorted_by_col:
|
1776
1938
|
cls_name = input.__class__.__name__
|
1777
1939
|
raise ValueError(f"'matmul(..., transpose=True)' requires "
|
1778
|
-
f"'{cls_name}' to be sorted by
|
1940
|
+
f"'{cls_name}' to be sorted by columns")
|
1779
1941
|
|
1780
1942
|
if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
|
1781
1943
|
and other.is_cuda): # pragma: no cover
|
@@ -1786,17 +1948,17 @@ def _spmm(
|
|
1786
1948
|
return _torch_sparse_spmm(input, other, value, reduce, transpose)
|
1787
1949
|
return _scatter_spmm(input, other, value, reduce, transpose)
|
1788
1950
|
|
1789
|
-
if
|
1790
|
-
|
1951
|
+
if torch_geometric.typing.WITH_PT20:
|
1952
|
+
if reduce == 'sum' or reduce == 'add':
|
1953
|
+
return _TorchSPMM.apply(input, other, value, 'sum', transpose)
|
1791
1954
|
|
1792
|
-
|
1793
|
-
|
1794
|
-
|
1795
|
-
|
1955
|
+
if reduce == 'mean':
|
1956
|
+
out = _TorchSPMM.apply(input, other, value, 'sum', transpose)
|
1957
|
+
count = input.get_indptr().diff()
|
1958
|
+
return out / count.clamp_(min=1).to(out.dtype).view(-1, 1)
|
1796
1959
|
|
1797
|
-
|
1798
|
-
|
1799
|
-
return _TorchSPMM.apply(input, other, value, reduce, transpose)
|
1960
|
+
if not other.is_cuda and not other.requires_grad:
|
1961
|
+
return _TorchSPMM.apply(input, other, value, reduce, transpose)
|
1800
1962
|
|
1801
1963
|
if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling():
|
1802
1964
|
return _torch_sparse_spmm(input, other, value, reduce, transpose)
|
@@ -1858,7 +2020,7 @@ def matmul(
|
|
1858
2020
|
else:
|
1859
2021
|
raise NotImplementedError
|
1860
2022
|
|
1861
|
-
edge_index = edge_index
|
2023
|
+
edge_index = EdgeIndex(edge_index)
|
1862
2024
|
edge_index._sort_order = SortOrder.ROW
|
1863
2025
|
edge_index._sparse_size = (out.size(0), out.size(1))
|
1864
2026
|
edge_index._indptr = rowptr
|
@@ -1866,20 +2028,36 @@ def matmul(
|
|
1866
2028
|
return edge_index, out.values()
|
1867
2029
|
|
1868
2030
|
|
1869
|
-
@implements(
|
1870
|
-
|
1871
|
-
@implements(Tensor.matmul)
|
1872
|
-
def _matmul1(
|
2031
|
+
@implements(aten.mm.default)
|
2032
|
+
def _mm(
|
1873
2033
|
input: EdgeIndex,
|
1874
2034
|
other: Union[Tensor, EdgeIndex],
|
1875
2035
|
) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:
|
1876
2036
|
return matmul(input, other)
|
1877
2037
|
|
1878
2038
|
|
1879
|
-
@implements(
|
1880
|
-
def
|
2039
|
+
@implements(aten._sparse_addmm.default)
|
2040
|
+
def _addmm(
|
2041
|
+
input: Tensor,
|
1881
2042
|
mat1: EdgeIndex,
|
1882
|
-
mat2:
|
1883
|
-
|
1884
|
-
|
1885
|
-
|
2043
|
+
mat2: Tensor,
|
2044
|
+
beta: float = 1.0,
|
2045
|
+
alpha: float = 1.0,
|
2046
|
+
) -> Tensor:
|
2047
|
+
assert input.abs().sum() == 0.0
|
2048
|
+
out = matmul(mat1, mat2)
|
2049
|
+
assert isinstance(out, Tensor)
|
2050
|
+
return alpha * out if alpha != 1.0 else out
|
2051
|
+
|
2052
|
+
|
2053
|
+
if hasattr(aten, '_sparse_mm_reduce_impl'):
|
2054
|
+
|
2055
|
+
@implements(aten._sparse_mm_reduce_impl.default)
|
2056
|
+
def _mm_reduce(
|
2057
|
+
mat1: EdgeIndex,
|
2058
|
+
mat2: Tensor,
|
2059
|
+
reduce: ReduceType = 'sum',
|
2060
|
+
) -> Tuple[Tensor, Tensor]:
|
2061
|
+
out = matmul(mat1, mat2, reduce=reduce)
|
2062
|
+
assert isinstance(out, Tensor)
|
2063
|
+
return out, out # We return a dummy tensor for `argout` for now.
|