pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/nn/encoding.py
CHANGED
|
@@ -1,8 +1,14 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor
|
|
5
6
|
|
|
7
|
+
__all__ = classes = [
|
|
8
|
+
'PositionalEncoding',
|
|
9
|
+
'TemporalEncoding',
|
|
10
|
+
]
|
|
11
|
+
|
|
6
12
|
|
|
7
13
|
class PositionalEncoding(torch.nn.Module):
|
|
8
14
|
r"""The positional encoding scheme from the `"Attention Is All You Need"
|
|
@@ -23,12 +29,15 @@ class PositionalEncoding(torch.nn.Module):
|
|
|
23
29
|
granularity (float, optional): The granularity of the positions. If
|
|
24
30
|
set to smaller value, the encoder will capture more fine-grained
|
|
25
31
|
changes in positions. (default: :obj:`1.0`)
|
|
32
|
+
device (torch.device, optional): The device of the module.
|
|
33
|
+
(default: :obj:`None`)
|
|
26
34
|
"""
|
|
27
35
|
def __init__(
|
|
28
36
|
self,
|
|
29
37
|
out_channels: int,
|
|
30
38
|
base_freq: float = 1e-4,
|
|
31
39
|
granularity: float = 1.0,
|
|
40
|
+
device: Optional[torch.device] = None,
|
|
32
41
|
):
|
|
33
42
|
super().__init__()
|
|
34
43
|
|
|
@@ -40,7 +49,8 @@ class PositionalEncoding(torch.nn.Module):
|
|
|
40
49
|
self.base_freq = base_freq
|
|
41
50
|
self.granularity = granularity
|
|
42
51
|
|
|
43
|
-
frequency = torch.logspace(0, 1, out_channels // 2, base_freq
|
|
52
|
+
frequency = torch.logspace(0, 1, out_channels // 2, base_freq,
|
|
53
|
+
device=device)
|
|
44
54
|
self.register_buffer('frequency', frequency)
|
|
45
55
|
|
|
46
56
|
self.reset_parameters()
|
|
@@ -75,13 +85,17 @@ class TemporalEncoding(torch.nn.Module):
|
|
|
75
85
|
|
|
76
86
|
Args:
|
|
77
87
|
out_channels (int): Size :math:`d` of each output sample.
|
|
88
|
+
device (torch.device, optional): The device of the module.
|
|
89
|
+
(default: :obj:`None`)
|
|
78
90
|
"""
|
|
79
|
-
def __init__(self, out_channels: int
|
|
91
|
+
def __init__(self, out_channels: int,
|
|
92
|
+
device: Optional[torch.device] = None):
|
|
80
93
|
super().__init__()
|
|
81
94
|
self.out_channels = out_channels
|
|
82
95
|
|
|
83
96
|
sqrt = math.sqrt(out_channels)
|
|
84
|
-
weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels
|
|
97
|
+
weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels,
|
|
98
|
+
device=device).view(1, -1)
|
|
85
99
|
self.register_buffer('weight', weight)
|
|
86
100
|
|
|
87
101
|
self.reset_parameters()
|
torch_geometric/nn/fx.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Dict, Optional
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from torch import Tensor
|
|
6
7
|
from torch.nn import Module, ModuleDict, ModuleList, Sequential
|
|
7
8
|
|
|
8
9
|
try:
|
|
@@ -18,8 +19,8 @@ class Transformer:
|
|
|
18
19
|
:class:`~torch.nn.Module`.
|
|
19
20
|
:class:`Transformer` works entirely symbolically.
|
|
20
21
|
|
|
21
|
-
Methods in the :class:`Transformer` class can be
|
|
22
|
-
behavior of transformation.
|
|
22
|
+
Methods in the :class:`Transformer` class can be overridden to customize
|
|
23
|
+
the behavior of transformation.
|
|
23
24
|
|
|
24
25
|
.. code-block:: none
|
|
25
26
|
|
|
@@ -129,11 +130,13 @@ class Transformer:
|
|
|
129
130
|
# (node-level, edge-level) by filling `self._state`:
|
|
130
131
|
for node in list(self.graph.nodes):
|
|
131
132
|
if node.op == 'call_function' and 'training' in node.kwargs:
|
|
132
|
-
warnings.warn(
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
133
|
+
warnings.warn(
|
|
134
|
+
f"Found function '{node.name}' with keyword "
|
|
135
|
+
f"argument 'training'. During FX tracing, this "
|
|
136
|
+
f"will likely be baked in as a constant value. "
|
|
137
|
+
f"Consider replacing this function by a module "
|
|
138
|
+
f"to properly encapsulate its training flag.",
|
|
139
|
+
stacklevel=2)
|
|
137
140
|
|
|
138
141
|
if node.op == 'placeholder':
|
|
139
142
|
if node.name not in self._state:
|
|
@@ -283,13 +286,13 @@ def symbolic_trace(
|
|
|
283
286
|
# TODO We currently only trace top-level modules.
|
|
284
287
|
return not isinstance(module, torch.nn.Sequential)
|
|
285
288
|
|
|
286
|
-
# Note: This is a hack around the fact that `
|
|
289
|
+
# Note: This is a hack around the fact that `Aggregation.__call__`
|
|
287
290
|
# is not patched by the base implementation of `trace`.
|
|
288
291
|
# see https://github.com/pyg-team/pytorch_geometric/pull/5021 for
|
|
289
292
|
# details on the rationale
|
|
290
293
|
# TODO: Revisit https://github.com/pyg-team/pytorch_geometric/pull/5021
|
|
291
294
|
@st.compatibility(is_backward_compatible=True)
|
|
292
|
-
def trace(self, root:
|
|
295
|
+
def trace(self, root: Union[torch.nn.Module, Callable[..., Any]],
|
|
293
296
|
concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
|
|
294
297
|
|
|
295
298
|
if isinstance(root, torch.nn.Module):
|
|
@@ -303,17 +306,16 @@ def symbolic_trace(
|
|
|
303
306
|
self.root = torch.nn.Module()
|
|
304
307
|
fn = root
|
|
305
308
|
|
|
306
|
-
tracer_cls: Optional[
|
|
309
|
+
tracer_cls: Optional[Type['Tracer']] = getattr(
|
|
307
310
|
self, '__class__', None)
|
|
308
311
|
self.graph = Graph(tracer_cls=tracer_cls)
|
|
309
312
|
|
|
310
|
-
self.tensor_attrs: Dict[
|
|
311
|
-
str] = {}
|
|
313
|
+
self.tensor_attrs: Dict[Union[Tensor, st.ScriptObject], str] = {}
|
|
312
314
|
|
|
313
315
|
def collect_tensor_attrs(m: torch.nn.Module,
|
|
314
|
-
prefix_atoms:
|
|
316
|
+
prefix_atoms: List[str]):
|
|
315
317
|
for k, v in m.__dict__.items():
|
|
316
|
-
if isinstance(v, (
|
|
318
|
+
if isinstance(v, (Tensor, st.ScriptObject)):
|
|
317
319
|
self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
|
|
318
320
|
for k, v in m.named_children():
|
|
319
321
|
collect_tensor_attrs(v, prefix_atoms + [k])
|
torch_geometric/nn/model_hub.py
CHANGED
|
@@ -4,6 +4,8 @@ from typing import Any, Dict, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from torch_geometric.io import fs
|
|
8
|
+
|
|
7
9
|
try:
|
|
8
10
|
from huggingface_hub import ModelHubMixin, hf_hub_download
|
|
9
11
|
except ImportError:
|
|
@@ -142,10 +144,10 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
|
142
144
|
revision,
|
|
143
145
|
cache_dir,
|
|
144
146
|
force_download,
|
|
145
|
-
proxies,
|
|
146
|
-
resume_download,
|
|
147
147
|
local_files_only,
|
|
148
148
|
token,
|
|
149
|
+
proxies=None,
|
|
150
|
+
resume_download=False,
|
|
149
151
|
dataset_name='',
|
|
150
152
|
model_name='',
|
|
151
153
|
map_location='cpu',
|
|
@@ -163,8 +165,6 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
|
163
165
|
revision=revision,
|
|
164
166
|
cache_dir=cache_dir,
|
|
165
167
|
force_download=force_download,
|
|
166
|
-
proxies=proxies,
|
|
167
|
-
resume_download=resume_download,
|
|
168
168
|
token=token,
|
|
169
169
|
local_files_only=local_files_only,
|
|
170
170
|
)
|
|
@@ -175,7 +175,7 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
|
175
175
|
|
|
176
176
|
model = cls(dataset_name, model_name, model_kwargs)
|
|
177
177
|
|
|
178
|
-
state_dict =
|
|
178
|
+
state_dict = fs.torch_load(model_file, map_location=map_location)
|
|
179
179
|
model.load_state_dict(state_dict, strict=strict)
|
|
180
180
|
model.eval()
|
|
181
181
|
|
|
@@ -186,8 +186,6 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
|
186
186
|
cls,
|
|
187
187
|
pretrained_model_name_or_path: str,
|
|
188
188
|
force_download: bool = False,
|
|
189
|
-
resume_download: bool = False,
|
|
190
|
-
proxies: Optional[Dict] = None,
|
|
191
189
|
token: Optional[Union[str, bool]] = None,
|
|
192
190
|
cache_dir: Optional[str] = None,
|
|
193
191
|
local_files_only: bool = False,
|
|
@@ -213,13 +211,6 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
|
213
211
|
(re-)download of the model weights and configuration files,
|
|
214
212
|
overriding the cached versions if they exist.
|
|
215
213
|
(default: :obj:`False`)
|
|
216
|
-
resume_download (bool, optional): Whether to delete incompletely
|
|
217
|
-
received files. Will attempt to resume the download if such a
|
|
218
|
-
file exists. (default: :obj:`False`)
|
|
219
|
-
proxies (Dict[str, str], optional): A dictionary of proxy servers
|
|
220
|
-
to use by protocol or endpoint, *e.g.*,
|
|
221
|
-
:obj:`{'http': 'foo.bar:3128', 'http://host': 'foo.bar:4012'}`.
|
|
222
|
-
The proxies are used on each request. (default: :obj:`None`)
|
|
223
214
|
token (str or bool, optional): The token to use as HTTP bearer
|
|
224
215
|
authorization for remote files. If set to :obj:`True`, will use
|
|
225
216
|
the token generated when running :obj:`transformers-cli login`
|
|
@@ -237,8 +228,6 @@ class PyGModelHubMixin(ModelHubMixin):
|
|
|
237
228
|
return super().from_pretrained(
|
|
238
229
|
pretrained_model_name_or_path,
|
|
239
230
|
force_download=force_download,
|
|
240
|
-
resume_download=resume_download,
|
|
241
|
-
proxies=proxies,
|
|
242
231
|
use_auth_token=token,
|
|
243
232
|
cache_dir=cache_dir,
|
|
244
233
|
local_files_only=local_files_only,
|
|
@@ -12,6 +12,7 @@ from .re_net import RENet
|
|
|
12
12
|
from .graph_unet import GraphUNet
|
|
13
13
|
from .schnet import SchNet
|
|
14
14
|
from .dimenet import DimeNet, DimeNetPlusPlus
|
|
15
|
+
from .gpse import GPSE, GPSENodeEncoder
|
|
15
16
|
from .captum import to_captum_model
|
|
16
17
|
from .metapath2vec import MetaPath2Vec
|
|
17
18
|
from .deepgcn import DeepGCNLayer
|
|
@@ -28,10 +29,14 @@ from .gnnff import GNNFF
|
|
|
28
29
|
from .pmlp import PMLP
|
|
29
30
|
from .neural_fingerprint import NeuralFingerprint
|
|
30
31
|
from .visnet import ViSNet
|
|
32
|
+
from .lpformer import LPFormer
|
|
33
|
+
from .sgformer import SGFormer
|
|
31
34
|
|
|
35
|
+
from .polynormer import Polynormer
|
|
32
36
|
# Deprecated:
|
|
33
37
|
from torch_geometric.explain.algorithm.captum import (to_captum_input,
|
|
34
38
|
captum_output_to_dicts)
|
|
39
|
+
from .attract_repel import ARLinkPredictor
|
|
35
40
|
|
|
36
41
|
__all__ = classes = [
|
|
37
42
|
'MLP',
|
|
@@ -57,6 +62,8 @@ __all__ = classes = [
|
|
|
57
62
|
'SchNet',
|
|
58
63
|
'DimeNet',
|
|
59
64
|
'DimeNetPlusPlus',
|
|
65
|
+
'GPSE',
|
|
66
|
+
'GPSENodeEncoder',
|
|
60
67
|
'to_captum_model',
|
|
61
68
|
'to_captum_input',
|
|
62
69
|
'captum_output_to_dicts',
|
|
@@ -75,4 +82,8 @@ __all__ = classes = [
|
|
|
75
82
|
'PMLP',
|
|
76
83
|
'NeuralFingerprint',
|
|
77
84
|
'ViSNet',
|
|
85
|
+
'LPFormer',
|
|
86
|
+
'SGFormer',
|
|
87
|
+
'Polynormer',
|
|
88
|
+
'ARLinkPredictor',
|
|
78
89
|
]
|
|
@@ -160,7 +160,7 @@ class AttentiveFP(torch.nn.Module):
|
|
|
160
160
|
edge_index = torch.stack([row, batch], dim=0)
|
|
161
161
|
|
|
162
162
|
out = global_add_pool(x, batch).relu_()
|
|
163
|
-
for
|
|
163
|
+
for _ in range(self.num_timesteps):
|
|
164
164
|
h = F.elu_(self.mol_conv((x, out), edge_index))
|
|
165
165
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
166
166
|
out = self.mol_gru(h, out).relu_()
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ARLinkPredictor(torch.nn.Module):
|
|
6
|
+
r"""Link predictor using Attract-Repel embeddings from the paper
|
|
7
|
+
`"Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs"
|
|
8
|
+
<https://arxiv.org/abs/2106.09671>`_.
|
|
9
|
+
|
|
10
|
+
This model splits node embeddings into: attract and
|
|
11
|
+
repel.
|
|
12
|
+
The edge prediction score is computed as the dot product of attract
|
|
13
|
+
components minus the dot product of repel components.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
in_channels (int): Size of each input sample.
|
|
17
|
+
hidden_channels (int): Size of hidden embeddings.
|
|
18
|
+
out_channels (int, optional): Size of output embeddings.
|
|
19
|
+
If set to :obj:`None`, will default to :obj:`hidden_channels`.
|
|
20
|
+
(default: :obj:`None`)
|
|
21
|
+
num_layers (int): Number of message passing layers.
|
|
22
|
+
(default: :obj:`2`)
|
|
23
|
+
dropout (float): Dropout probability. (default: :obj:`0.0`)
|
|
24
|
+
attract_ratio (float): Ratio to use for attract component.
|
|
25
|
+
Must be between 0 and 1. (default: :obj:`0.5`)
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self, in_channels, hidden_channels, out_channels=None,
|
|
28
|
+
num_layers=2, dropout=0.0, attract_ratio=0.5):
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
if out_channels is None:
|
|
32
|
+
out_channels = hidden_channels
|
|
33
|
+
|
|
34
|
+
self.in_channels = in_channels
|
|
35
|
+
self.hidden_channels = hidden_channels
|
|
36
|
+
self.out_channels = out_channels
|
|
37
|
+
self.num_layers = num_layers
|
|
38
|
+
self.dropout = dropout
|
|
39
|
+
|
|
40
|
+
if not 0 <= attract_ratio <= 1:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"attract_ratio must be between 0 and 1, got {attract_ratio}")
|
|
43
|
+
|
|
44
|
+
self.attract_ratio = attract_ratio
|
|
45
|
+
self.attract_dim = int(out_channels * attract_ratio)
|
|
46
|
+
self.repel_dim = out_channels - self.attract_dim
|
|
47
|
+
|
|
48
|
+
# Create model layers
|
|
49
|
+
self.lins = torch.nn.ModuleList()
|
|
50
|
+
self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
|
|
51
|
+
|
|
52
|
+
for _ in range(num_layers - 2):
|
|
53
|
+
self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
|
|
54
|
+
|
|
55
|
+
# Final layer splits into attract and repel components
|
|
56
|
+
self.lin_attract = torch.nn.Linear(hidden_channels, self.attract_dim)
|
|
57
|
+
self.lin_repel = torch.nn.Linear(hidden_channels, self.repel_dim)
|
|
58
|
+
|
|
59
|
+
self.reset_parameters()
|
|
60
|
+
|
|
61
|
+
def reset_parameters(self):
|
|
62
|
+
"""Reset all learnable parameters."""
|
|
63
|
+
for lin in self.lins:
|
|
64
|
+
lin.reset_parameters()
|
|
65
|
+
self.lin_attract.reset_parameters()
|
|
66
|
+
self.lin_repel.reset_parameters()
|
|
67
|
+
|
|
68
|
+
def encode(self, x, *args, **kwargs):
|
|
69
|
+
"""Encode node features into attract-repel embeddings.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
x (torch.Tensor): Node feature matrix of shape
|
|
73
|
+
:obj:`[num_nodes, in_channels]`.
|
|
74
|
+
*args: Variable length argument list
|
|
75
|
+
**kwargs: Arbitrary keyword arguments
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
for lin in self.lins:
|
|
79
|
+
x = lin(x)
|
|
80
|
+
x = F.relu(x)
|
|
81
|
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
82
|
+
|
|
83
|
+
# Split into attract and repel components
|
|
84
|
+
attract_x = self.lin_attract(x)
|
|
85
|
+
repel_x = self.lin_repel(x)
|
|
86
|
+
|
|
87
|
+
return attract_x, repel_x
|
|
88
|
+
|
|
89
|
+
def decode(self, attract_z, repel_z, edge_index):
|
|
90
|
+
"""Decode edge scores from attract-repel embeddings.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
attract_z (torch.Tensor): Attract embeddings of shape
|
|
94
|
+
:obj:`[num_nodes, attract_dim]`.
|
|
95
|
+
repel_z (torch.Tensor): Repel embeddings of shape
|
|
96
|
+
:obj:`[num_nodes, repel_dim]`.
|
|
97
|
+
edge_index (torch.Tensor): Edge indices of shape
|
|
98
|
+
:obj:`[2, num_edges]`.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
torch.Tensor: Edge prediction scores.
|
|
102
|
+
"""
|
|
103
|
+
# Get node embeddings for edges
|
|
104
|
+
row, col = edge_index
|
|
105
|
+
attract_z_row = attract_z[row]
|
|
106
|
+
attract_z_col = attract_z[col]
|
|
107
|
+
repel_z_row = repel_z[row]
|
|
108
|
+
repel_z_col = repel_z[col]
|
|
109
|
+
|
|
110
|
+
# Compute attract-repel scores
|
|
111
|
+
attract_score = torch.sum(attract_z_row * attract_z_col, dim=1)
|
|
112
|
+
repel_score = torch.sum(repel_z_row * repel_z_col, dim=1)
|
|
113
|
+
|
|
114
|
+
return attract_score - repel_score
|
|
115
|
+
|
|
116
|
+
def forward(self, x, edge_index):
|
|
117
|
+
"""Forward pass for link prediction.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
x (torch.Tensor): Node feature matrix.
|
|
121
|
+
edge_index (torch.Tensor): Edge indices to predict.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
torch.Tensor: Predicted edge scores.
|
|
125
|
+
"""
|
|
126
|
+
# Encode nodes into attract-repel embeddings
|
|
127
|
+
attract_z, repel_z = self.encode(x)
|
|
128
|
+
|
|
129
|
+
# Decode target edges
|
|
130
|
+
return torch.sigmoid(self.decode(attract_z, repel_z, edge_index))
|
|
131
|
+
|
|
132
|
+
def calculate_r_fraction(self, attract_z, repel_z):
|
|
133
|
+
"""Calculate the R-fraction (proportion of energy in repel space).
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
attract_z (torch.Tensor): Attract embeddings.
|
|
137
|
+
repel_z (torch.Tensor): Repel embeddings.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
float: R-fraction value.
|
|
141
|
+
"""
|
|
142
|
+
attract_norm_squared = torch.sum(attract_z**2)
|
|
143
|
+
repel_norm_squared = torch.sum(repel_z**2)
|
|
144
|
+
|
|
145
|
+
r_fraction = repel_norm_squared / (attract_norm_squared +
|
|
146
|
+
repel_norm_squared + 1e-10)
|
|
147
|
+
|
|
148
|
+
return r_fraction.item()
|
|
@@ -415,7 +415,8 @@ class GCN(BasicGNN):
|
|
|
415
415
|
(default: :obj:`None`)
|
|
416
416
|
jk (str, optional): The Jumping Knowledge mode. If specified, the model
|
|
417
417
|
will additionally apply a final linear transformation to transform
|
|
418
|
-
node embeddings to the expected output feature dimensionality
|
|
418
|
+
node embeddings to the expected output feature dimensionality,
|
|
419
|
+
while default will not.
|
|
419
420
|
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
|
|
420
421
|
:obj:`"lstm"`). (default: :obj:`None`)
|
|
421
422
|
**kwargs (optional): Additional arguments of
|
|
@@ -94,7 +94,7 @@ def to_captum_model(
|
|
|
94
94
|
function will return the output of the model for the element at
|
|
95
95
|
the index specified. (default: :obj:`None`)
|
|
96
96
|
metadata (Metadata, optional): The metadata of the heterogeneous graph.
|
|
97
|
-
Only required if
|
|
97
|
+
Only required if explaining a
|
|
98
98
|
:class:`~torch_geometric.data.HeteroData` object.
|
|
99
99
|
(default: :obj:`None`)
|
|
100
100
|
"""
|
|
@@ -106,7 +106,7 @@ class DeepGraphInfomax(torch.nn.Module):
|
|
|
106
106
|
"""
|
|
107
107
|
from sklearn.linear_model import LogisticRegression
|
|
108
108
|
|
|
109
|
-
clf = LogisticRegression(solver=solver,
|
|
109
|
+
clf = LogisticRegression(*args, solver=solver,
|
|
110
110
|
**kwargs).fit(train_z.detach().cpu().numpy(),
|
|
111
111
|
train_y.detach().cpu().numpy())
|
|
112
112
|
return clf.score(test_z.detach().cpu().numpy(),
|
|
@@ -755,7 +755,7 @@ class DimeNetPlusPlus(DimeNet):
|
|
|
755
755
|
interaction blocks after the skip connection. (default: :obj:`2`)
|
|
756
756
|
num_output_layers: (int, optional): Number of linear layers for the
|
|
757
757
|
output blocks. (default: :obj:`3`)
|
|
758
|
-
act: (str or Callable, optional): The activation
|
|
758
|
+
act: (str or Callable, optional): The activation function.
|
|
759
759
|
(default: :obj:`"swish"`)
|
|
760
760
|
output_initializer (str, optional): The initialization method for the
|
|
761
761
|
output layer (:obj:`"zeros"`, :obj:`"glorot_orthogonal"`).
|
|
@@ -805,7 +805,7 @@ class DimeNetPlusPlus(DimeNet):
|
|
|
805
805
|
|
|
806
806
|
# We are re-using the RBF, SBF and embedding layers of `DimeNet` and
|
|
807
807
|
# redefine output_block and interaction_block in DimeNet++.
|
|
808
|
-
# Hence, it is to be noted that in the above
|
|
808
|
+
# Hence, it is to be noted that in the above initialization, the
|
|
809
809
|
# variable `num_bilinear` does not have any purpose as it is used
|
|
810
810
|
# solely in the `OutputBlock` of DimeNet:
|
|
811
811
|
self.output_blocks = torch.nn.ModuleList([
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# Shameless steal from: https://github.com/klicperajo/dimenet
|
|
2
2
|
|
|
3
|
+
import math
|
|
4
|
+
|
|
3
5
|
import numpy as np
|
|
4
6
|
import sympy as sym
|
|
5
7
|
from scipy import special as sp
|
|
@@ -62,8 +64,8 @@ def bessel_basis(n, k):
|
|
|
62
64
|
|
|
63
65
|
|
|
64
66
|
def sph_harm_prefactor(k, m):
|
|
65
|
-
return ((2 * k + 1) *
|
|
66
|
-
(4 * np.pi *
|
|
67
|
+
return ((2 * k + 1) * math.factorial(k - abs(m)) /
|
|
68
|
+
(4 * np.pi * math.factorial(k + abs(m))))**0.5
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
def associated_legendre_polynomials(k, zero_m_only=True):
|