pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -25,7 +25,13 @@ from torch_geometric.sampler import (
|
|
|
25
25
|
SamplerOutput,
|
|
26
26
|
)
|
|
27
27
|
from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType
|
|
28
|
-
from torch_geometric.sampler.utils import
|
|
28
|
+
from torch_geometric.sampler.utils import (
|
|
29
|
+
global_to_local_node_idx,
|
|
30
|
+
remap_keys,
|
|
31
|
+
reverse_edge_type,
|
|
32
|
+
to_csc,
|
|
33
|
+
to_hetero_csc,
|
|
34
|
+
)
|
|
29
35
|
from torch_geometric.typing import EdgeType, NodeType, OptTensor
|
|
30
36
|
|
|
31
37
|
NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]
|
|
@@ -47,23 +53,33 @@ class NeighborSampler(BaseSampler):
|
|
|
47
53
|
weight_attr: Optional[str] = None,
|
|
48
54
|
is_sorted: bool = False,
|
|
49
55
|
share_memory: bool = False,
|
|
50
|
-
# Deprecated
|
|
51
|
-
|
|
56
|
+
directed: bool = True, # Deprecated
|
|
57
|
+
sample_direction: Literal['forward', 'backward'] = 'forward',
|
|
52
58
|
):
|
|
53
59
|
if not directed:
|
|
54
60
|
subgraph_type = SubgraphType.induced
|
|
55
|
-
warnings.warn(
|
|
56
|
-
|
|
57
|
-
|
|
61
|
+
warnings.warn(
|
|
62
|
+
f"The usage of the 'directed' argument in "
|
|
63
|
+
f"'{self.__class__.__name__}' is deprecated. Use "
|
|
64
|
+
f"`subgraph_type='induced'` instead.", stacklevel=2)
|
|
58
65
|
|
|
59
66
|
if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'
|
|
60
67
|
and subgraph_type != SubgraphType.induced):
|
|
61
|
-
warnings.warn(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
68
|
+
warnings.warn(
|
|
69
|
+
f"Using '{self.__class__.__name__}' without a "
|
|
70
|
+
f"'pyg-lib' installation is deprecated and will be "
|
|
71
|
+
f"removed soon. Please install 'pyg-lib' for "
|
|
72
|
+
f"accelerated neighborhood sampling", stacklevel=2)
|
|
65
73
|
|
|
66
74
|
self.data_type = DataType.from_data(data)
|
|
75
|
+
self.sample_direction = sample_direction
|
|
76
|
+
|
|
77
|
+
if self.sample_direction == 'backward':
|
|
78
|
+
# TODO(zaristei)
|
|
79
|
+
if time_attr is not None:
|
|
80
|
+
raise NotImplementedError(
|
|
81
|
+
"Temporal Sampling not yet supported for backward sampling"
|
|
82
|
+
)
|
|
67
83
|
|
|
68
84
|
if self.data_type == DataType.homogeneous:
|
|
69
85
|
self.num_nodes = data.num_nodes
|
|
@@ -85,7 +101,8 @@ class NeighborSampler(BaseSampler):
|
|
|
85
101
|
self.colptr, self.row, self.perm = to_csc(
|
|
86
102
|
data, device='cpu', share_memory=share_memory,
|
|
87
103
|
is_sorted=is_sorted, src_node_time=self.node_time,
|
|
88
|
-
edge_time=self.edge_time
|
|
104
|
+
edge_time=self.edge_time,
|
|
105
|
+
to_transpose=self.sample_direction == 'backward')
|
|
89
106
|
|
|
90
107
|
if self.edge_time is not None and self.perm is not None:
|
|
91
108
|
self.edge_time = self.edge_time[self.perm]
|
|
@@ -99,6 +116,17 @@ class NeighborSampler(BaseSampler):
|
|
|
99
116
|
elif self.data_type == DataType.heterogeneous:
|
|
100
117
|
self.node_types, self.edge_types = data.metadata()
|
|
101
118
|
|
|
119
|
+
# reverse edge types if sample_direction is backward
|
|
120
|
+
if self.sample_direction == 'backward':
|
|
121
|
+
self.edge_types = [
|
|
122
|
+
reverse_edge_type(edge_type)
|
|
123
|
+
for edge_type in self.edge_types
|
|
124
|
+
]
|
|
125
|
+
self.to_restored_edge_type = {
|
|
126
|
+
k: reverse_edge_type(k)
|
|
127
|
+
for k in self.edge_types
|
|
128
|
+
}
|
|
129
|
+
|
|
102
130
|
self.num_nodes = {k: data[k].num_nodes for k in self.node_types}
|
|
103
131
|
|
|
104
132
|
self.node_time: Optional[Dict[NodeType, Tensor]] = None
|
|
@@ -139,7 +167,8 @@ class NeighborSampler(BaseSampler):
|
|
|
139
167
|
colptr_dict, row_dict, self.perm = to_hetero_csc(
|
|
140
168
|
data, device='cpu', share_memory=share_memory,
|
|
141
169
|
is_sorted=is_sorted, node_time_dict=self.node_time,
|
|
142
|
-
edge_time_dict=self.edge_time
|
|
170
|
+
edge_time_dict=self.edge_time,
|
|
171
|
+
to_transpose=self.sample_direction == 'backward')
|
|
143
172
|
|
|
144
173
|
self.row_dict = remap_keys(row_dict, self.to_rel_type)
|
|
145
174
|
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
|
|
@@ -170,6 +199,21 @@ class NeighborSampler(BaseSampler):
|
|
|
170
199
|
edge_attrs = graph_store.get_all_edge_attrs()
|
|
171
200
|
self.edge_types = list({attr.edge_type for attr in edge_attrs})
|
|
172
201
|
|
|
202
|
+
# reverse edge types if sample_direction is backward
|
|
203
|
+
if self.sample_direction == 'backward':
|
|
204
|
+
self.edge_types = [
|
|
205
|
+
reverse_edge_type(edge_type)
|
|
206
|
+
for edge_type in self.edge_types
|
|
207
|
+
]
|
|
208
|
+
self.to_restored_edge_type = {
|
|
209
|
+
k: reverse_edge_type(k)
|
|
210
|
+
for k in self.edge_types
|
|
211
|
+
}
|
|
212
|
+
self.to_backward_edge_type = {
|
|
213
|
+
v: k
|
|
214
|
+
for k, v in self.to_restored_edge_type.items()
|
|
215
|
+
}
|
|
216
|
+
|
|
173
217
|
if weight_attr is not None:
|
|
174
218
|
raise NotImplementedError(
|
|
175
219
|
f"'weight_attr' argument not yet supported within "
|
|
@@ -219,7 +263,10 @@ class NeighborSampler(BaseSampler):
|
|
|
219
263
|
else:
|
|
220
264
|
self.edge_time = time_tensor
|
|
221
265
|
|
|
222
|
-
|
|
266
|
+
if self.sample_direction == 'forward':
|
|
267
|
+
self.row, self.colptr, self.perm = graph_store.csc()
|
|
268
|
+
elif self.sample_direction == 'backward':
|
|
269
|
+
self.colptr, self.row, self.perm = graph_store.csr()
|
|
223
270
|
|
|
224
271
|
else:
|
|
225
272
|
node_types = [
|
|
@@ -259,8 +306,17 @@ class NeighborSampler(BaseSampler):
|
|
|
259
306
|
# Conversion to/from C++ string type (see above):
|
|
260
307
|
self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
|
|
261
308
|
self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
|
|
262
|
-
|
|
263
|
-
|
|
309
|
+
if self.sample_direction == 'forward':
|
|
310
|
+
row_dict, colptr_dict, self.perm = graph_store.csc()
|
|
311
|
+
elif self.sample_direction == 'backward':
|
|
312
|
+
colptr_dict, row_dict, self.perm = graph_store.csr()
|
|
313
|
+
|
|
314
|
+
colptr_dict = remap_keys(colptr_dict,
|
|
315
|
+
self.to_backward_edge_type)
|
|
316
|
+
row_dict = remap_keys(row_dict, self.to_backward_edge_type)
|
|
317
|
+
self.perm = remap_keys(self.perm,
|
|
318
|
+
self.to_backward_edge_type)
|
|
319
|
+
|
|
264
320
|
self.row_dict = remap_keys(row_dict, self.to_rel_type)
|
|
265
321
|
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
|
|
266
322
|
|
|
@@ -279,17 +335,42 @@ class NeighborSampler(BaseSampler):
|
|
|
279
335
|
self.subgraph_type = SubgraphType(subgraph_type)
|
|
280
336
|
self.disjoint = disjoint
|
|
281
337
|
self.temporal_strategy = temporal_strategy
|
|
338
|
+
self.keep_orig_edges = False
|
|
282
339
|
|
|
283
340
|
@property
|
|
284
341
|
def num_neighbors(self) -> NumNeighbors:
|
|
342
|
+
if self.sample_direction == 'backward':
|
|
343
|
+
return self._input_num_neighbors \
|
|
344
|
+
if self._input_num_neighbors is not None \
|
|
345
|
+
else self._num_neighbors
|
|
285
346
|
return self._num_neighbors
|
|
286
347
|
|
|
287
348
|
@num_neighbors.setter
|
|
288
349
|
def num_neighbors(self, num_neighbors: NumNeighborsType):
|
|
350
|
+
# only used if sample direction is backward and num_neighbors has edge
|
|
351
|
+
# keys
|
|
352
|
+
self._input_num_neighbors = None
|
|
353
|
+
|
|
289
354
|
if isinstance(num_neighbors, NumNeighbors):
|
|
290
|
-
|
|
355
|
+
num_neighbors_values = num_neighbors.values
|
|
356
|
+
if isinstance(num_neighbors_values,
|
|
357
|
+
dict) and self.sample_direction == 'backward':
|
|
358
|
+
# reverse the edge_types if sample_direction is backward
|
|
359
|
+
self._input_num_neighbors = num_neighbors
|
|
360
|
+
num_neighbors_values = remap_keys(num_neighbors_values,
|
|
361
|
+
self.to_backward_edge_type)
|
|
362
|
+
self._num_neighbors = NumNeighbors(num_neighbors_values)
|
|
363
|
+
else:
|
|
364
|
+
self._num_neighbors = num_neighbors
|
|
291
365
|
else:
|
|
292
|
-
|
|
366
|
+
if isinstance(num_neighbors,
|
|
367
|
+
dict) and self.sample_direction == 'backward':
|
|
368
|
+
# intentionally recursing here to make sure num_neighbors is
|
|
369
|
+
# set as expected for the user
|
|
370
|
+
self.num_neighbors = NumNeighbors(
|
|
371
|
+
remap_keys(num_neighbors, self.to_backward_edge_type))
|
|
372
|
+
else:
|
|
373
|
+
self._num_neighbors = NumNeighbors(num_neighbors)
|
|
293
374
|
|
|
294
375
|
@property
|
|
295
376
|
def is_hetero(self) -> bool:
|
|
@@ -321,7 +402,7 @@ class NeighborSampler(BaseSampler):
|
|
|
321
402
|
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
|
322
403
|
out = node_sample(inputs, self._sample)
|
|
323
404
|
if self.subgraph_type == SubgraphType.bidirectional:
|
|
324
|
-
out = out.to_bidirectional()
|
|
405
|
+
out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges)
|
|
325
406
|
return out
|
|
326
407
|
|
|
327
408
|
# Edge-based sampling #####################################################
|
|
@@ -334,7 +415,7 @@ class NeighborSampler(BaseSampler):
|
|
|
334
415
|
out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,
|
|
335
416
|
self.node_time, neg_sampling)
|
|
336
417
|
if self.subgraph_type == SubgraphType.bidirectional:
|
|
337
|
-
out = out.to_bidirectional()
|
|
418
|
+
out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges)
|
|
338
419
|
return out
|
|
339
420
|
|
|
340
421
|
# Other Utilities #########################################################
|
|
@@ -431,17 +512,34 @@ class NeighborSampler(BaseSampler):
|
|
|
431
512
|
raise ImportError(f"'{self.__class__.__name__}' requires "
|
|
432
513
|
f"either 'pyg-lib' or 'torch-sparse'")
|
|
433
514
|
|
|
515
|
+
if self.sample_direction == 'backward':
|
|
516
|
+
row, col = col, row
|
|
517
|
+
|
|
518
|
+
row = remap_keys(row, self.to_edge_type)
|
|
519
|
+
col = remap_keys(col, self.to_edge_type)
|
|
520
|
+
edge = remap_keys(edge, self.to_edge_type)
|
|
521
|
+
|
|
522
|
+
# In the case of backward sampling, we need to restore the edges
|
|
523
|
+
# keys to be forward facing in the HeteroSamplerOutput object.
|
|
524
|
+
if self.sample_direction == 'backward':
|
|
525
|
+
row = remap_keys(row, self.to_restored_edge_type)
|
|
526
|
+
col = remap_keys(col, self.to_restored_edge_type)
|
|
527
|
+
edge = remap_keys(edge, self.to_restored_edge_type)
|
|
528
|
+
|
|
434
529
|
if num_sampled_edges is not None:
|
|
435
530
|
num_sampled_edges = remap_keys(
|
|
436
531
|
num_sampled_edges,
|
|
437
532
|
self.to_edge_type,
|
|
438
533
|
)
|
|
534
|
+
if self.sample_direction == 'backward':
|
|
535
|
+
num_sampled_edges = remap_keys(num_sampled_edges,
|
|
536
|
+
self.to_restored_edge_type)
|
|
439
537
|
|
|
440
538
|
return HeteroSamplerOutput(
|
|
441
539
|
node=node,
|
|
442
|
-
row=
|
|
443
|
-
col=
|
|
444
|
-
edge=
|
|
540
|
+
row=row,
|
|
541
|
+
col=col,
|
|
542
|
+
edge=edge,
|
|
445
543
|
batch=batch,
|
|
446
544
|
num_sampled_nodes=num_sampled_nodes,
|
|
447
545
|
num_sampled_edges=num_sampled_edges,
|
|
@@ -508,6 +606,9 @@ class NeighborSampler(BaseSampler):
|
|
|
508
606
|
raise ImportError(f"'{self.__class__.__name__}' requires "
|
|
509
607
|
f"either 'pyg-lib' or 'torch-sparse'")
|
|
510
608
|
|
|
609
|
+
if self.sample_direction == 'backward':
|
|
610
|
+
row, col = col, row
|
|
611
|
+
|
|
511
612
|
return SamplerOutput(
|
|
512
613
|
node=node,
|
|
513
614
|
row=row,
|
|
@@ -519,6 +620,178 @@ class NeighborSampler(BaseSampler):
|
|
|
519
620
|
)
|
|
520
621
|
|
|
521
622
|
|
|
623
|
+
class BidirectionalNeighborSampler(NeighborSampler):
|
|
624
|
+
"""A sampler that allows for both upstream and downstream sampling."""
|
|
625
|
+
def __init__(
|
|
626
|
+
self,
|
|
627
|
+
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
|
|
628
|
+
num_neighbors: NumNeighborsType,
|
|
629
|
+
subgraph_type: Union[SubgraphType, str] = 'directional',
|
|
630
|
+
replace: bool = False,
|
|
631
|
+
disjoint: bool = False,
|
|
632
|
+
temporal_strategy: str = 'uniform',
|
|
633
|
+
time_attr: Optional[str] = None,
|
|
634
|
+
weight_attr: Optional[str] = None,
|
|
635
|
+
is_sorted: bool = False,
|
|
636
|
+
share_memory: bool = False,
|
|
637
|
+
# Deprecated:
|
|
638
|
+
directed: bool = True,
|
|
639
|
+
):
|
|
640
|
+
|
|
641
|
+
# TODO(zaristei)
|
|
642
|
+
if isinstance(num_neighbors, NumNeighbors) and isinstance(
|
|
643
|
+
num_neighbors.values, dict) or isinstance(num_neighbors, dict):
|
|
644
|
+
raise RuntimeError(
|
|
645
|
+
"BidirectionalNeighborSampler does not yet support edge "
|
|
646
|
+
"delimited sampling.")
|
|
647
|
+
|
|
648
|
+
self.forward_sampler = NeighborSampler(
|
|
649
|
+
data, num_neighbors, subgraph_type, replace, disjoint,
|
|
650
|
+
temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
|
|
651
|
+
sample_direction='forward', directed=directed)
|
|
652
|
+
self.backward_sampler = NeighborSampler(
|
|
653
|
+
data, num_neighbors, subgraph_type, replace, disjoint,
|
|
654
|
+
temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,
|
|
655
|
+
sample_direction='backward', directed=directed)
|
|
656
|
+
|
|
657
|
+
# Trigger warnings on init if number of hops is greater than 1
|
|
658
|
+
self.num_neighbors = num_neighbors
|
|
659
|
+
self.subgraph_type = subgraph_type
|
|
660
|
+
|
|
661
|
+
@property
|
|
662
|
+
def num_neighbors(self) -> NumNeighbors:
|
|
663
|
+
return self._num_neighbors
|
|
664
|
+
|
|
665
|
+
@num_neighbors.setter
|
|
666
|
+
def num_neighbors(self, num_neighbors: NumNeighborsType):
|
|
667
|
+
if not isinstance(num_neighbors, NumNeighbors):
|
|
668
|
+
num_neighbors = NumNeighbors(num_neighbors)
|
|
669
|
+
if num_neighbors.num_hops > 1:
|
|
670
|
+
print("Warning: Number of hops is greater than 1, resulting in "
|
|
671
|
+
"memory-expensive recursive calls.")
|
|
672
|
+
self._num_neighbors = num_neighbors
|
|
673
|
+
|
|
674
|
+
@property
|
|
675
|
+
def is_hetero(self) -> bool:
|
|
676
|
+
return self.forward_sampler.is_hetero
|
|
677
|
+
|
|
678
|
+
@property
|
|
679
|
+
def is_temporal(self) -> bool:
|
|
680
|
+
return self.forward_sampler.is_temporal
|
|
681
|
+
|
|
682
|
+
@property
|
|
683
|
+
def disjoint(self) -> bool:
|
|
684
|
+
return self.forward_sampler.disjoint
|
|
685
|
+
|
|
686
|
+
@disjoint.setter
|
|
687
|
+
def disjoint(self, disjoint: bool):
|
|
688
|
+
self.forward_sampler.disjoint = disjoint
|
|
689
|
+
self.backward_sampler.disjoint = disjoint
|
|
690
|
+
|
|
691
|
+
def sample_from_nodes(
|
|
692
|
+
self,
|
|
693
|
+
inputs: NodeSamplerInput,
|
|
694
|
+
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
|
695
|
+
return super().sample_from_nodes(inputs)
|
|
696
|
+
|
|
697
|
+
def sample_from_edges(
|
|
698
|
+
self,
|
|
699
|
+
inputs: EdgeSamplerInput,
|
|
700
|
+
neg_sampling: Optional[NegativeSampling] = None,
|
|
701
|
+
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
|
702
|
+
# TODO(zaristei) Figure out what exactly regular and negative sampling
|
|
703
|
+
# imply for bidirectional sampling case
|
|
704
|
+
if neg_sampling is not None:
|
|
705
|
+
raise RuntimeError(
|
|
706
|
+
"BidirectionalNeighborSampler does not yet support "
|
|
707
|
+
"negative sampling.")
|
|
708
|
+
# Not thoroughly tested yet!
|
|
709
|
+
return super().sample_from_edges(inputs)
|
|
710
|
+
|
|
711
|
+
@property
|
|
712
|
+
def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
|
|
713
|
+
return self.forward_sampler.edge_permutation
|
|
714
|
+
|
|
715
|
+
def _sample(
|
|
716
|
+
self,
|
|
717
|
+
seed: Union[Tensor, Dict[NodeType, Tensor]],
|
|
718
|
+
seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
|
|
719
|
+
**kwargs,
|
|
720
|
+
) -> Union[SamplerOutput, HeteroSamplerOutput]:
|
|
721
|
+
|
|
722
|
+
if seed_time is not None:
|
|
723
|
+
raise NotImplementedError(
|
|
724
|
+
"BidirectionalNeighborSampler does not yet support "
|
|
725
|
+
"temporal sampling.")
|
|
726
|
+
|
|
727
|
+
if self.is_hetero:
|
|
728
|
+
raise NotImplementedError(
|
|
729
|
+
"BidirectionalNeighborSampler does not yet support "
|
|
730
|
+
"heterogeneous sampling.")
|
|
731
|
+
else:
|
|
732
|
+
current_seed = seed
|
|
733
|
+
current_seed_batch = None
|
|
734
|
+
current_seed_time = seed_time
|
|
735
|
+
seen_seed_set = {int(node) for node in current_seed}
|
|
736
|
+
if self.disjoint:
|
|
737
|
+
current_seed_batch = torch.arange(len(current_seed))
|
|
738
|
+
seen_seed_set = {
|
|
739
|
+
(int(node), int(batch))
|
|
740
|
+
for node, batch in zip(current_seed, current_seed_batch)
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
iter_results = []
|
|
744
|
+
|
|
745
|
+
for n_neighbors in self.num_neighbors.values:
|
|
746
|
+
current_n_neighbors = [n_neighbors]
|
|
747
|
+
self.forward_sampler.num_neighbors = current_n_neighbors
|
|
748
|
+
self.backward_sampler.num_neighbors = current_n_neighbors
|
|
749
|
+
|
|
750
|
+
fwd_result = self.forward_sampler._sample(
|
|
751
|
+
current_seed, current_seed_time, **kwargs)
|
|
752
|
+
bwd_result = self.backward_sampler._sample(
|
|
753
|
+
current_seed, current_seed_time, **kwargs)
|
|
754
|
+
# The seeds for the next iteration will be the new nodes in
|
|
755
|
+
# this iteration
|
|
756
|
+
iter_result = fwd_result.merge_with(bwd_result)
|
|
757
|
+
iter_results.append(iter_result)
|
|
758
|
+
|
|
759
|
+
# Find the nodes not yet seen to set a seed for next iteration
|
|
760
|
+
if self.disjoint:
|
|
761
|
+
iter_seed_global_batch = global_to_local_node_idx(
|
|
762
|
+
current_seed_batch, iter_result.batch)
|
|
763
|
+
iter_result.seed_node = seed[iter_seed_global_batch]
|
|
764
|
+
|
|
765
|
+
keep_mask = torch.tensor([
|
|
766
|
+
(int(node), int(batch)) not in seen_seed_set
|
|
767
|
+
for node, batch in zip(iter_result.node,
|
|
768
|
+
iter_seed_global_batch)
|
|
769
|
+
])
|
|
770
|
+
next_seed = [(int(node), int(batch))
|
|
771
|
+
for node, batch in zip(
|
|
772
|
+
iter_result.node[keep_mask],
|
|
773
|
+
iter_seed_global_batch[keep_mask])
|
|
774
|
+
] if keep_mask.any() else []
|
|
775
|
+
current_seed, current_seed_batch = torch.tensor(
|
|
776
|
+
next_seed).reshape(-1, 2).transpose(0, 1).contiguous()
|
|
777
|
+
else:
|
|
778
|
+
keep_mask = torch.tensor([
|
|
779
|
+
int(node) not in seen_seed_set
|
|
780
|
+
for node in iter_result.node
|
|
781
|
+
])
|
|
782
|
+
next_seed = [
|
|
783
|
+
int(node) for node in iter_result.node[keep_mask]
|
|
784
|
+
] if keep_mask.any() else []
|
|
785
|
+
current_seed = torch.tensor(next_seed)
|
|
786
|
+
|
|
787
|
+
seen_seed_set |= set(next_seed)
|
|
788
|
+
|
|
789
|
+
# TODO(zaristei) figure out how to update seed times for
|
|
790
|
+
# temporal sampling
|
|
791
|
+
|
|
792
|
+
return SamplerOutput.collate(iter_results)
|
|
793
|
+
|
|
794
|
+
|
|
522
795
|
# Sampling Utilities ##########################################################
|
|
523
796
|
|
|
524
797
|
|
|
@@ -631,7 +904,7 @@ def edge_sample(
|
|
|
631
904
|
if edge_label_time is not None:
|
|
632
905
|
dst_time = edge_label_time.repeat(1 + neg_sampling.amount)
|
|
633
906
|
|
|
634
|
-
#
|
|
907
|
+
# Heterogeneous Neighborhood Sampling #####################################
|
|
635
908
|
|
|
636
909
|
if input_type is not None:
|
|
637
910
|
seed_time_dict = None
|
|
@@ -724,7 +997,7 @@ def edge_sample(
|
|
|
724
997
|
src_time,
|
|
725
998
|
)
|
|
726
999
|
|
|
727
|
-
#
|
|
1000
|
+
# Homogeneous Neighborhood Sampling #######################################
|
|
728
1001
|
|
|
729
1002
|
else:
|
|
730
1003
|
|
|
@@ -805,7 +1078,7 @@ def neg_sample(
|
|
|
805
1078
|
out = out.view(num_samples, seed.numel())
|
|
806
1079
|
mask = node_time[out] > seed_time # holds all invalid samples.
|
|
807
1080
|
neg_sampling_complete = False
|
|
808
|
-
for
|
|
1081
|
+
for _ in range(5): # pragma: no cover
|
|
809
1082
|
num_invalid = int(mask.sum())
|
|
810
1083
|
if num_invalid == 0:
|
|
811
1084
|
neg_sampling_complete = True
|
torch_geometric/sampler/utils.py
CHANGED
|
@@ -9,6 +9,15 @@ from torch_geometric.index import index2ptr
|
|
|
9
9
|
from torch_geometric.typing import EdgeType, NodeType, OptTensor
|
|
10
10
|
from torch_geometric.utils import coalesce, index_sort, lexsort
|
|
11
11
|
|
|
12
|
+
|
|
13
|
+
def reverse_edge_type(edge_type: EdgeType) -> EdgeType:
|
|
14
|
+
"""Reverses edge types for heterogeneous graphs. Useful in cases of
|
|
15
|
+
backward sampling.
|
|
16
|
+
"""
|
|
17
|
+
return (edge_type[2], edge_type[1],
|
|
18
|
+
edge_type[0]) if edge_type is not None else None
|
|
19
|
+
|
|
20
|
+
|
|
12
21
|
# Edge Layout Conversion ######################################################
|
|
13
22
|
|
|
14
23
|
|
|
@@ -41,6 +50,7 @@ def to_csc(
|
|
|
41
50
|
is_sorted: bool = False,
|
|
42
51
|
src_node_time: Optional[Tensor] = None,
|
|
43
52
|
edge_time: Optional[Tensor] = None,
|
|
53
|
+
to_transpose: bool = False,
|
|
44
54
|
) -> Tuple[Tensor, Tensor, OptTensor]:
|
|
45
55
|
# Convert the graph data into a suitable format for sampling (CSC format).
|
|
46
56
|
# Returns the `colptr` and `row` indices of the graph, as well as an
|
|
@@ -53,7 +63,10 @@ def to_csc(
|
|
|
53
63
|
if src_node_time is not None:
|
|
54
64
|
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
|
|
55
65
|
"format not yet supported")
|
|
56
|
-
|
|
66
|
+
if to_transpose:
|
|
67
|
+
row, colptr, _ = data.adj.csr()
|
|
68
|
+
else:
|
|
69
|
+
colptr, row, _ = data.adj.csc()
|
|
57
70
|
|
|
58
71
|
elif hasattr(data, 'adj_t'):
|
|
59
72
|
if src_node_time is not None:
|
|
@@ -65,13 +78,21 @@ def to_csc(
|
|
|
65
78
|
# raise NotImplementedError("Temporal sampling via 'SparseTensor' "
|
|
66
79
|
# "format not yet supported")
|
|
67
80
|
pass
|
|
68
|
-
|
|
81
|
+
if to_transpose:
|
|
82
|
+
row, colptr, _ = data.adj_t.csc()
|
|
83
|
+
else:
|
|
84
|
+
colptr, row, _ = data.adj_t.csr()
|
|
69
85
|
|
|
70
86
|
elif data.edge_index is not None:
|
|
71
|
-
|
|
87
|
+
if to_transpose:
|
|
88
|
+
col, row = data.edge_index
|
|
89
|
+
else:
|
|
90
|
+
row, col = data.edge_index
|
|
91
|
+
|
|
72
92
|
if not is_sorted:
|
|
73
93
|
row, col, perm = sort_csc(row, col, src_node_time, edge_time)
|
|
74
|
-
colptr = index2ptr(col,
|
|
94
|
+
colptr = index2ptr(col,
|
|
95
|
+
data.size(1) if not to_transpose else data.size(0))
|
|
75
96
|
else:
|
|
76
97
|
row = torch.empty(0, dtype=torch.long, device=device)
|
|
77
98
|
colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
|
|
@@ -97,6 +118,7 @@ def to_hetero_csc(
|
|
|
97
118
|
is_sorted: bool = False,
|
|
98
119
|
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
|
|
99
120
|
edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
|
|
121
|
+
to_transpose: bool = False,
|
|
100
122
|
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
|
|
101
123
|
# Convert the heterogeneous graph data into a suitable format for sampling
|
|
102
124
|
# (CSC format).
|
|
@@ -108,7 +130,11 @@ def to_hetero_csc(
|
|
|
108
130
|
src_node_time = (node_time_dict or {}).get(edge_type[0], None)
|
|
109
131
|
edge_time = (edge_time_dict or {}).get(edge_type, None)
|
|
110
132
|
out = to_csc(store, device, share_memory, is_sorted, src_node_time,
|
|
111
|
-
edge_time)
|
|
133
|
+
edge_time, to_transpose)
|
|
134
|
+
# Edge types need to be reversed for backward sampling:
|
|
135
|
+
if to_transpose:
|
|
136
|
+
edge_type = reverse_edge_type(edge_type)
|
|
137
|
+
|
|
112
138
|
colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out
|
|
113
139
|
|
|
114
140
|
return colptr_dict, row_dict, perm_dict
|
|
@@ -160,3 +186,65 @@ def remap_keys(
|
|
|
160
186
|
k if k in exclude else mapping.get(k, k): v
|
|
161
187
|
for k, v in inputs.items()
|
|
162
188
|
}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def local_to_global_node_idx(node_values: Tensor,
|
|
192
|
+
local_indices: Tensor) -> Tensor:
|
|
193
|
+
"""Convert a tensor of indices referring to elements in the node_values
|
|
194
|
+
tensor to their values.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
node_values (Tensor): The node values. (num_nodes, feature_dim)
|
|
198
|
+
local_indices (Tensor): The local indices. (num_indices)
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Tensor: The values of the node_values tensor at the local indices.
|
|
202
|
+
(num_indices, feature_dim)
|
|
203
|
+
"""
|
|
204
|
+
return torch.index_select(node_values, dim=0, index=local_indices)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def global_to_local_node_idx(node_values: Tensor,
|
|
208
|
+
local_values: Tensor) -> Tensor:
|
|
209
|
+
"""Converts a tensor of values that are contained in the node_values
|
|
210
|
+
tensor to their indices in that tensor.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
node_values (Tensor): The node values. (num_nodes, feature_dim)
|
|
214
|
+
local_values (Tensor): The local values. (num_indices, feature_dim)
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Tensor: The indices of the local values in the node_values tensor.
|
|
218
|
+
(num_indices)
|
|
219
|
+
"""
|
|
220
|
+
if node_values.dim() == 1:
|
|
221
|
+
node_values = node_values.unsqueeze(1)
|
|
222
|
+
if local_values.dim() == 1:
|
|
223
|
+
local_values = local_values.unsqueeze(1)
|
|
224
|
+
node_values_expand = node_values.unsqueeze(-1).expand(
|
|
225
|
+
*node_values.shape,
|
|
226
|
+
local_values.shape[0]) # (num_nodes, feature_dim, num_indices)
|
|
227
|
+
local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand(
|
|
228
|
+
*node_values_expand.shape) # (num_nodes, feature_dim, num_indices)
|
|
229
|
+
idx_match = torch.all(node_values_expand == local_values_expand,
|
|
230
|
+
dim=1).nonzero() # (num_indices, 2)
|
|
231
|
+
sort_idx = torch.argsort(idx_match[:, 1])
|
|
232
|
+
|
|
233
|
+
return idx_match[:, 0][sort_idx]
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def unique_unsorted(tensor: Tensor) -> Tensor:
|
|
237
|
+
"""Returns the unique elements of a tensor while preserving the original
|
|
238
|
+
order.
|
|
239
|
+
|
|
240
|
+
Necessary because torch.unique() ignores sort parameter.
|
|
241
|
+
"""
|
|
242
|
+
seen = set()
|
|
243
|
+
output = []
|
|
244
|
+
for val in tensor:
|
|
245
|
+
val = tuple(val.tolist())
|
|
246
|
+
if val not in seen:
|
|
247
|
+
seen.add(val)
|
|
248
|
+
output.append(val)
|
|
249
|
+
return torch.tensor(output, dtype=tensor.dtype,
|
|
250
|
+
device=tensor.device).reshape((-1, *tensor.shape[1:]))
|
|
@@ -17,11 +17,13 @@ from .decorators import (
|
|
|
17
17
|
onlyOnline,
|
|
18
18
|
onlyGraphviz,
|
|
19
19
|
onlyNeighborSampler,
|
|
20
|
+
onlyRAG,
|
|
20
21
|
has_package,
|
|
21
22
|
withPackage,
|
|
22
23
|
withDevice,
|
|
23
24
|
withCUDA,
|
|
24
25
|
withMETIS,
|
|
26
|
+
withHashTensor,
|
|
25
27
|
disableExtensions,
|
|
26
28
|
withoutExtensions,
|
|
27
29
|
)
|
|
@@ -48,11 +50,13 @@ __all__ = [
|
|
|
48
50
|
'onlyOnline',
|
|
49
51
|
'onlyGraphviz',
|
|
50
52
|
'onlyNeighborSampler',
|
|
53
|
+
'onlyRAG',
|
|
51
54
|
'has_package',
|
|
52
55
|
'withPackage',
|
|
53
56
|
'withDevice',
|
|
54
57
|
'withCUDA',
|
|
55
58
|
'withMETIS',
|
|
59
|
+
'withHashTensor',
|
|
56
60
|
'disableExtensions',
|
|
57
61
|
'withoutExtensions',
|
|
58
62
|
'assert_module',
|