pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -25,7 +25,13 @@ from torch_geometric.sampler import (
25
25
  SamplerOutput,
26
26
  )
27
27
  from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType
28
- from torch_geometric.sampler.utils import remap_keys, to_csc, to_hetero_csc
28
+ from torch_geometric.sampler.utils import (
29
+ global_to_local_node_idx,
30
+ remap_keys,
31
+ reverse_edge_type,
32
+ to_csc,
33
+ to_hetero_csc,
34
+ )
29
35
  from torch_geometric.typing import EdgeType, NodeType, OptTensor
30
36
 
31
37
  NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]
@@ -47,23 +53,33 @@ class NeighborSampler(BaseSampler):
47
53
  weight_attr: Optional[str] = None,
48
54
  is_sorted: bool = False,
49
55
  share_memory: bool = False,
50
- # Deprecated:
51
- directed: bool = True,
56
+ directed: bool = True, # Deprecated
57
+ sample_direction: Literal['forward', 'backward'] = 'forward',
52
58
  ):
53
59
  if not directed:
54
60
  subgraph_type = SubgraphType.induced
55
- warnings.warn(f"The usage of the 'directed' argument in "
56
- f"'{self.__class__.__name__}' is deprecated. Use "
57
- f"`subgraph_type='induced'` instead.")
61
+ warnings.warn(
62
+ f"The usage of the 'directed' argument in "
63
+ f"'{self.__class__.__name__}' is deprecated. Use "
64
+ f"`subgraph_type='induced'` instead.", stacklevel=2)
58
65
 
59
66
  if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'
60
67
  and subgraph_type != SubgraphType.induced):
61
- warnings.warn(f"Using '{self.__class__.__name__}' without a "
62
- f"'pyg-lib' installation is deprecated and will be "
63
- f"removed soon. Please install 'pyg-lib' for "
64
- f"accelerated neighborhood sampling")
68
+ warnings.warn(
69
+ f"Using '{self.__class__.__name__}' without a "
70
+ f"'pyg-lib' installation is deprecated and will be "
71
+ f"removed soon. Please install 'pyg-lib' for "
72
+ f"accelerated neighborhood sampling", stacklevel=2)
65
73
 
66
74
  self.data_type = DataType.from_data(data)
75
+ self.sample_direction = sample_direction
76
+
77
+ if self.sample_direction == 'backward':
78
+ # TODO(zaristei)
79
+ if time_attr is not None:
80
+ raise NotImplementedError(
81
+ "Temporal Sampling not yet supported for backward sampling"
82
+ )
67
83
 
68
84
  if self.data_type == DataType.homogeneous:
69
85
  self.num_nodes = data.num_nodes
@@ -85,7 +101,8 @@ class NeighborSampler(BaseSampler):
85
101
  self.colptr, self.row, self.perm = to_csc(
86
102
  data, device='cpu', share_memory=share_memory,
87
103
  is_sorted=is_sorted, src_node_time=self.node_time,
88
- edge_time=self.edge_time)
104
+ edge_time=self.edge_time,
105
+ to_transpose=self.sample_direction == 'backward')
89
106
 
90
107
  if self.edge_time is not None and self.perm is not None:
91
108
  self.edge_time = self.edge_time[self.perm]
@@ -99,6 +116,17 @@ class NeighborSampler(BaseSampler):
99
116
  elif self.data_type == DataType.heterogeneous:
100
117
  self.node_types, self.edge_types = data.metadata()
101
118
 
119
+ # reverse edge types if sample_direction is backward
120
+ if self.sample_direction == 'backward':
121
+ self.edge_types = [
122
+ reverse_edge_type(edge_type)
123
+ for edge_type in self.edge_types
124
+ ]
125
+ self.to_restored_edge_type = {
126
+ k: reverse_edge_type(k)
127
+ for k in self.edge_types
128
+ }
129
+
102
130
  self.num_nodes = {k: data[k].num_nodes for k in self.node_types}
103
131
 
104
132
  self.node_time: Optional[Dict[NodeType, Tensor]] = None
@@ -139,7 +167,8 @@ class NeighborSampler(BaseSampler):
139
167
  colptr_dict, row_dict, self.perm = to_hetero_csc(
140
168
  data, device='cpu', share_memory=share_memory,
141
169
  is_sorted=is_sorted, node_time_dict=self.node_time,
142
- edge_time_dict=self.edge_time)
170
+ edge_time_dict=self.edge_time,
171
+ to_transpose=self.sample_direction == 'backward')
143
172
 
144
173
  self.row_dict = remap_keys(row_dict, self.to_rel_type)
145
174
  self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
@@ -170,6 +199,21 @@ class NeighborSampler(BaseSampler):
170
199
  edge_attrs = graph_store.get_all_edge_attrs()
171
200
  self.edge_types = list({attr.edge_type for attr in edge_attrs})
172
201
 
202
+ # reverse edge types if sample_direction is backward
203
+ if self.sample_direction == 'backward':
204
+ self.edge_types = [
205
+ reverse_edge_type(edge_type)
206
+ for edge_type in self.edge_types
207
+ ]
208
+ self.to_restored_edge_type = {
209
+ k: reverse_edge_type(k)
210
+ for k in self.edge_types
211
+ }
212
+ self.to_backward_edge_type = {
213
+ v: k
214
+ for k, v in self.to_restored_edge_type.items()
215
+ }
216
+
173
217
  if weight_attr is not None:
174
218
  raise NotImplementedError(
175
219
  f"'weight_attr' argument not yet supported within "
@@ -219,7 +263,10 @@ class NeighborSampler(BaseSampler):
219
263
  else:
220
264
  self.edge_time = time_tensor
221
265
 
222
- self.row, self.colptr, self.perm = graph_store.csc()
266
+ if self.sample_direction == 'forward':
267
+ self.row, self.colptr, self.perm = graph_store.csc()
268
+ elif self.sample_direction == 'backward':
269
+ self.colptr, self.row, self.perm = graph_store.csr()
223
270
 
224
271
  else:
225
272
  node_types = [
@@ -259,8 +306,17 @@ class NeighborSampler(BaseSampler):
259
306
  # Conversion to/from C++ string type (see above):
260
307
  self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
261
308
  self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
262
- # Convert the graph data into CSC format for sampling:
263
- row_dict, colptr_dict, self.perm = graph_store.csc()
309
+ if self.sample_direction == 'forward':
310
+ row_dict, colptr_dict, self.perm = graph_store.csc()
311
+ elif self.sample_direction == 'backward':
312
+ colptr_dict, row_dict, self.perm = graph_store.csr()
313
+
314
+ colptr_dict = remap_keys(colptr_dict,
315
+ self.to_backward_edge_type)
316
+ row_dict = remap_keys(row_dict, self.to_backward_edge_type)
317
+ self.perm = remap_keys(self.perm,
318
+ self.to_backward_edge_type)
319
+
264
320
  self.row_dict = remap_keys(row_dict, self.to_rel_type)
265
321
  self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
266
322
 
@@ -279,17 +335,42 @@ class NeighborSampler(BaseSampler):
279
335
  self.subgraph_type = SubgraphType(subgraph_type)
280
336
  self.disjoint = disjoint
281
337
  self.temporal_strategy = temporal_strategy
338
+ self.keep_orig_edges = False
282
339
 
283
340
  @property
284
341
  def num_neighbors(self) -> NumNeighbors:
342
+ if self.sample_direction == 'backward':
343
+ return self._input_num_neighbors \
344
+ if self._input_num_neighbors is not None \
345
+ else self._num_neighbors
285
346
  return self._num_neighbors
286
347
 
287
348
  @num_neighbors.setter
288
349
  def num_neighbors(self, num_neighbors: NumNeighborsType):
350
+ # only used if sample direction is backward and num_neighbors has edge
351
+ # keys
352
+ self._input_num_neighbors = None
353
+
289
354
  if isinstance(num_neighbors, NumNeighbors):
290
- self._num_neighbors = num_neighbors
355
+ num_neighbors_values = num_neighbors.values
356
+ if isinstance(num_neighbors_values,
357
+ dict) and self.sample_direction == 'backward':
358
+ # reverse the edge_types if sample_direction is backward
359
+ self._input_num_neighbors = num_neighbors
360
+ num_neighbors_values = remap_keys(num_neighbors_values,
361
+ self.to_backward_edge_type)
362
+ self._num_neighbors = NumNeighbors(num_neighbors_values)
363
+ else:
364
+ self._num_neighbors = num_neighbors
291
365
  else:
292
- self._num_neighbors = NumNeighbors(num_neighbors)
366
+ if isinstance(num_neighbors,
367
+ dict) and self.sample_direction == 'backward':
368
+ # intentionally recursing here to make sure num_neighbors is
369
+ # set as expected for the user
370
+ self.num_neighbors = NumNeighbors(
371
+ remap_keys(num_neighbors, self.to_backward_edge_type))
372
+ else:
373
+ self._num_neighbors = NumNeighbors(num_neighbors)
293
374
 
294
375
  @property
295
376
  def is_hetero(self) -> bool:
@@ -321,7 +402,7 @@ class NeighborSampler(BaseSampler):
321
402
  ) -> Union[SamplerOutput, HeteroSamplerOutput]:
322
403
  out = node_sample(inputs, self._sample)
323
404
  if self.subgraph_type == SubgraphType.bidirectional:
324
- out = out.to_bidirectional()
405
+ out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges)
325
406
  return out
326
407
 
327
408
  # Edge-based sampling #####################################################
@@ -334,7 +415,7 @@ class NeighborSampler(BaseSampler):
334
415
  out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,
335
416
  self.node_time, neg_sampling)
336
417
  if self.subgraph_type == SubgraphType.bidirectional:
337
- out = out.to_bidirectional()
418
+ out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges)
338
419
  return out
339
420
 
340
421
  # Other Utilities #########################################################
@@ -431,17 +512,34 @@ class NeighborSampler(BaseSampler):
431
512
  raise ImportError(f"'{self.__class__.__name__}' requires "
432
513
  f"either 'pyg-lib' or 'torch-sparse'")
433
514
 
515
+ if self.sample_direction == 'backward':
516
+ row, col = col, row
517
+
518
+ row = remap_keys(row, self.to_edge_type)
519
+ col = remap_keys(col, self.to_edge_type)
520
+ edge = remap_keys(edge, self.to_edge_type)
521
+
522
+ # In the case of backward sampling, we need to restore the edges
523
+ # keys to be forward facing in the HeteroSamplerOutput object.
524
+ if self.sample_direction == 'backward':
525
+ row = remap_keys(row, self.to_restored_edge_type)
526
+ col = remap_keys(col, self.to_restored_edge_type)
527
+ edge = remap_keys(edge, self.to_restored_edge_type)
528
+
434
529
  if num_sampled_edges is not None:
435
530
  num_sampled_edges = remap_keys(
436
531
  num_sampled_edges,
437
532
  self.to_edge_type,
438
533
  )
534
+ if self.sample_direction == 'backward':
535
+ num_sampled_edges = remap_keys(num_sampled_edges,
536
+ self.to_restored_edge_type)
439
537
 
440
538
  return HeteroSamplerOutput(
441
539
  node=node,
442
- row=remap_keys(row, self.to_edge_type),
443
- col=remap_keys(col, self.to_edge_type),
444
- edge=remap_keys(edge, self.to_edge_type),
540
+ row=row,
541
+ col=col,
542
+ edge=edge,
445
543
  batch=batch,
446
544
  num_sampled_nodes=num_sampled_nodes,
447
545
  num_sampled_edges=num_sampled_edges,
@@ -508,6 +606,9 @@ class NeighborSampler(BaseSampler):
508
606
  raise ImportError(f"'{self.__class__.__name__}' requires "
509
607
  f"either 'pyg-lib' or 'torch-sparse'")
510
608
 
609
+ if self.sample_direction == 'backward':
610
+ row, col = col, row
611
+
511
612
  return SamplerOutput(
512
613
  node=node,
513
614
  row=row,
@@ -519,6 +620,178 @@ class NeighborSampler(BaseSampler):
519
620
  )
520
621
 
521
622
 
623
+ class BidirectionalNeighborSampler(NeighborSampler):
624
+ """A sampler that allows for both upstream and downstream sampling."""
625
+ def __init__(
626
+ self,
627
+ data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
628
+ num_neighbors: NumNeighborsType,
629
+ subgraph_type: Union[SubgraphType, str] = 'directional',
630
+ replace: bool = False,
631
+ disjoint: bool = False,
632
+ temporal_strategy: str = 'uniform',
633
+ time_attr: Optional[str] = None,
634
+ weight_attr: Optional[str] = None,
635
+ is_sorted: bool = False,
636
+ share_memory: bool = False,
637
+ # Deprecated:
638
+ directed: bool = True,
639
+ ):
640
+
641
+ # TODO(zaristei)
642
+ if isinstance(num_neighbors, NumNeighbors) and isinstance(
643
+ num_neighbors.values, dict) or isinstance(num_neighbors, dict):
644
+ raise RuntimeError(
645
+ "BidirectionalNeighborSampler does not yet support edge "
646
+ "delimited sampling.")
647
+
648
+ self.forward_sampler = NeighborSampler(
649
+ data, num_neighbors, subgraph_type, replace, disjoint,
650
+ temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
651
+ sample_direction='forward', directed=directed)
652
+ self.backward_sampler = NeighborSampler(
653
+ data, num_neighbors, subgraph_type, replace, disjoint,
654
+ temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
655
+ sample_direction='backward', directed=directed)
656
+
657
+ # Trigger warnings on init if number of hops is greater than 1
658
+ self.num_neighbors = num_neighbors
659
+ self.subgraph_type = subgraph_type
660
+
661
+ @property
662
+ def num_neighbors(self) -> NumNeighbors:
663
+ return self._num_neighbors
664
+
665
+ @num_neighbors.setter
666
+ def num_neighbors(self, num_neighbors: NumNeighborsType):
667
+ if not isinstance(num_neighbors, NumNeighbors):
668
+ num_neighbors = NumNeighbors(num_neighbors)
669
+ if num_neighbors.num_hops > 1:
670
+ print("Warning: Number of hops is greater than 1, resulting in "
671
+ "memory-expensive recursive calls.")
672
+ self._num_neighbors = num_neighbors
673
+
674
+ @property
675
+ def is_hetero(self) -> bool:
676
+ return self.forward_sampler.is_hetero
677
+
678
+ @property
679
+ def is_temporal(self) -> bool:
680
+ return self.forward_sampler.is_temporal
681
+
682
+ @property
683
+ def disjoint(self) -> bool:
684
+ return self.forward_sampler.disjoint
685
+
686
+ @disjoint.setter
687
+ def disjoint(self, disjoint: bool):
688
+ self.forward_sampler.disjoint = disjoint
689
+ self.backward_sampler.disjoint = disjoint
690
+
691
+ def sample_from_nodes(
692
+ self,
693
+ inputs: NodeSamplerInput,
694
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
695
+ return super().sample_from_nodes(inputs)
696
+
697
+ def sample_from_edges(
698
+ self,
699
+ inputs: EdgeSamplerInput,
700
+ neg_sampling: Optional[NegativeSampling] = None,
701
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
702
+ # TODO(zaristei) Figure out what exactly regular and negative sampling
703
+ # imply for bidirectional sampling case
704
+ if neg_sampling is not None:
705
+ raise RuntimeError(
706
+ "BidirectionalNeighborSampler does not yet support "
707
+ "negative sampling.")
708
+ # Not thoroughly tested yet!
709
+ return super().sample_from_edges(inputs)
710
+
711
+ @property
712
+ def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
713
+ return self.forward_sampler.edge_permutation
714
+
715
+ def _sample(
716
+ self,
717
+ seed: Union[Tensor, Dict[NodeType, Tensor]],
718
+ seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
719
+ **kwargs,
720
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
721
+
722
+ if seed_time is not None:
723
+ raise NotImplementedError(
724
+ "BidirectionalNeighborSampler does not yet support "
725
+ "temporal sampling.")
726
+
727
+ if self.is_hetero:
728
+ raise NotImplementedError(
729
+ "BidirectionalNeighborSampler does not yet support "
730
+ "heterogeneous sampling.")
731
+ else:
732
+ current_seed = seed
733
+ current_seed_batch = None
734
+ current_seed_time = seed_time
735
+ seen_seed_set = {int(node) for node in current_seed}
736
+ if self.disjoint:
737
+ current_seed_batch = torch.arange(len(current_seed))
738
+ seen_seed_set = {
739
+ (int(node), int(batch))
740
+ for node, batch in zip(current_seed, current_seed_batch)
741
+ }
742
+
743
+ iter_results = []
744
+
745
+ for n_neighbors in self.num_neighbors.values:
746
+ current_n_neighbors = [n_neighbors]
747
+ self.forward_sampler.num_neighbors = current_n_neighbors
748
+ self.backward_sampler.num_neighbors = current_n_neighbors
749
+
750
+ fwd_result = self.forward_sampler._sample(
751
+ current_seed, current_seed_time, **kwargs)
752
+ bwd_result = self.backward_sampler._sample(
753
+ current_seed, current_seed_time, **kwargs)
754
+ # The seeds for the next iteration will be the new nodes in
755
+ # this iteration
756
+ iter_result = fwd_result.merge_with(bwd_result)
757
+ iter_results.append(iter_result)
758
+
759
+ # Find the nodes not yet seen to set a seed for next iteration
760
+ if self.disjoint:
761
+ iter_seed_global_batch = global_to_local_node_idx(
762
+ current_seed_batch, iter_result.batch)
763
+ iter_result.seed_node = seed[iter_seed_global_batch]
764
+
765
+ keep_mask = torch.tensor([
766
+ (int(node), int(batch)) not in seen_seed_set
767
+ for node, batch in zip(iter_result.node,
768
+ iter_seed_global_batch)
769
+ ])
770
+ next_seed = [(int(node), int(batch))
771
+ for node, batch in zip(
772
+ iter_result.node[keep_mask],
773
+ iter_seed_global_batch[keep_mask])
774
+ ] if keep_mask.any() else []
775
+ current_seed, current_seed_batch = torch.tensor(
776
+ next_seed).reshape(-1, 2).transpose(0, 1).contiguous()
777
+ else:
778
+ keep_mask = torch.tensor([
779
+ int(node) not in seen_seed_set
780
+ for node in iter_result.node
781
+ ])
782
+ next_seed = [
783
+ int(node) for node in iter_result.node[keep_mask]
784
+ ] if keep_mask.any() else []
785
+ current_seed = torch.tensor(next_seed)
786
+
787
+ seen_seed_set |= set(next_seed)
788
+
789
+ # TODO(zaristei) figure out how to update seed times for
790
+ # temporal sampling
791
+
792
+ return SamplerOutput.collate(iter_results)
793
+
794
+
522
795
  # Sampling Utilities ##########################################################
523
796
 
524
797
 
@@ -805,7 +1078,7 @@ def neg_sample(
805
1078
  out = out.view(num_samples, seed.numel())
806
1079
  mask = node_time[out] > seed_time # holds all invalid samples.
807
1080
  neg_sampling_complete = False
808
- for i in range(5): # pragma: no cover
1081
+ for _ in range(5): # pragma: no cover
809
1082
  num_invalid = int(mask.sum())
810
1083
  if num_invalid == 0:
811
1084
  neg_sampling_complete = True
@@ -9,6 +9,15 @@ from torch_geometric.index import index2ptr
9
9
  from torch_geometric.typing import EdgeType, NodeType, OptTensor
10
10
  from torch_geometric.utils import coalesce, index_sort, lexsort
11
11
 
12
+
13
+ def reverse_edge_type(edge_type: EdgeType) -> EdgeType:
14
+ """Reverses edge types for heterogeneous graphs. Useful in cases of
15
+ backward sampling.
16
+ """
17
+ return (edge_type[2], edge_type[1],
18
+ edge_type[0]) if edge_type is not None else None
19
+
20
+
12
21
  # Edge Layout Conversion ######################################################
13
22
 
14
23
 
@@ -41,6 +50,7 @@ def to_csc(
41
50
  is_sorted: bool = False,
42
51
  src_node_time: Optional[Tensor] = None,
43
52
  edge_time: Optional[Tensor] = None,
53
+ to_transpose: bool = False,
44
54
  ) -> Tuple[Tensor, Tensor, OptTensor]:
45
55
  # Convert the graph data into a suitable format for sampling (CSC format).
46
56
  # Returns the `colptr` and `row` indices of the graph, as well as an
@@ -53,7 +63,10 @@ def to_csc(
53
63
  if src_node_time is not None:
54
64
  raise NotImplementedError("Temporal sampling via 'SparseTensor' "
55
65
  "format not yet supported")
56
- colptr, row, _ = data.adj.csc()
66
+ if to_transpose:
67
+ row, colptr, _ = data.adj.csr()
68
+ else:
69
+ colptr, row, _ = data.adj.csc()
57
70
 
58
71
  elif hasattr(data, 'adj_t'):
59
72
  if src_node_time is not None:
@@ -65,13 +78,21 @@ def to_csc(
65
78
  # raise NotImplementedError("Temporal sampling via 'SparseTensor' "
66
79
  # "format not yet supported")
67
80
  pass
68
- colptr, row, _ = data.adj_t.csr()
81
+ if to_transpose:
82
+ row, colptr, _ = data.adj_t.csc()
83
+ else:
84
+ colptr, row, _ = data.adj_t.csr()
69
85
 
70
86
  elif data.edge_index is not None:
71
- row, col = data.edge_index
87
+ if to_transpose:
88
+ col, row = data.edge_index
89
+ else:
90
+ row, col = data.edge_index
91
+
72
92
  if not is_sorted:
73
93
  row, col, perm = sort_csc(row, col, src_node_time, edge_time)
74
- colptr = index2ptr(col, data.size(1))
94
+ colptr = index2ptr(col,
95
+ data.size(1) if not to_transpose else data.size(0))
75
96
  else:
76
97
  row = torch.empty(0, dtype=torch.long, device=device)
77
98
  colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
@@ -97,6 +118,7 @@ def to_hetero_csc(
97
118
  is_sorted: bool = False,
98
119
  node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
99
120
  edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
121
+ to_transpose: bool = False,
100
122
  ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
101
123
  # Convert the heterogeneous graph data into a suitable format for sampling
102
124
  # (CSC format).
@@ -108,7 +130,11 @@ def to_hetero_csc(
108
130
  src_node_time = (node_time_dict or {}).get(edge_type[0], None)
109
131
  edge_time = (edge_time_dict or {}).get(edge_type, None)
110
132
  out = to_csc(store, device, share_memory, is_sorted, src_node_time,
111
- edge_time)
133
+ edge_time, to_transpose)
134
+ # Edge types need to be reversed for backward sampling:
135
+ if to_transpose:
136
+ edge_type = reverse_edge_type(edge_type)
137
+
112
138
  colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out
113
139
 
114
140
  return colptr_dict, row_dict, perm_dict
@@ -160,3 +186,65 @@ def remap_keys(
160
186
  k if k in exclude else mapping.get(k, k): v
161
187
  for k, v in inputs.items()
162
188
  }
189
+
190
+
191
+ def local_to_global_node_idx(node_values: Tensor,
192
+ local_indices: Tensor) -> Tensor:
193
+ """Convert a tensor of indices referring to elements in the node_values
194
+ tensor to their values.
195
+
196
+ Args:
197
+ node_values (Tensor): The node values. (num_nodes, feature_dim)
198
+ local_indices (Tensor): The local indices. (num_indices)
199
+
200
+ Returns:
201
+ Tensor: The values of the node_values tensor at the local indices.
202
+ (num_indices, feature_dim)
203
+ """
204
+ return torch.index_select(node_values, dim=0, index=local_indices)
205
+
206
+
207
+ def global_to_local_node_idx(node_values: Tensor,
208
+ local_values: Tensor) -> Tensor:
209
+ """Converts a tensor of values that are contained in the node_values
210
+ tensor to their indices in that tensor.
211
+
212
+ Args:
213
+ node_values (Tensor): The node values. (num_nodes, feature_dim)
214
+ local_values (Tensor): The local values. (num_indices, feature_dim)
215
+
216
+ Returns:
217
+ Tensor: The indices of the local values in the node_values tensor.
218
+ (num_indices)
219
+ """
220
+ if node_values.dim() == 1:
221
+ node_values = node_values.unsqueeze(1)
222
+ if local_values.dim() == 1:
223
+ local_values = local_values.unsqueeze(1)
224
+ node_values_expand = node_values.unsqueeze(-1).expand(
225
+ *node_values.shape,
226
+ local_values.shape[0]) # (num_nodes, feature_dim, num_indices)
227
+ local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand(
228
+ *node_values_expand.shape) # (num_nodes, feature_dim, num_indices)
229
+ idx_match = torch.all(node_values_expand == local_values_expand,
230
+ dim=1).nonzero() # (num_indices, 2)
231
+ sort_idx = torch.argsort(idx_match[:, 1])
232
+
233
+ return idx_match[:, 0][sort_idx]
234
+
235
+
236
+ def unique_unsorted(tensor: Tensor) -> Tensor:
237
+ """Returns the unique elements of a tensor while preserving the original
238
+ order.
239
+
240
+ Necessary because torch.unique() ignores sort parameter.
241
+ """
242
+ seen = set()
243
+ output = []
244
+ for val in tensor:
245
+ val = tuple(val.tolist())
246
+ if val not in seen:
247
+ seen.add(val)
248
+ output.append(val)
249
+ return torch.tensor(output, dtype=tensor.dtype,
250
+ device=tensor.device).reshape((-1, *tensor.shape[1:]))
@@ -17,11 +17,13 @@ from .decorators import (
17
17
  onlyOnline,
18
18
  onlyGraphviz,
19
19
  onlyNeighborSampler,
20
+ onlyRAG,
20
21
  has_package,
21
22
  withPackage,
22
23
  withDevice,
23
24
  withCUDA,
24
25
  withMETIS,
26
+ withHashTensor,
25
27
  disableExtensions,
26
28
  withoutExtensions,
27
29
  )
@@ -48,11 +50,13 @@ __all__ = [
48
50
  'onlyOnline',
49
51
  'onlyGraphviz',
50
52
  'onlyNeighborSampler',
53
+ 'onlyRAG',
51
54
  'has_package',
52
55
  'withPackage',
53
56
  'withDevice',
54
57
  'withCUDA',
55
58
  'withMETIS',
59
+ 'withHashTensor',
56
60
  'disableExtensions',
57
61
  'withoutExtensions',
58
62
  'assert_module',