pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/typing.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import sys
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
6
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -14,8 +14,10 @@ WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
|
|
|
14
14
|
WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
|
|
15
15
|
WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
|
|
16
16
|
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
|
|
18
|
+
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
|
|
19
|
+
WITH_PT27 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 7
|
|
20
|
+
WITH_PT28 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 8
|
|
19
21
|
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
|
|
20
22
|
|
|
21
23
|
WITH_WINDOWS = os.name == 'nt'
|
|
@@ -62,10 +64,21 @@ try:
|
|
|
62
64
|
pyg_lib.sampler.neighbor_sample).parameters)
|
|
63
65
|
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
|
|
64
66
|
pyg_lib.sampler.neighbor_sample).parameters)
|
|
67
|
+
try:
|
|
68
|
+
torch.classes.pyg.CPUHashMap # noqa: B018
|
|
69
|
+
WITH_CPU_HASH_MAP = True
|
|
70
|
+
except Exception:
|
|
71
|
+
WITH_CPU_HASH_MAP = False
|
|
72
|
+
try:
|
|
73
|
+
torch.classes.pyg.CUDAHashMap # noqa: B018
|
|
74
|
+
WITH_CUDA_HASH_MAP = True
|
|
75
|
+
except Exception:
|
|
76
|
+
WITH_CUDA_HASH_MAP = False
|
|
65
77
|
except Exception as e:
|
|
66
78
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
67
|
-
warnings.warn(
|
|
68
|
-
|
|
79
|
+
warnings.warn(
|
|
80
|
+
f"An issue occurred while importing 'pyg-lib'. "
|
|
81
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
69
82
|
pyg_lib = object
|
|
70
83
|
WITH_PYG_LIB = False
|
|
71
84
|
WITH_GMM = False
|
|
@@ -76,14 +89,41 @@ except Exception as e:
|
|
|
76
89
|
WITH_METIS = False
|
|
77
90
|
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
|
|
78
91
|
WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
|
|
92
|
+
WITH_CPU_HASH_MAP = False
|
|
93
|
+
WITH_CUDA_HASH_MAP = False
|
|
94
|
+
|
|
95
|
+
if WITH_CPU_HASH_MAP:
|
|
96
|
+
CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap # type: ignore[name-defined] # noqa: E501
|
|
97
|
+
else:
|
|
98
|
+
|
|
99
|
+
class CPUHashMap: # type: ignore
|
|
100
|
+
def __init__(self, key: Tensor) -> None:
|
|
101
|
+
raise ImportError("'CPUHashMap' requires 'pyg-lib'")
|
|
102
|
+
|
|
103
|
+
def get(self, query: Tensor) -> Tensor:
|
|
104
|
+
raise ImportError("'CPUHashMap' requires 'pyg-lib'")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
if WITH_CUDA_HASH_MAP:
|
|
108
|
+
CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap # type: ignore[name-defined] # noqa: E501
|
|
109
|
+
else:
|
|
110
|
+
|
|
111
|
+
class CUDAHashMap: # type: ignore
|
|
112
|
+
def __init__(self, key: Tensor) -> None:
|
|
113
|
+
raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
|
|
114
|
+
|
|
115
|
+
def get(self, query: Tensor) -> Tensor:
|
|
116
|
+
raise ImportError("'CUDAHashMap' requires 'pyg-lib'")
|
|
117
|
+
|
|
79
118
|
|
|
80
119
|
try:
|
|
81
120
|
import torch_scatter # noqa
|
|
82
121
|
WITH_TORCH_SCATTER = True
|
|
83
122
|
except Exception as e:
|
|
84
123
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
85
|
-
warnings.warn(
|
|
86
|
-
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"An issue occurred while importing 'torch-scatter'. "
|
|
126
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
87
127
|
torch_scatter = object
|
|
88
128
|
WITH_TORCH_SCATTER = False
|
|
89
129
|
|
|
@@ -93,8 +133,9 @@ try:
|
|
|
93
133
|
WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
|
|
94
134
|
except Exception as e:
|
|
95
135
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
96
|
-
warnings.warn(
|
|
97
|
-
|
|
136
|
+
warnings.warn(
|
|
137
|
+
f"An issue occurred while importing 'torch-cluster'. "
|
|
138
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
98
139
|
WITH_TORCH_CLUSTER = False
|
|
99
140
|
WITH_TORCH_CLUSTER_BATCH_SIZE = False
|
|
100
141
|
|
|
@@ -111,7 +152,7 @@ except Exception as e:
|
|
|
111
152
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
112
153
|
warnings.warn(
|
|
113
154
|
f"An issue occurred while importing 'torch-spline-conv'. "
|
|
114
|
-
f"Disabling its usage. Stacktrace: {e}")
|
|
155
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
115
156
|
WITH_TORCH_SPLINE_CONV = False
|
|
116
157
|
|
|
117
158
|
try:
|
|
@@ -120,8 +161,9 @@ try:
|
|
|
120
161
|
WITH_TORCH_SPARSE = True
|
|
121
162
|
except Exception as e:
|
|
122
163
|
if not isinstance(e, ImportError): # pragma: no cover
|
|
123
|
-
warnings.warn(
|
|
124
|
-
|
|
164
|
+
warnings.warn(
|
|
165
|
+
f"An issue occurred while importing 'torch-sparse'. "
|
|
166
|
+
f"Disabling its usage. Stacktrace: {e}", stacklevel=2)
|
|
125
167
|
WITH_TORCH_SPARSE = False
|
|
126
168
|
|
|
127
169
|
class SparseStorage: # type: ignore
|
|
@@ -305,6 +347,8 @@ class EdgeTypeStr(str):
|
|
|
305
347
|
r"""A helper class to construct serializable edge types by merging an edge
|
|
306
348
|
type tuple into a single string.
|
|
307
349
|
"""
|
|
350
|
+
edge_type: tuple[str, str, str]
|
|
351
|
+
|
|
308
352
|
def __new__(cls, *args: Any) -> 'EdgeTypeStr':
|
|
309
353
|
if isinstance(args[0], (list, tuple)):
|
|
310
354
|
# Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
|
|
@@ -312,27 +356,37 @@ class EdgeTypeStr(str):
|
|
|
312
356
|
|
|
313
357
|
if len(args) == 1 and isinstance(args[0], str):
|
|
314
358
|
arg = args[0] # An edge type string was passed.
|
|
359
|
+
edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
|
|
360
|
+
if len(edge_type) != 3:
|
|
361
|
+
raise ValueError(f"Cannot convert the edge type '{arg}' to a "
|
|
362
|
+
f"tuple since it holds invalid characters")
|
|
315
363
|
|
|
316
364
|
elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
|
|
317
365
|
# A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
|
|
318
|
-
|
|
366
|
+
edge_type = (args[0], DEFAULT_REL, args[1])
|
|
367
|
+
arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
|
|
319
368
|
|
|
320
369
|
elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
|
|
321
370
|
# A `(src, rel, dst)` edge type was passed:
|
|
371
|
+
edge_type = tuple(args)
|
|
322
372
|
arg = EDGE_TYPE_STR_SPLIT.join(args)
|
|
323
373
|
|
|
324
374
|
else:
|
|
325
375
|
raise ValueError(f"Encountered invalid edge type '{args}'")
|
|
326
376
|
|
|
327
|
-
|
|
377
|
+
out = str.__new__(cls, arg)
|
|
378
|
+
out.edge_type = edge_type # type: ignore
|
|
379
|
+
return out
|
|
328
380
|
|
|
329
381
|
def to_tuple(self) -> EdgeType:
|
|
330
382
|
r"""Returns the original edge type."""
|
|
331
|
-
|
|
332
|
-
if len(out) != 3:
|
|
383
|
+
if len(self.edge_type) != 3:
|
|
333
384
|
raise ValueError(f"Cannot convert the edge type '{self}' to a "
|
|
334
385
|
f"tuple since it holds invalid characters")
|
|
335
|
-
return
|
|
386
|
+
return self.edge_type
|
|
387
|
+
|
|
388
|
+
def __reduce__(self) -> tuple[Any, Any]:
|
|
389
|
+
return (self.__class__, (self.edge_type, ))
|
|
336
390
|
|
|
337
391
|
|
|
338
392
|
# There exist some short-cuts to query edge-types (given that the full triplet
|
|
@@ -370,3 +424,14 @@ MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]]
|
|
|
370
424
|
|
|
371
425
|
InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
|
|
372
426
|
InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]
|
|
427
|
+
|
|
428
|
+
# Serialization ###############################################################
|
|
429
|
+
|
|
430
|
+
if WITH_PT24:
|
|
431
|
+
torch.serialization.add_safe_globals([
|
|
432
|
+
SparseTensor,
|
|
433
|
+
SparseStorage,
|
|
434
|
+
TensorFrame,
|
|
435
|
+
MockTorchCSCTensor,
|
|
436
|
+
EdgeTypeStr,
|
|
437
|
+
])
|
|
@@ -21,6 +21,7 @@ from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph,
|
|
|
21
21
|
from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
|
|
22
22
|
from ._homophily import homophily
|
|
23
23
|
from ._assortativity import assortativity
|
|
24
|
+
from ._normalize_edge_index import normalize_edge_index
|
|
24
25
|
from .laplacian import get_laplacian
|
|
25
26
|
from .mesh_laplacian import get_mesh_laplacian
|
|
26
27
|
from .mask import mask_select, index_to_mask, mask_to_index
|
|
@@ -52,10 +53,11 @@ from ._negative_sampling import (negative_sampling, batched_negative_sampling,
|
|
|
52
53
|
structured_negative_sampling_feasible)
|
|
53
54
|
from .augmentation import shuffle_node, mask_feature, add_random_edge
|
|
54
55
|
from ._tree_decomposition import tree_decomposition
|
|
55
|
-
from .embedding import get_embeddings
|
|
56
|
+
from .embedding import get_embeddings, get_embeddings_hetero
|
|
56
57
|
from ._trim_to_layer import trim_to_layer
|
|
57
58
|
from .ppr import get_ppr
|
|
58
59
|
from ._train_test_split_edges import train_test_split_edges
|
|
60
|
+
from .influence import total_influence
|
|
59
61
|
|
|
60
62
|
__all__ = [
|
|
61
63
|
'scatter',
|
|
@@ -89,6 +91,7 @@ __all__ = [
|
|
|
89
91
|
'dropout_adj',
|
|
90
92
|
'homophily',
|
|
91
93
|
'assortativity',
|
|
94
|
+
'normalize_edge_index',
|
|
92
95
|
'get_laplacian',
|
|
93
96
|
'get_mesh_laplacian',
|
|
94
97
|
'mask_select',
|
|
@@ -143,9 +146,11 @@ __all__ = [
|
|
|
143
146
|
'add_random_edge',
|
|
144
147
|
'tree_decomposition',
|
|
145
148
|
'get_embeddings',
|
|
149
|
+
'get_embeddings_hetero',
|
|
146
150
|
'trim_to_layer',
|
|
147
151
|
'get_ppr',
|
|
148
152
|
'train_test_split_edges',
|
|
153
|
+
'total_influence',
|
|
149
154
|
]
|
|
150
155
|
|
|
151
156
|
# `structured_negative_sampling_feasible` is a long name and thus destroys the
|
|
@@ -1,11 +1,7 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
3
|
from torch import Tensor
|
|
6
4
|
|
|
7
|
-
import torch_geometric.typing
|
|
8
|
-
|
|
9
5
|
|
|
10
6
|
def lexsort(
|
|
11
7
|
keys: List[Tensor],
|
|
@@ -28,11 +24,6 @@ def lexsort(
|
|
|
28
24
|
"""
|
|
29
25
|
assert len(keys) >= 1
|
|
30
26
|
|
|
31
|
-
if not torch_geometric.typing.WITH_PT113:
|
|
32
|
-
keys = [k.neg() for k in keys] if descending else keys
|
|
33
|
-
out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
|
|
34
|
-
return torch.from_numpy(out).to(keys[0].device)
|
|
35
|
-
|
|
36
27
|
out = keys[0].argsort(dim=dim, descending=descending, stable=True)
|
|
37
28
|
for k in keys[1:]:
|
|
38
29
|
index = k.gather(dim, out)
|
|
@@ -12,7 +12,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|
|
12
12
|
def negative_sampling(
|
|
13
13
|
edge_index: Tensor,
|
|
14
14
|
num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
|
|
15
|
-
num_neg_samples: Optional[int] = None,
|
|
15
|
+
num_neg_samples: Optional[Union[int, float]] = None,
|
|
16
16
|
method: str = "sparse",
|
|
17
17
|
force_undirected: bool = False,
|
|
18
18
|
) -> Tensor:
|
|
@@ -25,10 +25,12 @@ def negative_sampling(
|
|
|
25
25
|
If given as a tuple, then :obj:`edge_index` is interpreted as a
|
|
26
26
|
bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.
|
|
27
27
|
(default: :obj:`None`)
|
|
28
|
-
num_neg_samples (int, optional): The (approximate) number of
|
|
29
|
-
samples to return.
|
|
30
|
-
|
|
31
|
-
positive
|
|
28
|
+
num_neg_samples (int or float, optional): The (approximate) number of
|
|
29
|
+
negative samples to return. If set to a floating-point value, it
|
|
30
|
+
represents the ratio of negative samples to generate based on the
|
|
31
|
+
number of positive edges. If set to :obj:`None`, will try to
|
|
32
|
+
return a negative edge for every positive edge.
|
|
33
|
+
(default: :obj:`None`)
|
|
32
34
|
method (str, optional): The method to use for negative sampling,
|
|
33
35
|
*i.e.* :obj:`"sparse"` or :obj:`"dense"`.
|
|
34
36
|
This is a memory/runtime trade-off.
|
|
@@ -48,6 +50,11 @@ def negative_sampling(
|
|
|
48
50
|
tensor([[3, 0, 0, 3],
|
|
49
51
|
[2, 3, 2, 1]])
|
|
50
52
|
|
|
53
|
+
>>> negative_sampling(edge_index, num_nodes=(3, 4),
|
|
54
|
+
... num_neg_samples=0.5) # 50% of positive edges
|
|
55
|
+
tensor([[0, 3],
|
|
56
|
+
[3, 0]])
|
|
57
|
+
|
|
51
58
|
>>> # For bipartite graph
|
|
52
59
|
>>> negative_sampling(edge_index, num_nodes=(3, 4))
|
|
53
60
|
tensor([[0, 2, 2, 1],
|
|
@@ -74,6 +81,8 @@ def negative_sampling(
|
|
|
74
81
|
|
|
75
82
|
if num_neg_samples is None:
|
|
76
83
|
num_neg_samples = edge_index.size(1)
|
|
84
|
+
elif isinstance(num_neg_samples, float):
|
|
85
|
+
num_neg_samples = int(num_neg_samples * edge_index.size(1))
|
|
77
86
|
if force_undirected:
|
|
78
87
|
num_neg_samples = num_neg_samples // 2
|
|
79
88
|
|
|
@@ -100,10 +109,9 @@ def negative_sampling(
|
|
|
100
109
|
idx = idx.to('cpu')
|
|
101
110
|
for _ in range(3): # Number of tries to sample negative indices.
|
|
102
111
|
rnd = sample(population, sample_size, device='cpu')
|
|
103
|
-
mask = np.isin(rnd.numpy(), idx.numpy())
|
|
112
|
+
mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
|
|
104
113
|
if neg_idx is not None:
|
|
105
|
-
mask |= np.isin(rnd, neg_idx.
|
|
106
|
-
mask = torch.from_numpy(mask).to(torch.bool)
|
|
114
|
+
mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
|
|
107
115
|
rnd = rnd[~mask].to(edge_index.device)
|
|
108
116
|
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
|
|
109
117
|
if neg_idx.numel() >= num_neg_samples:
|
|
@@ -117,7 +125,7 @@ def negative_sampling(
|
|
|
117
125
|
def batched_negative_sampling(
|
|
118
126
|
edge_index: Tensor,
|
|
119
127
|
batch: Union[Tensor, Tuple[Tensor, Tensor]],
|
|
120
|
-
num_neg_samples: Optional[int] = None,
|
|
128
|
+
num_neg_samples: Optional[Union[int, float]] = None,
|
|
121
129
|
method: str = "sparse",
|
|
122
130
|
force_undirected: bool = False,
|
|
123
131
|
) -> Tensor:
|
|
@@ -131,9 +139,11 @@ def batched_negative_sampling(
|
|
|
131
139
|
node to a specific example.
|
|
132
140
|
If given as a tuple, then :obj:`edge_index` is interpreted as a
|
|
133
141
|
bipartite graph connecting two different node types.
|
|
134
|
-
num_neg_samples (int, optional): The number of negative
|
|
135
|
-
return. If set to :obj:`None`, will try to return a
|
|
136
|
-
for every positive edge.
|
|
142
|
+
num_neg_samples (int or float, optional): The number of negative
|
|
143
|
+
samples to return. If set to :obj:`None`, will try to return a
|
|
144
|
+
negative edge for every positive edge. If float, it will generate
|
|
145
|
+
:obj:`num_neg_samples * num_edges` negative samples.
|
|
146
|
+
(default: :obj:`None`)
|
|
137
147
|
method (str, optional): The method to use for negative sampling,
|
|
138
148
|
*i.e.* :obj:`"sparse"` or :obj:`"dense"`.
|
|
139
149
|
This is a memory/runtime trade-off.
|
|
@@ -157,6 +167,11 @@ def batched_negative_sampling(
|
|
|
157
167
|
tensor([[3, 1, 3, 2, 7, 7, 6, 5],
|
|
158
168
|
[2, 0, 1, 1, 5, 6, 4, 4]])
|
|
159
169
|
|
|
170
|
+
>>> # Using float multiplier for negative samples
|
|
171
|
+
>>> batched_negative_sampling(edge_index, batch, num_neg_samples=1.5)
|
|
172
|
+
tensor([[3, 1, 3, 2, 7, 7, 6, 5, 2, 0, 1, 1],
|
|
173
|
+
[2, 0, 1, 1, 5, 6, 4, 4, 3, 2, 3, 0]])
|
|
174
|
+
|
|
160
175
|
>>> # For bipartite graph
|
|
161
176
|
>>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])
|
|
162
177
|
>>> edge_index2 = edge_index1 + torch.tensor([[2], [4]])
|
|
@@ -265,7 +280,7 @@ def structured_negative_sampling_feasible(
|
|
|
265
280
|
:meth:`~torch_geometric.utils.structured_negative_sampling` is feasible
|
|
266
281
|
on the graph given by :obj:`edge_index`.
|
|
267
282
|
:meth:`~torch_geometric.utils.structured_negative_sampling` is infeasible
|
|
268
|
-
if
|
|
283
|
+
if at least one node is connected to all other nodes.
|
|
269
284
|
|
|
270
285
|
Args:
|
|
271
286
|
edge_index (LongTensor): The edge indices.
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from torch_geometric.utils import add_self_loops as add_self_loops_fn
|
|
7
|
+
from torch_geometric.utils import degree
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def normalize_edge_index(
|
|
11
|
+
edge_index: Tensor,
|
|
12
|
+
num_nodes: Optional[int] = None,
|
|
13
|
+
add_self_loops: bool = True,
|
|
14
|
+
symmetric: bool = True,
|
|
15
|
+
) -> Tuple[Tensor, Tensor]:
|
|
16
|
+
"""Applies normalization to the edges of a graph.
|
|
17
|
+
|
|
18
|
+
This function can add self-loops to the graph and apply either symmetric or
|
|
19
|
+
asymmetric normalization based on the node degrees.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
edge_index (LongTensor): The edge indices.
|
|
23
|
+
num_nodes (int, int], optional): The number of nodes, *i.e.*
|
|
24
|
+
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
|
|
25
|
+
add_self_loops (bool, optional): If set to :obj:`False`, will not add
|
|
26
|
+
self-loops to the input graph. (default: :obj:`True`)
|
|
27
|
+
symmetric (bool, optional): If set to :obj:`True`, symmetric
|
|
28
|
+
normalization (:math:`D^{-1/2} A D^{-1/2}`) is used, otherwise
|
|
29
|
+
asymmetric normalization (:math:`D^{-1} A`).
|
|
30
|
+
"""
|
|
31
|
+
if add_self_loops:
|
|
32
|
+
edge_index, _ = add_self_loops_fn(edge_index, num_nodes=num_nodes)
|
|
33
|
+
|
|
34
|
+
row, col = edge_index[0], edge_index[1]
|
|
35
|
+
deg = degree(row, num_nodes, dtype=torch.get_default_dtype())
|
|
36
|
+
|
|
37
|
+
if symmetric: # D^-1/2 * A * D^-1/2
|
|
38
|
+
deg_inv_sqrt = deg.pow(-0.5)
|
|
39
|
+
deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
|
|
40
|
+
edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
|
|
41
|
+
else: # D^-1 * A
|
|
42
|
+
deg_inv = deg.pow(-1)
|
|
43
|
+
deg_inv[torch.isinf(deg_inv)] = 0
|
|
44
|
+
edge_weight = deg_inv[row]
|
|
45
|
+
|
|
46
|
+
return edge_index, edge_weight
|