pyg-nightly 2.7.0.dev20250826__py3-none-any.whl → 2.7.0.dev20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250826.dist-info → pyg_nightly-2.7.0.dev20250828.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250826.dist-info → pyg_nightly-2.7.0.dev20250828.dist-info}/RECORD +13 -12
- torch_geometric/__init__.py +3 -2
- torch_geometric/_onnx.py +214 -0
- torch_geometric/loader/link_neighbor_loader.py +1 -0
- torch_geometric/nn/models/__init__.py +3 -0
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +258 -3
- torch_geometric/sampler/neighbor_sampler.py +283 -13
- torch_geometric/sampler/utils.py +48 -5
- {pyg_nightly-2.7.0.dev20250826.dist-info → pyg_nightly-2.7.0.dev20250828.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250826.dist-info → pyg_nightly-2.7.0.dev20250828.dist-info}/licenses/LICENSE +0 -0
@@ -25,7 +25,13 @@ from torch_geometric.sampler import (
|
|
25
25
|
SamplerOutput,
|
26
26
|
)
|
27
27
|
from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType
|
28
|
-
from torch_geometric.sampler.utils import
|
28
|
+
from torch_geometric.sampler.utils import (
|
29
|
+
global_to_local_node_idx,
|
30
|
+
remap_keys,
|
31
|
+
reverse_edge_type,
|
32
|
+
to_csc,
|
33
|
+
to_hetero_csc,
|
34
|
+
)
|
29
35
|
from torch_geometric.typing import EdgeType, NodeType, OptTensor
|
30
36
|
|
31
37
|
NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]
|
@@ -47,8 +53,8 @@ class NeighborSampler(BaseSampler):
|
|
47
53
|
weight_attr: Optional[str] = None,
|
48
54
|
is_sorted: bool = False,
|
49
55
|
share_memory: bool = False,
|
50
|
-
# Deprecated
|
51
|
-
|
56
|
+
directed: bool = True, # Deprecated
|
57
|
+
sample_direction: Literal['forward', 'backward'] = 'forward',
|
52
58
|
):
|
53
59
|
if not directed:
|
54
60
|
subgraph_type = SubgraphType.induced
|
@@ -66,6 +72,14 @@ class NeighborSampler(BaseSampler):
|
|
66
72
|
f"accelerated neighborhood sampling", stacklevel=2)
|
67
73
|
|
68
74
|
self.data_type = DataType.from_data(data)
|
75
|
+
self.sample_direction = sample_direction
|
76
|
+
|
77
|
+
if self.sample_direction == 'backward':
|
78
|
+
# TODO(zaristei)
|
79
|
+
if time_attr is not None:
|
80
|
+
raise NotImplementedError(
|
81
|
+
"Temporal Sampling not yet supported for backward sampling"
|
82
|
+
)
|
69
83
|
|
70
84
|
if self.data_type == DataType.homogeneous:
|
71
85
|
self.num_nodes = data.num_nodes
|
@@ -87,7 +101,8 @@ class NeighborSampler(BaseSampler):
|
|
87
101
|
self.colptr, self.row, self.perm = to_csc(
|
88
102
|
data, device='cpu', share_memory=share_memory,
|
89
103
|
is_sorted=is_sorted, src_node_time=self.node_time,
|
90
|
-
edge_time=self.edge_time
|
104
|
+
edge_time=self.edge_time,
|
105
|
+
to_transpose=self.sample_direction == 'backward')
|
91
106
|
|
92
107
|
if self.edge_time is not None and self.perm is not None:
|
93
108
|
self.edge_time = self.edge_time[self.perm]
|
@@ -101,6 +116,17 @@ class NeighborSampler(BaseSampler):
|
|
101
116
|
elif self.data_type == DataType.heterogeneous:
|
102
117
|
self.node_types, self.edge_types = data.metadata()
|
103
118
|
|
119
|
+
# reverse edge types if sample_direction is backward
|
120
|
+
if self.sample_direction == 'backward':
|
121
|
+
self.edge_types = [
|
122
|
+
reverse_edge_type(edge_type)
|
123
|
+
for edge_type in self.edge_types
|
124
|
+
]
|
125
|
+
self.to_restored_edge_type = {
|
126
|
+
k: reverse_edge_type(k)
|
127
|
+
for k in self.edge_types
|
128
|
+
}
|
129
|
+
|
104
130
|
self.num_nodes = {k: data[k].num_nodes for k in self.node_types}
|
105
131
|
|
106
132
|
self.node_time: Optional[Dict[NodeType, Tensor]] = None
|
@@ -141,7 +167,8 @@ class NeighborSampler(BaseSampler):
|
|
141
167
|
colptr_dict, row_dict, self.perm = to_hetero_csc(
|
142
168
|
data, device='cpu', share_memory=share_memory,
|
143
169
|
is_sorted=is_sorted, node_time_dict=self.node_time,
|
144
|
-
edge_time_dict=self.edge_time
|
170
|
+
edge_time_dict=self.edge_time,
|
171
|
+
to_transpose=self.sample_direction == 'backward')
|
145
172
|
|
146
173
|
self.row_dict = remap_keys(row_dict, self.to_rel_type)
|
147
174
|
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
|
@@ -172,6 +199,21 @@ class NeighborSampler(BaseSampler):
|
|
172
199
|
edge_attrs = graph_store.get_all_edge_attrs()
|
173
200
|
self.edge_types = list({attr.edge_type for attr in edge_attrs})
|
174
201
|
|
202
|
+
# reverse edge types if sample_direction is backward
|
203
|
+
if self.sample_direction == 'backward':
|
204
|
+
self.edge_types = [
|
205
|
+
reverse_edge_type(edge_type)
|
206
|
+
for edge_type in self.edge_types
|
207
|
+
]
|
208
|
+
self.to_restored_edge_type = {
|
209
|
+
k: reverse_edge_type(k)
|
210
|
+
for k in self.edge_types
|
211
|
+
}
|
212
|
+
self.to_backward_edge_type = {
|
213
|
+
v: k
|
214
|
+
for k, v in self.to_restored_edge_type.items()
|
215
|
+
}
|
216
|
+
|
175
217
|
if weight_attr is not None:
|
176
218
|
raise NotImplementedError(
|
177
219
|
f"'weight_attr' argument not yet supported within "
|
@@ -221,7 +263,10 @@ class NeighborSampler(BaseSampler):
|
|
221
263
|
else:
|
222
264
|
self.edge_time = time_tensor
|
223
265
|
|
224
|
-
|
266
|
+
if self.sample_direction == 'forward':
|
267
|
+
self.row, self.colptr, self.perm = graph_store.csc()
|
268
|
+
elif self.sample_direction == 'backward':
|
269
|
+
self.colptr, self.row, self.perm = graph_store.csr()
|
225
270
|
|
226
271
|
else:
|
227
272
|
node_types = [
|
@@ -261,8 +306,17 @@ class NeighborSampler(BaseSampler):
|
|
261
306
|
# Conversion to/from C++ string type (see above):
|
262
307
|
self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
|
263
308
|
self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
|
264
|
-
|
265
|
-
|
309
|
+
if self.sample_direction == 'forward':
|
310
|
+
row_dict, colptr_dict, self.perm = graph_store.csc()
|
311
|
+
elif self.sample_direction == 'backward':
|
312
|
+
colptr_dict, row_dict, self.perm = graph_store.csr()
|
313
|
+
|
314
|
+
colptr_dict = remap_keys(colptr_dict,
|
315
|
+
self.to_backward_edge_type)
|
316
|
+
row_dict = remap_keys(row_dict, self.to_backward_edge_type)
|
317
|
+
self.perm = remap_keys(self.perm,
|
318
|
+
self.to_backward_edge_type)
|
319
|
+
|
266
320
|
self.row_dict = remap_keys(row_dict, self.to_rel_type)
|
267
321
|
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
|
268
322
|
|
@@ -285,14 +339,38 @@ class NeighborSampler(BaseSampler):
|
|
285
339
|
|
286
340
|
@property
|
287
341
|
def num_neighbors(self) -> NumNeighbors:
|
342
|
+
if self.sample_direction == 'backward':
|
343
|
+
return self._input_num_neighbors \
|
344
|
+
if self._input_num_neighbors is not None \
|
345
|
+
else self._num_neighbors
|
288
346
|
return self._num_neighbors
|
289
347
|
|
290
348
|
@num_neighbors.setter
|
291
349
|
def num_neighbors(self, num_neighbors: NumNeighborsType):
|
350
|
+
# only used if sample direction is backward and num_neighbors has edge
|
351
|
+
# keys
|
352
|
+
self._input_num_neighbors = None
|
353
|
+
|
292
354
|
if isinstance(num_neighbors, NumNeighbors):
|
293
|
-
|
355
|
+
num_neighbors_values = num_neighbors.values
|
356
|
+
if isinstance(num_neighbors_values,
|
357
|
+
dict) and self.sample_direction == 'backward':
|
358
|
+
# reverse the edge_types if sample_direction is backward
|
359
|
+
self._input_num_neighbors = num_neighbors
|
360
|
+
num_neighbors_values = remap_keys(num_neighbors_values,
|
361
|
+
self.to_backward_edge_type)
|
362
|
+
self._num_neighbors = NumNeighbors(num_neighbors_values)
|
363
|
+
else:
|
364
|
+
self._num_neighbors = num_neighbors
|
294
365
|
else:
|
295
|
-
|
366
|
+
if isinstance(num_neighbors,
|
367
|
+
dict) and self.sample_direction == 'backward':
|
368
|
+
# intentionally recursing here to make sure num_neighbors is
|
369
|
+
# set as expected for the user
|
370
|
+
self.num_neighbors = NumNeighbors(
|
371
|
+
remap_keys(num_neighbors, self.to_backward_edge_type))
|
372
|
+
else:
|
373
|
+
self._num_neighbors = NumNeighbors(num_neighbors)
|
296
374
|
|
297
375
|
@property
|
298
376
|
def is_hetero(self) -> bool:
|
@@ -434,17 +512,34 @@ class NeighborSampler(BaseSampler):
|
|
434
512
|
raise ImportError(f"'{self.__class__.__name__}' requires "
|
435
513
|
f"either 'pyg-lib' or 'torch-sparse'")
|
436
514
|
|
515
|
+
if self.sample_direction == 'backward':
|
516
|
+
row, col = col, row
|
517
|
+
|
518
|
+
row = remap_keys(row, self.to_edge_type)
|
519
|
+
col = remap_keys(col, self.to_edge_type)
|
520
|
+
edge = remap_keys(edge, self.to_edge_type)
|
521
|
+
|
522
|
+
# In the case of backward sampling, we need to restore the edges
|
523
|
+
# keys to be forward facing in the HeteroSamplerOutput object.
|
524
|
+
if self.sample_direction == 'backward':
|
525
|
+
row = remap_keys(row, self.to_restored_edge_type)
|
526
|
+
col = remap_keys(col, self.to_restored_edge_type)
|
527
|
+
edge = remap_keys(edge, self.to_restored_edge_type)
|
528
|
+
|
437
529
|
if num_sampled_edges is not None:
|
438
530
|
num_sampled_edges = remap_keys(
|
439
531
|
num_sampled_edges,
|
440
532
|
self.to_edge_type,
|
441
533
|
)
|
534
|
+
if self.sample_direction == 'backward':
|
535
|
+
num_sampled_edges = remap_keys(num_sampled_edges,
|
536
|
+
self.to_restored_edge_type)
|
442
537
|
|
443
538
|
return HeteroSamplerOutput(
|
444
539
|
node=node,
|
445
|
-
row=
|
446
|
-
col=
|
447
|
-
edge=
|
540
|
+
row=row,
|
541
|
+
col=col,
|
542
|
+
edge=edge,
|
448
543
|
batch=batch,
|
449
544
|
num_sampled_nodes=num_sampled_nodes,
|
450
545
|
num_sampled_edges=num_sampled_edges,
|
@@ -511,6 +606,9 @@ class NeighborSampler(BaseSampler):
|
|
511
606
|
raise ImportError(f"'{self.__class__.__name__}' requires "
|
512
607
|
f"either 'pyg-lib' or 'torch-sparse'")
|
513
608
|
|
609
|
+
if self.sample_direction == 'backward':
|
610
|
+
row, col = col, row
|
611
|
+
|
514
612
|
return SamplerOutput(
|
515
613
|
node=node,
|
516
614
|
row=row,
|
@@ -522,6 +620,178 @@ class NeighborSampler(BaseSampler):
|
|
522
620
|
)
|
523
621
|
|
524
622
|
|
623
|
+
class BidirectionalNeighborSampler(NeighborSampler):
|
624
|
+
"""A sampler that allows for both upstream and downstream sampling."""
|
625
|
+
def __init__(
|
626
|
+
self,
|
627
|
+
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
|
628
|
+
num_neighbors: NumNeighborsType,
|
629
|
+
subgraph_type: Union[SubgraphType, str] = 'directional',
|
630
|
+
replace: bool = False,
|
631
|
+
disjoint: bool = False,
|
632
|
+
temporal_strategy: str = 'uniform',
|
633
|
+
time_attr: Optional[str] = None,
|
634
|
+
weight_attr: Optional[str] = None,
|
635
|
+
is_sorted: bool = False,
|
636
|
+
share_memory: bool = False,
|
637
|
+
# Deprecated:
|
638
|
+
directed: bool = True,
|
639
|
+
):
|
640
|
+
|
641
|
+
# TODO(zaristei)
|
642
|
+
if isinstance(num_neighbors, NumNeighbors) and isinstance(
|
643
|
+
num_neighbors.values, dict) or isinstance(num_neighbors, dict):
|
644
|
+
raise RuntimeError(
|
645
|
+
"BidirectionalNeighborSampler does not yet support edge "
|
646
|
+
"delimited sampling.")
|
647
|
+
|
648
|
+
self.forward_sampler = NeighborSampler(
|
649
|
+
data, num_neighbors, subgraph_type, replace, disjoint,
|
650
|
+
temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
|
651
|
+
sample_direction='forward', directed=directed)
|
652
|
+
self.backward_sampler = NeighborSampler(
|
653
|
+
data, num_neighbors, subgraph_type, replace, disjoint,
|
654
|
+
temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
|
655
|
+
sample_direction='backward', directed=directed)
|
656
|
+
|
657
|
+
# Trigger warnings on init if number of hops is greater than 1
|
658
|
+
self.num_neighbors = num_neighbors
|
659
|
+
self.subgraph_type = subgraph_type
|
660
|
+
|
661
|
+
@property
|
662
|
+
def num_neighbors(self) -> NumNeighbors:
|
663
|
+
return self._num_neighbors
|
664
|
+
|
665
|
+
@num_neighbors.setter
|
666
|
+
def num_neighbors(self, num_neighbors: NumNeighborsType):
|
667
|
+
if not isinstance(num_neighbors, NumNeighbors):
|
668
|
+
num_neighbors = NumNeighbors(num_neighbors)
|
669
|
+
if num_neighbors.num_hops > 1:
|
670
|
+
print("Warning: Number of hops is greater than 1, resulting in "
|
671
|
+
"memory-expensive recursive calls.")
|
672
|
+
self._num_neighbors = num_neighbors
|
673
|
+
|
674
|
+
@property
|
675
|
+
def is_hetero(self) -> bool:
|
676
|
+
return self.forward_sampler.is_hetero
|
677
|
+
|
678
|
+
@property
|
679
|
+
def is_temporal(self) -> bool:
|
680
|
+
return self.forward_sampler.is_temporal
|
681
|
+
|
682
|
+
@property
|
683
|
+
def disjoint(self) -> bool:
|
684
|
+
return self.forward_sampler.disjoint
|
685
|
+
|
686
|
+
@disjoint.setter
|
687
|
+
def disjoint(self, disjoint: bool):
|
688
|
+
self.forward_sampler.disjoint = disjoint
|
689
|
+
self.backward_sampler.disjoint = disjoint
|
690
|
+
|
691
|
+
def sample_from_nodes(
|
692
|
+
self,
|
693
|
+
inputs: NodeSamplerInput,
|
694
|
+
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
695
|
+
return super().sample_from_nodes(inputs)
|
696
|
+
|
697
|
+
def sample_from_edges(
|
698
|
+
self,
|
699
|
+
inputs: EdgeSamplerInput,
|
700
|
+
neg_sampling: Optional[NegativeSampling] = None,
|
701
|
+
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
702
|
+
# TODO(zaristei) Figure out what exactly regular and negative sampling
|
703
|
+
# imply for bidirectional sampling case
|
704
|
+
if neg_sampling is not None:
|
705
|
+
raise RuntimeError(
|
706
|
+
"BidirectionalNeighborSampler does not yet support "
|
707
|
+
"negative sampling.")
|
708
|
+
# Not thoroughly tested yet!
|
709
|
+
return super().sample_from_edges(inputs)
|
710
|
+
|
711
|
+
@property
|
712
|
+
def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
|
713
|
+
return self.forward_sampler.edge_permutation
|
714
|
+
|
715
|
+
def _sample(
|
716
|
+
self,
|
717
|
+
seed: Union[Tensor, Dict[NodeType, Tensor]],
|
718
|
+
seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
|
719
|
+
**kwargs,
|
720
|
+
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
721
|
+
|
722
|
+
if seed_time is not None:
|
723
|
+
raise NotImplementedError(
|
724
|
+
"BidirectionalNeighborSampler does not yet support "
|
725
|
+
"temporal sampling.")
|
726
|
+
|
727
|
+
if self.is_hetero:
|
728
|
+
raise NotImplementedError(
|
729
|
+
"BidirectionalNeighborSampler does not yet support "
|
730
|
+
"heterogeneous sampling.")
|
731
|
+
else:
|
732
|
+
current_seed = seed
|
733
|
+
current_seed_batch = None
|
734
|
+
current_seed_time = seed_time
|
735
|
+
seen_seed_set = {int(node) for node in current_seed}
|
736
|
+
if self.disjoint:
|
737
|
+
current_seed_batch = torch.arange(len(current_seed))
|
738
|
+
seen_seed_set = {
|
739
|
+
(int(node), int(batch))
|
740
|
+
for node, batch in zip(current_seed, current_seed_batch)
|
741
|
+
}
|
742
|
+
|
743
|
+
iter_results = []
|
744
|
+
|
745
|
+
for n_neighbors in self.num_neighbors.values:
|
746
|
+
current_n_neighbors = [n_neighbors]
|
747
|
+
self.forward_sampler.num_neighbors = current_n_neighbors
|
748
|
+
self.backward_sampler.num_neighbors = current_n_neighbors
|
749
|
+
|
750
|
+
fwd_result = self.forward_sampler._sample(
|
751
|
+
current_seed, current_seed_time, **kwargs)
|
752
|
+
bwd_result = self.backward_sampler._sample(
|
753
|
+
current_seed, current_seed_time, **kwargs)
|
754
|
+
# The seeds for the next iteration will be the new nodes in
|
755
|
+
# this iteration
|
756
|
+
iter_result = fwd_result.merge_with(bwd_result)
|
757
|
+
iter_results.append(iter_result)
|
758
|
+
|
759
|
+
# Find the nodes not yet seen to set a seed for next iteration
|
760
|
+
if self.disjoint:
|
761
|
+
iter_seed_global_batch = global_to_local_node_idx(
|
762
|
+
current_seed_batch, iter_result.batch)
|
763
|
+
iter_result.seed_node = seed[iter_seed_global_batch]
|
764
|
+
|
765
|
+
keep_mask = torch.tensor([
|
766
|
+
(int(node), int(batch)) not in seen_seed_set
|
767
|
+
for node, batch in zip(iter_result.node,
|
768
|
+
iter_seed_global_batch)
|
769
|
+
])
|
770
|
+
next_seed = [(int(node), int(batch))
|
771
|
+
for node, batch in zip(
|
772
|
+
iter_result.node[keep_mask],
|
773
|
+
iter_seed_global_batch[keep_mask])
|
774
|
+
] if keep_mask.any() else []
|
775
|
+
current_seed, current_seed_batch = torch.tensor(
|
776
|
+
next_seed).reshape(-1, 2).transpose(0, 1).contiguous()
|
777
|
+
else:
|
778
|
+
keep_mask = torch.tensor([
|
779
|
+
int(node) not in seen_seed_set
|
780
|
+
for node in iter_result.node
|
781
|
+
])
|
782
|
+
next_seed = [
|
783
|
+
int(node) for node in iter_result.node[keep_mask]
|
784
|
+
] if keep_mask.any() else []
|
785
|
+
current_seed = torch.tensor(next_seed)
|
786
|
+
|
787
|
+
seen_seed_set |= set(next_seed)
|
788
|
+
|
789
|
+
# TODO(zaristei) figure out how to update seed times for
|
790
|
+
# temporal sampling
|
791
|
+
|
792
|
+
return SamplerOutput.collate(iter_results)
|
793
|
+
|
794
|
+
|
525
795
|
# Sampling Utilities ##########################################################
|
526
796
|
|
527
797
|
|
torch_geometric/sampler/utils.py
CHANGED
@@ -9,6 +9,15 @@ from torch_geometric.index import index2ptr
|
|
9
9
|
from torch_geometric.typing import EdgeType, NodeType, OptTensor
|
10
10
|
from torch_geometric.utils import coalesce, index_sort, lexsort
|
11
11
|
|
12
|
+
|
13
|
+
def reverse_edge_type(edge_type: EdgeType) -> EdgeType:
|
14
|
+
"""Reverses edge types for heterogeneous graphs. Useful in cases of
|
15
|
+
backward sampling.
|
16
|
+
"""
|
17
|
+
return (edge_type[2], edge_type[1],
|
18
|
+
edge_type[0]) if edge_type is not None else None
|
19
|
+
|
20
|
+
|
12
21
|
# Edge Layout Conversion ######################################################
|
13
22
|
|
14
23
|
|
@@ -41,6 +50,7 @@ def to_csc(
|
|
41
50
|
is_sorted: bool = False,
|
42
51
|
src_node_time: Optional[Tensor] = None,
|
43
52
|
edge_time: Optional[Tensor] = None,
|
53
|
+
to_transpose: bool = False,
|
44
54
|
) -> Tuple[Tensor, Tensor, OptTensor]:
|
45
55
|
# Convert the graph data into a suitable format for sampling (CSC format).
|
46
56
|
# Returns the `colptr` and `row` indices of the graph, as well as an
|
@@ -53,7 +63,10 @@ def to_csc(
|
|
53
63
|
if src_node_time is not None:
|
54
64
|
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
|
55
65
|
"format not yet supported")
|
56
|
-
|
66
|
+
if to_transpose:
|
67
|
+
row, colptr, _ = data.adj.csr()
|
68
|
+
else:
|
69
|
+
colptr, row, _ = data.adj.csc()
|
57
70
|
|
58
71
|
elif hasattr(data, 'adj_t'):
|
59
72
|
if src_node_time is not None:
|
@@ -65,13 +78,21 @@ def to_csc(
|
|
65
78
|
# raise NotImplementedError("Temporal sampling via 'SparseTensor' "
|
66
79
|
# "format not yet supported")
|
67
80
|
pass
|
68
|
-
|
81
|
+
if to_transpose:
|
82
|
+
row, colptr, _ = data.adj_t.csc()
|
83
|
+
else:
|
84
|
+
colptr, row, _ = data.adj_t.csr()
|
69
85
|
|
70
86
|
elif data.edge_index is not None:
|
71
|
-
|
87
|
+
if to_transpose:
|
88
|
+
col, row = data.edge_index
|
89
|
+
else:
|
90
|
+
row, col = data.edge_index
|
91
|
+
|
72
92
|
if not is_sorted:
|
73
93
|
row, col, perm = sort_csc(row, col, src_node_time, edge_time)
|
74
|
-
colptr = index2ptr(col,
|
94
|
+
colptr = index2ptr(col,
|
95
|
+
data.size(1) if not to_transpose else data.size(0))
|
75
96
|
else:
|
76
97
|
row = torch.empty(0, dtype=torch.long, device=device)
|
77
98
|
colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
|
@@ -97,6 +118,7 @@ def to_hetero_csc(
|
|
97
118
|
is_sorted: bool = False,
|
98
119
|
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
|
99
120
|
edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
|
121
|
+
to_transpose: bool = False,
|
100
122
|
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
|
101
123
|
# Convert the heterogeneous graph data into a suitable format for sampling
|
102
124
|
# (CSC format).
|
@@ -108,7 +130,11 @@ def to_hetero_csc(
|
|
108
130
|
src_node_time = (node_time_dict or {}).get(edge_type[0], None)
|
109
131
|
edge_time = (edge_time_dict or {}).get(edge_type, None)
|
110
132
|
out = to_csc(store, device, share_memory, is_sorted, src_node_time,
|
111
|
-
edge_time)
|
133
|
+
edge_time, to_transpose)
|
134
|
+
# Edge types need to be reversed for backward sampling:
|
135
|
+
if to_transpose:
|
136
|
+
edge_type = reverse_edge_type(edge_type)
|
137
|
+
|
112
138
|
colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out
|
113
139
|
|
114
140
|
return colptr_dict, row_dict, perm_dict
|
@@ -205,3 +231,20 @@ def global_to_local_node_idx(node_values: Tensor,
|
|
205
231
|
sort_idx = torch.argsort(idx_match[:, 1])
|
206
232
|
|
207
233
|
return idx_match[:, 0][sort_idx]
|
234
|
+
|
235
|
+
|
236
|
+
def unique_unsorted(tensor: Tensor) -> Tensor:
|
237
|
+
"""Returns the unique elements of a tensor while preserving the original
|
238
|
+
order.
|
239
|
+
|
240
|
+
Necessary because torch.unique() ignores sort parameter.
|
241
|
+
"""
|
242
|
+
seen = set()
|
243
|
+
output = []
|
244
|
+
for val in tensor:
|
245
|
+
val = tuple(val.tolist())
|
246
|
+
if val not in seen:
|
247
|
+
seen.add(val)
|
248
|
+
output.append(val)
|
249
|
+
return torch.tensor(output, dtype=tensor.dtype,
|
250
|
+
device=tensor.device).reshape((-1, *tensor.shape[1:]))
|
File without changes
|
{pyg_nightly-2.7.0.dev20250826.dist-info → pyg_nightly-2.7.0.dev20250828.dist-info}/licenses/LICENSE
RENAMED
File without changes
|