pyg-nightly 2.7.0.dev20250826__py3-none-any.whl → 2.7.0.dev20250827__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,13 @@ from torch_geometric.sampler import (
25
25
  SamplerOutput,
26
26
  )
27
27
  from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType
28
- from torch_geometric.sampler.utils import remap_keys, to_csc, to_hetero_csc
28
+ from torch_geometric.sampler.utils import (
29
+ global_to_local_node_idx,
30
+ remap_keys,
31
+ reverse_edge_type,
32
+ to_csc,
33
+ to_hetero_csc,
34
+ )
29
35
  from torch_geometric.typing import EdgeType, NodeType, OptTensor
30
36
 
31
37
  NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]
@@ -47,8 +53,8 @@ class NeighborSampler(BaseSampler):
47
53
  weight_attr: Optional[str] = None,
48
54
  is_sorted: bool = False,
49
55
  share_memory: bool = False,
50
- # Deprecated:
51
- directed: bool = True,
56
+ directed: bool = True, # Deprecated
57
+ sample_direction: Literal['forward', 'backward'] = 'forward',
52
58
  ):
53
59
  if not directed:
54
60
  subgraph_type = SubgraphType.induced
@@ -66,6 +72,14 @@ class NeighborSampler(BaseSampler):
66
72
  f"accelerated neighborhood sampling", stacklevel=2)
67
73
 
68
74
  self.data_type = DataType.from_data(data)
75
+ self.sample_direction = sample_direction
76
+
77
+ if self.sample_direction == 'backward':
78
+ # TODO(zaristei)
79
+ if time_attr is not None:
80
+ raise NotImplementedError(
81
+ "Temporal Sampling not yet supported for backward sampling"
82
+ )
69
83
 
70
84
  if self.data_type == DataType.homogeneous:
71
85
  self.num_nodes = data.num_nodes
@@ -87,7 +101,8 @@ class NeighborSampler(BaseSampler):
87
101
  self.colptr, self.row, self.perm = to_csc(
88
102
  data, device='cpu', share_memory=share_memory,
89
103
  is_sorted=is_sorted, src_node_time=self.node_time,
90
- edge_time=self.edge_time)
104
+ edge_time=self.edge_time,
105
+ to_transpose=self.sample_direction == 'backward')
91
106
 
92
107
  if self.edge_time is not None and self.perm is not None:
93
108
  self.edge_time = self.edge_time[self.perm]
@@ -101,6 +116,17 @@ class NeighborSampler(BaseSampler):
101
116
  elif self.data_type == DataType.heterogeneous:
102
117
  self.node_types, self.edge_types = data.metadata()
103
118
 
119
+ # reverse edge types if sample_direction is backward
120
+ if self.sample_direction == 'backward':
121
+ self.edge_types = [
122
+ reverse_edge_type(edge_type)
123
+ for edge_type in self.edge_types
124
+ ]
125
+ self.to_restored_edge_type = {
126
+ k: reverse_edge_type(k)
127
+ for k in self.edge_types
128
+ }
129
+
104
130
  self.num_nodes = {k: data[k].num_nodes for k in self.node_types}
105
131
 
106
132
  self.node_time: Optional[Dict[NodeType, Tensor]] = None
@@ -141,7 +167,8 @@ class NeighborSampler(BaseSampler):
141
167
  colptr_dict, row_dict, self.perm = to_hetero_csc(
142
168
  data, device='cpu', share_memory=share_memory,
143
169
  is_sorted=is_sorted, node_time_dict=self.node_time,
144
- edge_time_dict=self.edge_time)
170
+ edge_time_dict=self.edge_time,
171
+ to_transpose=self.sample_direction == 'backward')
145
172
 
146
173
  self.row_dict = remap_keys(row_dict, self.to_rel_type)
147
174
  self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
@@ -172,6 +199,21 @@ class NeighborSampler(BaseSampler):
172
199
  edge_attrs = graph_store.get_all_edge_attrs()
173
200
  self.edge_types = list({attr.edge_type for attr in edge_attrs})
174
201
 
202
+ # reverse edge types if sample_direction is backward
203
+ if self.sample_direction == 'backward':
204
+ self.edge_types = [
205
+ reverse_edge_type(edge_type)
206
+ for edge_type in self.edge_types
207
+ ]
208
+ self.to_restored_edge_type = {
209
+ k: reverse_edge_type(k)
210
+ for k in self.edge_types
211
+ }
212
+ self.to_backward_edge_type = {
213
+ v: k
214
+ for k, v in self.to_restored_edge_type.items()
215
+ }
216
+
175
217
  if weight_attr is not None:
176
218
  raise NotImplementedError(
177
219
  f"'weight_attr' argument not yet supported within "
@@ -221,7 +263,10 @@ class NeighborSampler(BaseSampler):
221
263
  else:
222
264
  self.edge_time = time_tensor
223
265
 
224
- self.row, self.colptr, self.perm = graph_store.csc()
266
+ if self.sample_direction == 'forward':
267
+ self.row, self.colptr, self.perm = graph_store.csc()
268
+ elif self.sample_direction == 'backward':
269
+ self.colptr, self.row, self.perm = graph_store.csr()
225
270
 
226
271
  else:
227
272
  node_types = [
@@ -261,8 +306,17 @@ class NeighborSampler(BaseSampler):
261
306
  # Conversion to/from C++ string type (see above):
262
307
  self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
263
308
  self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
264
- # Convert the graph data into CSC format for sampling:
265
- row_dict, colptr_dict, self.perm = graph_store.csc()
309
+ if self.sample_direction == 'forward':
310
+ row_dict, colptr_dict, self.perm = graph_store.csc()
311
+ elif self.sample_direction == 'backward':
312
+ colptr_dict, row_dict, self.perm = graph_store.csr()
313
+
314
+ colptr_dict = remap_keys(colptr_dict,
315
+ self.to_backward_edge_type)
316
+ row_dict = remap_keys(row_dict, self.to_backward_edge_type)
317
+ self.perm = remap_keys(self.perm,
318
+ self.to_backward_edge_type)
319
+
266
320
  self.row_dict = remap_keys(row_dict, self.to_rel_type)
267
321
  self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
268
322
 
@@ -285,14 +339,38 @@ class NeighborSampler(BaseSampler):
285
339
 
286
340
  @property
287
341
  def num_neighbors(self) -> NumNeighbors:
342
+ if self.sample_direction == 'backward':
343
+ return self._input_num_neighbors \
344
+ if self._input_num_neighbors is not None \
345
+ else self._num_neighbors
288
346
  return self._num_neighbors
289
347
 
290
348
  @num_neighbors.setter
291
349
  def num_neighbors(self, num_neighbors: NumNeighborsType):
350
+ # only used if sample direction is backward and num_neighbors has edge
351
+ # keys
352
+ self._input_num_neighbors = None
353
+
292
354
  if isinstance(num_neighbors, NumNeighbors):
293
- self._num_neighbors = num_neighbors
355
+ num_neighbors_values = num_neighbors.values
356
+ if isinstance(num_neighbors_values,
357
+ dict) and self.sample_direction == 'backward':
358
+ # reverse the edge_types if sample_direction is backward
359
+ self._input_num_neighbors = num_neighbors
360
+ num_neighbors_values = remap_keys(num_neighbors_values,
361
+ self.to_backward_edge_type)
362
+ self._num_neighbors = NumNeighbors(num_neighbors_values)
363
+ else:
364
+ self._num_neighbors = num_neighbors
294
365
  else:
295
- self._num_neighbors = NumNeighbors(num_neighbors)
366
+ if isinstance(num_neighbors,
367
+ dict) and self.sample_direction == 'backward':
368
+ # intentionally recursing here to make sure num_neighbors is
369
+ # set as expected for the user
370
+ self.num_neighbors = NumNeighbors(
371
+ remap_keys(num_neighbors, self.to_backward_edge_type))
372
+ else:
373
+ self._num_neighbors = NumNeighbors(num_neighbors)
296
374
 
297
375
  @property
298
376
  def is_hetero(self) -> bool:
@@ -434,17 +512,34 @@ class NeighborSampler(BaseSampler):
434
512
  raise ImportError(f"'{self.__class__.__name__}' requires "
435
513
  f"either 'pyg-lib' or 'torch-sparse'")
436
514
 
515
+ if self.sample_direction == 'backward':
516
+ row, col = col, row
517
+
518
+ row = remap_keys(row, self.to_edge_type)
519
+ col = remap_keys(col, self.to_edge_type)
520
+ edge = remap_keys(edge, self.to_edge_type)
521
+
522
+ # In the case of backward sampling, we need to restore the edges
523
+ # keys to be forward facing in the HeteroSamplerOutput object.
524
+ if self.sample_direction == 'backward':
525
+ row = remap_keys(row, self.to_restored_edge_type)
526
+ col = remap_keys(col, self.to_restored_edge_type)
527
+ edge = remap_keys(edge, self.to_restored_edge_type)
528
+
437
529
  if num_sampled_edges is not None:
438
530
  num_sampled_edges = remap_keys(
439
531
  num_sampled_edges,
440
532
  self.to_edge_type,
441
533
  )
534
+ if self.sample_direction == 'backward':
535
+ num_sampled_edges = remap_keys(num_sampled_edges,
536
+ self.to_restored_edge_type)
442
537
 
443
538
  return HeteroSamplerOutput(
444
539
  node=node,
445
- row=remap_keys(row, self.to_edge_type),
446
- col=remap_keys(col, self.to_edge_type),
447
- edge=remap_keys(edge, self.to_edge_type),
540
+ row=row,
541
+ col=col,
542
+ edge=edge,
448
543
  batch=batch,
449
544
  num_sampled_nodes=num_sampled_nodes,
450
545
  num_sampled_edges=num_sampled_edges,
@@ -511,6 +606,9 @@ class NeighborSampler(BaseSampler):
511
606
  raise ImportError(f"'{self.__class__.__name__}' requires "
512
607
  f"either 'pyg-lib' or 'torch-sparse'")
513
608
 
609
+ if self.sample_direction == 'backward':
610
+ row, col = col, row
611
+
514
612
  return SamplerOutput(
515
613
  node=node,
516
614
  row=row,
@@ -522,6 +620,178 @@ class NeighborSampler(BaseSampler):
522
620
  )
523
621
 
524
622
 
623
+ class BidirectionalNeighborSampler(NeighborSampler):
624
+ """A sampler that allows for both upstream and downstream sampling."""
625
+ def __init__(
626
+ self,
627
+ data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
628
+ num_neighbors: NumNeighborsType,
629
+ subgraph_type: Union[SubgraphType, str] = 'directional',
630
+ replace: bool = False,
631
+ disjoint: bool = False,
632
+ temporal_strategy: str = 'uniform',
633
+ time_attr: Optional[str] = None,
634
+ weight_attr: Optional[str] = None,
635
+ is_sorted: bool = False,
636
+ share_memory: bool = False,
637
+ # Deprecated:
638
+ directed: bool = True,
639
+ ):
640
+
641
+ # TODO(zaristei)
642
+ if isinstance(num_neighbors, NumNeighbors) and isinstance(
643
+ num_neighbors.values, dict) or isinstance(num_neighbors, dict):
644
+ raise RuntimeError(
645
+ "BidirectionalNeighborSampler does not yet support edge "
646
+ "delimited sampling.")
647
+
648
+ self.forward_sampler = NeighborSampler(
649
+ data, num_neighbors, subgraph_type, replace, disjoint,
650
+ temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
651
+ sample_direction='forward', directed=directed)
652
+ self.backward_sampler = NeighborSampler(
653
+ data, num_neighbors, subgraph_type, replace, disjoint,
654
+ temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
655
+ sample_direction='backward', directed=directed)
656
+
657
+ # Trigger warnings on init if number of hops is greater than 1
658
+ self.num_neighbors = num_neighbors
659
+ self.subgraph_type = subgraph_type
660
+
661
+ @property
662
+ def num_neighbors(self) -> NumNeighbors:
663
+ return self._num_neighbors
664
+
665
+ @num_neighbors.setter
666
+ def num_neighbors(self, num_neighbors: NumNeighborsType):
667
+ if not isinstance(num_neighbors, NumNeighbors):
668
+ num_neighbors = NumNeighbors(num_neighbors)
669
+ if num_neighbors.num_hops > 1:
670
+ print("Warning: Number of hops is greater than 1, resulting in "
671
+ "memory-expensive recursive calls.")
672
+ self._num_neighbors = num_neighbors
673
+
674
+ @property
675
+ def is_hetero(self) -> bool:
676
+ return self.forward_sampler.is_hetero
677
+
678
+ @property
679
+ def is_temporal(self) -> bool:
680
+ return self.forward_sampler.is_temporal
681
+
682
+ @property
683
+ def disjoint(self) -> bool:
684
+ return self.forward_sampler.disjoint
685
+
686
+ @disjoint.setter
687
+ def disjoint(self, disjoint: bool):
688
+ self.forward_sampler.disjoint = disjoint
689
+ self.backward_sampler.disjoint = disjoint
690
+
691
+ def sample_from_nodes(
692
+ self,
693
+ inputs: NodeSamplerInput,
694
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
695
+ return super().sample_from_nodes(inputs)
696
+
697
+ def sample_from_edges(
698
+ self,
699
+ inputs: EdgeSamplerInput,
700
+ neg_sampling: Optional[NegativeSampling] = None,
701
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
702
+ # TODO(zaristei) Figure out what exactly regular and negative sampling
703
+ # imply for bidirectional sampling case
704
+ if neg_sampling is not None:
705
+ raise RuntimeError(
706
+ "BidirectionalNeighborSampler does not yet support "
707
+ "negative sampling.")
708
+ # Not thoroughly tested yet!
709
+ return super().sample_from_edges(inputs)
710
+
711
+ @property
712
+ def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
713
+ return self.forward_sampler.edge_permutation
714
+
715
+ def _sample(
716
+ self,
717
+ seed: Union[Tensor, Dict[NodeType, Tensor]],
718
+ seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
719
+ **kwargs,
720
+ ) -> Union[SamplerOutput, HeteroSamplerOutput]:
721
+
722
+ if seed_time is not None:
723
+ raise NotImplementedError(
724
+ "BidirectionalNeighborSampler does not yet support "
725
+ "temporal sampling.")
726
+
727
+ if self.is_hetero:
728
+ raise NotImplementedError(
729
+ "BidirectionalNeighborSampler does not yet support "
730
+ "heterogeneous sampling.")
731
+ else:
732
+ current_seed = seed
733
+ current_seed_batch = None
734
+ current_seed_time = seed_time
735
+ seen_seed_set = {int(node) for node in current_seed}
736
+ if self.disjoint:
737
+ current_seed_batch = torch.arange(len(current_seed))
738
+ seen_seed_set = {
739
+ (int(node), int(batch))
740
+ for node, batch in zip(current_seed, current_seed_batch)
741
+ }
742
+
743
+ iter_results = []
744
+
745
+ for n_neighbors in self.num_neighbors.values:
746
+ current_n_neighbors = [n_neighbors]
747
+ self.forward_sampler.num_neighbors = current_n_neighbors
748
+ self.backward_sampler.num_neighbors = current_n_neighbors
749
+
750
+ fwd_result = self.forward_sampler._sample(
751
+ current_seed, current_seed_time, **kwargs)
752
+ bwd_result = self.backward_sampler._sample(
753
+ current_seed, current_seed_time, **kwargs)
754
+ # The seeds for the next iteration will be the new nodes in
755
+ # this iteration
756
+ iter_result = fwd_result.merge_with(bwd_result)
757
+ iter_results.append(iter_result)
758
+
759
+ # Find the nodes not yet seen to set a seed for next iteration
760
+ if self.disjoint:
761
+ iter_seed_global_batch = global_to_local_node_idx(
762
+ current_seed_batch, iter_result.batch)
763
+ iter_result.seed_node = seed[iter_seed_global_batch]
764
+
765
+ keep_mask = torch.tensor([
766
+ (int(node), int(batch)) not in seen_seed_set
767
+ for node, batch in zip(iter_result.node,
768
+ iter_seed_global_batch)
769
+ ])
770
+ next_seed = [(int(node), int(batch))
771
+ for node, batch in zip(
772
+ iter_result.node[keep_mask],
773
+ iter_seed_global_batch[keep_mask])
774
+ ] if keep_mask.any() else []
775
+ current_seed, current_seed_batch = torch.tensor(
776
+ next_seed).reshape(-1, 2).transpose(0, 1).contiguous()
777
+ else:
778
+ keep_mask = torch.tensor([
779
+ int(node) not in seen_seed_set
780
+ for node in iter_result.node
781
+ ])
782
+ next_seed = [
783
+ int(node) for node in iter_result.node[keep_mask]
784
+ ] if keep_mask.any() else []
785
+ current_seed = torch.tensor(next_seed)
786
+
787
+ seen_seed_set |= set(next_seed)
788
+
789
+ # TODO(zaristei) figure out how to update seed times for
790
+ # temporal sampling
791
+
792
+ return SamplerOutput.collate(iter_results)
793
+
794
+
525
795
  # Sampling Utilities ##########################################################
526
796
 
527
797
 
@@ -9,6 +9,15 @@ from torch_geometric.index import index2ptr
9
9
  from torch_geometric.typing import EdgeType, NodeType, OptTensor
10
10
  from torch_geometric.utils import coalesce, index_sort, lexsort
11
11
 
12
+
13
+ def reverse_edge_type(edge_type: EdgeType) -> EdgeType:
14
+ """Reverses edge types for heterogeneous graphs. Useful in cases of
15
+ backward sampling.
16
+ """
17
+ return (edge_type[2], edge_type[1],
18
+ edge_type[0]) if edge_type is not None else None
19
+
20
+
12
21
  # Edge Layout Conversion ######################################################
13
22
 
14
23
 
@@ -41,6 +50,7 @@ def to_csc(
41
50
  is_sorted: bool = False,
42
51
  src_node_time: Optional[Tensor] = None,
43
52
  edge_time: Optional[Tensor] = None,
53
+ to_transpose: bool = False,
44
54
  ) -> Tuple[Tensor, Tensor, OptTensor]:
45
55
  # Convert the graph data into a suitable format for sampling (CSC format).
46
56
  # Returns the `colptr` and `row` indices of the graph, as well as an
@@ -53,7 +63,10 @@ def to_csc(
53
63
  if src_node_time is not None:
54
64
  raise NotImplementedError("Temporal sampling via 'SparseTensor' "
55
65
  "format not yet supported")
56
- colptr, row, _ = data.adj.csc()
66
+ if to_transpose:
67
+ row, colptr, _ = data.adj.csr()
68
+ else:
69
+ colptr, row, _ = data.adj.csc()
57
70
 
58
71
  elif hasattr(data, 'adj_t'):
59
72
  if src_node_time is not None:
@@ -65,13 +78,21 @@ def to_csc(
65
78
  # raise NotImplementedError("Temporal sampling via 'SparseTensor' "
66
79
  # "format not yet supported")
67
80
  pass
68
- colptr, row, _ = data.adj_t.csr()
81
+ if to_transpose:
82
+ row, colptr, _ = data.adj_t.csc()
83
+ else:
84
+ colptr, row, _ = data.adj_t.csr()
69
85
 
70
86
  elif data.edge_index is not None:
71
- row, col = data.edge_index
87
+ if to_transpose:
88
+ col, row = data.edge_index
89
+ else:
90
+ row, col = data.edge_index
91
+
72
92
  if not is_sorted:
73
93
  row, col, perm = sort_csc(row, col, src_node_time, edge_time)
74
- colptr = index2ptr(col, data.size(1))
94
+ colptr = index2ptr(col,
95
+ data.size(1) if not to_transpose else data.size(0))
75
96
  else:
76
97
  row = torch.empty(0, dtype=torch.long, device=device)
77
98
  colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
@@ -97,6 +118,7 @@ def to_hetero_csc(
97
118
  is_sorted: bool = False,
98
119
  node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
99
120
  edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
121
+ to_transpose: bool = False,
100
122
  ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
101
123
  # Convert the heterogeneous graph data into a suitable format for sampling
102
124
  # (CSC format).
@@ -108,7 +130,11 @@ def to_hetero_csc(
108
130
  src_node_time = (node_time_dict or {}).get(edge_type[0], None)
109
131
  edge_time = (edge_time_dict or {}).get(edge_type, None)
110
132
  out = to_csc(store, device, share_memory, is_sorted, src_node_time,
111
- edge_time)
133
+ edge_time, to_transpose)
134
+ # Edge types need to be reversed for backward sampling:
135
+ if to_transpose:
136
+ edge_type = reverse_edge_type(edge_type)
137
+
112
138
  colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out
113
139
 
114
140
  return colptr_dict, row_dict, perm_dict
@@ -205,3 +231,20 @@ def global_to_local_node_idx(node_values: Tensor,
205
231
  sort_idx = torch.argsort(idx_match[:, 1])
206
232
 
207
233
  return idx_match[:, 0][sort_idx]
234
+
235
+
236
+ def unique_unsorted(tensor: Tensor) -> Tensor:
237
+ """Returns the unique elements of a tensor while preserving the original
238
+ order.
239
+
240
+ Necessary because torch.unique() ignores sort parameter.
241
+ """
242
+ seen = set()
243
+ output = []
244
+ for val in tensor:
245
+ val = tuple(val.tolist())
246
+ if val not in seen:
247
+ seen.add(val)
248
+ output.append(val)
249
+ return torch.tensor(output, dtype=tensor.dtype,
250
+ device=tensor.device).reshape((-1, *tensor.shape[1:]))