pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_to_dense_batch.py +2 -2
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
torch_geometric/nn/__init__.py
CHANGED
torch_geometric/nn/aggr/base.py
CHANGED
|
@@ -135,7 +135,7 @@ class Aggregation(torch.nn.Module):
|
|
|
135
135
|
if index.numel() > 0 and dim_size <= int(index.max()):
|
|
136
136
|
raise ValueError(f"Encountered invalid 'dim_size' (got "
|
|
137
137
|
f"'{dim_size}' but expected "
|
|
138
|
-
f">= '{int(index.max()) + 1}')")
|
|
138
|
+
f">= '{int(index.max()) + 1}')") from e
|
|
139
139
|
raise e
|
|
140
140
|
|
|
141
141
|
def __repr__(self) -> str:
|
|
@@ -52,7 +52,7 @@ class MomentumOptimizer(torch.nn.Module):
|
|
|
52
52
|
layer. It is based on an unrolled Nesterov momentum algorithm.
|
|
53
53
|
|
|
54
54
|
Args:
|
|
55
|
-
learning_rate (
|
|
55
|
+
learning_rate (float): learning rate for optimizer.
|
|
56
56
|
momentum (float): momentum for optimizer.
|
|
57
57
|
learnable (bool): If :obj:`True` then the :obj:`learning_rate` and
|
|
58
58
|
:obj:`momentum` will be learnable parameters. If False they
|
torch_geometric/nn/aggr/fused.py
CHANGED
|
@@ -216,7 +216,7 @@ class FusedAggregation(Aggregation):
|
|
|
216
216
|
outs: List[Optional[Tensor]] = []
|
|
217
217
|
|
|
218
218
|
# Iterate over all reduction ops to compute first results:
|
|
219
|
-
for
|
|
219
|
+
for reduce in self.reduce_ops:
|
|
220
220
|
if reduce is None:
|
|
221
221
|
outs.append(None)
|
|
222
222
|
continue
|
|
@@ -32,6 +32,8 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
32
32
|
aggr (str or list[str], optional): The aggregation module, *e.g.*,
|
|
33
33
|
:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
|
|
34
34
|
:obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
|
|
35
|
+
device (torch.device, optional): The device of the module.
|
|
36
|
+
(default: :obj:`None`)
|
|
35
37
|
"""
|
|
36
38
|
def __init__(
|
|
37
39
|
self,
|
|
@@ -43,6 +45,7 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
43
45
|
heads: int = 1,
|
|
44
46
|
dropout: float = 0.0,
|
|
45
47
|
aggr: Union[str, List[str]] = 'mean',
|
|
48
|
+
device: Optional[torch.device] = None,
|
|
46
49
|
) -> None:
|
|
47
50
|
super().__init__()
|
|
48
51
|
|
|
@@ -55,12 +58,13 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
55
58
|
for aggr in self.aggrs:
|
|
56
59
|
assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
|
|
57
60
|
|
|
58
|
-
self.lin = torch.nn.Linear(in_channels, hidden_channels)
|
|
61
|
+
self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)
|
|
59
62
|
self.pad_projector = torch.nn.Linear(
|
|
60
63
|
patch_size * hidden_channels,
|
|
61
64
|
hidden_channels,
|
|
65
|
+
device=device,
|
|
62
66
|
)
|
|
63
|
-
self.pe = PositionalEncoding(hidden_channels)
|
|
67
|
+
self.pe = PositionalEncoding(hidden_channels, device=device)
|
|
64
68
|
|
|
65
69
|
self.blocks = torch.nn.ModuleList([
|
|
66
70
|
MultiheadAttentionBlock(
|
|
@@ -68,12 +72,14 @@ class PatchTransformerAggregation(Aggregation):
|
|
|
68
72
|
heads=heads,
|
|
69
73
|
layer_norm=True,
|
|
70
74
|
dropout=dropout,
|
|
75
|
+
device=device,
|
|
71
76
|
) for _ in range(num_transformer_blocks)
|
|
72
77
|
])
|
|
73
78
|
|
|
74
79
|
self.fc = torch.nn.Linear(
|
|
75
80
|
hidden_channels * len(self.aggrs),
|
|
76
81
|
out_channels,
|
|
82
|
+
device=device,
|
|
77
83
|
)
|
|
78
84
|
|
|
79
85
|
def reset_parameters(self) -> None:
|
|
@@ -38,7 +38,7 @@ class SetTransformerAggregation(Aggregation):
|
|
|
38
38
|
(default: :obj:`1`)
|
|
39
39
|
concat (bool, optional): If set to :obj:`False`, the seed embeddings
|
|
40
40
|
are averaged instead of concatenated. (default: :obj:`True`)
|
|
41
|
-
|
|
41
|
+
layer_norm (str, optional): If set to :obj:`True`, will apply layer
|
|
42
42
|
normalization. (default: :obj:`False`)
|
|
43
43
|
dropout (float, optional): Dropout probability of attention weights.
|
|
44
44
|
(default: :obj:`0`)
|
torch_geometric/nn/aggr/utils.py
CHANGED
|
@@ -26,9 +26,11 @@ class MultiheadAttentionBlock(torch.nn.Module):
|
|
|
26
26
|
normalization. (default: :obj:`True`)
|
|
27
27
|
dropout (float, optional): Dropout probability of attention weights.
|
|
28
28
|
(default: :obj:`0`)
|
|
29
|
+
device (torch.device, optional): The device of the module.
|
|
30
|
+
(default: :obj:`None`)
|
|
29
31
|
"""
|
|
30
32
|
def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,
|
|
31
|
-
dropout: float = 0.0):
|
|
33
|
+
dropout: float = 0.0, device: Optional[torch.device] = None):
|
|
32
34
|
super().__init__()
|
|
33
35
|
|
|
34
36
|
self.channels = channels
|
|
@@ -40,10 +42,13 @@ class MultiheadAttentionBlock(torch.nn.Module):
|
|
|
40
42
|
heads,
|
|
41
43
|
batch_first=True,
|
|
42
44
|
dropout=dropout,
|
|
45
|
+
device=device,
|
|
43
46
|
)
|
|
44
|
-
self.lin = Linear(channels, channels)
|
|
45
|
-
self.layer_norm1 = LayerNorm(channels
|
|
46
|
-
|
|
47
|
+
self.lin = Linear(channels, channels, device=device)
|
|
48
|
+
self.layer_norm1 = LayerNorm(channels,
|
|
49
|
+
device=device) if layer_norm else None
|
|
50
|
+
self.layer_norm2 = LayerNorm(channels,
|
|
51
|
+
device=device) if layer_norm else None
|
|
47
52
|
|
|
48
53
|
def reset_parameters(self):
|
|
49
54
|
self.attn._reset_parameters()
|
|
@@ -1,3 +1,11 @@
|
|
|
1
1
|
from .performer import PerformerAttention
|
|
2
|
+
from .qformer import QFormer
|
|
3
|
+
from .sgformer import SGFormerAttention
|
|
4
|
+
from .polynormer import PolynormerAttention
|
|
2
5
|
|
|
3
|
-
__all__ = [
|
|
6
|
+
__all__ = classes = [
|
|
7
|
+
'PerformerAttention',
|
|
8
|
+
'QFormer',
|
|
9
|
+
'SGFormerAttention',
|
|
10
|
+
'PolynormerAttention',
|
|
11
|
+
]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PolynormerAttention(torch.nn.Module):
|
|
9
|
+
r"""The polynomial-expressive attention mechanism from the
|
|
10
|
+
`"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time"
|
|
11
|
+
<https://arxiv.org/abs/2403.01232>`_ paper.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
channels (int): Size of each input sample.
|
|
15
|
+
heads (int, optional): Number of parallel attention heads.
|
|
16
|
+
head_channels (int, optional): Size of each attention head.
|
|
17
|
+
(default: :obj:`64.`)
|
|
18
|
+
beta (float, optional): Polynormer beta initialization.
|
|
19
|
+
(default: :obj:`0.9`)
|
|
20
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
|
21
|
+
and value in the self attention. (default: :obj:`False`)
|
|
22
|
+
qk_shared (bool optional): Whether weight of query and key are shared.
|
|
23
|
+
(default: :obj:`True`)
|
|
24
|
+
dropout (float, optional): Dropout probability of the final
|
|
25
|
+
attention output. (default: :obj:`0.0`)
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
channels: int,
|
|
30
|
+
heads: int,
|
|
31
|
+
head_channels: int = 64,
|
|
32
|
+
beta: float = 0.9,
|
|
33
|
+
qkv_bias: bool = False,
|
|
34
|
+
qk_shared: bool = True,
|
|
35
|
+
dropout: float = 0.0,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self.head_channels = head_channels
|
|
40
|
+
self.heads = heads
|
|
41
|
+
self.beta = beta
|
|
42
|
+
self.qk_shared = qk_shared
|
|
43
|
+
|
|
44
|
+
inner_channels = heads * head_channels
|
|
45
|
+
self.h_lins = torch.nn.Linear(channels, inner_channels)
|
|
46
|
+
if not self.qk_shared:
|
|
47
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
48
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
49
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
50
|
+
self.lns = torch.nn.LayerNorm(inner_channels)
|
|
51
|
+
self.lin_out = torch.nn.Linear(inner_channels, inner_channels)
|
|
52
|
+
self.dropout = torch.nn.Dropout(dropout)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
55
|
+
r"""Forward pass.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): Node feature tensor
|
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
60
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
61
|
+
each graph, and feature dimension :math:`F`.
|
|
62
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
63
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
64
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
65
|
+
"""
|
|
66
|
+
B, N, *_ = x.shape
|
|
67
|
+
h = self.h_lins(x)
|
|
68
|
+
k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)
|
|
69
|
+
if self.qk_shared:
|
|
70
|
+
q = k
|
|
71
|
+
else:
|
|
72
|
+
q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)
|
|
73
|
+
v = self.v(x).view(B, N, self.head_channels, self.heads)
|
|
74
|
+
|
|
75
|
+
if mask is not None:
|
|
76
|
+
mask = mask[:, :, None, None]
|
|
77
|
+
v.masked_fill_(~mask, 0.)
|
|
78
|
+
|
|
79
|
+
# numerator
|
|
80
|
+
kv = torch.einsum('bndh, bnmh -> bdmh', k, v)
|
|
81
|
+
num = torch.einsum('bndh, bdmh -> bnmh', q, kv)
|
|
82
|
+
|
|
83
|
+
# denominator
|
|
84
|
+
k_sum = torch.einsum('bndh -> bdh', k)
|
|
85
|
+
den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)
|
|
86
|
+
|
|
87
|
+
# linear global attention based on kernel trick
|
|
88
|
+
x = (num / (den + 1e-6)).reshape(B, N, -1)
|
|
89
|
+
x = self.lns(x) * (h + self.beta)
|
|
90
|
+
x = F.relu(self.lin_out(x))
|
|
91
|
+
x = self.dropout(x)
|
|
92
|
+
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
def reset_parameters(self) -> None:
|
|
96
|
+
self.h_lins.reset_parameters()
|
|
97
|
+
if not self.qk_shared:
|
|
98
|
+
self.q.reset_parameters()
|
|
99
|
+
self.k.reset_parameters()
|
|
100
|
+
self.v.reset_parameters()
|
|
101
|
+
self.lns.reset_parameters()
|
|
102
|
+
self.lin_out.reset_parameters()
|
|
103
|
+
|
|
104
|
+
def __repr__(self) -> str:
|
|
105
|
+
return (f'{self.__class__.__name__}('
|
|
106
|
+
f'heads={self.heads}, '
|
|
107
|
+
f'head_channels={self.head_channels})')
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class QFormer(torch.nn.Module):
|
|
7
|
+
r"""The Querying Transformer (Q-Former) from
|
|
8
|
+
`"BLIP-2: Bootstrapping Language-Image Pre-training
|
|
9
|
+
with Frozen Image Encoders and Large Language Models"
|
|
10
|
+
<https://arxiv.org/pdf/2301.12597>`_ paper.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
input_dim (int): The number of features in the input.
|
|
14
|
+
hidden_dim (int): The dimension of the fnn in the encoder layer.
|
|
15
|
+
output_dim (int): The final output dimension.
|
|
16
|
+
num_heads (int): The number of multi-attention-heads.
|
|
17
|
+
num_layers (int): The number of sub-encoder-layers in the encoder.
|
|
18
|
+
dropout (int): The dropout value in each encoder layer.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
.. note::
|
|
22
|
+
This is a simplified version of the original Q-Former implementation.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
input_dim: int,
|
|
27
|
+
hidden_dim: int,
|
|
28
|
+
output_dim: int,
|
|
29
|
+
num_heads: int,
|
|
30
|
+
num_layers: int,
|
|
31
|
+
dropout: float = 0.0,
|
|
32
|
+
activation: Callable = torch.nn.ReLU(),
|
|
33
|
+
) -> None:
|
|
34
|
+
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.num_layers = num_layers
|
|
37
|
+
self.num_heads = num_heads
|
|
38
|
+
|
|
39
|
+
self.layer_norm = torch.nn.LayerNorm(input_dim)
|
|
40
|
+
self.encoder_layer = torch.nn.TransformerEncoderLayer(
|
|
41
|
+
d_model=input_dim,
|
|
42
|
+
nhead=num_heads,
|
|
43
|
+
dim_feedforward=hidden_dim,
|
|
44
|
+
dropout=dropout,
|
|
45
|
+
activation=activation,
|
|
46
|
+
batch_first=True,
|
|
47
|
+
)
|
|
48
|
+
self.encoder = torch.nn.TransformerEncoder(
|
|
49
|
+
self.encoder_layer,
|
|
50
|
+
num_layers=num_layers,
|
|
51
|
+
)
|
|
52
|
+
self.project = torch.nn.Linear(input_dim, output_dim)
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
r"""Forward pass.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x (torch.Tensor): Input sequence to the encoder layer.
|
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
60
|
+
batch-size :math:`B`, sequence length :math:`N`,
|
|
61
|
+
and feature dimension :math:`F`.
|
|
62
|
+
"""
|
|
63
|
+
x = self.layer_norm(x)
|
|
64
|
+
x = self.encoder(x)
|
|
65
|
+
out = self.project(x)
|
|
66
|
+
return out
|
|
67
|
+
|
|
68
|
+
def __repr__(self) -> str:
|
|
69
|
+
return (f'{self.__class__.__name__}('
|
|
70
|
+
f'num_heads={self.num_heads}, '
|
|
71
|
+
f'num_layers={self.num_layers})')
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SGFormerAttention(torch.nn.Module):
|
|
8
|
+
r"""The simple global attention mechanism from the
|
|
9
|
+
`"SGFormer: Simplifying and Empowering Transformers for
|
|
10
|
+
Large-Graph Representations"
|
|
11
|
+
<https://arxiv.org/abs/2306.10759>`_ paper.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
channels (int): Size of each input sample.
|
|
15
|
+
heads (int, optional): Number of parallel attention heads.
|
|
16
|
+
(default: :obj:`1.`)
|
|
17
|
+
head_channels (int, optional): Size of each attention head.
|
|
18
|
+
(default: :obj:`64.`)
|
|
19
|
+
qkv_bias (bool, optional): If specified, add bias to query, key
|
|
20
|
+
and value in the self attention. (default: :obj:`False`)
|
|
21
|
+
"""
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
channels: int,
|
|
25
|
+
heads: int = 1,
|
|
26
|
+
head_channels: int = 64,
|
|
27
|
+
qkv_bias: bool = False,
|
|
28
|
+
) -> None:
|
|
29
|
+
super().__init__()
|
|
30
|
+
assert channels % heads == 0
|
|
31
|
+
if head_channels is None:
|
|
32
|
+
head_channels = channels // heads
|
|
33
|
+
|
|
34
|
+
self.heads = heads
|
|
35
|
+
self.head_channels = head_channels
|
|
36
|
+
|
|
37
|
+
inner_channels = head_channels * heads
|
|
38
|
+
self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
39
|
+
self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
40
|
+
self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
43
|
+
r"""Forward pass.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
x (torch.Tensor): Node feature tensor
|
|
47
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
48
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
49
|
+
each graph, and feature dimension :math:`F`.
|
|
50
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
51
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
52
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
53
|
+
"""
|
|
54
|
+
B, N, *_ = x.shape
|
|
55
|
+
qs, ks, vs = self.q(x), self.k(x), self.v(x)
|
|
56
|
+
# reshape and permute q, k and v to proper shape
|
|
57
|
+
# (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels)
|
|
58
|
+
qs, ks, vs = map(
|
|
59
|
+
lambda t: t.reshape(B, N, self.heads, self.head_channels),
|
|
60
|
+
(qs, ks, vs))
|
|
61
|
+
|
|
62
|
+
if mask is not None:
|
|
63
|
+
mask = mask[:, :, None, None]
|
|
64
|
+
vs.masked_fill_(~mask, 0.)
|
|
65
|
+
# replace 0's with epsilon
|
|
66
|
+
epsilon = 1e-6
|
|
67
|
+
qs[qs == 0] = epsilon
|
|
68
|
+
ks[ks == 0] = epsilon
|
|
69
|
+
# normalize input, shape not changed
|
|
70
|
+
qs, ks = map(
|
|
71
|
+
lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True),
|
|
72
|
+
(qs, ks))
|
|
73
|
+
|
|
74
|
+
# numerator
|
|
75
|
+
kvs = torch.einsum("blhm,blhd->bhmd", ks, vs)
|
|
76
|
+
attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs)
|
|
77
|
+
attention_num += N * vs
|
|
78
|
+
|
|
79
|
+
# denominator
|
|
80
|
+
all_ones = torch.ones([B, N]).to(ks.device)
|
|
81
|
+
ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones)
|
|
82
|
+
attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum)
|
|
83
|
+
# attentive aggregated results
|
|
84
|
+
attention_normalizer = torch.unsqueeze(attention_normalizer,
|
|
85
|
+
len(attention_normalizer.shape))
|
|
86
|
+
attention_normalizer += torch.ones_like(attention_normalizer) * N
|
|
87
|
+
attn_output = attention_num / attention_normalizer
|
|
88
|
+
|
|
89
|
+
return attn_output.mean(dim=2)
|
|
90
|
+
|
|
91
|
+
def reset_parameters(self):
|
|
92
|
+
self.q.reset_parameters()
|
|
93
|
+
self.k.reset_parameters()
|
|
94
|
+
self.v.reset_parameters()
|
|
95
|
+
|
|
96
|
+
def __repr__(self) -> str:
|
|
97
|
+
return (f'{self.__class__.__name__}('
|
|
98
|
+
f'heads={self.heads}, '
|
|
99
|
+
f'head_channels={self.head_channels})')
|
|
@@ -61,6 +61,7 @@ from .gps_conv import GPSConv
|
|
|
61
61
|
from .antisymmetric_conv import AntiSymmetricConv
|
|
62
62
|
from .dir_gnn_conv import DirGNNConv
|
|
63
63
|
from .mixhop_conv import MixHopConv
|
|
64
|
+
from .meshcnn_conv import MeshCNNConv
|
|
64
65
|
|
|
65
66
|
import torch_geometric.nn.conv.utils # noqa
|
|
66
67
|
|
|
@@ -131,6 +132,7 @@ __all__ = [
|
|
|
131
132
|
'AntiSymmetricConv',
|
|
132
133
|
'DirGNNConv',
|
|
133
134
|
'MixHopConv',
|
|
135
|
+
'MeshCNNConv',
|
|
134
136
|
]
|
|
135
137
|
|
|
136
138
|
classes = __all__
|
torch_geometric/nn/conv/appnp.py
CHANGED
|
@@ -26,6 +26,9 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
|
|
|
26
26
|
:class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops`
|
|
27
27
|
package that fuses message passing computation for accelerated execution
|
|
28
28
|
and lower memory footprint.
|
|
29
|
+
The current method to enable :obj:`cugraph-ops`
|
|
30
|
+
is to use `The NVIDIA PyG Container
|
|
31
|
+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
|
|
29
32
|
"""
|
|
30
33
|
def __init__(
|
|
31
34
|
self,
|
|
@@ -67,6 +70,7 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
|
|
|
67
70
|
self,
|
|
68
71
|
x: Tensor,
|
|
69
72
|
edge_index: EdgeIndex,
|
|
73
|
+
edge_attr: Tensor,
|
|
70
74
|
max_num_neighbors: Optional[int] = None,
|
|
71
75
|
) -> Tensor:
|
|
72
76
|
graph = self.get_cugraph(edge_index, max_num_neighbors)
|
|
@@ -75,10 +79,12 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
|
|
|
75
79
|
|
|
76
80
|
if LEGACY_MODE:
|
|
77
81
|
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
|
|
78
|
-
self.negative_slope, False, self.concat
|
|
82
|
+
self.negative_slope, False, self.concat,
|
|
83
|
+
edge_feat=edge_attr)
|
|
79
84
|
else:
|
|
80
85
|
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
|
|
81
|
-
self.negative_slope, self.concat
|
|
86
|
+
self.negative_slope, self.concat,
|
|
87
|
+
edge_feat=edge_attr)
|
|
82
88
|
|
|
83
89
|
if self.bias is not None:
|
|
84
90
|
out = out + self.bias
|
|
@@ -29,6 +29,9 @@ class CuGraphRGCNConv(CuGraphModule): # pragma: no cover
|
|
|
29
29
|
:class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops`
|
|
30
30
|
package that fuses message passing computation for accelerated execution
|
|
31
31
|
and lower memory footprint.
|
|
32
|
+
The current method to enable :obj:`cugraph-ops`
|
|
33
|
+
is to use `The NVIDIA PyG Container
|
|
34
|
+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
|
|
32
35
|
"""
|
|
33
36
|
def __init__(self, in_channels: int, out_channels: int, num_relations: int,
|
|
34
37
|
num_bases: Optional[int] = None, aggr: str = 'mean',
|
|
@@ -27,6 +27,9 @@ class CuGraphSAGEConv(CuGraphModule): # pragma: no cover
|
|
|
27
27
|
:class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops`
|
|
28
28
|
package that fuses message passing computation for accelerated execution
|
|
29
29
|
and lower memory footprint.
|
|
30
|
+
The current method to enable :obj:`cugraph-ops`
|
|
31
|
+
is to use `The NVIDIA PyG Container
|
|
32
|
+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
|
|
30
33
|
"""
|
|
31
34
|
def __init__(
|
|
32
35
|
self,
|
|
@@ -163,7 +163,7 @@ class MultiHead(Attention):
|
|
|
163
163
|
def __repr__(self) -> str: # pragma: no cover
|
|
164
164
|
return (f'{self.__class__.__name__}({self.in_channels}, '
|
|
165
165
|
f'{self.out_channels}, heads={self.heads}, '
|
|
166
|
-
f'groups={self.groups}, dropout={self.
|
|
166
|
+
f'groups={self.groups}, dropout={self.dropout}, '
|
|
167
167
|
f'bias={self.bias})')
|
|
168
168
|
|
|
169
169
|
|
|
@@ -81,7 +81,7 @@ class EGConv(MessagePassing):
|
|
|
81
81
|
self,
|
|
82
82
|
in_channels: int,
|
|
83
83
|
out_channels: int,
|
|
84
|
-
aggregators: List[str] =
|
|
84
|
+
aggregators: Optional[List[str]] = None,
|
|
85
85
|
num_heads: int = 8,
|
|
86
86
|
num_bases: int = 4,
|
|
87
87
|
cached: bool = False,
|
|
@@ -96,23 +96,23 @@ class EGConv(MessagePassing):
|
|
|
96
96
|
f"divisible by the number of heads "
|
|
97
97
|
f"(got {num_heads})")
|
|
98
98
|
|
|
99
|
-
for a in aggregators:
|
|
100
|
-
if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
|
|
101
|
-
raise ValueError(f"Unsupported aggregator: '{a}'")
|
|
102
|
-
|
|
103
99
|
self.in_channels = in_channels
|
|
104
100
|
self.out_channels = out_channels
|
|
105
101
|
self.num_heads = num_heads
|
|
106
102
|
self.num_bases = num_bases
|
|
107
103
|
self.cached = cached
|
|
108
104
|
self.add_self_loops = add_self_loops
|
|
109
|
-
self.aggregators = aggregators
|
|
105
|
+
self.aggregators = aggregators or ['symnorm']
|
|
106
|
+
|
|
107
|
+
for a in self.aggregators:
|
|
108
|
+
if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
|
|
109
|
+
raise ValueError(f"Unsupported aggregator: '{a}'")
|
|
110
110
|
|
|
111
111
|
self.bases_lin = Linear(in_channels,
|
|
112
112
|
(out_channels // num_heads) * num_bases,
|
|
113
113
|
bias=False, weight_initializer='glorot')
|
|
114
114
|
self.comb_lin = Linear(in_channels,
|
|
115
|
-
num_heads * num_bases * len(aggregators))
|
|
115
|
+
num_heads * num_bases * len(self.aggregators))
|
|
116
116
|
|
|
117
117
|
if bias:
|
|
118
118
|
self.bias = Parameter(torch.empty(out_channels))
|
|
@@ -178,7 +178,7 @@ class GENConv(MessagePassing):
|
|
|
178
178
|
self.lin_dst = Linear(in_channels[1], out_channels, bias=bias)
|
|
179
179
|
|
|
180
180
|
channels = [out_channels]
|
|
181
|
-
for
|
|
181
|
+
for _ in range(num_layers - 1):
|
|
182
182
|
channels.append(out_channels * expansion)
|
|
183
183
|
channels.append(out_channels)
|
|
184
184
|
self.mlp = MLP(channels, norm=norm, bias=bias)
|
|
@@ -63,7 +63,8 @@ class GravNetConv(MessagePassing):
|
|
|
63
63
|
if num_workers is not None:
|
|
64
64
|
warnings.warn(
|
|
65
65
|
"'num_workers' attribute in '{self.__class__.__name__}' is "
|
|
66
|
-
"deprecated and will be removed in a future release"
|
|
66
|
+
"deprecated and will be removed in a future release",
|
|
67
|
+
stacklevel=2)
|
|
67
68
|
|
|
68
69
|
self.in_channels = in_channels
|
|
69
70
|
self.out_channels = out_channels
|
|
@@ -77,7 +77,8 @@ class HeteroConv(torch.nn.Module):
|
|
|
77
77
|
f"There exist node types ({src_node_types - dst_node_types}) "
|
|
78
78
|
f"whose representations do not get updated during message "
|
|
79
79
|
f"passing as they do not occur as destination type in any "
|
|
80
|
-
f"edge type. This may lead to unexpected behavior."
|
|
80
|
+
f"edge type. This may lead to unexpected behavior.",
|
|
81
|
+
stacklevel=2)
|
|
81
82
|
|
|
82
83
|
self.convs = ModuleDict(convs)
|
|
83
84
|
self.aggr = aggr
|