pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (228) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_trim_to_layer.py +2 -2
  215. torch_geometric/utils/convert.py +17 -10
  216. torch_geometric/utils/cross_entropy.py +34 -13
  217. torch_geometric/utils/embedding.py +91 -2
  218. torch_geometric/utils/geodesic.py +4 -3
  219. torch_geometric/utils/influence.py +279 -0
  220. torch_geometric/utils/map.py +13 -9
  221. torch_geometric/utils/nested.py +1 -1
  222. torch_geometric/utils/smiles.py +3 -3
  223. torch_geometric/utils/sparse.py +7 -14
  224. torch_geometric/visualization/__init__.py +2 -1
  225. torch_geometric/visualization/graph.py +250 -5
  226. torch_geometric/warnings.py +11 -2
  227. torch_geometric/nn/nlp/__init__.py +0 -7
  228. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -0,0 +1,798 @@
1
+ import functools
2
+ import warnings
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Iterable,
8
+ List,
9
+ Optional,
10
+ Tuple,
11
+ Type,
12
+ Union,
13
+ )
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.utils._pytree as pytree
18
+ import xxhash
19
+ from torch import Tensor
20
+
21
+ import torch_geometric.typing
22
+ from torch_geometric.typing import CPUHashMap, CUDAHashMap
23
+
24
+ aten = torch.ops.aten
25
+
26
+ HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}
27
+
28
+
29
+ def implements(torch_function: Callable) -> Callable:
30
+ r"""Registers a :pytorch:`PyTorch` function override."""
31
+ @functools.wraps(torch_function)
32
+ def decorator(my_function: Callable) -> Callable:
33
+ HANDLED_FUNCTIONS[torch_function] = my_function
34
+ return my_function
35
+
36
+ return decorator
37
+
38
+
39
+ def as_key_tensor(
40
+ key: Any,
41
+ *,
42
+ device: Optional[torch.device] = None,
43
+ ) -> Tensor:
44
+ try:
45
+ key = torch.as_tensor(key, device=device)
46
+ except Exception:
47
+ device = device or torch.get_default_device()
48
+ key = torch.tensor(
49
+ [xxhash.xxh64(x).intdigest() & 0x7FFFFFFFFFFFFFFF for x in key],
50
+ dtype=torch.int64, device=device)
51
+
52
+ if key.element_size() == 1:
53
+ key = key.view(torch.uint8)
54
+ elif key.element_size() == 2:
55
+ key = key.view(torch.int16)
56
+ elif key.element_size() == 4:
57
+ key = key.view(torch.int32)
58
+ elif key.element_size() == 8:
59
+ key = key.view(torch.int64)
60
+ else:
61
+ raise ValueError(f"Received invalid dtype '{key.dtype}' with "
62
+ f"{key.element_size()} bytes")
63
+
64
+ return key
65
+
66
+
67
+ def get_hash_map(key: Tensor) -> Union[CPUHashMap, CUDAHashMap]:
68
+ if torch_geometric.typing.WITH_CUDA_HASH_MAP and key.is_cuda:
69
+ return CUDAHashMap(key, 0.5)
70
+
71
+ if key.is_cuda:
72
+ warnings.warn(
73
+ "Fallback to CPU-based mapping algorithm which may "
74
+ "cause slowdowns and device synchronization. Please "
75
+ "install 'pyg-lib' for an accelerated 'HashTensor' "
76
+ "implementation.", stacklevel=2)
77
+
78
+ if torch_geometric.typing.WITH_CPU_HASH_MAP:
79
+ return CPUHashMap(key.cpu(), -1)
80
+
81
+ import pandas as pd
82
+
83
+ return pd.CategoricalDtype(
84
+ categories=key.cpu().numpy(),
85
+ ordered=True,
86
+ )
87
+
88
+
89
+ class HashTensor(Tensor):
90
+ r"""A :pytorch:`null` :class:`torch.Tensor` that can be referenced by
91
+ arbitrary keys rather than indices in the first dimension.
92
+
93
+ :class:`HashTensor` sub-classes a general :pytorch:`null`
94
+ :class:`torch.Tensor`, and extends it by CPU- and GPU-accelerated mapping
95
+ routines. This allow for fast and efficient access to non-contiguous
96
+ indices/keys while the underlying data is stored in a compact format.
97
+
98
+ This representation is ideal for scenarios where one needs a fast mapping
99
+ routine without relying on CPU-based external packages, and can be used,
100
+ *e.g.*, to perform mapping of global indices to local indices during
101
+ subgraph creation, or in data-processing pipelines to map non-contiguous
102
+ input data into a contiguous space, such as
103
+
104
+ * mapping of hashed node IDs to range :obj:`[0, num_nodes - 1]`
105
+ * mapping of raw input data, *e.g.*, categorical data to range
106
+ :obj:`[0, num_categories - 1]`
107
+
108
+ Specifically, :class:`HashTensor` supports *any* keys of *any* type,
109
+ *e.g.*, strings, timestamps, etc.
110
+
111
+ .. code-block:: python
112
+
113
+ from torch_geometric import HashTensor
114
+
115
+ key = torch.tensor([1000, 100, 10000])
116
+ value = torch.randn(3, 4)
117
+
118
+ tensor = HashTensor(key, value)
119
+ assert tensor.size() == (3, 4)
120
+
121
+ # Filtering:
122
+ query = torch.tensor([10000, 1000])
123
+ out = tensor[query]
124
+ assert out.equal(value[[2, 0]])
125
+
126
+ # Accessing non-existing keys:
127
+ out = tensor[[10000, 0]]
128
+ out.isnan()
129
+ >>> tensor([[False, False, False, False],
130
+ ... [True, True, True, True])
131
+
132
+ # If `value` is not given, indexing returns the position of `query` in
133
+ # `key`, and `-1` otherwise:
134
+ key = ['Animation', 'Comedy', 'Fantasy']
135
+ tensor = HashTensor(key)
136
+
137
+ out = tensor[['Comedy', 'Romance']]
138
+ >>> tensor([1, -1])
139
+
140
+ Args:
141
+ key: The keys in the first dimension.
142
+ value: The values to hold.
143
+ dtype: The desired data type of the values of the returned tensor.
144
+ device: The device of the returned tensor.
145
+ """
146
+ _map: Union[Tensor, CPUHashMap, CUDAHashMap]
147
+ _value: Optional[Tensor]
148
+ _min_key: Tensor
149
+ _max_key: Tensor
150
+
151
+ @staticmethod
152
+ def __new__(
153
+ cls: Type,
154
+ key: Any,
155
+ value: Optional[Any] = None,
156
+ *,
157
+ dtype: Optional[torch.dtype] = None,
158
+ device: Optional[torch.device] = None,
159
+ ) -> 'HashTensor':
160
+
161
+ if value is not None:
162
+ value = torch.as_tensor(value, dtype=dtype, device=device)
163
+ device = value.device
164
+
165
+ key = as_key_tensor(key, device=device)
166
+
167
+ if key.dim() != 1:
168
+ raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
169
+ f"one-dimensional (got {key.dim()} dimensions)")
170
+
171
+ if not key.is_contiguous():
172
+ raise ValueError(f"'key' data in '{cls.__name__}' needs to be "
173
+ f"contiguous")
174
+
175
+ if value is not None:
176
+ if key.device != value.device:
177
+ raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
178
+ f"are expected to be on the same device (got "
179
+ f"'{key.device}' and '{value.device}')")
180
+
181
+ if key.numel() != value.size(0):
182
+ raise ValueError(f"'key' and 'value' data in '{cls.__name__}' "
183
+ f"are expected to have the same size in the "
184
+ f"first dimension (got {key.size(0)} and "
185
+ f"{value.size(0)})")
186
+
187
+ min_key = key.min() if key.numel() > 0 else key.new_zeros(())
188
+ max_key = key.max() if key.numel() > 0 else key.new_zeros(())
189
+
190
+ _range = max_key - min_key
191
+ # TODO Expose fixed threshold as argument.
192
+ if (key.dtype in {torch.uint8, torch.int16} or _range <= 1_000_000
193
+ or _range <= 2 * key.numel()):
194
+ _map = torch.full(
195
+ size=(_range + 3, ),
196
+ fill_value=-1,
197
+ dtype=torch.int64,
198
+ device=key.device,
199
+ )
200
+ _map[key.long() - (min_key.long() - 1)] = torch.arange(
201
+ key.numel(),
202
+ dtype=_map.dtype,
203
+ device=_map.device,
204
+ )
205
+ else:
206
+ _map = get_hash_map(key)
207
+
208
+ return cls._from_data(
209
+ _map,
210
+ value,
211
+ min_key,
212
+ max_key,
213
+ num_keys=key.numel(),
214
+ dtype=dtype,
215
+ )
216
+
217
+ # Private Methods #########################################################
218
+
219
+ @classmethod
220
+ def _from_data(
221
+ cls,
222
+ _map: Union[Tensor, CPUHashMap, CUDAHashMap],
223
+ value: Optional[Tensor],
224
+ min_key: Tensor,
225
+ max_key: Tensor,
226
+ *,
227
+ num_keys: int,
228
+ dtype: Optional[torch.dtype],
229
+ ) -> 'HashTensor':
230
+
231
+ if value is not None:
232
+ dtype = value.dtype
233
+ size = value.size()
234
+ stride = value.stride()
235
+ layout = value.layout
236
+ requires_grad = value.requires_grad
237
+ else:
238
+ dtype = dtype or torch.int64
239
+ size = torch.Size([num_keys])
240
+ stride = (1, )
241
+ layout = torch.strided
242
+ requires_grad = False
243
+
244
+ out = Tensor._make_wrapper_subclass(
245
+ cls,
246
+ size=size,
247
+ strides=stride,
248
+ dtype=dtype,
249
+ device=min_key.device,
250
+ layout=layout,
251
+ requires_grad=requires_grad,
252
+ )
253
+ assert isinstance(out, HashTensor)
254
+
255
+ out._map = _map
256
+ out._value = value
257
+ out._min_key = min_key
258
+ out._max_key = max_key
259
+
260
+ return out
261
+
262
+ @property
263
+ def _key(self) -> Tensor:
264
+ if isinstance(self._map, Tensor):
265
+ mask = self._map >= 0
266
+ key = mask.nonzero().view(-1) - 1
267
+ key = key[self._map[mask]]
268
+ elif (torch_geometric.typing.WITH_CUDA_HASH_MAP
269
+ or torch_geometric.typing.WITH_CPU_HASH_MAP):
270
+ key = self._map.keys().to(self.device)
271
+ else:
272
+ key = torch.from_numpy(self._map.categories.to_numpy())
273
+
274
+ return key.to(self.device)
275
+
276
+ def _shallow_copy(self) -> 'HashTensor':
277
+ return self._from_data(
278
+ self._map,
279
+ self._value,
280
+ self._min_key,
281
+ self._max_key,
282
+ num_keys=self.size(0),
283
+ dtype=self.dtype,
284
+ )
285
+
286
+ def _get(self, query: Tensor) -> Tensor:
287
+ if isinstance(self._map, Tensor):
288
+ index = query.long() - (self._min_key.long() - 1)
289
+ index = self._map[index.clamp_(min=0, max=self._map.numel() - 1)]
290
+ elif torch_geometric.typing.WITH_CUDA_HASH_MAP and query.is_cuda:
291
+ index = self._map.get(query)
292
+ elif torch_geometric.typing.WITH_CPU_HASH_MAP:
293
+ index = self._map.get(query.cpu())
294
+ else:
295
+ import pandas as pd
296
+
297
+ ser = pd.Series(query.cpu().numpy(), dtype=self._map)
298
+ index = torch.from_numpy(ser.cat.codes.to_numpy().copy()).long()
299
+
300
+ index = index.to(self.device)
301
+
302
+ if self._value is None:
303
+ return index.to(self.dtype)
304
+
305
+ out = self._value[index]
306
+
307
+ mask = index != -1
308
+ mask = mask.view([-1] + [1] * (out.dim() - 1))
309
+ fill_value = float('NaN') if out.is_floating_point() else -1
310
+ if torch_geometric.typing.WITH_PT20:
311
+ other: Union[int, float, Tensor] = fill_value
312
+ else:
313
+ other = torch.full_like(out, fill_value)
314
+
315
+ return out.where(mask, other)
316
+
317
+ # Methods #################################################################
318
+
319
+ def as_tensor(self) -> Tensor:
320
+ r"""Zero-copies the :class:`HashTensor` representation back to a
321
+ :class:`torch.Tensor` representation.
322
+ """
323
+ if self._value is not None:
324
+ return self._value
325
+ return torch.arange(self.size(0), dtype=self.dtype, device=self.device)
326
+
327
+ # PyTorch/Python builtins #################################################
328
+
329
+ # Prevent auto-wrapping outputs back into the proper subclass type:
330
+ __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore
331
+
332
+ @classmethod
333
+ def __torch_dispatch__( # type: ignore
334
+ cls: Type,
335
+ func: Callable[..., Any],
336
+ types: Iterable[Type[Any]],
337
+ args: Iterable[Tuple[Any, ...]] = (),
338
+ kwargs: Optional[Dict[Any, Any]] = None,
339
+ ) -> Any:
340
+ # Hold a number of `HANDLED_FUNCTIONS` that implement specific
341
+ # functions for valid `HashTensor` routines.
342
+ if func in HANDLED_FUNCTIONS:
343
+ return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))
344
+
345
+ # For all other PyTorch functions, we treat them as vanilla tensors.
346
+ args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args)
347
+ if kwargs is not None:
348
+ kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(),
349
+ kwargs)
350
+ return func(*args, **(kwargs or {}))
351
+
352
+ def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
353
+ attrs = ['_map', '_min_key', '_max_key']
354
+ if self._value is not None:
355
+ attrs.append('_value')
356
+
357
+ ctx = (self.size(0), self.dtype)
358
+
359
+ return attrs, ctx
360
+
361
+ @staticmethod
362
+ def __tensor_unflatten__(
363
+ inner_tensors: Dict[str, Any],
364
+ ctx: Tuple[Any, ...],
365
+ outer_size: Tuple[int, ...],
366
+ outer_stride: Tuple[int, ...],
367
+ ) -> 'HashTensor':
368
+ return HashTensor._from_data(
369
+ inner_tensors['_map'],
370
+ inner_tensors.get('_value', None),
371
+ inner_tensors['_min_key'],
372
+ inner_tensors['_min_key'],
373
+ num_keys=ctx[0],
374
+ dtype=ctx[1],
375
+ )
376
+
377
+ def __repr__(self) -> str: # type: ignore
378
+ indent = len(f'{self.__class__.__name__}(')
379
+ tensor_str = torch._tensor_str._tensor_str(self.as_tensor(), indent)
380
+ return torch._tensor_str._str_intern(self, tensor_contents=tensor_str)
381
+
382
+ def tolist(self) -> List[Any]:
383
+ """""" # noqa: D419
384
+ return self.as_tensor().tolist()
385
+
386
+ def numpy(self, *, force: bool = False) -> np.ndarray:
387
+ """""" # noqa: D419
388
+ return self.as_tensor().numpy(force=force)
389
+
390
+ def index_select( # type: ignore
391
+ self,
392
+ dim: int,
393
+ index: Any,
394
+ ) -> Union['HashTensor', Tensor]:
395
+ """""" # noqa: D419
396
+ return torch.index_select(self, dim, index)
397
+
398
+ def select( # type: ignore
399
+ self,
400
+ dim: int,
401
+ index: Any,
402
+ ) -> Union['HashTensor', Tensor]:
403
+ """""" # noqa: D419
404
+ return torch.select(self, dim, index)
405
+
406
+ def share_memory_(self) -> 'HashTensor':
407
+ """""" # noqa: D419
408
+ if isinstance(self._map, Tensor):
409
+ self._map.share_memory_()
410
+ if self._value is not None:
411
+ self._value.share_memory_()
412
+ self._min_key.share_memory_()
413
+ self._max_key.share_memory_()
414
+ return self
415
+
416
+ def is_shared(self) -> bool:
417
+ """""" # noqa: D419
418
+ return self._min_key.is_shared()
419
+
420
+ def detach_(self) -> 'HashTensor':
421
+ """""" # noqa: D419
422
+ if self._value is not None:
423
+ self._value.detach_()
424
+ return super().detach_() # type: ignore
425
+
426
+ def __getitem__(self, indices: Any) -> Union['HashTensor', Tensor]:
427
+ if not isinstance(indices, tuple):
428
+ indices = (indices, )
429
+ assert len(indices) > 0
430
+
431
+ # We convert any index tensor in the first dimension into a tensor.
432
+ # This means that downstream handling (i.e. in `aten.index.Tensor`)
433
+ # needs to take this pre-conversion into account. However, detecting
434
+ # whether the first dimension is indexed can be tricky at times:
435
+ # * We need to take into account `Ellipsis`
436
+ # * We need to take any unsqueezing into account
437
+ if indices[0] is Ellipsis and len(indices) > 1:
438
+ nonempty_indices = [i for i in indices[1:] if i is not None]
439
+ if len(nonempty_indices) == self.dim():
440
+ indices = indices[1:]
441
+
442
+ if isinstance(indices[0], (int, bool)):
443
+ index: Union[int, Tensor] = int(as_key_tensor([indices[0]]))
444
+ indices = (index, ) + indices[1:]
445
+ elif isinstance(indices[0], (Tensor, list, np.ndarray)):
446
+ index = as_key_tensor(indices[0], device=self.device)
447
+ indices = (index, ) + indices[1:]
448
+
449
+ indices = indices[0] if len(indices) == 1 else indices
450
+
451
+ return super().__getitem__(indices)
452
+
453
+
454
+ @implements(aten.alias.default)
455
+ def _alias(tensor: HashTensor) -> HashTensor:
456
+ return tensor._shallow_copy()
457
+
458
+
459
+ @implements(aten.clone.default)
460
+ def _clone(
461
+ tensor: HashTensor,
462
+ *,
463
+ memory_format: torch.memory_format = torch.preserve_format,
464
+ ) -> HashTensor:
465
+
466
+ value = tensor._value
467
+ if value is not None:
468
+ value = aten.clone.default(value, memory_format=memory_format)
469
+
470
+ return tensor._from_data(
471
+ tensor._map, # NOTE No reason to do clone since it is read-only.
472
+ value,
473
+ tensor._min_key, # NOTE No reason to do clone since it is read-only.
474
+ tensor._max_key, # NOTE No reason to do clone since it is read-only.
475
+ num_keys=tensor.size(0),
476
+ dtype=tensor.dtype,
477
+ )
478
+
479
+
480
+ @implements(aten.detach.default)
481
+ def _detach(tensor: HashTensor) -> HashTensor:
482
+ value = tensor._value
483
+ if value is not None:
484
+ value = aten.detach.default(value)
485
+
486
+ return tensor._from_data(
487
+ tensor._map,
488
+ value,
489
+ tensor._min_key,
490
+ tensor._max_key,
491
+ num_keys=tensor.size(0),
492
+ dtype=tensor.dtype,
493
+ )
494
+
495
+
496
+ @implements(aten._to_copy.default)
497
+ def _to_copy(
498
+ tensor: HashTensor,
499
+ *,
500
+ dtype: Optional[torch.dtype] = None,
501
+ layout: Optional[torch.layout] = None,
502
+ device: Optional[torch.device] = None,
503
+ pin_memory: bool = False,
504
+ non_blocking: bool = False,
505
+ memory_format: Optional[torch.memory_format] = None,
506
+ ) -> HashTensor:
507
+
508
+ value = tensor._value
509
+ if value is not None:
510
+ value = aten._to_copy.default(
511
+ value,
512
+ dtype=dtype,
513
+ layout=layout,
514
+ device=device,
515
+ pin_memory=pin_memory,
516
+ non_blocking=non_blocking,
517
+ memory_format=memory_format,
518
+ )
519
+
520
+ min_key = aten._to_copy.default(tensor._min_key, device=device)
521
+ max_key = aten._to_copy.default(tensor._max_key, device=device)
522
+
523
+ _map = tensor._map
524
+ if isinstance(_map, Tensor):
525
+ _map = aten._to_copy.default(_map, device=device)
526
+ # Only convert `_map` in case `CUDAHashMap` exists - otherwise we use
527
+ # CPU-based mapping anyway and there is no need for a copy.
528
+ elif (torch_geometric.typing.WITH_CUDA_HASH_MAP and tensor.is_cuda
529
+ and tensor.device != min_key.device):
530
+ key = _map.keys()
531
+ key = aten._to_copy.default(key, device=device)
532
+ _map = get_hash_map(key)
533
+
534
+ return tensor._from_data(
535
+ _map,
536
+ value,
537
+ min_key,
538
+ max_key,
539
+ num_keys=tensor.size(0),
540
+ dtype=dtype or tensor.dtype,
541
+ )
542
+
543
+
544
+ @implements(aten._pin_memory.default)
545
+ def _pin_memory(tensor: HashTensor) -> HashTensor:
546
+ _map = tensor._map
547
+ if isinstance(_map, Tensor):
548
+ _map = aten._pin_memory.default(_map)
549
+
550
+ value = tensor._value
551
+ if value is not None:
552
+ value = aten._pin_memory.default(value)
553
+
554
+ return tensor._from_data(
555
+ _map,
556
+ value,
557
+ aten._pin_memory.default(tensor._min_key),
558
+ aten._pin_memory.default(tensor._max_key),
559
+ num_keys=tensor.size(0),
560
+ dtype=tensor.dtype,
561
+ )
562
+
563
+
564
+ @implements(aten.unsqueeze.default)
565
+ def _unsqueeze(tensor: HashTensor, dim: int) -> HashTensor:
566
+ if dim == 0 or dim == -(tensor.dim() + 1):
567
+ raise IndexError(f"Cannot unsqueeze '{tensor.__class__.__name__}' in "
568
+ f"the first dimension. Please call `as_tensor()` "
569
+ f"beforehand")
570
+
571
+ return tensor._from_data(
572
+ tensor._map,
573
+ aten.unsqueeze.default(tensor.as_tensor(), dim),
574
+ tensor._min_key,
575
+ tensor._max_key,
576
+ num_keys=tensor.size(0),
577
+ dtype=tensor.dtype,
578
+ )
579
+
580
+
581
+ @implements(aten.squeeze.default)
582
+ def _squeeze_default(tensor: HashTensor) -> HashTensor:
583
+ if tensor._value is None:
584
+ return tensor._shallow_copy()
585
+
586
+ value = tensor.as_tensor()
587
+ for d in range(tensor.dim() - 1, 0, -1):
588
+ value = value.squeeze(d)
589
+
590
+ return tensor._from_data(
591
+ tensor._map,
592
+ value,
593
+ tensor._min_key,
594
+ tensor._max_key,
595
+ num_keys=tensor.size(0),
596
+ dtype=tensor.dtype,
597
+ )
598
+
599
+
600
+ @implements(aten.squeeze.dim)
601
+ @implements(getattr(aten.squeeze, 'dims', aten.squeeze.dim))
602
+ def _squeeze_dim(
603
+ tensor: HashTensor,
604
+ dim: Union[int, List[int]],
605
+ ) -> HashTensor:
606
+ if isinstance(dim, int):
607
+ dim = [dim]
608
+
609
+ for d in dim:
610
+ if d < -tensor.dim() or d >= tensor.dim():
611
+ raise IndexError(f"Dimension out of range (expected to be in "
612
+ f"range of [{-tensor.dim()}, {tensor.dim()-1}], "
613
+ f"but got {d})")
614
+
615
+ if tensor._value is None:
616
+ return tensor._shallow_copy()
617
+
618
+ value = tensor.as_tensor()
619
+ for d in dim[::-1]:
620
+ if d != 0 and d != -tensor.dim():
621
+ value = value.squeeze(d)
622
+
623
+ return tensor._from_data(
624
+ tensor._map,
625
+ value,
626
+ tensor._min_key,
627
+ tensor._max_key,
628
+ num_keys=tensor.size(0),
629
+ dtype=tensor.dtype,
630
+ )
631
+
632
+
633
+ @implements(aten.slice.Tensor)
634
+ def _slice(
635
+ tensor: HashTensor,
636
+ dim: int,
637
+ start: Optional[int] = None,
638
+ end: Optional[int] = None,
639
+ step: int = 1,
640
+ ) -> HashTensor:
641
+
642
+ if dim == 0 or dim == -tensor.dim():
643
+ copy = start is None or start == 0 or start <= -tensor.size(0)
644
+ copy &= end is None or end > tensor.size(0)
645
+ copy &= step == 1
646
+ if copy:
647
+ return tensor._shallow_copy()
648
+
649
+ key = aten.slice.Tensor(tensor._key, 0, start, end, step)
650
+ value = aten.slice.Tensor(tensor.as_tensor(), 0, start, end, step)
651
+ return tensor.__class__(key, value)
652
+
653
+ return tensor._from_data(
654
+ tensor._map,
655
+ aten.slice.Tensor(tensor.as_tensor(), dim, start, end, step),
656
+ tensor._min_key,
657
+ tensor._max_key,
658
+ num_keys=tensor.size(0),
659
+ dtype=tensor.dtype,
660
+ )
661
+
662
+
663
+ # Since PyTorch does only allow PyTorch tensors as indices in `index_select`,
664
+ # we need to create a wrapper function and monkey patch `index_select` :(
665
+ _old_index_select = torch.index_select
666
+
667
+
668
+ def _new_index_select(
669
+ input: Tensor,
670
+ dim: Union[int, str],
671
+ index: Tensor,
672
+ out: Optional[Tensor] = None,
673
+ ) -> Tensor:
674
+
675
+ if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()):
676
+ raise IndexError(f"Dimension out of range (expected to be in range of "
677
+ f"[{-input.dim()}, {input.dim()-1}], but got {dim})")
678
+
679
+ # We convert any index tensor in the first dimension into a tensor. This
680
+ # means that downstream handling (i.e. in `aten.index_select.default`)
681
+ # needs to take this pre-conversion into account.
682
+ if (not torch.jit.is_scripting() and isinstance(input, HashTensor)
683
+ and isinstance(dim, int) and (dim == 0 or dim == -input.dim())):
684
+ index = as_key_tensor(index, device=input.device)
685
+
686
+ if isinstance(dim, int): # Type narrowing...
687
+ if out is None:
688
+ return _old_index_select(input, dim, index)
689
+ else:
690
+ return _old_index_select(input, dim, index, out=out)
691
+ else:
692
+ if out is None:
693
+ return _old_index_select(input, dim, index)
694
+ else:
695
+ return _old_index_select(input, dim, index, out=out)
696
+
697
+
698
+ torch.index_select = _new_index_select # type: ignore
699
+
700
+
701
+ @implements(aten.index_select.default)
702
+ def _index_select(
703
+ tensor: HashTensor,
704
+ dim: int,
705
+ index: Tensor,
706
+ ) -> Union[HashTensor, Tensor]:
707
+
708
+ if dim == 0 or dim == -tensor.dim():
709
+ return tensor._get(index)
710
+
711
+ return tensor._from_data(
712
+ tensor._map,
713
+ aten.index_select.default(tensor.as_tensor(), dim, index),
714
+ tensor._min_key,
715
+ tensor._max_key,
716
+ num_keys=tensor.size(0),
717
+ dtype=tensor.dtype,
718
+ )
719
+
720
+
721
+ # Since PyTorch does only allow PyTorch tensors as indices in `select`, we need
722
+ # to create a wrapper function and monkey patch `select` :(
723
+ _old_select = torch.select
724
+
725
+
726
+ def _new_select(
727
+ input: Tensor,
728
+ dim: Union[int, str],
729
+ index: int,
730
+ ) -> Tensor:
731
+
732
+ if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()):
733
+ raise IndexError(f"Dimension out of range (expected to be in range of "
734
+ f"[{-input.dim()}, {input.dim()-1}], but got {dim})")
735
+
736
+ # We convert any index in the first dimension into an integer. This means
737
+ # that downstream handling (i.e. in `aten.select.int`) needs to take this
738
+ # pre-conversion into account.
739
+ if (not torch.jit.is_scripting() and isinstance(input, HashTensor)
740
+ and isinstance(dim, int) and (dim == 0 or dim == -input.dim())):
741
+ index = int(as_key_tensor([index]))
742
+
743
+ if isinstance(dim, int): # Type narrowing...
744
+ return _old_select(input, dim, index)
745
+ else:
746
+ return _old_select(input, dim, index)
747
+
748
+
749
+ torch.select = _new_select # type: ignore
750
+
751
+
752
+ @implements(aten.select.int)
753
+ def _select(
754
+ tensor: HashTensor,
755
+ dim: int,
756
+ index: int,
757
+ ) -> Union[HashTensor, Tensor]:
758
+
759
+ if dim == 0 or dim == -tensor.dim():
760
+ key = torch.tensor(
761
+ [index],
762
+ dtype=tensor._min_key.dtype,
763
+ device=tensor.device,
764
+ )
765
+ return tensor._get(key).squeeze(0)
766
+
767
+ return tensor._from_data(
768
+ tensor._map,
769
+ aten.select.int(tensor.as_tensor(), dim, index),
770
+ tensor._min_key,
771
+ tensor._max_key,
772
+ num_keys=tensor.size(0),
773
+ dtype=tensor.dtype,
774
+ )
775
+
776
+
777
+ @implements(aten.index.Tensor)
778
+ def _index(
779
+ tensor: HashTensor,
780
+ indices: List[Optional[Tensor]],
781
+ ) -> Union[HashTensor, Tensor]:
782
+
783
+ assert len(indices) > 0
784
+
785
+ if indices[0] is not None:
786
+ out = tensor._get(indices[0])
787
+ if len(indices) > 1:
788
+ out = aten.index.Tensor(out, [None] + indices[1:])
789
+ return out
790
+
791
+ return tensor._from_data(
792
+ tensor._map,
793
+ aten.index.Tensor(tensor.as_tensor(), indices),
794
+ tensor._min_key,
795
+ tensor._max_key,
796
+ num_keys=tensor.size(0),
797
+ dtype=tensor.dtype,
798
+ )