pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
torch_geometric/nn/__init__.py
CHANGED
|
@@ -25,6 +25,7 @@ from .deep_sets import DeepSetsAggregation
|
|
|
25
25
|
from .set_transformer import SetTransformerAggregation
|
|
26
26
|
from .lcm import LCMAggregation
|
|
27
27
|
from .variance_preserving import VariancePreservingAggregation
|
|
28
|
+
from .patch_transformer import PatchTransformerAggregation
|
|
28
29
|
|
|
29
30
|
__all__ = classes = [
|
|
30
31
|
'Aggregation',
|
|
@@ -53,4 +54,5 @@ __all__ = classes = [
|
|
|
53
54
|
'SetTransformerAggregation',
|
|
54
55
|
'LCMAggregation',
|
|
55
56
|
'VariancePreservingAggregation',
|
|
57
|
+
'PatchTransformerAggregation',
|
|
56
58
|
]
|
torch_geometric/nn/aggr/base.py
CHANGED
|
@@ -135,7 +135,7 @@ class Aggregation(torch.nn.Module):
|
|
|
135
135
|
if index.numel() > 0 and dim_size <= int(index.max()):
|
|
136
136
|
raise ValueError(f"Encountered invalid 'dim_size' (got "
|
|
137
137
|
f"'{dim_size}' but expected "
|
|
138
|
-
f">= '{int(index.max()) + 1}')")
|
|
138
|
+
f">= '{int(index.max()) + 1}')") from e
|
|
139
139
|
raise e
|
|
140
140
|
|
|
141
141
|
def __repr__(self) -> str:
|
|
@@ -52,7 +52,7 @@ class MomentumOptimizer(torch.nn.Module):
|
|
|
52
52
|
layer. It is based on an unrolled Nesterov momentum algorithm.
|
|
53
53
|
|
|
54
54
|
Args:
|
|
55
|
-
learning_rate (
|
|
55
|
+
learning_rate (float): learning rate for optimizer.
|
|
56
56
|
momentum (float): momentum for optimizer.
|
|
57
57
|
learnable (bool): If :obj:`True` then the :obj:`learning_rate` and
|
|
58
58
|
:obj:`momentum` will be learnable parameters. If False they
|
torch_geometric/nn/aggr/fused.py
CHANGED
|
@@ -216,7 +216,7 @@ class FusedAggregation(Aggregation):
|
|
|
216
216
|
outs: List[Optional[Tensor]] = []
|
|
217
217
|
|
|
218
218
|
# Iterate over all reduction ops to compute first results:
|
|
219
|
-
for
|
|
219
|
+
for reduce in self.reduce_ops:
|
|
220
220
|
if reduce is None:
|
|
221
221
|
outs.append(None)
|
|
222
222
|
continue
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from torch_geometric.experimental import disable_dynamic_shapes
|
|
8
|
+
from torch_geometric.nn.aggr import Aggregation
|
|
9
|
+
from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
|
|
10
|
+
from torch_geometric.nn.encoding import PositionalEncoding
|
|
11
|
+
from torch_geometric.utils import scatter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PatchTransformerAggregation(Aggregation):
|
|
15
|
+
r"""Performs patch transformer aggregation in which the elements to
|
|
16
|
+
aggregate are processed by multi-head attention blocks across patches, as
|
|
17
|
+
described in the `"Simplifying Temporal Heterogeneous Network for
|
|
18
|
+
Continuous-Time Link Prediction"
|
|
19
|
+
<https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
in_channels (int): Size of each input sample.
|
|
23
|
+
out_channels (int): Size of each output sample.
|
|
24
|
+
patch_size (int): Number of elements in a patch.
|
|
25
|
+
hidden_channels (int): Intermediate size of each sample.
|
|
26
|
+
num_transformer_blocks (int, optional): Number of transformer blocks
|
|
27
|
+
(default: :obj:`1`).
|
|
28
|
+
heads (int, optional): Number of multi-head-attentions.
|
|
29
|
+
(default: :obj:`1`)
|
|
30
|
+
dropout (float, optional): Dropout probability of attention weights.
|
|
31
|
+
(default: :obj:`0.0`)
|
|
32
|
+
aggr (str or list[str], optional): The aggregation module, *e.g.*,
|
|
33
|
+
:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
|
|
34
|
+
:obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
|
|
35
|
+
device (torch.device, optional): The device of the module.
|
|
36
|
+
(default: :obj:`None`)
|
|
37
|
+
"""
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
in_channels: int,
|
|
41
|
+
out_channels: int,
|
|
42
|
+
patch_size: int,
|
|
43
|
+
hidden_channels: int,
|
|
44
|
+
num_transformer_blocks: int = 1,
|
|
45
|
+
heads: int = 1,
|
|
46
|
+
dropout: float = 0.0,
|
|
47
|
+
aggr: Union[str, List[str]] = 'mean',
|
|
48
|
+
device: Optional[torch.device] = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
self.in_channels = in_channels
|
|
53
|
+
self.out_channels = out_channels
|
|
54
|
+
self.patch_size = patch_size
|
|
55
|
+
self.aggrs = [aggr] if isinstance(aggr, str) else aggr
|
|
56
|
+
|
|
57
|
+
assert len(self.aggrs) > 0
|
|
58
|
+
for aggr in self.aggrs:
|
|
59
|
+
assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
|
|
60
|
+
|
|
61
|
+
self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)
|
|
62
|
+
self.pad_projector = torch.nn.Linear(
|
|
63
|
+
patch_size * hidden_channels,
|
|
64
|
+
hidden_channels,
|
|
65
|
+
device=device,
|
|
66
|
+
)
|
|
67
|
+
self.pe = PositionalEncoding(hidden_channels, device=device)
|
|
68
|
+
|
|
69
|
+
self.blocks = torch.nn.ModuleList([
|
|
70
|
+
MultiheadAttentionBlock(
|
|
71
|
+
channels=hidden_channels,
|
|
72
|
+
heads=heads,
|
|
73
|
+
layer_norm=True,
|
|
74
|
+
dropout=dropout,
|
|
75
|
+
device=device,
|
|
76
|
+
) for _ in range(num_transformer_blocks)
|
|
77
|
+
])
|
|
78
|
+
|
|
79
|
+
self.fc = torch.nn.Linear(
|
|
80
|
+
hidden_channels * len(self.aggrs),
|
|
81
|
+
out_channels,
|
|
82
|
+
device=device,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def reset_parameters(self) -> None:
|
|
86
|
+
self.lin.reset_parameters()
|
|
87
|
+
self.pad_projector.reset_parameters()
|
|
88
|
+
self.pe.reset_parameters()
|
|
89
|
+
for block in self.blocks:
|
|
90
|
+
block.reset_parameters()
|
|
91
|
+
self.fc.reset_parameters()
|
|
92
|
+
|
|
93
|
+
@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
|
|
94
|
+
def forward(
|
|
95
|
+
self,
|
|
96
|
+
x: Tensor,
|
|
97
|
+
index: Tensor,
|
|
98
|
+
ptr: Optional[Tensor] = None,
|
|
99
|
+
dim_size: Optional[int] = None,
|
|
100
|
+
dim: int = -2,
|
|
101
|
+
max_num_elements: Optional[int] = None,
|
|
102
|
+
) -> Tensor:
|
|
103
|
+
|
|
104
|
+
if max_num_elements is None:
|
|
105
|
+
if ptr is not None:
|
|
106
|
+
count = ptr.diff()
|
|
107
|
+
else:
|
|
108
|
+
count = scatter(torch.ones_like(index), index, dim=0,
|
|
109
|
+
dim_size=dim_size, reduce='sum')
|
|
110
|
+
max_num_elements = int(count.max()) + 1
|
|
111
|
+
|
|
112
|
+
# Set `max_num_elements` to a multiple of `patch_size`:
|
|
113
|
+
max_num_elements = (math.floor(max_num_elements / self.patch_size) *
|
|
114
|
+
self.patch_size)
|
|
115
|
+
|
|
116
|
+
x = self.lin(x)
|
|
117
|
+
|
|
118
|
+
# TODO If groups are heavily unbalanced, this will create a lot of
|
|
119
|
+
# "empty" patches. Try to figure out a way to fix this.
|
|
120
|
+
# [batch_size, num_patches * patch_size, hidden_channels]
|
|
121
|
+
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
|
|
122
|
+
max_num_elements=max_num_elements)
|
|
123
|
+
|
|
124
|
+
# [batch_size, num_patches, patch_size * hidden_channels]
|
|
125
|
+
x = x.view(x.size(0), max_num_elements // self.patch_size,
|
|
126
|
+
self.patch_size * x.size(-1))
|
|
127
|
+
|
|
128
|
+
# [batch_size, num_patches, hidden_channels]
|
|
129
|
+
x = self.pad_projector(x)
|
|
130
|
+
|
|
131
|
+
x = x + self.pe(torch.arange(x.size(1), device=x.device))
|
|
132
|
+
|
|
133
|
+
# [batch_size, num_patches, hidden_channels]
|
|
134
|
+
for block in self.blocks:
|
|
135
|
+
x = block(x, x)
|
|
136
|
+
|
|
137
|
+
# [batch_size, hidden_channels]
|
|
138
|
+
outs: List[Tensor] = []
|
|
139
|
+
for aggr in self.aggrs:
|
|
140
|
+
out = getattr(torch, aggr)(x, dim=1)
|
|
141
|
+
outs.append(out[0] if isinstance(out, tuple) else out)
|
|
142
|
+
out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]
|
|
143
|
+
|
|
144
|
+
# [batch_size, out_channels]
|
|
145
|
+
return self.fc(out)
|
|
146
|
+
|
|
147
|
+
def __repr__(self) -> str:
|
|
148
|
+
return (f'{self.__class__.__name__}({self.in_channels}, '
|
|
149
|
+
f'{self.out_channels}, patch_size={self.patch_size})')
|
|
@@ -38,7 +38,7 @@ class SetTransformerAggregation(Aggregation):
|
|
|
38
38
|
(default: :obj:`1`)
|
|
39
39
|
concat (bool, optional): If set to :obj:`False`, the seed embeddings
|
|
40
40
|
are averaged instead of concatenated. (default: :obj:`True`)
|
|
41
|
-
|
|
41
|
+
layer_norm (str, optional): If set to :obj:`True`, will apply layer
|
|
42
42
|
normalization. (default: :obj:`False`)
|
|
43
43
|
dropout (float, optional): Dropout probability of attention weights.
|
|
44
44
|
(default: :obj:`0`)
|
torch_geometric/nn/aggr/utils.py
CHANGED
|
@@ -26,9 +26,11 @@ class MultiheadAttentionBlock(torch.nn.Module):
|
|
|
26
26
|
normalization. (default: :obj:`True`)
|
|
27
27
|
dropout (float, optional): Dropout probability of attention weights.
|
|
28
28
|
(default: :obj:`0`)
|
|
29
|
+
device (torch.device, optional): The device of the module.
|
|
30
|
+
(default: :obj:`None`)
|
|
29
31
|
"""
|
|
30
32
|
def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,
|
|
31
|
-
dropout: float = 0.0):
|
|
33
|
+
dropout: float = 0.0, device: Optional[torch.device] = None):
|
|
32
34
|
super().__init__()
|
|
33
35
|
|
|
34
36
|
self.channels = channels
|
|
@@ -40,10 +42,13 @@ class MultiheadAttentionBlock(torch.nn.Module):
|
|
|
40
42
|
heads,
|
|
41
43
|
batch_first=True,
|
|
42
44
|
dropout=dropout,
|
|
45
|
+
device=device,
|
|
43
46
|
)
|
|
44
|
-
self.lin = Linear(channels, channels)
|
|
45
|
-
self.layer_norm1 = LayerNorm(channels
|
|
46
|
-
|
|
47
|
+
self.lin = Linear(channels, channels, device=device)
|
|
48
|
+
self.layer_norm1 = LayerNorm(channels,
|
|
49
|
+
device=device) if layer_norm else None
|
|
50
|
+
self.layer_norm2 = LayerNorm(channels,
|
|
51
|
+
device=device) if layer_norm else None
|
|
47
52
|
|
|
48
53
|
def reset_parameters(self):
|
|
49
54
|
self.attn._reset_parameters()
|
|
@@ -1,3 +1,11 @@
|
|
|
1
1
|
from .performer import PerformerAttention
|
|
2
|
+
from .qformer import QFormer
|
|
3
|
+
from .sgformer import SGFormerAttention
|
|
4
|
+
from .polynormer import PolynormerAttention
|
|
2
5
|
|
|
3
|
-
__all__ = [
|
|
6
|
+
__all__ = classes = [
|
|
7
|
+
'PerformerAttention',
|
|
8
|
+
'QFormer',
|
|
9
|
+
'SGFormerAttention',
|
|
10
|
+
'PolynormerAttention',
|
|
11
|
+
]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PolynormerAttention(torch.nn.Module):
|
|
9
|
+
r"""The polynomial-expressive attention mechanism from the
|
|
10
|
+
`"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time"
|
|
11
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
channels (int): Size of each input sample.
|
|
15
|
+
heads (int, optional): Number of parallel attention heads.
|
|
16
|
+
head_channels (int, optional): Size of each attention head.
|
|
17
|
+
(default: :obj:`64.`)
|
|
18
|
+
beta (float, optional): Polynormer beta initialization.
|
|
19
|
+
(default: :obj:`0.9`)
|
|
20
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
|
21
|
+
and value in the self attention. (default: :obj:`False`)
|
|
22
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
23
|
+
(default: :obj:`True`)
|
|
24
|
+
dropout (float, optional): Dropout probability of the final
|
|
25
|
+
attention output. (default: :obj:`0.0`)
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
channels: int,
|
|
30
|
+
heads: int,
|
|
31
|
+
head_channels: int = 64,
|
|
32
|
+
beta: float = 0.9,
|
|
33
|
+
qkv_bias: bool = False,
|
|
34
|
+
qk_shared: bool = True,
|
|
35
|
+
dropout: float = 0.0,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self.head_channels = head_channels
|
|
40
|
+
self.heads = heads
|
|
41
|
+
self.beta = beta
|
|
42
|
+
self.qk_shared = qk_shared
|
|
43
|
+
|
|
44
|
+
inner_channels = heads * head_channels
|
|
45
|
+
self.h_lins = torch.nn.Linear(channels, inner_channels)
|
|
46
|
+
if not self.qk_shared:
|
|
47
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
48
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
49
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
50
|
+
self.lns = torch.nn.LayerNorm(inner_channels)
|
|
51
|
+
self.lin_out = torch.nn.Linear(inner_channels, inner_channels)
|
|
52
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
55
|
+
r"""Forward pass.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): Node feature tensor
|
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
60
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
61
|
+
each graph, and feature dimension :math:`F`.
|
|
62
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
63
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
64
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
65
|
+
"""
|
|
66
|
+
B, N, *_ = x.shape
|
|
67
|
+
h = self.h_lins(x)
|
|
68
|
+
k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)
|
|
69
|
+
if self.qk_shared:
|
|
70
|
+
q = k
|
|
71
|
+
else:
|
|
72
|
+
q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)
|
|
73
|
+
v = self.v(x).view(B, N, self.head_channels, self.heads)
|
|
74
|
+
|
|
75
|
+
if mask is not None:
|
|
76
|
+
mask = mask[:, :, None, None]
|
|
77
|
+
v.masked_fill_(~mask, 0.)
|
|
78
|
+
|
|
79
|
+
# numerator
|
|
80
|
+
kv = torch.einsum('bndh, bnmh -> bdmh', k, v)
|
|
81
|
+
num = torch.einsum('bndh, bdmh -> bnmh', q, kv)
|
|
82
|
+
|
|
83
|
+
# denominator
|
|
84
|
+
k_sum = torch.einsum('bndh -> bdh', k)
|
|
85
|
+
den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)
|
|
86
|
+
|
|
87
|
+
# linear global attention based on kernel trick
|
|
88
|
+
x = (num / (den + 1e-6)).reshape(B, N, -1)
|
|
89
|
+
x = self.lns(x) * (h + self.beta)
|
|
90
|
+
x = F.relu(self.lin_out(x))
|
|
91
|
+
x = self.dropout(x)
|
|
92
|
+
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
def reset_parameters(self) -> None:
|
|
96
|
+
self.h_lins.reset_parameters()
|
|
97
|
+
if not self.qk_shared:
|
|
98
|
+
self.q.reset_parameters()
|
|
99
|
+
self.k.reset_parameters()
|
|
100
|
+
self.v.reset_parameters()
|
|
101
|
+
self.lns.reset_parameters()
|
|
102
|
+
self.lin_out.reset_parameters()
|
|
103
|
+
|
|
104
|
+
def __repr__(self) -> str:
|
|
105
|
+
return (f'{self.__class__.__name__}('
|
|
106
|
+
f'heads={self.heads}, '
|
|
107
|
+
f'head_channels={self.head_channels})')
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class QFormer(torch.nn.Module):
|
|
7
|
+
r"""The Querying Transformer (Q-Former) from
|
|
8
|
+
`"BLIP-2: Bootstrapping Language-Image Pre-training
|
|
9
|
+
with Frozen Image Encoders and Large Language Models"
|
|
10
|
+
<https://arxiv.org/pdf/2301.12597>`_ paper.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
input_dim (int): The number of features in the input.
|
|
14
|
+
hidden_dim (int): The dimension of the fnn in the encoder layer.
|
|
15
|
+
output_dim (int): The final output dimension.
|
|
16
|
+
num_heads (int): The number of multi-attention-heads.
|
|
17
|
+
num_layers (int): The number of sub-encoder-layers in the encoder.
|
|
18
|
+
dropout (int): The dropout value in each encoder layer.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
.. note::
|
|
22
|
+
This is a simplified version of the original Q-Former implementation.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
input_dim: int,
|
|
27
|
+
hidden_dim: int,
|
|
28
|
+
output_dim: int,
|
|
29
|
+
num_heads: int,
|
|
30
|
+
num_layers: int,
|
|
31
|
+
dropout: float = 0.0,
|
|
32
|
+
activation: Callable = torch.nn.ReLU(),
|
|
33
|
+
) -> None:
|
|
34
|
+
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.num_layers = num_layers
|
|
37
|
+
self.num_heads = num_heads
|
|
38
|
+
|
|
39
|
+
self.layer_norm = torch.nn.LayerNorm(input_dim)
|
|
40
|
+
self.encoder_layer = torch.nn.TransformerEncoderLayer(
|
|
41
|
+
d_model=input_dim,
|
|
42
|
+
nhead=num_heads,
|
|
43
|
+
dim_feedforward=hidden_dim,
|
|
44
|
+
dropout=dropout,
|
|
45
|
+
activation=activation,
|
|
46
|
+
batch_first=True,
|
|
47
|
+
)
|
|
48
|
+
self.encoder = torch.nn.TransformerEncoder(
|
|
49
|
+
self.encoder_layer,
|
|
50
|
+
num_layers=num_layers,
|
|
51
|
+
)
|
|
52
|
+
self.project = torch.nn.Linear(input_dim, output_dim)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
r"""Forward pass.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): Input sequence to the encoder layer.
|
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
60
|
+
batch-size :math:`B`, sequence length :math:`N`,
|
|
61
|
+
and feature dimension :math:`F`.
|
|
62
|
+
"""
|
|
63
|
+
x = self.layer_norm(x)
|
|
64
|
+
x = self.encoder(x)
|
|
65
|
+
out = self.project(x)
|
|
66
|
+
return out
|
|
67
|
+
|
|
68
|
+
def __repr__(self) -> str:
|
|
69
|
+
return (f'{self.__class__.__name__}('
|
|
70
|
+
f'num_heads={self.num_heads}, '
|
|
71
|
+
f'num_layers={self.num_layers})')
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SGFormerAttention(torch.nn.Module):
|
|
8
|
+
r"""The simple global attention mechanism from the
|
|
9
|
+
`"SGFormer: Simplifying and Empowering Transformers for
|
|
10
|
+
Large-Graph Representations"
|
|
11
|
+
<https://arxiv.org/abs/2306.10759>`_ paper.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
channels (int): Size of each input sample.
|
|
15
|
+
heads (int, optional): Number of parallel attention heads.
|
|
16
|
+
(default: :obj:`1.`)
|
|
17
|
+
head_channels (int, optional): Size of each attention head.
|
|
18
|
+
(default: :obj:`64.`)
|
|
19
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
|
20
|
+
and value in the self attention. (default: :obj:`False`)
|
|
21
|
+
"""
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
channels: int,
|
|
25
|
+
heads: int = 1,
|
|
26
|
+
head_channels: int = 64,
|
|
27
|
+
qkv_bias: bool = False,
|
|
28
|
+
) -> None:
|
|
29
|
+
super().__init__()
|
|
30
|
+
assert channels % heads == 0
|
|
31
|
+
if head_channels is None:
|
|
32
|
+
head_channels = channels // heads
|
|
33
|
+
|
|
34
|
+
self.heads = heads
|
|
35
|
+
self.head_channels = head_channels
|
|
36
|
+
|
|
37
|
+
inner_channels = head_channels * heads
|
|
38
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
39
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
40
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
43
|
+
r"""Forward pass.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
x (torch.Tensor): Node feature tensor
|
|
47
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
48
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
49
|
+
each graph, and feature dimension :math:`F`.
|
|
50
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
51
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
52
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
53
|
+
"""
|
|
54
|
+
B, N, *_ = x.shape
|
|
55
|
+
qs, ks, vs = self.q(x), self.k(x), self.v(x)
|
|
56
|
+
# reshape and permute q, k and v to proper shape
|
|
57
|
+
# (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels)
|
|
58
|
+
qs, ks, vs = map(
|
|
59
|
+
lambda t: t.reshape(B, N, self.heads, self.head_channels),
|
|
60
|
+
(qs, ks, vs))
|
|
61
|
+
|
|
62
|
+
if mask is not None:
|
|
63
|
+
mask = mask[:, :, None, None]
|
|
64
|
+
vs.masked_fill_(~mask, 0.)
|
|
65
|
+
# replace 0's with epsilon
|
|
66
|
+
epsilon = 1e-6
|
|
67
|
+
qs[qs == 0] = epsilon
|
|
68
|
+
ks[ks == 0] = epsilon
|
|
69
|
+
# normalize input, shape not changed
|
|
70
|
+
qs, ks = map(
|
|
71
|
+
lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True),
|
|
72
|
+
(qs, ks))
|
|
73
|
+
|
|
74
|
+
# numerator
|
|
75
|
+
kvs = torch.einsum("blhm,blhd->bhmd", ks, vs)
|
|
76
|
+
attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs)
|
|
77
|
+
attention_num += N * vs
|
|
78
|
+
|
|
79
|
+
# denominator
|
|
80
|
+
all_ones = torch.ones([B, N]).to(ks.device)
|
|
81
|
+
ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones)
|
|
82
|
+
attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum)
|
|
83
|
+
# attentive aggregated results
|
|
84
|
+
attention_normalizer = torch.unsqueeze(attention_normalizer,
|
|
85
|
+
len(attention_normalizer.shape))
|
|
86
|
+
attention_normalizer += torch.ones_like(attention_normalizer) * N
|
|
87
|
+
attn_output = attention_num / attention_normalizer
|
|
88
|
+
|
|
89
|
+
return attn_output.mean(dim=2)
|
|
90
|
+
|
|
91
|
+
def reset_parameters(self):
|
|
92
|
+
self.q.reset_parameters()
|
|
93
|
+
self.k.reset_parameters()
|
|
94
|
+
self.v.reset_parameters()
|
|
95
|
+
|
|
96
|
+
def __repr__(self) -> str:
|
|
97
|
+
return (f'{self.__class__.__name__}('
|
|
98
|
+
f'heads={self.heads}, '
|
|
99
|
+
f'head_channels={self.head_channels})')
|
|
@@ -61,6 +61,7 @@ from .gps_conv import GPSConv
|
|
|
61
61
|
from .antisymmetric_conv import AntiSymmetricConv
|
|
62
62
|
from .dir_gnn_conv import DirGNNConv
|
|
63
63
|
from .mixhop_conv import MixHopConv
|
|
64
|
+
from .meshcnn_conv import MeshCNNConv
|
|
64
65
|
|
|
65
66
|
import torch_geometric.nn.conv.utils # noqa
|
|
66
67
|
|
|
@@ -131,6 +132,7 @@ __all__ = [
|
|
|
131
132
|
'AntiSymmetricConv',
|
|
132
133
|
'DirGNNConv',
|
|
133
134
|
'MixHopConv',
|
|
135
|
+
'MeshCNNConv',
|
|
134
136
|
]
|
|
135
137
|
|
|
136
138
|
classes = __all__
|
torch_geometric/nn/conv/appnp.py
CHANGED
|
@@ -98,13 +98,16 @@ def {{collect_name}}(
|
|
|
98
98
|
|
|
99
99
|
{%- if 'edge_weight' in collect_param_dict and
|
|
100
100
|
collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
|
|
101
|
-
|
|
101
|
+
if torch.jit.is_scripting():
|
|
102
|
+
assert edge_weight is not None
|
|
102
103
|
{%- elif 'edge_attr' in collect_param_dict and
|
|
103
104
|
collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
|
|
104
|
-
|
|
105
|
+
if torch.jit.is_scripting():
|
|
106
|
+
assert edge_attr is not None
|
|
105
107
|
{%- elif 'edge_type' in collect_param_dict and
|
|
106
108
|
collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
|
|
107
|
-
|
|
109
|
+
if torch.jit.is_scripting():
|
|
110
|
+
assert edge_type is not None
|
|
108
111
|
{%- endif %}
|
|
109
112
|
|
|
110
113
|
# Collect user-defined arguments:
|
|
@@ -26,6 +26,9 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
|
|
|
26
26
|
:class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops`
|
|
27
27
|
package that fuses message passing computation for accelerated execution
|
|
28
28
|
and lower memory footprint.
|
|
29
|
+
The current method to enable :obj:`cugraph-ops`
|
|
30
|
+
is to use `The NVIDIA PyG Container
|
|
31
|
+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
|
|
29
32
|
"""
|
|
30
33
|
def __init__(
|
|
31
34
|
self,
|
|
@@ -67,6 +70,7 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
|
|
|
67
70
|
self,
|
|
68
71
|
x: Tensor,
|
|
69
72
|
edge_index: EdgeIndex,
|
|
73
|
+
edge_attr: Tensor,
|
|
70
74
|
max_num_neighbors: Optional[int] = None,
|
|
71
75
|
) -> Tensor:
|
|
72
76
|
graph = self.get_cugraph(edge_index, max_num_neighbors)
|
|
@@ -75,10 +79,12 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
|
|
|
75
79
|
|
|
76
80
|
if LEGACY_MODE:
|
|
77
81
|
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
|
|
78
|
-
self.negative_slope, False, self.concat
|
|
82
|
+
self.negative_slope, False, self.concat,
|
|
83
|
+
edge_feat=edge_attr)
|
|
79
84
|
else:
|
|
80
85
|
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
|
|
81
|
-
self.negative_slope, self.concat
|
|
86
|
+
self.negative_slope, self.concat,
|
|
87
|
+
edge_feat=edge_attr)
|
|
82
88
|
|
|
83
89
|
if self.bias is not None:
|
|
84
90
|
out = out + self.bias
|
|
@@ -29,6 +29,9 @@ class CuGraphRGCNConv(CuGraphModule): # pragma: no cover
|
|
|
29
29
|
:class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops`
|
|
30
30
|
package that fuses message passing computation for accelerated execution
|
|
31
31
|
and lower memory footprint.
|
|
32
|
+
The current method to enable :obj:`cugraph-ops`
|
|
33
|
+
is to use `The NVIDIA PyG Container
|
|
34
|
+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
|
|
32
35
|
"""
|
|
33
36
|
def __init__(self, in_channels: int, out_channels: int, num_relations: int,
|
|
34
37
|
num_bases: Optional[int] = None, aggr: str = 'mean',
|
|
@@ -27,6 +27,9 @@ class CuGraphSAGEConv(CuGraphModule): # pragma: no cover
|
|
|
27
27
|
:class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops`
|
|
28
28
|
package that fuses message passing computation for accelerated execution
|
|
29
29
|
and lower memory footprint.
|
|
30
|
+
The current method to enable :obj:`cugraph-ops`
|
|
31
|
+
is to use `The NVIDIA PyG Container
|
|
32
|
+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
|
|
30
33
|
"""
|
|
31
34
|
def __init__(
|
|
32
35
|
self,
|
|
@@ -163,7 +163,7 @@ class MultiHead(Attention):
|
|
|
163
163
|
def __repr__(self) -> str: # pragma: no cover
|
|
164
164
|
return (f'{self.__class__.__name__}({self.in_channels}, '
|
|
165
165
|
f'{self.out_channels}, heads={self.heads}, '
|
|
166
|
-
f'groups={self.groups}, dropout={self.
|
|
166
|
+
f'groups={self.groups}, dropout={self.dropout}, '
|
|
167
167
|
f'bias={self.bias})')
|
|
168
168
|
|
|
169
169
|
|