pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -0,0 +1,798 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
Callable,
|
|
6
|
+
Dict,
|
|
7
|
+
Iterable,
|
|
8
|
+
List,
|
|
9
|
+
Optional,
|
|
10
|
+
Tuple,
|
|
11
|
+
Type,
|
|
12
|
+
Union,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import torch.utils._pytree as pytree
|
|
18
|
+
import xxhash
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
import torch_geometric.typing
|
|
22
|
+
from torch_geometric.typing import CPUHashMap, CUDAHashMap
|
|
23
|
+
|
|
24
|
+
aten = torch.ops.aten
|
|
25
|
+
|
|
26
|
+
HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def implements(torch_function: Callable) -> Callable:
|
|
30
|
+
r"""Registers a :pytorch:`PyTorch` function override."""
|
|
31
|
+
@functools.wraps(torch_function)
|
|
32
|
+
def decorator(my_function: Callable) -> Callable:
|
|
33
|
+
HANDLED_FUNCTIONS[torch_function] = my_function
|
|
34
|
+
return my_function
|
|
35
|
+
|
|
36
|
+
return decorator
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def as_key_tensor(
|
|
40
|
+
key: Any,
|
|
41
|
+
*,
|
|
42
|
+
device: Optional[torch.device] = None,
|
|
43
|
+
) -> Tensor:
|
|
44
|
+
try:
|
|
45
|
+
key = torch.as_tensor(key, device=device)
|
|
46
|
+
except Exception:
|
|
47
|
+
device = device or torch.get_default_device()
|
|
48
|
+
key = torch.tensor(
|
|
49
|
+
[xxhash.xxh64(x).intdigest() & 0x7FFFFFFFFFFFFFFF for x in key],
|
|
50
|
+
dtype=torch.int64, device=device)
|
|
51
|
+
|
|
52
|
+
if key.element_size() == 1:
|
|
53
|
+
key = key.view(torch.uint8)
|
|
54
|
+
elif key.element_size() == 2:
|
|
55
|
+
key = key.view(torch.int16)
|
|
56
|
+
elif key.element_size() == 4:
|
|
57
|
+
key = key.view(torch.int32)
|
|
58
|
+
elif key.element_size() == 8:
|
|
59
|
+
key = key.view(torch.int64)
|
|
60
|
+
else:
|
|
61
|
+
raise ValueError(f"Received invalid dtype '{key.dtype}' with "
|
|
62
|
+
f"{key.element_size()} bytes")
|
|
63
|
+
|
|
64
|
+
return key
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_hash_map(key: Tensor) -> Union[CPUHashMap, CUDAHashMap]:
|
|
68
|
+
if torch_geometric.typing.WITH_CUDA_HASH_MAP and key.is_cuda:
|
|
69
|
+
return CUDAHashMap(key, 0.5)
|
|
70
|
+
|
|
71
|
+
if key.is_cuda:
|
|
72
|
+
warnings.warn(
|
|
73
|
+
"Fallback to CPU-based mapping algorithm which may "
|
|
74
|
+
"cause slowdowns and device synchronization. Please "
|
|
75
|
+
"install 'pyg-lib' for an accelerated 'HashTensor' "
|
|
76
|
+
"implementation.", stacklevel=2)
|
|
77
|
+
|
|
78
|
+
if torch_geometric.typing.WITH_CPU_HASH_MAP:
|
|
79
|
+
return CPUHashMap(key.cpu(), -1)
|
|
80
|
+
|
|
81
|
+
import pandas as pd
|
|
82
|
+
|
|
83
|
+
return pd.CategoricalDtype(
|
|
84
|
+
categories=key.cpu().numpy(),
|
|
85
|
+
ordered=True,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class HashTensor(Tensor):
|
|
90
|
+
r"""A :pytorch:`null` :class:`torch.Tensor` that can be referenced by
|
|
91
|
+
arbitrary keys rather than indices in the first dimension.
|
|
92
|
+
|
|
93
|
+
:class:`HashTensor` sub-classes a general :pytorch:`null`
|
|
94
|
+
:class:`torch.Tensor`, and extends it by CPU- and GPU-accelerated mapping
|
|
95
|
+
routines. This allow for fast and efficient access to non-contiguous
|
|
96
|
+
indices/keys while the underlying data is stored in a compact format.
|
|
97
|
+
|
|
98
|
+
This representation is ideal for scenarios where one needs a fast mapping
|
|
99
|
+
routine without relying on CPU-based external packages, and can be used,
|
|
100
|
+
*e.g.*, to perform mapping of global indices to local indices during
|
|
101
|
+
subgraph creation, or in data-processing pipelines to map non-contiguous
|
|
102
|
+
input data into a contiguous space, such as
|
|
103
|
+
|
|
104
|
+
* mapping of hashed node IDs to range :obj:`[0, num_nodes - 1]`
|
|
105
|
+
* mapping of raw input data, *e.g.*, categorical data to range
|
|
106
|
+
:obj:`[0, num_categories - 1]`
|
|
107
|
+
|
|
108
|
+
Specifically, :class:`HashTensor` supports *any* keys of *any* type,
|
|
109
|
+
*e.g.*, strings, timestamps, etc.
|
|
110
|
+
|
|
111
|
+
.. code-block:: python
|
|
112
|
+
|
|
113
|
+
from torch_geometric import HashTensor
|
|
114
|
+
|
|
115
|
+
key = torch.tensor([1000, 100, 10000])
|
|
116
|
+
value = torch.randn(3, 4)
|
|
117
|
+
|
|
118
|
+
tensor = HashTensor(key, value)
|
|
119
|
+
assert tensor.size() == (3, 4)
|
|
120
|
+
|
|
121
|
+
# Filtering:
|
|
122
|
+
query = torch.tensor([10000, 1000])
|
|
123
|
+
out = tensor[query]
|
|
124
|
+
assert out.equal(value[[2, 0]])
|
|
125
|
+
|
|
126
|
+
# Accessing non-existing keys:
|
|
127
|
+
out = tensor[[10000, 0]]
|
|
128
|
+
out.isnan()
|
|
129
|
+
>>> tensor([[False, False, False, False],
|
|
130
|
+
... [True, True, True, True])
|
|
131
|
+
|
|
132
|
+
# If `value` is not given, indexing returns the position of `query` in
|
|
133
|
+
# `key`, and `-1` otherwise:
|
|
134
|
+
key = ['Animation', 'Comedy', 'Fantasy']
|
|
135
|
+
tensor = HashTensor(key)
|
|
136
|
+
|
|
137
|
+
out = tensor[['Comedy', 'Romance']]
|
|
138
|
+
>>> tensor([1, -1])
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
key: The keys in the first dimension.
|
|
142
|
+
value: The values to hold.
|
|
143
|
+
dtype: The desired data type of the values of the returned tensor.
|
|
144
|
+
device: The device of the returned tensor.
|
|
145
|
+
"""
|
|
146
|
+
_map: Union[Tensor, CPUHashMap, CUDAHashMap]
|
|
147
|
+
_value: Optional[Tensor]
|
|
148
|
+
_min_key: Tensor
|
|
149
|
+
_max_key: Tensor
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def __new__(
|
|
153
|
+
cls: Type,
|
|
154
|
+
key: Any,
|
|
155
|
+
value: Optional[Any] = None,
|
|
156
|
+
*,
|
|
157
|
+
dtype: Optional[torch.dtype] = None,
|
|
158
|
+
device: Optional[torch.device] = None,
|
|
159
|
+
) -> 'HashTensor':
|
|
160
|
+
|
|
161
|
+
if value is not None:
|
|
162
|
+
value = torch.as_tensor(value, dtype=dtype, device=device)
|
|
163
|
+
device = value.device
|
|
164
|
+
|
|
165
|
+
key = as_key_tensor(key, device=device)
|
|
166
|
+
|
|
167
|
+
if key.dim() != 1:
|
|
168
|
+
raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
|
|
169
|
+
f"one-dimensional (got {key.dim()} dimensions)")
|
|
170
|
+
|
|
171
|
+
if not key.is_contiguous():
|
|
172
|
+
raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
|
|
173
|
+
f"contiguous")
|
|
174
|
+
|
|
175
|
+
if value is not None:
|
|
176
|
+
if key.device != value.device:
|
|
177
|
+
raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
|
|
178
|
+
f"are expected to be on the same device (got "
|
|
179
|
+
f"'{key.device}' and '{value.device}')")
|
|
180
|
+
|
|
181
|
+
if key.numel() != value.size(0):
|
|
182
|
+
raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
|
|
183
|
+
f"are expected to have the same size in the "
|
|
184
|
+
f"first dimension (got {key.size(0)} and "
|
|
185
|
+
f"{value.size(0)})")
|
|
186
|
+
|
|
187
|
+
min_key = key.min() if key.numel() > 0 else key.new_zeros(())
|
|
188
|
+
max_key = key.max() if key.numel() > 0 else key.new_zeros(())
|
|
189
|
+
|
|
190
|
+
_range = max_key - min_key
|
|
191
|
+
# TODO Expose fixed threshold as argument.
|
|
192
|
+
if (key.dtype in {torch.uint8, torch.int16} or _range <= 1_000_000
|
|
193
|
+
or _range <= 2 * key.numel()):
|
|
194
|
+
_map = torch.full(
|
|
195
|
+
size=(_range + 3, ),
|
|
196
|
+
fill_value=-1,
|
|
197
|
+
dtype=torch.int64,
|
|
198
|
+
device=key.device,
|
|
199
|
+
)
|
|
200
|
+
_map[key.long() - (min_key.long() - 1)] = torch.arange(
|
|
201
|
+
key.numel(),
|
|
202
|
+
dtype=_map.dtype,
|
|
203
|
+
device=_map.device,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
_map = get_hash_map(key)
|
|
207
|
+
|
|
208
|
+
return cls._from_data(
|
|
209
|
+
_map,
|
|
210
|
+
value,
|
|
211
|
+
min_key,
|
|
212
|
+
max_key,
|
|
213
|
+
num_keys=key.numel(),
|
|
214
|
+
dtype=dtype,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Private Methods #########################################################
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def _from_data(
|
|
221
|
+
cls,
|
|
222
|
+
_map: Union[Tensor, CPUHashMap, CUDAHashMap],
|
|
223
|
+
value: Optional[Tensor],
|
|
224
|
+
min_key: Tensor,
|
|
225
|
+
max_key: Tensor,
|
|
226
|
+
*,
|
|
227
|
+
num_keys: int,
|
|
228
|
+
dtype: Optional[torch.dtype],
|
|
229
|
+
) -> 'HashTensor':
|
|
230
|
+
|
|
231
|
+
if value is not None:
|
|
232
|
+
dtype = value.dtype
|
|
233
|
+
size = value.size()
|
|
234
|
+
stride = value.stride()
|
|
235
|
+
layout = value.layout
|
|
236
|
+
requires_grad = value.requires_grad
|
|
237
|
+
else:
|
|
238
|
+
dtype = dtype or torch.int64
|
|
239
|
+
size = torch.Size([num_keys])
|
|
240
|
+
stride = (1, )
|
|
241
|
+
layout = torch.strided
|
|
242
|
+
requires_grad = False
|
|
243
|
+
|
|
244
|
+
out = Tensor._make_wrapper_subclass(
|
|
245
|
+
cls,
|
|
246
|
+
size=size,
|
|
247
|
+
strides=stride,
|
|
248
|
+
dtype=dtype,
|
|
249
|
+
device=min_key.device,
|
|
250
|
+
layout=layout,
|
|
251
|
+
requires_grad=requires_grad,
|
|
252
|
+
)
|
|
253
|
+
assert isinstance(out, HashTensor)
|
|
254
|
+
|
|
255
|
+
out._map = _map
|
|
256
|
+
out._value = value
|
|
257
|
+
out._min_key = min_key
|
|
258
|
+
out._max_key = max_key
|
|
259
|
+
|
|
260
|
+
return out
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def _key(self) -> Tensor:
|
|
264
|
+
if isinstance(self._map, Tensor):
|
|
265
|
+
mask = self._map >= 0
|
|
266
|
+
key = mask.nonzero().view(-1) - 1
|
|
267
|
+
key = key[self._map[mask]]
|
|
268
|
+
elif (torch_geometric.typing.WITH_CUDA_HASH_MAP
|
|
269
|
+
or torch_geometric.typing.WITH_CPU_HASH_MAP):
|
|
270
|
+
key = self._map.keys().to(self.device)
|
|
271
|
+
else:
|
|
272
|
+
key = torch.from_numpy(self._map.categories.to_numpy())
|
|
273
|
+
|
|
274
|
+
return key.to(self.device)
|
|
275
|
+
|
|
276
|
+
def _shallow_copy(self) -> 'HashTensor':
|
|
277
|
+
return self._from_data(
|
|
278
|
+
self._map,
|
|
279
|
+
self._value,
|
|
280
|
+
self._min_key,
|
|
281
|
+
self._max_key,
|
|
282
|
+
num_keys=self.size(0),
|
|
283
|
+
dtype=self.dtype,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def _get(self, query: Tensor) -> Tensor:
|
|
287
|
+
if isinstance(self._map, Tensor):
|
|
288
|
+
index = query.long() - (self._min_key.long() - 1)
|
|
289
|
+
index = self._map[index.clamp_(min=0, max=self._map.numel() - 1)]
|
|
290
|
+
elif torch_geometric.typing.WITH_CUDA_HASH_MAP and query.is_cuda:
|
|
291
|
+
index = self._map.get(query)
|
|
292
|
+
elif torch_geometric.typing.WITH_CPU_HASH_MAP:
|
|
293
|
+
index = self._map.get(query.cpu())
|
|
294
|
+
else:
|
|
295
|
+
import pandas as pd
|
|
296
|
+
|
|
297
|
+
ser = pd.Series(query.cpu().numpy(), dtype=self._map)
|
|
298
|
+
index = torch.from_numpy(ser.cat.codes.to_numpy().copy()).long()
|
|
299
|
+
|
|
300
|
+
index = index.to(self.device)
|
|
301
|
+
|
|
302
|
+
if self._value is None:
|
|
303
|
+
return index.to(self.dtype)
|
|
304
|
+
|
|
305
|
+
out = self._value[index]
|
|
306
|
+
|
|
307
|
+
mask = index != -1
|
|
308
|
+
mask = mask.view([-1] + [1] * (out.dim() - 1))
|
|
309
|
+
fill_value = float('NaN') if out.is_floating_point() else -1
|
|
310
|
+
if torch_geometric.typing.WITH_PT20:
|
|
311
|
+
other: Union[int, float, Tensor] = fill_value
|
|
312
|
+
else:
|
|
313
|
+
other = torch.full_like(out, fill_value)
|
|
314
|
+
|
|
315
|
+
return out.where(mask, other)
|
|
316
|
+
|
|
317
|
+
# Methods #################################################################
|
|
318
|
+
|
|
319
|
+
def as_tensor(self) -> Tensor:
|
|
320
|
+
r"""Zero-copies the :class:`HashTensor` representation back to a
|
|
321
|
+
:class:`torch.Tensor` representation.
|
|
322
|
+
"""
|
|
323
|
+
if self._value is not None:
|
|
324
|
+
return self._value
|
|
325
|
+
return torch.arange(self.size(0), dtype=self.dtype, device=self.device)
|
|
326
|
+
|
|
327
|
+
# PyTorch/Python builtins #################################################
|
|
328
|
+
|
|
329
|
+
# Prevent auto-wrapping outputs back into the proper subclass type:
|
|
330
|
+
__torch_function__ = torch._C._disabled_torch_function_impl # type: ignore
|
|
331
|
+
|
|
332
|
+
@classmethod
|
|
333
|
+
def __torch_dispatch__( # type: ignore
|
|
334
|
+
cls: Type,
|
|
335
|
+
func: Callable[..., Any],
|
|
336
|
+
types: Iterable[Type[Any]],
|
|
337
|
+
args: Iterable[Tuple[Any, ...]] = (),
|
|
338
|
+
kwargs: Optional[Dict[Any, Any]] = None,
|
|
339
|
+
) -> Any:
|
|
340
|
+
# Hold a number of `HANDLED_FUNCTIONS` that implement specific
|
|
341
|
+
# functions for valid `HashTensor` routines.
|
|
342
|
+
if func in HANDLED_FUNCTIONS:
|
|
343
|
+
return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
|
|
344
|
+
|
|
345
|
+
# For all other PyTorch functions, we treat them as vanilla tensors.
|
|
346
|
+
args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args)
|
|
347
|
+
if kwargs is not None:
|
|
348
|
+
kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(),
|
|
349
|
+
kwargs)
|
|
350
|
+
return func(*args, **(kwargs or {}))
|
|
351
|
+
|
|
352
|
+
def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
|
|
353
|
+
attrs = ['_map', '_min_key', '_max_key']
|
|
354
|
+
if self._value is not None:
|
|
355
|
+
attrs.append('_value')
|
|
356
|
+
|
|
357
|
+
ctx = (self.size(0), self.dtype)
|
|
358
|
+
|
|
359
|
+
return attrs, ctx
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def __tensor_unflatten__(
|
|
363
|
+
inner_tensors: Dict[str, Any],
|
|
364
|
+
ctx: Tuple[Any, ...],
|
|
365
|
+
outer_size: Tuple[int, ...],
|
|
366
|
+
outer_stride: Tuple[int, ...],
|
|
367
|
+
) -> 'HashTensor':
|
|
368
|
+
return HashTensor._from_data(
|
|
369
|
+
inner_tensors['_map'],
|
|
370
|
+
inner_tensors.get('_value', None),
|
|
371
|
+
inner_tensors['_min_key'],
|
|
372
|
+
inner_tensors['_min_key'],
|
|
373
|
+
num_keys=ctx[0],
|
|
374
|
+
dtype=ctx[1],
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def __repr__(self) -> str: # type: ignore
|
|
378
|
+
indent = len(f'{self.__class__.__name__}(')
|
|
379
|
+
tensor_str = torch._tensor_str._tensor_str(self.as_tensor(), indent)
|
|
380
|
+
return torch._tensor_str._str_intern(self, tensor_contents=tensor_str)
|
|
381
|
+
|
|
382
|
+
def tolist(self) -> List[Any]:
|
|
383
|
+
"""""" # noqa: D419
|
|
384
|
+
return self.as_tensor().tolist()
|
|
385
|
+
|
|
386
|
+
def numpy(self, *, force: bool = False) -> np.ndarray:
|
|
387
|
+
"""""" # noqa: D419
|
|
388
|
+
return self.as_tensor().numpy(force=force)
|
|
389
|
+
|
|
390
|
+
def index_select( # type: ignore
|
|
391
|
+
self,
|
|
392
|
+
dim: int,
|
|
393
|
+
index: Any,
|
|
394
|
+
) -> Union['HashTensor', Tensor]:
|
|
395
|
+
"""""" # noqa: D419
|
|
396
|
+
return torch.index_select(self, dim, index)
|
|
397
|
+
|
|
398
|
+
def select( # type: ignore
|
|
399
|
+
self,
|
|
400
|
+
dim: int,
|
|
401
|
+
index: Any,
|
|
402
|
+
) -> Union['HashTensor', Tensor]:
|
|
403
|
+
"""""" # noqa: D419
|
|
404
|
+
return torch.select(self, dim, index)
|
|
405
|
+
|
|
406
|
+
def share_memory_(self) -> 'HashTensor':
|
|
407
|
+
"""""" # noqa: D419
|
|
408
|
+
if isinstance(self._map, Tensor):
|
|
409
|
+
self._map.share_memory_()
|
|
410
|
+
if self._value is not None:
|
|
411
|
+
self._value.share_memory_()
|
|
412
|
+
self._min_key.share_memory_()
|
|
413
|
+
self._max_key.share_memory_()
|
|
414
|
+
return self
|
|
415
|
+
|
|
416
|
+
def is_shared(self) -> bool:
|
|
417
|
+
"""""" # noqa: D419
|
|
418
|
+
return self._min_key.is_shared()
|
|
419
|
+
|
|
420
|
+
def detach_(self) -> 'HashTensor':
|
|
421
|
+
"""""" # noqa: D419
|
|
422
|
+
if self._value is not None:
|
|
423
|
+
self._value.detach_()
|
|
424
|
+
return super().detach_() # type: ignore
|
|
425
|
+
|
|
426
|
+
def __getitem__(self, indices: Any) -> Union['HashTensor', Tensor]:
|
|
427
|
+
if not isinstance(indices, tuple):
|
|
428
|
+
indices = (indices, )
|
|
429
|
+
assert len(indices) > 0
|
|
430
|
+
|
|
431
|
+
# We convert any index tensor in the first dimension into a tensor.
|
|
432
|
+
# This means that downstream handling (i.e. in `aten.index.Tensor`)
|
|
433
|
+
# needs to take this pre-conversion into account. However, detecting
|
|
434
|
+
# whether the first dimension is indexed can be tricky at times:
|
|
435
|
+
# * We need to take into account `Ellipsis`
|
|
436
|
+
# * We need to take any unsqueezing into account
|
|
437
|
+
if indices[0] is Ellipsis and len(indices) > 1:
|
|
438
|
+
nonempty_indices = [i for i in indices[1:] if i is not None]
|
|
439
|
+
if len(nonempty_indices) == self.dim():
|
|
440
|
+
indices = indices[1:]
|
|
441
|
+
|
|
442
|
+
if isinstance(indices[0], (int, bool)):
|
|
443
|
+
index: Union[int, Tensor] = int(as_key_tensor([indices[0]]))
|
|
444
|
+
indices = (index, ) + indices[1:]
|
|
445
|
+
elif isinstance(indices[0], (Tensor, list, np.ndarray)):
|
|
446
|
+
index = as_key_tensor(indices[0], device=self.device)
|
|
447
|
+
indices = (index, ) + indices[1:]
|
|
448
|
+
|
|
449
|
+
indices = indices[0] if len(indices) == 1 else indices
|
|
450
|
+
|
|
451
|
+
return super().__getitem__(indices)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@implements(aten.alias.default)
|
|
455
|
+
def _alias(tensor: HashTensor) -> HashTensor:
|
|
456
|
+
return tensor._shallow_copy()
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@implements(aten.clone.default)
|
|
460
|
+
def _clone(
|
|
461
|
+
tensor: HashTensor,
|
|
462
|
+
*,
|
|
463
|
+
memory_format: torch.memory_format = torch.preserve_format,
|
|
464
|
+
) -> HashTensor:
|
|
465
|
+
|
|
466
|
+
value = tensor._value
|
|
467
|
+
if value is not None:
|
|
468
|
+
value = aten.clone.default(value, memory_format=memory_format)
|
|
469
|
+
|
|
470
|
+
return tensor._from_data(
|
|
471
|
+
tensor._map, # NOTE No reason to do clone since it is read-only.
|
|
472
|
+
value,
|
|
473
|
+
tensor._min_key, # NOTE No reason to do clone since it is read-only.
|
|
474
|
+
tensor._max_key, # NOTE No reason to do clone since it is read-only.
|
|
475
|
+
num_keys=tensor.size(0),
|
|
476
|
+
dtype=tensor.dtype,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@implements(aten.detach.default)
|
|
481
|
+
def _detach(tensor: HashTensor) -> HashTensor:
|
|
482
|
+
value = tensor._value
|
|
483
|
+
if value is not None:
|
|
484
|
+
value = aten.detach.default(value)
|
|
485
|
+
|
|
486
|
+
return tensor._from_data(
|
|
487
|
+
tensor._map,
|
|
488
|
+
value,
|
|
489
|
+
tensor._min_key,
|
|
490
|
+
tensor._max_key,
|
|
491
|
+
num_keys=tensor.size(0),
|
|
492
|
+
dtype=tensor.dtype,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
@implements(aten._to_copy.default)
|
|
497
|
+
def _to_copy(
|
|
498
|
+
tensor: HashTensor,
|
|
499
|
+
*,
|
|
500
|
+
dtype: Optional[torch.dtype] = None,
|
|
501
|
+
layout: Optional[torch.layout] = None,
|
|
502
|
+
device: Optional[torch.device] = None,
|
|
503
|
+
pin_memory: bool = False,
|
|
504
|
+
non_blocking: bool = False,
|
|
505
|
+
memory_format: Optional[torch.memory_format] = None,
|
|
506
|
+
) -> HashTensor:
|
|
507
|
+
|
|
508
|
+
value = tensor._value
|
|
509
|
+
if value is not None:
|
|
510
|
+
value = aten._to_copy.default(
|
|
511
|
+
value,
|
|
512
|
+
dtype=dtype,
|
|
513
|
+
layout=layout,
|
|
514
|
+
device=device,
|
|
515
|
+
pin_memory=pin_memory,
|
|
516
|
+
non_blocking=non_blocking,
|
|
517
|
+
memory_format=memory_format,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
min_key = aten._to_copy.default(tensor._min_key, device=device)
|
|
521
|
+
max_key = aten._to_copy.default(tensor._max_key, device=device)
|
|
522
|
+
|
|
523
|
+
_map = tensor._map
|
|
524
|
+
if isinstance(_map, Tensor):
|
|
525
|
+
_map = aten._to_copy.default(_map, device=device)
|
|
526
|
+
# Only convert `_map` in case `CUDAHashMap` exists - otherwise we use
|
|
527
|
+
# CPU-based mapping anyway and there is no need for a copy.
|
|
528
|
+
elif (torch_geometric.typing.WITH_CUDA_HASH_MAP and tensor.is_cuda
|
|
529
|
+
and tensor.device != min_key.device):
|
|
530
|
+
key = _map.keys()
|
|
531
|
+
key = aten._to_copy.default(key, device=device)
|
|
532
|
+
_map = get_hash_map(key)
|
|
533
|
+
|
|
534
|
+
return tensor._from_data(
|
|
535
|
+
_map,
|
|
536
|
+
value,
|
|
537
|
+
min_key,
|
|
538
|
+
max_key,
|
|
539
|
+
num_keys=tensor.size(0),
|
|
540
|
+
dtype=dtype or tensor.dtype,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
@implements(aten._pin_memory.default)
|
|
545
|
+
def _pin_memory(tensor: HashTensor) -> HashTensor:
|
|
546
|
+
_map = tensor._map
|
|
547
|
+
if isinstance(_map, Tensor):
|
|
548
|
+
_map = aten._pin_memory.default(_map)
|
|
549
|
+
|
|
550
|
+
value = tensor._value
|
|
551
|
+
if value is not None:
|
|
552
|
+
value = aten._pin_memory.default(value)
|
|
553
|
+
|
|
554
|
+
return tensor._from_data(
|
|
555
|
+
_map,
|
|
556
|
+
value,
|
|
557
|
+
aten._pin_memory.default(tensor._min_key),
|
|
558
|
+
aten._pin_memory.default(tensor._max_key),
|
|
559
|
+
num_keys=tensor.size(0),
|
|
560
|
+
dtype=tensor.dtype,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
@implements(aten.unsqueeze.default)
|
|
565
|
+
def _unsqueeze(tensor: HashTensor, dim: int) -> HashTensor:
|
|
566
|
+
if dim == 0 or dim == -(tensor.dim() + 1):
|
|
567
|
+
raise IndexError(f"Cannot unsqueeze '{tensor.__class__.__name__}' in "
|
|
568
|
+
f"the first dimension. Please call `as_tensor()` "
|
|
569
|
+
f"beforehand")
|
|
570
|
+
|
|
571
|
+
return tensor._from_data(
|
|
572
|
+
tensor._map,
|
|
573
|
+
aten.unsqueeze.default(tensor.as_tensor(), dim),
|
|
574
|
+
tensor._min_key,
|
|
575
|
+
tensor._max_key,
|
|
576
|
+
num_keys=tensor.size(0),
|
|
577
|
+
dtype=tensor.dtype,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
@implements(aten.squeeze.default)
|
|
582
|
+
def _squeeze_default(tensor: HashTensor) -> HashTensor:
|
|
583
|
+
if tensor._value is None:
|
|
584
|
+
return tensor._shallow_copy()
|
|
585
|
+
|
|
586
|
+
value = tensor.as_tensor()
|
|
587
|
+
for d in range(tensor.dim() - 1, 0, -1):
|
|
588
|
+
value = value.squeeze(d)
|
|
589
|
+
|
|
590
|
+
return tensor._from_data(
|
|
591
|
+
tensor._map,
|
|
592
|
+
value,
|
|
593
|
+
tensor._min_key,
|
|
594
|
+
tensor._max_key,
|
|
595
|
+
num_keys=tensor.size(0),
|
|
596
|
+
dtype=tensor.dtype,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
@implements(aten.squeeze.dim)
|
|
601
|
+
@implements(getattr(aten.squeeze, 'dims', aten.squeeze.dim))
|
|
602
|
+
def _squeeze_dim(
|
|
603
|
+
tensor: HashTensor,
|
|
604
|
+
dim: Union[int, List[int]],
|
|
605
|
+
) -> HashTensor:
|
|
606
|
+
if isinstance(dim, int):
|
|
607
|
+
dim = [dim]
|
|
608
|
+
|
|
609
|
+
for d in dim:
|
|
610
|
+
if d < -tensor.dim() or d >= tensor.dim():
|
|
611
|
+
raise IndexError(f"Dimension out of range (expected to be in "
|
|
612
|
+
f"range of [{-tensor.dim()}, {tensor.dim()-1}], "
|
|
613
|
+
f"but got {d})")
|
|
614
|
+
|
|
615
|
+
if tensor._value is None:
|
|
616
|
+
return tensor._shallow_copy()
|
|
617
|
+
|
|
618
|
+
value = tensor.as_tensor()
|
|
619
|
+
for d in dim[::-1]:
|
|
620
|
+
if d != 0 and d != -tensor.dim():
|
|
621
|
+
value = value.squeeze(d)
|
|
622
|
+
|
|
623
|
+
return tensor._from_data(
|
|
624
|
+
tensor._map,
|
|
625
|
+
value,
|
|
626
|
+
tensor._min_key,
|
|
627
|
+
tensor._max_key,
|
|
628
|
+
num_keys=tensor.size(0),
|
|
629
|
+
dtype=tensor.dtype,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
@implements(aten.slice.Tensor)
|
|
634
|
+
def _slice(
|
|
635
|
+
tensor: HashTensor,
|
|
636
|
+
dim: int,
|
|
637
|
+
start: Optional[int] = None,
|
|
638
|
+
end: Optional[int] = None,
|
|
639
|
+
step: int = 1,
|
|
640
|
+
) -> HashTensor:
|
|
641
|
+
|
|
642
|
+
if dim == 0 or dim == -tensor.dim():
|
|
643
|
+
copy = start is None or start == 0 or start <= -tensor.size(0)
|
|
644
|
+
copy &= end is None or end > tensor.size(0)
|
|
645
|
+
copy &= step == 1
|
|
646
|
+
if copy:
|
|
647
|
+
return tensor._shallow_copy()
|
|
648
|
+
|
|
649
|
+
key = aten.slice.Tensor(tensor._key, 0, start, end, step)
|
|
650
|
+
value = aten.slice.Tensor(tensor.as_tensor(), 0, start, end, step)
|
|
651
|
+
return tensor.__class__(key, value)
|
|
652
|
+
|
|
653
|
+
return tensor._from_data(
|
|
654
|
+
tensor._map,
|
|
655
|
+
aten.slice.Tensor(tensor.as_tensor(), dim, start, end, step),
|
|
656
|
+
tensor._min_key,
|
|
657
|
+
tensor._max_key,
|
|
658
|
+
num_keys=tensor.size(0),
|
|
659
|
+
dtype=tensor.dtype,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
# Since PyTorch does only allow PyTorch tensors as indices in `index_select`,
|
|
664
|
+
# we need to create a wrapper function and monkey patch `index_select` :(
|
|
665
|
+
_old_index_select = torch.index_select
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def _new_index_select(
|
|
669
|
+
input: Tensor,
|
|
670
|
+
dim: Union[int, str],
|
|
671
|
+
index: Tensor,
|
|
672
|
+
out: Optional[Tensor] = None,
|
|
673
|
+
) -> Tensor:
|
|
674
|
+
|
|
675
|
+
if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()):
|
|
676
|
+
raise IndexError(f"Dimension out of range (expected to be in range of "
|
|
677
|
+
f"[{-input.dim()}, {input.dim()-1}], but got {dim})")
|
|
678
|
+
|
|
679
|
+
# We convert any index tensor in the first dimension into a tensor. This
|
|
680
|
+
# means that downstream handling (i.e. in `aten.index_select.default`)
|
|
681
|
+
# needs to take this pre-conversion into account.
|
|
682
|
+
if (not torch.jit.is_scripting() and isinstance(input, HashTensor)
|
|
683
|
+
and isinstance(dim, int) and (dim == 0 or dim == -input.dim())):
|
|
684
|
+
index = as_key_tensor(index, device=input.device)
|
|
685
|
+
|
|
686
|
+
if isinstance(dim, int): # Type narrowing...
|
|
687
|
+
if out is None:
|
|
688
|
+
return _old_index_select(input, dim, index)
|
|
689
|
+
else:
|
|
690
|
+
return _old_index_select(input, dim, index, out=out)
|
|
691
|
+
else:
|
|
692
|
+
if out is None:
|
|
693
|
+
return _old_index_select(input, dim, index)
|
|
694
|
+
else:
|
|
695
|
+
return _old_index_select(input, dim, index, out=out)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
torch.index_select = _new_index_select # type: ignore
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
@implements(aten.index_select.default)
|
|
702
|
+
def _index_select(
|
|
703
|
+
tensor: HashTensor,
|
|
704
|
+
dim: int,
|
|
705
|
+
index: Tensor,
|
|
706
|
+
) -> Union[HashTensor, Tensor]:
|
|
707
|
+
|
|
708
|
+
if dim == 0 or dim == -tensor.dim():
|
|
709
|
+
return tensor._get(index)
|
|
710
|
+
|
|
711
|
+
return tensor._from_data(
|
|
712
|
+
tensor._map,
|
|
713
|
+
aten.index_select.default(tensor.as_tensor(), dim, index),
|
|
714
|
+
tensor._min_key,
|
|
715
|
+
tensor._max_key,
|
|
716
|
+
num_keys=tensor.size(0),
|
|
717
|
+
dtype=tensor.dtype,
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
# Since PyTorch does only allow PyTorch tensors as indices in `select`, we need
|
|
722
|
+
# to create a wrapper function and monkey patch `select` :(
|
|
723
|
+
_old_select = torch.select
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _new_select(
|
|
727
|
+
input: Tensor,
|
|
728
|
+
dim: Union[int, str],
|
|
729
|
+
index: int,
|
|
730
|
+
) -> Tensor:
|
|
731
|
+
|
|
732
|
+
if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()):
|
|
733
|
+
raise IndexError(f"Dimension out of range (expected to be in range of "
|
|
734
|
+
f"[{-input.dim()}, {input.dim()-1}], but got {dim})")
|
|
735
|
+
|
|
736
|
+
# We convert any index in the first dimension into an integer. This means
|
|
737
|
+
# that downstream handling (i.e. in `aten.select.int`) needs to take this
|
|
738
|
+
# pre-conversion into account.
|
|
739
|
+
if (not torch.jit.is_scripting() and isinstance(input, HashTensor)
|
|
740
|
+
and isinstance(dim, int) and (dim == 0 or dim == -input.dim())):
|
|
741
|
+
index = int(as_key_tensor([index]))
|
|
742
|
+
|
|
743
|
+
if isinstance(dim, int): # Type narrowing...
|
|
744
|
+
return _old_select(input, dim, index)
|
|
745
|
+
else:
|
|
746
|
+
return _old_select(input, dim, index)
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
torch.select = _new_select # type: ignore
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
@implements(aten.select.int)
|
|
753
|
+
def _select(
|
|
754
|
+
tensor: HashTensor,
|
|
755
|
+
dim: int,
|
|
756
|
+
index: int,
|
|
757
|
+
) -> Union[HashTensor, Tensor]:
|
|
758
|
+
|
|
759
|
+
if dim == 0 or dim == -tensor.dim():
|
|
760
|
+
key = torch.tensor(
|
|
761
|
+
[index],
|
|
762
|
+
dtype=tensor._min_key.dtype,
|
|
763
|
+
device=tensor.device,
|
|
764
|
+
)
|
|
765
|
+
return tensor._get(key).squeeze(0)
|
|
766
|
+
|
|
767
|
+
return tensor._from_data(
|
|
768
|
+
tensor._map,
|
|
769
|
+
aten.select.int(tensor.as_tensor(), dim, index),
|
|
770
|
+
tensor._min_key,
|
|
771
|
+
tensor._max_key,
|
|
772
|
+
num_keys=tensor.size(0),
|
|
773
|
+
dtype=tensor.dtype,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
@implements(aten.index.Tensor)
|
|
778
|
+
def _index(
|
|
779
|
+
tensor: HashTensor,
|
|
780
|
+
indices: List[Optional[Tensor]],
|
|
781
|
+
) -> Union[HashTensor, Tensor]:
|
|
782
|
+
|
|
783
|
+
assert len(indices) > 0
|
|
784
|
+
|
|
785
|
+
if indices[0] is not None:
|
|
786
|
+
out = tensor._get(indices[0])
|
|
787
|
+
if len(indices) > 1:
|
|
788
|
+
out = aten.index.Tensor(out, [None] + indices[1:])
|
|
789
|
+
return out
|
|
790
|
+
|
|
791
|
+
return tensor._from_data(
|
|
792
|
+
tensor._map,
|
|
793
|
+
aten.index.Tensor(tensor.as_tensor(), indices),
|
|
794
|
+
tensor._min_key,
|
|
795
|
+
tensor._max_key,
|
|
796
|
+
num_keys=tensor.size(0),
|
|
797
|
+
dtype=tensor.dtype,
|
|
798
|
+
)
|