pyg-nightly 2.6.0.dev20240318__py3-none-any.whl → 2.7.0.dev20250115__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240318.dist-info → pyg_nightly-2.7.0.dev20250115.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +13 -7
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +317 -65
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +3 -5
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +329 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +56 -22
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
@@ -0,0 +1,143 @@
|
|
1
|
+
import math
|
2
|
+
from typing import List, Optional, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import Tensor
|
6
|
+
|
7
|
+
from torch_geometric.experimental import disable_dynamic_shapes
|
8
|
+
from torch_geometric.nn.aggr import Aggregation
|
9
|
+
from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
|
10
|
+
from torch_geometric.nn.encoding import PositionalEncoding
|
11
|
+
from torch_geometric.utils import scatter
|
12
|
+
|
13
|
+
|
14
|
+
class PatchTransformerAggregation(Aggregation):
|
15
|
+
r"""Performs patch transformer aggregation in which the elements to
|
16
|
+
aggregate are processed by multi-head attention blocks across patches, as
|
17
|
+
described in the `"Simplifying Temporal Heterogeneous Network for
|
18
|
+
Continuous-Time Link Prediction"
|
19
|
+
<https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
in_channels (int): Size of each input sample.
|
23
|
+
out_channels (int): Size of each output sample.
|
24
|
+
patch_size (int): Number of elements in a patch.
|
25
|
+
hidden_channels (int): Intermediate size of each sample.
|
26
|
+
num_transformer_blocks (int, optional): Number of transformer blocks
|
27
|
+
(default: :obj:`1`).
|
28
|
+
heads (int, optional): Number of multi-head-attentions.
|
29
|
+
(default: :obj:`1`)
|
30
|
+
dropout (float, optional): Dropout probability of attention weights.
|
31
|
+
(default: :obj:`0.0`)
|
32
|
+
aggr (str or list[str], optional): The aggregation module, *e.g.*,
|
33
|
+
:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
|
34
|
+
:obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
|
35
|
+
"""
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
in_channels: int,
|
39
|
+
out_channels: int,
|
40
|
+
patch_size: int,
|
41
|
+
hidden_channels: int,
|
42
|
+
num_transformer_blocks: int = 1,
|
43
|
+
heads: int = 1,
|
44
|
+
dropout: float = 0.0,
|
45
|
+
aggr: Union[str, List[str]] = 'mean',
|
46
|
+
) -> None:
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
self.in_channels = in_channels
|
50
|
+
self.out_channels = out_channels
|
51
|
+
self.patch_size = patch_size
|
52
|
+
self.aggrs = [aggr] if isinstance(aggr, str) else aggr
|
53
|
+
|
54
|
+
assert len(self.aggrs) > 0
|
55
|
+
for aggr in self.aggrs:
|
56
|
+
assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
|
57
|
+
|
58
|
+
self.lin = torch.nn.Linear(in_channels, hidden_channels)
|
59
|
+
self.pad_projector = torch.nn.Linear(
|
60
|
+
patch_size * hidden_channels,
|
61
|
+
hidden_channels,
|
62
|
+
)
|
63
|
+
self.pe = PositionalEncoding(hidden_channels)
|
64
|
+
|
65
|
+
self.blocks = torch.nn.ModuleList([
|
66
|
+
MultiheadAttentionBlock(
|
67
|
+
channels=hidden_channels,
|
68
|
+
heads=heads,
|
69
|
+
layer_norm=True,
|
70
|
+
dropout=dropout,
|
71
|
+
) for _ in range(num_transformer_blocks)
|
72
|
+
])
|
73
|
+
|
74
|
+
self.fc = torch.nn.Linear(
|
75
|
+
hidden_channels * len(self.aggrs),
|
76
|
+
out_channels,
|
77
|
+
)
|
78
|
+
|
79
|
+
def reset_parameters(self) -> None:
|
80
|
+
self.lin.reset_parameters()
|
81
|
+
self.pad_projector.reset_parameters()
|
82
|
+
self.pe.reset_parameters()
|
83
|
+
for block in self.blocks:
|
84
|
+
block.reset_parameters()
|
85
|
+
self.fc.reset_parameters()
|
86
|
+
|
87
|
+
@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
|
88
|
+
def forward(
|
89
|
+
self,
|
90
|
+
x: Tensor,
|
91
|
+
index: Tensor,
|
92
|
+
ptr: Optional[Tensor] = None,
|
93
|
+
dim_size: Optional[int] = None,
|
94
|
+
dim: int = -2,
|
95
|
+
max_num_elements: Optional[int] = None,
|
96
|
+
) -> Tensor:
|
97
|
+
|
98
|
+
if max_num_elements is None:
|
99
|
+
if ptr is not None:
|
100
|
+
count = ptr.diff()
|
101
|
+
else:
|
102
|
+
count = scatter(torch.ones_like(index), index, dim=0,
|
103
|
+
dim_size=dim_size, reduce='sum')
|
104
|
+
max_num_elements = int(count.max()) + 1
|
105
|
+
|
106
|
+
# Set `max_num_elements` to a multiple of `patch_size`:
|
107
|
+
max_num_elements = (math.floor(max_num_elements / self.patch_size) *
|
108
|
+
self.patch_size)
|
109
|
+
|
110
|
+
x = self.lin(x)
|
111
|
+
|
112
|
+
# TODO If groups are heavily unbalanced, this will create a lot of
|
113
|
+
# "empty" patches. Try to figure out a way to fix this.
|
114
|
+
# [batch_size, num_patches * patch_size, hidden_channels]
|
115
|
+
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
|
116
|
+
max_num_elements=max_num_elements)
|
117
|
+
|
118
|
+
# [batch_size, num_patches, patch_size * hidden_channels]
|
119
|
+
x = x.view(x.size(0), max_num_elements // self.patch_size,
|
120
|
+
self.patch_size * x.size(-1))
|
121
|
+
|
122
|
+
# [batch_size, num_patches, hidden_channels]
|
123
|
+
x = self.pad_projector(x)
|
124
|
+
|
125
|
+
x = x + self.pe(torch.arange(x.size(1), device=x.device))
|
126
|
+
|
127
|
+
# [batch_size, num_patches, hidden_channels]
|
128
|
+
for block in self.blocks:
|
129
|
+
x = block(x, x)
|
130
|
+
|
131
|
+
# [batch_size, hidden_channels]
|
132
|
+
outs: List[Tensor] = []
|
133
|
+
for aggr in self.aggrs:
|
134
|
+
out = getattr(torch, aggr)(x, dim=1)
|
135
|
+
outs.append(out[0] if isinstance(out, tuple) else out)
|
136
|
+
out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]
|
137
|
+
|
138
|
+
# [batch_size, out_channels]
|
139
|
+
return self.fc(out)
|
140
|
+
|
141
|
+
def __repr__(self) -> str:
|
142
|
+
return (f'{self.__class__.__name__}({self.in_channels}, '
|
143
|
+
f'{self.out_channels}, patch_size={self.patch_size})')
|
@@ -38,7 +38,7 @@ class SetTransformerAggregation(Aggregation):
|
|
38
38
|
(default: :obj:`1`)
|
39
39
|
concat (bool, optional): If set to :obj:`False`, the seed embeddings
|
40
40
|
are averaged instead of concatenated. (default: :obj:`True`)
|
41
|
-
|
41
|
+
layer_norm (str, optional): If set to :obj:`True`, will apply layer
|
42
42
|
normalization. (default: :obj:`False`)
|
43
43
|
dropout (float, optional): Dropout probability of attention weights.
|
44
44
|
(default: :obj:`0`)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from torch import Tensor
|
4
|
+
|
5
|
+
from torch_geometric.nn.aggr import Aggregation
|
6
|
+
from torch_geometric.utils import degree
|
7
|
+
from torch_geometric.utils._scatter import broadcast
|
8
|
+
|
9
|
+
|
10
|
+
class VariancePreservingAggregation(Aggregation):
|
11
|
+
r"""Performs the Variance Preserving Aggregation (VPA) from the `"GNN-VPA:
|
12
|
+
A Variance-Preserving Aggregation Strategy for Graph Neural Networks"
|
13
|
+
<https://arxiv.org/abs/2403.04747>`_ paper.
|
14
|
+
|
15
|
+
.. math::
|
16
|
+
\mathrm{vpa}(\mathcal{X}) = \frac{1}{\sqrt{|\mathcal{X}|}}
|
17
|
+
\sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i
|
18
|
+
"""
|
19
|
+
def forward(self, x: Tensor, index: Optional[Tensor] = None,
|
20
|
+
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
|
21
|
+
dim: int = -2) -> Tensor:
|
22
|
+
|
23
|
+
out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
|
24
|
+
|
25
|
+
if ptr is not None:
|
26
|
+
count = ptr.diff().to(out.dtype)
|
27
|
+
else:
|
28
|
+
count = degree(index, dim_size, dtype=out.dtype)
|
29
|
+
|
30
|
+
count = count.sqrt().clamp(min=1.0)
|
31
|
+
count = broadcast(count, ref=out, dim=dim)
|
32
|
+
|
33
|
+
return out / count
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
class QFormer(torch.nn.Module):
|
7
|
+
r"""The Querying Transformer (Q-Former) from
|
8
|
+
`"BLIP-2: Bootstrapping Language-Image Pre-training
|
9
|
+
with Frozen Image Encoders and Large Language Models"
|
10
|
+
<https://arxiv.org/pdf/2301.12597>`_ paper.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
input_dim (int): The number of features in the input.
|
14
|
+
hidden_dim (int): The dimension of the fnn in the encoder layer.
|
15
|
+
output_dim (int): The final output dimension.
|
16
|
+
num_heads (int): The number of multi-attention-heads.
|
17
|
+
num_layers (int): The number of sub-encoder-layers in the encoder.
|
18
|
+
dropout (int): The dropout value in each encoder layer.
|
19
|
+
|
20
|
+
|
21
|
+
.. note::
|
22
|
+
This is a simplified version of the original Q-Former implementation.
|
23
|
+
"""
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
input_dim: int,
|
27
|
+
hidden_dim: int,
|
28
|
+
output_dim: int,
|
29
|
+
num_heads: int,
|
30
|
+
num_layers: int,
|
31
|
+
dropout: float = 0.0,
|
32
|
+
activation: Callable = torch.nn.ReLU(),
|
33
|
+
) -> None:
|
34
|
+
|
35
|
+
super().__init__()
|
36
|
+
self.num_layers = num_layers
|
37
|
+
self.num_heads = num_heads
|
38
|
+
|
39
|
+
self.layer_norm = torch.nn.LayerNorm(input_dim)
|
40
|
+
self.encoder_layer = torch.nn.TransformerEncoderLayer(
|
41
|
+
d_model=input_dim,
|
42
|
+
nhead=num_heads,
|
43
|
+
dim_feedforward=hidden_dim,
|
44
|
+
dropout=dropout,
|
45
|
+
activation=activation,
|
46
|
+
batch_first=True,
|
47
|
+
)
|
48
|
+
self.encoder = torch.nn.TransformerEncoder(
|
49
|
+
self.encoder_layer,
|
50
|
+
num_layers=num_layers,
|
51
|
+
)
|
52
|
+
self.project = torch.nn.Linear(input_dim, output_dim)
|
53
|
+
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
55
|
+
r"""Forward pass.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
x (torch.Tensor): Input sequence to the encoder layer.
|
59
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
60
|
+
batch-size :math:`B`, sequence length :math:`N`,
|
61
|
+
and feature dimension :math:`F`.
|
62
|
+
"""
|
63
|
+
x = self.layer_norm(x)
|
64
|
+
x = self.encoder(x)
|
65
|
+
out = self.project(x)
|
66
|
+
return out
|
67
|
+
|
68
|
+
def __repr__(self) -> str:
|
69
|
+
return (f'{self.__class__.__name__}('
|
70
|
+
f'num_heads={self.num_heads}, '
|
71
|
+
f'num_layers={self.num_layers})')
|
@@ -4,8 +4,8 @@ import torch
|
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
6
|
from torch_geometric import EdgeIndex
|
7
|
+
from torch_geometric.index import ptr2index
|
7
8
|
from torch_geometric.utils import is_torch_sparse_tensor
|
8
|
-
from torch_geometric.utils.sparse import ptr2index
|
9
9
|
from torch_geometric.typing import SparseTensor
|
10
10
|
|
11
11
|
|
@@ -98,13 +98,16 @@ def {{collect_name}}(
|
|
98
98
|
|
99
99
|
{%- if 'edge_weight' in collect_param_dict and
|
100
100
|
collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
|
101
|
-
|
101
|
+
if torch.jit.is_scripting():
|
102
|
+
assert edge_weight is not None
|
102
103
|
{%- elif 'edge_attr' in collect_param_dict and
|
103
104
|
collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
|
104
|
-
|
105
|
+
if torch.jit.is_scripting():
|
106
|
+
assert edge_attr is not None
|
105
107
|
{%- elif 'edge_type' in collect_param_dict and
|
106
108
|
collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
|
107
|
-
|
109
|
+
if torch.jit.is_scripting():
|
110
|
+
assert edge_type is not None
|
108
111
|
{%- endif %}
|
109
112
|
|
110
113
|
# Collect user-defined arguments:
|
@@ -7,12 +7,7 @@ from torch_geometric import EdgeIndex
|
|
7
7
|
|
8
8
|
try: # pragma: no cover
|
9
9
|
LEGACY_MODE = False
|
10
|
-
from pylibcugraphops.pytorch import
|
11
|
-
SampledCSC,
|
12
|
-
SampledHeteroCSC,
|
13
|
-
StaticCSC,
|
14
|
-
StaticHeteroCSC,
|
15
|
-
)
|
10
|
+
from pylibcugraphops.pytorch import CSC, HeteroCSC
|
16
11
|
HAS_PYLIBCUGRAPHOPS = True
|
17
12
|
except ImportError:
|
18
13
|
HAS_PYLIBCUGRAPHOPS = False
|
@@ -41,7 +36,6 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
|
|
41
36
|
|
42
37
|
def reset_parameters(self):
|
43
38
|
r"""Resets all learnable parameters of the module."""
|
44
|
-
pass
|
45
39
|
|
46
40
|
def get_cugraph(
|
47
41
|
self,
|
@@ -79,12 +73,13 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
|
|
79
73
|
return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,
|
80
74
|
num_src_nodes)
|
81
75
|
|
82
|
-
return
|
76
|
+
return CSC(colptr, row, num_src_nodes,
|
77
|
+
dst_max_in_degree=max_num_neighbors)
|
83
78
|
|
84
79
|
if LEGACY_MODE:
|
85
80
|
return make_fg_csr(colptr, row)
|
86
81
|
|
87
|
-
return
|
82
|
+
return CSC(colptr, row, num_src_nodes=num_src_nodes)
|
88
83
|
|
89
84
|
def get_typed_cugraph(
|
90
85
|
self,
|
@@ -135,15 +130,16 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
|
|
135
130
|
out_node_types=None, in_node_types=None,
|
136
131
|
edge_types=edge_type)
|
137
132
|
|
138
|
-
return
|
139
|
-
|
133
|
+
return HeteroCSC(colptr, row, edge_type, num_src_nodes,
|
134
|
+
num_edge_types,
|
135
|
+
dst_max_in_degree=max_num_neighbors)
|
140
136
|
|
141
137
|
if LEGACY_MODE:
|
142
138
|
return make_fg_csr_hg(colptr, row, n_node_types=0,
|
143
139
|
n_edge_types=num_edge_types, node_types=None,
|
144
140
|
edge_types=edge_type)
|
145
141
|
|
146
|
-
return
|
142
|
+
return HeteroCSC(colptr, row, edge_type, num_src_nodes, num_edge_types)
|
147
143
|
|
148
144
|
def forward(
|
149
145
|
self,
|
@@ -3,13 +3,14 @@ from typing import Callable, Optional, Union
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
+
import torch_geometric.typing
|
6
7
|
from torch_geometric.nn.conv import MessagePassing
|
7
8
|
from torch_geometric.nn.inits import reset
|
8
9
|
from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor
|
9
10
|
|
10
|
-
|
11
|
+
if torch_geometric.typing.WITH_TORCH_CLUSTER:
|
11
12
|
from torch_cluster import knn
|
12
|
-
|
13
|
+
else:
|
13
14
|
knn = None
|
14
15
|
|
15
16
|
|
@@ -3,9 +3,9 @@ from typing import Optional, Tuple
|
|
3
3
|
import torch
|
4
4
|
from torch import Tensor
|
5
5
|
|
6
|
+
from torch_geometric.index import index2ptr
|
6
7
|
from torch_geometric.nn.conv import GATConv
|
7
8
|
from torch_geometric.utils import sort_edge_index
|
8
|
-
from torch_geometric.utils.sparse import index2ptr
|
9
9
|
|
10
10
|
|
11
11
|
class FusedGATConv(GATConv): # pragma: no cover
|
@@ -37,9 +37,8 @@ class GATConv(MessagePassing):
|
|
37
37
|
<https://arxiv.org/abs/1710.10903>`_ paper.
|
38
38
|
|
39
39
|
.. math::
|
40
|
-
\mathbf{x}^{\prime}_i = \
|
41
|
-
\
|
42
|
-
\alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j},
|
40
|
+
\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}}
|
41
|
+
\alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j},
|
43
42
|
|
44
43
|
where the attention coefficients :math:`\alpha_{i,j}` are computed as
|
45
44
|
|
@@ -108,6 +107,8 @@ class GATConv(MessagePassing):
|
|
108
107
|
:obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
|
109
108
|
bias (bool, optional): If set to :obj:`False`, the layer will not learn
|
110
109
|
an additive bias. (default: :obj:`True`)
|
110
|
+
residual (bool, optional): If set to :obj:`True`, the layer will add
|
111
|
+
a learnable skip-connection. (default: :obj:`False`)
|
111
112
|
**kwargs (optional): Additional arguments of
|
112
113
|
:class:`torch_geometric.nn.conv.MessagePassing`.
|
113
114
|
|
@@ -138,6 +139,7 @@ class GATConv(MessagePassing):
|
|
138
139
|
edge_dim: Optional[int] = None,
|
139
140
|
fill_value: Union[float, Tensor, str] = 'mean',
|
140
141
|
bias: bool = True,
|
142
|
+
residual: bool = False,
|
141
143
|
**kwargs,
|
142
144
|
):
|
143
145
|
kwargs.setdefault('aggr', 'add')
|
@@ -152,6 +154,7 @@ class GATConv(MessagePassing):
|
|
152
154
|
self.add_self_loops = add_self_loops
|
153
155
|
self.edge_dim = edge_dim
|
154
156
|
self.fill_value = fill_value
|
157
|
+
self.residual = residual
|
155
158
|
|
156
159
|
# In case we are operating in bipartite graphs, we apply separate
|
157
160
|
# transformations 'lin_src' and 'lin_dst' to source and target nodes:
|
@@ -177,10 +180,22 @@ class GATConv(MessagePassing):
|
|
177
180
|
self.lin_edge = None
|
178
181
|
self.register_parameter('att_edge', None)
|
179
182
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
183
|
+
# The number of output channels:
|
184
|
+
total_out_channels = out_channels * (heads if concat else 1)
|
185
|
+
|
186
|
+
if residual:
|
187
|
+
self.res = Linear(
|
188
|
+
in_channels
|
189
|
+
if isinstance(in_channels, int) else in_channels[1],
|
190
|
+
total_out_channels,
|
191
|
+
bias=False,
|
192
|
+
weight_initializer='glorot',
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
self.register_parameter('res', None)
|
196
|
+
|
197
|
+
if bias:
|
198
|
+
self.bias = Parameter(torch.empty(total_out_channels))
|
184
199
|
else:
|
185
200
|
self.register_parameter('bias', None)
|
186
201
|
|
@@ -196,6 +211,8 @@ class GATConv(MessagePassing):
|
|
196
211
|
self.lin_dst.reset_parameters()
|
197
212
|
if self.lin_edge is not None:
|
198
213
|
self.lin_edge.reset_parameters()
|
214
|
+
if self.res is not None:
|
215
|
+
self.res.reset_parameters()
|
199
216
|
glorot(self.att_src)
|
200
217
|
glorot(self.att_dst)
|
201
218
|
glorot(self.att_edge)
|
@@ -271,11 +288,16 @@ class GATConv(MessagePassing):
|
|
271
288
|
|
272
289
|
H, C = self.heads, self.out_channels
|
273
290
|
|
291
|
+
res: Optional[Tensor] = None
|
292
|
+
|
274
293
|
# We first transform the input node features. If a tuple is passed, we
|
275
294
|
# transform source and target node features via separate weights:
|
276
295
|
if isinstance(x, Tensor):
|
277
296
|
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
|
278
297
|
|
298
|
+
if self.res is not None:
|
299
|
+
res = self.res(x)
|
300
|
+
|
279
301
|
if self.lin is not None:
|
280
302
|
x_src = x_dst = self.lin(x).view(-1, H, C)
|
281
303
|
else:
|
@@ -289,6 +311,9 @@ class GATConv(MessagePassing):
|
|
289
311
|
x_src, x_dst = x
|
290
312
|
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
|
291
313
|
|
314
|
+
if x_dst is not None and self.res is not None:
|
315
|
+
res = self.res(x_dst)
|
316
|
+
|
292
317
|
if self.lin is not None:
|
293
318
|
# If the module is initialized as non-bipartite, we expect that
|
294
319
|
# source and destination node features have the same shape and
|
@@ -345,6 +370,9 @@ class GATConv(MessagePassing):
|
|
345
370
|
else:
|
346
371
|
out = out.mean(dim=1)
|
347
372
|
|
373
|
+
if res is not None:
|
374
|
+
out = out + res
|
375
|
+
|
348
376
|
if self.bias is not None:
|
349
377
|
out = out + self.bias
|
350
378
|
|
@@ -41,8 +41,7 @@ class GATv2Conv(MessagePassing):
|
|
41
41
|
In contrast, in :class:`GATv2`, every node can attend to any other node.
|
42
42
|
|
43
43
|
.. math::
|
44
|
-
\mathbf{x}^{\prime}_i = \
|
45
|
-
\sum_{j \in \mathcal{N}(i)}
|
44
|
+
\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}}
|
46
45
|
\alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j},
|
47
46
|
|
48
47
|
where the attention coefficients :math:`\alpha_{i,j}` are computed as
|
@@ -111,6 +110,8 @@ class GATv2Conv(MessagePassing):
|
|
111
110
|
will be applied to the source and the target node of every edge,
|
112
111
|
*i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`.
|
113
112
|
(default: :obj:`False`)
|
113
|
+
residual (bool, optional): If set to :obj:`True`, the layer will add
|
114
|
+
a learnable skip-connection. (default: :obj:`False`)
|
114
115
|
**kwargs (optional): Additional arguments of
|
115
116
|
:class:`torch_geometric.nn.conv.MessagePassing`.
|
116
117
|
|
@@ -142,6 +143,7 @@ class GATv2Conv(MessagePassing):
|
|
142
143
|
fill_value: Union[float, Tensor, str] = 'mean',
|
143
144
|
bias: bool = True,
|
144
145
|
share_weights: bool = False,
|
146
|
+
residual: bool = False,
|
145
147
|
**kwargs,
|
146
148
|
):
|
147
149
|
super().__init__(node_dim=0, **kwargs)
|
@@ -155,6 +157,7 @@ class GATv2Conv(MessagePassing):
|
|
155
157
|
self.add_self_loops = add_self_loops
|
156
158
|
self.edge_dim = edge_dim
|
157
159
|
self.fill_value = fill_value
|
160
|
+
self.residual = residual
|
158
161
|
self.share_weights = share_weights
|
159
162
|
|
160
163
|
if isinstance(in_channels, int):
|
@@ -182,10 +185,22 @@ class GATv2Conv(MessagePassing):
|
|
182
185
|
else:
|
183
186
|
self.lin_edge = None
|
184
187
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
188
|
+
# The number of output channels:
|
189
|
+
total_out_channels = out_channels * (heads if concat else 1)
|
190
|
+
|
191
|
+
if residual:
|
192
|
+
self.res = Linear(
|
193
|
+
in_channels
|
194
|
+
if isinstance(in_channels, int) else in_channels[1],
|
195
|
+
total_out_channels,
|
196
|
+
bias=False,
|
197
|
+
weight_initializer='glorot',
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
self.register_parameter('res', None)
|
201
|
+
|
202
|
+
if bias:
|
203
|
+
self.bias = Parameter(torch.empty(total_out_channels))
|
189
204
|
else:
|
190
205
|
self.register_parameter('bias', None)
|
191
206
|
|
@@ -197,6 +212,8 @@ class GATv2Conv(MessagePassing):
|
|
197
212
|
self.lin_r.reset_parameters()
|
198
213
|
if self.lin_edge is not None:
|
199
214
|
self.lin_edge.reset_parameters()
|
215
|
+
if self.res is not None:
|
216
|
+
self.res.reset_parameters()
|
200
217
|
glorot(self.att)
|
201
218
|
zeros(self.bias)
|
202
219
|
|
@@ -256,10 +273,16 @@ class GATv2Conv(MessagePassing):
|
|
256
273
|
"""
|
257
274
|
H, C = self.heads, self.out_channels
|
258
275
|
|
276
|
+
res: Optional[Tensor] = None
|
277
|
+
|
259
278
|
x_l: OptTensor = None
|
260
279
|
x_r: OptTensor = None
|
261
280
|
if isinstance(x, Tensor):
|
262
281
|
assert x.dim() == 2
|
282
|
+
|
283
|
+
if self.res is not None:
|
284
|
+
res = self.res(x)
|
285
|
+
|
263
286
|
x_l = self.lin_l(x).view(-1, H, C)
|
264
287
|
if self.share_weights:
|
265
288
|
x_r = x_l
|
@@ -268,6 +291,10 @@ class GATv2Conv(MessagePassing):
|
|
268
291
|
else:
|
269
292
|
x_l, x_r = x[0], x[1]
|
270
293
|
assert x[0].dim() == 2
|
294
|
+
|
295
|
+
if x_r is not None and self.res is not None:
|
296
|
+
res = self.res(x_r)
|
297
|
+
|
271
298
|
x_l = self.lin_l(x_l).view(-1, H, C)
|
272
299
|
if x_r is not None:
|
273
300
|
x_r = self.lin_r(x_r).view(-1, H, C)
|
@@ -306,6 +333,9 @@ class GATv2Conv(MessagePassing):
|
|
306
333
|
else:
|
307
334
|
out = out.mean(dim=1)
|
308
335
|
|
336
|
+
if res is not None:
|
337
|
+
out = out + res
|
338
|
+
|
309
339
|
if self.bias is not None:
|
310
340
|
out = out + self.bias
|
311
341
|
|
@@ -70,7 +70,7 @@ class GeneralConv(MessagePassing):
|
|
70
70
|
self,
|
71
71
|
in_channels: Union[int, Tuple[int, int]],
|
72
72
|
out_channels: Optional[int],
|
73
|
-
in_edge_channels: int = None,
|
73
|
+
in_edge_channels: Optional[int] = None,
|
74
74
|
aggr: str = "add",
|
75
75
|
skip_linear: str = False,
|
76
76
|
directed_msg: bool = True,
|
@@ -1,7 +1,9 @@
|
|
1
|
-
from typing import Tuple, Union
|
1
|
+
from typing import Final, Tuple, Union
|
2
2
|
|
3
|
+
import torch
|
3
4
|
from torch import Tensor
|
4
5
|
|
6
|
+
from torch_geometric import EdgeIndex
|
5
7
|
from torch_geometric.nn.conv import MessagePassing
|
6
8
|
from torch_geometric.nn.dense.linear import Linear
|
7
9
|
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
|
@@ -44,6 +46,8 @@ class GraphConv(MessagePassing):
|
|
44
46
|
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
|
45
47
|
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
|
46
48
|
"""
|
49
|
+
SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = True
|
50
|
+
|
47
51
|
def __init__(
|
48
52
|
self,
|
49
53
|
in_channels: Union[int, Tuple[int, int]],
|
@@ -90,5 +94,19 @@ class GraphConv(MessagePassing):
|
|
90
94
|
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
|
91
95
|
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
|
92
96
|
|
93
|
-
def message_and_aggregate(
|
94
|
-
|
97
|
+
def message_and_aggregate(
|
98
|
+
self,
|
99
|
+
edge_index: Adj,
|
100
|
+
x: OptPairTensor,
|
101
|
+
edge_weight: OptTensor,
|
102
|
+
) -> Tensor:
|
103
|
+
|
104
|
+
if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
|
105
|
+
return edge_index.matmul(
|
106
|
+
other=x[0],
|
107
|
+
input_value=edge_weight,
|
108
|
+
reduce=self.aggr,
|
109
|
+
transpose=True,
|
110
|
+
)
|
111
|
+
|
112
|
+
return spmm(edge_index, x[0], reduce=self.aggr)
|
@@ -4,14 +4,15 @@ from typing import Optional, Union
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
6
6
|
|
7
|
+
import torch_geometric.typing
|
7
8
|
from torch_geometric.nn.conv import MessagePassing
|
8
9
|
from torch_geometric.nn.dense.linear import Linear
|
9
10
|
from torch_geometric.typing import OptPairTensor # noqa
|
10
11
|
from torch_geometric.typing import OptTensor, PairOptTensor, PairTensor
|
11
12
|
|
12
|
-
|
13
|
+
if torch_geometric.typing.WITH_TORCH_CLUSTER:
|
13
14
|
from torch_cluster import knn
|
14
|
-
|
15
|
+
else:
|
15
16
|
knn = None
|
16
17
|
|
17
18
|
|