pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +13 -7
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +317 -65
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +3 -5
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +329 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +56 -22
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,826 @@
1
+ import functools
2
+ from typing import (
3
+ Any,
4
+ Callable,
5
+ Dict,
6
+ Iterable,
7
+ List,
8
+ NamedTuple,
9
+ Optional,
10
+ Tuple,
11
+ Type,
12
+ Union,
13
+ )
14
+
15
+ import torch
16
+ import torch.utils._pytree as pytree
17
+ from torch import Tensor
18
+
19
+ from torch_geometric.typing import INDEX_DTYPES
20
+
21
+ aten = torch.ops.aten
22
+
23
+ HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
24
+
25
+
26
+ def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
27
+ index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
28
+ return index.repeat_interleave(ptr.diff(), output_size=output_size)
29
+
30
+
31
+ def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:
32
+ if size is None:
33
+ size = int(index.max()) + 1 if index.numel() > 0 else 0
34
+
35
+ return torch._convert_indices_from_coo_to_csr(
36
+ index, size, out_int32=index.dtype != torch.int64)
37
+
38
+
39
+ class CatMetadata(NamedTuple):
40
+ nnz: List[int]
41
+ dim_size: List[Optional[int]]
42
+ is_sorted: List[bool]
43
+
44
+
45
+ def implements(torch_function: Callable) -> Callable:
46
+ r"""Registers a :pytorch:`PyTorch` function override."""
47
+ @functools.wraps(torch_function)
48
+ def decorator(my_function: Callable) -> Callable:
49
+ HANDLED_FUNCTIONS[torch_function] = my_function
50
+ return my_function
51
+
52
+ return decorator
53
+
54
+
55
+ def assert_valid_dtype(tensor: Tensor) -> None:
56
+ if tensor.dtype not in INDEX_DTYPES:
57
+ raise ValueError(f"'Index' holds an unsupported data type "
58
+ f"(got '{tensor.dtype}', but expected one of "
59
+ f"{INDEX_DTYPES})")
60
+
61
+
62
+ def assert_one_dimensional(tensor: Tensor) -> None:
63
+ if tensor.dim() != 1:
64
+ raise ValueError(f"'Index' needs to be one-dimensional "
65
+ f"(got {tensor.dim()} dimensions)")
66
+
67
+
68
+ def assert_contiguous(tensor: Tensor) -> None:
69
+ if not tensor.is_contiguous():
70
+ raise ValueError("'Index' needs to be contiguous. Please call "
71
+ "`index.contiguous()` before proceeding.")
72
+
73
+
74
+ def assert_sorted(func: Callable) -> Callable:
75
+ @functools.wraps(func)
76
+ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
77
+ if not self.is_sorted:
78
+ cls_name = self.__class__.__name__
79
+ raise ValueError(
80
+ f"Cannot call '{func.__name__}' since '{cls_name}' is not "
81
+ f"sorted. Please call `{cls_name}.sort()` first.")
82
+ return func(self, *args, **kwargs)
83
+
84
+ return wrapper
85
+
86
+
87
+ class Index(Tensor):
88
+ r"""A one-dimensional :obj:`index` tensor with additional (meta)data
89
+ attached.
90
+
91
+ :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
92
+ indices of shape :obj:`[num_indices]`.
93
+
94
+ While :class:`Index` sub-classes a general :pytorch:`null`
95
+ :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:
96
+
97
+ * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
98
+ the size of a dimension that can be indexed via :obj:`index`.
99
+ By default, it is inferred as :obj:`dim_size=index.max() + 1`.
100
+ * :obj:`is_sorted`: Whether indices are sorted in ascending order.
101
+
102
+ Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
103
+ conversion in case its representation is sorted.
104
+ Caches are filled based on demand (*e.g.*, when calling
105
+ :meth:`Index.get_indptr`), or when explicitly requested via
106
+ :meth:`Index.fill_cache_`, and are maintaned and adjusted over its
107
+ lifespan.
108
+
109
+ This representation ensures optimal computation in GNN message passing
110
+ schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
111
+ workflows.
112
+
113
+ .. code-block:: python
114
+
115
+ from torch_geometric import Index
116
+
117
+ index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
118
+ >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
119
+ assert index.dim_size == 3
120
+ assert index.is_sorted
121
+
122
+ # Flipping order:
123
+ index.flip(0)
124
+ >>> Index([[2, 1, 1, 0], dim_size=3)
125
+ assert not index.is_sorted
126
+
127
+ # Filtering:
128
+ mask = torch.tensor([True, True, True, False])
129
+ index[:, mask]
130
+ >>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
131
+ assert index.is_sorted
132
+ """
133
+ # See "https://pytorch.org/docs/stable/notes/extending.html"
134
+ # for a basic tutorial on how to subclass `torch.Tensor`.
135
+
136
+ # The underlying tensor representation:
137
+ _data: Tensor
138
+
139
+ # The size of the underlying sparse vector, e.g. `_data.max() + 1` :
140
+ _dim_size: Optional[int] = None
141
+
142
+ # Whether the `index` representation is sorted:
143
+ _is_sorted: bool = False
144
+
145
+ # A cache for its compressed representation:
146
+ _indptr: Optional[Tensor] = None
147
+
148
+ # Whenever we perform a concatenation of indices, we cache the original
149
+ # metadata to be able to reconstruct individual indices:
150
+ _cat_metadata: Optional[CatMetadata] = None
151
+
152
+ @staticmethod
153
+ def __new__(
154
+ cls: Type,
155
+ data: Any,
156
+ *args: Any,
157
+ dim_size: Optional[int] = None,
158
+ is_sorted: bool = False,
159
+ **kwargs: Any,
160
+ ) -> 'Index':
161
+ if not isinstance(data, Tensor):
162
+ data = torch.tensor(data, *args, **kwargs)
163
+ elif len(args) > 0:
164
+ raise TypeError(
165
+ f"new() received an invalid combination of arguments - got "
166
+ f"(Tensor, {', '.join(str(type(arg)) for arg in args)})")
167
+ elif len(kwargs) > 0:
168
+ raise TypeError(f"new() received invalid keyword arguments - got "
169
+ f"{set(kwargs.keys())})")
170
+
171
+ assert isinstance(data, Tensor)
172
+
173
+ indptr: Optional[Tensor] = None
174
+
175
+ if isinstance(data, cls): # If passed `Index`, inherit metadata:
176
+ indptr = data._indptr
177
+ dim_size = dim_size or data.dim_size
178
+ is_sorted = is_sorted or data.is_sorted
179
+
180
+ assert_valid_dtype(data)
181
+ assert_one_dimensional(data)
182
+ assert_contiguous(data)
183
+
184
+ out = Tensor._make_wrapper_subclass( # type: ignore
185
+ cls,
186
+ size=data.size(),
187
+ strides=data.stride(),
188
+ dtype=data.dtype,
189
+ device=data.device,
190
+ layout=data.layout,
191
+ requires_grad=False,
192
+ )
193
+ assert isinstance(out, Index)
194
+
195
+ # Attach metadata:
196
+ out._data = data
197
+ out._dim_size = dim_size
198
+ out._is_sorted = is_sorted
199
+ out._indptr = indptr
200
+
201
+ if isinstance(data, cls):
202
+ out._data = data._data
203
+
204
+ # Reset metadata if cache is invalidated:
205
+ if dim_size is not None and dim_size != data.dim_size:
206
+ out._indptr = None
207
+
208
+ return out
209
+
210
+ # Validation ##############################################################
211
+
212
+ def validate(self) -> 'Index':
213
+ r"""Validates the :class:`Index` representation.
214
+
215
+ In particular, it ensures that
216
+
217
+ * it only holds valid indices.
218
+ * the sort order is correctly set.
219
+ """
220
+ assert_valid_dtype(self._data)
221
+ assert_one_dimensional(self._data)
222
+ assert_contiguous(self._data)
223
+
224
+ if self.numel() > 0 and self._data.min() < 0:
225
+ raise ValueError(f"'{self.__class__.__name__}' contains negative "
226
+ f"indices (got {int(self.min())})")
227
+
228
+ if (self.numel() > 0 and self.dim_size is not None
229
+ and self._data.max() >= self.dim_size):
230
+ raise ValueError(f"'{self.__class__.__name__}' contains larger "
231
+ f"indices than its registered size "
232
+ f"(got {int(self._data.max())}, but expected "
233
+ f"values smaller than {self.dim_size})")
234
+
235
+ if self.is_sorted and (self._data.diff() < 0).any():
236
+ raise ValueError(f"'{self.__class__.__name__}' is not sorted")
237
+
238
+ return self
239
+
240
+ # Properties ##############################################################
241
+
242
+ @property
243
+ def dim_size(self) -> Optional[int]:
244
+ r"""The size of the underlying sparse vector."""
245
+ return self._dim_size
246
+
247
+ @property
248
+ def is_sorted(self) -> bool:
249
+ r"""Returns whether indices are sorted in ascending order."""
250
+ return self._is_sorted
251
+
252
+ @property
253
+ def dtype(self) -> torch.dtype: # type: ignore
254
+ # TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
255
+ return self._data.dtype
256
+
257
+ # Cache Interface #########################################################
258
+
259
+ def get_dim_size(self) -> int:
260
+ r"""The size of the underlying sparse vector.
261
+ Automatically computed and cached when not explicitly set.
262
+ """
263
+ if self._dim_size is None:
264
+ dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
265
+ self._dim_size = dim_size
266
+
267
+ assert isinstance(self._dim_size, int)
268
+ return self._dim_size
269
+
270
+ def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
271
+ r"""Assigns or re-assigns the size of the underlying sparse vector."""
272
+ if self.is_sorted and self._indptr is not None:
273
+ if dim_size is None:
274
+ self._indptr = None
275
+
276
+ elif self._indptr.numel() - 1 >= dim_size:
277
+ self._indptr = self._indptr[:dim_size + 1]
278
+
279
+ else:
280
+ fill_value = self._indptr.new_full(
281
+ (dim_size - self._indptr.numel() + 1, ),
282
+ fill_value=self._indptr[-1], # type: ignore
283
+ )
284
+ self._indptr = torch.cat([self._indptr, fill_value], dim=0)
285
+
286
+ self._dim_size = dim_size
287
+
288
+ return self
289
+
290
+ @assert_sorted
291
+ def get_indptr(self) -> Tensor:
292
+ r"""Returns the compressed index representation in case :class:`Index`
293
+ is sorted.
294
+ """
295
+ if self._indptr is None:
296
+ self._indptr = index2ptr(self._data, self.get_dim_size())
297
+
298
+ assert isinstance(self._indptr, Tensor)
299
+ return self._indptr
300
+
301
+ def fill_cache_(self) -> 'Index':
302
+ r"""Fills the cache with (meta)data information."""
303
+ self.get_dim_size()
304
+
305
+ if self.is_sorted:
306
+ self.get_indptr()
307
+
308
+ return self
309
+
310
+ # Methods #################################################################
311
+
312
+ def share_memory_(self) -> 'Index':
313
+ """""" # noqa: D419
314
+ self._data.share_memory_()
315
+ if self._indptr is not None:
316
+ self._indptr.share_memory_()
317
+ return self
318
+
319
+ def is_shared(self) -> bool:
320
+ """""" # noqa: D419
321
+ return self._data.is_shared()
322
+
323
+ def as_tensor(self) -> Tensor:
324
+ r"""Zero-copies the :class:`Index` representation back to a
325
+ :class:`torch.Tensor` representation.
326
+ """
327
+ return self._data
328
+
329
+ # PyTorch/Python builtins #################################################
330
+
331
+ def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
332
+ attrs = ['_data']
333
+ if self._indptr is not None:
334
+ attrs.append('_indptr')
335
+
336
+ ctx = (
337
+ self._dim_size,
338
+ self._is_sorted,
339
+ self._cat_metadata,
340
+ )
341
+
342
+ return attrs, ctx
343
+
344
+ @staticmethod
345
+ def __tensor_unflatten__(
346
+ inner_tensors: Dict[str, Any],
347
+ ctx: Tuple[Any, ...],
348
+ outer_size: Tuple[int, ...],
349
+ outer_stride: Tuple[int, ...],
350
+ ) -> 'Index':
351
+ index = Index(
352
+ inner_tensors['_data'],
353
+ dim_size=ctx[0],
354
+ is_sorted=ctx[1],
355
+ )
356
+
357
+ index._indptr = inner_tensors.get('_indptr', None)
358
+ index._cat_metadata = ctx[2]
359
+
360
+ return index
361
+
362
+ # Prevent auto-wrapping outputs back into the proper subclass type:
363
+ __torch_function__ = torch._C._disabled_torch_function_impl
364
+
365
+ @classmethod
366
+ def __torch_dispatch__(
367
+ cls: Type,
368
+ func: Callable[..., Any],
369
+ types: Iterable[Type[Any]],
370
+ args: Iterable[Tuple[Any, ...]] = (),
371
+ kwargs: Optional[Dict[Any, Any]] = None,
372
+ ) -> Any:
373
+ # `Index` should be treated as a regular PyTorch tensor for all
374
+ # standard PyTorch functionalities. However,
375
+ # * some of its metadata can be transferred to new functions, e.g.,
376
+ # `torch.narrow()` can inherit the `is_sorted` property.
377
+ # * not all operations lead to valid `Index` tensors again, e.g.,
378
+ # `torch.sum()` does not yield a `Index` as its output, or
379
+ # `torch.stack() violates the [*] shape assumption.
380
+
381
+ # To account for this, we hold a number of `HANDLED_FUNCTIONS` that
382
+ # implement specific functions for valid `Index` routines.
383
+ if func in HANDLED_FUNCTIONS:
384
+ return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
385
+
386
+ # For all other PyTorch functions, we treat them as vanilla tensors.
387
+ args = pytree.tree_map_only(Index, lambda x: x._data, args)
388
+ if kwargs is not None:
389
+ kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs)
390
+ return func(*args, **(kwargs or {}))
391
+
392
+ def __repr__(self) -> str: # type: ignore
393
+ prefix = f'{self.__class__.__name__}('
394
+ indent = len(prefix)
395
+ tensor_str = torch._tensor_str._tensor_str(self._data, indent)
396
+
397
+ suffixes = []
398
+ if self.dim_size is not None:
399
+ suffixes.append(f'dim_size={self.dim_size}')
400
+ if (self.device.type != torch._C._get_default_device()
401
+ or (self.device.type == 'cuda'
402
+ and torch.cuda.current_device() != self.device.index)
403
+ or (self.device.type == 'mps')):
404
+ suffixes.append(f"device='{self.device}'")
405
+ if self.dtype != torch.int64:
406
+ suffixes.append(f'dtype={self.dtype}')
407
+ if self.is_sorted:
408
+ suffixes.append('is_sorted=True')
409
+
410
+ return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
411
+ indent, force_newline=False)
412
+
413
+ # Helpers #################################################################
414
+
415
+ def _shallow_copy(self) -> 'Index':
416
+ out = Index(self._data)
417
+ out._dim_size = self._dim_size
418
+ out._is_sorted = self._is_sorted
419
+ out._indptr = self._indptr
420
+ out._cat_metadata = self._cat_metadata
421
+ return out
422
+
423
+ def _clear_metadata(self) -> 'Index':
424
+ self._dim_size = None
425
+ self._is_sorted = False
426
+ self._indptr = None
427
+ self._cat_metadata = None
428
+ return self
429
+
430
+
431
+ def apply_(
432
+ tensor: Index,
433
+ fn: Callable,
434
+ *args: Any,
435
+ **kwargs: Any,
436
+ ) -> Union[Index, Tensor]:
437
+
438
+ data = fn(tensor._data, *args, **kwargs)
439
+
440
+ if data.dtype not in INDEX_DTYPES:
441
+ return data
442
+
443
+ if tensor._data.data_ptr() != data.data_ptr():
444
+ out = Index(data)
445
+ else: # In-place:
446
+ tensor._data = data
447
+ out = tensor
448
+
449
+ # Copy metadata:
450
+ out._dim_size = tensor._dim_size
451
+ out._is_sorted = tensor._is_sorted
452
+ out._cat_metadata = tensor._cat_metadata
453
+
454
+ # Convert cache:
455
+ if tensor._indptr is not None:
456
+ out._indptr = fn(tensor._indptr, *args, **kwargs)
457
+
458
+ return out
459
+
460
+
461
+ @implements(aten.clone.default)
462
+ def _clone(
463
+ tensor: Index,
464
+ *,
465
+ memory_format: torch.memory_format = torch.preserve_format,
466
+ ) -> Index:
467
+ out = apply_(tensor, aten.clone.default, memory_format=memory_format)
468
+ assert isinstance(out, Index)
469
+ return out
470
+
471
+
472
+ @implements(aten._to_copy.default)
473
+ def _to_copy(
474
+ tensor: Index,
475
+ *,
476
+ dtype: Optional[torch.dtype] = None,
477
+ layout: Optional[torch.layout] = None,
478
+ device: Optional[torch.device] = None,
479
+ pin_memory: bool = False,
480
+ non_blocking: bool = False,
481
+ memory_format: Optional[torch.memory_format] = None,
482
+ ) -> Union[Index, Tensor]:
483
+ return apply_(
484
+ tensor,
485
+ aten._to_copy.default,
486
+ dtype=dtype,
487
+ layout=layout,
488
+ device=device,
489
+ pin_memory=pin_memory,
490
+ non_blocking=non_blocking,
491
+ memory_format=memory_format,
492
+ )
493
+
494
+
495
+ @implements(aten.alias.default)
496
+ def _alias(tensor: Index) -> Index:
497
+ return tensor._shallow_copy()
498
+
499
+
500
+ @implements(aten._pin_memory.default)
501
+ def _pin_memory(tensor: Index) -> Index:
502
+ out = apply_(tensor, aten._pin_memory.default)
503
+ assert isinstance(out, Index)
504
+ return out
505
+
506
+
507
+ @implements(aten.sort.default)
508
+ def _sort(
509
+ tensor: Index,
510
+ dim: int = -1,
511
+ descending: bool = False,
512
+ ) -> Tuple[Index, Tensor]:
513
+
514
+ if tensor.is_sorted and not descending:
515
+ return tensor, torch.arange(tensor._data.numel(),
516
+ device=tensor._data.device)
517
+
518
+ data, perm = aten.sort.default(tensor._data, dim, descending)
519
+
520
+ out = Index(data)
521
+ out._dim_size = tensor._dim_size
522
+
523
+ if not descending:
524
+ out._is_sorted = True
525
+
526
+ return out, perm
527
+
528
+
529
+ @implements(aten.sort.stable)
530
+ def _sort_stable(
531
+ tensor: Index,
532
+ *,
533
+ stable: bool = False,
534
+ dim: int = -1,
535
+ descending: bool = False,
536
+ ) -> Tuple[Index, Tensor]:
537
+
538
+ if tensor.is_sorted and not descending:
539
+ return tensor, torch.arange(tensor._data.numel(),
540
+ device=tensor._data.device)
541
+
542
+ data, perm = aten.sort.stable(tensor._data, stable=stable, dim=dim,
543
+ descending=descending)
544
+
545
+ out = Index(data)
546
+ out._dim_size = tensor._dim_size
547
+
548
+ if not descending:
549
+ out._is_sorted = True
550
+
551
+ return out, perm
552
+
553
+
554
+ @implements(aten.cat.default)
555
+ def _cat(
556
+ tensors: List[Union[Index, Tensor]],
557
+ dim: int = 0,
558
+ ) -> Union[Index, Tensor]:
559
+
560
+ data_list = pytree.tree_map_only(Index, lambda x: x._data, tensors)
561
+ data = aten.cat.default(data_list, dim=dim)
562
+
563
+ if any([not isinstance(tensor, Index) for tensor in tensors]):
564
+ return data
565
+
566
+ out = Index(data)
567
+
568
+ nnz_list = [t.numel() for t in tensors]
569
+ dim_size_list = [t.dim_size for t in tensors] # type: ignore
570
+ is_sorted_list = [t.is_sorted for t in tensors] # type: ignore
571
+
572
+ # Post-process `dim_size`:
573
+ total_dim_size: Optional[int] = 0
574
+ for dim_size in dim_size_list:
575
+ if dim_size is None:
576
+ total_dim_size = None
577
+ break
578
+ assert isinstance(total_dim_size, int)
579
+ total_dim_size = max(dim_size, total_dim_size)
580
+
581
+ out._dim_size = total_dim_size
582
+
583
+ out._cat_metadata = CatMetadata(
584
+ nnz=nnz_list,
585
+ dim_size=dim_size_list,
586
+ is_sorted=is_sorted_list,
587
+ )
588
+
589
+ return out
590
+
591
+
592
+ @implements(aten.flip.default)
593
+ def _flip(
594
+ input: Index,
595
+ dims: Union[List[int], Tuple[int, ...]],
596
+ ) -> Index:
597
+
598
+ data = aten.flip.default(input._data, dims)
599
+
600
+ out = Index(data)
601
+ out._dim_size = input.dim_size
602
+
603
+ return out
604
+
605
+
606
+ @implements(aten.index_select.default)
607
+ def _index_select(
608
+ input: Union[Index, Tensor],
609
+ dim: int,
610
+ index: Union[Index, Tensor],
611
+ ) -> Union[Index, Tensor]:
612
+
613
+ out = aten.index_select.default(
614
+ input._data if isinstance(input, Index) else input,
615
+ dim,
616
+ index._data if isinstance(index, Index) else index,
617
+ )
618
+
619
+ if isinstance(input, Index):
620
+ out = Index(out)
621
+ out._dim_size = input.dim_size
622
+
623
+ return out
624
+
625
+
626
+ @implements(aten.slice.Tensor)
627
+ def _slice(
628
+ input: Index,
629
+ dim: int,
630
+ start: Optional[int] = None,
631
+ end: Optional[int] = None,
632
+ step: int = 1,
633
+ ) -> Index:
634
+
635
+ if ((start is None or start <= 0)
636
+ and (end is None or end > input.size(dim)) and step == 1):
637
+ return input._shallow_copy() # No-op.
638
+
639
+ data = aten.slice.Tensor(input._data, dim, start, end, step)
640
+
641
+ if step != 1:
642
+ data = data.contiguous()
643
+
644
+ out = Index(data)
645
+ out._dim_size = input.dim_size
646
+ # NOTE We could potentially maintain the `indptr` attribute here,
647
+ # but it is not really clear if this is worth it. The most important
648
+ # information `is_sorted` needs to be maintained though:
649
+ if step >= 0:
650
+ out._is_sorted = input.is_sorted
651
+
652
+ return out
653
+
654
+
655
+ @implements(aten.index.Tensor)
656
+ def _index(
657
+ input: Union[Index, Tensor],
658
+ indices: List[Optional[Union[Tensor, Index]]],
659
+ ) -> Union[Index, Tensor]:
660
+
661
+ if not isinstance(input, Index):
662
+ indices = pytree.tree_map_only(Index, lambda x: x._data, indices)
663
+ return aten.index.Tensor(input, indices)
664
+
665
+ data = aten.index.Tensor(input._data, indices)
666
+
667
+ if data.dim() != 1:
668
+ return data
669
+
670
+ assert len(indices) == 1
671
+ index = indices[0]
672
+ assert index is not None
673
+
674
+ out = Index(data)
675
+
676
+ if index.dtype in (torch.bool, torch.uint8): # 1. `index[mask]`.
677
+ out._dim_size = input.dim_size
678
+ out._is_sorted = input.is_sorted
679
+
680
+ else: # 2. `index[index]`.
681
+ out._dim_size = input.dim_size
682
+
683
+ return out
684
+
685
+
686
+ @implements(aten.add.Tensor)
687
+ def _add(
688
+ input: Union[int, Tensor, Index],
689
+ other: Union[int, Tensor, Index],
690
+ *,
691
+ alpha: int = 1,
692
+ ) -> Union[Index, Tensor]:
693
+
694
+ data = aten.add.Tensor(
695
+ input._data if isinstance(input, Index) else input,
696
+ other._data if isinstance(other, Index) else other,
697
+ alpha=alpha,
698
+ )
699
+
700
+ if data.dtype not in INDEX_DTYPES:
701
+ return data
702
+ if data.dim() != 1:
703
+ return data
704
+
705
+ out = Index(data)
706
+
707
+ if isinstance(input, Tensor) and input.numel() <= 1:
708
+ input = int(input)
709
+
710
+ if isinstance(other, Tensor) and other.numel() <= 1:
711
+ other = int(other)
712
+
713
+ if isinstance(other, int):
714
+ assert isinstance(input, Index)
715
+ if input.dim_size is not None:
716
+ out._dim_size = input.dim_size + alpha * other
717
+ out._is_sorted = input.is_sorted
718
+
719
+ elif isinstance(input, int):
720
+ assert isinstance(other, Index)
721
+ if other.dim_size is not None:
722
+ out._dim_size = input + alpha * other.dim_size
723
+ out._is_sorted = other.is_sorted
724
+
725
+ elif isinstance(input, Index) and isinstance(other, Index):
726
+ if input.dim_size is not None and other.dim_size is not None:
727
+ out._dim_size = input.dim_size + alpha * other.dim_size
728
+
729
+ return out
730
+
731
+
732
+ @implements(aten.add_.Tensor)
733
+ def add_(
734
+ input: Index,
735
+ other: Union[int, Tensor, Index],
736
+ *,
737
+ alpha: int = 1,
738
+ ) -> Index:
739
+
740
+ dim_size = input.dim_size
741
+ is_sorted = input.is_sorted
742
+ input._clear_metadata()
743
+
744
+ aten.add_.Tensor(
745
+ input._data,
746
+ other._data if isinstance(other, Index) else other,
747
+ alpha=alpha,
748
+ )
749
+
750
+ if isinstance(other, Tensor) and other.numel() <= 1:
751
+ other = int(other)
752
+
753
+ if isinstance(other, int):
754
+ if dim_size is not None:
755
+ input._dim_size = dim_size + alpha * other
756
+ input._is_sorted = is_sorted
757
+
758
+ elif isinstance(other, Index):
759
+ if dim_size is not None and other.dim_size is not None:
760
+ input._dim_size = dim_size + alpha * other.dim_size
761
+
762
+ return input
763
+
764
+
765
+ @implements(aten.sub.Tensor)
766
+ def _sub(
767
+ input: Union[int, Tensor, Index],
768
+ other: Union[int, Tensor, Index],
769
+ *,
770
+ alpha: int = 1,
771
+ ) -> Union[Index, Tensor]:
772
+
773
+ data = aten.sub.Tensor(
774
+ input._data if isinstance(input, Index) else input,
775
+ other._data if isinstance(other, Index) else other,
776
+ alpha=alpha,
777
+ )
778
+
779
+ if data.dtype not in INDEX_DTYPES:
780
+ return data
781
+ if data.dim() != 1:
782
+ return data
783
+
784
+ out = Index(data)
785
+
786
+ if not isinstance(input, Index):
787
+ return out
788
+
789
+ if isinstance(other, Tensor) and other.numel() <= 1:
790
+ other = int(other)
791
+
792
+ if isinstance(other, int):
793
+ if input.dim_size is not None:
794
+ out._dim_size = input.dim_size - alpha * other
795
+ out._is_sorted = input.is_sorted
796
+
797
+ return out
798
+
799
+
800
+ @implements(aten.sub_.Tensor)
801
+ def sub_(
802
+ input: Index,
803
+ other: Union[int, Tensor, Index],
804
+ *,
805
+ alpha: int = 1,
806
+ ) -> Index:
807
+
808
+ dim_size = input.dim_size
809
+ is_sorted = input.is_sorted
810
+ input._clear_metadata()
811
+
812
+ aten.sub_.Tensor(
813
+ input._data,
814
+ other._data if isinstance(other, Index) else other,
815
+ alpha=alpha,
816
+ )
817
+
818
+ if isinstance(other, Tensor) and other.numel() <= 1:
819
+ other = int(other)
820
+
821
+ if isinstance(other, int):
822
+ if dim_size is not None:
823
+ input._dim_size = dim_size - alpha * other
824
+ input._is_sorted = is_sorted
825
+
826
+ return input