pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -81,7 +81,7 @@ class EGConv(MessagePassing):
81
81
  self,
82
82
  in_channels: int,
83
83
  out_channels: int,
84
- aggregators: List[str] = ['symnorm'],
84
+ aggregators: Optional[List[str]] = None,
85
85
  num_heads: int = 8,
86
86
  num_bases: int = 4,
87
87
  cached: bool = False,
@@ -96,23 +96,23 @@ class EGConv(MessagePassing):
96
96
  f"divisible by the number of heads "
97
97
  f"(got {num_heads})")
98
98
 
99
- for a in aggregators:
100
- if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
101
- raise ValueError(f"Unsupported aggregator: '{a}'")
102
-
103
99
  self.in_channels = in_channels
104
100
  self.out_channels = out_channels
105
101
  self.num_heads = num_heads
106
102
  self.num_bases = num_bases
107
103
  self.cached = cached
108
104
  self.add_self_loops = add_self_loops
109
- self.aggregators = aggregators
105
+ self.aggregators = aggregators or ['symnorm']
106
+
107
+ for a in self.aggregators:
108
+ if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
109
+ raise ValueError(f"Unsupported aggregator: '{a}'")
110
110
 
111
111
  self.bases_lin = Linear(in_channels,
112
112
  (out_channels // num_heads) * num_bases,
113
113
  bias=False, weight_initializer='glorot')
114
114
  self.comb_lin = Linear(in_channels,
115
- num_heads * num_bases * len(aggregators))
115
+ num_heads * num_bases * len(self.aggregators))
116
116
 
117
117
  if bias:
118
118
  self.bias = Parameter(torch.empty(out_channels))
@@ -107,6 +107,8 @@ class GATConv(MessagePassing):
107
107
  :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
108
108
  bias (bool, optional): If set to :obj:`False`, the layer will not learn
109
109
  an additive bias. (default: :obj:`True`)
110
+ residual (bool, optional): If set to :obj:`True`, the layer will add
111
+ a learnable skip-connection. (default: :obj:`False`)
110
112
  **kwargs (optional): Additional arguments of
111
113
  :class:`torch_geometric.nn.conv.MessagePassing`.
112
114
 
@@ -137,6 +139,7 @@ class GATConv(MessagePassing):
137
139
  edge_dim: Optional[int] = None,
138
140
  fill_value: Union[float, Tensor, str] = 'mean',
139
141
  bias: bool = True,
142
+ residual: bool = False,
140
143
  **kwargs,
141
144
  ):
142
145
  kwargs.setdefault('aggr', 'add')
@@ -151,6 +154,7 @@ class GATConv(MessagePassing):
151
154
  self.add_self_loops = add_self_loops
152
155
  self.edge_dim = edge_dim
153
156
  self.fill_value = fill_value
157
+ self.residual = residual
154
158
 
155
159
  # In case we are operating in bipartite graphs, we apply separate
156
160
  # transformations 'lin_src' and 'lin_dst' to source and target nodes:
@@ -176,10 +180,22 @@ class GATConv(MessagePassing):
176
180
  self.lin_edge = None
177
181
  self.register_parameter('att_edge', None)
178
182
 
179
- if bias and concat:
180
- self.bias = Parameter(torch.empty(heads * out_channels))
181
- elif bias and not concat:
182
- self.bias = Parameter(torch.empty(out_channels))
183
+ # The number of output channels:
184
+ total_out_channels = out_channels * (heads if concat else 1)
185
+
186
+ if residual:
187
+ self.res = Linear(
188
+ in_channels
189
+ if isinstance(in_channels, int) else in_channels[1],
190
+ total_out_channels,
191
+ bias=False,
192
+ weight_initializer='glorot',
193
+ )
194
+ else:
195
+ self.register_parameter('res', None)
196
+
197
+ if bias:
198
+ self.bias = Parameter(torch.empty(total_out_channels))
183
199
  else:
184
200
  self.register_parameter('bias', None)
185
201
 
@@ -195,6 +211,8 @@ class GATConv(MessagePassing):
195
211
  self.lin_dst.reset_parameters()
196
212
  if self.lin_edge is not None:
197
213
  self.lin_edge.reset_parameters()
214
+ if self.res is not None:
215
+ self.res.reset_parameters()
198
216
  glorot(self.att_src)
199
217
  glorot(self.att_dst)
200
218
  glorot(self.att_edge)
@@ -270,11 +288,16 @@ class GATConv(MessagePassing):
270
288
 
271
289
  H, C = self.heads, self.out_channels
272
290
 
291
+ res: Optional[Tensor] = None
292
+
273
293
  # We first transform the input node features. If a tuple is passed, we
274
294
  # transform source and target node features via separate weights:
275
295
  if isinstance(x, Tensor):
276
296
  assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
277
297
 
298
+ if self.res is not None:
299
+ res = self.res(x)
300
+
278
301
  if self.lin is not None:
279
302
  x_src = x_dst = self.lin(x).view(-1, H, C)
280
303
  else:
@@ -288,6 +311,9 @@ class GATConv(MessagePassing):
288
311
  x_src, x_dst = x
289
312
  assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
290
313
 
314
+ if x_dst is not None and self.res is not None:
315
+ res = self.res(x_dst)
316
+
291
317
  if self.lin is not None:
292
318
  # If the module is initialized as non-bipartite, we expect that
293
319
  # source and destination node features have the same shape and
@@ -344,6 +370,9 @@ class GATConv(MessagePassing):
344
370
  else:
345
371
  out = out.mean(dim=1)
346
372
 
373
+ if res is not None:
374
+ out = out + res
375
+
347
376
  if self.bias is not None:
348
377
  out = out + self.bias
349
378
 
@@ -110,6 +110,8 @@ class GATv2Conv(MessagePassing):
110
110
  will be applied to the source and the target node of every edge,
111
111
  *i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`.
112
112
  (default: :obj:`False`)
113
+ residual (bool, optional): If set to :obj:`True`, the layer will add
114
+ a learnable skip-connection. (default: :obj:`False`)
113
115
  **kwargs (optional): Additional arguments of
114
116
  :class:`torch_geometric.nn.conv.MessagePassing`.
115
117
 
@@ -141,6 +143,7 @@ class GATv2Conv(MessagePassing):
141
143
  fill_value: Union[float, Tensor, str] = 'mean',
142
144
  bias: bool = True,
143
145
  share_weights: bool = False,
146
+ residual: bool = False,
144
147
  **kwargs,
145
148
  ):
146
149
  super().__init__(node_dim=0, **kwargs)
@@ -154,6 +157,7 @@ class GATv2Conv(MessagePassing):
154
157
  self.add_self_loops = add_self_loops
155
158
  self.edge_dim = edge_dim
156
159
  self.fill_value = fill_value
160
+ self.residual = residual
157
161
  self.share_weights = share_weights
158
162
 
159
163
  if isinstance(in_channels, int):
@@ -181,10 +185,22 @@ class GATv2Conv(MessagePassing):
181
185
  else:
182
186
  self.lin_edge = None
183
187
 
184
- if bias and concat:
185
- self.bias = Parameter(torch.empty(heads * out_channels))
186
- elif bias and not concat:
187
- self.bias = Parameter(torch.empty(out_channels))
188
+ # The number of output channels:
189
+ total_out_channels = out_channels * (heads if concat else 1)
190
+
191
+ if residual:
192
+ self.res = Linear(
193
+ in_channels
194
+ if isinstance(in_channels, int) else in_channels[1],
195
+ total_out_channels,
196
+ bias=False,
197
+ weight_initializer='glorot',
198
+ )
199
+ else:
200
+ self.register_parameter('res', None)
201
+
202
+ if bias:
203
+ self.bias = Parameter(torch.empty(total_out_channels))
188
204
  else:
189
205
  self.register_parameter('bias', None)
190
206
 
@@ -196,6 +212,8 @@ class GATv2Conv(MessagePassing):
196
212
  self.lin_r.reset_parameters()
197
213
  if self.lin_edge is not None:
198
214
  self.lin_edge.reset_parameters()
215
+ if self.res is not None:
216
+ self.res.reset_parameters()
199
217
  glorot(self.att)
200
218
  zeros(self.bias)
201
219
 
@@ -255,10 +273,16 @@ class GATv2Conv(MessagePassing):
255
273
  """
256
274
  H, C = self.heads, self.out_channels
257
275
 
276
+ res: Optional[Tensor] = None
277
+
258
278
  x_l: OptTensor = None
259
279
  x_r: OptTensor = None
260
280
  if isinstance(x, Tensor):
261
281
  assert x.dim() == 2
282
+
283
+ if self.res is not None:
284
+ res = self.res(x)
285
+
262
286
  x_l = self.lin_l(x).view(-1, H, C)
263
287
  if self.share_weights:
264
288
  x_r = x_l
@@ -267,6 +291,10 @@ class GATv2Conv(MessagePassing):
267
291
  else:
268
292
  x_l, x_r = x[0], x[1]
269
293
  assert x[0].dim() == 2
294
+
295
+ if x_r is not None and self.res is not None:
296
+ res = self.res(x_r)
297
+
270
298
  x_l = self.lin_l(x_l).view(-1, H, C)
271
299
  if x_r is not None:
272
300
  x_r = self.lin_r(x_r).view(-1, H, C)
@@ -305,6 +333,9 @@ class GATv2Conv(MessagePassing):
305
333
  else:
306
334
  out = out.mean(dim=1)
307
335
 
336
+ if res is not None:
337
+ out = out + res
338
+
308
339
  if self.bias is not None:
309
340
  out = out + self.bias
310
341
 
@@ -178,7 +178,7 @@ class GENConv(MessagePassing):
178
178
  self.lin_dst = Linear(in_channels[1], out_channels, bias=bias)
179
179
 
180
180
  channels = [out_channels]
181
- for i in range(num_layers - 1):
181
+ for _ in range(num_layers - 1):
182
182
  channels.append(out_channels * expansion)
183
183
  channels.append(out_channels)
184
184
  self.mlp = MLP(channels, norm=norm, bias=bias)
@@ -70,7 +70,7 @@ class GeneralConv(MessagePassing):
70
70
  self,
71
71
  in_channels: Union[int, Tuple[int, int]],
72
72
  out_channels: Optional[int],
73
- in_edge_channels: int = None,
73
+ in_edge_channels: Optional[int] = None,
74
74
  aggr: str = "add",
75
75
  skip_linear: str = False,
76
76
  directed_msg: bool = True,
@@ -63,7 +63,8 @@ class GravNetConv(MessagePassing):
63
63
  if num_workers is not None:
64
64
  warnings.warn(
65
65
  "'num_workers' attribute in '{self.__class__.__name__}' is "
66
- "deprecated and will be removed in a future release")
66
+ "deprecated and will be removed in a future release",
67
+ stacklevel=2)
67
68
 
68
69
  self.in_channels = in_channels
69
70
  self.out_channels = out_channels
@@ -77,7 +77,8 @@ class HeteroConv(torch.nn.Module):
77
77
  f"There exist node types ({src_node_types - dst_node_types}) "
78
78
  f"whose representations do not get updated during message "
79
79
  f"passing as they do not occur as destination type in any "
80
- f"edge type. This may lead to unexpected behavior.")
80
+ f"edge type. This may lead to unexpected behavior.",
81
+ stacklevel=2)
81
82
 
82
83
  self.convs = ModuleDict(convs)
83
84
  self.aggr = aggr
@@ -102,7 +103,7 @@ class HeteroConv(torch.nn.Module):
102
103
  individual edge type, either as a :class:`torch.Tensor` of
103
104
  shape :obj:`[2, num_edges]` or a
104
105
  :class:`torch_sparse.SparseTensor`.
105
- *args_dict (optional): Additional forward arguments of invididual
106
+ *args_dict (optional): Additional forward arguments of individual
106
107
  :class:`torch_geometric.nn.conv.MessagePassing` layers.
107
108
  **kwargs_dict (optional): Additional forward arguments of
108
109
  individual :class:`torch_geometric.nn.conv.MessagePassing`