pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/METADATA +77 -53
  2. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/RECORD +227 -190
  3. {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251228.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251228.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +14 -2
  6. torch_geometric/_compile.py +9 -3
  7. torch_geometric/_onnx.py +214 -0
  8. torch_geometric/config_mixin.py +5 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/batch.py +2 -2
  13. torch_geometric/data/collate.py +1 -3
  14. torch_geometric/data/data.py +109 -5
  15. torch_geometric/data/database.py +4 -0
  16. torch_geometric/data/dataset.py +14 -11
  17. torch_geometric/data/extract.py +1 -1
  18. torch_geometric/data/feature_store.py +17 -22
  19. torch_geometric/data/graph_store.py +3 -2
  20. torch_geometric/data/hetero_data.py +139 -7
  21. torch_geometric/data/hypergraph_data.py +2 -2
  22. torch_geometric/data/in_memory_dataset.py +2 -2
  23. torch_geometric/data/lightning/datamodule.py +42 -28
  24. torch_geometric/data/storage.py +9 -1
  25. torch_geometric/datasets/__init__.py +18 -1
  26. torch_geometric/datasets/actor.py +7 -9
  27. torch_geometric/datasets/airfrans.py +15 -17
  28. torch_geometric/datasets/airports.py +8 -10
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +8 -9
  31. torch_geometric/datasets/amazon_products.py +7 -9
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/city.py +157 -0
  38. torch_geometric/datasets/dbp15k.py +1 -1
  39. torch_geometric/datasets/git_mol_dataset.py +263 -0
  40. torch_geometric/datasets/hgb_dataset.py +2 -2
  41. torch_geometric/datasets/hm.py +1 -1
  42. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  43. torch_geometric/datasets/md17.py +3 -3
  44. torch_geometric/datasets/medshapenet.py +145 -0
  45. torch_geometric/datasets/modelnet.py +1 -1
  46. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  47. torch_geometric/datasets/molecule_net.py +3 -2
  48. torch_geometric/datasets/ppi.py +2 -1
  49. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  50. torch_geometric/datasets/qm7.py +1 -1
  51. torch_geometric/datasets/qm9.py +1 -1
  52. torch_geometric/datasets/snap_dataset.py +8 -4
  53. torch_geometric/datasets/tag_dataset.py +462 -0
  54. torch_geometric/datasets/teeth3ds.py +269 -0
  55. torch_geometric/datasets/web_qsp_dataset.py +310 -209
  56. torch_geometric/datasets/wikics.py +2 -1
  57. torch_geometric/deprecation.py +1 -1
  58. torch_geometric/distributed/__init__.py +13 -0
  59. torch_geometric/distributed/dist_loader.py +2 -2
  60. torch_geometric/distributed/partition.py +2 -2
  61. torch_geometric/distributed/rpc.py +3 -3
  62. torch_geometric/edge_index.py +18 -14
  63. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  64. torch_geometric/explain/algorithm/base.py +2 -2
  65. torch_geometric/explain/algorithm/captum.py +1 -1
  66. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  67. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  68. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  69. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  70. torch_geometric/explain/explainer.py +2 -2
  71. torch_geometric/explain/explanation.py +87 -3
  72. torch_geometric/explain/metric/faithfulness.py +1 -1
  73. torch_geometric/graphgym/config.py +3 -2
  74. torch_geometric/graphgym/imports.py +15 -4
  75. torch_geometric/graphgym/logger.py +1 -1
  76. torch_geometric/graphgym/loss.py +1 -1
  77. torch_geometric/graphgym/models/encoder.py +2 -2
  78. torch_geometric/graphgym/models/layer.py +1 -1
  79. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  80. torch_geometric/hash_tensor.py +798 -0
  81. torch_geometric/index.py +14 -5
  82. torch_geometric/inspector.py +4 -0
  83. torch_geometric/io/fs.py +5 -4
  84. torch_geometric/llm/__init__.py +9 -0
  85. torch_geometric/llm/large_graph_indexer.py +741 -0
  86. torch_geometric/llm/models/__init__.py +23 -0
  87. torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
  88. torch_geometric/llm/models/git_mol.py +336 -0
  89. torch_geometric/llm/models/glem.py +397 -0
  90. torch_geometric/{nn/nlp → llm/models}/llm.py +180 -32
  91. torch_geometric/llm/models/llm_judge.py +158 -0
  92. torch_geometric/llm/models/molecule_gpt.py +222 -0
  93. torch_geometric/llm/models/protein_mpnn.py +333 -0
  94. torch_geometric/llm/models/sentence_transformer.py +188 -0
  95. torch_geometric/llm/models/txt2kg.py +353 -0
  96. torch_geometric/llm/models/vision_transformer.py +38 -0
  97. torch_geometric/llm/rag_loader.py +154 -0
  98. torch_geometric/llm/utils/__init__.py +10 -0
  99. torch_geometric/llm/utils/backend_utils.py +443 -0
  100. torch_geometric/llm/utils/feature_store.py +169 -0
  101. torch_geometric/llm/utils/graph_store.py +199 -0
  102. torch_geometric/llm/utils/vectorrag.py +125 -0
  103. torch_geometric/loader/cluster.py +4 -4
  104. torch_geometric/loader/ibmb_loader.py +4 -4
  105. torch_geometric/loader/link_loader.py +1 -1
  106. torch_geometric/loader/link_neighbor_loader.py +2 -1
  107. torch_geometric/loader/mixin.py +6 -5
  108. torch_geometric/loader/neighbor_loader.py +1 -1
  109. torch_geometric/loader/neighbor_sampler.py +2 -2
  110. torch_geometric/loader/prefetch.py +3 -2
  111. torch_geometric/loader/temporal_dataloader.py +2 -2
  112. torch_geometric/loader/utils.py +10 -10
  113. torch_geometric/metrics/__init__.py +14 -0
  114. torch_geometric/metrics/link_pred.py +745 -92
  115. torch_geometric/nn/__init__.py +1 -0
  116. torch_geometric/nn/aggr/base.py +1 -1
  117. torch_geometric/nn/aggr/equilibrium.py +1 -1
  118. torch_geometric/nn/aggr/fused.py +1 -1
  119. torch_geometric/nn/aggr/patch_transformer.py +8 -2
  120. torch_geometric/nn/aggr/set_transformer.py +1 -1
  121. torch_geometric/nn/aggr/utils.py +9 -4
  122. torch_geometric/nn/attention/__init__.py +9 -1
  123. torch_geometric/nn/attention/polynormer.py +107 -0
  124. torch_geometric/nn/attention/qformer.py +71 -0
  125. torch_geometric/nn/attention/sgformer.py +99 -0
  126. torch_geometric/nn/conv/__init__.py +2 -0
  127. torch_geometric/nn/conv/appnp.py +1 -1
  128. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  129. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  130. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  131. torch_geometric/nn/conv/dna_conv.py +1 -1
  132. torch_geometric/nn/conv/eg_conv.py +7 -7
  133. torch_geometric/nn/conv/gen_conv.py +1 -1
  134. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  135. torch_geometric/nn/conv/hetero_conv.py +2 -1
  136. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  137. torch_geometric/nn/conv/message_passing.py +5 -4
  138. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  139. torch_geometric/nn/conv/sg_conv.py +1 -1
  140. torch_geometric/nn/conv/spline_conv.py +2 -1
  141. torch_geometric/nn/conv/ssg_conv.py +1 -1
  142. torch_geometric/nn/conv/transformer_conv.py +5 -3
  143. torch_geometric/nn/data_parallel.py +5 -4
  144. torch_geometric/nn/dense/linear.py +0 -20
  145. torch_geometric/nn/encoding.py +17 -3
  146. torch_geometric/nn/fx.py +14 -12
  147. torch_geometric/nn/model_hub.py +2 -15
  148. torch_geometric/nn/models/__init__.py +11 -2
  149. torch_geometric/nn/models/attentive_fp.py +1 -1
  150. torch_geometric/nn/models/attract_repel.py +148 -0
  151. torch_geometric/nn/models/basic_gnn.py +2 -1
  152. torch_geometric/nn/models/captum.py +1 -1
  153. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  154. torch_geometric/nn/models/dimenet.py +2 -2
  155. torch_geometric/nn/models/dimenet_utils.py +4 -2
  156. torch_geometric/nn/models/gpse.py +1083 -0
  157. torch_geometric/nn/models/graph_unet.py +13 -4
  158. torch_geometric/nn/models/lpformer.py +783 -0
  159. torch_geometric/nn/models/metapath2vec.py +1 -1
  160. torch_geometric/nn/models/mlp.py +4 -2
  161. torch_geometric/nn/models/node2vec.py +1 -1
  162. torch_geometric/nn/models/polynormer.py +206 -0
  163. torch_geometric/nn/models/rev_gnn.py +3 -3
  164. torch_geometric/nn/models/sgformer.py +219 -0
  165. torch_geometric/nn/models/signed_gcn.py +1 -1
  166. torch_geometric/nn/models/visnet.py +2 -2
  167. torch_geometric/nn/norm/batch_norm.py +17 -7
  168. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  169. torch_geometric/nn/norm/graph_norm.py +9 -4
  170. torch_geometric/nn/norm/instance_norm.py +5 -1
  171. torch_geometric/nn/norm/layer_norm.py +15 -7
  172. torch_geometric/nn/norm/msg_norm.py +8 -2
  173. torch_geometric/nn/pool/__init__.py +8 -4
  174. torch_geometric/nn/pool/cluster_pool.py +3 -4
  175. torch_geometric/nn/pool/connect/base.py +1 -3
  176. torch_geometric/nn/pool/knn.py +13 -10
  177. torch_geometric/nn/pool/select/base.py +1 -4
  178. torch_geometric/nn/to_hetero_module.py +4 -3
  179. torch_geometric/nn/to_hetero_transformer.py +3 -3
  180. torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
  181. torch_geometric/profile/__init__.py +2 -0
  182. torch_geometric/profile/nvtx.py +66 -0
  183. torch_geometric/profile/utils.py +20 -5
  184. torch_geometric/sampler/__init__.py +2 -1
  185. torch_geometric/sampler/base.py +336 -7
  186. torch_geometric/sampler/hgt_sampler.py +11 -1
  187. torch_geometric/sampler/neighbor_sampler.py +296 -23
  188. torch_geometric/sampler/utils.py +93 -5
  189. torch_geometric/testing/__init__.py +4 -0
  190. torch_geometric/testing/decorators.py +35 -5
  191. torch_geometric/testing/distributed.py +1 -1
  192. torch_geometric/transforms/__init__.py +2 -0
  193. torch_geometric/transforms/add_gpse.py +49 -0
  194. torch_geometric/transforms/add_metapaths.py +8 -6
  195. torch_geometric/transforms/add_positional_encoding.py +2 -2
  196. torch_geometric/transforms/base_transform.py +2 -1
  197. torch_geometric/transforms/delaunay.py +65 -15
  198. torch_geometric/transforms/face_to_edge.py +32 -3
  199. torch_geometric/transforms/gdc.py +7 -8
  200. torch_geometric/transforms/largest_connected_components.py +1 -1
  201. torch_geometric/transforms/mask.py +5 -1
  202. torch_geometric/transforms/normalize_features.py +3 -3
  203. torch_geometric/transforms/random_link_split.py +1 -1
  204. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  205. torch_geometric/transforms/rooted_subgraph.py +1 -1
  206. torch_geometric/typing.py +70 -17
  207. torch_geometric/utils/__init__.py +4 -1
  208. torch_geometric/utils/_lexsort.py +0 -9
  209. torch_geometric/utils/_negative_sampling.py +27 -12
  210. torch_geometric/utils/_scatter.py +132 -195
  211. torch_geometric/utils/_sort_edge_index.py +0 -2
  212. torch_geometric/utils/_spmm.py +16 -14
  213. torch_geometric/utils/_subgraph.py +4 -0
  214. torch_geometric/utils/_to_dense_batch.py +2 -2
  215. torch_geometric/utils/_trim_to_layer.py +2 -2
  216. torch_geometric/utils/convert.py +17 -10
  217. torch_geometric/utils/cross_entropy.py +34 -13
  218. torch_geometric/utils/embedding.py +91 -2
  219. torch_geometric/utils/geodesic.py +4 -3
  220. torch_geometric/utils/influence.py +279 -0
  221. torch_geometric/utils/map.py +13 -9
  222. torch_geometric/utils/nested.py +1 -1
  223. torch_geometric/utils/smiles.py +3 -3
  224. torch_geometric/utils/sparse.py +7 -14
  225. torch_geometric/visualization/__init__.py +2 -1
  226. torch_geometric/visualization/graph.py +250 -5
  227. torch_geometric/warnings.py +11 -2
  228. torch_geometric/nn/nlp/__init__.py +0 -7
  229. torch_geometric/nn/nlp/sentence_transformer.py +0 -101
@@ -8,6 +8,7 @@ from .encoding import PositionalEncoding, TemporalEncoding
8
8
  from .summary import summary
9
9
 
10
10
  from .aggr import * # noqa
11
+ from .attention import * # noqa
11
12
  from .conv import * # noqa
12
13
  from .pool import * # noqa
13
14
  from .glob import * # noqa
@@ -135,7 +135,7 @@ class Aggregation(torch.nn.Module):
135
135
  if index.numel() > 0 and dim_size <= int(index.max()):
136
136
  raise ValueError(f"Encountered invalid 'dim_size' (got "
137
137
  f"'{dim_size}' but expected "
138
- f">= '{int(index.max()) + 1}')")
138
+ f">= '{int(index.max()) + 1}')") from e
139
139
  raise e
140
140
 
141
141
  def __repr__(self) -> str:
@@ -52,7 +52,7 @@ class MomentumOptimizer(torch.nn.Module):
52
52
  layer. It is based on an unrolled Nesterov momentum algorithm.
53
53
 
54
54
  Args:
55
- learning_rate (flaot): learning rate for optimizer.
55
+ learning_rate (float): learning rate for optimizer.
56
56
  momentum (float): momentum for optimizer.
57
57
  learnable (bool): If :obj:`True` then the :obj:`learning_rate` and
58
58
  :obj:`momentum` will be learnable parameters. If False they
@@ -216,7 +216,7 @@ class FusedAggregation(Aggregation):
216
216
  outs: List[Optional[Tensor]] = []
217
217
 
218
218
  # Iterate over all reduction ops to compute first results:
219
- for i, reduce in enumerate(self.reduce_ops):
219
+ for reduce in self.reduce_ops:
220
220
  if reduce is None:
221
221
  outs.append(None)
222
222
  continue
@@ -32,6 +32,8 @@ class PatchTransformerAggregation(Aggregation):
32
32
  aggr (str or list[str], optional): The aggregation module, *e.g.*,
33
33
  :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
34
34
  :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
35
+ device (torch.device, optional): The device of the module.
36
+ (default: :obj:`None`)
35
37
  """
36
38
  def __init__(
37
39
  self,
@@ -43,6 +45,7 @@ class PatchTransformerAggregation(Aggregation):
43
45
  heads: int = 1,
44
46
  dropout: float = 0.0,
45
47
  aggr: Union[str, List[str]] = 'mean',
48
+ device: Optional[torch.device] = None,
46
49
  ) -> None:
47
50
  super().__init__()
48
51
 
@@ -55,12 +58,13 @@ class PatchTransformerAggregation(Aggregation):
55
58
  for aggr in self.aggrs:
56
59
  assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']
57
60
 
58
- self.lin = torch.nn.Linear(in_channels, hidden_channels)
61
+ self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)
59
62
  self.pad_projector = torch.nn.Linear(
60
63
  patch_size * hidden_channels,
61
64
  hidden_channels,
65
+ device=device,
62
66
  )
63
- self.pe = PositionalEncoding(hidden_channels)
67
+ self.pe = PositionalEncoding(hidden_channels, device=device)
64
68
 
65
69
  self.blocks = torch.nn.ModuleList([
66
70
  MultiheadAttentionBlock(
@@ -68,12 +72,14 @@ class PatchTransformerAggregation(Aggregation):
68
72
  heads=heads,
69
73
  layer_norm=True,
70
74
  dropout=dropout,
75
+ device=device,
71
76
  ) for _ in range(num_transformer_blocks)
72
77
  ])
73
78
 
74
79
  self.fc = torch.nn.Linear(
75
80
  hidden_channels * len(self.aggrs),
76
81
  out_channels,
82
+ device=device,
77
83
  )
78
84
 
79
85
  def reset_parameters(self) -> None:
@@ -38,7 +38,7 @@ class SetTransformerAggregation(Aggregation):
38
38
  (default: :obj:`1`)
39
39
  concat (bool, optional): If set to :obj:`False`, the seed embeddings
40
40
  are averaged instead of concatenated. (default: :obj:`True`)
41
- norm (str, optional): If set to :obj:`True`, will apply layer
41
+ layer_norm (str, optional): If set to :obj:`True`, will apply layer
42
42
  normalization. (default: :obj:`False`)
43
43
  dropout (float, optional): Dropout probability of attention weights.
44
44
  (default: :obj:`0`)
@@ -26,9 +26,11 @@ class MultiheadAttentionBlock(torch.nn.Module):
26
26
  normalization. (default: :obj:`True`)
27
27
  dropout (float, optional): Dropout probability of attention weights.
28
28
  (default: :obj:`0`)
29
+ device (torch.device, optional): The device of the module.
30
+ (default: :obj:`None`)
29
31
  """
30
32
  def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,
31
- dropout: float = 0.0):
33
+ dropout: float = 0.0, device: Optional[torch.device] = None):
32
34
  super().__init__()
33
35
 
34
36
  self.channels = channels
@@ -40,10 +42,13 @@ class MultiheadAttentionBlock(torch.nn.Module):
40
42
  heads,
41
43
  batch_first=True,
42
44
  dropout=dropout,
45
+ device=device,
43
46
  )
44
- self.lin = Linear(channels, channels)
45
- self.layer_norm1 = LayerNorm(channels) if layer_norm else None
46
- self.layer_norm2 = LayerNorm(channels) if layer_norm else None
47
+ self.lin = Linear(channels, channels, device=device)
48
+ self.layer_norm1 = LayerNorm(channels,
49
+ device=device) if layer_norm else None
50
+ self.layer_norm2 = LayerNorm(channels,
51
+ device=device) if layer_norm else None
47
52
 
48
53
  def reset_parameters(self):
49
54
  self.attn._reset_parameters()
@@ -1,3 +1,11 @@
1
1
  from .performer import PerformerAttention
2
+ from .qformer import QFormer
3
+ from .sgformer import SGFormerAttention
4
+ from .polynormer import PolynormerAttention
2
5
 
3
- __all__ = ['PerformerAttention']
6
+ __all__ = classes = [
7
+ 'PerformerAttention',
8
+ 'QFormer',
9
+ 'SGFormerAttention',
10
+ 'PolynormerAttention',
11
+ ]
@@ -0,0 +1,107 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+
8
+ class PolynormerAttention(torch.nn.Module):
9
+ r"""The polynomial-expressive attention mechanism from the
10
+ `"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time"
11
+ <https://arxiv.org/abs/2403.01232>`_ paper.
12
+
13
+ Args:
14
+ channels (int): Size of each input sample.
15
+ heads (int, optional): Number of parallel attention heads.
16
+ head_channels (int, optional): Size of each attention head.
17
+ (default: :obj:`64.`)
18
+ beta (float, optional): Polynormer beta initialization.
19
+ (default: :obj:`0.9`)
20
+ qkv_bias (bool, optional): If specified, add bias to query, key
21
+ and value in the self attention. (default: :obj:`False`)
22
+ qk_shared (bool optional): Whether weight of query and key are shared.
23
+ (default: :obj:`True`)
24
+ dropout (float, optional): Dropout probability of the final
25
+ attention output. (default: :obj:`0.0`)
26
+ """
27
+ def __init__(
28
+ self,
29
+ channels: int,
30
+ heads: int,
31
+ head_channels: int = 64,
32
+ beta: float = 0.9,
33
+ qkv_bias: bool = False,
34
+ qk_shared: bool = True,
35
+ dropout: float = 0.0,
36
+ ) -> None:
37
+ super().__init__()
38
+
39
+ self.head_channels = head_channels
40
+ self.heads = heads
41
+ self.beta = beta
42
+ self.qk_shared = qk_shared
43
+
44
+ inner_channels = heads * head_channels
45
+ self.h_lins = torch.nn.Linear(channels, inner_channels)
46
+ if not self.qk_shared:
47
+ self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
48
+ self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
49
+ self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
50
+ self.lns = torch.nn.LayerNorm(inner_channels)
51
+ self.lin_out = torch.nn.Linear(inner_channels, inner_channels)
52
+ self.dropout = torch.nn.Dropout(dropout)
53
+
54
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
55
+ r"""Forward pass.
56
+
57
+ Args:
58
+ x (torch.Tensor): Node feature tensor
59
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
60
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
61
+ each graph, and feature dimension :math:`F`.
62
+ mask (torch.Tensor, optional): Mask matrix
63
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
64
+ the valid nodes for each graph. (default: :obj:`None`)
65
+ """
66
+ B, N, *_ = x.shape
67
+ h = self.h_lins(x)
68
+ k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)
69
+ if self.qk_shared:
70
+ q = k
71
+ else:
72
+ q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)
73
+ v = self.v(x).view(B, N, self.head_channels, self.heads)
74
+
75
+ if mask is not None:
76
+ mask = mask[:, :, None, None]
77
+ v.masked_fill_(~mask, 0.)
78
+
79
+ # numerator
80
+ kv = torch.einsum('bndh, bnmh -> bdmh', k, v)
81
+ num = torch.einsum('bndh, bdmh -> bnmh', q, kv)
82
+
83
+ # denominator
84
+ k_sum = torch.einsum('bndh -> bdh', k)
85
+ den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)
86
+
87
+ # linear global attention based on kernel trick
88
+ x = (num / (den + 1e-6)).reshape(B, N, -1)
89
+ x = self.lns(x) * (h + self.beta)
90
+ x = F.relu(self.lin_out(x))
91
+ x = self.dropout(x)
92
+
93
+ return x
94
+
95
+ def reset_parameters(self) -> None:
96
+ self.h_lins.reset_parameters()
97
+ if not self.qk_shared:
98
+ self.q.reset_parameters()
99
+ self.k.reset_parameters()
100
+ self.v.reset_parameters()
101
+ self.lns.reset_parameters()
102
+ self.lin_out.reset_parameters()
103
+
104
+ def __repr__(self) -> str:
105
+ return (f'{self.__class__.__name__}('
106
+ f'heads={self.heads}, '
107
+ f'head_channels={self.head_channels})')
@@ -0,0 +1,71 @@
1
+ from typing import Callable
2
+
3
+ import torch
4
+
5
+
6
+ class QFormer(torch.nn.Module):
7
+ r"""The Querying Transformer (Q-Former) from
8
+ `"BLIP-2: Bootstrapping Language-Image Pre-training
9
+ with Frozen Image Encoders and Large Language Models"
10
+ <https://arxiv.org/pdf/2301.12597>`_ paper.
11
+
12
+ Args:
13
+ input_dim (int): The number of features in the input.
14
+ hidden_dim (int): The dimension of the fnn in the encoder layer.
15
+ output_dim (int): The final output dimension.
16
+ num_heads (int): The number of multi-attention-heads.
17
+ num_layers (int): The number of sub-encoder-layers in the encoder.
18
+ dropout (int): The dropout value in each encoder layer.
19
+
20
+
21
+ .. note::
22
+ This is a simplified version of the original Q-Former implementation.
23
+ """
24
+ def __init__(
25
+ self,
26
+ input_dim: int,
27
+ hidden_dim: int,
28
+ output_dim: int,
29
+ num_heads: int,
30
+ num_layers: int,
31
+ dropout: float = 0.0,
32
+ activation: Callable = torch.nn.ReLU(),
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.num_heads = num_heads
38
+
39
+ self.layer_norm = torch.nn.LayerNorm(input_dim)
40
+ self.encoder_layer = torch.nn.TransformerEncoderLayer(
41
+ d_model=input_dim,
42
+ nhead=num_heads,
43
+ dim_feedforward=hidden_dim,
44
+ dropout=dropout,
45
+ activation=activation,
46
+ batch_first=True,
47
+ )
48
+ self.encoder = torch.nn.TransformerEncoder(
49
+ self.encoder_layer,
50
+ num_layers=num_layers,
51
+ )
52
+ self.project = torch.nn.Linear(input_dim, output_dim)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ r"""Forward pass.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input sequence to the encoder layer.
59
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
60
+ batch-size :math:`B`, sequence length :math:`N`,
61
+ and feature dimension :math:`F`.
62
+ """
63
+ x = self.layer_norm(x)
64
+ x = self.encoder(x)
65
+ out = self.project(x)
66
+ return out
67
+
68
+ def __repr__(self) -> str:
69
+ return (f'{self.__class__.__name__}('
70
+ f'num_heads={self.num_heads}, '
71
+ f'num_layers={self.num_layers})')
@@ -0,0 +1,99 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class SGFormerAttention(torch.nn.Module):
8
+ r"""The simple global attention mechanism from the
9
+ `"SGFormer: Simplifying and Empowering Transformers for
10
+ Large-Graph Representations"
11
+ <https://arxiv.org/abs/2306.10759>`_ paper.
12
+
13
+ Args:
14
+ channels (int): Size of each input sample.
15
+ heads (int, optional): Number of parallel attention heads.
16
+ (default: :obj:`1.`)
17
+ head_channels (int, optional): Size of each attention head.
18
+ (default: :obj:`64.`)
19
+ qkv_bias (bool, optional): If specified, add bias to query, key
20
+ and value in the self attention. (default: :obj:`False`)
21
+ """
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ heads: int = 1,
26
+ head_channels: int = 64,
27
+ qkv_bias: bool = False,
28
+ ) -> None:
29
+ super().__init__()
30
+ assert channels % heads == 0
31
+ if head_channels is None:
32
+ head_channels = channels // heads
33
+
34
+ self.heads = heads
35
+ self.head_channels = head_channels
36
+
37
+ inner_channels = head_channels * heads
38
+ self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
39
+ self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
40
+ self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
41
+
42
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
43
+ r"""Forward pass.
44
+
45
+ Args:
46
+ x (torch.Tensor): Node feature tensor
47
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
48
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
49
+ each graph, and feature dimension :math:`F`.
50
+ mask (torch.Tensor, optional): Mask matrix
51
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
52
+ the valid nodes for each graph. (default: :obj:`None`)
53
+ """
54
+ B, N, *_ = x.shape
55
+ qs, ks, vs = self.q(x), self.k(x), self.v(x)
56
+ # reshape and permute q, k and v to proper shape
57
+ # (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels)
58
+ qs, ks, vs = map(
59
+ lambda t: t.reshape(B, N, self.heads, self.head_channels),
60
+ (qs, ks, vs))
61
+
62
+ if mask is not None:
63
+ mask = mask[:, :, None, None]
64
+ vs.masked_fill_(~mask, 0.)
65
+ # replace 0's with epsilon
66
+ epsilon = 1e-6
67
+ qs[qs == 0] = epsilon
68
+ ks[ks == 0] = epsilon
69
+ # normalize input, shape not changed
70
+ qs, ks = map(
71
+ lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True),
72
+ (qs, ks))
73
+
74
+ # numerator
75
+ kvs = torch.einsum("blhm,blhd->bhmd", ks, vs)
76
+ attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs)
77
+ attention_num += N * vs
78
+
79
+ # denominator
80
+ all_ones = torch.ones([B, N]).to(ks.device)
81
+ ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones)
82
+ attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum)
83
+ # attentive aggregated results
84
+ attention_normalizer = torch.unsqueeze(attention_normalizer,
85
+ len(attention_normalizer.shape))
86
+ attention_normalizer += torch.ones_like(attention_normalizer) * N
87
+ attn_output = attention_num / attention_normalizer
88
+
89
+ return attn_output.mean(dim=2)
90
+
91
+ def reset_parameters(self):
92
+ self.q.reset_parameters()
93
+ self.k.reset_parameters()
94
+ self.v.reset_parameters()
95
+
96
+ def __repr__(self) -> str:
97
+ return (f'{self.__class__.__name__}('
98
+ f'heads={self.heads}, '
99
+ f'head_channels={self.head_channels})')
@@ -61,6 +61,7 @@ from .gps_conv import GPSConv
61
61
  from .antisymmetric_conv import AntiSymmetricConv
62
62
  from .dir_gnn_conv import DirGNNConv
63
63
  from .mixhop_conv import MixHopConv
64
+ from .meshcnn_conv import MeshCNNConv
64
65
 
65
66
  import torch_geometric.nn.conv.utils # noqa
66
67
 
@@ -131,6 +132,7 @@ __all__ = [
131
132
  'AntiSymmetricConv',
132
133
  'DirGNNConv',
133
134
  'MixHopConv',
135
+ 'MeshCNNConv',
134
136
  ]
135
137
 
136
138
  classes = __all__
@@ -109,7 +109,7 @@ class APPNP(MessagePassing):
109
109
  edge_index = cache
110
110
 
111
111
  h = x
112
- for k in range(self.K):
112
+ for _ in range(self.K):
113
113
  if self.dropout > 0 and self.training:
114
114
  if isinstance(edge_index, Tensor):
115
115
  if is_torch_sparse_tensor(edge_index):
@@ -26,6 +26,9 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
26
26
  :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops`
27
27
  package that fuses message passing computation for accelerated execution
28
28
  and lower memory footprint.
29
+ The current method to enable :obj:`cugraph-ops`
30
+ is to use `The NVIDIA PyG Container
31
+ <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
29
32
  """
30
33
  def __init__(
31
34
  self,
@@ -67,6 +70,7 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
67
70
  self,
68
71
  x: Tensor,
69
72
  edge_index: EdgeIndex,
73
+ edge_attr: Tensor,
70
74
  max_num_neighbors: Optional[int] = None,
71
75
  ) -> Tensor:
72
76
  graph = self.get_cugraph(edge_index, max_num_neighbors)
@@ -75,10 +79,12 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
75
79
 
76
80
  if LEGACY_MODE:
77
81
  out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
78
- self.negative_slope, False, self.concat)
82
+ self.negative_slope, False, self.concat,
83
+ edge_feat=edge_attr)
79
84
  else:
80
85
  out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
81
- self.negative_slope, self.concat)
86
+ self.negative_slope, self.concat,
87
+ edge_feat=edge_attr)
82
88
 
83
89
  if self.bias is not None:
84
90
  out = out + self.bias
@@ -29,6 +29,9 @@ class CuGraphRGCNConv(CuGraphModule): # pragma: no cover
29
29
  :class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops`
30
30
  package that fuses message passing computation for accelerated execution
31
31
  and lower memory footprint.
32
+ The current method to enable :obj:`cugraph-ops`
33
+ is to use `The NVIDIA PyG Container
34
+ <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
32
35
  """
33
36
  def __init__(self, in_channels: int, out_channels: int, num_relations: int,
34
37
  num_bases: Optional[int] = None, aggr: str = 'mean',
@@ -27,6 +27,9 @@ class CuGraphSAGEConv(CuGraphModule): # pragma: no cover
27
27
  :class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops`
28
28
  package that fuses message passing computation for accelerated execution
29
29
  and lower memory footprint.
30
+ The current method to enable :obj:`cugraph-ops`
31
+ is to use `The NVIDIA PyG Container
32
+ <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
30
33
  """
31
34
  def __init__(
32
35
  self,
@@ -163,7 +163,7 @@ class MultiHead(Attention):
163
163
  def __repr__(self) -> str: # pragma: no cover
164
164
  return (f'{self.__class__.__name__}({self.in_channels}, '
165
165
  f'{self.out_channels}, heads={self.heads}, '
166
- f'groups={self.groups}, dropout={self.droput}, '
166
+ f'groups={self.groups}, dropout={self.dropout}, '
167
167
  f'bias={self.bias})')
168
168
 
169
169
 
@@ -81,7 +81,7 @@ class EGConv(MessagePassing):
81
81
  self,
82
82
  in_channels: int,
83
83
  out_channels: int,
84
- aggregators: List[str] = ['symnorm'],
84
+ aggregators: Optional[List[str]] = None,
85
85
  num_heads: int = 8,
86
86
  num_bases: int = 4,
87
87
  cached: bool = False,
@@ -96,23 +96,23 @@ class EGConv(MessagePassing):
96
96
  f"divisible by the number of heads "
97
97
  f"(got {num_heads})")
98
98
 
99
- for a in aggregators:
100
- if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
101
- raise ValueError(f"Unsupported aggregator: '{a}'")
102
-
103
99
  self.in_channels = in_channels
104
100
  self.out_channels = out_channels
105
101
  self.num_heads = num_heads
106
102
  self.num_bases = num_bases
107
103
  self.cached = cached
108
104
  self.add_self_loops = add_self_loops
109
- self.aggregators = aggregators
105
+ self.aggregators = aggregators or ['symnorm']
106
+
107
+ for a in self.aggregators:
108
+ if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
109
+ raise ValueError(f"Unsupported aggregator: '{a}'")
110
110
 
111
111
  self.bases_lin = Linear(in_channels,
112
112
  (out_channels // num_heads) * num_bases,
113
113
  bias=False, weight_initializer='glorot')
114
114
  self.comb_lin = Linear(in_channels,
115
- num_heads * num_bases * len(aggregators))
115
+ num_heads * num_bases * len(self.aggregators))
116
116
 
117
117
  if bias:
118
118
  self.bias = Parameter(torch.empty(out_channels))
@@ -178,7 +178,7 @@ class GENConv(MessagePassing):
178
178
  self.lin_dst = Linear(in_channels[1], out_channels, bias=bias)
179
179
 
180
180
  channels = [out_channels]
181
- for i in range(num_layers - 1):
181
+ for _ in range(num_layers - 1):
182
182
  channels.append(out_channels * expansion)
183
183
  channels.append(out_channels)
184
184
  self.mlp = MLP(channels, norm=norm, bias=bias)
@@ -63,7 +63,8 @@ class GravNetConv(MessagePassing):
63
63
  if num_workers is not None:
64
64
  warnings.warn(
65
65
  "'num_workers' attribute in '{self.__class__.__name__}' is "
66
- "deprecated and will be removed in a future release")
66
+ "deprecated and will be removed in a future release",
67
+ stacklevel=2)
67
68
 
68
69
  self.in_channels = in_channels
69
70
  self.out_channels = out_channels
@@ -77,7 +77,8 @@ class HeteroConv(torch.nn.Module):
77
77
  f"There exist node types ({src_node_types - dst_node_types}) "
78
78
  f"whose representations do not get updated during message "
79
79
  f"passing as they do not occur as destination type in any "
80
- f"edge type. This may lead to unexpected behavior.")
80
+ f"edge type. This may lead to unexpected behavior.",
81
+ stacklevel=2)
81
82
 
82
83
  self.convs = ModuleDict(convs)
83
84
  self.aggr = aggr