pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/index.py
ADDED
@@ -0,0 +1,826 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import (
|
3
|
+
Any,
|
4
|
+
Callable,
|
5
|
+
Dict,
|
6
|
+
Iterable,
|
7
|
+
List,
|
8
|
+
NamedTuple,
|
9
|
+
Optional,
|
10
|
+
Tuple,
|
11
|
+
Type,
|
12
|
+
Union,
|
13
|
+
)
|
14
|
+
|
15
|
+
import torch
|
16
|
+
import torch.utils._pytree as pytree
|
17
|
+
from torch import Tensor
|
18
|
+
|
19
|
+
from torch_geometric.typing import INDEX_DTYPES
|
20
|
+
|
21
|
+
aten = torch.ops.aten
|
22
|
+
|
23
|
+
HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
|
24
|
+
|
25
|
+
|
26
|
+
def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
|
27
|
+
index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
|
28
|
+
return index.repeat_interleave(ptr.diff(), output_size=output_size)
|
29
|
+
|
30
|
+
|
31
|
+
def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:
|
32
|
+
if size is None:
|
33
|
+
size = int(index.max()) + 1 if index.numel() > 0 else 0
|
34
|
+
|
35
|
+
return torch._convert_indices_from_coo_to_csr(
|
36
|
+
index, size, out_int32=index.dtype != torch.int64)
|
37
|
+
|
38
|
+
|
39
|
+
class CatMetadata(NamedTuple):
|
40
|
+
nnz: List[int]
|
41
|
+
dim_size: List[Optional[int]]
|
42
|
+
is_sorted: List[bool]
|
43
|
+
|
44
|
+
|
45
|
+
def implements(torch_function: Callable) -> Callable:
|
46
|
+
r"""Registers a :pytorch:`PyTorch` function override."""
|
47
|
+
@functools.wraps(torch_function)
|
48
|
+
def decorator(my_function: Callable) -> Callable:
|
49
|
+
HANDLED_FUNCTIONS[torch_function] = my_function
|
50
|
+
return my_function
|
51
|
+
|
52
|
+
return decorator
|
53
|
+
|
54
|
+
|
55
|
+
def assert_valid_dtype(tensor: Tensor) -> None:
|
56
|
+
if tensor.dtype not in INDEX_DTYPES:
|
57
|
+
raise ValueError(f"'Index' holds an unsupported data type "
|
58
|
+
f"(got '{tensor.dtype}', but expected one of "
|
59
|
+
f"{INDEX_DTYPES})")
|
60
|
+
|
61
|
+
|
62
|
+
def assert_one_dimensional(tensor: Tensor) -> None:
|
63
|
+
if tensor.dim() != 1:
|
64
|
+
raise ValueError(f"'Index' needs to be one-dimensional "
|
65
|
+
f"(got {tensor.dim()} dimensions)")
|
66
|
+
|
67
|
+
|
68
|
+
def assert_contiguous(tensor: Tensor) -> None:
|
69
|
+
if not tensor.is_contiguous():
|
70
|
+
raise ValueError("'Index' needs to be contiguous. Please call "
|
71
|
+
"`index.contiguous()` before proceeding.")
|
72
|
+
|
73
|
+
|
74
|
+
def assert_sorted(func: Callable) -> Callable:
|
75
|
+
@functools.wraps(func)
|
76
|
+
def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
|
77
|
+
if not self.is_sorted:
|
78
|
+
cls_name = self.__class__.__name__
|
79
|
+
raise ValueError(
|
80
|
+
f"Cannot call '{func.__name__}' since '{cls_name}' is not "
|
81
|
+
f"sorted. Please call `{cls_name}.sort()` first.")
|
82
|
+
return func(self, *args, **kwargs)
|
83
|
+
|
84
|
+
return wrapper
|
85
|
+
|
86
|
+
|
87
|
+
class Index(Tensor):
|
88
|
+
r"""A one-dimensional :obj:`index` tensor with additional (meta)data
|
89
|
+
attached.
|
90
|
+
|
91
|
+
:class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
|
92
|
+
indices of shape :obj:`[num_indices]`.
|
93
|
+
|
94
|
+
While :class:`Index` sub-classes a general :pytorch:`null`
|
95
|
+
:class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:
|
96
|
+
|
97
|
+
* :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
|
98
|
+
the size of a dimension that can be indexed via :obj:`index`.
|
99
|
+
By default, it is inferred as :obj:`dim_size=index.max() + 1`.
|
100
|
+
* :obj:`is_sorted`: Whether indices are sorted in ascending order.
|
101
|
+
|
102
|
+
Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
|
103
|
+
conversion in case its representation is sorted.
|
104
|
+
Caches are filled based on demand (*e.g.*, when calling
|
105
|
+
:meth:`Index.get_indptr`), or when explicitly requested via
|
106
|
+
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
|
107
|
+
lifespan.
|
108
|
+
|
109
|
+
This representation ensures optimal computation in GNN message passing
|
110
|
+
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
|
111
|
+
workflows.
|
112
|
+
|
113
|
+
.. code-block:: python
|
114
|
+
|
115
|
+
from torch_geometric import Index
|
116
|
+
|
117
|
+
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
|
118
|
+
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
|
119
|
+
assert index.dim_size == 3
|
120
|
+
assert index.is_sorted
|
121
|
+
|
122
|
+
# Flipping order:
|
123
|
+
index.flip(0)
|
124
|
+
>>> Index([[2, 1, 1, 0], dim_size=3)
|
125
|
+
assert not index.is_sorted
|
126
|
+
|
127
|
+
# Filtering:
|
128
|
+
mask = torch.tensor([True, True, True, False])
|
129
|
+
index[:, mask]
|
130
|
+
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
|
131
|
+
assert index.is_sorted
|
132
|
+
"""
|
133
|
+
# See "https://pytorch.org/docs/stable/notes/extending.html"
|
134
|
+
# for a basic tutorial on how to subclass `torch.Tensor`.
|
135
|
+
|
136
|
+
# The underlying tensor representation:
|
137
|
+
_data: Tensor
|
138
|
+
|
139
|
+
# The size of the underlying sparse vector, e.g. `_data.max() + 1` :
|
140
|
+
_dim_size: Optional[int] = None
|
141
|
+
|
142
|
+
# Whether the `index` representation is sorted:
|
143
|
+
_is_sorted: bool = False
|
144
|
+
|
145
|
+
# A cache for its compressed representation:
|
146
|
+
_indptr: Optional[Tensor] = None
|
147
|
+
|
148
|
+
# Whenever we perform a concatenation of indices, we cache the original
|
149
|
+
# metadata to be able to reconstruct individual indices:
|
150
|
+
_cat_metadata: Optional[CatMetadata] = None
|
151
|
+
|
152
|
+
@staticmethod
|
153
|
+
def __new__(
|
154
|
+
cls: Type,
|
155
|
+
data: Any,
|
156
|
+
*args: Any,
|
157
|
+
dim_size: Optional[int] = None,
|
158
|
+
is_sorted: bool = False,
|
159
|
+
**kwargs: Any,
|
160
|
+
) -> 'Index':
|
161
|
+
if not isinstance(data, Tensor):
|
162
|
+
data = torch.tensor(data, *args, **kwargs)
|
163
|
+
elif len(args) > 0:
|
164
|
+
raise TypeError(
|
165
|
+
f"new() received an invalid combination of arguments - got "
|
166
|
+
f"(Tensor, {', '.join(str(type(arg)) for arg in args)})")
|
167
|
+
elif len(kwargs) > 0:
|
168
|
+
raise TypeError(f"new() received invalid keyword arguments - got "
|
169
|
+
f"{set(kwargs.keys())})")
|
170
|
+
|
171
|
+
assert isinstance(data, Tensor)
|
172
|
+
|
173
|
+
indptr: Optional[Tensor] = None
|
174
|
+
|
175
|
+
if isinstance(data, cls): # If passed `Index`, inherit metadata:
|
176
|
+
indptr = data._indptr
|
177
|
+
dim_size = dim_size or data.dim_size
|
178
|
+
is_sorted = is_sorted or data.is_sorted
|
179
|
+
|
180
|
+
assert_valid_dtype(data)
|
181
|
+
assert_one_dimensional(data)
|
182
|
+
assert_contiguous(data)
|
183
|
+
|
184
|
+
out = Tensor._make_wrapper_subclass( # type: ignore
|
185
|
+
cls,
|
186
|
+
size=data.size(),
|
187
|
+
strides=data.stride(),
|
188
|
+
dtype=data.dtype,
|
189
|
+
device=data.device,
|
190
|
+
layout=data.layout,
|
191
|
+
requires_grad=False,
|
192
|
+
)
|
193
|
+
assert isinstance(out, Index)
|
194
|
+
|
195
|
+
# Attach metadata:
|
196
|
+
out._data = data
|
197
|
+
out._dim_size = dim_size
|
198
|
+
out._is_sorted = is_sorted
|
199
|
+
out._indptr = indptr
|
200
|
+
|
201
|
+
if isinstance(data, cls):
|
202
|
+
out._data = data._data
|
203
|
+
|
204
|
+
# Reset metadata if cache is invalidated:
|
205
|
+
if dim_size is not None and dim_size != data.dim_size:
|
206
|
+
out._indptr = None
|
207
|
+
|
208
|
+
return out
|
209
|
+
|
210
|
+
# Validation ##############################################################
|
211
|
+
|
212
|
+
def validate(self) -> 'Index':
|
213
|
+
r"""Validates the :class:`Index` representation.
|
214
|
+
|
215
|
+
In particular, it ensures that
|
216
|
+
|
217
|
+
* it only holds valid indices.
|
218
|
+
* the sort order is correctly set.
|
219
|
+
"""
|
220
|
+
assert_valid_dtype(self._data)
|
221
|
+
assert_one_dimensional(self._data)
|
222
|
+
assert_contiguous(self._data)
|
223
|
+
|
224
|
+
if self.numel() > 0 and self._data.min() < 0:
|
225
|
+
raise ValueError(f"'{self.__class__.__name__}' contains negative "
|
226
|
+
f"indices (got {int(self.min())})")
|
227
|
+
|
228
|
+
if (self.numel() > 0 and self.dim_size is not None
|
229
|
+
and self._data.max() >= self.dim_size):
|
230
|
+
raise ValueError(f"'{self.__class__.__name__}' contains larger "
|
231
|
+
f"indices than its registered size "
|
232
|
+
f"(got {int(self._data.max())}, but expected "
|
233
|
+
f"values smaller than {self.dim_size})")
|
234
|
+
|
235
|
+
if self.is_sorted and (self._data.diff() < 0).any():
|
236
|
+
raise ValueError(f"'{self.__class__.__name__}' is not sorted")
|
237
|
+
|
238
|
+
return self
|
239
|
+
|
240
|
+
# Properties ##############################################################
|
241
|
+
|
242
|
+
@property
|
243
|
+
def dim_size(self) -> Optional[int]:
|
244
|
+
r"""The size of the underlying sparse vector."""
|
245
|
+
return self._dim_size
|
246
|
+
|
247
|
+
@property
|
248
|
+
def is_sorted(self) -> bool:
|
249
|
+
r"""Returns whether indices are sorted in ascending order."""
|
250
|
+
return self._is_sorted
|
251
|
+
|
252
|
+
@property
|
253
|
+
def dtype(self) -> torch.dtype: # type: ignore
|
254
|
+
# TODO Remove once PyTorch does not override `dtype` in `DataLoader`.
|
255
|
+
return self._data.dtype
|
256
|
+
|
257
|
+
# Cache Interface #########################################################
|
258
|
+
|
259
|
+
def get_dim_size(self) -> int:
|
260
|
+
r"""The size of the underlying sparse vector.
|
261
|
+
Automatically computed and cached when not explicitly set.
|
262
|
+
"""
|
263
|
+
if self._dim_size is None:
|
264
|
+
dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
|
265
|
+
self._dim_size = dim_size
|
266
|
+
|
267
|
+
assert isinstance(self._dim_size, int)
|
268
|
+
return self._dim_size
|
269
|
+
|
270
|
+
def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
|
271
|
+
r"""Assigns or re-assigns the size of the underlying sparse vector."""
|
272
|
+
if self.is_sorted and self._indptr is not None:
|
273
|
+
if dim_size is None:
|
274
|
+
self._indptr = None
|
275
|
+
|
276
|
+
elif self._indptr.numel() - 1 >= dim_size:
|
277
|
+
self._indptr = self._indptr[:dim_size + 1]
|
278
|
+
|
279
|
+
else:
|
280
|
+
fill_value = self._indptr.new_full(
|
281
|
+
(dim_size - self._indptr.numel() + 1, ),
|
282
|
+
fill_value=self._indptr[-1], # type: ignore
|
283
|
+
)
|
284
|
+
self._indptr = torch.cat([self._indptr, fill_value], dim=0)
|
285
|
+
|
286
|
+
self._dim_size = dim_size
|
287
|
+
|
288
|
+
return self
|
289
|
+
|
290
|
+
@assert_sorted
|
291
|
+
def get_indptr(self) -> Tensor:
|
292
|
+
r"""Returns the compressed index representation in case :class:`Index`
|
293
|
+
is sorted.
|
294
|
+
"""
|
295
|
+
if self._indptr is None:
|
296
|
+
self._indptr = index2ptr(self._data, self.get_dim_size())
|
297
|
+
|
298
|
+
assert isinstance(self._indptr, Tensor)
|
299
|
+
return self._indptr
|
300
|
+
|
301
|
+
def fill_cache_(self) -> 'Index':
|
302
|
+
r"""Fills the cache with (meta)data information."""
|
303
|
+
self.get_dim_size()
|
304
|
+
|
305
|
+
if self.is_sorted:
|
306
|
+
self.get_indptr()
|
307
|
+
|
308
|
+
return self
|
309
|
+
|
310
|
+
# Methods #################################################################
|
311
|
+
|
312
|
+
def share_memory_(self) -> 'Index':
|
313
|
+
"""""" # noqa: D419
|
314
|
+
self._data.share_memory_()
|
315
|
+
if self._indptr is not None:
|
316
|
+
self._indptr.share_memory_()
|
317
|
+
return self
|
318
|
+
|
319
|
+
def is_shared(self) -> bool:
|
320
|
+
"""""" # noqa: D419
|
321
|
+
return self._data.is_shared()
|
322
|
+
|
323
|
+
def as_tensor(self) -> Tensor:
|
324
|
+
r"""Zero-copies the :class:`Index` representation back to a
|
325
|
+
:class:`torch.Tensor` representation.
|
326
|
+
"""
|
327
|
+
return self._data
|
328
|
+
|
329
|
+
# PyTorch/Python builtins #################################################
|
330
|
+
|
331
|
+
def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
|
332
|
+
attrs = ['_data']
|
333
|
+
if self._indptr is not None:
|
334
|
+
attrs.append('_indptr')
|
335
|
+
|
336
|
+
ctx = (
|
337
|
+
self._dim_size,
|
338
|
+
self._is_sorted,
|
339
|
+
self._cat_metadata,
|
340
|
+
)
|
341
|
+
|
342
|
+
return attrs, ctx
|
343
|
+
|
344
|
+
@staticmethod
|
345
|
+
def __tensor_unflatten__(
|
346
|
+
inner_tensors: Dict[str, Any],
|
347
|
+
ctx: Tuple[Any, ...],
|
348
|
+
outer_size: Tuple[int, ...],
|
349
|
+
outer_stride: Tuple[int, ...],
|
350
|
+
) -> 'Index':
|
351
|
+
index = Index(
|
352
|
+
inner_tensors['_data'],
|
353
|
+
dim_size=ctx[0],
|
354
|
+
is_sorted=ctx[1],
|
355
|
+
)
|
356
|
+
|
357
|
+
index._indptr = inner_tensors.get('_indptr', None)
|
358
|
+
index._cat_metadata = ctx[2]
|
359
|
+
|
360
|
+
return index
|
361
|
+
|
362
|
+
# Prevent auto-wrapping outputs back into the proper subclass type:
|
363
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
364
|
+
|
365
|
+
@classmethod
|
366
|
+
def __torch_dispatch__(
|
367
|
+
cls: Type,
|
368
|
+
func: Callable[..., Any],
|
369
|
+
types: Iterable[Type[Any]],
|
370
|
+
args: Iterable[Tuple[Any, ...]] = (),
|
371
|
+
kwargs: Optional[Dict[Any, Any]] = None,
|
372
|
+
) -> Any:
|
373
|
+
# `Index` should be treated as a regular PyTorch tensor for all
|
374
|
+
# standard PyTorch functionalities. However,
|
375
|
+
# * some of its metadata can be transferred to new functions, e.g.,
|
376
|
+
# `torch.narrow()` can inherit the `is_sorted` property.
|
377
|
+
# * not all operations lead to valid `Index` tensors again, e.g.,
|
378
|
+
# `torch.sum()` does not yield a `Index` as its output, or
|
379
|
+
# `torch.stack() violates the [*] shape assumption.
|
380
|
+
|
381
|
+
# To account for this, we hold a number of `HANDLED_FUNCTIONS` that
|
382
|
+
# implement specific functions for valid `Index` routines.
|
383
|
+
if func in HANDLED_FUNCTIONS:
|
384
|
+
return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
|
385
|
+
|
386
|
+
# For all other PyTorch functions, we treat them as vanilla tensors.
|
387
|
+
args = pytree.tree_map_only(Index, lambda x: x._data, args)
|
388
|
+
if kwargs is not None:
|
389
|
+
kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs)
|
390
|
+
return func(*args, **(kwargs or {}))
|
391
|
+
|
392
|
+
def __repr__(self) -> str: # type: ignore
|
393
|
+
prefix = f'{self.__class__.__name__}('
|
394
|
+
indent = len(prefix)
|
395
|
+
tensor_str = torch._tensor_str._tensor_str(self._data, indent)
|
396
|
+
|
397
|
+
suffixes = []
|
398
|
+
if self.dim_size is not None:
|
399
|
+
suffixes.append(f'dim_size={self.dim_size}')
|
400
|
+
if (self.device.type != torch._C._get_default_device()
|
401
|
+
or (self.device.type == 'cuda'
|
402
|
+
and torch.cuda.current_device() != self.device.index)
|
403
|
+
or (self.device.type == 'mps')):
|
404
|
+
suffixes.append(f"device='{self.device}'")
|
405
|
+
if self.dtype != torch.int64:
|
406
|
+
suffixes.append(f'dtype={self.dtype}')
|
407
|
+
if self.is_sorted:
|
408
|
+
suffixes.append('is_sorted=True')
|
409
|
+
|
410
|
+
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
|
411
|
+
indent, force_newline=False)
|
412
|
+
|
413
|
+
# Helpers #################################################################
|
414
|
+
|
415
|
+
def _shallow_copy(self) -> 'Index':
|
416
|
+
out = Index(self._data)
|
417
|
+
out._dim_size = self._dim_size
|
418
|
+
out._is_sorted = self._is_sorted
|
419
|
+
out._indptr = self._indptr
|
420
|
+
out._cat_metadata = self._cat_metadata
|
421
|
+
return out
|
422
|
+
|
423
|
+
def _clear_metadata(self) -> 'Index':
|
424
|
+
self._dim_size = None
|
425
|
+
self._is_sorted = False
|
426
|
+
self._indptr = None
|
427
|
+
self._cat_metadata = None
|
428
|
+
return self
|
429
|
+
|
430
|
+
|
431
|
+
def apply_(
|
432
|
+
tensor: Index,
|
433
|
+
fn: Callable,
|
434
|
+
*args: Any,
|
435
|
+
**kwargs: Any,
|
436
|
+
) -> Union[Index, Tensor]:
|
437
|
+
|
438
|
+
data = fn(tensor._data, *args, **kwargs)
|
439
|
+
|
440
|
+
if data.dtype not in INDEX_DTYPES:
|
441
|
+
return data
|
442
|
+
|
443
|
+
if tensor._data.data_ptr() != data.data_ptr():
|
444
|
+
out = Index(data)
|
445
|
+
else: # In-place:
|
446
|
+
tensor._data = data
|
447
|
+
out = tensor
|
448
|
+
|
449
|
+
# Copy metadata:
|
450
|
+
out._dim_size = tensor._dim_size
|
451
|
+
out._is_sorted = tensor._is_sorted
|
452
|
+
out._cat_metadata = tensor._cat_metadata
|
453
|
+
|
454
|
+
# Convert cache:
|
455
|
+
if tensor._indptr is not None:
|
456
|
+
out._indptr = fn(tensor._indptr, *args, **kwargs)
|
457
|
+
|
458
|
+
return out
|
459
|
+
|
460
|
+
|
461
|
+
@implements(aten.clone.default)
|
462
|
+
def _clone(
|
463
|
+
tensor: Index,
|
464
|
+
*,
|
465
|
+
memory_format: torch.memory_format = torch.preserve_format,
|
466
|
+
) -> Index:
|
467
|
+
out = apply_(tensor, aten.clone.default, memory_format=memory_format)
|
468
|
+
assert isinstance(out, Index)
|
469
|
+
return out
|
470
|
+
|
471
|
+
|
472
|
+
@implements(aten._to_copy.default)
|
473
|
+
def _to_copy(
|
474
|
+
tensor: Index,
|
475
|
+
*,
|
476
|
+
dtype: Optional[torch.dtype] = None,
|
477
|
+
layout: Optional[torch.layout] = None,
|
478
|
+
device: Optional[torch.device] = None,
|
479
|
+
pin_memory: bool = False,
|
480
|
+
non_blocking: bool = False,
|
481
|
+
memory_format: Optional[torch.memory_format] = None,
|
482
|
+
) -> Union[Index, Tensor]:
|
483
|
+
return apply_(
|
484
|
+
tensor,
|
485
|
+
aten._to_copy.default,
|
486
|
+
dtype=dtype,
|
487
|
+
layout=layout,
|
488
|
+
device=device,
|
489
|
+
pin_memory=pin_memory,
|
490
|
+
non_blocking=non_blocking,
|
491
|
+
memory_format=memory_format,
|
492
|
+
)
|
493
|
+
|
494
|
+
|
495
|
+
@implements(aten.alias.default)
|
496
|
+
def _alias(tensor: Index) -> Index:
|
497
|
+
return tensor._shallow_copy()
|
498
|
+
|
499
|
+
|
500
|
+
@implements(aten._pin_memory.default)
|
501
|
+
def _pin_memory(tensor: Index) -> Index:
|
502
|
+
out = apply_(tensor, aten._pin_memory.default)
|
503
|
+
assert isinstance(out, Index)
|
504
|
+
return out
|
505
|
+
|
506
|
+
|
507
|
+
@implements(aten.sort.default)
|
508
|
+
def _sort(
|
509
|
+
tensor: Index,
|
510
|
+
dim: int = -1,
|
511
|
+
descending: bool = False,
|
512
|
+
) -> Tuple[Index, Tensor]:
|
513
|
+
|
514
|
+
if tensor.is_sorted and not descending:
|
515
|
+
return tensor, torch.arange(tensor._data.numel(),
|
516
|
+
device=tensor._data.device)
|
517
|
+
|
518
|
+
data, perm = aten.sort.default(tensor._data, dim, descending)
|
519
|
+
|
520
|
+
out = Index(data)
|
521
|
+
out._dim_size = tensor._dim_size
|
522
|
+
|
523
|
+
if not descending:
|
524
|
+
out._is_sorted = True
|
525
|
+
|
526
|
+
return out, perm
|
527
|
+
|
528
|
+
|
529
|
+
@implements(aten.sort.stable)
|
530
|
+
def _sort_stable(
|
531
|
+
tensor: Index,
|
532
|
+
*,
|
533
|
+
stable: bool = False,
|
534
|
+
dim: int = -1,
|
535
|
+
descending: bool = False,
|
536
|
+
) -> Tuple[Index, Tensor]:
|
537
|
+
|
538
|
+
if tensor.is_sorted and not descending:
|
539
|
+
return tensor, torch.arange(tensor._data.numel(),
|
540
|
+
device=tensor._data.device)
|
541
|
+
|
542
|
+
data, perm = aten.sort.stable(tensor._data, stable=stable, dim=dim,
|
543
|
+
descending=descending)
|
544
|
+
|
545
|
+
out = Index(data)
|
546
|
+
out._dim_size = tensor._dim_size
|
547
|
+
|
548
|
+
if not descending:
|
549
|
+
out._is_sorted = True
|
550
|
+
|
551
|
+
return out, perm
|
552
|
+
|
553
|
+
|
554
|
+
@implements(aten.cat.default)
|
555
|
+
def _cat(
|
556
|
+
tensors: List[Union[Index, Tensor]],
|
557
|
+
dim: int = 0,
|
558
|
+
) -> Union[Index, Tensor]:
|
559
|
+
|
560
|
+
data_list = pytree.tree_map_only(Index, lambda x: x._data, tensors)
|
561
|
+
data = aten.cat.default(data_list, dim=dim)
|
562
|
+
|
563
|
+
if any([not isinstance(tensor, Index) for tensor in tensors]):
|
564
|
+
return data
|
565
|
+
|
566
|
+
out = Index(data)
|
567
|
+
|
568
|
+
nnz_list = [t.numel() for t in tensors]
|
569
|
+
dim_size_list = [t.dim_size for t in tensors] # type: ignore
|
570
|
+
is_sorted_list = [t.is_sorted for t in tensors] # type: ignore
|
571
|
+
|
572
|
+
# Post-process `dim_size`:
|
573
|
+
total_dim_size: Optional[int] = 0
|
574
|
+
for dim_size in dim_size_list:
|
575
|
+
if dim_size is None:
|
576
|
+
total_dim_size = None
|
577
|
+
break
|
578
|
+
assert isinstance(total_dim_size, int)
|
579
|
+
total_dim_size = max(dim_size, total_dim_size)
|
580
|
+
|
581
|
+
out._dim_size = total_dim_size
|
582
|
+
|
583
|
+
out._cat_metadata = CatMetadata(
|
584
|
+
nnz=nnz_list,
|
585
|
+
dim_size=dim_size_list,
|
586
|
+
is_sorted=is_sorted_list,
|
587
|
+
)
|
588
|
+
|
589
|
+
return out
|
590
|
+
|
591
|
+
|
592
|
+
@implements(aten.flip.default)
|
593
|
+
def _flip(
|
594
|
+
input: Index,
|
595
|
+
dims: Union[List[int], Tuple[int, ...]],
|
596
|
+
) -> Index:
|
597
|
+
|
598
|
+
data = aten.flip.default(input._data, dims)
|
599
|
+
|
600
|
+
out = Index(data)
|
601
|
+
out._dim_size = input.dim_size
|
602
|
+
|
603
|
+
return out
|
604
|
+
|
605
|
+
|
606
|
+
@implements(aten.index_select.default)
|
607
|
+
def _index_select(
|
608
|
+
input: Union[Index, Tensor],
|
609
|
+
dim: int,
|
610
|
+
index: Union[Index, Tensor],
|
611
|
+
) -> Union[Index, Tensor]:
|
612
|
+
|
613
|
+
out = aten.index_select.default(
|
614
|
+
input._data if isinstance(input, Index) else input,
|
615
|
+
dim,
|
616
|
+
index._data if isinstance(index, Index) else index,
|
617
|
+
)
|
618
|
+
|
619
|
+
if isinstance(input, Index):
|
620
|
+
out = Index(out)
|
621
|
+
out._dim_size = input.dim_size
|
622
|
+
|
623
|
+
return out
|
624
|
+
|
625
|
+
|
626
|
+
@implements(aten.slice.Tensor)
|
627
|
+
def _slice(
|
628
|
+
input: Index,
|
629
|
+
dim: int,
|
630
|
+
start: Optional[int] = None,
|
631
|
+
end: Optional[int] = None,
|
632
|
+
step: int = 1,
|
633
|
+
) -> Index:
|
634
|
+
|
635
|
+
if ((start is None or start <= 0)
|
636
|
+
and (end is None or end > input.size(dim)) and step == 1):
|
637
|
+
return input._shallow_copy() # No-op.
|
638
|
+
|
639
|
+
data = aten.slice.Tensor(input._data, dim, start, end, step)
|
640
|
+
|
641
|
+
if step != 1:
|
642
|
+
data = data.contiguous()
|
643
|
+
|
644
|
+
out = Index(data)
|
645
|
+
out._dim_size = input.dim_size
|
646
|
+
# NOTE We could potentially maintain the `indptr` attribute here,
|
647
|
+
# but it is not really clear if this is worth it. The most important
|
648
|
+
# information `is_sorted` needs to be maintained though:
|
649
|
+
if step >= 0:
|
650
|
+
out._is_sorted = input.is_sorted
|
651
|
+
|
652
|
+
return out
|
653
|
+
|
654
|
+
|
655
|
+
@implements(aten.index.Tensor)
|
656
|
+
def _index(
|
657
|
+
input: Union[Index, Tensor],
|
658
|
+
indices: List[Optional[Union[Tensor, Index]]],
|
659
|
+
) -> Union[Index, Tensor]:
|
660
|
+
|
661
|
+
if not isinstance(input, Index):
|
662
|
+
indices = pytree.tree_map_only(Index, lambda x: x._data, indices)
|
663
|
+
return aten.index.Tensor(input, indices)
|
664
|
+
|
665
|
+
data = aten.index.Tensor(input._data, indices)
|
666
|
+
|
667
|
+
if data.dim() != 1:
|
668
|
+
return data
|
669
|
+
|
670
|
+
assert len(indices) == 1
|
671
|
+
index = indices[0]
|
672
|
+
assert index is not None
|
673
|
+
|
674
|
+
out = Index(data)
|
675
|
+
|
676
|
+
if index.dtype in (torch.bool, torch.uint8): # 1. `index[mask]`.
|
677
|
+
out._dim_size = input.dim_size
|
678
|
+
out._is_sorted = input.is_sorted
|
679
|
+
|
680
|
+
else: # 2. `index[index]`.
|
681
|
+
out._dim_size = input.dim_size
|
682
|
+
|
683
|
+
return out
|
684
|
+
|
685
|
+
|
686
|
+
@implements(aten.add.Tensor)
|
687
|
+
def _add(
|
688
|
+
input: Union[int, Tensor, Index],
|
689
|
+
other: Union[int, Tensor, Index],
|
690
|
+
*,
|
691
|
+
alpha: int = 1,
|
692
|
+
) -> Union[Index, Tensor]:
|
693
|
+
|
694
|
+
data = aten.add.Tensor(
|
695
|
+
input._data if isinstance(input, Index) else input,
|
696
|
+
other._data if isinstance(other, Index) else other,
|
697
|
+
alpha=alpha,
|
698
|
+
)
|
699
|
+
|
700
|
+
if data.dtype not in INDEX_DTYPES:
|
701
|
+
return data
|
702
|
+
if data.dim() != 1:
|
703
|
+
return data
|
704
|
+
|
705
|
+
out = Index(data)
|
706
|
+
|
707
|
+
if isinstance(input, Tensor) and input.numel() <= 1:
|
708
|
+
input = int(input)
|
709
|
+
|
710
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
711
|
+
other = int(other)
|
712
|
+
|
713
|
+
if isinstance(other, int):
|
714
|
+
assert isinstance(input, Index)
|
715
|
+
if input.dim_size is not None:
|
716
|
+
out._dim_size = input.dim_size + alpha * other
|
717
|
+
out._is_sorted = input.is_sorted
|
718
|
+
|
719
|
+
elif isinstance(input, int):
|
720
|
+
assert isinstance(other, Index)
|
721
|
+
if other.dim_size is not None:
|
722
|
+
out._dim_size = input + alpha * other.dim_size
|
723
|
+
out._is_sorted = other.is_sorted
|
724
|
+
|
725
|
+
elif isinstance(input, Index) and isinstance(other, Index):
|
726
|
+
if input.dim_size is not None and other.dim_size is not None:
|
727
|
+
out._dim_size = input.dim_size + alpha * other.dim_size
|
728
|
+
|
729
|
+
return out
|
730
|
+
|
731
|
+
|
732
|
+
@implements(aten.add_.Tensor)
|
733
|
+
def add_(
|
734
|
+
input: Index,
|
735
|
+
other: Union[int, Tensor, Index],
|
736
|
+
*,
|
737
|
+
alpha: int = 1,
|
738
|
+
) -> Index:
|
739
|
+
|
740
|
+
dim_size = input.dim_size
|
741
|
+
is_sorted = input.is_sorted
|
742
|
+
input._clear_metadata()
|
743
|
+
|
744
|
+
aten.add_.Tensor(
|
745
|
+
input._data,
|
746
|
+
other._data if isinstance(other, Index) else other,
|
747
|
+
alpha=alpha,
|
748
|
+
)
|
749
|
+
|
750
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
751
|
+
other = int(other)
|
752
|
+
|
753
|
+
if isinstance(other, int):
|
754
|
+
if dim_size is not None:
|
755
|
+
input._dim_size = dim_size + alpha * other
|
756
|
+
input._is_sorted = is_sorted
|
757
|
+
|
758
|
+
elif isinstance(other, Index):
|
759
|
+
if dim_size is not None and other.dim_size is not None:
|
760
|
+
input._dim_size = dim_size + alpha * other.dim_size
|
761
|
+
|
762
|
+
return input
|
763
|
+
|
764
|
+
|
765
|
+
@implements(aten.sub.Tensor)
|
766
|
+
def _sub(
|
767
|
+
input: Union[int, Tensor, Index],
|
768
|
+
other: Union[int, Tensor, Index],
|
769
|
+
*,
|
770
|
+
alpha: int = 1,
|
771
|
+
) -> Union[Index, Tensor]:
|
772
|
+
|
773
|
+
data = aten.sub.Tensor(
|
774
|
+
input._data if isinstance(input, Index) else input,
|
775
|
+
other._data if isinstance(other, Index) else other,
|
776
|
+
alpha=alpha,
|
777
|
+
)
|
778
|
+
|
779
|
+
if data.dtype not in INDEX_DTYPES:
|
780
|
+
return data
|
781
|
+
if data.dim() != 1:
|
782
|
+
return data
|
783
|
+
|
784
|
+
out = Index(data)
|
785
|
+
|
786
|
+
if not isinstance(input, Index):
|
787
|
+
return out
|
788
|
+
|
789
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
790
|
+
other = int(other)
|
791
|
+
|
792
|
+
if isinstance(other, int):
|
793
|
+
if input.dim_size is not None:
|
794
|
+
out._dim_size = input.dim_size - alpha * other
|
795
|
+
out._is_sorted = input.is_sorted
|
796
|
+
|
797
|
+
return out
|
798
|
+
|
799
|
+
|
800
|
+
@implements(aten.sub_.Tensor)
|
801
|
+
def sub_(
|
802
|
+
input: Index,
|
803
|
+
other: Union[int, Tensor, Index],
|
804
|
+
*,
|
805
|
+
alpha: int = 1,
|
806
|
+
) -> Index:
|
807
|
+
|
808
|
+
dim_size = input.dim_size
|
809
|
+
is_sorted = input.is_sorted
|
810
|
+
input._clear_metadata()
|
811
|
+
|
812
|
+
aten.sub_.Tensor(
|
813
|
+
input._data,
|
814
|
+
other._data if isinstance(other, Index) else other,
|
815
|
+
alpha=alpha,
|
816
|
+
)
|
817
|
+
|
818
|
+
if isinstance(other, Tensor) and other.numel() <= 1:
|
819
|
+
other = int(other)
|
820
|
+
|
821
|
+
if isinstance(other, int):
|
822
|
+
if dim_size is not None:
|
823
|
+
input._dim_size = dim_size - alpha * other
|
824
|
+
input._is_sorted = is_sorted
|
825
|
+
|
826
|
+
return input
|