pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -1,16 +1,15 @@
1
1
  import functools
2
- import typing
3
2
  from enum import Enum
4
3
  from typing import (
5
4
  Any,
6
5
  Callable,
7
6
  Dict,
7
+ Iterable,
8
8
  List,
9
9
  Literal,
10
10
  NamedTuple,
11
11
  Optional,
12
12
  Sequence,
13
- Set,
14
13
  Tuple,
15
14
  Type,
16
15
  Union,
@@ -19,23 +18,17 @@ from typing import (
19
18
  )
20
19
 
21
20
  import torch
21
+ import torch.utils._pytree as pytree
22
22
  from torch import Tensor
23
23
 
24
24
  import torch_geometric.typing
25
- from torch_geometric import is_compiling
26
- from torch_geometric.typing import SparseTensor
25
+ from torch_geometric import Index, is_compiling
26
+ from torch_geometric.index import index2ptr, ptr2index
27
+ from torch_geometric.typing import INDEX_DTYPES, SparseTensor
27
28
 
28
- HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
29
+ aten = torch.ops.aten
29
30
 
30
- if torch_geometric.typing.WITH_PT20:
31
- SUPPORTED_DTYPES: Set[torch.dtype] = {
32
- torch.int32,
33
- torch.int64,
34
- }
35
- elif not typing.TYPE_CHECKING: # pragma: no cover
36
- SUPPORTED_DTYPES: Set[torch.dtype] = {
37
- torch.int64,
38
- }
31
+ HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
39
32
 
40
33
  ReduceType = Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max']
41
34
  PYG_REDUCE: Dict[ReduceType, ReduceType] = {
@@ -114,16 +107,11 @@ def maybe_sub(
114
107
  for v, o in zip(value, other))
115
108
 
116
109
 
117
- def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
118
- index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
119
- return index.repeat_interleave(ptr.diff(), output_size=output_size)
120
-
121
-
122
110
  def assert_valid_dtype(tensor: Tensor) -> None:
123
- if tensor.dtype not in SUPPORTED_DTYPES:
111
+ if tensor.dtype not in INDEX_DTYPES:
124
112
  raise ValueError(f"'EdgeIndex' holds an unsupported data type "
125
113
  f"(got '{tensor.dtype}', but expected one of "
126
- f"{SUPPORTED_DTYPES})")
114
+ f"{INDEX_DTYPES})")
127
115
 
128
116
 
129
117
  def assert_two_dimensional(tensor: Tensor) -> None:
@@ -136,7 +124,7 @@ def assert_two_dimensional(tensor: Tensor) -> None:
136
124
 
137
125
 
138
126
  def assert_contiguous(tensor: Tensor) -> None:
139
- if not tensor.is_contiguous():
127
+ if not tensor[0].is_contiguous() or not tensor[1].is_contiguous():
140
128
  raise ValueError("'EdgeIndex' needs to be contiguous. Please call "
141
129
  "`edge_index.contiguous()` before proceeding.")
142
130
 
@@ -150,13 +138,13 @@ def assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None:
150
138
 
151
139
  def assert_sorted(func: Callable) -> Callable:
152
140
  @functools.wraps(func)
153
- def wrapper(*args: Any, **kwargs: Any) -> Any:
154
- if not args[0].is_sorted:
155
- cls_name = args[0].__class__.__name__
141
+ def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any:
142
+ if not self.is_sorted:
143
+ cls_name = self.__class__.__name__
156
144
  raise ValueError(
157
145
  f"Cannot call '{func.__name__}' since '{cls_name}' is not "
158
146
  f"sorted. Please call `{cls_name}.sort_by(...)` first.")
159
- return func(*args, **kwargs)
147
+ return func(self, *args, **kwargs)
160
148
 
161
149
  return wrapper
162
150
 
@@ -185,7 +173,7 @@ class EdgeIndex(Tensor):
185
173
  :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its
186
174
  lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
187
175
 
188
- This representation ensures for optimal computation in GNN message passing
176
+ This representation ensures optimal computation in GNN message passing
189
177
  schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
190
178
  workflows.
191
179
 
@@ -229,12 +217,12 @@ class EdgeIndex(Tensor):
229
217
  # for a basic tutorial on how to subclass `torch.Tensor`.
230
218
 
231
219
  # The underlying tensor representation:
232
- _data: Optional[Tensor] = None
220
+ _data: Tensor
233
221
 
234
222
  # The size of the underlying sparse matrix:
235
223
  _sparse_size: Tuple[Optional[int], Optional[int]] = (None, None)
236
224
 
237
- # Whether the `edge_index` represented is non-sorted (`None`), or sorted
225
+ # Whether the `edge_index` representation is non-sorted (`None`), or sorted
238
226
  # based on row or column values.
239
227
  _sort_order: Optional[SortOrder] = None
240
228
 
@@ -260,6 +248,7 @@ class EdgeIndex(Tensor):
260
248
  # original metadata to be able to reconstruct individual edge indices:
261
249
  _cat_metadata: Optional[CatMetadata] = None
262
250
 
251
+ @staticmethod
263
252
  def __new__(
264
253
  cls: Type,
265
254
  data: Any,
@@ -336,21 +325,26 @@ class EdgeIndex(Tensor):
336
325
  elif sparse_size[0] is None and sparse_size[1] is not None:
337
326
  sparse_size = (sparse_size[1], sparse_size[1])
338
327
 
339
- if torch_geometric.typing.WITH_PT112:
340
- out = super().__new__(cls, data)
341
- else:
342
- out = Tensor._make_subclass(cls, data)
328
+ out = Tensor._make_wrapper_subclass( # type: ignore
329
+ cls,
330
+ size=data.size(),
331
+ strides=data.stride(),
332
+ dtype=data.dtype,
333
+ device=data.device,
334
+ layout=data.layout,
335
+ requires_grad=False,
336
+ )
337
+ assert isinstance(out, EdgeIndex)
343
338
 
344
339
  # Attach metadata:
345
- assert isinstance(out, EdgeIndex)
346
- if torch_geometric.typing.WITH_PT22:
347
- out._data = data
340
+ out._data = data
348
341
  out._sparse_size = sparse_size
349
342
  out._sort_order = None if sort_order is None else SortOrder(sort_order)
350
343
  out._is_undirected = is_undirected
351
344
  out._indptr = indptr
352
345
 
353
346
  if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata:
347
+ out._data = data._data
354
348
  out._T_perm = data._T_perm
355
349
  out._T_index = data._T_index
356
350
  out._T_indptr = data._T_indptr
@@ -378,41 +372,43 @@ class EdgeIndex(Tensor):
378
372
  * the sort order is correctly set.
379
373
  * indices are bidirectional in case it is specified as undirected.
380
374
  """
381
- assert_valid_dtype(self)
382
- assert_two_dimensional(self)
383
- assert_contiguous(self)
375
+ assert_valid_dtype(self._data)
376
+ assert_two_dimensional(self._data)
377
+ assert_contiguous(self._data)
384
378
  if self.is_undirected:
385
379
  assert_symmetric(self.sparse_size())
386
380
 
387
- if self.numel() > 0 and self.min() < 0:
381
+ if self.numel() > 0 and self._data.min() < 0:
388
382
  raise ValueError(f"'{self.__class__.__name__}' contains negative "
389
383
  f"indices (got {int(self.min())})")
390
384
 
391
385
  if (self.numel() > 0 and self.num_rows is not None
392
- and self[0].max() >= self.num_rows):
386
+ and self._data[0].max() >= self.num_rows):
393
387
  raise ValueError(f"'{self.__class__.__name__}' contains larger "
394
388
  f"indices than its number of rows "
395
- f"(got {int(self[0].max())}, but expected values "
396
- f"smaller than {self.num_rows})")
389
+ f"(got {int(self._data[0].max())}, but expected "
390
+ f"values smaller than {self.num_rows})")
397
391
 
398
392
  if (self.numel() > 0 and self.num_cols is not None
399
- and self[1].max() >= self.num_cols):
393
+ and self._data[1].max() >= self.num_cols):
400
394
  raise ValueError(f"'{self.__class__.__name__}' contains larger "
401
395
  f"indices than its number of columns "
402
- f"(got {int(self[1].max())}, but expected values "
403
- f"smaller than {self.num_cols})")
396
+ f"(got {int(self._data[1].max())}, but expected "
397
+ f"values smaller than {self.num_cols})")
404
398
 
405
- if self.is_sorted_by_row and (self[0].diff() < 0).any():
399
+ if self.is_sorted_by_row and (self._data[0].diff() < 0).any():
406
400
  raise ValueError(f"'{self.__class__.__name__}' is not sorted by "
407
401
  f"row indices")
408
402
 
409
- if self.is_sorted_by_col and (self[1].diff() < 0).any():
403
+ if self.is_sorted_by_col and (self._data[1].diff() < 0).any():
410
404
  raise ValueError(f"'{self.__class__.__name__}' is not sorted by "
411
405
  f"column indices")
412
406
 
413
407
  if self.is_undirected:
414
- flat_index1 = (self[0] * self.get_num_rows() + self[1]).sort()[0]
415
- flat_index2 = (self[1] * self.get_num_cols() + self[0]).sort()[0]
408
+ flat_index1 = self._data[0] * self.get_num_rows() + self._data[1]
409
+ flat_index1 = flat_index1.sort()[0]
410
+ flat_index2 = self._data[1] * self.get_num_cols() + self._data[0]
411
+ flat_index2 = flat_index2.sort()[0]
416
412
  if not torch.equal(flat_index1, flat_index2):
417
413
  raise ValueError(f"'{self.__class__.__name__}' is not "
418
414
  f"undirected")
@@ -482,6 +478,11 @@ class EdgeIndex(Tensor):
482
478
  r"""Returns whether indices are bidirectional."""
483
479
  return self._is_undirected
484
480
 
481
+ @property
482
+ def dtype(self) -> torch.dtype: # type: ignore
483
+ # TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
484
+ return self._data.dtype
485
+
485
486
  # Cache Interface #########################################################
486
487
 
487
488
  @overload
@@ -511,11 +512,11 @@ class EdgeIndex(Tensor):
511
512
  return size
512
513
 
513
514
  if self.is_undirected:
514
- size = int(self.max()) + 1 if self.numel() > 0 else 0
515
+ size = int(self._data.max()) + 1 if self.numel() > 0 else 0
515
516
  self._sparse_size = (size, size)
516
517
  return size
517
518
 
518
- size = int(self[dim].max()) + 1 if self.numel() > 0 else 0
519
+ size = int(self._data[dim].max()) + 1 if self.numel() > 0 else 0
519
520
  self._sparse_size = set_tuple_item(self._sparse_size, dim, size)
520
521
  return size
521
522
 
@@ -551,11 +552,8 @@ class EdgeIndex(Tensor):
551
552
  if ptr is None or size is None:
552
553
  return None
553
554
 
554
- if ptr.numel() - 1 == size:
555
- return ptr
556
-
557
- if ptr.numel() - 1 > size:
558
- return None
555
+ if ptr.numel() - 1 >= size:
556
+ return ptr[:size + 1]
559
557
 
560
558
  fill_value = ptr.new_full(
561
559
  (size - ptr.numel() + 1, ),
@@ -599,11 +597,7 @@ class EdgeIndex(Tensor):
599
597
  return self._T_indptr
600
598
 
601
599
  dim = 0 if self.is_sorted_by_row else 1
602
- self._indptr = torch._convert_indices_from_coo_to_csr(
603
- self[dim],
604
- self.get_sparse_size(dim),
605
- out_int32=self.dtype != torch.int64,
606
- )
600
+ self._indptr = index2ptr(self._data[dim], self.get_sparse_size(dim))
607
601
 
608
602
  return self._indptr
609
603
 
@@ -614,13 +608,14 @@ class EdgeIndex(Tensor):
614
608
  dim = 1 if self.is_sorted_by_row else 0
615
609
 
616
610
  if self._T_perm is None:
617
- index, perm = index_sort(self[dim], self.get_sparse_size(dim))
611
+ max_index = self.get_sparse_size(dim)
612
+ index, perm = index_sort(self._data[dim], max_index)
618
613
  self._T_index = set_tuple_item(self._T_index, dim, index)
619
- self._T_perm = perm
614
+ self._T_perm = perm.to(self.dtype)
620
615
 
621
616
  if self._T_index[1 - dim] is None:
622
617
  self._T_index = set_tuple_item( #
623
- self._T_index, 1 - dim, self[1 - dim][self._T_perm])
618
+ self._T_index, 1 - dim, self._data[1 - dim][self._T_perm])
624
619
 
625
620
  row, col = self._T_index
626
621
  assert row is not None and col is not None
@@ -628,12 +623,12 @@ class EdgeIndex(Tensor):
628
623
  return (row, col), self._T_perm
629
624
 
630
625
  @assert_sorted
631
- def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]:
626
+ def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]:
632
627
  r"""Returns the compressed CSR representation
633
628
  :obj:`(rowptr, col), perm` in case :class:`EdgeIndex` is sorted.
634
629
  """
635
630
  if self.is_sorted_by_row:
636
- return (self.get_indptr(), self[1]), slice(None, None, None)
631
+ return (self.get_indptr(), self._data[1]), None
637
632
 
638
633
  assert self.is_sorted_by_col
639
634
  (row, col), perm = self._sort_by_transpose()
@@ -643,21 +638,17 @@ class EdgeIndex(Tensor):
643
638
  elif self.is_undirected and self._indptr is not None:
644
639
  rowptr = self._indptr
645
640
  else:
646
- rowptr = self._T_indptr = torch._convert_indices_from_coo_to_csr(
647
- row,
648
- self.get_num_rows(),
649
- out_int32=self.dtype != torch.int64,
650
- )
641
+ rowptr = self._T_indptr = index2ptr(row, self.get_num_rows())
651
642
 
652
643
  return (rowptr, col), perm
653
644
 
654
645
  @assert_sorted
655
- def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]:
646
+ def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]:
656
647
  r"""Returns the compressed CSC representation
657
648
  :obj:`(colptr, row), perm` in case :class:`EdgeIndex` is sorted.
658
649
  """
659
650
  if self.is_sorted_by_col:
660
- return (self.get_indptr(), self[0]), slice(None, None, None)
651
+ return (self.get_indptr(), self._data[0]), None
661
652
 
662
653
  assert self.is_sorted_by_row
663
654
  (row, col), perm = self._sort_by_transpose()
@@ -667,11 +658,7 @@ class EdgeIndex(Tensor):
667
658
  elif self.is_undirected and self._indptr is not None:
668
659
  colptr = self._indptr
669
660
  else:
670
- colptr = self._T_indptr = torch._convert_indices_from_coo_to_csr(
671
- col,
672
- self.get_num_cols(),
673
- out_int32=self.dtype != torch.int64,
674
- )
661
+ colptr = self._T_indptr = index2ptr(col, self.get_num_cols())
675
662
 
676
663
  return (colptr, row), perm
677
664
 
@@ -710,11 +697,32 @@ class EdgeIndex(Tensor):
710
697
 
711
698
  # Methods #################################################################
712
699
 
700
+ def share_memory_(self) -> 'EdgeIndex':
701
+ """""" # noqa: D419
702
+ self._data.share_memory_()
703
+ if self._indptr is not None:
704
+ self._indptr.share_memory_()
705
+ if self._T_perm is not None:
706
+ self._T_perm.share_memory_()
707
+ if self._T_index[0] is not None:
708
+ self._T_index[0].share_memory_()
709
+ if self._T_index[1] is not None:
710
+ self._T_index[1].share_memory_()
711
+ if self._T_indptr is not None:
712
+ self._T_indptr.share_memory_()
713
+ if self._value is not None:
714
+ self._value.share_memory_()
715
+ return self
716
+
717
+ def is_shared(self) -> bool:
718
+ """""" # noqa: D419
719
+ return self._data.is_shared()
720
+
713
721
  def as_tensor(self) -> Tensor:
714
722
  r"""Zero-copies the :class:`EdgeIndex` representation back to a
715
723
  :class:`torch.Tensor` representation.
716
724
  """
717
- return self.as_subclass(Tensor)
725
+ return self._data
718
726
 
719
727
  def sort_by(
720
728
  self,
@@ -735,7 +743,7 @@ class EdgeIndex(Tensor):
735
743
  sort_order = SortOrder(sort_order)
736
744
 
737
745
  if self._sort_order == sort_order: # Nothing to do.
738
- return SortReturnType(self, slice(None, None, None))
746
+ return SortReturnType(self, None)
739
747
 
740
748
  if self.is_sorted:
741
749
  (row, col), perm = self._sort_by_transpose()
@@ -743,12 +751,12 @@ class EdgeIndex(Tensor):
743
751
 
744
752
  # Otherwise, perform sorting:
745
753
  elif sort_order == SortOrder.ROW:
746
- row, perm = index_sort(self[0], self.get_num_rows(), stable)
747
- edge_index = torch.stack([row, self[1][perm]], dim=0)
754
+ row, perm = index_sort(self._data[0], self.get_num_rows(), stable)
755
+ edge_index = torch.stack([row, self._data[1][perm]], dim=0)
748
756
 
749
757
  else:
750
- col, perm = index_sort(self[1], self.get_num_cols(), stable)
751
- edge_index = torch.stack([self[0][perm], col], dim=0)
758
+ col, perm = index_sort(self._data[1], self.get_num_cols(), stable)
759
+ edge_index = torch.stack([self._data[0][perm], col], dim=0)
752
760
 
753
761
  out = self.__class__(edge_index)
754
762
 
@@ -798,7 +806,7 @@ class EdgeIndex(Tensor):
798
806
  size = size + value.size()[1:] # type: ignore
799
807
 
800
808
  out = torch.full(size, fill_value, dtype=dtype, device=self.device)
801
- out[self[0], self[1]] = value if value is not None else 1
809
+ out[self._data[0], self._data[1]] = value if value is not None else 1
802
810
 
803
811
  return out
804
812
 
@@ -812,19 +820,28 @@ class EdgeIndex(Tensor):
812
820
  :obj:`1.0`. (default: :obj:`None`)
813
821
  """
814
822
  value = self._get_value() if value is None else value
815
- out = torch.sparse_coo_tensor(
816
- indices=self.as_tensor(),
823
+
824
+ if not torch_geometric.typing.WITH_PT21:
825
+ out = torch.sparse_coo_tensor(
826
+ indices=self._data,
827
+ values=value,
828
+ size=self.get_sparse_size(),
829
+ device=self.device,
830
+ requires_grad=value.requires_grad,
831
+ )
832
+ if self.is_sorted_by_row:
833
+ out = out._coalesced_(True)
834
+ return out
835
+
836
+ return torch.sparse_coo_tensor(
837
+ indices=self._data,
817
838
  values=value,
818
839
  size=self.get_sparse_size(),
819
840
  device=self.device,
820
841
  requires_grad=value.requires_grad,
842
+ is_coalesced=True if self.is_sorted_by_row else None,
821
843
  )
822
844
 
823
- if self.is_sorted_by_row:
824
- out = out._coalesced_(True)
825
-
826
- return out
827
-
828
845
  def to_sparse_csr( # type: ignore
829
846
  self,
830
847
  value: Optional[Tensor] = None,
@@ -838,7 +855,10 @@ class EdgeIndex(Tensor):
838
855
  :obj:`1.0`. (default: :obj:`None`)
839
856
  """
840
857
  (rowptr, col), perm = self.get_csr()
841
- value = self._get_value() if value is None else value[perm]
858
+ if value is not None and perm is not None:
859
+ value = value[perm]
860
+ elif value is None:
861
+ value = self._get_value()
842
862
 
843
863
  return torch.sparse_csr_tensor(
844
864
  crow_indices=rowptr,
@@ -866,7 +886,10 @@ class EdgeIndex(Tensor):
866
886
  "'to_sparse_csc' not supported for PyTorch < 1.12")
867
887
 
868
888
  (colptr, row), perm = self.get_csc()
869
- value = self._get_value() if value is None else value[perm]
889
+ if value is not None and perm is not None:
890
+ value = value[perm]
891
+ elif value is None:
892
+ value = self._get_value()
870
893
 
871
894
  return torch.sparse_csc_tensor(
872
895
  ccol_indices=colptr,
@@ -916,8 +939,8 @@ class EdgeIndex(Tensor):
916
939
  (default: :obj:`None`)
917
940
  """
918
941
  return SparseTensor(
919
- row=self[0],
920
- col=self[1],
942
+ row=self._data[0],
943
+ col=self._data[1],
921
944
  rowptr=self._indptr if self.is_sorted_by_row else None,
922
945
  value=value,
923
946
  sparse_sizes=self.get_sparse_size(),
@@ -925,7 +948,7 @@ class EdgeIndex(Tensor):
925
948
  trust_data=True,
926
949
  )
927
950
 
928
- # TODO investigate how to avoid overlapping return types here.
951
+ # TODO Investigate how to avoid overlapping return types here.
929
952
  @overload
930
953
  def matmul( # type: ignore
931
954
  self,
@@ -1034,93 +1057,148 @@ class EdgeIndex(Tensor):
1034
1057
  f"(got {start})")
1035
1058
 
1036
1059
  if dim == 0:
1037
- (rowptr, col), _ = self.get_csr()
1038
- rowptr = rowptr.narrow(0, start, length + 1)
1060
+ if self.is_sorted_by_row:
1061
+ (rowptr, col), _ = self.get_csr()
1062
+ rowptr = rowptr.narrow(0, start, length + 1)
1063
+
1064
+ if rowptr.numel() < 2:
1065
+ row, col = self._data[0, :0], self._data[1, :0]
1066
+ rowptr = None
1067
+ num_rows = 0
1068
+ else:
1069
+ col = col[rowptr[0]:rowptr[-1]]
1070
+ rowptr = rowptr - rowptr[0]
1071
+ num_rows = rowptr.numel() - 1
1072
+
1073
+ row = torch.arange(
1074
+ num_rows,
1075
+ dtype=col.dtype,
1076
+ device=col.device,
1077
+ ).repeat_interleave(
1078
+ rowptr.diff(),
1079
+ output_size=col.numel(),
1080
+ )
1039
1081
 
1040
- if rowptr.numel() < 2:
1041
- row, col = self[0, :0], self[1, :0]
1042
- rowptr = None
1043
- num_rows = 0
1044
- else:
1045
- col = col[rowptr[0]:rowptr[-1]]
1046
- rowptr = rowptr - rowptr[0]
1047
- num_rows = rowptr.numel() - 1
1048
-
1049
- row = torch.arange(
1050
- num_rows,
1051
- dtype=col.dtype,
1052
- device=col.device,
1053
- ).repeat_interleave(
1054
- rowptr.diff(),
1055
- output_size=col.numel(),
1082
+ edge_index = EdgeIndex(
1083
+ torch.stack([row, col], dim=0),
1084
+ sparse_size=(num_rows, self.sparse_size(1)),
1085
+ sort_order='row',
1056
1086
  )
1087
+ edge_index._indptr = rowptr
1088
+ return edge_index
1057
1089
 
1058
- edge_index = EdgeIndex(
1059
- torch.stack([row, col], dim=0),
1060
- sparse_size=(num_rows, self.sparse_size(1)),
1061
- sort_order='row',
1062
- )
1063
- edge_index._indptr = rowptr
1064
- return edge_index
1090
+ else:
1091
+ mask = self._data[0] >= start
1092
+ mask &= self._data[0] < (start + length)
1093
+ offset = torch.tensor([[start], [0]], device=self.device)
1094
+ edge_index = self[:, mask].sub_(offset) # type: ignore
1095
+ edge_index._sparse_size = (length, edge_index._sparse_size[1])
1096
+ return edge_index
1097
+
1098
+ else:
1099
+ assert dim == 1
1065
1100
 
1066
- else: # dim == 0:
1067
- (colptr, row), _ = self.get_csc()
1068
- colptr = colptr.narrow(0, start, length + 1)
1101
+ if self.is_sorted_by_col:
1102
+ (colptr, row), _ = self.get_csc()
1103
+ colptr = colptr.narrow(0, start, length + 1)
1069
1104
 
1070
- if colptr.numel() < 2:
1071
- row, col = self[0, :0], self[1, :0]
1072
- colptr = None
1073
- num_cols = 0
1074
- else:
1075
- row = row[colptr[0]:colptr[-1]]
1076
- colptr = colptr - colptr[0]
1077
- num_cols = colptr.numel() - 1
1078
-
1079
- col = torch.arange(
1080
- num_cols,
1081
- dtype=row.dtype,
1082
- device=row.device,
1083
- ).repeat_interleave(
1084
- colptr.diff(),
1085
- output_size=row.numel(),
1105
+ if colptr.numel() < 2:
1106
+ row, col = self._data[0, :0], self._data[1, :0]
1107
+ colptr = None
1108
+ num_cols = 0
1109
+ else:
1110
+ row = row[colptr[0]:colptr[-1]]
1111
+ colptr = colptr - colptr[0]
1112
+ num_cols = colptr.numel() - 1
1113
+
1114
+ col = torch.arange(
1115
+ num_cols,
1116
+ dtype=row.dtype,
1117
+ device=row.device,
1118
+ ).repeat_interleave(
1119
+ colptr.diff(),
1120
+ output_size=row.numel(),
1121
+ )
1122
+
1123
+ edge_index = EdgeIndex(
1124
+ torch.stack([row, col], dim=0),
1125
+ sparse_size=(self.sparse_size(0), num_cols),
1126
+ sort_order='col',
1086
1127
  )
1128
+ edge_index._indptr = colptr
1129
+ return edge_index
1087
1130
 
1088
- edge_index = EdgeIndex(
1089
- torch.stack([row, col], dim=0),
1090
- sparse_size=(self.sparse_size(0), num_cols),
1091
- sort_order='col',
1092
- )
1093
- edge_index._indptr = colptr
1094
- return edge_index
1131
+ else:
1132
+ mask = self._data[1] >= start
1133
+ mask &= self._data[1] < (start + length)
1134
+ offset = torch.tensor([[0], [start]], device=self.device)
1135
+ edge_index = self[:, mask].sub_(offset) # type: ignore
1136
+ edge_index._sparse_size = (edge_index._sparse_size[0], length)
1137
+ return edge_index
1138
+
1139
+ def to_vector(self) -> Tensor:
1140
+ r"""Converts :class:`EdgeIndex` into a one-dimensional index
1141
+ vector representation.
1142
+ """
1143
+ num_rows, num_cols = self.get_sparse_size()
1144
+
1145
+ if num_rows * num_cols > torch_geometric.typing.MAX_INT64:
1146
+ raise ValueError("'to_vector()' will result in an overflow")
1147
+
1148
+ return self._data[0] * num_rows + self._data[1]
1149
+
1150
+ # PyTorch/Python builtins #################################################
1095
1151
 
1096
1152
  def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
1097
- if not torch_geometric.typing.WITH_PT22: # pragma: no cover
1098
- raise RuntimeError("'torch.compile' with 'EdgeIndex' only "
1099
- "supported from PyTorch 2.2 onwards")
1100
- assert self._data is not None
1101
- # TODO Add `_T_index`.
1102
- attrs = ['_data', '_indptr', '_T_perm', '_T_indptr']
1103
- return attrs, ()
1153
+ attrs = ['_data']
1154
+ if self._indptr is not None:
1155
+ attrs.append('_indptr')
1156
+ if self._T_perm is not None:
1157
+ attrs.append('_T_perm')
1158
+ # TODO We cannot save `_T_index` for now since it is stored as tuple.
1159
+ if self._T_indptr is not None:
1160
+ attrs.append('_T_indptr')
1161
+
1162
+ ctx = (
1163
+ self._sparse_size,
1164
+ self._sort_order,
1165
+ self._is_undirected,
1166
+ self._cat_metadata,
1167
+ )
1168
+
1169
+ return attrs, ctx
1104
1170
 
1105
1171
  @staticmethod
1106
1172
  def __tensor_unflatten__(
1107
- inner_tensors: Tuple[Any],
1173
+ inner_tensors: Dict[str, Any],
1108
1174
  ctx: Tuple[Any, ...],
1109
- *args: Any,
1110
- **kwargs: Any,
1175
+ outer_size: Tuple[int, ...],
1176
+ outer_stride: Tuple[int, ...],
1111
1177
  ) -> 'EdgeIndex':
1112
- if not torch_geometric.typing.WITH_PT22: # pragma: no cover
1113
- raise RuntimeError("'torch.compile' with 'EdgeIndex' only "
1114
- "supported from PyTorch 2.2 onwards")
1115
- raise NotImplementedError
1178
+ edge_index = EdgeIndex(
1179
+ inner_tensors['_data'],
1180
+ sparse_size=ctx[0],
1181
+ sort_order=ctx[1],
1182
+ is_undirected=ctx[2],
1183
+ )
1184
+
1185
+ edge_index._indptr = inner_tensors.get('_indptr', None)
1186
+ edge_index._T_perm = inner_tensors.get('_T_perm', None)
1187
+ edge_index._T_indptr = inner_tensors.get('_T_indptr', None)
1188
+ edge_index._cat_metadata = ctx[3]
1189
+
1190
+ return edge_index
1191
+
1192
+ # Prevent auto-wrapping outputs back into the proper subclass type:
1193
+ __torch_function__ = torch._C._disabled_torch_function_impl
1116
1194
 
1117
1195
  @classmethod
1118
- def __torch_function__(
1196
+ def __torch_dispatch__(
1119
1197
  cls: Type,
1120
- func: Callable,
1121
- types: Tuple[Type, ...],
1122
- args: Tuple[Any, ...] = (),
1123
- kwargs: Optional[Dict[str, Any]] = None,
1198
+ func: Callable[..., Any],
1199
+ types: Iterable[Type[Any]],
1200
+ args: Iterable[Tuple[Any, ...]] = (),
1201
+ kwargs: Optional[Dict[Any, Any]] = None,
1124
1202
  ) -> Any:
1125
1203
  # `EdgeIndex` should be treated as a regular PyTorch tensor for all
1126
1204
  # standard PyTorch functionalities. However,
@@ -1136,53 +1214,69 @@ class EdgeIndex(Tensor):
1136
1214
  if func in HANDLED_FUNCTIONS:
1137
1215
  return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
1138
1216
 
1139
- # For all other PyTorch functions, we return a vanilla PyTorch tensor.
1140
- _types = tuple(Tensor if issubclass(t, cls) else t for t in types)
1141
- return Tensor.__torch_function__(func, _types, args, kwargs)
1142
-
1143
-
1144
- class SortReturnType(NamedTuple):
1145
- values: EdgeIndex
1146
- indices: Union[Tensor, slice]
1217
+ # For all other PyTorch functions, we treat them as vanilla tensors.
1218
+ args = pytree.tree_map_only(EdgeIndex, lambda x: x._data, args)
1219
+ if kwargs is not None:
1220
+ kwargs = pytree.tree_map_only(EdgeIndex, lambda x: x._data, kwargs)
1221
+ return func(*args, **(kwargs or {}))
1147
1222
 
1223
+ def __repr__(self) -> str: # type: ignore
1224
+ prefix = f'{self.__class__.__name__}('
1225
+ indent = len(prefix)
1226
+ tensor_str = torch._tensor_str._tensor_str(self._data, indent)
1148
1227
 
1149
- @implements(Tensor.__repr__)
1150
- def __repr__(
1151
- tensor: EdgeIndex,
1152
- *,
1153
- tensor_contents: Optional[str] = None,
1154
- ) -> str:
1155
- # Monkey-patch `torch._tensor_str._add_suffixes`. There might exist better
1156
- # solutions to attach additional metadata, but this seems to be the most
1157
- # straightforward one to inherit most of the `torch.Tensor` print logic:
1158
- orig_fn = torch._tensor_str._add_suffixes
1159
-
1160
- def _add_suffixes(
1161
- tensor_str: str,
1162
- suffixes: List[str],
1163
- indent: int,
1164
- force_newline: bool,
1165
- ) -> str:
1166
-
1167
- num_rows, num_cols = tensor.sparse_size()
1228
+ suffixes = []
1229
+ num_rows, num_cols = self.sparse_size()
1168
1230
  if num_rows is not None or num_cols is not None:
1169
1231
  size_repr = f"({num_rows or '?'}, {num_cols or '?'})"
1170
1232
  suffixes.append(f'sparse_size={size_repr}')
1233
+ suffixes.append(f'nnz={self._data.size(1)}')
1234
+ if (self.device.type != torch._C._get_default_device()
1235
+ or (self.device.type == 'cuda'
1236
+ and torch.cuda.current_device() != self.device.index)
1237
+ or (self.device.type == 'mps')):
1238
+ suffixes.append(f"device='{self.device}'")
1239
+ if self.dtype != torch.int64:
1240
+ suffixes.append(f'dtype={self.dtype}')
1241
+ if self.is_sorted:
1242
+ suffixes.append(f'sort_order={self.sort_order}')
1243
+ if self.is_undirected:
1244
+ suffixes.append('is_undirected=True')
1171
1245
 
1172
- suffixes.append(f'nnz={tensor.size(1)}')
1246
+ return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
1247
+ indent, force_newline=False)
1173
1248
 
1174
- if tensor.is_sorted:
1175
- suffixes.append(f'sort_order={tensor.sort_order}')
1249
+ # Helpers #################################################################
1176
1250
 
1177
- if tensor.is_undirected:
1178
- suffixes.append('is_undirected=True')
1251
+ def _shallow_copy(self) -> 'EdgeIndex':
1252
+ out = EdgeIndex(self._data)
1253
+ out._sparse_size = self._sparse_size
1254
+ out._sort_order = self._sort_order
1255
+ out._is_undirected = self._is_undirected
1256
+ out._indptr = self._indptr
1257
+ out._T_perm = self._T_perm
1258
+ out._T_index = self._T_index
1259
+ out._T_indptr = self._T_indptr
1260
+ out._value = self._value
1261
+ out._cat_metadata = self._cat_metadata
1262
+ return out
1179
1263
 
1180
- return orig_fn(tensor_str, suffixes, indent, force_newline)
1264
+ def _clear_metadata(self) -> 'EdgeIndex':
1265
+ self._sparse_size = (None, None)
1266
+ self._sort_order = None
1267
+ self._is_undirected = False
1268
+ self._indptr = None
1269
+ self._T_perm = None
1270
+ self._T_index = (None, None)
1271
+ self._T_indptr = None
1272
+ self._value = None
1273
+ self._cat_metadata = None
1274
+ return self
1181
1275
 
1182
- torch._tensor_str._add_suffixes = _add_suffixes
1183
- out = torch._tensor_str._str(tensor, tensor_contents=tensor_contents)
1184
- torch._tensor_str._add_suffixes = orig_fn
1185
- return out
1276
+
1277
+ class SortReturnType(NamedTuple):
1278
+ values: EdgeIndex
1279
+ indices: Optional[Tensor]
1186
1280
 
1187
1281
 
1188
1282
  def apply_(
@@ -1190,15 +1284,24 @@ def apply_(
1190
1284
  fn: Callable,
1191
1285
  *args: Any,
1192
1286
  **kwargs: Any,
1193
- ) -> EdgeIndex:
1287
+ ) -> Union[EdgeIndex, Tensor]:
1288
+
1289
+ data = fn(tensor._data, *args, **kwargs)
1290
+
1291
+ if data.dtype not in INDEX_DTYPES:
1292
+ return data
1194
1293
 
1195
- out = Tensor.__torch_function__(fn, (Tensor, ), (tensor, ) + args, kwargs)
1196
- out = out.as_subclass(EdgeIndex)
1294
+ if tensor._data.data_ptr() != data.data_ptr():
1295
+ out = EdgeIndex(data)
1296
+ else: # In-place:
1297
+ tensor._data = data
1298
+ out = tensor
1197
1299
 
1198
1300
  # Copy metadata:
1199
- out._sparse_size = tensor.sparse_size()
1301
+ out._sparse_size = tensor._sparse_size
1200
1302
  out._sort_order = tensor._sort_order
1201
1303
  out._is_undirected = tensor._is_undirected
1304
+ out._cat_metadata = tensor._cat_metadata
1202
1305
 
1203
1306
  # Convert cache (but do not consider `_value`):
1204
1307
  if tensor._indptr is not None:
@@ -1220,77 +1323,68 @@ def apply_(
1220
1323
  return out
1221
1324
 
1222
1325
 
1223
- @implements(torch.clone)
1224
- @implements(Tensor.clone)
1225
- def clone(tensor: EdgeIndex) -> EdgeIndex:
1226
- return apply_(tensor, Tensor.clone)
1227
-
1228
-
1229
- @implements(Tensor.to)
1230
- def to(
1326
+ @implements(aten.clone.default)
1327
+ def _clone(
1231
1328
  tensor: EdgeIndex,
1232
- *args: Any,
1233
- **kwargs: Any,
1234
- ) -> Union[EdgeIndex, Tensor]:
1235
- out = apply_(tensor, Tensor.to, *args, **kwargs)
1236
- return out if out.dtype in SUPPORTED_DTYPES else out.as_tensor()
1237
-
1238
-
1239
- @implements(Tensor.int)
1240
- def _int(tensor: EdgeIndex) -> EdgeIndex:
1241
- return to(tensor, torch.int32)
1242
-
1243
-
1244
- @implements(Tensor.long)
1245
- def long(tensor: EdgeIndex, *args: Any, **kwargs: Any) -> EdgeIndex:
1246
- return to(tensor, torch.int64)
1247
-
1248
-
1249
- @implements(Tensor.cpu)
1250
- def cpu(tensor: EdgeIndex, *args: Any, **kwargs: Any) -> EdgeIndex:
1251
- return apply_(tensor, Tensor.cpu, *args, **kwargs)
1329
+ *,
1330
+ memory_format: torch.memory_format = torch.preserve_format,
1331
+ ) -> EdgeIndex:
1332
+ out = apply_(tensor, aten.clone.default, memory_format=memory_format)
1333
+ assert isinstance(out, EdgeIndex)
1334
+ return out
1252
1335
 
1253
1336
 
1254
- @implements(Tensor.cuda)
1255
- def cuda( # pragma: no cover
1337
+ @implements(aten._to_copy.default)
1338
+ def _to_copy(
1256
1339
  tensor: EdgeIndex,
1257
- *args: Any,
1258
- **kwargs: Any,
1259
- ) -> EdgeIndex:
1260
- return apply_(tensor, Tensor.cuda, *args, **kwargs)
1340
+ *,
1341
+ dtype: Optional[torch.dtype] = None,
1342
+ layout: Optional[torch.layout] = None,
1343
+ device: Optional[torch.device] = None,
1344
+ pin_memory: bool = False,
1345
+ non_blocking: bool = False,
1346
+ memory_format: Optional[torch.memory_format] = None,
1347
+ ) -> Union[EdgeIndex, Tensor]:
1348
+ return apply_(
1349
+ tensor,
1350
+ aten._to_copy.default,
1351
+ dtype=dtype,
1352
+ layout=layout,
1353
+ device=device,
1354
+ pin_memory=pin_memory,
1355
+ non_blocking=non_blocking,
1356
+ memory_format=memory_format,
1357
+ )
1261
1358
 
1262
1359
 
1263
- @implements(Tensor.share_memory_)
1264
- def share_memory_(tensor: EdgeIndex) -> EdgeIndex:
1265
- return apply_(tensor, Tensor.share_memory_)
1360
+ @implements(aten.alias.default)
1361
+ def _alias(tensor: EdgeIndex) -> EdgeIndex:
1362
+ return tensor._shallow_copy()
1266
1363
 
1267
1364
 
1268
- @implements(Tensor.contiguous)
1269
- def contiguous(tensor: EdgeIndex) -> EdgeIndex:
1270
- return apply_(tensor, Tensor.contiguous)
1365
+ @implements(aten._pin_memory.default)
1366
+ def _pin_memory(tensor: EdgeIndex) -> EdgeIndex:
1367
+ out = apply_(tensor, aten._pin_memory.default)
1368
+ assert isinstance(out, EdgeIndex)
1369
+ return out
1271
1370
 
1272
1371
 
1273
- @implements(torch.cat)
1274
- def cat(
1372
+ @implements(aten.cat.default)
1373
+ def _cat(
1275
1374
  tensors: List[Union[EdgeIndex, Tensor]],
1276
1375
  dim: int = 0,
1277
- *,
1278
- out: Optional[Tensor] = None,
1279
1376
  ) -> Union[EdgeIndex, Tensor]:
1280
1377
 
1281
- if len(tensors) == 1:
1282
- return tensors[0]
1283
-
1284
- output = Tensor.__torch_function__(torch.cat, (Tensor, ), (tensors, dim),
1285
- dict(out=out))
1378
+ data_list = pytree.tree_map_only(EdgeIndex, lambda x: x._data, tensors)
1379
+ data = aten.cat.default(data_list, dim=dim)
1286
1380
 
1287
1381
  if dim != 1 and dim != -1: # No valid `EdgeIndex` anymore.
1288
- return output
1382
+ return data
1289
1383
 
1290
1384
  if any([not isinstance(tensor, EdgeIndex) for tensor in tensors]):
1291
- return output
1385
+ return data
1292
1386
 
1293
- output = output.as_subclass(EdgeIndex)
1387
+ out = EdgeIndex(data)
1294
1388
 
1295
1389
  nnz_list = [t.size(1) for t in tensors]
1296
1390
  sparse_size_list = [t.sparse_size() for t in tensors] # type: ignore
@@ -1312,36 +1406,31 @@ def cat(
1312
1406
  total_num_cols = None
1313
1407
  break
1314
1408
  assert isinstance(total_num_cols, int)
1315
- num_cols = max(num_cols, total_num_cols)
1409
+ total_num_cols = max(num_cols, total_num_cols)
1316
1410
 
1317
- output._sparse_size = (num_rows, num_cols)
1411
+ out._sparse_size = (total_num_rows, total_num_cols)
1318
1412
 
1319
1413
  # Post-process `is_undirected`:
1320
- output._is_undirected = all(is_undirected_list)
1414
+ out._is_undirected = all(is_undirected_list)
1321
1415
 
1322
- output._cat_metadata = CatMetadata(
1416
+ out._cat_metadata = CatMetadata(
1323
1417
  nnz=nnz_list,
1324
1418
  sparse_size=sparse_size_list,
1325
1419
  sort_order=sort_order_list,
1326
1420
  is_undirected=is_undirected_list,
1327
1421
  )
1328
1422
 
1329
- return output
1423
+ return out
1330
1424
 
1331
1425
 
1332
- @implements(torch.flip)
1333
- @implements(Tensor.flip)
1334
- def flip(
1426
+ @implements(aten.flip.default)
1427
+ def _flip(
1335
1428
  input: EdgeIndex,
1336
- dims: Union[int, List[int], Tuple[int, ...]],
1337
- ) -> Union[EdgeIndex, Tensor]:
1338
-
1339
- if isinstance(dims, int):
1340
- dims = [dims]
1341
- assert isinstance(dims, (tuple, list))
1429
+ dims: Union[List[int], Tuple[int, ...]],
1430
+ ) -> EdgeIndex:
1342
1431
 
1343
- out = Tensor.__torch_function__(torch.flip, (Tensor, ), (input, dims))
1344
- out = out.as_subclass(EdgeIndex)
1432
+ data = aten.flip.default(input._data, dims)
1433
+ out = EdgeIndex(data)
1345
1434
 
1346
1435
  out._value = input._value
1347
1436
  out._is_undirected = input.is_undirected
@@ -1364,238 +1453,309 @@ def flip(
1364
1453
  return out
1365
1454
 
1366
1455
 
1367
- @implements(torch.index_select)
1368
- @implements(Tensor.index_select)
1369
- def index_select(
1456
+ @implements(aten.index_select.default)
1457
+ def _index_select(
1370
1458
  input: EdgeIndex,
1371
1459
  dim: int,
1372
1460
  index: Tensor,
1373
- *,
1374
- out: Optional[Tensor] = None,
1375
1461
  ) -> Union[EdgeIndex, Tensor]:
1376
1462
 
1377
- output = Tensor.__torch_function__( #
1378
- torch.index_select, (Tensor, ), (input, dim, index), dict(out=out))
1463
+ out = aten.index_select.default(input._data, dim, index)
1379
1464
 
1380
1465
  if dim == 1 or dim == -1:
1381
- output = output.as_subclass(EdgeIndex)
1382
- output._sparse_size = input.sparse_size()
1466
+ out = EdgeIndex(out)
1467
+ out._sparse_size = input.sparse_size()
1383
1468
 
1384
- return output
1469
+ return out
1385
1470
 
1386
1471
 
1387
- @implements(torch.narrow)
1388
- @implements(Tensor.narrow)
1389
- def narrow(
1472
+ @implements(aten.slice.Tensor)
1473
+ def _slice(
1390
1474
  input: EdgeIndex,
1391
1475
  dim: int,
1392
- start: Union[int, Tensor],
1393
- length: int,
1476
+ start: Optional[int] = None,
1477
+ end: Optional[int] = None,
1478
+ step: int = 1,
1394
1479
  ) -> Union[EdgeIndex, Tensor]:
1395
1480
 
1396
- out = Tensor.__torch_function__( #
1397
- torch.narrow, (Tensor, ), (input, dim, start, length))
1481
+ if ((start is None or start <= 0)
1482
+ and (end is None or end > input.size(dim)) and step == 1):
1483
+ return input._shallow_copy() # No-op.
1484
+
1485
+ out = aten.slice.Tensor(input._data, dim, start, end, step)
1398
1486
 
1399
1487
  if dim == 1 or dim == -1:
1400
- out = out.as_subclass(EdgeIndex)
1488
+ if step != 1:
1489
+ out = out.contiguous()
1490
+
1491
+ out = EdgeIndex(out)
1401
1492
  out._sparse_size = input.sparse_size()
1402
1493
  # NOTE We could potentially maintain `rowptr`/`colptr` attributes here,
1403
1494
  # but it is not really clear if this is worth it. The most important
1404
1495
  # information, the sort order, needs to be maintained though:
1405
- out._sort_order = input._sort_order
1496
+ if step >= 0:
1497
+ out._sort_order = input._sort_order
1498
+ else:
1499
+ if input._sort_order == SortOrder.ROW:
1500
+ out._sort_order = SortOrder.COL
1501
+ elif input._sort_order == SortOrder.COL:
1502
+ out._sort_order = SortOrder.ROW
1406
1503
 
1407
1504
  return out
1408
1505
 
1409
1506
 
1410
- @implements(Tensor.__getitem__)
1411
- def getitem(input: EdgeIndex, index: Any) -> Union[EdgeIndex, Tensor]:
1412
- out = Tensor.__torch_function__( #
1413
- Tensor.__getitem__, (Tensor, ), (input, index))
1507
+ @implements(aten.index.Tensor)
1508
+ def _index(
1509
+ input: Union[EdgeIndex, Tensor],
1510
+ indices: List[Optional[Union[Tensor, EdgeIndex]]],
1511
+ ) -> Union[EdgeIndex, Tensor]:
1512
+
1513
+ if not isinstance(input, EdgeIndex):
1514
+ indices = pytree.tree_map_only(EdgeIndex, lambda x: x._data, indices)
1515
+ return aten.index.Tensor(input, indices)
1516
+
1517
+ out = aten.index.Tensor(input._data, indices)
1518
+
1519
+ if len(indices) != 2 or indices[0] is not None:
1520
+ return out
1414
1521
 
1415
- # There exists 3 possible index types that map back to a valid `EdgeIndex`,
1416
- # and all include selecting/filtering in the last dimension only:
1417
- def is_last_dim_select(i: Any) -> bool:
1418
- # Maps to true for `__getitem__` requests of the form
1419
- # `tensor[..., index]` or `tensor[:, index]`.
1420
- if not isinstance(i, tuple) or len(i) != 2:
1421
- return False
1422
- if i[0] == Ellipsis:
1423
- return True
1424
- if not isinstance(i[0], slice):
1425
- return False
1426
- return i[0].start is None and i[0].stop is None and i[0].step is None
1522
+ index = indices[1]
1523
+ assert isinstance(index, Tensor)
1427
1524
 
1428
- is_valid = is_last_dim_select(index)
1525
+ out = EdgeIndex(out)
1429
1526
 
1430
1527
  # 1. `edge_index[:, mask]` or `edge_index[..., mask]`.
1431
- if (is_valid and isinstance(index[1], Tensor)
1432
- and index[1].dtype in (torch.bool, torch.uint8)):
1433
- out = out.as_subclass(EdgeIndex)
1528
+ if index.dtype in (torch.bool, torch.uint8):
1434
1529
  out._sparse_size = input.sparse_size()
1435
1530
  out._sort_order = input._sort_order
1436
1531
 
1437
- # 2. `edge_index[:, index]` or `edge_index[..., index]`.
1438
- elif is_valid and isinstance(index[1], Tensor):
1439
- out = out.as_subclass(EdgeIndex)
1532
+ else: # 2. `edge_index[:, index]` or `edge_index[..., index]`.
1440
1533
  out._sparse_size = input.sparse_size()
1441
1534
 
1442
- # 3. `edge_index[:, slice]` or `edge_index[..., slice]`.
1443
- elif is_valid and isinstance(index[1], slice):
1444
- out = out.as_subclass(EdgeIndex)
1445
- out._sparse_size = input.sparse_size()
1446
- if index[1].step is None or index[1].step > 0:
1447
- out._sort_order = input._sort_order
1535
+ return out
1536
+
1537
+
1538
+ @implements(aten.select.int)
1539
+ def _select(input: EdgeIndex, dim: int, index: int) -> Union[Tensor, Index]:
1540
+ out = aten.select.int(input._data, dim, index)
1541
+
1542
+ if dim == 0 or dim == -2:
1543
+ out = Index(out)
1544
+
1545
+ if index == 0 or index == -2: # Row-select:
1546
+ out._dim_size = input.sparse_size(0)
1547
+ out._is_sorted = input.is_sorted_by_row
1548
+ if input.is_sorted_by_row:
1549
+ out._indptr = input._indptr
1550
+
1551
+ else: # Col-select:
1552
+ assert index == 1 or index == -1
1553
+ out._dim_size = input.sparse_size(1)
1554
+ out._is_sorted = input.is_sorted_by_col
1555
+ if input.is_sorted_by_col:
1556
+ out._indptr = input._indptr
1448
1557
 
1449
1558
  return out
1450
1559
 
1451
1560
 
1452
- def postprocess_add_(
1561
+ @implements(aten.unbind.int)
1562
+ def _unbind(
1453
1563
  input: EdgeIndex,
1454
- other: Union[int, Tensor],
1455
- out: Tensor,
1564
+ dim: int = 0,
1565
+ ) -> Union[List[Index], List[Tensor]]:
1566
+
1567
+ if dim == 0 or dim == -2:
1568
+ row = input[0]
1569
+ assert isinstance(row, Index)
1570
+ col = input[1]
1571
+ assert isinstance(col, Index)
1572
+ return [row, col]
1573
+
1574
+ return aten.unbind.int(input._data, dim)
1575
+
1576
+
1577
+ @implements(aten.add.Tensor)
1578
+ def _add(
1579
+ input: EdgeIndex,
1580
+ other: Union[int, Tensor, EdgeIndex],
1581
+ *,
1456
1582
  alpha: int = 1,
1457
1583
  ) -> Union[EdgeIndex, Tensor]:
1458
1584
 
1459
- if out.dtype not in SUPPORTED_DTYPES:
1585
+ out = aten.add.Tensor(
1586
+ input._data,
1587
+ other._data if isinstance(other, EdgeIndex) else other,
1588
+ alpha=alpha,
1589
+ )
1590
+
1591
+ if out.dtype not in INDEX_DTYPES:
1460
1592
  return out
1461
1593
  if out.dim() != 2 or out.size(0) != 2:
1462
1594
  return out
1463
1595
 
1464
- output: EdgeIndex = out.as_subclass(EdgeIndex)
1596
+ out = EdgeIndex(out)
1597
+
1598
+ if isinstance(other, Tensor) and other.numel() <= 1:
1599
+ other = int(other)
1465
1600
 
1466
1601
  if isinstance(other, int):
1467
1602
  size = maybe_add(input._sparse_size, other, alpha)
1468
1603
  assert len(size) == 2
1469
- output._sparse_size = size
1470
- output._sort_order = input._sort_order
1471
- output._is_undirected = input.is_undirected
1472
- output._T_perm = input._T_perm
1473
-
1474
- elif isinstance(other, Tensor) and other.numel() <= 1:
1475
- size = maybe_add(input._sparse_size, int(other), alpha)
1476
- assert len(size) == 2
1477
- output._sparse_size = size
1478
- output._sort_order = input._sort_order
1479
- output._is_undirected = input.is_undirected
1480
- output._T_perm = input._T_perm
1604
+ out._sparse_size = size
1605
+ out._sort_order = input._sort_order
1606
+ out._is_undirected = input.is_undirected
1607
+ out._T_perm = input._T_perm
1481
1608
 
1482
1609
  elif isinstance(other, Tensor) and other.size() == (2, 1):
1483
1610
  size = maybe_add(input._sparse_size, other.view(-1).tolist(), alpha)
1484
1611
  assert len(size) == 2
1485
- output._sparse_size = size
1486
- output._sort_order = input._sort_order
1487
- output._T_perm = input._T_perm
1612
+ out._sparse_size = size
1613
+ out._sort_order = input._sort_order
1488
1614
  if torch.equal(other[0], other[1]):
1489
- output._is_undirected = input.is_undirected
1615
+ out._is_undirected = input.is_undirected
1616
+ out._T_perm = input._T_perm
1490
1617
 
1491
1618
  elif isinstance(other, EdgeIndex):
1492
1619
  size = maybe_add(input._sparse_size, other._sparse_size, alpha)
1493
1620
  assert len(size) == 2
1494
- output._sparse_size = size
1621
+ out._sparse_size = size
1495
1622
 
1496
- return output
1623
+ return out
1497
1624
 
1498
1625
 
1499
- @implements(torch.add)
1500
- @implements(Tensor.add)
1501
- def add(
1626
+ @implements(aten.add_.Tensor)
1627
+ def add_(
1502
1628
  input: EdgeIndex,
1503
- other: Union[int, Tensor],
1629
+ other: Union[int, Tensor, EdgeIndex],
1504
1630
  *,
1505
1631
  alpha: int = 1,
1506
- out: Optional[Tensor] = None,
1507
- ) -> Union[EdgeIndex, Tensor]:
1632
+ ) -> EdgeIndex:
1508
1633
 
1509
- output = Tensor.__torch_function__( #
1510
- torch.add, (Tensor, ), (input, other), dict(alpha=alpha, out=out))
1634
+ sparse_size = input._sparse_size
1635
+ sort_order = input._sort_order
1636
+ is_undirected = input._is_undirected
1637
+ T_perm = input._T_perm
1638
+ input._clear_metadata()
1511
1639
 
1512
- return postprocess_add_(input, other, output, alpha)
1640
+ aten.add_.Tensor(
1641
+ input._data,
1642
+ other._data if isinstance(other, EdgeIndex) else other,
1643
+ alpha=alpha,
1644
+ )
1513
1645
 
1646
+ if isinstance(other, Tensor) and other.numel() <= 1:
1647
+ other = int(other)
1514
1648
 
1515
- @implements(Tensor.add_)
1516
- def add_(
1517
- input: EdgeIndex,
1518
- other: Union[int, Tensor],
1519
- *,
1520
- alpha: int = 1,
1521
- ) -> Union[EdgeIndex, Tensor]:
1649
+ if isinstance(other, int):
1650
+ size = maybe_add(sparse_size, other, alpha)
1651
+ assert len(size) == 2
1652
+ input._sparse_size = size
1653
+ input._sort_order = sort_order
1654
+ input._is_undirected = is_undirected
1655
+ input._T_perm = T_perm
1522
1656
 
1523
- output = Tensor.__torch_function__( #
1524
- Tensor.add_, (Tensor, ), (input, other), dict(alpha=alpha))
1657
+ elif isinstance(other, Tensor) and other.size() == (2, 1):
1658
+ size = maybe_add(sparse_size, other.view(-1).tolist(), alpha)
1659
+ assert len(size) == 2
1660
+ input._sparse_size = size
1661
+ input._sort_order = sort_order
1662
+ if torch.equal(other[0], other[1]):
1663
+ input._is_undirected = is_undirected
1664
+ input._T_perm = T_perm
1525
1665
 
1526
- return postprocess_add_(input, other, output, alpha)
1666
+ elif isinstance(other, EdgeIndex):
1667
+ size = maybe_add(sparse_size, other._sparse_size, alpha)
1668
+ assert len(size) == 2
1669
+ input._sparse_size = size
1670
+
1671
+ return input
1527
1672
 
1528
1673
 
1529
- def postprocess_sub_(
1674
+ @implements(aten.sub.Tensor)
1675
+ def _sub(
1530
1676
  input: EdgeIndex,
1531
- other: Union[int, Tensor],
1532
- out: Tensor,
1677
+ other: Union[int, Tensor, EdgeIndex],
1678
+ *,
1533
1679
  alpha: int = 1,
1534
1680
  ) -> Union[EdgeIndex, Tensor]:
1535
1681
 
1536
- if out.dtype not in SUPPORTED_DTYPES:
1682
+ out = aten.sub.Tensor(
1683
+ input._data,
1684
+ other._data if isinstance(other, EdgeIndex) else other,
1685
+ alpha=alpha,
1686
+ )
1687
+
1688
+ if out.dtype not in INDEX_DTYPES:
1537
1689
  return out
1538
1690
  if out.dim() != 2 or out.size(0) != 2:
1539
1691
  return out
1540
1692
 
1541
- output: EdgeIndex = out.as_subclass(EdgeIndex)
1693
+ out = EdgeIndex(out)
1694
+
1695
+ if isinstance(other, Tensor) and other.numel() <= 1:
1696
+ other = int(other)
1542
1697
 
1543
1698
  if isinstance(other, int):
1544
1699
  size = maybe_sub(input._sparse_size, other, alpha)
1545
1700
  assert len(size) == 2
1546
- output._sparse_size = size
1547
- output._sort_order = input._sort_order
1548
- output._is_undirected = input.is_undirected
1549
- output._T_perm = input._T_perm
1550
-
1551
- elif isinstance(other, Tensor) and other.numel() <= 1:
1552
- size = maybe_sub(input._sparse_size, int(other), alpha)
1553
- assert len(size) == 2
1554
- output._sparse_size = size
1555
- output._sort_order = input._sort_order
1556
- output._is_undirected = input.is_undirected
1557
- output._T_perm = input._T_perm
1701
+ out._sparse_size = size
1702
+ out._sort_order = input._sort_order
1703
+ out._is_undirected = input.is_undirected
1704
+ out._T_perm = input._T_perm
1558
1705
 
1559
1706
  elif isinstance(other, Tensor) and other.size() == (2, 1):
1560
1707
  size = maybe_sub(input._sparse_size, other.view(-1).tolist(), alpha)
1561
1708
  assert len(size) == 2
1562
- output._sparse_size = size
1563
- output._sort_order = input._sort_order
1564
- output._T_perm = input._T_perm
1709
+ out._sparse_size = size
1710
+ out._sort_order = input._sort_order
1565
1711
  if torch.equal(other[0], other[1]):
1566
- output._is_undirected = input.is_undirected
1712
+ out._is_undirected = input.is_undirected
1713
+ out._T_perm = input._T_perm
1567
1714
 
1568
- return output
1715
+ return out
1569
1716
 
1570
1717
 
1571
- @implements(torch.sub)
1572
- @implements(Tensor.sub)
1573
- def sub(
1718
+ @implements(aten.sub_.Tensor)
1719
+ def sub_(
1574
1720
  input: EdgeIndex,
1575
- other: Union[int, Tensor],
1721
+ other: Union[int, Tensor, EdgeIndex],
1576
1722
  *,
1577
1723
  alpha: int = 1,
1578
- out: Optional[Tensor] = None,
1579
- ) -> Union[EdgeIndex, Tensor]:
1724
+ ) -> EdgeIndex:
1580
1725
 
1581
- output = Tensor.__torch_function__( #
1582
- torch.sub, (Tensor, ), (input, other), dict(alpha=alpha, out=out))
1726
+ sparse_size = input._sparse_size
1727
+ sort_order = input._sort_order
1728
+ is_undirected = input._is_undirected
1729
+ T_perm = input._T_perm
1730
+ input._clear_metadata()
1583
1731
 
1584
- return postprocess_sub_(input, other, output, alpha)
1732
+ aten.sub_.Tensor(
1733
+ input._data,
1734
+ other._data if isinstance(other, EdgeIndex) else other,
1735
+ alpha=alpha,
1736
+ )
1585
1737
 
1738
+ if isinstance(other, Tensor) and other.numel() <= 1:
1739
+ other = int(other)
1586
1740
 
1587
- @implements(Tensor.sub_)
1588
- def sub_(
1589
- input: EdgeIndex,
1590
- other: Union[int, Tensor],
1591
- *,
1592
- alpha: int = 1,
1593
- ) -> Union[EdgeIndex, Tensor]:
1741
+ if isinstance(other, int):
1742
+ size = maybe_sub(sparse_size, other, alpha)
1743
+ assert len(size) == 2
1744
+ input._sparse_size = size
1745
+ input._sort_order = sort_order
1746
+ input._is_undirected = is_undirected
1747
+ input._T_perm = T_perm
1594
1748
 
1595
- output = Tensor.__torch_function__( #
1596
- Tensor.sub_, (Tensor, ), (input, other), dict(alpha=alpha))
1749
+ elif isinstance(other, Tensor) and other.size() == (2, 1):
1750
+ size = maybe_sub(sparse_size, other.view(-1).tolist(), alpha)
1751
+ assert len(size) == 2
1752
+ input._sparse_size = size
1753
+ input._sort_order = sort_order
1754
+ if torch.equal(other[0], other[1]):
1755
+ input._is_undirected = is_undirected
1756
+ input._T_perm = T_perm
1597
1757
 
1598
- return postprocess_sub_(input, other, output, alpha)
1758
+ return input
1599
1759
 
1600
1760
 
1601
1761
  # Sparse-Dense Matrix Multiplication ##########################################
@@ -1620,13 +1780,13 @@ def _torch_sparse_spmm(
1620
1780
  if not transpose:
1621
1781
  assert input.is_sorted_by_row
1622
1782
  (rowptr, col), _ = input.get_csr()
1623
- row = input[0]
1783
+ row = input._data[0]
1624
1784
  if other.requires_grad and reduce in ['sum', 'mean']:
1625
1785
  (colptr, _), perm = input.get_csc()
1626
1786
  else:
1627
1787
  assert input.is_sorted_by_col
1628
1788
  (rowptr, col), _ = input.get_csc()
1629
- row = input[1]
1789
+ row = input._data[1]
1630
1790
  if other.requires_grad and reduce in ['sum', 'mean']:
1631
1791
  (colptr, _), perm = input.get_csr()
1632
1792
 
@@ -1699,7 +1859,7 @@ class _TorchSPMM(torch.autograd.Function):
1699
1859
  adj = input.to_sparse_csr(value)
1700
1860
  else:
1701
1861
  (colptr, row), perm = input.get_csc()
1702
- if value is not None:
1862
+ if value is not None and perm is not None:
1703
1863
  value = value[perm]
1704
1864
  else:
1705
1865
  value = input._get_value()
@@ -1715,7 +1875,7 @@ class _TorchSPMM(torch.autograd.Function):
1715
1875
  adj = input.to_sparse_csc(value).t()
1716
1876
  else:
1717
1877
  (rowptr, col), perm = input.get_csr()
1718
- if value is not None:
1878
+ if value is not None and perm is not None:
1719
1879
  value = value[perm]
1720
1880
  else:
1721
1881
  value = input._get_value()
@@ -1746,14 +1906,16 @@ def _scatter_spmm(
1746
1906
  from torch_geometric.utils import scatter
1747
1907
 
1748
1908
  if not transpose:
1749
- other_j = other[input[1]]
1750
- index = input[0]
1909
+ other_j = other[input._data[1]]
1910
+ index = input._data[0]
1911
+ dim_size = input.get_sparse_size(0)
1751
1912
  else:
1752
- other_j = other[input[0]]
1753
- index = input[1]
1913
+ other_j = other[input._data[0]]
1914
+ index = input._data[1]
1915
+ dim_size = input.get_sparse_size(1)
1754
1916
 
1755
1917
  other_j = other_j * value.view(-1, 1) if value is not None else other_j
1756
- return scatter(other_j, index, 0, dim_size=other.size(0), reduce=reduce)
1918
+ return scatter(other_j, index, 0, dim_size=dim_size, reduce=reduce)
1757
1919
 
1758
1920
 
1759
1921
  def _spmm(
@@ -1775,7 +1937,7 @@ def _spmm(
1775
1937
  if transpose and not input.is_sorted_by_col:
1776
1938
  cls_name = input.__class__.__name__
1777
1939
  raise ValueError(f"'matmul(..., transpose=True)' requires "
1778
- f"'{cls_name}' to be sorted by colums")
1940
+ f"'{cls_name}' to be sorted by columns")
1779
1941
 
1780
1942
  if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()
1781
1943
  and other.is_cuda): # pragma: no cover
@@ -1786,17 +1948,17 @@ def _spmm(
1786
1948
  return _torch_sparse_spmm(input, other, value, reduce, transpose)
1787
1949
  return _scatter_spmm(input, other, value, reduce, transpose)
1788
1950
 
1789
- if reduce == 'sum' or reduce == 'add':
1790
- return _TorchSPMM.apply(input, other, value, 'sum', transpose)
1951
+ if torch_geometric.typing.WITH_PT20:
1952
+ if reduce == 'sum' or reduce == 'add':
1953
+ return _TorchSPMM.apply(input, other, value, 'sum', transpose)
1791
1954
 
1792
- if reduce == 'mean':
1793
- out = _TorchSPMM.apply(input, other, value, 'sum', transpose)
1794
- count = input.get_indptr().diff()
1795
- return out / count.clamp_(min=1).to(out.dtype).view(-1, 1)
1955
+ if reduce == 'mean':
1956
+ out = _TorchSPMM.apply(input, other, value, 'sum', transpose)
1957
+ count = input.get_indptr().diff()
1958
+ return out / count.clamp_(min=1).to(out.dtype).view(-1, 1)
1796
1959
 
1797
- if (torch_geometric.typing.WITH_PT20 and not other.is_cuda
1798
- and not other.requires_grad):
1799
- return _TorchSPMM.apply(input, other, value, reduce, transpose)
1960
+ if not other.is_cuda and not other.requires_grad:
1961
+ return _TorchSPMM.apply(input, other, value, reduce, transpose)
1800
1962
 
1801
1963
  if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling():
1802
1964
  return _torch_sparse_spmm(input, other, value, reduce, transpose)
@@ -1858,7 +2020,7 @@ def matmul(
1858
2020
  else:
1859
2021
  raise NotImplementedError
1860
2022
 
1861
- edge_index = edge_index.as_subclass(EdgeIndex)
2023
+ edge_index = EdgeIndex(edge_index)
1862
2024
  edge_index._sort_order = SortOrder.ROW
1863
2025
  edge_index._sparse_size = (out.size(0), out.size(1))
1864
2026
  edge_index._indptr = rowptr
@@ -1866,20 +2028,36 @@ def matmul(
1866
2028
  return edge_index, out.values()
1867
2029
 
1868
2030
 
1869
- @implements(torch.mm)
1870
- @implements(torch.matmul)
1871
- @implements(Tensor.matmul)
1872
- def _matmul1(
2031
+ @implements(aten.mm.default)
2032
+ def _mm(
1873
2033
  input: EdgeIndex,
1874
2034
  other: Union[Tensor, EdgeIndex],
1875
2035
  ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:
1876
2036
  return matmul(input, other)
1877
2037
 
1878
2038
 
1879
- @implements(torch.sparse.mm)
1880
- def _matmul2(
2039
+ @implements(aten._sparse_addmm.default)
2040
+ def _addmm(
2041
+ input: Tensor,
1881
2042
  mat1: EdgeIndex,
1882
- mat2: Union[Tensor, EdgeIndex],
1883
- reduce: ReduceType = 'sum',
1884
- ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:
1885
- return matmul(mat1, mat2, reduce=reduce)
2043
+ mat2: Tensor,
2044
+ beta: float = 1.0,
2045
+ alpha: float = 1.0,
2046
+ ) -> Tensor:
2047
+ assert input.abs().sum() == 0.0
2048
+ out = matmul(mat1, mat2)
2049
+ assert isinstance(out, Tensor)
2050
+ return alpha * out if alpha != 1.0 else out
2051
+
2052
+
2053
+ if hasattr(aten, '_sparse_mm_reduce_impl'):
2054
+
2055
+ @implements(aten._sparse_mm_reduce_impl.default)
2056
+ def _mm_reduce(
2057
+ mat1: EdgeIndex,
2058
+ mat2: Tensor,
2059
+ reduce: ReduceType = 'sum',
2060
+ ) -> Tuple[Tensor, Tensor]:
2061
+ out = matmul(mat1, mat2, reduce=reduce)
2062
+ assert isinstance(out, Tensor)
2063
+ return out, out # We return a dummy tensor for `argout` for now.