pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -233,7 +233,7 @@ class MetaPath2Vec(torch.nn.Module):
233
233
  """
234
234
  from sklearn.linear_model import LogisticRegression
235
235
 
236
- clf = LogisticRegression(solver=solver, *args,
236
+ clf = LogisticRegression(*args, solver=solver,
237
237
  **kwargs).fit(train_z.detach().cpu().numpy(),
238
238
  train_y.detach().cpu().numpy())
239
239
  return clf.score(test_z.detach().cpu().numpy(),
@@ -99,8 +99,10 @@ class MLP(torch.nn.Module):
99
99
  act_first = act_first or kwargs.get("relu_first", False)
100
100
  batch_norm = kwargs.get("batch_norm", None)
101
101
  if batch_norm is not None and isinstance(batch_norm, bool):
102
- warnings.warn("Argument `batch_norm` is deprecated, "
103
- "please use `norm` to specify normalization layer.")
102
+ warnings.warn(
103
+ "Argument `batch_norm` is deprecated, "
104
+ "please use `norm` to specify normalization layer.",
105
+ stacklevel=2)
104
106
  norm = 'batch_norm' if batch_norm else None
105
107
  batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
106
108
  norm_kwargs = batch_norm_kwargs or {}
@@ -181,7 +181,7 @@ class Node2Vec(torch.nn.Module):
181
181
  """
182
182
  from sklearn.linear_model import LogisticRegression
183
183
 
184
- clf = LogisticRegression(solver=solver, *args,
184
+ clf = LogisticRegression(*args, solver=solver,
185
185
  **kwargs).fit(train_z.detach().cpu().numpy(),
186
186
  train_y.detach().cpu().numpy())
187
187
  return clf.score(test_z.detach().cpu().numpy(),
@@ -0,0 +1,206 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.nn import GATConv, GCNConv
8
+ from torch_geometric.nn.attention import PolynormerAttention
9
+ from torch_geometric.utils import to_dense_batch
10
+
11
+
12
+ class Polynormer(torch.nn.Module):
13
+ r"""The polynormer module from the
14
+ `"Polynormer: polynomial-expressive graph
15
+ transformer in linear time"
16
+ <https://arxiv.org/abs/2403.01232>`_ paper.
17
+
18
+ Args:
19
+ in_channels (int): Input channels.
20
+ hidden_channels (int): Hidden channels.
21
+ out_channels (int): Output channels.
22
+ local_layers (int): The number of local attention layers.
23
+ (default: :obj:`7`)
24
+ global_layers (int): The number of global attention layers.
25
+ (default: :obj:`2`)
26
+ in_dropout (float): Input dropout rate.
27
+ (default: :obj:`0.15`)
28
+ dropout (float): Dropout rate.
29
+ (default: :obj:`0.5`)
30
+ global_dropout (float): Global dropout rate.
31
+ (default: :obj:`0.5`)
32
+ heads (int): The number of heads.
33
+ (default: :obj:`1`)
34
+ beta (float): Aggregate type.
35
+ (default: :obj:`0.9`)
36
+ qk_shared (bool optional): Whether weight of query and key are shared.
37
+ (default: :obj:`True`)
38
+ pre_ln (bool): Pre layer normalization.
39
+ (default: :obj:`False`)
40
+ post_bn (bool): Post batch normalization.
41
+ (default: :obj:`True`)
42
+ local_attn (bool): Whether use local attention.
43
+ (default: :obj:`False`)
44
+ """
45
+ def __init__(
46
+ self,
47
+ in_channels: int,
48
+ hidden_channels: int,
49
+ out_channels: int,
50
+ local_layers: int = 7,
51
+ global_layers: int = 2,
52
+ in_dropout: float = 0.15,
53
+ dropout: float = 0.5,
54
+ global_dropout: float = 0.5,
55
+ heads: int = 1,
56
+ beta: float = 0.9,
57
+ qk_shared: bool = False,
58
+ pre_ln: bool = False,
59
+ post_bn: bool = True,
60
+ local_attn: bool = False,
61
+ ) -> None:
62
+ super().__init__()
63
+ self._global = False
64
+ self.in_drop = in_dropout
65
+ self.dropout = dropout
66
+ self.pre_ln = pre_ln
67
+ self.post_bn = post_bn
68
+
69
+ self.beta = beta
70
+
71
+ self.h_lins = torch.nn.ModuleList()
72
+ self.local_convs = torch.nn.ModuleList()
73
+ self.lins = torch.nn.ModuleList()
74
+ self.lns = torch.nn.ModuleList()
75
+ if self.pre_ln:
76
+ self.pre_lns = torch.nn.ModuleList()
77
+ if self.post_bn:
78
+ self.post_bns = torch.nn.ModuleList()
79
+
80
+ # first layer
81
+ inner_channels = heads * hidden_channels
82
+ self.h_lins.append(torch.nn.Linear(in_channels, inner_channels))
83
+ if local_attn:
84
+ self.local_convs.append(
85
+ GATConv(in_channels, hidden_channels, heads=heads, concat=True,
86
+ add_self_loops=False, bias=False))
87
+ else:
88
+ self.local_convs.append(
89
+ GCNConv(in_channels, inner_channels, cached=False,
90
+ normalize=True))
91
+
92
+ self.lins.append(torch.nn.Linear(in_channels, inner_channels))
93
+ self.lns.append(torch.nn.LayerNorm(inner_channels))
94
+ if self.pre_ln:
95
+ self.pre_lns.append(torch.nn.LayerNorm(in_channels))
96
+ if self.post_bn:
97
+ self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
98
+
99
+ # following layers
100
+ for _ in range(local_layers - 1):
101
+ self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels))
102
+ if local_attn:
103
+ self.local_convs.append(
104
+ GATConv(inner_channels, hidden_channels, heads=heads,
105
+ concat=True, add_self_loops=False, bias=False))
106
+ else:
107
+ self.local_convs.append(
108
+ GCNConv(inner_channels, inner_channels, cached=False,
109
+ normalize=True))
110
+
111
+ self.lins.append(torch.nn.Linear(inner_channels, inner_channels))
112
+ self.lns.append(torch.nn.LayerNorm(inner_channels))
113
+ if self.pre_ln:
114
+ self.pre_lns.append(torch.nn.LayerNorm(heads *
115
+ hidden_channels))
116
+ if self.post_bn:
117
+ self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))
118
+
119
+ self.lin_in = torch.nn.Linear(in_channels, inner_channels)
120
+ self.ln = torch.nn.LayerNorm(inner_channels)
121
+
122
+ self.global_attn = torch.nn.ModuleList()
123
+ for _ in range(global_layers):
124
+ self.global_attn.append(
125
+ PolynormerAttention(
126
+ channels=hidden_channels,
127
+ heads=heads,
128
+ head_channels=hidden_channels,
129
+ beta=beta,
130
+ dropout=global_dropout,
131
+ qk_shared=qk_shared,
132
+ ))
133
+ self.pred_local = torch.nn.Linear(inner_channels, out_channels)
134
+ self.pred_global = torch.nn.Linear(inner_channels, out_channels)
135
+ self.reset_parameters()
136
+
137
+ def reset_parameters(self) -> None:
138
+ for local_conv in self.local_convs:
139
+ local_conv.reset_parameters()
140
+ for attn in self.global_attn:
141
+ attn.reset_parameters()
142
+ for lin in self.lins:
143
+ lin.reset_parameters()
144
+ for h_lin in self.h_lins:
145
+ h_lin.reset_parameters()
146
+ for ln in self.lns:
147
+ ln.reset_parameters()
148
+ if self.pre_ln:
149
+ for p_ln in self.pre_lns:
150
+ p_ln.reset_parameters()
151
+ if self.post_bn:
152
+ for p_bn in self.post_bns:
153
+ p_bn.reset_parameters()
154
+ self.lin_in.reset_parameters()
155
+ self.ln.reset_parameters()
156
+ self.pred_local.reset_parameters()
157
+ self.pred_global.reset_parameters()
158
+
159
+ def forward(
160
+ self,
161
+ x: Tensor,
162
+ edge_index: Tensor,
163
+ batch: Optional[Tensor],
164
+ ) -> Tensor:
165
+ r"""Forward pass.
166
+
167
+ Args:
168
+ x (torch.Tensor): The input node features.
169
+ edge_index (torch.Tensor or SparseTensor): The edge indices.
170
+ batch (torch.Tensor, optional): The batch vector
171
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
172
+ each element to a specific example.
173
+ """
174
+ x = F.dropout(x, p=self.in_drop, training=self.training)
175
+
176
+ # equivariant local attention
177
+ x_local = 0
178
+ for i, local_conv in enumerate(self.local_convs):
179
+ if self.pre_ln:
180
+ x = self.pre_lns[i](x)
181
+ h = self.h_lins[i](x)
182
+ h = F.relu(h)
183
+ x = local_conv(x, edge_index) + self.lins[i](x)
184
+ if self.post_bn:
185
+ x = self.post_bns[i](x)
186
+ x = F.relu(x)
187
+ x = F.dropout(x, p=self.dropout, training=self.training)
188
+ x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x
189
+ x_local = x_local + x
190
+
191
+ # equivariant global attention
192
+ if self._global:
193
+ batch, indices = batch.sort()
194
+ rev_perm = torch.empty_like(indices)
195
+ rev_perm[indices] = torch.arange(len(indices),
196
+ device=indices.device)
197
+ x_local = self.ln(x_local[indices])
198
+ x_global, mask = to_dense_batch(x_local, batch)
199
+ for attn in self.global_attn:
200
+ x_global = attn(x_global, mask)
201
+ x = x_global[mask][rev_perm]
202
+ x = self.pred_global(x)
203
+ else:
204
+ x = self.pred_local(x_local)
205
+
206
+ return F.log_softmax(x, dim=-1)
@@ -196,8 +196,8 @@ class InvertibleModule(torch.nn.Module):
196
196
  class GroupAddRev(InvertibleModule):
197
197
  r"""The Grouped Reversible GNN module from the `"Graph Neural Networks with
198
198
  1000 Layers" <https://arxiv.org/abs/2106.07476>`_ paper.
199
- This module enables training of arbitary deep GNNs with a memory complexity
200
- independent of the number of layers.
199
+ This module enables training of arbitrary deep GNNs with a memory
200
+ complexity independent of the number of layers.
201
201
 
202
202
  It does so by partitioning input node features :math:`\mathbf{X}` into
203
203
  :math:`C` groups across the feature dimension. Then, a grouped reversible
@@ -249,7 +249,7 @@ class GroupAddRev(InvertibleModule):
249
249
  else:
250
250
  assert num_groups is not None, "Please specific 'num_groups'"
251
251
  self.convs = torch.nn.ModuleList([conv])
252
- for i in range(num_groups - 1):
252
+ for _ in range(num_groups - 1):
253
253
  conv = copy.deepcopy(self.convs[0])
254
254
  if hasattr(conv, 'reset_parameters'):
255
255
  conv.reset_parameters()
@@ -11,6 +11,7 @@ from torch import Tensor
11
11
  from torch.nn import Embedding, Linear, ModuleList, Sequential
12
12
 
13
13
  from torch_geometric.data import Dataset, download_url, extract_zip
14
+ from torch_geometric.io import fs
14
15
  from torch_geometric.nn import MessagePassing, SumAggregation, radius_graph
15
16
  from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
16
17
  from torch_geometric.typing import OptTensor
@@ -216,7 +217,7 @@ class SchNet(torch.nn.Module):
216
217
 
217
218
  with warnings.catch_warnings():
218
219
  warnings.simplefilter('ignore')
219
- state = torch.load(path, map_location='cpu')
220
+ state = fs.torch_load(path, map_location='cpu')
220
221
 
221
222
  net = SchNet(
222
223
  hidden_channels=128,
@@ -0,0 +1,219 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+ from torch_geometric.nn.attention import SGFormerAttention
8
+ from torch_geometric.nn.conv import GCNConv
9
+ from torch_geometric.utils import to_dense_batch
10
+
11
+
12
+ class GraphModule(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels,
16
+ hidden_channels,
17
+ num_layers=2,
18
+ dropout=0.5,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.convs = torch.nn.ModuleList()
23
+ self.fcs = torch.nn.ModuleList()
24
+ self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
25
+
26
+ self.bns = torch.nn.ModuleList()
27
+ self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
28
+ for _ in range(num_layers):
29
+ self.convs.append(GCNConv(hidden_channels, hidden_channels))
30
+ self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
31
+
32
+ self.dropout = dropout
33
+ self.activation = F.relu
34
+
35
+ def reset_parameters(self):
36
+ for conv in self.convs:
37
+ conv.reset_parameters()
38
+ for bn in self.bns:
39
+ bn.reset_parameters()
40
+ for fc in self.fcs:
41
+ fc.reset_parameters()
42
+
43
+ def forward(self, x, edge_index):
44
+ x = self.fcs[0](x)
45
+ x = self.bns[0](x)
46
+ x = self.activation(x)
47
+ x = F.dropout(x, p=self.dropout, training=self.training)
48
+ last_x = x
49
+
50
+ for i, conv in enumerate(self.convs):
51
+ x = conv(x, edge_index)
52
+ x = self.bns[i + 1](x)
53
+ x = self.activation(x)
54
+ x = F.dropout(x, p=self.dropout, training=self.training)
55
+ x = x + last_x
56
+ return x
57
+
58
+
59
+ class SGModule(torch.nn.Module):
60
+ def __init__(
61
+ self,
62
+ in_channels,
63
+ hidden_channels,
64
+ num_layers=2,
65
+ num_heads=1,
66
+ dropout=0.5,
67
+ ):
68
+ super().__init__()
69
+
70
+ self.attns = torch.nn.ModuleList()
71
+ self.fcs = torch.nn.ModuleList()
72
+ self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
73
+ self.bns = torch.nn.ModuleList()
74
+ self.bns.append(torch.nn.LayerNorm(hidden_channels))
75
+ for _ in range(num_layers):
76
+ self.attns.append(
77
+ SGFormerAttention(hidden_channels, num_heads, hidden_channels))
78
+ self.bns.append(torch.nn.LayerNorm(hidden_channels))
79
+
80
+ self.dropout = dropout
81
+ self.activation = F.relu
82
+
83
+ def reset_parameters(self):
84
+ for attn in self.attns:
85
+ attn.reset_parameters()
86
+ for bn in self.bns:
87
+ bn.reset_parameters()
88
+ for fc in self.fcs:
89
+ fc.reset_parameters()
90
+
91
+ def forward(self, x: Tensor, batch: Tensor):
92
+ # to dense batch expects sorted batch
93
+ batch, indices = batch.sort(stable=True)
94
+ rev_perm = torch.empty_like(indices)
95
+ rev_perm[indices] = torch.arange(len(indices), device=indices.device)
96
+ x = x[indices]
97
+ x, mask = to_dense_batch(x, batch)
98
+ layer_ = []
99
+
100
+ # input MLP layer
101
+ x = self.fcs[0](x)
102
+ x = self.bns[0](x)
103
+ x = self.activation(x)
104
+ x = F.dropout(x, p=self.dropout, training=self.training)
105
+
106
+ # store as residual link
107
+ layer_.append(x)
108
+
109
+ for i, attn in enumerate(self.attns):
110
+ x = attn(x, mask)
111
+ x = (x + layer_[i]) / 2.
112
+ x = self.bns[i + 1](x)
113
+ x = self.activation(x)
114
+ x = F.dropout(x, p=self.dropout, training=self.training)
115
+ layer_.append(x)
116
+
117
+ x_mask = x[mask]
118
+ # reverse the sorting
119
+ unsorted_x_mask = x_mask[rev_perm]
120
+ return unsorted_x_mask
121
+
122
+
123
+ class SGFormer(torch.nn.Module):
124
+ r"""The sgformer module from the
125
+ `"SGFormer: Simplifying and Empowering Transformers for
126
+ Large-Graph Representations"
127
+ <https://arxiv.org/abs/2306.10759>`_ paper.
128
+
129
+ Args:
130
+ in_channels (int): Input channels.
131
+ hidden_channels (int): Hidden channels.
132
+ out_channels (int): Output channels.
133
+ trans_num_layers (int): The number of layers for all-pair attention.
134
+ (default: :obj:`2`)
135
+ trans_num_heads (int): The number of heads for attention.
136
+ (default: :obj:`1`)
137
+ trans_dropout (float): Global dropout rate.
138
+ (default: :obj:`0.5`)
139
+ gnn_num_layers (int): The number of layers for GNN.
140
+ (default: :obj:`3`)
141
+ gnn_dropout (float): GNN dropout rate.
142
+ (default: :obj:`0.5`)
143
+ graph_weight (float): The weight balance global and gnn module.
144
+ (default: :obj:`0.5`)
145
+ aggregate (str): Aggregate type.
146
+ (default: :obj:`add`)
147
+ """
148
+ def __init__(
149
+ self,
150
+ in_channels: int,
151
+ hidden_channels: int,
152
+ out_channels: int,
153
+ trans_num_layers: int = 2,
154
+ trans_num_heads: int = 1,
155
+ trans_dropout: float = 0.5,
156
+ gnn_num_layers: int = 3,
157
+ gnn_dropout: float = 0.5,
158
+ graph_weight: float = 0.5,
159
+ aggregate: str = 'add',
160
+ ):
161
+ super().__init__()
162
+ self.trans_conv = SGModule(
163
+ in_channels,
164
+ hidden_channels,
165
+ trans_num_layers,
166
+ trans_num_heads,
167
+ trans_dropout,
168
+ )
169
+ self.graph_conv = GraphModule(
170
+ in_channels,
171
+ hidden_channels,
172
+ gnn_num_layers,
173
+ gnn_dropout,
174
+ )
175
+ self.graph_weight = graph_weight
176
+
177
+ self.aggregate = aggregate
178
+
179
+ if aggregate == 'add':
180
+ self.fc = torch.nn.Linear(hidden_channels, out_channels)
181
+ elif aggregate == 'cat':
182
+ self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)
183
+ else:
184
+ raise ValueError(f'Invalid aggregate type:{aggregate}')
185
+
186
+ self.params1 = list(self.trans_conv.parameters())
187
+ self.params2 = list(self.graph_conv.parameters())
188
+ self.params2.extend(list(self.fc.parameters()))
189
+
190
+ self.out_channels = out_channels
191
+
192
+ def reset_parameters(self) -> None:
193
+ self.trans_conv.reset_parameters()
194
+ self.graph_conv.reset_parameters()
195
+ self.fc.reset_parameters()
196
+
197
+ def forward(
198
+ self,
199
+ x: Tensor,
200
+ edge_index: Tensor,
201
+ batch: Optional[Tensor],
202
+ ) -> Tensor:
203
+ r"""Forward pass.
204
+
205
+ Args:
206
+ x (torch.Tensor): The input node features.
207
+ edge_index (torch.Tensor or SparseTensor): The edge indices.
208
+ batch (torch.Tensor, optional): The batch vector
209
+ :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
210
+ each element to a specific example.
211
+ """
212
+ x1 = self.trans_conv(x, batch)
213
+ x2 = self.graph_conv(x, edge_index)
214
+ if self.aggregate == 'add':
215
+ x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
216
+ else:
217
+ x = torch.cat((x1, x2), dim=1)
218
+ x = self.fc(x)
219
+ return F.log_softmax(x, dim=-1)
@@ -45,7 +45,7 @@ class SignedGCN(torch.nn.Module):
45
45
  self.conv1 = SignedConv(in_channels, hidden_channels // 2,
46
46
  first_aggr=True)
47
47
  self.convs = torch.nn.ModuleList()
48
- for i in range(num_layers - 1):
48
+ for _ in range(num_layers - 1):
49
49
  self.convs.append(
50
50
  SignedConv(hidden_channels // 2, hidden_channels // 2,
51
51
  first_aggr=False))
@@ -11,7 +11,7 @@ from torch_geometric.utils import scatter
11
11
 
12
12
 
13
13
  class CosineCutoff(torch.nn.Module):
14
- r"""Appies a cosine cutoff to the input distances.
14
+ r"""Applies a cosine cutoff to the input distances.
15
15
 
16
16
  .. math::
17
17
  \text{cutoffs} =
@@ -572,7 +572,7 @@ class ViS_MP(MessagePassing):
572
572
  d_ij: Tensor,
573
573
  ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
574
574
  r"""Computes the residual scalar and vector features of the nodes and
575
- scalar featues of the edges.
575
+ scalar features of the edges.
576
576
 
577
577
  Args:
578
578
  x (torch.Tensor): The scalar features of the nodes.
@@ -39,6 +39,8 @@ class BatchNorm(torch.nn.Module):
39
39
  with only a single element will work as during in evaluation.
40
40
  That is the running mean and variance will be used.
41
41
  Requires :obj:`track_running_stats=True`. (default: :obj:`False`)
42
+ device (torch.device, optional): The device to use for the module.
43
+ (default: :obj:`None`)
42
44
  """
43
45
  def __init__(
44
46
  self,
@@ -48,6 +50,7 @@ class BatchNorm(torch.nn.Module):
48
50
  affine: bool = True,
49
51
  track_running_stats: bool = True,
50
52
  allow_single_element: bool = False,
53
+ device: Optional[torch.device] = None,
51
54
  ):
52
55
  super().__init__()
53
56
 
@@ -56,7 +59,7 @@ class BatchNorm(torch.nn.Module):
56
59
  "'track_running_stats' to be set to `True`")
57
60
 
58
61
  self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
59
- track_running_stats)
62
+ track_running_stats, device=device)
60
63
  self.in_channels = in_channels
61
64
  self.allow_single_element = allow_single_element
62
65
 
@@ -114,6 +117,8 @@ class HeteroBatchNorm(torch.nn.Module):
114
117
  :obj:`False`, this module does not track such statistics and always
115
118
  uses batch statistics in both training and eval modes.
116
119
  (default: :obj:`True`)
120
+ device (torch.device, optional): The device to use for the module.
121
+ (default: :obj:`None`)
117
122
  """
118
123
  def __init__(
119
124
  self,
@@ -123,6 +128,7 @@ class HeteroBatchNorm(torch.nn.Module):
123
128
  momentum: Optional[float] = 0.1,
124
129
  affine: bool = True,
125
130
  track_running_stats: bool = True,
131
+ device: Optional[torch.device] = None,
126
132
  ):
127
133
  super().__init__()
128
134
 
@@ -134,17 +140,21 @@ class HeteroBatchNorm(torch.nn.Module):
134
140
  self.track_running_stats = track_running_stats
135
141
 
136
142
  if self.affine:
137
- self.weight = Parameter(torch.empty(num_types, in_channels))
138
- self.bias = Parameter(torch.empty(num_types, in_channels))
143
+ self.weight = Parameter(
144
+ torch.empty(num_types, in_channels, device=device))
145
+ self.bias = Parameter(
146
+ torch.empty(num_types, in_channels, device=device))
139
147
  else:
140
148
  self.register_parameter('weight', None)
141
149
  self.register_parameter('bias', None)
142
150
 
143
151
  if self.track_running_stats:
144
- self.register_buffer('running_mean',
145
- torch.empty(num_types, in_channels))
146
- self.register_buffer('running_var',
147
- torch.empty(num_types, in_channels))
152
+ self.register_buffer(
153
+ 'running_mean',
154
+ torch.empty(num_types, in_channels, device=device))
155
+ self.register_buffer(
156
+ 'running_var',
157
+ torch.empty(num_types, in_channels, device=device))
148
158
  self.register_buffer('num_batches_tracked', torch.tensor(0))
149
159
  else:
150
160
  self.register_buffer('running_mean', None)
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
  from torch import Tensor
3
5
  from torch.nn import BatchNorm1d, Linear
@@ -39,6 +41,8 @@ class DiffGroupNorm(torch.nn.Module):
39
41
  :obj:`False`, this module does not track such statistics and always
40
42
  uses batch statistics in both training and eval modes.
41
43
  (default: :obj:`True`)
44
+ device (torch.device, optional): The device to use for the module.
45
+ (default: :obj:`None`)
42
46
  """
43
47
  def __init__(
44
48
  self,
@@ -49,6 +53,7 @@ class DiffGroupNorm(torch.nn.Module):
49
53
  momentum: float = 0.1,
50
54
  affine: bool = True,
51
55
  track_running_stats: bool = True,
56
+ device: Optional[torch.device] = None,
52
57
  ):
53
58
  super().__init__()
54
59
 
@@ -56,9 +61,9 @@ class DiffGroupNorm(torch.nn.Module):
56
61
  self.groups = groups
57
62
  self.lamda = lamda
58
63
 
59
- self.lin = Linear(in_channels, groups, bias=False)
64
+ self.lin = Linear(in_channels, groups, bias=False, device=device)
60
65
  self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,
61
- track_running_stats)
66
+ track_running_stats, device=device)
62
67
 
63
68
  self.reset_parameters()
64
69
 
@@ -26,16 +26,21 @@ class GraphNorm(torch.nn.Module):
26
26
  in_channels (int): Size of each input sample.
27
27
  eps (float, optional): A value added to the denominator for numerical
28
28
  stability. (default: :obj:`1e-5`)
29
+ device (torch.device, optional): The device to use for the module.
30
+ (default: :obj:`None`)
29
31
  """
30
- def __init__(self, in_channels: int, eps: float = 1e-5):
32
+ def __init__(self, in_channels: int, eps: float = 1e-5,
33
+ device: Optional[torch.device] = None):
31
34
  super().__init__()
32
35
 
33
36
  self.in_channels = in_channels
34
37
  self.eps = eps
35
38
 
36
- self.weight = torch.nn.Parameter(torch.empty(in_channels))
37
- self.bias = torch.nn.Parameter(torch.empty(in_channels))
38
- self.mean_scale = torch.nn.Parameter(torch.empty(in_channels))
39
+ self.weight = torch.nn.Parameter(
40
+ torch.empty(in_channels, device=device))
41
+ self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))
42
+ self.mean_scale = torch.nn.Parameter(
43
+ torch.empty(in_channels, device=device))
39
44
 
40
45
  self.reset_parameters()
41
46
 
@@ -1,5 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
+ import torch
3
4
  import torch.nn.functional as F
4
5
  from torch import Tensor
5
6
  from torch.nn.modules.instancenorm import _InstanceNorm
@@ -36,6 +37,8 @@ class InstanceNorm(_InstanceNorm):
36
37
  :obj:`False`, this module does not track such statistics and always
37
38
  uses instance statistics in both training and eval modes.
38
39
  (default: :obj:`False`)
40
+ device (torch.device, optional): The device to use for the module.
41
+ (default: :obj:`None`)
39
42
  """
40
43
  def __init__(
41
44
  self,
@@ -44,9 +47,10 @@ class InstanceNorm(_InstanceNorm):
44
47
  momentum: float = 0.1,
45
48
  affine: bool = False,
46
49
  track_running_stats: bool = False,
50
+ device: Optional[torch.device] = None,
47
51
  ):
48
52
  super().__init__(in_channels, eps, momentum, affine,
49
- track_running_stats)
53
+ track_running_stats, device=device)
50
54
 
51
55
  def reset_parameters(self):
52
56
  r"""Resets all learnable parameters of the module."""