pyg-nightly 2.6.0.dev20240704__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyg-nightly might be problematic. Click here for more details.

Files changed (268) hide show
  1. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +81 -58
  2. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +265 -221
  3. {pyg_nightly-2.6.0.dev20240704.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
  4. pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
  5. torch_geometric/__init__.py +34 -1
  6. torch_geometric/_compile.py +11 -3
  7. torch_geometric/_onnx.py +228 -0
  8. torch_geometric/config_mixin.py +8 -3
  9. torch_geometric/config_store.py +1 -1
  10. torch_geometric/contrib/__init__.py +1 -1
  11. torch_geometric/contrib/explain/pgm_explainer.py +1 -1
  12. torch_geometric/data/__init__.py +19 -1
  13. torch_geometric/data/batch.py +2 -2
  14. torch_geometric/data/collate.py +1 -3
  15. torch_geometric/data/data.py +110 -6
  16. torch_geometric/data/database.py +19 -5
  17. torch_geometric/data/dataset.py +14 -9
  18. torch_geometric/data/extract.py +1 -1
  19. torch_geometric/data/feature_store.py +17 -22
  20. torch_geometric/data/graph_store.py +3 -2
  21. torch_geometric/data/hetero_data.py +139 -7
  22. torch_geometric/data/hypergraph_data.py +2 -2
  23. torch_geometric/data/in_memory_dataset.py +2 -2
  24. torch_geometric/data/lightning/datamodule.py +42 -28
  25. torch_geometric/data/storage.py +9 -1
  26. torch_geometric/datasets/__init__.py +20 -1
  27. torch_geometric/datasets/actor.py +7 -9
  28. torch_geometric/datasets/airfrans.py +17 -20
  29. torch_geometric/datasets/airports.py +8 -10
  30. torch_geometric/datasets/amazon.py +8 -11
  31. torch_geometric/datasets/amazon_book.py +8 -9
  32. torch_geometric/datasets/amazon_products.py +7 -9
  33. torch_geometric/datasets/aminer.py +8 -9
  34. torch_geometric/datasets/aqsol.py +10 -13
  35. torch_geometric/datasets/attributed_graph_dataset.py +8 -10
  36. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  37. torch_geometric/datasets/ba_shapes.py +5 -6
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/city.py +157 -0
  40. torch_geometric/datasets/dbp15k.py +1 -1
  41. torch_geometric/datasets/gdelt_lite.py +3 -2
  42. torch_geometric/datasets/ged_dataset.py +3 -2
  43. torch_geometric/datasets/git_mol_dataset.py +263 -0
  44. torch_geometric/datasets/gnn_benchmark_dataset.py +3 -2
  45. torch_geometric/datasets/hgb_dataset.py +2 -2
  46. torch_geometric/datasets/hm.py +1 -1
  47. torch_geometric/datasets/instruct_mol_dataset.py +134 -0
  48. torch_geometric/datasets/linkx_dataset.py +4 -3
  49. torch_geometric/datasets/lrgb.py +3 -5
  50. torch_geometric/datasets/malnet_tiny.py +2 -1
  51. torch_geometric/datasets/md17.py +3 -3
  52. torch_geometric/datasets/medshapenet.py +145 -0
  53. torch_geometric/datasets/mnist_superpixels.py +2 -3
  54. torch_geometric/datasets/modelnet.py +1 -1
  55. torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
  56. torch_geometric/datasets/molecule_net.py +3 -2
  57. torch_geometric/datasets/neurograph.py +1 -3
  58. torch_geometric/datasets/ogb_mag.py +1 -1
  59. torch_geometric/datasets/opf.py +19 -5
  60. torch_geometric/datasets/pascal_pf.py +1 -1
  61. torch_geometric/datasets/pcqm4m.py +2 -1
  62. torch_geometric/datasets/ppi.py +2 -1
  63. torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
  64. torch_geometric/datasets/qm7.py +1 -1
  65. torch_geometric/datasets/qm9.py +3 -2
  66. torch_geometric/datasets/shrec2016.py +2 -2
  67. torch_geometric/datasets/snap_dataset.py +8 -4
  68. torch_geometric/datasets/tag_dataset.py +462 -0
  69. torch_geometric/datasets/teeth3ds.py +269 -0
  70. torch_geometric/datasets/web_qsp_dataset.py +342 -0
  71. torch_geometric/datasets/wikics.py +2 -1
  72. torch_geometric/datasets/wikidata.py +2 -1
  73. torch_geometric/deprecation.py +1 -1
  74. torch_geometric/distributed/__init__.py +13 -0
  75. torch_geometric/distributed/dist_loader.py +2 -2
  76. torch_geometric/distributed/local_feature_store.py +3 -2
  77. torch_geometric/distributed/local_graph_store.py +2 -1
  78. torch_geometric/distributed/partition.py +9 -8
  79. torch_geometric/distributed/rpc.py +3 -3
  80. torch_geometric/edge_index.py +35 -22
  81. torch_geometric/explain/algorithm/attention_explainer.py +219 -29
  82. torch_geometric/explain/algorithm/base.py +2 -2
  83. torch_geometric/explain/algorithm/captum.py +1 -1
  84. torch_geometric/explain/algorithm/captum_explainer.py +2 -1
  85. torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
  86. torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
  87. torch_geometric/explain/algorithm/pg_explainer.py +305 -47
  88. torch_geometric/explain/explainer.py +2 -2
  89. torch_geometric/explain/explanation.py +89 -5
  90. torch_geometric/explain/metric/faithfulness.py +1 -1
  91. torch_geometric/graphgym/checkpoint.py +2 -1
  92. torch_geometric/graphgym/config.py +3 -2
  93. torch_geometric/graphgym/imports.py +15 -4
  94. torch_geometric/graphgym/logger.py +1 -1
  95. torch_geometric/graphgym/loss.py +1 -1
  96. torch_geometric/graphgym/models/encoder.py +2 -2
  97. torch_geometric/graphgym/models/layer.py +1 -1
  98. torch_geometric/graphgym/utils/comp_budget.py +4 -3
  99. torch_geometric/hash_tensor.py +798 -0
  100. torch_geometric/index.py +16 -7
  101. torch_geometric/inspector.py +6 -2
  102. torch_geometric/io/fs.py +27 -0
  103. torch_geometric/io/tu.py +2 -3
  104. torch_geometric/llm/__init__.py +9 -0
  105. torch_geometric/llm/large_graph_indexer.py +741 -0
  106. torch_geometric/llm/models/__init__.py +23 -0
  107. torch_geometric/llm/models/g_retriever.py +251 -0
  108. torch_geometric/llm/models/git_mol.py +336 -0
  109. torch_geometric/llm/models/glem.py +397 -0
  110. torch_geometric/llm/models/llm.py +470 -0
  111. torch_geometric/llm/models/llm_judge.py +158 -0
  112. torch_geometric/llm/models/molecule_gpt.py +222 -0
  113. torch_geometric/llm/models/protein_mpnn.py +333 -0
  114. torch_geometric/llm/models/sentence_transformer.py +188 -0
  115. torch_geometric/llm/models/txt2kg.py +353 -0
  116. torch_geometric/llm/models/vision_transformer.py +38 -0
  117. torch_geometric/llm/rag_loader.py +154 -0
  118. torch_geometric/llm/utils/__init__.py +10 -0
  119. torch_geometric/llm/utils/backend_utils.py +443 -0
  120. torch_geometric/llm/utils/feature_store.py +169 -0
  121. torch_geometric/llm/utils/graph_store.py +199 -0
  122. torch_geometric/llm/utils/vectorrag.py +125 -0
  123. torch_geometric/loader/cluster.py +6 -5
  124. torch_geometric/loader/graph_saint.py +2 -1
  125. torch_geometric/loader/ibmb_loader.py +4 -4
  126. torch_geometric/loader/link_loader.py +1 -1
  127. torch_geometric/loader/link_neighbor_loader.py +2 -1
  128. torch_geometric/loader/mixin.py +6 -5
  129. torch_geometric/loader/neighbor_loader.py +1 -1
  130. torch_geometric/loader/neighbor_sampler.py +2 -2
  131. torch_geometric/loader/prefetch.py +4 -3
  132. torch_geometric/loader/temporal_dataloader.py +2 -2
  133. torch_geometric/loader/utils.py +10 -10
  134. torch_geometric/metrics/__init__.py +23 -2
  135. torch_geometric/metrics/link_pred.py +755 -85
  136. torch_geometric/nn/__init__.py +1 -0
  137. torch_geometric/nn/aggr/__init__.py +2 -0
  138. torch_geometric/nn/aggr/base.py +1 -1
  139. torch_geometric/nn/aggr/equilibrium.py +1 -1
  140. torch_geometric/nn/aggr/fused.py +1 -1
  141. torch_geometric/nn/aggr/patch_transformer.py +149 -0
  142. torch_geometric/nn/aggr/set_transformer.py +1 -1
  143. torch_geometric/nn/aggr/utils.py +9 -4
  144. torch_geometric/nn/attention/__init__.py +9 -1
  145. torch_geometric/nn/attention/polynormer.py +107 -0
  146. torch_geometric/nn/attention/qformer.py +71 -0
  147. torch_geometric/nn/attention/sgformer.py +99 -0
  148. torch_geometric/nn/conv/__init__.py +2 -0
  149. torch_geometric/nn/conv/appnp.py +1 -1
  150. torch_geometric/nn/conv/collect.jinja +6 -3
  151. torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
  152. torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
  153. torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
  154. torch_geometric/nn/conv/dna_conv.py +1 -1
  155. torch_geometric/nn/conv/eg_conv.py +7 -7
  156. torch_geometric/nn/conv/gat_conv.py +33 -4
  157. torch_geometric/nn/conv/gatv2_conv.py +35 -4
  158. torch_geometric/nn/conv/gen_conv.py +1 -1
  159. torch_geometric/nn/conv/general_conv.py +1 -1
  160. torch_geometric/nn/conv/gravnet_conv.py +2 -1
  161. torch_geometric/nn/conv/hetero_conv.py +3 -2
  162. torch_geometric/nn/conv/meshcnn_conv.py +487 -0
  163. torch_geometric/nn/conv/message_passing.py +6 -5
  164. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  165. torch_geometric/nn/conv/rgcn_conv.py +2 -1
  166. torch_geometric/nn/conv/sg_conv.py +1 -1
  167. torch_geometric/nn/conv/spline_conv.py +2 -1
  168. torch_geometric/nn/conv/ssg_conv.py +1 -1
  169. torch_geometric/nn/conv/transformer_conv.py +5 -3
  170. torch_geometric/nn/data_parallel.py +5 -4
  171. torch_geometric/nn/dense/linear.py +5 -24
  172. torch_geometric/nn/encoding.py +17 -3
  173. torch_geometric/nn/fx.py +17 -15
  174. torch_geometric/nn/model_hub.py +5 -16
  175. torch_geometric/nn/models/__init__.py +11 -0
  176. torch_geometric/nn/models/attentive_fp.py +1 -1
  177. torch_geometric/nn/models/attract_repel.py +148 -0
  178. torch_geometric/nn/models/basic_gnn.py +2 -1
  179. torch_geometric/nn/models/captum.py +1 -1
  180. torch_geometric/nn/models/deep_graph_infomax.py +1 -1
  181. torch_geometric/nn/models/dimenet.py +2 -2
  182. torch_geometric/nn/models/dimenet_utils.py +4 -2
  183. torch_geometric/nn/models/gpse.py +1083 -0
  184. torch_geometric/nn/models/graph_unet.py +13 -4
  185. torch_geometric/nn/models/lpformer.py +783 -0
  186. torch_geometric/nn/models/metapath2vec.py +1 -1
  187. torch_geometric/nn/models/mlp.py +4 -2
  188. torch_geometric/nn/models/node2vec.py +1 -1
  189. torch_geometric/nn/models/polynormer.py +206 -0
  190. torch_geometric/nn/models/rev_gnn.py +3 -3
  191. torch_geometric/nn/models/schnet.py +2 -1
  192. torch_geometric/nn/models/sgformer.py +219 -0
  193. torch_geometric/nn/models/signed_gcn.py +1 -1
  194. torch_geometric/nn/models/visnet.py +2 -2
  195. torch_geometric/nn/norm/batch_norm.py +17 -7
  196. torch_geometric/nn/norm/diff_group_norm.py +7 -2
  197. torch_geometric/nn/norm/graph_norm.py +9 -4
  198. torch_geometric/nn/norm/instance_norm.py +5 -1
  199. torch_geometric/nn/norm/layer_norm.py +15 -7
  200. torch_geometric/nn/norm/msg_norm.py +8 -2
  201. torch_geometric/nn/pool/__init__.py +15 -9
  202. torch_geometric/nn/pool/cluster_pool.py +144 -0
  203. torch_geometric/nn/pool/connect/base.py +1 -3
  204. torch_geometric/nn/pool/edge_pool.py +1 -1
  205. torch_geometric/nn/pool/knn.py +13 -10
  206. torch_geometric/nn/pool/select/base.py +1 -4
  207. torch_geometric/nn/summary.py +1 -1
  208. torch_geometric/nn/to_hetero_module.py +4 -3
  209. torch_geometric/nn/to_hetero_transformer.py +3 -3
  210. torch_geometric/nn/to_hetero_with_bases_transformer.py +5 -5
  211. torch_geometric/profile/__init__.py +2 -0
  212. torch_geometric/profile/nvtx.py +66 -0
  213. torch_geometric/profile/profiler.py +18 -9
  214. torch_geometric/profile/utils.py +20 -5
  215. torch_geometric/sampler/__init__.py +2 -1
  216. torch_geometric/sampler/base.py +337 -8
  217. torch_geometric/sampler/hgt_sampler.py +11 -1
  218. torch_geometric/sampler/neighbor_sampler.py +298 -25
  219. torch_geometric/sampler/utils.py +93 -5
  220. torch_geometric/testing/__init__.py +4 -0
  221. torch_geometric/testing/decorators.py +35 -5
  222. torch_geometric/testing/distributed.py +1 -1
  223. torch_geometric/transforms/__init__.py +4 -0
  224. torch_geometric/transforms/add_gpse.py +49 -0
  225. torch_geometric/transforms/add_metapaths.py +10 -8
  226. torch_geometric/transforms/add_positional_encoding.py +2 -2
  227. torch_geometric/transforms/base_transform.py +2 -1
  228. torch_geometric/transforms/delaunay.py +65 -15
  229. torch_geometric/transforms/face_to_edge.py +32 -3
  230. torch_geometric/transforms/gdc.py +8 -9
  231. torch_geometric/transforms/largest_connected_components.py +1 -1
  232. torch_geometric/transforms/mask.py +5 -1
  233. torch_geometric/transforms/node_property_split.py +1 -1
  234. torch_geometric/transforms/normalize_features.py +3 -3
  235. torch_geometric/transforms/pad.py +1 -1
  236. torch_geometric/transforms/random_link_split.py +1 -1
  237. torch_geometric/transforms/remove_duplicated_edges.py +4 -2
  238. torch_geometric/transforms/remove_self_loops.py +36 -0
  239. torch_geometric/transforms/rooted_subgraph.py +1 -1
  240. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  241. torch_geometric/transforms/virtual_node.py +2 -1
  242. torch_geometric/typing.py +82 -17
  243. torch_geometric/utils/__init__.py +6 -1
  244. torch_geometric/utils/_lexsort.py +0 -9
  245. torch_geometric/utils/_negative_sampling.py +28 -13
  246. torch_geometric/utils/_normalize_edge_index.py +46 -0
  247. torch_geometric/utils/_scatter.py +126 -164
  248. torch_geometric/utils/_sort_edge_index.py +0 -2
  249. torch_geometric/utils/_spmm.py +16 -14
  250. torch_geometric/utils/_subgraph.py +4 -0
  251. torch_geometric/utils/_tree_decomposition.py +1 -1
  252. torch_geometric/utils/_trim_to_layer.py +2 -2
  253. torch_geometric/utils/augmentation.py +1 -1
  254. torch_geometric/utils/convert.py +17 -10
  255. torch_geometric/utils/cross_entropy.py +34 -13
  256. torch_geometric/utils/embedding.py +91 -2
  257. torch_geometric/utils/geodesic.py +28 -25
  258. torch_geometric/utils/influence.py +279 -0
  259. torch_geometric/utils/map.py +14 -10
  260. torch_geometric/utils/nested.py +1 -1
  261. torch_geometric/utils/smiles.py +3 -3
  262. torch_geometric/utils/sparse.py +32 -24
  263. torch_geometric/visualization/__init__.py +2 -1
  264. torch_geometric/visualization/graph.py +250 -5
  265. torch_geometric/warnings.py +11 -2
  266. torch_geometric/nn/nlp/__init__.py +0 -7
  267. torch_geometric/nn/nlp/llm.py +0 -283
  268. torch_geometric/nn/nlp/sentence_transformer.py +0 -94
@@ -0,0 +1,487 @@
1
+ # The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
2
+ # pyright: reportIncompatibleMethodOverride=false
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch.nn import Linear, Module, ModuleList
8
+
9
+ from torch_geometric.nn.conv import MessagePassing
10
+ from torch_geometric.typing import Tensor
11
+
12
+
13
+ class MeshCNNConv(MessagePassing):
14
+ r"""The convolutional layer introduced by the paper
15
+ `"MeshCNN: A Network With An Edge" <https://arxiv.org/abs/1809.05910>`_.
16
+
17
+ Recall that, given a set of categories :math:`C`,
18
+ MeshCNN is a function that takes as its input
19
+ a triangular mesh
20
+ :math:`\mathcal{m} = (V, F) \in \mathbb{R}^{|V| \times 3} \times
21
+ \{0,...,|V|-1\}^{3 \times |F|}`, and returns as its output
22
+ a :math:`|C|`-dimensional vector, whose :math:`i` th component denotes
23
+ the probability of the input mesh belonging to category :math:`c_i \in C`.
24
+
25
+ Let :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}`
26
+ denote the output value of the prior (e.g. :math:`k` th )
27
+ layer of our neural network. The :math:`i` th row of :math:`X^{(k)}` is a
28
+ :math:`\text{Dim-Out}(k)`-dimensional vector that represents the features
29
+ computed by the :math:`k` th layer for edge :math:`e_i` of the input mesh
30
+ :math:`\mathcal{m}`. Let :math:`A \in \{0, ..., |E|-1\}^{2 \times 4*|E|}`
31
+ denote the *edge adjacency* matrix of our input mesh :math:`\mathcal{m}`.
32
+ The :math:`j` th column of :math:`A` returns a pair of indices
33
+ :math:`k,l \in \{0,...,|E|-1\}`, which means that edge
34
+ :math:`e_k` is adjacent to edge :math:`e_l`
35
+ in our input mesh :math:`\mathcal{m}`.
36
+ The definition of edge adjacency in a triangular
37
+ mesh is illustrated in Figure 1.
38
+ In a triangular
39
+ mesh, each edge :math:`e_i` is expected to be adjacent to exactly :math:`4`
40
+ neighboring edges, hence the number of columns of :math:`A`: :math:`4*|E|`.
41
+ We write *the neighborhood* of edge :math:`e_i` as
42
+ :math:`\mathcal{N}(i) = (a(i), b(i), c(i), d(i))` where
43
+
44
+ 1. :math:`a(i)` denotes the index of the *first* counter-clockwise
45
+ edge of the face *above* :math:`e_i`.
46
+
47
+ 2. :math:`b(i)` denotes the index of the *second* counter-clockwise
48
+ edge of the face *above* :math:`e_i`.
49
+
50
+ 3. :math:`c(i)` denotes the index of the *first* counter-clockwise edge
51
+ of the face *below* :math:`e_i`.
52
+
53
+ 4. :math:`d(i)` denotes the index of the *second*
54
+ counter-clockwise edge of the face *below* :math:`e_i`.
55
+
56
+ .. figure:: ../_figures/meshcnn_edge_adjacency.svg
57
+ :align: center
58
+ :width: 80%
59
+
60
+ **Figure 1:** The neighbors of edge :math:`\mathbf{e_1}`
61
+ are :math:`\mathbf{e_2}, \mathbf{e_3}, \mathbf{e_4}` and
62
+ :math:`\mathbf{e_5}`, respectively.
63
+ We write this as
64
+ :math:`\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)`
65
+
66
+
67
+ Because of this ordering constraint, :obj:`MeshCNNConv` **requires
68
+ that the columns of** :math:`A`
69
+ **be ordered in the following way**:
70
+
71
+ .. math::
72
+ &A[:,0] = (0, \text{The index of the "a" edge for edge } 0) \\
73
+ &A[:,1] = (0, \text{The index of the "b" edge for edge } 0) \\
74
+ &A[:,2] = (0, \text{The index of the "c" edge for edge } 0) \\
75
+ &A[:,3] = (0, \text{The index of the "d" edge for edge } 0) \\
76
+ \vdots \\
77
+ &A[:,4*|E|-4] =
78
+ \bigl(|E|-1,
79
+ a\bigl(|E|-1\bigr)\bigr) \\
80
+ &A[:,4*|E|-3] =
81
+ \bigl(|E|-1,
82
+ b\bigl(|E|-1\bigr)\bigr) \\
83
+ &A[:,4*|E|-2] =
84
+ \bigl(|E|-1,
85
+ c\bigl(|E|-1\bigr)\bigr) \\
86
+ &A[:,4*|E|-1] =
87
+ \bigl(|E|-1,
88
+ d\bigl(|E|-1\bigr)\bigr)
89
+
90
+
91
+ Stated a bit more compactly, for every edge :math:`e_i` in the input mesh,
92
+ :math:`A`, should have the following entries
93
+
94
+ .. math::
95
+ A[:, 4*i] &= (i, a(i)) \\
96
+ A[:, 4*i + 1] &= (i, b(i)) \\
97
+ A[:, 4*i + 2] &= (i, c(i)) \\
98
+ A[:, 4*i + 3] &= (i, d(i))
99
+
100
+ To summarize so far, we have defined 3 things:
101
+
102
+ 1. The activation of the prior (e.g. :math:`k` th) layer,
103
+ :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}`
104
+
105
+ 2. The edge adjacency matrix and the definition of edge adjacency.
106
+ :math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}`
107
+
108
+ 3. The ways the columns of :math:`A` must be ordered.
109
+
110
+
111
+
112
+ We are now finally able to define the :obj:`MeshCNNConv` class/layer.
113
+ In the following definition
114
+ we assume :obj:`MeshCNNConv` is at the :math:`k+1` th layer of our
115
+ neural network.
116
+
117
+ The :obj:`MeshCNNConv` layer is a function,
118
+
119
+ .. math::
120
+ \text{MeshCNNConv}^{(k+1)}(X^{(k)}, A) = X^{(k+1)},
121
+
122
+ that, given the prior layer's output
123
+ :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}`
124
+ and the edge adjacency matrix :math:`A`
125
+ of the input mesh (graph) :math:`\mathcal{m}` ,
126
+ returns a new edge feature tensor
127
+ :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k+1)}`,
128
+ where the :math:`i` th row of :math:`X^{(k+1)}`, denoted by
129
+ :math:`x^{(k+1)}_i`,
130
+ represents the :math:`\text{Dim-Out}(k+1)`-dimensional feature vector
131
+ of edge :math:`e_i`, **and is defined as follows**:
132
+
133
+ .. math::
134
+ x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\
135
+ &+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\
136
+ &+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\
137
+ &+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\
138
+ &+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr).
139
+
140
+ :math:`W_0^{(k+1)},W_1^{(k+1)},W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)}
141
+ \in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}`
142
+ are trainable linear functions (i.e. "the weights" of this layer).
143
+ :math:`x_i` is the :math:`\text{Dim-Out}(k)`-dimensional feature of
144
+ edge :math:`e_i` vector computed by the prior (e.g. :math:`k`) th layer.
145
+ :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, and
146
+ :math:`x^{(k)}_{d(i)}` are the :math:`\text{Dim-Out}(k)`-feature vectors,
147
+ computed in the :math:`k` th layer, that are associated with the :math:`4`
148
+ neighboring edges of :math:`e_i`.
149
+
150
+
151
+ Args:
152
+ in_channels (int): Corresponds to :math:`\text{Dim-Out}(k)`
153
+ in the above overview. This
154
+ represents the output dimension of the prior layer. For the given
155
+ input mesh :math:`\mathcal{m} = (V, F)`, the prior layer is
156
+ expected to output a
157
+ :math:`X \in \mathbb{R}^{|E| \times \textit{in_channels}}`
158
+ feature matrix.
159
+ Assuming the instance of this class
160
+ is situated at layer :math:`k+1`, we write that
161
+ :math:`X^{(k)} \in \mathbb{R}^{|E| \times \textit{in_channels}}`.
162
+ out_channels (int): Corresponds to :math:`\text{Dim-Out}(k+1)` in the
163
+ above overview. This represents the output dimension of this layer.
164
+ Assuming the instance of this class
165
+ is situated at layer :math:`k+1`, we write that
166
+ :math:`X^{(k+1)}
167
+ \in \mathbb{R}^{|E| \times \textit{out_channels}}`.
168
+ kernels (torch.nn.ModuleList, optional): A list of length of 5,
169
+ where each
170
+ element is a :class:`torch.nn.module` (i.e a neural network),
171
+ that each MUST take as input a vector
172
+ of dimension :`obj:in_channels` and return a vector of dimension
173
+ :obj:`out_channels`. In particular,
174
+ `obj:kernels[0]` is :math:`W^{(k+1)}_0` in the above overview
175
+ (see :obj:`MeshCNNConv`), `obj:kernels[1]` is :math:`W^{(k+1)}_1`,
176
+ `obj:kernels[2]` is :math:`W^{(k+1)}_2`,
177
+ `obj:kernels[3]` is :math:`W^{(k+1)}_3`
178
+ `obj:kernels[4]` is :math:`W^{(k+1)}_4`.
179
+ Note that this input is optional, in which case
180
+ each of the 5 elements in the kernels will be a linear
181
+ neural network :class:`torch.nn.modules.Linear`
182
+ correctly configured to take as input
183
+ :attr:`in_channels`-dimensional vectors and return
184
+ a vector of dimensions :attr:`out_channels`.
185
+
186
+ Discussion:
187
+ The key difference that separates :obj:`MeshCNNConv` from a traditional
188
+ message passing graph neural network is that :obj:`MeshCNNConv`
189
+ requires the set of neighbors for a node
190
+ :math:`\mathcal{N}(u) = (v_1, v_2, ...)`
191
+ to *be an ordered set* (i.e. a tuple). In
192
+ fact, :obj:`MeshCNNConv` goes further, requiring
193
+ that :math:`\mathcal{N}(u)` always return a set of size :math:`4`.
194
+ This is different to most message passing graph neural networks,
195
+ which assume that :math:`\mathcal{N}(u) = \{v_1, v_2, ...\}` returns an
196
+ ordered set. This lends :obj:`MeshCNNConv` more expressive power,
197
+ at the cost of no longer being permutation invariant to
198
+ :math:`\mathbb{S}_4`. Put more plainly, in tradition message passing
199
+ GNNs, the network is *unable* to distinguish one neighboring node
200
+ from another.
201
+ In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a
202
+ "role", either the "a", "b", "c", or "d" neighbor. We encode this fact
203
+ by requiring that :math:`\mathcal{N}` return the 4-tuple,
204
+ where the first component is the "a" neighbor, and so on.
205
+
206
+ To summarize this comparison, it may re-define
207
+ :obj:`MeshCNNConv` in terms of :math:`\text{UPDATE}` and
208
+ :math:`\text{AGGREGATE}`
209
+ functions, which is a general way to define a traditional GNN layer.
210
+ If we let :math:`x_i^{(k+1)}`
211
+ denote the output of a GNN layer for node :math:`i` at
212
+ layer :math:`k+1`, and let
213
+ :math:`\mathcal{N}(i)` denote the set of nodes adjacent
214
+ to node :math:`i`,
215
+ then we can describe the :math:`k+1` th layer as traditional GNN
216
+ as
217
+
218
+ .. math::
219
+ x_i^{(k+1)} = \text{UPDATE}^{(k+1)}\bigl(x^{(k)}_i,
220
+ \text{AGGREGATE}^{(k+1)}\bigl(\mathcal{N}(i)\bigr)\bigr).
221
+
222
+ Here, :math:`\text{UPDATE}^{(k+1)}` is a function of :math:`2`
223
+ :math:`\text{Dim-Out}(k)`-dimensional vectors, and returns a
224
+ :math:`\text{Dim-Out}(k+1)`-dimensional vector.
225
+ :math:`\text{AGGREGATE}^{(k+1)}` function
226
+ is a function of a *unordered set*
227
+ of nodes that are neighbors of node :math:`i`, as defined by
228
+ :math:`\mathcal{N}(i)`. Usually the size of this set varies across
229
+ different nodes :math:`i`, and one of the most basic examples
230
+ of such a function is the "sum aggregation", defined as
231
+ :math:`\text{AGGREGATE}^{(k+1)}(\mathcal{N}(i)) =
232
+ \sum_{j \in \mathcal{N}(i)} x^{(k)}_j`.
233
+ See
234
+ :class:`SumAggregation <torch_geometric.nn.aggr.basic.SumAggregation>`
235
+ for more.
236
+
237
+ In contrast, while :obj:`MeshCNNConv` 's :math:`\text{UPDATE}`
238
+ function follows
239
+ a tradition GNN, its :math:`\text{AGGREGATE}` is a function of a tuple
240
+ (i.e. an ordered set) of neighbors
241
+ rather than a unordered set of neighbors.
242
+ In particular, while the :math:`\text{UPDATE}`
243
+ function of :obj:`MeshCNNConv` for :math:`e_i` is
244
+
245
+ .. math::
246
+ x_i^{(k+1)} = \text{UPDATE}^{(k+1)}(x_i^{(k)}, s_i^{(k+1)})
247
+ = W_0^{(k+1)}x_i^{(k)} + s_i^{(k+1)},
248
+
249
+ in contrast, :obj:`MeshCNNConv` 's :math:`\text{AGGREGATE}` function is
250
+
251
+ .. math::
252
+ s_i^{(k+1)} = \text{AGGREGATE}^{(k+1)}(A, B, C, D)
253
+ &= W_1^{(k+1)}\bigl|A - C \bigr| \\
254
+ &= W_2^{(k+1)}\bigl(A + C \bigr) \\
255
+ &= W_3^{(k+1)}\bigl|B - D \bigr| \\
256
+ &= W_4^{(k+1)}\bigl(B + D \bigr),
257
+
258
+ where :math:`A=x_{a(i)}^{(k)}, B=x_{b(i)}^{(k)}, C=x_{c(i)}^{(k)},`
259
+ and :math:`D=x_{d(i)}^{(k)}`.
260
+
261
+ ..
262
+
263
+ The :math:`i` th row of
264
+ :math:`V \in \mathbb{R}^{|V| \times 3}`
265
+ holds the cartesian :math:`xyz`
266
+ coordinates for node :math:`v_i` in the mesh, and the :math:`j` th
267
+ column in :math:`F \in \{1,...,|V|\}^{3 \times |V|}`
268
+ holds the :math:`3` indices
269
+ :math:`(k,l,m)` that correspond to the :math:`3` nodes
270
+ :math:`(v_k, v_l, v_m)` that construct face :math:`j` of the mesh.
271
+ """
272
+ def __init__(self, in_channels: int, out_channels: int,
273
+ kernels: Optional[ModuleList] = None):
274
+ super().__init__(aggr='add')
275
+ self.in_channels = in_channels
276
+ self.out_channels = out_channels
277
+
278
+ if kernels is None:
279
+ self.kernels = ModuleList(
280
+ [Linear(in_channels, out_channels) for _ in range(5)])
281
+
282
+ else:
283
+ # ensures kernels is properly formed, otherwise throws
284
+ # the appropriate error.
285
+ self._assert_kernels(kernels)
286
+ self.kernels = kernels
287
+
288
+ def forward(self, x: Tensor, edge_index: Tensor):
289
+ r"""Forward pass.
290
+
291
+ Args:
292
+ x(torch.Tensor): :math:`X^{(k)} \in
293
+ \mathbb{R}^{|E| \times \textit{in_channels}}`.
294
+ The edge feature tensor returned by the prior layer
295
+ (e.g. :math:`k`). The tensor is of shape
296
+ :math:`|E| \times \text{Dim-Out}(k)`, or equivalently,
297
+ :obj:`(|E|, self.in_channels)`.
298
+
299
+ edge_index(torch.Tensor):
300
+ :math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}`.
301
+ The edge adjacency tensor of the networks input mesh
302
+ :math:`\mathcal{m} = (V, F)`. The edge adjacency tensor
303
+ **MUST** have the following form:
304
+
305
+ .. math::
306
+ &A[:,0] = (0,
307
+ \text{The index of the "a" edge for edge } 0) \\
308
+ &A[:,1] = (0,
309
+ \text{The index of the "b" edge for edge } 0) \\
310
+ &A[:,2] = (0,
311
+ \text{The index of the "c" edge for edge } 0) \\
312
+ &A[:,3] = (0,
313
+ \text{The index of the "d" edge for edge } 0) \\
314
+ \vdots \\
315
+ &A[:,4*|E|-4] =
316
+ \bigl(|E|-1,
317
+ a\bigl(|E|-1\bigr)\bigr) \\
318
+ &A[:,4*|E|-3] =
319
+ \bigl(|E|-1,
320
+ b\bigl(|E|-1\bigr)\bigr) \\
321
+ &A[:,4*|E|-2] =
322
+ \bigl(|E|-1,
323
+ c\bigl(|E|-1\bigr)\bigr) \\
324
+ &A[:,4*|E|-1] =
325
+ \bigl(|E|-1,
326
+ d\bigl(|E|-1\bigr)\bigr)
327
+
328
+ See :obj:`MeshCNNConv` for what
329
+ "index of the 'a'(b,c,d) edge for edge i" means, and also
330
+ for the general definition of edge adjacency in MeshCNN.
331
+ These definitions are also provided in the
332
+ `paper <https://arxiv.org/abs/1809.05910>`_ itself.
333
+
334
+ Returns:
335
+ torch.Tensor:
336
+ :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \textit{out_channels}}`.
337
+ The edge feature tensor for this (e.g. the :math:`k+1` th) layer.
338
+ The :math:`i` th row of :math:`X^{(k+1)}` is computed according
339
+ to the formula
340
+
341
+ .. math::
342
+ x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\
343
+ &+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\
344
+ &+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\
345
+ &+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\
346
+ &+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr),
347
+
348
+ where :math:`W_0^{(k+1)},W_1^{(k+1)},
349
+ W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)}
350
+ \in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}`
351
+ are the trainable linear functions (i.e. the trainable
352
+ "weights") of this layer, and
353
+ :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`,
354
+ :math:`x^{(k)}_{d(i)}` are the
355
+ :math:`\text{Dim-Out}(k)`-dimensional edge feature vectors
356
+ computed by the prior (:math:`k` th) layer,
357
+ that are associated with the :math:`4`
358
+ neighboring edges of :math:`e_i`.
359
+
360
+ """
361
+ return self.propagate(edge_index, x=x)
362
+
363
+ def message(self, x_j: Tensor) -> Tensor:
364
+ r"""The messaging passing step of :obj:`MeshCNNConv`.
365
+
366
+
367
+ Args:
368
+ x_j: A :obj:`[4*|E|, num_node_features]` tensor.
369
+ Its ith row holds the value
370
+ stored by the source node in the previous layer of edge i.
371
+
372
+ Returns:
373
+ A :obj:`[|E|, num_node_features]` tensor,
374
+ whose ith row will be the value
375
+ that the target node of edge i will receive.
376
+ """
377
+ # The following variables names are taken from the paper
378
+ # MeshCNN computes the features associated with edge
379
+ # e by (|a - c|, a + c, |b - c|, b + c), where a, b, c, d are the
380
+ # neighboring edges of e, a being the 1 edge of the upper face,
381
+ # b being the second edge of the upper face, c being the first edge
382
+ # of the lower face,
383
+ # and d being the second edge of the lower face of the input Mesh
384
+
385
+ # TODO: It is unclear if view is faster. If it is not,
386
+ # then we should prefer the strided method commented out below
387
+
388
+ E4, in_channels = x_j.size() # E4 = 4|E|, i.e. num edges in line graph
389
+ # Option 1
390
+ n_a = x_j[0::4] # shape: |E| x in_channels
391
+ n_b = x_j[1::4] # shape: |E| x in_channels
392
+ n_c = x_j[2::4] # shape: |E| x in_channels
393
+ n_d = x_j[3::4] # shape: |E| x in_channels
394
+ m = torch.empty(E4, self.out_channels)
395
+ m[0::4] = self.kernels[1].forward(torch.abs(n_a - n_c))
396
+ m[1::4] = self.kernels[2].forward(n_a + n_c)
397
+ m[2::4] = self.kernels[3].forward(torch.abs(n_b - n_d))
398
+ m[3::4] = self.kernels[4].forward(n_b + n_d)
399
+ return m
400
+
401
+ # Option 2
402
+ # E4, in_channels = x_j.size()
403
+ # E = E4 // 4
404
+ # x_j = x_j.view(E, 4, in_channels) # shape: (|E| x 4 x in_channels)
405
+ # n_a, n_b, n_c, n_d = x_j.unbind(
406
+ # dim=1) # shape: (4 x |E| x in_channels)
407
+ # m = torch.stack(
408
+ # [
409
+ # (n_a - n_c).abs(), # shape: |E| x in_channels
410
+ # n_a + n_c,
411
+ # (n_b - n_d).abs(),
412
+ # n_b + n_d,
413
+ # ],
414
+ # dim=1) # shape: (|E| x 4 x in_channels)
415
+ # m.view(E4, in_channels) # shape 4*|E| x in_channels
416
+ # return m
417
+
418
+ def update(self, inputs: Tensor, x: Tensor) -> Tensor:
419
+ r"""The UPDATE step, in reference to the UPDATE and AGGREGATE
420
+ formulation of message passing convolution.
421
+
422
+ Args:
423
+ inputs(torch.Tensor): The :attr:`in_channels`-dimensional vector
424
+ returned by aggregate.
425
+ x(torch.Tensor): :math:`X^{(k)}`. The original inputs to this layer.
426
+
427
+ Returns:
428
+ torch.Tensor: :math:`X^{(k+1)}`. The output of this layer, which
429
+ has shape :obj:`(|E|, out_channels)`.
430
+ """
431
+ return self.kernels[0].forward(x) + inputs
432
+
433
+ def _assert_kernels(self, kernels: ModuleList):
434
+ r"""Ensures that :obj:`kernels` is a list of 5 :obj:`torch.nn.Module`
435
+ modules (i.e. networks). In addition, it also ensures that each network
436
+ takes in input of dimension :attr:`in_channels`, and returns output
437
+ of dimension :attr:`out_channels`.
438
+ This method throws an error otherwise.
439
+
440
+ .. warn::
441
+ This method throws an error if :obj:`kernels` is
442
+ not valid. (Otherwise this method returns nothing)
443
+
444
+ """
445
+ assert isinstance(kernels, ModuleList), \
446
+ f"Parameter 'kernels' must be a \
447
+ torch.nn.module.ModuleList with 5 members, but we got \
448
+ {type(kernels)}."
449
+
450
+ assert len(kernels) == 5, "Parameter 'kernels' must be a \
451
+ torch.nn.module.ModuleList of with exactly 5 members"
452
+
453
+ for i, network in enumerate(kernels):
454
+ assert isinstance(network, Module), \
455
+ f"kernels[{i}] must be torch.nn.Module, got \
456
+ {type(network)}"
457
+ if not hasattr(network, "in_channels") and \
458
+ not hasattr(network, "in_features"):
459
+ warnings.warn(
460
+ f"kernel[{i}] does not have attribute 'in_channels' nor "
461
+ f"'out_features'. The network must take as input a "
462
+ f"{self.in_channels}-dimensional tensor.", stacklevel=2)
463
+ else:
464
+ input_dimension = getattr(network, "in_channels",
465
+ network.in_features)
466
+ assert input_dimension == self.in_channels, f"The input \
467
+ dimension of the neural network in kernel[{i}] must \
468
+ be \
469
+ equal to 'in_channels', but input_dimension = \
470
+ {input_dimension}, and \
471
+ self.in_channels={self.in_channels}."
472
+
473
+ if not hasattr(network, "out_channels") and \
474
+ not hasattr(network, "out_features"):
475
+ warnings.warn(
476
+ f"kernel[{i}] does not have attribute 'in_channels' nor "
477
+ f"'out_features'. The network must take as input a "
478
+ f"{self.in_channels}-dimensional tensor.", stacklevel=2)
479
+ else:
480
+ output_dimension = getattr(network, "out_channels",
481
+ network.out_features)
482
+ assert output_dimension == self.out_channels, f"The output \
483
+ dimension of the neural network in kernel[{i}] must \
484
+ be \
485
+ equal to 'out_channels', but out_dimension = \
486
+ {output_dimension}, and \
487
+ self.out_channels={self.out_channels}."
@@ -204,7 +204,7 @@ class MessagePassing(torch.nn.Module):
204
204
  def _check_input(
205
205
  self,
206
206
  edge_index: Union[Tensor, SparseTensor],
207
- size: Optional[Tuple[int, int]],
207
+ size: Optional[Tuple[Optional[int], Optional[int]]],
208
208
  ) -> List[Optional[int]]:
209
209
 
210
210
  if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
@@ -276,7 +276,7 @@ class MessagePassing(torch.nn.Module):
276
276
  f"{index.min().item()}). Please ensure that all "
277
277
  f"indices in 'edge_index' point to valid indices "
278
278
  f"in the interval [0, {src.size(self.node_dim)}) in "
279
- f"your node feature matrix and try again.")
279
+ f"your node feature matrix and try again.") from e
280
280
 
281
281
  if (index.numel() > 0 and index.max() >= src.size(self.node_dim)):
282
282
  raise IndexError(
@@ -285,7 +285,7 @@ class MessagePassing(torch.nn.Module):
285
285
  f"{index.max().item()}). Please ensure that all "
286
286
  f"indices in 'edge_index' point to valid indices "
287
287
  f"in the interval [0, {src.size(self.node_dim)}) in "
288
- f"your node feature matrix and try again.")
288
+ f"your node feature matrix and try again.") from e
289
289
 
290
290
  raise e
291
291
 
@@ -1029,6 +1029,7 @@ class MessagePassing(torch.nn.Module):
1029
1029
  :meth:`jittable` is deprecated and a no-op from :pyg:`PyG` 2.5
1030
1030
  onwards.
1031
1031
  """
1032
- warnings.warn(f"'{self.__class__.__name__}.jittable' is deprecated "
1033
- f"and a no-op. Please remove its usage.")
1032
+ warnings.warn(
1033
+ f"'{self.__class__.__name__}.jittable' is deprecated "
1034
+ f"and a no-op. Please remove its usage.", stacklevel=2)
1034
1035
  return self
@@ -14,7 +14,7 @@ from torch_geometric.utils import spmm
14
14
 
15
15
  class MixHopConv(MessagePassing):
16
16
  r"""The Mix-Hop graph convolutional operator from the
17
- `"MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified
17
+ `"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified
18
18
  Neighborhood Mixing" <https://arxiv.org/abs/1905.00067>`_ paper.
19
19
 
20
20
  .. math::
@@ -120,7 +120,8 @@ class RGCNConv(MessagePassing):
120
120
  in_channels = (in_channels, in_channels)
121
121
  self.in_channels_l = in_channels[0]
122
122
 
123
- self._use_segment_matmul_heuristic_output: Optional[bool] = None
123
+ self._use_segment_matmul_heuristic_output: torch.jit.Attribute(
124
+ None, Optional[float])
124
125
 
125
126
  if num_bases is not None:
126
127
  self.weight = Parameter(
@@ -90,7 +90,7 @@ class SGConv(MessagePassing):
90
90
  edge_index, edge_weight, x.size(self.node_dim), False,
91
91
  self.add_self_loops, self.flow, dtype=x.dtype)
92
92
 
93
- for k in range(self.K):
93
+ for _ in range(self.K):
94
94
  # propagate_type: (x: Tensor, edge_weight: OptTensor)
95
95
  x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
96
96
  if self.cached:
@@ -132,7 +132,8 @@ class SplineConv(MessagePassing):
132
132
  if not x[0].is_cuda:
133
133
  warnings.warn(
134
134
  'We do not recommend using the non-optimized CPU version of '
135
- '`SplineConv`. If possible, please move your data to GPU.')
135
+ '`SplineConv`. If possible, please move your data to GPU.',
136
+ stacklevel=2)
136
137
 
137
138
  # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
138
139
  out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
@@ -100,7 +100,7 @@ class SSGConv(MessagePassing):
100
100
  self.add_self_loops, self.flow, dtype=x.dtype)
101
101
 
102
102
  h = x * self.alpha
103
- for k in range(self.K):
103
+ for _ in range(self.K):
104
104
  # propagate_type: (x: Tensor, edge_weight: OptTensor)
105
105
  x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
106
106
  h = h + (1 - self.alpha) / self.K * x
@@ -126,9 +126,11 @@ class TransformerConv(MessagePassing):
126
126
  if isinstance(in_channels, int):
127
127
  in_channels = (in_channels, in_channels)
128
128
 
129
- self.lin_key = Linear(in_channels[0], heads * out_channels)
130
- self.lin_query = Linear(in_channels[1], heads * out_channels)
131
- self.lin_value = Linear(in_channels[0], heads * out_channels)
129
+ self.lin_key = Linear(in_channels[0], heads * out_channels, bias=bias)
130
+ self.lin_query = Linear(in_channels[1], heads * out_channels,
131
+ bias=bias)
132
+ self.lin_value = Linear(in_channels[0], heads * out_channels,
133
+ bias=bias)
132
134
  if edge_dim is not None:
133
135
  self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
134
136
  else:
@@ -57,10 +57,11 @@ class DataParallel(torch.nn.DataParallel):
57
57
  follow_batch=None, exclude_keys=None):
58
58
  super().__init__(module, device_ids, output_device)
59
59
 
60
- warnings.warn("'DataParallel' is usually much slower than "
61
- "'DistributedDataParallel' even on a single machine. "
62
- "Please consider switching to 'DistributedDataParallel' "
63
- "for multi-GPU training.")
60
+ warnings.warn(
61
+ "'DataParallel' is usually much slower than "
62
+ "'DistributedDataParallel' even on a single machine. "
63
+ "Please consider switching to 'DistributedDataParallel' "
64
+ "for multi-GPU training.", stacklevel=2)
64
65
 
65
66
  self.src_device = torch.device(f'cuda:{self.device_ids[0]}')
66
67
  self.follow_batch = follow_batch or []
@@ -1,4 +1,3 @@
1
- import copy
2
1
  import math
3
2
  import sys
4
3
  import time
@@ -58,7 +57,7 @@ def reset_bias_(bias: Optional[Tensor], in_channels: int,
58
57
 
59
58
 
60
59
  class Linear(torch.nn.Module):
61
- r"""Applies a linear tranformation to the incoming data.
60
+ r"""Applies a linear transformation to the incoming data.
62
61
 
63
62
  .. math::
64
63
  \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}
@@ -114,25 +113,6 @@ class Linear(torch.nn.Module):
114
113
 
115
114
  self.reset_parameters()
116
115
 
117
- def __deepcopy__(self, memo):
118
- # PyTorch<1.13 cannot handle deep copies of uninitialized parameters :(
119
- # TODO Drop this code once PyTorch 1.12 is no longer supported.
120
- out = Linear(
121
- self.in_channels,
122
- self.out_channels,
123
- self.bias is not None,
124
- self.weight_initializer,
125
- self.bias_initializer,
126
- ).to(self.weight.device)
127
-
128
- if self.in_channels > 0:
129
- out.weight = copy.deepcopy(self.weight, memo)
130
-
131
- if self.bias is not None:
132
- out.bias = copy.deepcopy(self.bias, memo)
133
-
134
- return out
135
-
136
116
  def reset_parameters(self):
137
117
  r"""Resets all learnable parameters of the module."""
138
118
  reset_weight_(self.weight, self.in_channels, self.weight_initializer)
@@ -192,7 +172,7 @@ class Linear(torch.nn.Module):
192
172
 
193
173
 
194
174
  class HeteroLinear(torch.nn.Module):
195
- r"""Applies separate linear tranformations to the incoming data according
175
+ r"""Applies separate linear transformations to the incoming data according
196
176
  to types.
197
177
 
198
178
  For type :math:`\kappa`, it computes
@@ -365,7 +345,8 @@ class HeteroLinear(torch.nn.Module):
365
345
 
366
346
 
367
347
  class HeteroDictLinear(torch.nn.Module):
368
- r"""Applies separate linear tranformations to the incoming data dictionary.
348
+ r"""Applies separate linear transformations to the incoming data
349
+ dictionary.
369
350
 
370
351
  For key :math:`\kappa`, it computes
371
352
 
@@ -479,7 +460,7 @@ class HeteroDictLinear(torch.nn.Module):
479
460
  lin = self.lins[key]
480
461
  if is_uninitialized_parameter(lin.weight):
481
462
  self.lins[key].initialize_parameters(None, x)
482
- self.reset_parameters()
463
+ self.lins[key].reset_parameters()
483
464
  self._hook.remove()
484
465
  self.in_channels = {key: x.size(-1) for key, x in input[0].items()}
485
466
  delattr(self, '_hook')