pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
- {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +34 -1
- torch_geometric/_compile.py +11 -3
- torch_geometric/_onnx.py +228 -0
- torch_geometric/config_mixin.py +8 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/__init__.py +19 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +110 -6
- torch_geometric/data/database.py +19 -5
- torch_geometric/data/dataset.py +14 -9
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +20 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +17 -20
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +2 -1
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +19 -5
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcqm4m.py +2 -1
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +3 -2
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +342 -0
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/datasets/wikidata.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +35 -22
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +89 -5
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +16 -7
- torch_geometric/inspector.py +6 -2
- torch_geometric/io/fs.py +27 -0
- torch_geometric/io/tu.py +2 -3
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/llm/models/g_retriever.py +251 -0
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/llm/models/llm.py +470 -0
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +6 -5
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +4 -3
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +23 -2
- torch_geometric/metrics/link_pred.py +755 -85
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/__init__.py +2 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +149 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/collect.jinja +6 -3
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gat_conv.py +33 -4
- torch_geometric/nn/conv/gatv2_conv.py +35 -4
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +3 -2
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +6 -5
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +5 -24
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +17 -15
- torch_geometric/nn/model_hub.py +5 -16
- torch_geometric/nn/models/__init__.py +11 -0
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +15 -9
- torch_geometric/nn/pool/cluster_pool.py +144 -0
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +18 -9
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +337 -8
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +298 -25
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +4 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +10 -8
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +8 -9
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/pad.py +1 -1
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +82 -17
- torch_geometric/utils/__init__.py +6 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +28 -13
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +126 -164
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +1 -1
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +28 -25
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +14 -10
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +32 -24
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/llm.py +0 -283
- torch_geometric/nn/nlp/sentence_transformer.py +0 -94
|
@@ -81,7 +81,7 @@ class EGConv(MessagePassing):
|
|
|
81
81
|
self,
|
|
82
82
|
in_channels: int,
|
|
83
83
|
out_channels: int,
|
|
84
|
-
aggregators: List[str] =
|
|
84
|
+
aggregators: Optional[List[str]] = None,
|
|
85
85
|
num_heads: int = 8,
|
|
86
86
|
num_bases: int = 4,
|
|
87
87
|
cached: bool = False,
|
|
@@ -96,23 +96,23 @@ class EGConv(MessagePassing):
|
|
|
96
96
|
f"divisible by the number of heads "
|
|
97
97
|
f"(got {num_heads})")
|
|
98
98
|
|
|
99
|
-
for a in aggregators:
|
|
100
|
-
if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
|
|
101
|
-
raise ValueError(f"Unsupported aggregator: '{a}'")
|
|
102
|
-
|
|
103
99
|
self.in_channels = in_channels
|
|
104
100
|
self.out_channels = out_channels
|
|
105
101
|
self.num_heads = num_heads
|
|
106
102
|
self.num_bases = num_bases
|
|
107
103
|
self.cached = cached
|
|
108
104
|
self.add_self_loops = add_self_loops
|
|
109
|
-
self.aggregators = aggregators
|
|
105
|
+
self.aggregators = aggregators or ['symnorm']
|
|
106
|
+
|
|
107
|
+
for a in self.aggregators:
|
|
108
|
+
if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
|
|
109
|
+
raise ValueError(f"Unsupported aggregator: '{a}'")
|
|
110
110
|
|
|
111
111
|
self.bases_lin = Linear(in_channels,
|
|
112
112
|
(out_channels // num_heads) * num_bases,
|
|
113
113
|
bias=False, weight_initializer='glorot')
|
|
114
114
|
self.comb_lin = Linear(in_channels,
|
|
115
|
-
num_heads * num_bases * len(aggregators))
|
|
115
|
+
num_heads * num_bases * len(self.aggregators))
|
|
116
116
|
|
|
117
117
|
if bias:
|
|
118
118
|
self.bias = Parameter(torch.empty(out_channels))
|
|
@@ -107,6 +107,8 @@ class GATConv(MessagePassing):
|
|
|
107
107
|
:obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
|
|
108
108
|
bias (bool, optional): If set to :obj:`False`, the layer will not learn
|
|
109
109
|
an additive bias. (default: :obj:`True`)
|
|
110
|
+
residual (bool, optional): If set to :obj:`True`, the layer will add
|
|
111
|
+
a learnable skip-connection. (default: :obj:`False`)
|
|
110
112
|
**kwargs (optional): Additional arguments of
|
|
111
113
|
:class:`torch_geometric.nn.conv.MessagePassing`.
|
|
112
114
|
|
|
@@ -137,6 +139,7 @@ class GATConv(MessagePassing):
|
|
|
137
139
|
edge_dim: Optional[int] = None,
|
|
138
140
|
fill_value: Union[float, Tensor, str] = 'mean',
|
|
139
141
|
bias: bool = True,
|
|
142
|
+
residual: bool = False,
|
|
140
143
|
**kwargs,
|
|
141
144
|
):
|
|
142
145
|
kwargs.setdefault('aggr', 'add')
|
|
@@ -151,6 +154,7 @@ class GATConv(MessagePassing):
|
|
|
151
154
|
self.add_self_loops = add_self_loops
|
|
152
155
|
self.edge_dim = edge_dim
|
|
153
156
|
self.fill_value = fill_value
|
|
157
|
+
self.residual = residual
|
|
154
158
|
|
|
155
159
|
# In case we are operating in bipartite graphs, we apply separate
|
|
156
160
|
# transformations 'lin_src' and 'lin_dst' to source and target nodes:
|
|
@@ -176,10 +180,22 @@ class GATConv(MessagePassing):
|
|
|
176
180
|
self.lin_edge = None
|
|
177
181
|
self.register_parameter('att_edge', None)
|
|
178
182
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
+
# The number of output channels:
|
|
184
|
+
total_out_channels = out_channels * (heads if concat else 1)
|
|
185
|
+
|
|
186
|
+
if residual:
|
|
187
|
+
self.res = Linear(
|
|
188
|
+
in_channels
|
|
189
|
+
if isinstance(in_channels, int) else in_channels[1],
|
|
190
|
+
total_out_channels,
|
|
191
|
+
bias=False,
|
|
192
|
+
weight_initializer='glorot',
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
self.register_parameter('res', None)
|
|
196
|
+
|
|
197
|
+
if bias:
|
|
198
|
+
self.bias = Parameter(torch.empty(total_out_channels))
|
|
183
199
|
else:
|
|
184
200
|
self.register_parameter('bias', None)
|
|
185
201
|
|
|
@@ -195,6 +211,8 @@ class GATConv(MessagePassing):
|
|
|
195
211
|
self.lin_dst.reset_parameters()
|
|
196
212
|
if self.lin_edge is not None:
|
|
197
213
|
self.lin_edge.reset_parameters()
|
|
214
|
+
if self.res is not None:
|
|
215
|
+
self.res.reset_parameters()
|
|
198
216
|
glorot(self.att_src)
|
|
199
217
|
glorot(self.att_dst)
|
|
200
218
|
glorot(self.att_edge)
|
|
@@ -270,11 +288,16 @@ class GATConv(MessagePassing):
|
|
|
270
288
|
|
|
271
289
|
H, C = self.heads, self.out_channels
|
|
272
290
|
|
|
291
|
+
res: Optional[Tensor] = None
|
|
292
|
+
|
|
273
293
|
# We first transform the input node features. If a tuple is passed, we
|
|
274
294
|
# transform source and target node features via separate weights:
|
|
275
295
|
if isinstance(x, Tensor):
|
|
276
296
|
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
|
|
277
297
|
|
|
298
|
+
if self.res is not None:
|
|
299
|
+
res = self.res(x)
|
|
300
|
+
|
|
278
301
|
if self.lin is not None:
|
|
279
302
|
x_src = x_dst = self.lin(x).view(-1, H, C)
|
|
280
303
|
else:
|
|
@@ -288,6 +311,9 @@ class GATConv(MessagePassing):
|
|
|
288
311
|
x_src, x_dst = x
|
|
289
312
|
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
|
|
290
313
|
|
|
314
|
+
if x_dst is not None and self.res is not None:
|
|
315
|
+
res = self.res(x_dst)
|
|
316
|
+
|
|
291
317
|
if self.lin is not None:
|
|
292
318
|
# If the module is initialized as non-bipartite, we expect that
|
|
293
319
|
# source and destination node features have the same shape and
|
|
@@ -344,6 +370,9 @@ class GATConv(MessagePassing):
|
|
|
344
370
|
else:
|
|
345
371
|
out = out.mean(dim=1)
|
|
346
372
|
|
|
373
|
+
if res is not None:
|
|
374
|
+
out = out + res
|
|
375
|
+
|
|
347
376
|
if self.bias is not None:
|
|
348
377
|
out = out + self.bias
|
|
349
378
|
|
|
@@ -110,6 +110,8 @@ class GATv2Conv(MessagePassing):
|
|
|
110
110
|
will be applied to the source and the target node of every edge,
|
|
111
111
|
*i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`.
|
|
112
112
|
(default: :obj:`False`)
|
|
113
|
+
residual (bool, optional): If set to :obj:`True`, the layer will add
|
|
114
|
+
a learnable skip-connection. (default: :obj:`False`)
|
|
113
115
|
**kwargs (optional): Additional arguments of
|
|
114
116
|
:class:`torch_geometric.nn.conv.MessagePassing`.
|
|
115
117
|
|
|
@@ -141,6 +143,7 @@ class GATv2Conv(MessagePassing):
|
|
|
141
143
|
fill_value: Union[float, Tensor, str] = 'mean',
|
|
142
144
|
bias: bool = True,
|
|
143
145
|
share_weights: bool = False,
|
|
146
|
+
residual: bool = False,
|
|
144
147
|
**kwargs,
|
|
145
148
|
):
|
|
146
149
|
super().__init__(node_dim=0, **kwargs)
|
|
@@ -154,6 +157,7 @@ class GATv2Conv(MessagePassing):
|
|
|
154
157
|
self.add_self_loops = add_self_loops
|
|
155
158
|
self.edge_dim = edge_dim
|
|
156
159
|
self.fill_value = fill_value
|
|
160
|
+
self.residual = residual
|
|
157
161
|
self.share_weights = share_weights
|
|
158
162
|
|
|
159
163
|
if isinstance(in_channels, int):
|
|
@@ -181,10 +185,22 @@ class GATv2Conv(MessagePassing):
|
|
|
181
185
|
else:
|
|
182
186
|
self.lin_edge = None
|
|
183
187
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
+
# The number of output channels:
|
|
189
|
+
total_out_channels = out_channels * (heads if concat else 1)
|
|
190
|
+
|
|
191
|
+
if residual:
|
|
192
|
+
self.res = Linear(
|
|
193
|
+
in_channels
|
|
194
|
+
if isinstance(in_channels, int) else in_channels[1],
|
|
195
|
+
total_out_channels,
|
|
196
|
+
bias=False,
|
|
197
|
+
weight_initializer='glorot',
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
self.register_parameter('res', None)
|
|
201
|
+
|
|
202
|
+
if bias:
|
|
203
|
+
self.bias = Parameter(torch.empty(total_out_channels))
|
|
188
204
|
else:
|
|
189
205
|
self.register_parameter('bias', None)
|
|
190
206
|
|
|
@@ -196,6 +212,8 @@ class GATv2Conv(MessagePassing):
|
|
|
196
212
|
self.lin_r.reset_parameters()
|
|
197
213
|
if self.lin_edge is not None:
|
|
198
214
|
self.lin_edge.reset_parameters()
|
|
215
|
+
if self.res is not None:
|
|
216
|
+
self.res.reset_parameters()
|
|
199
217
|
glorot(self.att)
|
|
200
218
|
zeros(self.bias)
|
|
201
219
|
|
|
@@ -255,10 +273,16 @@ class GATv2Conv(MessagePassing):
|
|
|
255
273
|
"""
|
|
256
274
|
H, C = self.heads, self.out_channels
|
|
257
275
|
|
|
276
|
+
res: Optional[Tensor] = None
|
|
277
|
+
|
|
258
278
|
x_l: OptTensor = None
|
|
259
279
|
x_r: OptTensor = None
|
|
260
280
|
if isinstance(x, Tensor):
|
|
261
281
|
assert x.dim() == 2
|
|
282
|
+
|
|
283
|
+
if self.res is not None:
|
|
284
|
+
res = self.res(x)
|
|
285
|
+
|
|
262
286
|
x_l = self.lin_l(x).view(-1, H, C)
|
|
263
287
|
if self.share_weights:
|
|
264
288
|
x_r = x_l
|
|
@@ -267,6 +291,10 @@ class GATv2Conv(MessagePassing):
|
|
|
267
291
|
else:
|
|
268
292
|
x_l, x_r = x[0], x[1]
|
|
269
293
|
assert x[0].dim() == 2
|
|
294
|
+
|
|
295
|
+
if x_r is not None and self.res is not None:
|
|
296
|
+
res = self.res(x_r)
|
|
297
|
+
|
|
270
298
|
x_l = self.lin_l(x_l).view(-1, H, C)
|
|
271
299
|
if x_r is not None:
|
|
272
300
|
x_r = self.lin_r(x_r).view(-1, H, C)
|
|
@@ -305,6 +333,9 @@ class GATv2Conv(MessagePassing):
|
|
|
305
333
|
else:
|
|
306
334
|
out = out.mean(dim=1)
|
|
307
335
|
|
|
336
|
+
if res is not None:
|
|
337
|
+
out = out + res
|
|
338
|
+
|
|
308
339
|
if self.bias is not None:
|
|
309
340
|
out = out + self.bias
|
|
310
341
|
|
|
@@ -178,7 +178,7 @@ class GENConv(MessagePassing):
|
|
|
178
178
|
self.lin_dst = Linear(in_channels[1], out_channels, bias=bias)
|
|
179
179
|
|
|
180
180
|
channels = [out_channels]
|
|
181
|
-
for
|
|
181
|
+
for _ in range(num_layers - 1):
|
|
182
182
|
channels.append(out_channels * expansion)
|
|
183
183
|
channels.append(out_channels)
|
|
184
184
|
self.mlp = MLP(channels, norm=norm, bias=bias)
|
|
@@ -70,7 +70,7 @@ class GeneralConv(MessagePassing):
|
|
|
70
70
|
self,
|
|
71
71
|
in_channels: Union[int, Tuple[int, int]],
|
|
72
72
|
out_channels: Optional[int],
|
|
73
|
-
in_edge_channels: int = None,
|
|
73
|
+
in_edge_channels: Optional[int] = None,
|
|
74
74
|
aggr: str = "add",
|
|
75
75
|
skip_linear: str = False,
|
|
76
76
|
directed_msg: bool = True,
|
|
@@ -63,7 +63,8 @@ class GravNetConv(MessagePassing):
|
|
|
63
63
|
if num_workers is not None:
|
|
64
64
|
warnings.warn(
|
|
65
65
|
"'num_workers' attribute in '{self.__class__.__name__}' is "
|
|
66
|
-
"deprecated and will be removed in a future release"
|
|
66
|
+
"deprecated and will be removed in a future release",
|
|
67
|
+
stacklevel=2)
|
|
67
68
|
|
|
68
69
|
self.in_channels = in_channels
|
|
69
70
|
self.out_channels = out_channels
|
|
@@ -77,7 +77,8 @@ class HeteroConv(torch.nn.Module):
|
|
|
77
77
|
f"There exist node types ({src_node_types - dst_node_types}) "
|
|
78
78
|
f"whose representations do not get updated during message "
|
|
79
79
|
f"passing as they do not occur as destination type in any "
|
|
80
|
-
f"edge type. This may lead to unexpected behavior."
|
|
80
|
+
f"edge type. This may lead to unexpected behavior.",
|
|
81
|
+
stacklevel=2)
|
|
81
82
|
|
|
82
83
|
self.convs = ModuleDict(convs)
|
|
83
84
|
self.aggr = aggr
|
|
@@ -102,7 +103,7 @@ class HeteroConv(torch.nn.Module):
|
|
|
102
103
|
individual edge type, either as a :class:`torch.Tensor` of
|
|
103
104
|
shape :obj:`[2, num_edges]` or a
|
|
104
105
|
:class:`torch_sparse.SparseTensor`.
|
|
105
|
-
*args_dict (optional): Additional forward arguments of
|
|
106
|
+
*args_dict (optional): Additional forward arguments of individual
|
|
106
107
|
:class:`torch_geometric.nn.conv.MessagePassing` layers.
|
|
107
108
|
**kwargs_dict (optional): Additional forward arguments of
|
|
108
109
|
individual :class:`torch_geometric.nn.conv.MessagePassing`
|