pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -1,9 +1,9 @@
1
1
  import copy
2
2
  import math
3
3
  import warnings
4
- from abc import ABC
4
+ from abc import ABC, abstractmethod
5
5
  from collections import defaultdict
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass, field
7
7
  from enum import Enum
8
8
  from typing import Any, Dict, List, Literal, Optional, Union
9
9
 
@@ -11,7 +11,12 @@ import torch
11
11
  from torch import Tensor
12
12
 
13
13
  from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
14
- from torch_geometric.sampler.utils import to_bidirectional
14
+ from torch_geometric.sampler.utils import (
15
+ global_to_local_node_idx,
16
+ local_to_global_node_idx,
17
+ to_bidirectional,
18
+ unique_unsorted,
19
+ )
15
20
  from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
16
21
  from torch_geometric.utils.mixin import CastMixin
17
22
 
@@ -206,6 +211,39 @@ class SamplerOutput(CastMixin):
206
211
  # TODO(manan): refine this further; it does not currently define a proper
207
212
  # API for the expected output of a sampler.
208
213
  metadata: Optional[Any] = None
214
+ _seed_node: OptTensor = field(repr=False, default=None)
215
+
216
+ @property
217
+ def global_row(self) -> Tensor:
218
+ return local_to_global_node_idx(self.node, self.row)
219
+
220
+ @property
221
+ def global_col(self) -> Tensor:
222
+ return local_to_global_node_idx(self.node, self.col)
223
+
224
+ @property
225
+ def seed_node(self) -> Tensor:
226
+ # can be set manually if the seed nodes are not contained in the
227
+ # sampled nodes
228
+ if self._seed_node is None:
229
+ self._seed_node = local_to_global_node_idx(
230
+ self.node, self.batch) if self.batch is not None else None
231
+ return self._seed_node
232
+
233
+ @seed_node.setter
234
+ def seed_node(self, value: Tensor):
235
+ assert len(value) == len(self.node)
236
+ self._seed_node = value
237
+
238
+ @property
239
+ def global_orig_row(self) -> Tensor:
240
+ return local_to_global_node_idx(
241
+ self.node, self.orig_row) if self.orig_row is not None else None
242
+
243
+ @property
244
+ def global_orig_col(self) -> Tensor:
245
+ return local_to_global_node_idx(
246
+ self.node, self.orig_col) if self.orig_col is not None else None
209
247
 
210
248
  def to_bidirectional(
211
249
  self,
@@ -237,6 +275,230 @@ class SamplerOutput(CastMixin):
237
275
 
238
276
  return out
239
277
 
278
+ @classmethod
279
+ def collate(cls, outputs: List['SamplerOutput'],
280
+ replace: bool = True) -> 'SamplerOutput':
281
+ r"""Collate a list of :class:`~torch_geometric.sampler.SamplerOutput`
282
+ objects into a single :class:`~torch_geometric.sampler.SamplerOutput`
283
+ object. Requires that they all have the same fields.
284
+ """
285
+ if len(outputs) == 0:
286
+ raise ValueError("Cannot collate an empty list of SamplerOutputs")
287
+ out = outputs[0]
288
+ has_edge = out.edge is not None
289
+ has_orig_row = out.orig_row is not None
290
+ has_orig_col = out.orig_col is not None
291
+ has_batch = out.batch is not None
292
+ has_num_sampled_nodes = out.num_sampled_nodes is not None
293
+ has_num_sampled_edges = out.num_sampled_edges is not None
294
+
295
+ try:
296
+ for i, sample_output in enumerate(outputs): # noqa
297
+ assert not has_edge == (sample_output.edge is None)
298
+ assert not has_orig_row == (sample_output.orig_row is None)
299
+ assert not has_orig_col == (sample_output.orig_col is None)
300
+ assert not has_batch == (sample_output.batch is None)
301
+ assert not has_num_sampled_nodes == (
302
+ sample_output.num_sampled_nodes is None)
303
+ assert not has_num_sampled_edges == (
304
+ sample_output.num_sampled_edges is None)
305
+ except AssertionError:
306
+ error_str = f"Output {i+1} has a different field than the first output" # noqa
307
+ raise ValueError(error_str) # noqa
308
+
309
+ for other in outputs[1:]:
310
+ out = out.merge_with(other, replace=replace)
311
+ return out
312
+
313
+ def merge_with(self, other: 'SamplerOutput',
314
+ replace: bool = True) -> 'SamplerOutput':
315
+ """Merges two SamplerOutputs.
316
+ If replace is True, self's nodes and edges take precedence.
317
+ """
318
+ if not replace:
319
+ return SamplerOutput(
320
+ node=torch.cat([self.node, other.node], dim=0),
321
+ row=torch.cat([self.row, len(self.node) + other.row], dim=0),
322
+ col=torch.cat([self.col, len(self.node) + other.col], dim=0),
323
+ edge=torch.cat([self.edge, other.edge], dim=0)
324
+ if self.edge is not None and other.edge is not None else None,
325
+ batch=torch.cat(
326
+ [self.batch, len(self.node) + other.batch], dim=0) if
327
+ self.batch is not None and other.batch is not None else None,
328
+ num_sampled_nodes=self.num_sampled_nodes +
329
+ other.num_sampled_nodes if self.num_sampled_nodes is not None
330
+ and other.num_sampled_nodes is not None else None,
331
+ num_sampled_edges=self.num_sampled_edges +
332
+ other.num_sampled_edges if self.num_sampled_edges is not None
333
+ and other.num_sampled_edges is not None else None,
334
+ orig_row=torch.cat(
335
+ [self.orig_row,
336
+ len(self.node) +
337
+ other.orig_row], dim=0) if self.orig_row is not None
338
+ and other.orig_row is not None else None,
339
+ orig_col=torch.cat(
340
+ [self.orig_col,
341
+ len(self.node) +
342
+ other.orig_col], dim=0) if self.orig_col is not None
343
+ and other.orig_col is not None else None,
344
+ metadata=[self.metadata, other.metadata],
345
+ )
346
+ else:
347
+
348
+ # NODES
349
+ old_nodes, new_nodes = self.node, other.node
350
+ old_node_uid, new_node_uid = [old_nodes], [new_nodes]
351
+
352
+ # batch tracks disjoint subgraph samplings
353
+ if self.batch is not None and other.batch is not None:
354
+ # Transform the batch indices to be global node ids
355
+ old_batch_nodes = self.seed_node
356
+ new_batch_nodes = other.seed_node
357
+ old_node_uid.append(old_batch_nodes)
358
+ new_node_uid.append(new_batch_nodes)
359
+
360
+ # NOTE: if any new node fields are added,
361
+ # they need to be merged here
362
+
363
+ old_node_uid = torch.stack(old_node_uid, dim=1)
364
+ new_node_uid = torch.stack(new_node_uid, dim=1)
365
+
366
+ merged_node_uid = unique_unsorted(
367
+ torch.cat([old_node_uid, new_node_uid], dim=0))
368
+ num_old_nodes = old_node_uid.shape[0]
369
+
370
+ # Recompute num sampled nodes for second output,
371
+ # subtracting out nodes already seen in first output
372
+ merged_node_num_sampled_nodes = None
373
+ if (self.num_sampled_nodes is not None
374
+ and other.num_sampled_nodes is not None):
375
+ merged_node_num_sampled_nodes = copy.copy(
376
+ self.num_sampled_nodes)
377
+ curr_index = 0
378
+ # NOTE: There's an assumption here that no two nodes will be
379
+ # sampled twice in the same SampleOutput object
380
+ for minibatch in other.num_sampled_nodes:
381
+ size_of_intersect = torch.cat([
382
+ old_node_uid,
383
+ new_node_uid[curr_index:curr_index + minibatch]
384
+ ]).unique(dim=0, sorted=False).shape[0] - num_old_nodes
385
+ merged_node_num_sampled_nodes.append(size_of_intersect)
386
+ curr_index += minibatch
387
+
388
+ merged_nodes = merged_node_uid[:, 0]
389
+ merged_batch = None
390
+ if self.batch is not None and other.batch is not None:
391
+ # Restore the batch indices to be relative to the nodes field
392
+ ref_merged_batch_nodes = merged_node_uid[:, 1].unsqueeze(
393
+ -1).expand(-1, 2) # num_nodes x 2
394
+ merged_batch = global_to_local_node_idx(
395
+ merged_node_uid, ref_merged_batch_nodes)
396
+
397
+ # EDGES
398
+ is_bidirectional = self.orig_row is not None \
399
+ and self.orig_col is not None \
400
+ and other.orig_row is not None \
401
+ and other.orig_col is not None
402
+ if is_bidirectional:
403
+ old_row, old_col = self.orig_row, self.orig_col
404
+ new_row, new_col = other.orig_row, other.orig_col
405
+ else:
406
+ old_row, old_col = self.row, self.col
407
+ new_row, new_col = other.row, other.col
408
+
409
+ # Transform the row and col indices to be global node ids
410
+ # instead of relative indices to nodes field
411
+ # Edge uids build off of node uids
412
+ old_row_idx, old_col_idx = local_to_global_node_idx(
413
+ old_node_uid,
414
+ old_row), local_to_global_node_idx(old_node_uid, old_col)
415
+ new_row_idx, new_col_idx = local_to_global_node_idx(
416
+ new_node_uid,
417
+ new_row), local_to_global_node_idx(new_node_uid, new_col)
418
+
419
+ old_edge_uid, new_edge_uid = [old_row_idx, old_col_idx
420
+ ], [new_row_idx, new_col_idx]
421
+
422
+ row_idx = 0
423
+ col_idx = old_row_idx.shape[1]
424
+ edge_idx = old_row_idx.shape[1] + old_col_idx.shape[1]
425
+
426
+ if self.edge is not None and other.edge is not None:
427
+ if is_bidirectional:
428
+ # bidirectional duplicates edge ids
429
+ old_edge_uid_ref = torch.stack([self.row, self.col],
430
+ dim=1) # num_edges x 2
431
+ old_orig_edge_uid_ref = torch.stack(
432
+ [self.orig_row, self.orig_col],
433
+ dim=1) # num_orig_edges x 2
434
+
435
+ old_edge_idx = global_to_local_node_idx(
436
+ old_edge_uid_ref, old_orig_edge_uid_ref)
437
+ old_edge = self.edge[old_edge_idx]
438
+
439
+ new_edge_uid_ref = torch.stack([other.row, other.col],
440
+ dim=1) # num_edges x 2
441
+ new_orig_edge_uid_ref = torch.stack(
442
+ [other.orig_row, other.orig_col],
443
+ dim=1) # num_orig_edges x 2
444
+
445
+ new_edge_idx = global_to_local_node_idx(
446
+ new_edge_uid_ref, new_orig_edge_uid_ref)
447
+ new_edge = other.edge[new_edge_idx]
448
+
449
+ else:
450
+ old_edge, new_edge = self.edge, other.edge
451
+
452
+ old_edge_uid.append(old_edge.unsqueeze(-1))
453
+ new_edge_uid.append(new_edge.unsqueeze(-1))
454
+
455
+ old_edge_uid = torch.cat(old_edge_uid, dim=1)
456
+ new_edge_uid = torch.cat(new_edge_uid, dim=1)
457
+
458
+ merged_edge_uid = unique_unsorted(
459
+ torch.cat([old_edge_uid, new_edge_uid], dim=0))
460
+ num_old_edges = old_edge_uid.shape[0]
461
+
462
+ merged_edge_num_sampled_edges = None
463
+ if (self.num_sampled_edges is not None
464
+ and other.num_sampled_edges is not None):
465
+ merged_edge_num_sampled_edges = copy.copy(
466
+ self.num_sampled_edges)
467
+ curr_index = 0
468
+ # NOTE: There's an assumption here that no two edges will be
469
+ # sampled twice in the same SampleOutput object
470
+ for minibatch in other.num_sampled_edges:
471
+ size_of_intersect = torch.cat([
472
+ old_edge_uid,
473
+ new_edge_uid[curr_index:curr_index + minibatch]
474
+ ]).unique(dim=0, sorted=False).shape[0] - num_old_edges
475
+ merged_edge_num_sampled_edges.append(size_of_intersect)
476
+ curr_index += minibatch
477
+
478
+ merged_row = merged_edge_uid[:, row_idx:col_idx]
479
+ merged_col = merged_edge_uid[:, col_idx:edge_idx]
480
+ merged_edge = merged_edge_uid[:, edge_idx:].squeeze() \
481
+ if self.edge is not None and other.edge is not None else None
482
+
483
+ # restore to row and col indices relative to nodes field
484
+ merged_row = global_to_local_node_idx(merged_node_uid, merged_row)
485
+ merged_col = global_to_local_node_idx(merged_node_uid, merged_col)
486
+
487
+ out = SamplerOutput(
488
+ node=merged_nodes,
489
+ row=merged_row,
490
+ col=merged_col,
491
+ edge=merged_edge,
492
+ batch=merged_batch,
493
+ num_sampled_nodes=merged_node_num_sampled_nodes,
494
+ num_sampled_edges=merged_edge_num_sampled_edges,
495
+ metadata=[self.metadata, other.metadata],
496
+ )
497
+ # Restores orig_row and orig_col if they existed before merging
498
+ if is_bidirectional:
499
+ out = out.to_bidirectional(keep_orig_edges=True)
500
+ return out
501
+
240
502
 
241
503
  @dataclass
242
504
  class HeteroSamplerOutput(CastMixin):
@@ -294,6 +556,43 @@ class HeteroSamplerOutput(CastMixin):
294
556
  # API for the expected output of a sampler.
295
557
  metadata: Optional[Any] = None
296
558
 
559
+ @property
560
+ def global_row(self) -> Dict[EdgeType, Tensor]:
561
+ return {
562
+ edge_type: local_to_global_node_idx(self.node[edge_type[0]], row)
563
+ for edge_type, row in self.row.items()
564
+ }
565
+
566
+ @property
567
+ def global_col(self) -> Dict[EdgeType, Tensor]:
568
+ return {
569
+ edge_type: local_to_global_node_idx(self.node[edge_type[2]], col)
570
+ for edge_type, col in self.col.items()
571
+ }
572
+
573
+ @property
574
+ def seed_node(self) -> Optional[Dict[NodeType, Tensor]]:
575
+ return {
576
+ node_type: local_to_global_node_idx(self.node[node_type], batch)
577
+ for node_type, batch in self.batch.items()
578
+ } if self.batch is not None else None
579
+
580
+ @property
581
+ def global_orig_row(self) -> Optional[Dict[EdgeType, Tensor]]:
582
+ return {
583
+ edge_type: local_to_global_node_idx(self.node[edge_type[0]],
584
+ orig_row)
585
+ for edge_type, orig_row in self.orig_row.items()
586
+ } if self.orig_row is not None else None
587
+
588
+ @property
589
+ def global_orig_col(self) -> Optional[Dict[EdgeType, Tensor]]:
590
+ return {
591
+ edge_type: local_to_global_node_idx(self.node[edge_type[2]],
592
+ orig_col)
593
+ for edge_type, orig_col in self.orig_col.items()
594
+ } if self.orig_col is not None else None
595
+
297
596
  def to_bidirectional(
298
597
  self,
299
598
  keep_orig_edges: bool = False,
@@ -369,12 +668,32 @@ class HeteroSamplerOutput(CastMixin):
369
668
  out.edge[edge_type] = None
370
669
 
371
670
  else:
372
- warnings.warn(f"Cannot convert to bidirectional graph "
373
- f"since the edge type {edge_type} does not "
374
- f"seem to have a reverse edge type")
671
+ warnings.warn(
672
+ f"Cannot convert to bidirectional graph "
673
+ f"since the edge type {edge_type} does not "
674
+ f"seem to have a reverse edge type", stacklevel=2)
375
675
 
376
676
  return out
377
677
 
678
+ @classmethod
679
+ def collate(cls, outputs: List['HeteroSamplerOutput'],
680
+ replace: bool = True) -> 'HeteroSamplerOutput':
681
+ r"""Collate a list of
682
+ :class:`~torch_geometric.sampler.HeteroSamplerOutput`objects into a
683
+ single :class:`~torch_geometric.sampler.HeteroSamplerOutput` object.
684
+ Requires that they all have the same fields.
685
+ """
686
+ # TODO(zaristei)
687
+ raise NotImplementedError
688
+
689
+ def merge_with(self, other: 'HeteroSamplerOutput',
690
+ replace: bool = True) -> 'HeteroSamplerOutput':
691
+ """Merges two HeteroSamplerOutputs.
692
+ If replace is True, self's nodes and edges take precedence.
693
+ """
694
+ # TODO(zaristei)
695
+ raise NotImplementedError
696
+
378
697
 
379
698
  @dataclass(frozen=True)
380
699
  class NumNeighbors:
@@ -423,7 +742,15 @@ class NumNeighbors:
423
742
  elif isinstance(self.values, dict):
424
743
  default = self.default
425
744
  else:
426
- assert False
745
+ raise AssertionError()
746
+
747
+ # Confirm that `values` only hold valid edge types:
748
+ if isinstance(self.values, dict):
749
+ edge_types_str = {EdgeTypeStr(key) for key in edge_types}
750
+ invalid_edge_types = set(self.values.keys()) - edge_types_str
751
+ if len(invalid_edge_types) > 0:
752
+ raise ValueError("Not all edge types specified in "
753
+ "'num_neighbors' exist in the graph")
427
754
 
428
755
  out = {}
429
756
  for edge_type in edge_types:
@@ -614,6 +941,7 @@ class BaseSampler(ABC):
614
941
  As such, it is recommended to limit the amount of information stored in
615
942
  the sampler.
616
943
  """
944
+ @abstractmethod
617
945
  def sample_from_nodes(
618
946
  self,
619
947
  index: NodeSamplerInput,
@@ -634,6 +962,7 @@ class BaseSampler(ABC):
634
962
  """
635
963
  raise NotImplementedError
636
964
 
965
+ @abstractmethod
637
966
  def sample_from_edges(
638
967
  self,
639
968
  index: EdgeSamplerInput,
@@ -1,12 +1,15 @@
1
- from typing import Dict, List, Union
1
+ from typing import Dict, List, Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from torch_geometric.data import Data, HeteroData
6
6
  from torch_geometric.sampler import (
7
7
  BaseSampler,
8
+ EdgeSamplerInput,
8
9
  HeteroSamplerOutput,
10
+ NegativeSampling,
9
11
  NodeSamplerInput,
12
+ SamplerOutput,
10
13
  )
11
14
  from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
12
15
  from torch_geometric.typing import (
@@ -76,6 +79,13 @@ class HGTSampler(BaseSampler):
76
79
  metadata=(inputs.input_id, inputs.time),
77
80
  )
78
81
 
82
+ def sample_from_edges(
83
+ self,
84
+ index: EdgeSamplerInput,
85
+ neg_sampling: Optional[NegativeSampling] = None,
86
+ ) -> Union[HeteroSamplerOutput, SamplerOutput]:
87
+ pass
88
+
79
89
  @property
80
90
  def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
81
91
  return self.perm