pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -1,7 +1,5 @@
|
|
1
1
|
from typing import Optional
|
2
2
|
|
3
|
-
from scipy.sparse.linalg import eigs, eigsh
|
4
|
-
|
5
3
|
from torch_geometric.data import Data
|
6
4
|
from torch_geometric.data.datapipes import functional_transform
|
7
5
|
from torch_geometric.transforms import BaseTransform
|
@@ -41,6 +39,8 @@ class LaplacianLambdaMax(BaseTransform):
|
|
41
39
|
self.is_undirected = is_undirected
|
42
40
|
|
43
41
|
def forward(self, data: Data) -> Data:
|
42
|
+
from scipy.sparse.linalg import eigs, eigsh
|
43
|
+
|
44
44
|
assert data.edge_index is not None
|
45
45
|
num_nodes = data.num_nodes
|
46
46
|
|
@@ -62,7 +62,7 @@ class LaplacianLambdaMax(BaseTransform):
|
|
62
62
|
eig_fn = eigsh
|
63
63
|
|
64
64
|
lambda_max = eig_fn(L, k=1, which='LM', return_eigenvectors=False)
|
65
|
-
data.lambda_max =
|
65
|
+
data.lambda_max = lambda_max.real.item()
|
66
66
|
|
67
67
|
return data
|
68
68
|
|
@@ -19,7 +19,11 @@ def get_attrs_with_suffix(
|
|
19
19
|
return [key for key in store.keys() if key.endswith(suffix)]
|
20
20
|
|
21
21
|
|
22
|
-
def get_mask_size(
|
22
|
+
def get_mask_size(
|
23
|
+
attr: str,
|
24
|
+
store: BaseStorage,
|
25
|
+
size: Optional[int],
|
26
|
+
) -> Optional[int]:
|
23
27
|
if size is not None:
|
24
28
|
return size
|
25
29
|
return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
|
@@ -44,7 +44,6 @@ class NodePropertySplit(BaseTransform):
|
|
44
44
|
of the node property, so that nodes with greater values of the
|
45
45
|
property are considered to be OOD (default: :obj:`True`)
|
46
46
|
|
47
|
-
Example:
|
48
47
|
.. code-block:: python
|
49
48
|
|
50
49
|
from torch_geometric.transforms import NodePropertySplit
|
@@ -54,7 +53,7 @@ class NodePropertySplit(BaseTransform):
|
|
54
53
|
|
55
54
|
property_name = 'popularity'
|
56
55
|
ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
|
57
|
-
|
56
|
+
transform = NodePropertySplit(property_name, ratios)
|
58
57
|
|
59
58
|
data = transform(data)
|
60
59
|
"""
|
@@ -262,15 +262,14 @@ class Pad(BaseTransform):
|
|
262
262
|
All the attributes of node types other than :obj:`v0` and :obj:`v1` are
|
263
263
|
padded using a value of :obj:`1.0`.
|
264
264
|
All the attributes of the :obj:`('v0', 'e0', 'v1')` edge type are padded
|
265
|
-
|
265
|
+
using a value of :obj:`3.5`.
|
266
266
|
The :obj:`edge_attr` attributes of the :obj:`('v1', 'e0', 'v0')` edge type
|
267
267
|
are padded using a value of :obj:`-1.5`, and any other attributes of this
|
268
268
|
edge type are padded using a value of :obj:`5.5`.
|
269
269
|
All the attributes of edge types other than these two are padded using a
|
270
270
|
value of :obj:`1.5`.
|
271
271
|
|
272
|
-
|
273
|
-
.. code-block::
|
272
|
+
.. code-block:: python
|
274
273
|
|
275
274
|
num_nodes = {'v0': 10, 'v1': 20, 'v2':30}
|
276
275
|
num_edges = {('v0', 'e0', 'v1'): 80}
|
@@ -467,9 +466,11 @@ class Pad(BaseTransform):
|
|
467
466
|
edge_type: Optional[EdgeType] = None,
|
468
467
|
) -> None:
|
469
468
|
|
470
|
-
attrs_to_pad =
|
471
|
-
attr
|
472
|
-
|
469
|
+
attrs_to_pad = {
|
470
|
+
attr
|
471
|
+
for attr in store.keys()
|
472
|
+
if store.is_edge_attr(attr) and self.__should_pad_edge_attr(attr)
|
473
|
+
}
|
473
474
|
if not attrs_to_pad:
|
474
475
|
return
|
475
476
|
num_target_edges = self.max_num_edges.get_value(edge_type)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
from torch_geometric.data import Data, HeteroData
|
4
|
+
from torch_geometric.data.datapipes import functional_transform
|
5
|
+
from torch_geometric.transforms import BaseTransform
|
6
|
+
from torch_geometric.utils import remove_self_loops
|
7
|
+
|
8
|
+
|
9
|
+
@functional_transform('remove_self_loops')
|
10
|
+
class RemoveSelfLoops(BaseTransform):
|
11
|
+
r"""Removes all self-loops in the given homogeneous or heterogeneous
|
12
|
+
graph (functional name: :obj:`remove_self_loops`).
|
13
|
+
|
14
|
+
Args:
|
15
|
+
attr (str, optional): The name of the attribute of edge weights
|
16
|
+
or multi-dimensional edge features to pass to
|
17
|
+
:meth:`torch_geometric.utils.remove_self_loops`.
|
18
|
+
(default: :obj:`"edge_weight"`)
|
19
|
+
"""
|
20
|
+
def __init__(self, attr: str = 'edge_weight') -> None:
|
21
|
+
self.attr = attr
|
22
|
+
|
23
|
+
def forward(
|
24
|
+
self,
|
25
|
+
data: Union[Data, HeteroData],
|
26
|
+
) -> Union[Data, HeteroData]:
|
27
|
+
for store in data.edge_stores:
|
28
|
+
if store.is_bipartite() or 'edge_index' not in store:
|
29
|
+
continue
|
30
|
+
|
31
|
+
store.edge_index, store[self.attr] = remove_self_loops(
|
32
|
+
store.edge_index,
|
33
|
+
edge_attr=store.get(self.attr, None),
|
34
|
+
)
|
35
|
+
|
36
|
+
return data
|
@@ -11,7 +11,7 @@ class SVDFeatureReduction(BaseTransform):
|
|
11
11
|
Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`).
|
12
12
|
|
13
13
|
Args:
|
14
|
-
out_channels (int): The
|
14
|
+
out_channels (int): The dimensionality of node features after
|
15
15
|
reduction.
|
16
16
|
"""
|
17
17
|
def __init__(self, out_channels: int):
|
@@ -19,7 +19,7 @@ class TwoHop(BaseTransform):
|
|
19
19
|
|
20
20
|
edge_index = EdgeIndex(edge_index, sparse_size=(N, N))
|
21
21
|
edge_index = edge_index.sort_by('row')[0]
|
22
|
-
edge_index2
|
22
|
+
edge_index2 = edge_index.matmul(edge_index)[0].as_tensor()
|
23
23
|
edge_index2, _ = remove_self_loops(edge_index2)
|
24
24
|
edge_index = torch.cat([edge_index, edge_index2], dim=1)
|
25
25
|
|
@@ -37,7 +37,8 @@ class VirtualNode(BaseTransform):
|
|
37
37
|
col = torch.cat([col, full, arange], dim=0)
|
38
38
|
edge_index = torch.stack([row, col], dim=0)
|
39
39
|
|
40
|
-
|
40
|
+
num_edge_types = int(edge_type.max()) if edge_type.numel() > 0 else 0
|
41
|
+
new_type = edge_type.new_full((num_nodes, ), num_edge_types + 1)
|
41
42
|
edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0)
|
42
43
|
|
43
44
|
old_data = copy.copy(data)
|
torch_geometric/typing.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import inspect
|
2
2
|
import os
|
3
3
|
import sys
|
4
|
+
import typing
|
4
5
|
import warnings
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
import torch
|
@@ -12,6 +13,9 @@ WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
|
|
12
13
|
WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
|
13
14
|
WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
|
14
15
|
WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
|
16
|
+
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
|
17
|
+
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
|
18
|
+
WITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6
|
15
19
|
WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
|
16
20
|
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
|
17
21
|
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13
|
@@ -21,6 +25,16 @@ NO_MKL = 'USE_MKL=OFF' in torch.__config__.show() or WITH_WINDOWS
|
|
21
25
|
|
22
26
|
MAX_INT64 = torch.iinfo(torch.int64).max
|
23
27
|
|
28
|
+
if WITH_PT20:
|
29
|
+
INDEX_DTYPES: Set[torch.dtype] = {
|
30
|
+
torch.int32,
|
31
|
+
torch.int64,
|
32
|
+
}
|
33
|
+
elif not typing.TYPE_CHECKING: # pragma: no cover
|
34
|
+
INDEX_DTYPES: Set[torch.dtype] = {
|
35
|
+
torch.int64,
|
36
|
+
}
|
37
|
+
|
24
38
|
if not hasattr(torch, 'sparse_csc'):
|
25
39
|
torch.sparse_csc = torch.sparse_coo
|
26
40
|
|
@@ -293,6 +307,8 @@ class EdgeTypeStr(str):
|
|
293
307
|
r"""A helper class to construct serializable edge types by merging an edge
|
294
308
|
type tuple into a single string.
|
295
309
|
"""
|
310
|
+
edge_type: tuple[str, str, str]
|
311
|
+
|
296
312
|
def __new__(cls, *args: Any) -> 'EdgeTypeStr':
|
297
313
|
if isinstance(args[0], (list, tuple)):
|
298
314
|
# Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
|
@@ -300,27 +316,37 @@ class EdgeTypeStr(str):
|
|
300
316
|
|
301
317
|
if len(args) == 1 and isinstance(args[0], str):
|
302
318
|
arg = args[0] # An edge type string was passed.
|
319
|
+
edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
|
320
|
+
if len(edge_type) != 3:
|
321
|
+
raise ValueError(f"Cannot convert the edge type '{arg}' to a "
|
322
|
+
f"tuple since it holds invalid characters")
|
303
323
|
|
304
324
|
elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
|
305
325
|
# A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
|
306
|
-
|
326
|
+
edge_type = (args[0], DEFAULT_REL, args[1])
|
327
|
+
arg = EDGE_TYPE_STR_SPLIT.join(edge_type)
|
307
328
|
|
308
329
|
elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
|
309
330
|
# A `(src, rel, dst)` edge type was passed:
|
331
|
+
edge_type = tuple(args)
|
310
332
|
arg = EDGE_TYPE_STR_SPLIT.join(args)
|
311
333
|
|
312
334
|
else:
|
313
335
|
raise ValueError(f"Encountered invalid edge type '{args}'")
|
314
336
|
|
315
|
-
|
337
|
+
out = str.__new__(cls, arg)
|
338
|
+
out.edge_type = edge_type # type: ignore
|
339
|
+
return out
|
316
340
|
|
317
341
|
def to_tuple(self) -> EdgeType:
|
318
342
|
r"""Returns the original edge type."""
|
319
|
-
|
320
|
-
if len(out) != 3:
|
343
|
+
if len(self.edge_type) != 3:
|
321
344
|
raise ValueError(f"Cannot convert the edge type '{self}' to a "
|
322
345
|
f"tuple since it holds invalid characters")
|
323
|
-
return
|
346
|
+
return self.edge_type
|
347
|
+
|
348
|
+
def __reduce__(self) -> tuple[Any, Any]:
|
349
|
+
return (self.__class__, (self.edge_type, ))
|
324
350
|
|
325
351
|
|
326
352
|
# There exist some short-cuts to query edge-types (given that the full triplet
|
@@ -358,3 +384,14 @@ MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]]
|
|
358
384
|
|
359
385
|
InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
|
360
386
|
InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]
|
387
|
+
|
388
|
+
# Serialization ###############################################################
|
389
|
+
|
390
|
+
if WITH_PT24:
|
391
|
+
torch.serialization.add_safe_globals([
|
392
|
+
SparseTensor,
|
393
|
+
SparseStorage,
|
394
|
+
TensorFrame,
|
395
|
+
MockTorchCSCTensor,
|
396
|
+
EdgeTypeStr,
|
397
|
+
])
|
@@ -21,6 +21,7 @@ from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph,
|
|
21
21
|
from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
|
22
22
|
from ._homophily import homophily
|
23
23
|
from ._assortativity import assortativity
|
24
|
+
from ._normalize_edge_index import normalize_edge_index
|
24
25
|
from .laplacian import get_laplacian
|
25
26
|
from .mesh_laplacian import get_mesh_laplacian
|
26
27
|
from .mask import mask_select, index_to_mask, mask_to_index
|
@@ -44,7 +45,7 @@ from .convert import to_networkit, from_networkit
|
|
44
45
|
from .convert import to_trimesh, from_trimesh
|
45
46
|
from .convert import to_cugraph, from_cugraph
|
46
47
|
from .convert import to_dgl, from_dgl
|
47
|
-
from .smiles import from_smiles, to_smiles
|
48
|
+
from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles
|
48
49
|
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
|
49
50
|
barabasi_albert_graph)
|
50
51
|
from ._negative_sampling import (negative_sampling, batched_negative_sampling,
|
@@ -89,6 +90,7 @@ __all__ = [
|
|
89
90
|
'dropout_adj',
|
90
91
|
'homophily',
|
91
92
|
'assortativity',
|
93
|
+
'normalize_edge_index',
|
92
94
|
'get_laplacian',
|
93
95
|
'get_mesh_laplacian',
|
94
96
|
'mask_select',
|
@@ -127,6 +129,8 @@ __all__ = [
|
|
127
129
|
'from_cugraph',
|
128
130
|
'to_dgl',
|
129
131
|
'from_dgl',
|
132
|
+
'from_rdmol',
|
133
|
+
'to_rdmol',
|
130
134
|
'from_smiles',
|
131
135
|
'to_smiles',
|
132
136
|
'erdos_renyi_graph',
|
@@ -265,7 +265,7 @@ def structured_negative_sampling_feasible(
|
|
265
265
|
:meth:`~torch_geometric.utils.structured_negative_sampling` is feasible
|
266
266
|
on the graph given by :obj:`edge_index`.
|
267
267
|
:meth:`~torch_geometric.utils.structured_negative_sampling` is infeasible
|
268
|
-
if
|
268
|
+
if at least one node is connected to all other nodes.
|
269
269
|
|
270
270
|
Args:
|
271
271
|
edge_index (LongTensor): The edge indices.
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
from torch_geometric.utils import add_self_loops as add_self_loops_fn
|
7
|
+
from torch_geometric.utils import degree
|
8
|
+
|
9
|
+
|
10
|
+
def normalize_edge_index(
|
11
|
+
edge_index: Tensor,
|
12
|
+
num_nodes: Optional[int] = None,
|
13
|
+
add_self_loops: bool = True,
|
14
|
+
symmetric: bool = True,
|
15
|
+
) -> Tuple[Tensor, Tensor]:
|
16
|
+
"""Applies normalization to the edges of a graph.
|
17
|
+
|
18
|
+
This function can add self-loops to the graph and apply either symmetric or
|
19
|
+
asymmetric normalization based on the node degrees.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
edge_index (LongTensor): The edge indices.
|
23
|
+
num_nodes (int, int], optional): The number of nodes, *i.e.*
|
24
|
+
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
|
25
|
+
add_self_loops (bool, optional): If set to :obj:`False`, will not add
|
26
|
+
self-loops to the input graph. (default: :obj:`True`)
|
27
|
+
symmetric (bool, optional): If set to :obj:`True`, symmetric
|
28
|
+
normalization (:math:`D^{-1/2} A D^{-1/2}`) is used, otherwise
|
29
|
+
asymmetric normalization (:math:`D^{-1} A`).
|
30
|
+
"""
|
31
|
+
if add_self_loops:
|
32
|
+
edge_index, _ = add_self_loops_fn(edge_index, num_nodes=num_nodes)
|
33
|
+
|
34
|
+
row, col = edge_index[0], edge_index[1]
|
35
|
+
deg = degree(row, num_nodes, dtype=torch.get_default_dtype())
|
36
|
+
|
37
|
+
if symmetric: # D^-1/2 * A * D^-1/2
|
38
|
+
deg_inv_sqrt = deg.pow(-0.5)
|
39
|
+
deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
|
40
|
+
edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
|
41
|
+
else: # D^-1 * A
|
42
|
+
deg_inv = deg.pow(-1)
|
43
|
+
deg_inv[torch.isinf(deg_inv)] = 0
|
44
|
+
edge_weight = deg_inv[row]
|
45
|
+
|
46
|
+
return edge_index, edge_weight
|
@@ -4,7 +4,7 @@ import torch
|
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
6
|
import torch_geometric.typing
|
7
|
-
from torch_geometric import is_compiling, warnings
|
7
|
+
from torch_geometric import is_compiling, is_in_onnx_export, warnings
|
8
8
|
from torch_geometric.typing import torch_scatter
|
9
9
|
from torch_geometric.utils.functions import cumsum
|
10
10
|
|
@@ -88,18 +88,33 @@ if torch_geometric.typing.WITH_PT112: # pragma: no cover
|
|
88
88
|
# in case the input does not require gradients:
|
89
89
|
if reduce in ['min', 'max', 'amin', 'amax']:
|
90
90
|
if (not torch_geometric.typing.WITH_TORCH_SCATTER
|
91
|
-
or is_compiling() or not src.is_cuda
|
91
|
+
or is_compiling() or is_in_onnx_export() or not src.is_cuda
|
92
92
|
or not src.requires_grad):
|
93
93
|
|
94
|
-
if src.is_cuda and src.requires_grad and not is_compiling()
|
94
|
+
if (src.is_cuda and src.requires_grad and not is_compiling()
|
95
|
+
and not is_in_onnx_export()):
|
95
96
|
warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
|
96
97
|
f"can be accelerated via the 'torch-scatter'"
|
97
98
|
f" package, but it was not found")
|
98
99
|
|
99
100
|
index = broadcast(index, src, dim)
|
100
|
-
|
101
|
-
|
102
|
-
|
101
|
+
if not is_in_onnx_export():
|
102
|
+
return src.new_zeros(size).scatter_reduce_(
|
103
|
+
dim, index, src, reduce=f'a{reduce[-3:]}',
|
104
|
+
include_self=False)
|
105
|
+
|
106
|
+
fill = torch.full( # type: ignore
|
107
|
+
size=(1, ),
|
108
|
+
fill_value=src.min() if 'max' in reduce else src.max(),
|
109
|
+
dtype=src.dtype,
|
110
|
+
device=src.device,
|
111
|
+
).expand_as(src)
|
112
|
+
out = src.new_zeros(size).scatter_reduce_(
|
113
|
+
dim, index, fill, reduce=f'a{reduce[-3:]}',
|
114
|
+
include_self=True)
|
115
|
+
return out.scatter_reduce_(dim, index, src,
|
116
|
+
reduce=f'a{reduce[-3:]}',
|
117
|
+
include_self=True)
|
103
118
|
|
104
119
|
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
|
105
120
|
reduce=reduce[-3:])
|
@@ -175,6 +190,7 @@ else: # pragma: no cover
|
|
175
190
|
|
176
191
|
|
177
192
|
def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
|
193
|
+
dim = ref.dim() + dim if dim < 0 else dim
|
178
194
|
size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1))
|
179
195
|
return src.view(size).expand_as(ref)
|
180
196
|
|
@@ -186,7 +202,8 @@ def scatter_argmax(
|
|
186
202
|
dim_size: Optional[int] = None,
|
187
203
|
) -> Tensor:
|
188
204
|
|
189
|
-
if torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()
|
205
|
+
if (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()
|
206
|
+
and not is_in_onnx_export()):
|
190
207
|
out = torch_scatter.scatter_max(src, index, dim=dim, dim_size=dim_size)
|
191
208
|
return out[1]
|
192
209
|
|
@@ -199,9 +216,18 @@ def scatter_argmax(
|
|
199
216
|
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
|
200
217
|
|
201
218
|
if torch_geometric.typing.WITH_PT112:
|
202
|
-
|
203
|
-
|
204
|
-
|
219
|
+
if not is_in_onnx_export():
|
220
|
+
res = src.new_empty(dim_size)
|
221
|
+
res.scatter_reduce_(0, index, src.detach(), reduce='amax',
|
222
|
+
include_self=False)
|
223
|
+
else:
|
224
|
+
# `include_self=False` is currently not supported by ONNX:
|
225
|
+
res = src.new_full(
|
226
|
+
size=(dim_size, ),
|
227
|
+
fill_value=src.min(), # type: ignore
|
228
|
+
)
|
229
|
+
res.scatter_reduce_(0, index, src.detach(), reduce="amax",
|
230
|
+
include_self=True)
|
205
231
|
elif torch_geometric.typing.WITH_PT111:
|
206
232
|
res = torch.scatter_reduce(src.detach(), 0, index, reduce='amax',
|
207
233
|
output_size=dim_size) # type: ignore
|
@@ -294,7 +320,7 @@ def group_cat(
|
|
294
320
|
r"""Concatenates the given sequence of tensors :obj:`tensors` in the given
|
295
321
|
dimension :obj:`dim`.
|
296
322
|
Different from :meth:`torch.cat`, values along the concatenating dimension
|
297
|
-
are grouped according to the
|
323
|
+
are grouped according to the indices defined in the :obj:`index` tensors.
|
298
324
|
All tensors must have the same shape (except in the concatenating
|
299
325
|
dimension).
|
300
326
|
|
@@ -325,5 +351,5 @@ def group_cat(
|
|
325
351
|
"""
|
326
352
|
assert len(tensors) == len(indices)
|
327
353
|
index, perm = torch.cat(indices).sort(stable=True)
|
328
|
-
out = torch.cat(tensors, dim=
|
354
|
+
out = torch.cat(tensors, dim=dim).index_select(dim, perm)
|
329
355
|
return (out, index) if return_index else out
|
@@ -346,10 +346,12 @@ def k_hop_subgraph(
|
|
346
346
|
|
347
347
|
subsets = [node_idx]
|
348
348
|
|
349
|
+
preserved_edge_mask = torch.zeros_like(edge_mask)
|
349
350
|
for _ in range(num_hops):
|
350
351
|
node_mask.fill_(False)
|
351
352
|
node_mask[subsets[-1]] = True
|
352
353
|
torch.index_select(node_mask, 0, row, out=edge_mask)
|
354
|
+
preserved_edge_mask |= edge_mask
|
353
355
|
subsets.append(col[edge_mask])
|
354
356
|
|
355
357
|
subset, inv = torch.cat(subsets).unique(return_inverse=True)
|
@@ -360,6 +362,8 @@ def k_hop_subgraph(
|
|
360
362
|
|
361
363
|
if not directed:
|
362
364
|
edge_mask = node_mask[row] & node_mask[col]
|
365
|
+
else:
|
366
|
+
edge_mask = preserved_edge_mask
|
363
367
|
|
364
368
|
edge_index = edge_index[:, edge_mask]
|
365
369
|
|
@@ -2,7 +2,6 @@ from itertools import chain
|
|
2
2
|
from typing import Any, List, Literal, Tuple, Union, overload
|
3
3
|
|
4
4
|
import torch
|
5
|
-
from scipy.sparse.csgraph import minimum_spanning_tree
|
6
5
|
from torch import Tensor
|
7
6
|
|
8
7
|
from torch_geometric.utils import (
|
@@ -54,6 +53,7 @@ def tree_decomposition(
|
|
54
53
|
:obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)`
|
55
54
|
"""
|
56
55
|
import rdkit.Chem as Chem
|
56
|
+
from scipy.sparse.csgraph import minimum_spanning_tree
|
57
57
|
|
58
58
|
# Cliques = rings and bonds.
|
59
59
|
cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)]
|
@@ -64,7 +64,7 @@ def tree_decomposition(
|
|
64
64
|
xs.append(1)
|
65
65
|
|
66
66
|
# Generate `atom2cliques` mappings.
|
67
|
-
atom2cliques: List[List[int]] = [[] for
|
67
|
+
atom2cliques: List[List[int]] = [[] for _ in range(mol.GetNumAtoms())]
|
68
68
|
for c in range(len(cliques)):
|
69
69
|
for atom in cliques[c]:
|
70
70
|
atom2cliques[atom].append(c)
|
@@ -12,7 +12,7 @@ def shuffle_node(
|
|
12
12
|
training: bool = True,
|
13
13
|
) -> Tuple[Tensor, Tensor]:
|
14
14
|
r"""Randomly shuffle the feature matrix :obj:`x` along the
|
15
|
-
first
|
15
|
+
first dimension.
|
16
16
|
|
17
17
|
The method returns (1) the shuffled :obj:`x`, (2) the permutation
|
18
18
|
indicating the orders of original nodes after shuffling.
|
torch_geometric/utils/convert.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
3
3
|
|
4
|
-
import scipy.sparse
|
5
4
|
import torch
|
6
5
|
from torch import Tensor
|
7
6
|
from torch.utils.dlpack import from_dlpack, to_dlpack
|
@@ -14,7 +13,7 @@ def to_scipy_sparse_matrix(
|
|
14
13
|
edge_index: Tensor,
|
15
14
|
edge_attr: Optional[Tensor] = None,
|
16
15
|
num_nodes: Optional[int] = None,
|
17
|
-
) ->
|
16
|
+
) -> Any:
|
18
17
|
r"""Converts a graph given by edge indices and edge attributes to a scipy
|
19
18
|
sparse matrix.
|
20
19
|
|
@@ -34,22 +33,23 @@ def to_scipy_sparse_matrix(
|
|
34
33
|
<4x4 sparse matrix of type '<class 'numpy.float32'>'
|
35
34
|
with 6 stored elements in COOrdinate format>
|
36
35
|
"""
|
36
|
+
import scipy.sparse as sp
|
37
|
+
|
37
38
|
row, col = edge_index.cpu()
|
38
39
|
|
39
40
|
if edge_attr is None:
|
40
|
-
edge_attr = torch.ones(row.size(0))
|
41
|
+
edge_attr = torch.ones(row.size(0), device="cpu")
|
41
42
|
else:
|
42
43
|
edge_attr = edge_attr.view(-1).cpu()
|
43
44
|
assert edge_attr.size(0) == row.size(0)
|
44
45
|
|
45
46
|
N = maybe_num_nodes(edge_index, num_nodes)
|
46
|
-
out =
|
47
|
+
out = sp.coo_matrix( #
|
47
48
|
(edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N))
|
48
49
|
return out
|
49
50
|
|
50
51
|
|
51
|
-
def from_scipy_sparse_matrix(
|
52
|
-
A: scipy.sparse.spmatrix) -> Tuple[Tensor, Tensor]:
|
52
|
+
def from_scipy_sparse_matrix(A: Any) -> Tuple[Tensor, Tensor]:
|
53
53
|
r"""Converts a scipy sparse matrix to edge indices and edge attributes.
|
54
54
|
|
55
55
|
Args:
|
@@ -527,10 +527,14 @@ def to_dgl(
|
|
527
527
|
if isinstance(data, Data):
|
528
528
|
if data.edge_index is not None:
|
529
529
|
row, col = data.edge_index
|
530
|
-
|
530
|
+
elif 'adj' in data:
|
531
|
+
row, col, _ = data.adj.coo()
|
532
|
+
elif 'adj_t' in data:
|
531
533
|
row, col, _ = data.adj_t.t().coo()
|
534
|
+
else:
|
535
|
+
row, col = [], []
|
532
536
|
|
533
|
-
g = dgl.graph((row, col))
|
537
|
+
g = dgl.graph((row, col), num_nodes=data.num_nodes)
|
534
538
|
|
535
539
|
for attr in data.node_attrs():
|
536
540
|
g.ndata[attr] = data[attr]
|