pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -10,10 +10,16 @@ from packaging.requirements import Requirement
|
|
|
10
10
|
from packaging.version import Version
|
|
11
11
|
|
|
12
12
|
import torch_geometric
|
|
13
|
+
import torch_geometric.typing
|
|
13
14
|
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
|
|
14
15
|
from torch_geometric.visualization.graph import has_graphviz
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def is_rag_test() -> bool:
|
|
19
|
+
r"""Whether to run the RAG test suite."""
|
|
20
|
+
return os.getenv('RAG_TEST', '0') == '1'
|
|
21
|
+
|
|
22
|
+
|
|
17
23
|
def is_full_test() -> bool:
|
|
18
24
|
r"""Whether to run the full but time-consuming test suite."""
|
|
19
25
|
return os.getenv('FULL_TEST', '0') == '1'
|
|
@@ -32,8 +38,8 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
32
38
|
|
|
33
39
|
def is_distributed_test() -> bool:
|
|
34
40
|
r"""Whether to run the distributed test suite."""
|
|
35
|
-
return (
|
|
36
|
-
and
|
|
41
|
+
return (os.getenv('DIST_TEST', '0') == '1' and sys.platform == 'linux'
|
|
42
|
+
and has_package('pyg_lib'))
|
|
37
43
|
|
|
38
44
|
|
|
39
45
|
def onlyDistributedTest(func: Callable) -> Callable:
|
|
@@ -203,6 +209,18 @@ def withPackage(*args: str) -> Callable:
|
|
|
203
209
|
return decorator
|
|
204
210
|
|
|
205
211
|
|
|
212
|
+
def onlyRAG(func: Callable) -> Callable:
|
|
213
|
+
r"""A decorator to specify that this function belongs to the RAG test
|
|
214
|
+
suite.
|
|
215
|
+
"""
|
|
216
|
+
import pytest
|
|
217
|
+
func = pytest.mark.rag(func)
|
|
218
|
+
return pytest.mark.skipif(
|
|
219
|
+
not is_rag_test(),
|
|
220
|
+
reason="RAG tests are disabled",
|
|
221
|
+
)(func)
|
|
222
|
+
|
|
223
|
+
|
|
206
224
|
def withCUDA(func: Callable) -> Callable:
|
|
207
225
|
r"""A decorator to test both on CPU and CUDA (if available)."""
|
|
208
226
|
import pytest
|
|
@@ -234,8 +252,9 @@ def withDevice(func: Callable) -> Callable:
|
|
|
234
252
|
if device:
|
|
235
253
|
backend = os.getenv('TORCH_BACKEND')
|
|
236
254
|
if backend is None:
|
|
237
|
-
warnings.warn(
|
|
238
|
-
|
|
255
|
+
warnings.warn(
|
|
256
|
+
f"Please specify the backend via 'TORCH_BACKEND' in"
|
|
257
|
+
f"order to test against '{device}'", stacklevel=2)
|
|
239
258
|
else:
|
|
240
259
|
import_module(backend)
|
|
241
260
|
devices.append(pytest.param(torch.device(device), id=device))
|
|
@@ -250,7 +269,7 @@ def withMETIS(func: Callable) -> Callable:
|
|
|
250
269
|
with_metis = WITH_METIS
|
|
251
270
|
|
|
252
271
|
if with_metis:
|
|
253
|
-
try: # Test that METIS can
|
|
272
|
+
try: # Test that METIS can successfully execute:
|
|
254
273
|
# TODO Using `pyg-lib` metis partitioning leads to some weird bugs
|
|
255
274
|
# in the # CI. As such, we require `torch-sparse` for now.
|
|
256
275
|
rowptr = torch.tensor([0, 2, 4, 6])
|
|
@@ -265,6 +284,17 @@ def withMETIS(func: Callable) -> Callable:
|
|
|
265
284
|
)(func)
|
|
266
285
|
|
|
267
286
|
|
|
287
|
+
def withHashTensor(func: Callable) -> Callable:
|
|
288
|
+
r"""A decorator to only test in case :class:`HashTensor` is available."""
|
|
289
|
+
import pytest
|
|
290
|
+
|
|
291
|
+
return pytest.mark.skipif(
|
|
292
|
+
not torch_geometric.typing.WITH_CPU_HASH_MAP
|
|
293
|
+
and not has_package('pandas'),
|
|
294
|
+
reason="HashTensor dependencies not available",
|
|
295
|
+
)(func)
|
|
296
|
+
|
|
297
|
+
|
|
268
298
|
def disableExtensions(func: Callable) -> Callable:
|
|
269
299
|
r"""A decorator to temporarily disable the usage of the
|
|
270
300
|
:obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
|
|
@@ -20,6 +20,7 @@ from .target_indegree import TargetIndegree
|
|
|
20
20
|
from .local_degree_profile import LocalDegreeProfile
|
|
21
21
|
from .add_self_loops import AddSelfLoops
|
|
22
22
|
from .add_remaining_self_loops import AddRemainingSelfLoops
|
|
23
|
+
from .remove_self_loops import RemoveSelfLoops
|
|
23
24
|
from .remove_isolated_nodes import RemoveIsolatedNodes
|
|
24
25
|
from .remove_duplicated_edges import RemoveDuplicatedEdges
|
|
25
26
|
from .knn_graph import KNNGraph
|
|
@@ -36,6 +37,7 @@ from .rooted_subgraph import RootedEgoNets, RootedRWSubgraph
|
|
|
36
37
|
from .largest_connected_components import LargestConnectedComponents
|
|
37
38
|
from .virtual_node import VirtualNode
|
|
38
39
|
from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
|
|
40
|
+
from .add_gpse import AddGPSE
|
|
39
41
|
from .feature_propagation import FeaturePropagation
|
|
40
42
|
from .half_hop import HalfHop
|
|
41
43
|
|
|
@@ -87,6 +89,7 @@ graph_transforms = [
|
|
|
87
89
|
'LocalDegreeProfile',
|
|
88
90
|
'AddSelfLoops',
|
|
89
91
|
'AddRemainingSelfLoops',
|
|
92
|
+
'RemoveSelfLoops',
|
|
90
93
|
'RemoveIsolatedNodes',
|
|
91
94
|
'RemoveDuplicatedEdges',
|
|
92
95
|
'KNNGraph',
|
|
@@ -106,6 +109,7 @@ graph_transforms = [
|
|
|
106
109
|
'VirtualNode',
|
|
107
110
|
'AddLaplacianEigenvectorPE',
|
|
108
111
|
'AddRandomWalkPE',
|
|
112
|
+
'AddGPSE',
|
|
109
113
|
'FeaturePropagation',
|
|
110
114
|
'HalfHop',
|
|
111
115
|
]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch.nn import Module
|
|
4
|
+
|
|
5
|
+
from torch_geometric.data import Data
|
|
6
|
+
from torch_geometric.data.datapipes import functional_transform
|
|
7
|
+
from torch_geometric.transforms import BaseTransform, VirtualNode
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@functional_transform('add_gpse')
|
|
11
|
+
class AddGPSE(BaseTransform):
|
|
12
|
+
r"""Adds the GPSE encoding from the `"Graph Positional and Structural
|
|
13
|
+
Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph
|
|
14
|
+
(functional name: :obj:`add_gpse`).
|
|
15
|
+
To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates
|
|
16
|
+
the actual encodings.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model (Module): The pre-trained GPSE model.
|
|
20
|
+
use_vn (bool, optional): Whether to use virtual nodes.
|
|
21
|
+
(default: :obj:`True`)
|
|
22
|
+
rand_type (str, optional): Type of random features to use. Options are
|
|
23
|
+
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
|
|
24
|
+
(default: :obj:`NormalSE`)
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model: Module,
|
|
30
|
+
use_vn: bool = True,
|
|
31
|
+
rand_type: str = 'NormalSE',
|
|
32
|
+
):
|
|
33
|
+
self.model = model
|
|
34
|
+
self.use_vn = use_vn
|
|
35
|
+
self.vn = VirtualNode()
|
|
36
|
+
self.rand_type = rand_type
|
|
37
|
+
|
|
38
|
+
def forward(self, data: Data) -> Any:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def __call__(self, data: Data) -> Data:
|
|
42
|
+
from torch_geometric.nn.models.gpse import gpse_process
|
|
43
|
+
|
|
44
|
+
data_vn = self.vn(data.clone()) if self.use_vn else data.clone()
|
|
45
|
+
batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)
|
|
46
|
+
batch_out = batch_out.to('cpu', non_blocking=True)
|
|
47
|
+
data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out
|
|
48
|
+
|
|
49
|
+
return data
|
|
@@ -37,7 +37,7 @@ class AddMetaPaths(BaseTransform):
|
|
|
37
37
|
:class:`~torch_geometric.data.HeteroData` object as edge type
|
|
38
38
|
:obj:`(src_node_type, "metapath_*", dst_node_type)`, where
|
|
39
39
|
:obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\mathcal{V}_1`
|
|
40
|
-
and :math:`\mathcal{V}_{\ell}`,
|
|
40
|
+
and :math:`\mathcal{V}_{\ell}`, respectively.
|
|
41
41
|
|
|
42
42
|
In addition, a :obj:`metapath_dict` object is added to the
|
|
43
43
|
:class:`~torch_geometric.data.HeteroData` object which maps the
|
|
@@ -108,13 +108,15 @@ class AddMetaPaths(BaseTransform):
|
|
|
108
108
|
**kwargs: bool,
|
|
109
109
|
) -> None:
|
|
110
110
|
if 'drop_orig_edges' in kwargs:
|
|
111
|
-
warnings.warn(
|
|
112
|
-
|
|
111
|
+
warnings.warn(
|
|
112
|
+
"'drop_orig_edges' is deprecated. Use "
|
|
113
|
+
"'drop_orig_edge_types' instead", stacklevel=2)
|
|
113
114
|
drop_orig_edge_types = kwargs['drop_orig_edges']
|
|
114
115
|
|
|
115
116
|
if 'drop_unconnected_nodes' in kwargs:
|
|
116
|
-
warnings.warn(
|
|
117
|
-
|
|
117
|
+
warnings.warn(
|
|
118
|
+
"'drop_unconnected_nodes' is deprecated. Use "
|
|
119
|
+
"'drop_unconnected_node_types' instead", stacklevel=2)
|
|
118
120
|
drop_unconnected_node_types = kwargs['drop_unconnected_nodes']
|
|
119
121
|
|
|
120
122
|
for path in metapaths:
|
|
@@ -144,7 +146,7 @@ class AddMetaPaths(BaseTransform):
|
|
|
144
146
|
if self.max_sample is not None:
|
|
145
147
|
edge_index, edge_weight = self._sample(edge_index, edge_weight)
|
|
146
148
|
|
|
147
|
-
for
|
|
149
|
+
for edge_type in metapath[1:]:
|
|
148
150
|
edge_index2, edge_weight2 = self._edge_index(data, edge_type)
|
|
149
151
|
|
|
150
152
|
edge_index, edge_weight = edge_index.matmul(
|
|
@@ -231,7 +233,7 @@ class AddRandomMetaPaths(BaseTransform):
|
|
|
231
233
|
will drop node types not connected by any edge type.
|
|
232
234
|
(default: :obj:`False`)
|
|
233
235
|
walks_per_node (int, List[int], optional): The number of random walks
|
|
234
|
-
for each starting node in a
|
|
236
|
+
for each starting node in a metapath. (default: :obj:`1`)
|
|
235
237
|
sample_ratio (float, optional): The ratio of source nodes to start
|
|
236
238
|
random walks from. (default: :obj:`1.0`)
|
|
237
239
|
"""
|
|
@@ -276,7 +278,7 @@ class AddRandomMetaPaths(BaseTransform):
|
|
|
276
278
|
row = start = torch.randperm(num_nodes)[:num_starts].repeat(
|
|
277
279
|
self.walks_per_node[j])
|
|
278
280
|
|
|
279
|
-
for
|
|
281
|
+
for edge_type in metapath:
|
|
280
282
|
edge_index = EdgeIndex(
|
|
281
283
|
data[edge_type].edge_index,
|
|
282
284
|
sparse_size=data[edge_type].size(),
|
|
@@ -92,12 +92,12 @@ class AddLaplacianEigenvectorPE(BaseTransform):
|
|
|
92
92
|
from numpy.linalg import eig, eigh
|
|
93
93
|
eig_fn = eig if not self.is_undirected else eigh
|
|
94
94
|
|
|
95
|
-
eig_vals, eig_vecs = eig_fn(L.todense())
|
|
95
|
+
eig_vals, eig_vecs = eig_fn(L.todense())
|
|
96
96
|
else:
|
|
97
97
|
from scipy.sparse.linalg import eigs, eigsh
|
|
98
98
|
eig_fn = eigs if not self.is_undirected else eigsh
|
|
99
99
|
|
|
100
|
-
eig_vals, eig_vecs = eig_fn(
|
|
100
|
+
eig_vals, eig_vecs = eig_fn(
|
|
101
101
|
L,
|
|
102
102
|
k=self.k + 1,
|
|
103
103
|
which='SR' if not self.is_undirected else 'SA',
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
from abc import ABC
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
|
|
@@ -31,6 +31,7 @@ class BaseTransform(ABC):
|
|
|
31
31
|
# Shallow-copy the data so that we prevent in-place data modification.
|
|
32
32
|
return self.forward(copy.copy(data))
|
|
33
33
|
|
|
34
|
+
@abstractmethod
|
|
34
35
|
def forward(self, data: Any) -> Any:
|
|
35
36
|
pass
|
|
36
37
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
|
|
3
5
|
from torch_geometric.data import Data
|
|
@@ -5,30 +7,78 @@ from torch_geometric.data.datapipes import functional_transform
|
|
|
5
7
|
from torch_geometric.transforms import BaseTransform
|
|
6
8
|
|
|
7
9
|
|
|
10
|
+
class _QhullTransform(BaseTransform):
|
|
11
|
+
r"""Q-hull implementation of delaunay triangulation."""
|
|
12
|
+
def forward(self, data: Data) -> Data:
|
|
13
|
+
assert data.pos is not None
|
|
14
|
+
import scipy.spatial
|
|
15
|
+
|
|
16
|
+
pos = data.pos.cpu().numpy()
|
|
17
|
+
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
|
18
|
+
face = torch.from_numpy(tri.simplices)
|
|
19
|
+
|
|
20
|
+
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
|
21
|
+
return data
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _ShullTransform(BaseTransform):
|
|
25
|
+
r"""Sweep-hull implementation of delaunay triangulation."""
|
|
26
|
+
def forward(self, data: Data) -> Data:
|
|
27
|
+
assert data.pos is not None
|
|
28
|
+
from torch_delaunay.functional import shull2d
|
|
29
|
+
|
|
30
|
+
face = shull2d(data.pos.cpu())
|
|
31
|
+
data.face = face.t().contiguous().to(data.pos.device)
|
|
32
|
+
return data
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _SequentialTransform(BaseTransform):
|
|
36
|
+
r"""Runs the first successful transformation.
|
|
37
|
+
|
|
38
|
+
All intermediate exceptions are suppressed except the last.
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, transforms: List[BaseTransform]) -> None:
|
|
41
|
+
assert len(transforms) > 0
|
|
42
|
+
self.transforms = transforms
|
|
43
|
+
|
|
44
|
+
def forward(self, data: Data) -> Data:
|
|
45
|
+
for i, transform in enumerate(self.transforms):
|
|
46
|
+
try:
|
|
47
|
+
return transform.forward(data)
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
if i == len(self.transforms) - 1:
|
|
50
|
+
raise e
|
|
51
|
+
return data
|
|
52
|
+
|
|
53
|
+
|
|
8
54
|
@functional_transform('delaunay')
|
|
9
55
|
class Delaunay(BaseTransform):
|
|
10
56
|
r"""Computes the delaunay triangulation of a set of points
|
|
11
57
|
(functional name: :obj:`delaunay`).
|
|
58
|
+
|
|
59
|
+
.. hint::
|
|
60
|
+
Consider installing the
|
|
61
|
+
`torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package
|
|
62
|
+
to speed up computation.
|
|
12
63
|
"""
|
|
13
|
-
def
|
|
14
|
-
|
|
64
|
+
def __init__(self) -> None:
|
|
65
|
+
self._transform = _SequentialTransform([
|
|
66
|
+
_ShullTransform(),
|
|
67
|
+
_QhullTransform(),
|
|
68
|
+
])
|
|
15
69
|
|
|
70
|
+
def forward(self, data: Data) -> Data:
|
|
16
71
|
assert data.pos is not None
|
|
72
|
+
device = data.pos.device
|
|
17
73
|
|
|
18
74
|
if data.pos.size(0) < 2:
|
|
19
|
-
data.edge_index = torch.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
data.edge_index = torch.tensor([[0, 1], [1, 0]],
|
|
23
|
-
device=data.pos.device)
|
|
75
|
+
data.edge_index = torch.empty(2, 0, dtype=torch.long,
|
|
76
|
+
device=device)
|
|
77
|
+
elif data.pos.size(0) == 2:
|
|
78
|
+
data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)
|
|
24
79
|
elif data.pos.size(0) == 3:
|
|
25
|
-
data.face = torch.tensor([[0], [1], [2]],
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
pos = data.pos.cpu().numpy()
|
|
29
|
-
tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
|
|
30
|
-
face = torch.from_numpy(tri.simplices)
|
|
31
|
-
|
|
32
|
-
data.face = face.t().contiguous().to(data.pos.device, torch.long)
|
|
80
|
+
data.face = torch.tensor([[0], [1], [2]], device=device)
|
|
81
|
+
else:
|
|
82
|
+
data = self._transform.forward(data)
|
|
33
83
|
|
|
34
84
|
return data
|
|
@@ -8,8 +8,15 @@ from torch_geometric.utils import to_undirected
|
|
|
8
8
|
|
|
9
9
|
@functional_transform('face_to_edge')
|
|
10
10
|
class FaceToEdge(BaseTransform):
|
|
11
|
-
r"""Converts mesh faces :obj:`[3, num_faces]`
|
|
12
|
-
:obj:`[
|
|
11
|
+
r"""Converts mesh faces of shape :obj:`[3, num_faces]` or
|
|
12
|
+
:obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`
|
|
13
|
+
(functional name: :obj:`face_to_edge`).
|
|
14
|
+
|
|
15
|
+
This transform supports both 2D triangular faces, represented by a
|
|
16
|
+
tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,
|
|
17
|
+
represented by a tensor of shape :obj:`[4, num_faces]`. It will convert
|
|
18
|
+
these faces into edge indices, where each edge is defined by the indices
|
|
19
|
+
of its two endpoints.
|
|
13
20
|
|
|
14
21
|
Args:
|
|
15
22
|
remove_faces (bool, optional): If set to :obj:`False`, the face tensor
|
|
@@ -22,7 +29,29 @@ class FaceToEdge(BaseTransform):
|
|
|
22
29
|
if hasattr(data, 'face'):
|
|
23
30
|
assert data.face is not None
|
|
24
31
|
face = data.face
|
|
25
|
-
|
|
32
|
+
|
|
33
|
+
if face.size(0) not in [3, 4]:
|
|
34
|
+
raise RuntimeError(f"Expected 'face' tensor with shape "
|
|
35
|
+
f"[3, num_faces] or [4, num_faces] "
|
|
36
|
+
f"(got {list(face.size())})")
|
|
37
|
+
|
|
38
|
+
if face.size()[0] == 3:
|
|
39
|
+
edge_index = torch.cat([
|
|
40
|
+
face[:2],
|
|
41
|
+
face[1:],
|
|
42
|
+
face[::2],
|
|
43
|
+
], dim=1)
|
|
44
|
+
else:
|
|
45
|
+
assert face.size()[0] == 4
|
|
46
|
+
edge_index = torch.cat([
|
|
47
|
+
face[:2],
|
|
48
|
+
face[1:3],
|
|
49
|
+
face[2:4],
|
|
50
|
+
face[::2],
|
|
51
|
+
face[1::2],
|
|
52
|
+
face[::3],
|
|
53
|
+
], dim=1)
|
|
54
|
+
|
|
26
55
|
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
|
|
27
56
|
|
|
28
57
|
data.edge_index = edge_index
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, Tuple
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -21,7 +21,7 @@ from torch_geometric.utils import (
|
|
|
21
21
|
@functional_transform('gdc')
|
|
22
22
|
class GDC(BaseTransform):
|
|
23
23
|
r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
|
|
24
|
-
`"Diffusion Improves Graph Learning" <https://
|
|
24
|
+
`"Diffusion Improves Graph Learning" <https://arxiv.org/abs/1911.05485>`_
|
|
25
25
|
paper (functional name: :obj:`gdc`).
|
|
26
26
|
|
|
27
27
|
.. note::
|
|
@@ -78,18 +78,17 @@ class GDC(BaseTransform):
|
|
|
78
78
|
self_loop_weight: float = 1.,
|
|
79
79
|
normalization_in: str = 'sym',
|
|
80
80
|
normalization_out: str = 'col',
|
|
81
|
-
diffusion_kwargs: Dict[str, Any] =
|
|
82
|
-
sparsification_kwargs: Dict[str, Any] =
|
|
83
|
-
method='threshold',
|
|
84
|
-
avg_degree=64,
|
|
85
|
-
),
|
|
81
|
+
diffusion_kwargs: Optional[Dict[str, Any]] = None,
|
|
82
|
+
sparsification_kwargs: Optional[Dict[str, Any]] = None,
|
|
86
83
|
exact: bool = True,
|
|
87
84
|
) -> None:
|
|
88
85
|
self.self_loop_weight = self_loop_weight
|
|
89
86
|
self.normalization_in = normalization_in
|
|
90
87
|
self.normalization_out = normalization_out
|
|
91
|
-
self.diffusion_kwargs = diffusion_kwargs
|
|
92
|
-
|
|
88
|
+
self.diffusion_kwargs = diffusion_kwargs or dict(
|
|
89
|
+
method='ppr', alpha=0.15)
|
|
90
|
+
self.sparsification_kwargs = sparsification_kwargs or dict(
|
|
91
|
+
method='threshold', avg_degree=64)
|
|
93
92
|
self.exact = exact
|
|
94
93
|
|
|
95
94
|
if self_loop_weight:
|
|
@@ -47,7 +47,7 @@ class LargestConnectedComponents(BaseTransform):
|
|
|
47
47
|
return data
|
|
48
48
|
|
|
49
49
|
_, count = np.unique(component, return_counts=True)
|
|
50
|
-
subset_np = np.
|
|
50
|
+
subset_np = np.isin(component, count.argsort()[-self.num_components:])
|
|
51
51
|
subset = torch.from_numpy(subset_np)
|
|
52
52
|
subset = subset.to(data.edge_index.device, torch.bool)
|
|
53
53
|
|
|
@@ -19,7 +19,11 @@ def get_attrs_with_suffix(
|
|
|
19
19
|
return [key for key in store.keys() if key.endswith(suffix)]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def get_mask_size(
|
|
22
|
+
def get_mask_size(
|
|
23
|
+
attr: str,
|
|
24
|
+
store: BaseStorage,
|
|
25
|
+
size: Optional[int],
|
|
26
|
+
) -> Optional[int]:
|
|
23
27
|
if size is not None:
|
|
24
28
|
return size
|
|
25
29
|
return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
|
|
@@ -53,7 +53,7 @@ class NodePropertySplit(BaseTransform):
|
|
|
53
53
|
|
|
54
54
|
property_name = 'popularity'
|
|
55
55
|
ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
|
|
56
|
-
|
|
56
|
+
transform = NodePropertySplit(property_name, ratios)
|
|
57
57
|
|
|
58
58
|
data = transform(data)
|
|
59
59
|
"""
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Union
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
2
|
|
|
3
3
|
from torch_geometric.data import Data, HeteroData
|
|
4
4
|
from torch_geometric.data.datapipes import functional_transform
|
|
@@ -14,8 +14,8 @@ class NormalizeFeatures(BaseTransform):
|
|
|
14
14
|
attrs (List[str]): The names of attributes to normalize.
|
|
15
15
|
(default: :obj:`["x"]`)
|
|
16
16
|
"""
|
|
17
|
-
def __init__(self, attrs: List[str] =
|
|
18
|
-
self.attrs = attrs
|
|
17
|
+
def __init__(self, attrs: Optional[List[str]] = None) -> None:
|
|
18
|
+
self.attrs = attrs or ["x"]
|
|
19
19
|
|
|
20
20
|
def forward(
|
|
21
21
|
self,
|
|
@@ -262,7 +262,7 @@ class Pad(BaseTransform):
|
|
|
262
262
|
All the attributes of node types other than :obj:`v0` and :obj:`v1` are
|
|
263
263
|
padded using a value of :obj:`1.0`.
|
|
264
264
|
All the attributes of the :obj:`('v0', 'e0', 'v1')` edge type are padded
|
|
265
|
-
|
|
265
|
+
using a value of :obj:`3.5`.
|
|
266
266
|
The :obj:`edge_attr` attributes of the :obj:`('v1', 'e0', 'v0')` edge type
|
|
267
267
|
are padded using a value of :obj:`-1.5`, and any other attributes of this
|
|
268
268
|
edge type are padded using a value of :obj:`5.5`.
|
|
@@ -245,7 +245,7 @@ class RandomLinkSplit(BaseTransform):
|
|
|
245
245
|
warnings.warn(
|
|
246
246
|
f"There are not enough negative edges to satisfy "
|
|
247
247
|
"the provided sampling ratio. The ratio will be "
|
|
248
|
-
f"adjusted to {ratio:.2f}.")
|
|
248
|
+
f"adjusted to {ratio:.2f}.", stacklevel=2)
|
|
249
249
|
num_neg_train = int((num_neg_train / num_neg) * num_neg_found)
|
|
250
250
|
num_neg_val = int((num_neg_val / num_neg) * num_neg_found)
|
|
251
251
|
num_neg_test = num_neg_found - num_neg_train - num_neg_val
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Union
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
2
|
|
|
3
3
|
from torch_geometric.data import Data, HeteroData
|
|
4
4
|
from torch_geometric.data.datapipes import functional_transform
|
|
@@ -22,9 +22,11 @@ class RemoveDuplicatedEdges(BaseTransform):
|
|
|
22
22
|
"""
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
|
-
key: Union[str, List[str]] =
|
|
25
|
+
key: Optional[Union[str, List[str]]] = None,
|
|
26
26
|
reduce: str = "add",
|
|
27
27
|
) -> None:
|
|
28
|
+
key = key or ['edge_attr', 'edge_weight']
|
|
29
|
+
|
|
28
30
|
if isinstance(key, str):
|
|
29
31
|
key = [key]
|
|
30
32
|
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from torch_geometric.data import Data, HeteroData
|
|
4
|
+
from torch_geometric.data.datapipes import functional_transform
|
|
5
|
+
from torch_geometric.transforms import BaseTransform
|
|
6
|
+
from torch_geometric.utils import remove_self_loops
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@functional_transform('remove_self_loops')
|
|
10
|
+
class RemoveSelfLoops(BaseTransform):
|
|
11
|
+
r"""Removes all self-loops in the given homogeneous or heterogeneous
|
|
12
|
+
graph (functional name: :obj:`remove_self_loops`).
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
attr (str, optional): The name of the attribute of edge weights
|
|
16
|
+
or multi-dimensional edge features to pass to
|
|
17
|
+
:meth:`torch_geometric.utils.remove_self_loops`.
|
|
18
|
+
(default: :obj:`"edge_weight"`)
|
|
19
|
+
"""
|
|
20
|
+
def __init__(self, attr: str = 'edge_weight') -> None:
|
|
21
|
+
self.attr = attr
|
|
22
|
+
|
|
23
|
+
def forward(
|
|
24
|
+
self,
|
|
25
|
+
data: Union[Data, HeteroData],
|
|
26
|
+
) -> Union[Data, HeteroData]:
|
|
27
|
+
for store in data.edge_stores:
|
|
28
|
+
if store.is_bipartite() or 'edge_index' not in store:
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
store.edge_index, store[self.attr] = remove_self_loops(
|
|
32
|
+
store.edge_index,
|
|
33
|
+
edge_attr=store.get(self.attr, None),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return data
|
|
@@ -94,7 +94,7 @@ class RootedSubgraph(BaseTransform, ABC):
|
|
|
94
94
|
arange = torch.arange(n_id.size(0), device=data.edge_index.device)
|
|
95
95
|
node_map = data.edge_index.new_ones(num_nodes, num_nodes)
|
|
96
96
|
node_map[n_sub_batch, n_id] = arange
|
|
97
|
-
sub_edge_index += (arange *
|
|
97
|
+
sub_edge_index += (arange * num_nodes)[e_sub_batch]
|
|
98
98
|
sub_edge_index = node_map.view(-1)[sub_edge_index]
|
|
99
99
|
|
|
100
100
|
return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch
|
|
@@ -11,7 +11,7 @@ class SVDFeatureReduction(BaseTransform):
|
|
|
11
11
|
Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`).
|
|
12
12
|
|
|
13
13
|
Args:
|
|
14
|
-
out_channels (int): The
|
|
14
|
+
out_channels (int): The dimensionality of node features after
|
|
15
15
|
reduction.
|
|
16
16
|
"""
|
|
17
17
|
def __init__(self, out_channels: int):
|
|
@@ -37,7 +37,8 @@ class VirtualNode(BaseTransform):
|
|
|
37
37
|
col = torch.cat([col, full, arange], dim=0)
|
|
38
38
|
edge_index = torch.stack([row, col], dim=0)
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
num_edge_types = int(edge_type.max()) if edge_type.numel() > 0 else 0
|
|
41
|
+
new_type = edge_type.new_full((num_nodes, ), num_edge_types + 1)
|
|
41
42
|
edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0)
|
|
42
43
|
|
|
43
44
|
old_data = copy.copy(data)
|