pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
torch_geometric/sampler/base.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import math
|
|
3
3
|
import warnings
|
|
4
|
-
from abc import ABC
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
5
|
from collections import defaultdict
|
|
6
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
7
|
from enum import Enum
|
|
8
8
|
from typing import Any, Dict, List, Literal, Optional, Union
|
|
9
9
|
|
|
@@ -11,7 +11,12 @@ import torch
|
|
|
11
11
|
from torch import Tensor
|
|
12
12
|
|
|
13
13
|
from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
|
|
14
|
-
from torch_geometric.sampler.utils import
|
|
14
|
+
from torch_geometric.sampler.utils import (
|
|
15
|
+
global_to_local_node_idx,
|
|
16
|
+
local_to_global_node_idx,
|
|
17
|
+
to_bidirectional,
|
|
18
|
+
unique_unsorted,
|
|
19
|
+
)
|
|
15
20
|
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
|
|
16
21
|
from torch_geometric.utils.mixin import CastMixin
|
|
17
22
|
|
|
@@ -206,6 +211,39 @@ class SamplerOutput(CastMixin):
|
|
|
206
211
|
# TODO(manan): refine this further; it does not currently define a proper
|
|
207
212
|
# API for the expected output of a sampler.
|
|
208
213
|
metadata: Optional[Any] = None
|
|
214
|
+
_seed_node: OptTensor = field(repr=False, default=None)
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def global_row(self) -> Tensor:
|
|
218
|
+
return local_to_global_node_idx(self.node, self.row)
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def global_col(self) -> Tensor:
|
|
222
|
+
return local_to_global_node_idx(self.node, self.col)
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def seed_node(self) -> Tensor:
|
|
226
|
+
# can be set manually if the seed nodes are not contained in the
|
|
227
|
+
# sampled nodes
|
|
228
|
+
if self._seed_node is None:
|
|
229
|
+
self._seed_node = local_to_global_node_idx(
|
|
230
|
+
self.node, self.batch) if self.batch is not None else None
|
|
231
|
+
return self._seed_node
|
|
232
|
+
|
|
233
|
+
@seed_node.setter
|
|
234
|
+
def seed_node(self, value: Tensor):
|
|
235
|
+
assert len(value) == len(self.node)
|
|
236
|
+
self._seed_node = value
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def global_orig_row(self) -> Tensor:
|
|
240
|
+
return local_to_global_node_idx(
|
|
241
|
+
self.node, self.orig_row) if self.orig_row is not None else None
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def global_orig_col(self) -> Tensor:
|
|
245
|
+
return local_to_global_node_idx(
|
|
246
|
+
self.node, self.orig_col) if self.orig_col is not None else None
|
|
209
247
|
|
|
210
248
|
def to_bidirectional(
|
|
211
249
|
self,
|
|
@@ -237,6 +275,230 @@ class SamplerOutput(CastMixin):
|
|
|
237
275
|
|
|
238
276
|
return out
|
|
239
277
|
|
|
278
|
+
@classmethod
|
|
279
|
+
def collate(cls, outputs: List['SamplerOutput'],
|
|
280
|
+
replace: bool = True) -> 'SamplerOutput':
|
|
281
|
+
r"""Collate a list of :class:`~torch_geometric.sampler.SamplerOutput`
|
|
282
|
+
objects into a single :class:`~torch_geometric.sampler.SamplerOutput`
|
|
283
|
+
object. Requires that they all have the same fields.
|
|
284
|
+
"""
|
|
285
|
+
if len(outputs) == 0:
|
|
286
|
+
raise ValueError("Cannot collate an empty list of SamplerOutputs")
|
|
287
|
+
out = outputs[0]
|
|
288
|
+
has_edge = out.edge is not None
|
|
289
|
+
has_orig_row = out.orig_row is not None
|
|
290
|
+
has_orig_col = out.orig_col is not None
|
|
291
|
+
has_batch = out.batch is not None
|
|
292
|
+
has_num_sampled_nodes = out.num_sampled_nodes is not None
|
|
293
|
+
has_num_sampled_edges = out.num_sampled_edges is not None
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
for i, sample_output in enumerate(outputs): # noqa
|
|
297
|
+
assert not has_edge == (sample_output.edge is None)
|
|
298
|
+
assert not has_orig_row == (sample_output.orig_row is None)
|
|
299
|
+
assert not has_orig_col == (sample_output.orig_col is None)
|
|
300
|
+
assert not has_batch == (sample_output.batch is None)
|
|
301
|
+
assert not has_num_sampled_nodes == (
|
|
302
|
+
sample_output.num_sampled_nodes is None)
|
|
303
|
+
assert not has_num_sampled_edges == (
|
|
304
|
+
sample_output.num_sampled_edges is None)
|
|
305
|
+
except AssertionError:
|
|
306
|
+
error_str = f"Output {i+1} has a different field than the first output" # noqa
|
|
307
|
+
raise ValueError(error_str) # noqa
|
|
308
|
+
|
|
309
|
+
for other in outputs[1:]:
|
|
310
|
+
out = out.merge_with(other, replace=replace)
|
|
311
|
+
return out
|
|
312
|
+
|
|
313
|
+
def merge_with(self, other: 'SamplerOutput',
|
|
314
|
+
replace: bool = True) -> 'SamplerOutput':
|
|
315
|
+
"""Merges two SamplerOutputs.
|
|
316
|
+
If replace is True, self's nodes and edges take precedence.
|
|
317
|
+
"""
|
|
318
|
+
if not replace:
|
|
319
|
+
return SamplerOutput(
|
|
320
|
+
node=torch.cat([self.node, other.node], dim=0),
|
|
321
|
+
row=torch.cat([self.row, len(self.node) + other.row], dim=0),
|
|
322
|
+
col=torch.cat([self.col, len(self.node) + other.col], dim=0),
|
|
323
|
+
edge=torch.cat([self.edge, other.edge], dim=0)
|
|
324
|
+
if self.edge is not None and other.edge is not None else None,
|
|
325
|
+
batch=torch.cat(
|
|
326
|
+
[self.batch, len(self.node) + other.batch], dim=0) if
|
|
327
|
+
self.batch is not None and other.batch is not None else None,
|
|
328
|
+
num_sampled_nodes=self.num_sampled_nodes +
|
|
329
|
+
other.num_sampled_nodes if self.num_sampled_nodes is not None
|
|
330
|
+
and other.num_sampled_nodes is not None else None,
|
|
331
|
+
num_sampled_edges=self.num_sampled_edges +
|
|
332
|
+
other.num_sampled_edges if self.num_sampled_edges is not None
|
|
333
|
+
and other.num_sampled_edges is not None else None,
|
|
334
|
+
orig_row=torch.cat(
|
|
335
|
+
[self.orig_row,
|
|
336
|
+
len(self.node) +
|
|
337
|
+
other.orig_row], dim=0) if self.orig_row is not None
|
|
338
|
+
and other.orig_row is not None else None,
|
|
339
|
+
orig_col=torch.cat(
|
|
340
|
+
[self.orig_col,
|
|
341
|
+
len(self.node) +
|
|
342
|
+
other.orig_col], dim=0) if self.orig_col is not None
|
|
343
|
+
and other.orig_col is not None else None,
|
|
344
|
+
metadata=[self.metadata, other.metadata],
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
|
|
348
|
+
# NODES
|
|
349
|
+
old_nodes, new_nodes = self.node, other.node
|
|
350
|
+
old_node_uid, new_node_uid = [old_nodes], [new_nodes]
|
|
351
|
+
|
|
352
|
+
# batch tracks disjoint subgraph samplings
|
|
353
|
+
if self.batch is not None and other.batch is not None:
|
|
354
|
+
# Transform the batch indices to be global node ids
|
|
355
|
+
old_batch_nodes = self.seed_node
|
|
356
|
+
new_batch_nodes = other.seed_node
|
|
357
|
+
old_node_uid.append(old_batch_nodes)
|
|
358
|
+
new_node_uid.append(new_batch_nodes)
|
|
359
|
+
|
|
360
|
+
# NOTE: if any new node fields are added,
|
|
361
|
+
# they need to be merged here
|
|
362
|
+
|
|
363
|
+
old_node_uid = torch.stack(old_node_uid, dim=1)
|
|
364
|
+
new_node_uid = torch.stack(new_node_uid, dim=1)
|
|
365
|
+
|
|
366
|
+
merged_node_uid = unique_unsorted(
|
|
367
|
+
torch.cat([old_node_uid, new_node_uid], dim=0))
|
|
368
|
+
num_old_nodes = old_node_uid.shape[0]
|
|
369
|
+
|
|
370
|
+
# Recompute num sampled nodes for second output,
|
|
371
|
+
# subtracting out nodes already seen in first output
|
|
372
|
+
merged_node_num_sampled_nodes = None
|
|
373
|
+
if (self.num_sampled_nodes is not None
|
|
374
|
+
and other.num_sampled_nodes is not None):
|
|
375
|
+
merged_node_num_sampled_nodes = copy.copy(
|
|
376
|
+
self.num_sampled_nodes)
|
|
377
|
+
curr_index = 0
|
|
378
|
+
# NOTE: There's an assumption here that no two nodes will be
|
|
379
|
+
# sampled twice in the same SampleOutput object
|
|
380
|
+
for minibatch in other.num_sampled_nodes:
|
|
381
|
+
size_of_intersect = torch.cat([
|
|
382
|
+
old_node_uid,
|
|
383
|
+
new_node_uid[curr_index:curr_index + minibatch]
|
|
384
|
+
]).unique(dim=0, sorted=False).shape[0] - num_old_nodes
|
|
385
|
+
merged_node_num_sampled_nodes.append(size_of_intersect)
|
|
386
|
+
curr_index += minibatch
|
|
387
|
+
|
|
388
|
+
merged_nodes = merged_node_uid[:, 0]
|
|
389
|
+
merged_batch = None
|
|
390
|
+
if self.batch is not None and other.batch is not None:
|
|
391
|
+
# Restore the batch indices to be relative to the nodes field
|
|
392
|
+
ref_merged_batch_nodes = merged_node_uid[:, 1].unsqueeze(
|
|
393
|
+
-1).expand(-1, 2) # num_nodes x 2
|
|
394
|
+
merged_batch = global_to_local_node_idx(
|
|
395
|
+
merged_node_uid, ref_merged_batch_nodes)
|
|
396
|
+
|
|
397
|
+
# EDGES
|
|
398
|
+
is_bidirectional = self.orig_row is not None \
|
|
399
|
+
and self.orig_col is not None \
|
|
400
|
+
and other.orig_row is not None \
|
|
401
|
+
and other.orig_col is not None
|
|
402
|
+
if is_bidirectional:
|
|
403
|
+
old_row, old_col = self.orig_row, self.orig_col
|
|
404
|
+
new_row, new_col = other.orig_row, other.orig_col
|
|
405
|
+
else:
|
|
406
|
+
old_row, old_col = self.row, self.col
|
|
407
|
+
new_row, new_col = other.row, other.col
|
|
408
|
+
|
|
409
|
+
# Transform the row and col indices to be global node ids
|
|
410
|
+
# instead of relative indices to nodes field
|
|
411
|
+
# Edge uids build off of node uids
|
|
412
|
+
old_row_idx, old_col_idx = local_to_global_node_idx(
|
|
413
|
+
old_node_uid,
|
|
414
|
+
old_row), local_to_global_node_idx(old_node_uid, old_col)
|
|
415
|
+
new_row_idx, new_col_idx = local_to_global_node_idx(
|
|
416
|
+
new_node_uid,
|
|
417
|
+
new_row), local_to_global_node_idx(new_node_uid, new_col)
|
|
418
|
+
|
|
419
|
+
old_edge_uid, new_edge_uid = [old_row_idx, old_col_idx
|
|
420
|
+
], [new_row_idx, new_col_idx]
|
|
421
|
+
|
|
422
|
+
row_idx = 0
|
|
423
|
+
col_idx = old_row_idx.shape[1]
|
|
424
|
+
edge_idx = old_row_idx.shape[1] + old_col_idx.shape[1]
|
|
425
|
+
|
|
426
|
+
if self.edge is not None and other.edge is not None:
|
|
427
|
+
if is_bidirectional:
|
|
428
|
+
# bidirectional duplicates edge ids
|
|
429
|
+
old_edge_uid_ref = torch.stack([self.row, self.col],
|
|
430
|
+
dim=1) # num_edges x 2
|
|
431
|
+
old_orig_edge_uid_ref = torch.stack(
|
|
432
|
+
[self.orig_row, self.orig_col],
|
|
433
|
+
dim=1) # num_orig_edges x 2
|
|
434
|
+
|
|
435
|
+
old_edge_idx = global_to_local_node_idx(
|
|
436
|
+
old_edge_uid_ref, old_orig_edge_uid_ref)
|
|
437
|
+
old_edge = self.edge[old_edge_idx]
|
|
438
|
+
|
|
439
|
+
new_edge_uid_ref = torch.stack([other.row, other.col],
|
|
440
|
+
dim=1) # num_edges x 2
|
|
441
|
+
new_orig_edge_uid_ref = torch.stack(
|
|
442
|
+
[other.orig_row, other.orig_col],
|
|
443
|
+
dim=1) # num_orig_edges x 2
|
|
444
|
+
|
|
445
|
+
new_edge_idx = global_to_local_node_idx(
|
|
446
|
+
new_edge_uid_ref, new_orig_edge_uid_ref)
|
|
447
|
+
new_edge = other.edge[new_edge_idx]
|
|
448
|
+
|
|
449
|
+
else:
|
|
450
|
+
old_edge, new_edge = self.edge, other.edge
|
|
451
|
+
|
|
452
|
+
old_edge_uid.append(old_edge.unsqueeze(-1))
|
|
453
|
+
new_edge_uid.append(new_edge.unsqueeze(-1))
|
|
454
|
+
|
|
455
|
+
old_edge_uid = torch.cat(old_edge_uid, dim=1)
|
|
456
|
+
new_edge_uid = torch.cat(new_edge_uid, dim=1)
|
|
457
|
+
|
|
458
|
+
merged_edge_uid = unique_unsorted(
|
|
459
|
+
torch.cat([old_edge_uid, new_edge_uid], dim=0))
|
|
460
|
+
num_old_edges = old_edge_uid.shape[0]
|
|
461
|
+
|
|
462
|
+
merged_edge_num_sampled_edges = None
|
|
463
|
+
if (self.num_sampled_edges is not None
|
|
464
|
+
and other.num_sampled_edges is not None):
|
|
465
|
+
merged_edge_num_sampled_edges = copy.copy(
|
|
466
|
+
self.num_sampled_edges)
|
|
467
|
+
curr_index = 0
|
|
468
|
+
# NOTE: There's an assumption here that no two edges will be
|
|
469
|
+
# sampled twice in the same SampleOutput object
|
|
470
|
+
for minibatch in other.num_sampled_edges:
|
|
471
|
+
size_of_intersect = torch.cat([
|
|
472
|
+
old_edge_uid,
|
|
473
|
+
new_edge_uid[curr_index:curr_index + minibatch]
|
|
474
|
+
]).unique(dim=0, sorted=False).shape[0] - num_old_edges
|
|
475
|
+
merged_edge_num_sampled_edges.append(size_of_intersect)
|
|
476
|
+
curr_index += minibatch
|
|
477
|
+
|
|
478
|
+
merged_row = merged_edge_uid[:, row_idx:col_idx]
|
|
479
|
+
merged_col = merged_edge_uid[:, col_idx:edge_idx]
|
|
480
|
+
merged_edge = merged_edge_uid[:, edge_idx:].squeeze() \
|
|
481
|
+
if self.edge is not None and other.edge is not None else None
|
|
482
|
+
|
|
483
|
+
# restore to row and col indices relative to nodes field
|
|
484
|
+
merged_row = global_to_local_node_idx(merged_node_uid, merged_row)
|
|
485
|
+
merged_col = global_to_local_node_idx(merged_node_uid, merged_col)
|
|
486
|
+
|
|
487
|
+
out = SamplerOutput(
|
|
488
|
+
node=merged_nodes,
|
|
489
|
+
row=merged_row,
|
|
490
|
+
col=merged_col,
|
|
491
|
+
edge=merged_edge,
|
|
492
|
+
batch=merged_batch,
|
|
493
|
+
num_sampled_nodes=merged_node_num_sampled_nodes,
|
|
494
|
+
num_sampled_edges=merged_edge_num_sampled_edges,
|
|
495
|
+
metadata=[self.metadata, other.metadata],
|
|
496
|
+
)
|
|
497
|
+
# Restores orig_row and orig_col if they existed before merging
|
|
498
|
+
if is_bidirectional:
|
|
499
|
+
out = out.to_bidirectional(keep_orig_edges=True)
|
|
500
|
+
return out
|
|
501
|
+
|
|
240
502
|
|
|
241
503
|
@dataclass
|
|
242
504
|
class HeteroSamplerOutput(CastMixin):
|
|
@@ -294,6 +556,43 @@ class HeteroSamplerOutput(CastMixin):
|
|
|
294
556
|
# API for the expected output of a sampler.
|
|
295
557
|
metadata: Optional[Any] = None
|
|
296
558
|
|
|
559
|
+
@property
|
|
560
|
+
def global_row(self) -> Dict[EdgeType, Tensor]:
|
|
561
|
+
return {
|
|
562
|
+
edge_type: local_to_global_node_idx(self.node[edge_type[0]], row)
|
|
563
|
+
for edge_type, row in self.row.items()
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
@property
|
|
567
|
+
def global_col(self) -> Dict[EdgeType, Tensor]:
|
|
568
|
+
return {
|
|
569
|
+
edge_type: local_to_global_node_idx(self.node[edge_type[2]], col)
|
|
570
|
+
for edge_type, col in self.col.items()
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def seed_node(self) -> Optional[Dict[NodeType, Tensor]]:
|
|
575
|
+
return {
|
|
576
|
+
node_type: local_to_global_node_idx(self.node[node_type], batch)
|
|
577
|
+
for node_type, batch in self.batch.items()
|
|
578
|
+
} if self.batch is not None else None
|
|
579
|
+
|
|
580
|
+
@property
|
|
581
|
+
def global_orig_row(self) -> Optional[Dict[EdgeType, Tensor]]:
|
|
582
|
+
return {
|
|
583
|
+
edge_type: local_to_global_node_idx(self.node[edge_type[0]],
|
|
584
|
+
orig_row)
|
|
585
|
+
for edge_type, orig_row in self.orig_row.items()
|
|
586
|
+
} if self.orig_row is not None else None
|
|
587
|
+
|
|
588
|
+
@property
|
|
589
|
+
def global_orig_col(self) -> Optional[Dict[EdgeType, Tensor]]:
|
|
590
|
+
return {
|
|
591
|
+
edge_type: local_to_global_node_idx(self.node[edge_type[2]],
|
|
592
|
+
orig_col)
|
|
593
|
+
for edge_type, orig_col in self.orig_col.items()
|
|
594
|
+
} if self.orig_col is not None else None
|
|
595
|
+
|
|
297
596
|
def to_bidirectional(
|
|
298
597
|
self,
|
|
299
598
|
keep_orig_edges: bool = False,
|
|
@@ -369,12 +668,32 @@ class HeteroSamplerOutput(CastMixin):
|
|
|
369
668
|
out.edge[edge_type] = None
|
|
370
669
|
|
|
371
670
|
else:
|
|
372
|
-
warnings.warn(
|
|
373
|
-
|
|
374
|
-
|
|
671
|
+
warnings.warn(
|
|
672
|
+
f"Cannot convert to bidirectional graph "
|
|
673
|
+
f"since the edge type {edge_type} does not "
|
|
674
|
+
f"seem to have a reverse edge type", stacklevel=2)
|
|
375
675
|
|
|
376
676
|
return out
|
|
377
677
|
|
|
678
|
+
@classmethod
|
|
679
|
+
def collate(cls, outputs: List['HeteroSamplerOutput'],
|
|
680
|
+
replace: bool = True) -> 'HeteroSamplerOutput':
|
|
681
|
+
r"""Collate a list of
|
|
682
|
+
:class:`~torch_geometric.sampler.HeteroSamplerOutput`objects into a
|
|
683
|
+
single :class:`~torch_geometric.sampler.HeteroSamplerOutput` object.
|
|
684
|
+
Requires that they all have the same fields.
|
|
685
|
+
"""
|
|
686
|
+
# TODO(zaristei)
|
|
687
|
+
raise NotImplementedError
|
|
688
|
+
|
|
689
|
+
def merge_with(self, other: 'HeteroSamplerOutput',
|
|
690
|
+
replace: bool = True) -> 'HeteroSamplerOutput':
|
|
691
|
+
"""Merges two HeteroSamplerOutputs.
|
|
692
|
+
If replace is True, self's nodes and edges take precedence.
|
|
693
|
+
"""
|
|
694
|
+
# TODO(zaristei)
|
|
695
|
+
raise NotImplementedError
|
|
696
|
+
|
|
378
697
|
|
|
379
698
|
@dataclass(frozen=True)
|
|
380
699
|
class NumNeighbors:
|
|
@@ -423,7 +742,15 @@ class NumNeighbors:
|
|
|
423
742
|
elif isinstance(self.values, dict):
|
|
424
743
|
default = self.default
|
|
425
744
|
else:
|
|
426
|
-
|
|
745
|
+
raise AssertionError()
|
|
746
|
+
|
|
747
|
+
# Confirm that `values` only hold valid edge types:
|
|
748
|
+
if isinstance(self.values, dict):
|
|
749
|
+
edge_types_str = {EdgeTypeStr(key) for key in edge_types}
|
|
750
|
+
invalid_edge_types = set(self.values.keys()) - edge_types_str
|
|
751
|
+
if len(invalid_edge_types) > 0:
|
|
752
|
+
raise ValueError("Not all edge types specified in "
|
|
753
|
+
"'num_neighbors' exist in the graph")
|
|
427
754
|
|
|
428
755
|
out = {}
|
|
429
756
|
for edge_type in edge_types:
|
|
@@ -614,6 +941,7 @@ class BaseSampler(ABC):
|
|
|
614
941
|
As such, it is recommended to limit the amount of information stored in
|
|
615
942
|
the sampler.
|
|
616
943
|
"""
|
|
944
|
+
@abstractmethod
|
|
617
945
|
def sample_from_nodes(
|
|
618
946
|
self,
|
|
619
947
|
index: NodeSamplerInput,
|
|
@@ -634,6 +962,7 @@ class BaseSampler(ABC):
|
|
|
634
962
|
"""
|
|
635
963
|
raise NotImplementedError
|
|
636
964
|
|
|
965
|
+
@abstractmethod
|
|
637
966
|
def sample_from_edges(
|
|
638
967
|
self,
|
|
639
968
|
index: EdgeSamplerInput,
|
|
@@ -1,12 +1,15 @@
|
|
|
1
|
-
from typing import Dict, List, Union
|
|
1
|
+
from typing import Dict, List, Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from torch_geometric.data import Data, HeteroData
|
|
6
6
|
from torch_geometric.sampler import (
|
|
7
7
|
BaseSampler,
|
|
8
|
+
EdgeSamplerInput,
|
|
8
9
|
HeteroSamplerOutput,
|
|
10
|
+
NegativeSampling,
|
|
9
11
|
NodeSamplerInput,
|
|
12
|
+
SamplerOutput,
|
|
10
13
|
)
|
|
11
14
|
from torch_geometric.sampler.utils import remap_keys, to_hetero_csc
|
|
12
15
|
from torch_geometric.typing import (
|
|
@@ -76,6 +79,13 @@ class HGTSampler(BaseSampler):
|
|
|
76
79
|
metadata=(inputs.input_id, inputs.time),
|
|
77
80
|
)
|
|
78
81
|
|
|
82
|
+
def sample_from_edges(
|
|
83
|
+
self,
|
|
84
|
+
index: EdgeSamplerInput,
|
|
85
|
+
neg_sampling: Optional[NegativeSampling] = None,
|
|
86
|
+
) -> Union[HeteroSamplerOutput, SamplerOutput]:
|
|
87
|
+
pass
|
|
88
|
+
|
|
79
89
|
@property
|
|
80
90
|
def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:
|
|
81
91
|
return self.perm
|