pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/sampler/base.py
CHANGED
@@ -5,7 +5,7 @@ from abc import ABC
|
|
5
5
|
from collections import defaultdict
|
6
6
|
from dataclasses import dataclass
|
7
7
|
from enum import Enum
|
8
|
-
from typing import Any, Dict, List, Optional, Union
|
8
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
9
9
|
|
10
10
|
import torch
|
11
11
|
from torch import Tensor
|
@@ -425,6 +425,14 @@ class NumNeighbors:
|
|
425
425
|
else:
|
426
426
|
assert False
|
427
427
|
|
428
|
+
# Confirm that `values` only hold valid edge types:
|
429
|
+
if isinstance(self.values, dict):
|
430
|
+
edge_types_str = {EdgeTypeStr(key) for key in edge_types}
|
431
|
+
invalid_edge_types = set(self.values.keys()) - edge_types_str
|
432
|
+
if len(invalid_edge_types) > 0:
|
433
|
+
raise ValueError("Not all edge types specified in "
|
434
|
+
"'num_neighbors' exist in the graph")
|
435
|
+
|
428
436
|
out = {}
|
429
437
|
for edge_type in edge_types:
|
430
438
|
edge_type_str = EdgeTypeStr(edge_type)
|
@@ -444,7 +452,7 @@ class NumNeighbors:
|
|
444
452
|
out = copy.copy(self.values)
|
445
453
|
|
446
454
|
if isinstance(out, dict):
|
447
|
-
num_hops =
|
455
|
+
num_hops = {len(v) for v in out.values()}
|
448
456
|
if len(num_hops) > 1:
|
449
457
|
raise ValueError(f"Number of hops must be the same across all "
|
450
458
|
f"edge types (got {len(num_hops)} different "
|
@@ -533,24 +541,31 @@ class NegativeSampling(CastMixin):
|
|
533
541
|
destination nodes for each positive source node.
|
534
542
|
amount (int or float, optional): The ratio of sampled negative edges to
|
535
543
|
the number of positive edges. (default: :obj:`1`)
|
536
|
-
|
537
|
-
sampling of nodes. Does not
|
538
|
-
If not given, negative nodes will be sampled uniformly.
|
544
|
+
src_weight (torch.Tensor, optional): A node-level vector determining
|
545
|
+
the sampling of source nodes. Does not necessarily need to sum up
|
546
|
+
to one. If not given, negative nodes will be sampled uniformly.
|
547
|
+
(default: :obj:`None`)
|
548
|
+
dst_weight (torch.Tensor, optional): A node-level vector determining
|
549
|
+
the sampling of destination nodes. Does not necessarily need to sum
|
550
|
+
up to one. If not given, negative nodes will be sampled uniformly.
|
539
551
|
(default: :obj:`None`)
|
540
552
|
"""
|
541
553
|
mode: NegativeSamplingMode
|
542
554
|
amount: Union[int, float] = 1
|
543
|
-
|
555
|
+
src_weight: Optional[Tensor] = None
|
556
|
+
dst_weight: Optional[Tensor] = None
|
544
557
|
|
545
558
|
def __init__(
|
546
559
|
self,
|
547
560
|
mode: Union[NegativeSamplingMode, str],
|
548
561
|
amount: Union[int, float] = 1,
|
549
|
-
|
562
|
+
src_weight: Optional[Tensor] = None,
|
563
|
+
dst_weight: Optional[Tensor] = None,
|
550
564
|
):
|
551
565
|
self.mode = NegativeSamplingMode(mode)
|
552
566
|
self.amount = amount
|
553
|
-
self.
|
567
|
+
self.src_weight = src_weight
|
568
|
+
self.dst_weight = dst_weight
|
554
569
|
|
555
570
|
if self.amount <= 0:
|
556
571
|
raise ValueError(f"The attribute 'amount' needs to be positive "
|
@@ -571,22 +586,28 @@ class NegativeSampling(CastMixin):
|
|
571
586
|
def is_triplet(self) -> bool:
|
572
587
|
return self.mode == NegativeSamplingMode.triplet
|
573
588
|
|
574
|
-
def sample(
|
575
|
-
|
589
|
+
def sample(
|
590
|
+
self,
|
591
|
+
num_samples: int,
|
592
|
+
endpoint: Literal['src', 'dst'],
|
593
|
+
num_nodes: Optional[int] = None,
|
594
|
+
) -> Tensor:
|
576
595
|
r"""Generates :obj:`num_samples` negative samples."""
|
577
|
-
|
596
|
+
weight = self.src_weight if endpoint == 'src' else self.dst_weight
|
597
|
+
|
598
|
+
if weight is None:
|
578
599
|
if num_nodes is None:
|
579
600
|
raise ValueError(
|
580
601
|
f"Cannot sample negatives in '{self.__class__.__name__}' "
|
581
602
|
f"without passing the 'num_nodes' argument")
|
582
603
|
return torch.randint(num_nodes, (num_samples, ))
|
583
604
|
|
584
|
-
if num_nodes is not None and
|
605
|
+
if num_nodes is not None and weight.numel() != num_nodes:
|
585
606
|
raise ValueError(
|
586
607
|
f"The 'weight' attribute in '{self.__class__.__name__}' "
|
587
608
|
f"needs to match the number of nodes {num_nodes} "
|
588
609
|
f"(got {self.weight.numel()})")
|
589
|
-
return torch.multinomial(
|
610
|
+
return torch.multinomial(weight, num_samples, replacement=True)
|
590
611
|
|
591
612
|
|
592
613
|
class BaseSampler(ABC):
|
@@ -2,7 +2,7 @@ import copy
|
|
2
2
|
import math
|
3
3
|
import sys
|
4
4
|
import warnings
|
5
|
-
from typing import Callable, Dict, List, Optional, Tuple, Union
|
5
|
+
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import Tensor
|
@@ -168,7 +168,7 @@ class NeighborSampler(BaseSampler):
|
|
168
168
|
attrs = [attr for attr in feature_store.get_all_tensor_attrs()]
|
169
169
|
|
170
170
|
edge_attrs = graph_store.get_all_edge_attrs()
|
171
|
-
self.edge_types = list(
|
171
|
+
self.edge_types = list({attr.edge_type for attr in edge_attrs})
|
172
172
|
|
173
173
|
if weight_attr is not None:
|
174
174
|
raise NotImplementedError(
|
@@ -593,7 +593,7 @@ def edge_sample(
|
|
593
593
|
src_node_time = node_time
|
594
594
|
|
595
595
|
src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time,
|
596
|
-
src_node_time)
|
596
|
+
src_node_time, endpoint='src')
|
597
597
|
src = torch.cat([src, src_neg], dim=0)
|
598
598
|
|
599
599
|
if isinstance(node_time, dict):
|
@@ -602,7 +602,7 @@ def edge_sample(
|
|
602
602
|
dst_node_time = node_time
|
603
603
|
|
604
604
|
dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,
|
605
|
-
dst_node_time)
|
605
|
+
dst_node_time, endpoint='dst')
|
606
606
|
dst = torch.cat([dst, dst_neg], dim=0)
|
607
607
|
|
608
608
|
if edge_label is None:
|
@@ -623,7 +623,7 @@ def edge_sample(
|
|
623
623
|
dst_node_time = node_time
|
624
624
|
|
625
625
|
dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,
|
626
|
-
dst_node_time)
|
626
|
+
dst_node_time, endpoint='dst')
|
627
627
|
dst = torch.cat([dst, dst_neg], dim=0)
|
628
628
|
|
629
629
|
assert edge_label is None
|
@@ -631,7 +631,7 @@ def edge_sample(
|
|
631
631
|
if edge_label_time is not None:
|
632
632
|
dst_time = edge_label_time.repeat(1 + neg_sampling.amount)
|
633
633
|
|
634
|
-
#
|
634
|
+
# Heterogeneous Neighborhood Sampling #####################################
|
635
635
|
|
636
636
|
if input_type is not None:
|
637
637
|
seed_time_dict = None
|
@@ -724,7 +724,7 @@ def edge_sample(
|
|
724
724
|
src_time,
|
725
725
|
)
|
726
726
|
|
727
|
-
#
|
727
|
+
# Homogeneous Neighborhood Sampling #######################################
|
728
728
|
|
729
729
|
else:
|
730
730
|
|
@@ -781,12 +781,13 @@ def neg_sample(
|
|
781
781
|
num_nodes: int,
|
782
782
|
seed_time: Optional[Tensor],
|
783
783
|
node_time: Optional[Tensor],
|
784
|
+
endpoint: Literal['str', 'dst'],
|
784
785
|
) -> Tensor:
|
785
786
|
num_neg = math.ceil(seed.numel() * neg_sampling.amount)
|
786
787
|
|
787
788
|
# TODO: Do not sample false negatives.
|
788
789
|
if node_time is None:
|
789
|
-
return neg_sampling.sample(num_neg, num_nodes)
|
790
|
+
return neg_sampling.sample(num_neg, endpoint, num_nodes)
|
790
791
|
|
791
792
|
# If we are in a temporal-sampling scenario, we need to respect the
|
792
793
|
# timestamp of the given nodes we can use as negative examples.
|
@@ -800,7 +801,7 @@ def neg_sample(
|
|
800
801
|
num_samples = math.ceil(neg_sampling.amount)
|
801
802
|
seed_time = seed_time.view(1, -1).expand(num_samples, -1)
|
802
803
|
|
803
|
-
out = neg_sampling.sample(num_samples * seed.numel(), num_nodes)
|
804
|
+
out = neg_sampling.sample(num_samples * seed.numel(), endpoint, num_nodes)
|
804
805
|
out = out.view(num_samples, seed.numel())
|
805
806
|
mask = node_time[out] > seed_time # holds all invalid samples.
|
806
807
|
neg_sampling_complete = False
|
@@ -811,7 +812,7 @@ def neg_sample(
|
|
811
812
|
break
|
812
813
|
|
813
814
|
# Greedily search for alternative negatives.
|
814
|
-
out[mask] = tmp = neg_sampling.sample(num_invalid, num_nodes)
|
815
|
+
out[mask] = tmp = neg_sampling.sample(num_invalid, endpoint, num_nodes)
|
815
816
|
mask[mask.clone()] = node_time[tmp] >= seed_time[mask]
|
816
817
|
|
817
818
|
if not neg_sampling_complete: # pragma: no cover
|
torch_geometric/sampler/utils.py
CHANGED
@@ -5,9 +5,9 @@ from torch import Tensor
|
|
5
5
|
|
6
6
|
from torch_geometric.data import Data, HeteroData
|
7
7
|
from torch_geometric.data.storage import EdgeStorage
|
8
|
+
from torch_geometric.index import index2ptr
|
8
9
|
from torch_geometric.typing import EdgeType, NodeType, OptTensor
|
9
10
|
from torch_geometric.utils import coalesce, index_sort, lexsort
|
10
|
-
from torch_geometric.utils.sparse import index2ptr
|
11
11
|
|
12
12
|
# Edge Layout Conversion ######################################################
|
13
13
|
|
torch_geometric/template.py
CHANGED
@@ -10,7 +10,8 @@ from .decorators import (
|
|
10
10
|
onlyDistributedTest,
|
11
11
|
onlyLinux,
|
12
12
|
noWindows,
|
13
|
-
|
13
|
+
noMac,
|
14
|
+
minPython,
|
14
15
|
onlyCUDA,
|
15
16
|
onlyXPU,
|
16
17
|
onlyOnline,
|
@@ -18,6 +19,7 @@ from .decorators import (
|
|
18
19
|
onlyNeighborSampler,
|
19
20
|
has_package,
|
20
21
|
withPackage,
|
22
|
+
withDevice,
|
21
23
|
withCUDA,
|
22
24
|
withMETIS,
|
23
25
|
disableExtensions,
|
@@ -39,7 +41,8 @@ __all__ = [
|
|
39
41
|
'onlyDistributedTest',
|
40
42
|
'onlyLinux',
|
41
43
|
'noWindows',
|
42
|
-
'
|
44
|
+
'noMac',
|
45
|
+
'minPython',
|
43
46
|
'onlyCUDA',
|
44
47
|
'onlyXPU',
|
45
48
|
'onlyOnline',
|
@@ -47,6 +50,7 @@ __all__ = [
|
|
47
50
|
'onlyNeighborSampler',
|
48
51
|
'has_package',
|
49
52
|
'withPackage',
|
53
|
+
'withDevice',
|
50
54
|
'withCUDA',
|
51
55
|
'withMETIS',
|
52
56
|
'disableExtensions',
|
@@ -7,7 +7,9 @@ from typing import Callable
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
from packaging.requirements import Requirement
|
10
|
+
from packaging.version import Version
|
10
11
|
|
12
|
+
import torch_geometric
|
11
13
|
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
|
12
14
|
from torch_geometric.visualization.graph import has_graphviz
|
13
15
|
|
@@ -67,15 +69,34 @@ def noWindows(func: Callable) -> Callable:
|
|
67
69
|
)(func)
|
68
70
|
|
69
71
|
|
70
|
-
def
|
72
|
+
def noMac(func: Callable) -> Callable:
|
73
|
+
r"""A decorator to specify that this function should not execute on
|
74
|
+
macOS systems.
|
75
|
+
"""
|
76
|
+
import pytest
|
77
|
+
return pytest.mark.skipif(
|
78
|
+
sys.platform == 'darwin',
|
79
|
+
reason="macOS system",
|
80
|
+
)(func)
|
81
|
+
|
82
|
+
|
83
|
+
def minPython(version: str) -> Callable:
|
71
84
|
r"""A decorator to run tests on specific :python:`Python` versions only."""
|
72
85
|
def decorator(func: Callable) -> Callable:
|
73
86
|
import pytest
|
74
87
|
|
75
|
-
|
88
|
+
major, minor = version.split('.')
|
89
|
+
|
90
|
+
skip = False
|
91
|
+
if sys.version_info.major < int(major):
|
92
|
+
skip = True
|
93
|
+
if (sys.version_info.major == int(major)
|
94
|
+
and sys.version_info.minor < int(minor)):
|
95
|
+
skip = True
|
96
|
+
|
76
97
|
return pytest.mark.skipif(
|
77
|
-
|
78
|
-
reason=f"Python {
|
98
|
+
skip,
|
99
|
+
reason=f"Python {version} required",
|
79
100
|
)(func)
|
80
101
|
|
81
102
|
return decorator
|
@@ -93,13 +114,8 @@ def onlyCUDA(func: Callable) -> Callable:
|
|
93
114
|
def onlyXPU(func: Callable) -> Callable:
|
94
115
|
r"""A decorator to skip tests if XPU is not found."""
|
95
116
|
import pytest
|
96
|
-
try:
|
97
|
-
import intel_extension_for_pytorch as ipex
|
98
|
-
xpu_available = ipex.xpu.is_available()
|
99
|
-
except ImportError:
|
100
|
-
xpu_available = False
|
101
117
|
return pytest.mark.skipif(
|
102
|
-
not
|
118
|
+
not torch_geometric.is_xpu_available(),
|
103
119
|
reason="XPU not available",
|
104
120
|
)(func)
|
105
121
|
|
@@ -157,24 +173,23 @@ def has_package(package: str) -> bool:
|
|
157
173
|
req = Requirement(package)
|
158
174
|
if find_spec(req.name) is None:
|
159
175
|
return False
|
160
|
-
module = import_module(req.name)
|
161
|
-
if not hasattr(module, '__version__'):
|
162
|
-
return True
|
163
176
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
version = '.'.join(version.split('.dev')[:-1])
|
177
|
+
try:
|
178
|
+
module = import_module(req.name)
|
179
|
+
if not hasattr(module, '__version__'):
|
180
|
+
return True
|
169
181
|
|
170
|
-
|
182
|
+
version = Version(module.__version__).base_version
|
183
|
+
return version in req.specifier
|
184
|
+
except Exception:
|
185
|
+
return False
|
171
186
|
|
172
187
|
|
173
188
|
def withPackage(*args: str) -> Callable:
|
174
189
|
r"""A decorator to skip tests if certain packages are not installed.
|
175
190
|
Also supports version specification.
|
176
191
|
"""
|
177
|
-
na_packages =
|
192
|
+
na_packages = {package for package in args if not has_package(package)}
|
178
193
|
|
179
194
|
if len(na_packages) == 1:
|
180
195
|
reason = f"Package {list(na_packages)[0]} not found"
|
@@ -196,6 +211,24 @@ def withCUDA(func: Callable) -> Callable:
|
|
196
211
|
if torch.cuda.is_available():
|
197
212
|
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
|
198
213
|
|
214
|
+
return pytest.mark.parametrize('device', devices)(func)
|
215
|
+
|
216
|
+
|
217
|
+
def withDevice(func: Callable) -> Callable:
|
218
|
+
r"""A decorator to test on all available tensor processing devices."""
|
219
|
+
import pytest
|
220
|
+
|
221
|
+
devices = [pytest.param(torch.device('cpu'), id='cpu')]
|
222
|
+
|
223
|
+
if torch.cuda.is_available():
|
224
|
+
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
|
225
|
+
|
226
|
+
if torch_geometric.is_mps_available():
|
227
|
+
devices.append(pytest.param(torch.device('mps:0'), id='mps'))
|
228
|
+
|
229
|
+
if torch_geometric.is_xpu_available():
|
230
|
+
devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))
|
231
|
+
|
199
232
|
# Additional devices can be registered through environment variables:
|
200
233
|
device = os.getenv('TORCH_DEVICE')
|
201
234
|
if device:
|
@@ -20,6 +20,7 @@ from .target_indegree import TargetIndegree
|
|
20
20
|
from .local_degree_profile import LocalDegreeProfile
|
21
21
|
from .add_self_loops import AddSelfLoops
|
22
22
|
from .add_remaining_self_loops import AddRemainingSelfLoops
|
23
|
+
from .remove_self_loops import RemoveSelfLoops
|
23
24
|
from .remove_isolated_nodes import RemoveIsolatedNodes
|
24
25
|
from .remove_duplicated_edges import RemoveDuplicatedEdges
|
25
26
|
from .knn_graph import KNNGraph
|
@@ -87,6 +88,7 @@ graph_transforms = [
|
|
87
88
|
'LocalDegreeProfile',
|
88
89
|
'AddSelfLoops',
|
89
90
|
'AddRemainingSelfLoops',
|
91
|
+
'RemoveSelfLoops',
|
90
92
|
'RemoveIsolatedNodes',
|
91
93
|
'RemoveDuplicatedEdges',
|
92
94
|
'KNNGraph',
|
@@ -37,7 +37,7 @@ class AddMetaPaths(BaseTransform):
|
|
37
37
|
:class:`~torch_geometric.data.HeteroData` object as edge type
|
38
38
|
:obj:`(src_node_type, "metapath_*", dst_node_type)`, where
|
39
39
|
:obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\mathcal{V}_1`
|
40
|
-
and :math:`\mathcal{V}_{\ell}`,
|
40
|
+
and :math:`\mathcal{V}_{\ell}`, respectively.
|
41
41
|
|
42
42
|
In addition, a :obj:`metapath_dict` object is added to the
|
43
43
|
:class:`~torch_geometric.data.HeteroData` object which maps the
|
@@ -108,12 +108,12 @@ class AddMetaPaths(BaseTransform):
|
|
108
108
|
**kwargs: bool,
|
109
109
|
) -> None:
|
110
110
|
if 'drop_orig_edges' in kwargs:
|
111
|
-
warnings.warn("'drop_orig_edges' is
|
111
|
+
warnings.warn("'drop_orig_edges' is deprecated. Use "
|
112
112
|
"'drop_orig_edge_types' instead")
|
113
113
|
drop_orig_edge_types = kwargs['drop_orig_edges']
|
114
114
|
|
115
115
|
if 'drop_unconnected_nodes' in kwargs:
|
116
|
-
warnings.warn("'drop_unconnected_nodes' is
|
116
|
+
warnings.warn("'drop_unconnected_nodes' is deprecated. Use "
|
117
117
|
"'drop_unconnected_node_types' instead")
|
118
118
|
drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
|
119
119
|
|
@@ -158,7 +158,7 @@ class AddMetaPaths(BaseTransform):
|
|
158
158
|
edge_index, edge_weight)
|
159
159
|
|
160
160
|
new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
|
161
|
-
data[new_edge_type].edge_index = edge_index
|
161
|
+
data[new_edge_type].edge_index = edge_index.as_tensor()
|
162
162
|
if self.weighted:
|
163
163
|
data[new_edge_type].edge_weight = edge_weight
|
164
164
|
data.metapath_dict[new_edge_type] = metapath
|
@@ -231,7 +231,7 @@ class AddRandomMetaPaths(BaseTransform):
|
|
231
231
|
will drop node types not connected by any edge type.
|
232
232
|
(default: :obj:`False`)
|
233
233
|
walks_per_node (int, List[int], optional): The number of random walks
|
234
|
-
for each starting node in a
|
234
|
+
for each starting node in a metapath. (default: :obj:`1`)
|
235
235
|
sample_ratio (float, optional): The ratio of source nodes to start
|
236
236
|
random walks from. (default: :obj:`1.0`)
|
237
237
|
"""
|
@@ -92,7 +92,7 @@ class AddLaplacianEigenvectorPE(BaseTransform):
|
|
92
92
|
from numpy.linalg import eig, eigh
|
93
93
|
eig_fn = eig if not self.is_undirected else eigh
|
94
94
|
|
95
|
-
eig_vals, eig_vecs = eig_fn(L.todense())
|
95
|
+
eig_vals, eig_vecs = eig_fn(L.todense())
|
96
96
|
else:
|
97
97
|
from scipy.sparse.linalg import eigs, eigsh
|
98
98
|
eig_fn = eigs if not self.is_undirected else eigsh
|
@@ -1,4 +1,5 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
|
+
|
2
3
|
import torch
|
3
4
|
|
4
5
|
from torch_geometric.data import Data
|
@@ -6,28 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
|
|
6
7
|
from torch_geometric.transforms import BaseTransform
|
7
8
|
|
8
9
|
|
10
|
+
class _QhullTransform(BaseTransform):
|
11
|
+
r"""Q-hull implementation of delaunay triangulation."""
|
12
|
+
def forward(self, data: Data) -> Data:
|
13
|
+
assert data.pos is not None
|
14
|
+
import scipy.spatial
|
15
|
+
|
16
|
+
pos = data.pos.cpu().numpy()
|
17
|
+
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
18
|
+
face = torch.from_numpy(tri.simplices)
|
19
|
+
|
20
|
+
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
21
|
+
return data
|
22
|
+
|
23
|
+
|
24
|
+
class _ShullTransform(BaseTransform):
|
25
|
+
r"""Sweep-hull implementation of delaunay triangulation."""
|
26
|
+
def forward(self, data: Data) -> Data:
|
27
|
+
assert data.pos is not None
|
28
|
+
from torch_delaunay.functional import shull2d
|
29
|
+
|
30
|
+
face = shull2d(data.pos.cpu())
|
31
|
+
data.face = face.t().contiguous().to(data.pos.device)
|
32
|
+
return data
|
33
|
+
|
34
|
+
|
35
|
+
class _SequentialTransform(BaseTransform):
|
36
|
+
r"""Runs the first successful transformation.
|
37
|
+
|
38
|
+
All intermediate exceptions are suppressed except the last.
|
39
|
+
"""
|
40
|
+
def __init__(self, transforms: List[BaseTransform]) -> None:
|
41
|
+
assert len(transforms) > 0
|
42
|
+
self.transforms = transforms
|
43
|
+
|
44
|
+
def forward(self, data: Data) -> Data:
|
45
|
+
for i, transform in enumerate(self.transforms):
|
46
|
+
try:
|
47
|
+
return transform.forward(data)
|
48
|
+
except ImportError as e:
|
49
|
+
if i == len(self.transforms) - 1:
|
50
|
+
raise e
|
51
|
+
return data
|
52
|
+
|
53
|
+
|
9
54
|
@functional_transform('delaunay')
|
10
55
|
class Delaunay(BaseTransform):
|
11
56
|
r"""Computes the delaunay triangulation of a set of points
|
12
57
|
(functional name: :obj:`delaunay`).
|
58
|
+
|
59
|
+
.. hint::
|
60
|
+
Consider installing the
|
61
|
+
`torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
|
62
|
+
to speed up computation.
|
13
63
|
"""
|
64
|
+
def __init__(self) -> None:
|
65
|
+
self._transform = _SequentialTransform([
|
66
|
+
_ShullTransform(),
|
67
|
+
_QhullTransform(),
|
68
|
+
])
|
69
|
+
|
14
70
|
def forward(self, data: Data) -> Data:
|
15
71
|
assert data.pos is not None
|
72
|
+
device = data.pos.device
|
16
73
|
|
17
74
|
if data.pos.size(0) < 2:
|
18
|
-
data.edge_index = torch.
|
19
|
-
|
20
|
-
|
21
|
-
data.edge_index = torch.tensor([[0, 1], [1, 0]],
|
22
|
-
device=data.pos.device)
|
75
|
+
data.edge_index = torch.empty(2, 0, dtype=torch.long,
|
76
|
+
device=device)
|
77
|
+
elif data.pos.size(0) == 2:
|
78
|
+
data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
|
23
79
|
elif data.pos.size(0) == 3:
|
24
|
-
data.face = torch.tensor([[0], [1], [2]],
|
25
|
-
|
26
|
-
|
27
|
-
pos = data.pos.cpu().numpy()
|
28
|
-
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
29
|
-
face = torch.from_numpy(tri.simplices)
|
30
|
-
|
31
|
-
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
80
|
+
data.face = torch.tensor([[0], [1], [2]], device=device)
|
81
|
+
else:
|
82
|
+
data = self._transform.forward(data)
|
32
83
|
|
33
84
|
return data
|
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
|
|
8
8
|
|
9
9
|
@functional_transform('face_to_edge')
|
10
10
|
class FaceToEdge(BaseTransform):
|
11
|
-
r"""Converts mesh faces :obj:`[3, num_faces]`
|
12
|
-
:obj:`[
|
11
|
+
r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
|
12
|
+
:obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
|
13
|
+
(functional name: :obj:`face_to_edge`).
|
14
|
+
|
15
|
+
This transform supports both 2D triangular faces, represented by a
|
16
|
+
tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
|
17
|
+
represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
|
18
|
+
these faces into edge indices, where each edge is defined by the indices
|
19
|
+
of its two endpoints.
|
13
20
|
|
14
21
|
Args:
|
15
22
|
remove_faces (bool, optional): If set to :obj:`False`, the face tensor
|
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
|
|
22
29
|
if hasattr(data, 'face'):
|
23
30
|
assert data.face is not None
|
24
31
|
face = data.face
|
25
|
-
|
32
|
+
|
33
|
+
if face.size(0) not in [3, 4]:
|
34
|
+
raise RuntimeError(f"Expected 'face' tensor with shape "
|
35
|
+
f"[3, num_faces] or [4, num_faces] "
|
36
|
+
f"(got {list(face.size())})")
|
37
|
+
|
38
|
+
if face.size()[0] == 3:
|
39
|
+
edge_index = torch.cat([
|
40
|
+
face[:2],
|
41
|
+
face[1:],
|
42
|
+
face[::2],
|
43
|
+
], dim=1)
|
44
|
+
else:
|
45
|
+
assert face.size()[0] == 4
|
46
|
+
edge_index = torch.cat([
|
47
|
+
face[:2],
|
48
|
+
face[1:3],
|
49
|
+
face[2:4],
|
50
|
+
face[::2],
|
51
|
+
face[1::2],
|
52
|
+
face[::3],
|
53
|
+
], dim=1)
|
54
|
+
|
26
55
|
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
|
27
56
|
|
28
57
|
data.edge_index = edge_index
|
@@ -2,7 +2,6 @@ from typing import Any, Dict, Tuple
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import torch
|
5
|
-
from scipy.linalg import expm
|
6
5
|
from torch import Tensor
|
7
6
|
|
8
7
|
from torch_geometric.data import Data
|
@@ -22,7 +21,7 @@ from torch_geometric.utils import (
|
|
22
21
|
@functional_transform('gdc')
|
23
22
|
class GDC(BaseTransform):
|
24
23
|
r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
|
25
|
-
`"Diffusion Improves Graph Learning" <https://
|
24
|
+
`"Diffusion Improves Graph Learning" <https://arxiv.org/abs/1911.05485>`_
|
26
25
|
paper (functional name: :obj:`gdc`).
|
27
26
|
|
28
27
|
.. note::
|
@@ -338,10 +337,10 @@ class GDC(BaseTransform):
|
|
338
337
|
|
339
338
|
elif method == 'heat':
|
340
339
|
raise NotImplementedError(
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
340
|
+
'Currently no fast heat kernel is implemented. You are '
|
341
|
+
'welcome to create one yourself, e.g., based on '
|
342
|
+
'"Kloster and Gleich: Heat kernel based community detection '
|
343
|
+
'(KDD 2014)."')
|
345
344
|
else:
|
346
345
|
raise ValueError(f"Approximate GDC diffusion '{method}' unknown")
|
347
346
|
|
@@ -473,6 +472,8 @@ class GDC(BaseTransform):
|
|
473
472
|
|
474
473
|
:rtype: (:class:`Tensor`)
|
475
474
|
"""
|
475
|
+
from scipy.linalg import expm
|
476
|
+
|
476
477
|
if symmetric:
|
477
478
|
e, V = torch.linalg.eigh(matrix, UPLO='U')
|
478
479
|
diff_mat = V @ torch.diag(e.exp()) @ V.t()
|
@@ -1,7 +1,5 @@
|
|
1
1
|
from typing import Optional
|
2
2
|
|
3
|
-
from scipy.sparse.linalg import eigs, eigsh
|
4
|
-
|
5
3
|
from torch_geometric.data import Data
|
6
4
|
from torch_geometric.data.datapipes import functional_transform
|
7
5
|
from torch_geometric.transforms import BaseTransform
|
@@ -41,6 +39,8 @@ class LaplacianLambdaMax(BaseTransform):
|
|
41
39
|
self.is_undirected = is_undirected
|
42
40
|
|
43
41
|
def forward(self, data: Data) -> Data:
|
42
|
+
from scipy.sparse.linalg import eigs, eigsh
|
43
|
+
|
44
44
|
assert data.edge_index is not None
|
45
45
|
num_nodes = data.num_nodes
|
46
46
|
|
@@ -62,7 +62,7 @@ class LaplacianLambdaMax(BaseTransform):
|
|
62
62
|
eig_fn = eigsh
|
63
63
|
|
64
64
|
lambda_max = eig_fn(L, k=1, which='LM', return_eigenvectors=False)
|
65
|
-
data.lambda_max =
|
65
|
+
data.lambda_max = lambda_max.real.item()
|
66
66
|
|
67
67
|
return data
|
68
68
|
|