pyg-nightly 2.7.0.dev20241009__py3-none-any.whl → 2.8.0.dev20251207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyg-nightly might be problematic. Click here for more details.
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/METADATA +77 -53
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/RECORD +226 -189
- {pyg_nightly-2.7.0.dev20241009.dist-info → pyg_nightly-2.8.0.dev20251207.dist-info}/WHEEL +1 -1
- pyg_nightly-2.8.0.dev20251207.dist-info/licenses/LICENSE +19 -0
- torch_geometric/__init__.py +14 -2
- torch_geometric/_compile.py +9 -3
- torch_geometric/_onnx.py +214 -0
- torch_geometric/config_mixin.py +5 -3
- torch_geometric/config_store.py +1 -1
- torch_geometric/contrib/__init__.py +1 -1
- torch_geometric/contrib/explain/pgm_explainer.py +1 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +1 -3
- torch_geometric/data/data.py +109 -5
- torch_geometric/data/database.py +4 -0
- torch_geometric/data/dataset.py +14 -11
- torch_geometric/data/extract.py +1 -1
- torch_geometric/data/feature_store.py +17 -22
- torch_geometric/data/graph_store.py +3 -2
- torch_geometric/data/hetero_data.py +139 -7
- torch_geometric/data/hypergraph_data.py +2 -2
- torch_geometric/data/in_memory_dataset.py +2 -2
- torch_geometric/data/lightning/datamodule.py +42 -28
- torch_geometric/data/storage.py +9 -1
- torch_geometric/datasets/__init__.py +18 -1
- torch_geometric/datasets/actor.py +7 -9
- torch_geometric/datasets/airfrans.py +15 -17
- torch_geometric/datasets/airports.py +8 -10
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +8 -9
- torch_geometric/datasets/amazon_products.py +7 -9
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +8 -10
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/city.py +157 -0
- torch_geometric/datasets/dbp15k.py +1 -1
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/hgb_dataset.py +2 -2
- torch_geometric/datasets/hm.py +1 -1
- torch_geometric/datasets/instruct_mol_dataset.py +134 -0
- torch_geometric/datasets/md17.py +3 -3
- torch_geometric/datasets/medshapenet.py +145 -0
- torch_geometric/datasets/modelnet.py +1 -1
- torch_geometric/datasets/molecule_gpt_dataset.py +492 -0
- torch_geometric/datasets/molecule_net.py +3 -2
- torch_geometric/datasets/ppi.py +2 -1
- torch_geometric/datasets/protein_mpnn_dataset.py +451 -0
- torch_geometric/datasets/qm7.py +1 -1
- torch_geometric/datasets/qm9.py +1 -1
- torch_geometric/datasets/snap_dataset.py +8 -4
- torch_geometric/datasets/tag_dataset.py +462 -0
- torch_geometric/datasets/teeth3ds.py +269 -0
- torch_geometric/datasets/web_qsp_dataset.py +310 -209
- torch_geometric/datasets/wikics.py +2 -1
- torch_geometric/deprecation.py +1 -1
- torch_geometric/distributed/__init__.py +13 -0
- torch_geometric/distributed/dist_loader.py +2 -2
- torch_geometric/distributed/partition.py +2 -2
- torch_geometric/distributed/rpc.py +3 -3
- torch_geometric/edge_index.py +18 -14
- torch_geometric/explain/algorithm/attention_explainer.py +219 -29
- torch_geometric/explain/algorithm/base.py +2 -2
- torch_geometric/explain/algorithm/captum.py +1 -1
- torch_geometric/explain/algorithm/captum_explainer.py +2 -1
- torch_geometric/explain/algorithm/gnn_explainer.py +406 -69
- torch_geometric/explain/algorithm/graphmask_explainer.py +8 -8
- torch_geometric/explain/algorithm/pg_explainer.py +305 -47
- torch_geometric/explain/explainer.py +2 -2
- torch_geometric/explain/explanation.py +87 -3
- torch_geometric/explain/metric/faithfulness.py +1 -1
- torch_geometric/graphgym/config.py +3 -2
- torch_geometric/graphgym/imports.py +15 -4
- torch_geometric/graphgym/logger.py +1 -1
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/models/encoder.py +2 -2
- torch_geometric/graphgym/models/layer.py +1 -1
- torch_geometric/graphgym/utils/comp_budget.py +4 -3
- torch_geometric/hash_tensor.py +798 -0
- torch_geometric/index.py +14 -5
- torch_geometric/inspector.py +4 -0
- torch_geometric/io/fs.py +5 -4
- torch_geometric/llm/__init__.py +9 -0
- torch_geometric/llm/large_graph_indexer.py +741 -0
- torch_geometric/llm/models/__init__.py +23 -0
- torch_geometric/{nn → llm}/models/g_retriever.py +77 -45
- torch_geometric/llm/models/git_mol.py +336 -0
- torch_geometric/llm/models/glem.py +397 -0
- torch_geometric/{nn/nlp → llm/models}/llm.py +179 -31
- torch_geometric/llm/models/llm_judge.py +158 -0
- torch_geometric/llm/models/molecule_gpt.py +222 -0
- torch_geometric/llm/models/protein_mpnn.py +333 -0
- torch_geometric/llm/models/sentence_transformer.py +188 -0
- torch_geometric/llm/models/txt2kg.py +353 -0
- torch_geometric/llm/models/vision_transformer.py +38 -0
- torch_geometric/llm/rag_loader.py +154 -0
- torch_geometric/llm/utils/__init__.py +10 -0
- torch_geometric/llm/utils/backend_utils.py +443 -0
- torch_geometric/llm/utils/feature_store.py +169 -0
- torch_geometric/llm/utils/graph_store.py +199 -0
- torch_geometric/llm/utils/vectorrag.py +125 -0
- torch_geometric/loader/cluster.py +4 -4
- torch_geometric/loader/ibmb_loader.py +4 -4
- torch_geometric/loader/link_loader.py +1 -1
- torch_geometric/loader/link_neighbor_loader.py +2 -1
- torch_geometric/loader/mixin.py +6 -5
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +3 -2
- torch_geometric/loader/temporal_dataloader.py +2 -2
- torch_geometric/loader/utils.py +10 -10
- torch_geometric/metrics/__init__.py +14 -0
- torch_geometric/metrics/link_pred.py +745 -92
- torch_geometric/nn/__init__.py +1 -0
- torch_geometric/nn/aggr/base.py +1 -1
- torch_geometric/nn/aggr/equilibrium.py +1 -1
- torch_geometric/nn/aggr/fused.py +1 -1
- torch_geometric/nn/aggr/patch_transformer.py +8 -2
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/utils.py +9 -4
- torch_geometric/nn/attention/__init__.py +9 -1
- torch_geometric/nn/attention/polynormer.py +107 -0
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/attention/sgformer.py +99 -0
- torch_geometric/nn/conv/__init__.py +2 -0
- torch_geometric/nn/conv/appnp.py +1 -1
- torch_geometric/nn/conv/cugraph/gat_conv.py +8 -2
- torch_geometric/nn/conv/cugraph/rgcn_conv.py +3 -0
- torch_geometric/nn/conv/cugraph/sage_conv.py +3 -0
- torch_geometric/nn/conv/dna_conv.py +1 -1
- torch_geometric/nn/conv/eg_conv.py +7 -7
- torch_geometric/nn/conv/gen_conv.py +1 -1
- torch_geometric/nn/conv/gravnet_conv.py +2 -1
- torch_geometric/nn/conv/hetero_conv.py +2 -1
- torch_geometric/nn/conv/meshcnn_conv.py +487 -0
- torch_geometric/nn/conv/message_passing.py +5 -4
- torch_geometric/nn/conv/rgcn_conv.py +2 -1
- torch_geometric/nn/conv/sg_conv.py +1 -1
- torch_geometric/nn/conv/spline_conv.py +2 -1
- torch_geometric/nn/conv/ssg_conv.py +1 -1
- torch_geometric/nn/conv/transformer_conv.py +5 -3
- torch_geometric/nn/data_parallel.py +5 -4
- torch_geometric/nn/dense/linear.py +0 -20
- torch_geometric/nn/encoding.py +17 -3
- torch_geometric/nn/fx.py +14 -12
- torch_geometric/nn/model_hub.py +2 -15
- torch_geometric/nn/models/__init__.py +11 -2
- torch_geometric/nn/models/attentive_fp.py +1 -1
- torch_geometric/nn/models/attract_repel.py +148 -0
- torch_geometric/nn/models/basic_gnn.py +2 -1
- torch_geometric/nn/models/captum.py +1 -1
- torch_geometric/nn/models/deep_graph_infomax.py +1 -1
- torch_geometric/nn/models/dimenet.py +2 -2
- torch_geometric/nn/models/dimenet_utils.py +4 -2
- torch_geometric/nn/models/gpse.py +1083 -0
- torch_geometric/nn/models/graph_unet.py +13 -4
- torch_geometric/nn/models/lpformer.py +783 -0
- torch_geometric/nn/models/metapath2vec.py +1 -1
- torch_geometric/nn/models/mlp.py +4 -2
- torch_geometric/nn/models/node2vec.py +1 -1
- torch_geometric/nn/models/polynormer.py +206 -0
- torch_geometric/nn/models/rev_gnn.py +3 -3
- torch_geometric/nn/models/sgformer.py +219 -0
- torch_geometric/nn/models/signed_gcn.py +1 -1
- torch_geometric/nn/models/visnet.py +2 -2
- torch_geometric/nn/norm/batch_norm.py +17 -7
- torch_geometric/nn/norm/diff_group_norm.py +7 -2
- torch_geometric/nn/norm/graph_norm.py +9 -4
- torch_geometric/nn/norm/instance_norm.py +5 -1
- torch_geometric/nn/norm/layer_norm.py +15 -7
- torch_geometric/nn/norm/msg_norm.py +8 -2
- torch_geometric/nn/pool/__init__.py +8 -4
- torch_geometric/nn/pool/cluster_pool.py +3 -4
- torch_geometric/nn/pool/connect/base.py +1 -3
- torch_geometric/nn/pool/knn.py +13 -10
- torch_geometric/nn/pool/select/base.py +1 -4
- torch_geometric/nn/to_hetero_module.py +4 -3
- torch_geometric/nn/to_hetero_transformer.py +3 -3
- torch_geometric/nn/to_hetero_with_bases_transformer.py +4 -4
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/utils.py +20 -5
- torch_geometric/sampler/__init__.py +2 -1
- torch_geometric/sampler/base.py +336 -7
- torch_geometric/sampler/hgt_sampler.py +11 -1
- torch_geometric/sampler/neighbor_sampler.py +296 -23
- torch_geometric/sampler/utils.py +93 -5
- torch_geometric/testing/__init__.py +4 -0
- torch_geometric/testing/decorators.py +35 -5
- torch_geometric/testing/distributed.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_gpse.py +49 -0
- torch_geometric/transforms/add_metapaths.py +8 -6
- torch_geometric/transforms/add_positional_encoding.py +2 -2
- torch_geometric/transforms/base_transform.py +2 -1
- torch_geometric/transforms/delaunay.py +65 -15
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -8
- torch_geometric/transforms/largest_connected_components.py +1 -1
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/normalize_features.py +3 -3
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_duplicated_edges.py +4 -2
- torch_geometric/transforms/rooted_subgraph.py +1 -1
- torch_geometric/typing.py +70 -17
- torch_geometric/utils/__init__.py +4 -1
- torch_geometric/utils/_lexsort.py +0 -9
- torch_geometric/utils/_negative_sampling.py +27 -12
- torch_geometric/utils/_scatter.py +132 -195
- torch_geometric/utils/_sort_edge_index.py +0 -2
- torch_geometric/utils/_spmm.py +16 -14
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_trim_to_layer.py +2 -2
- torch_geometric/utils/convert.py +17 -10
- torch_geometric/utils/cross_entropy.py +34 -13
- torch_geometric/utils/embedding.py +91 -2
- torch_geometric/utils/geodesic.py +4 -3
- torch_geometric/utils/influence.py +279 -0
- torch_geometric/utils/map.py +13 -9
- torch_geometric/utils/nested.py +1 -1
- torch_geometric/utils/smiles.py +3 -3
- torch_geometric/utils/sparse.py +7 -14
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +250 -5
- torch_geometric/warnings.py +11 -2
- torch_geometric/nn/nlp/__init__.py +0 -7
- torch_geometric/nn/nlp/sentence_transformer.py +0 -101
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
# The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
|
|
2
|
+
# pyright: reportIncompatibleMethodOverride=false
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn import Linear, Module, ModuleList
|
|
8
|
+
|
|
9
|
+
from torch_geometric.nn.conv import MessagePassing
|
|
10
|
+
from torch_geometric.typing import Tensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MeshCNNConv(MessagePassing):
|
|
14
|
+
r"""The convolutional layer introduced by the paper
|
|
15
|
+
`"MeshCNN: A Network With An Edge" <https://arxiv.org/abs/1809.05910>`_.
|
|
16
|
+
|
|
17
|
+
Recall that, given a set of categories :math:`C`,
|
|
18
|
+
MeshCNN is a function that takes as its input
|
|
19
|
+
a triangular mesh
|
|
20
|
+
:math:`\mathcal{m} = (V, F) \in \mathbb{R}^{|V| \times 3} \times
|
|
21
|
+
\{0,...,|V|-1\}^{3 \times |F|}`, and returns as its output
|
|
22
|
+
a :math:`|C|`-dimensional vector, whose :math:`i` th component denotes
|
|
23
|
+
the probability of the input mesh belonging to category :math:`c_i \in C`.
|
|
24
|
+
|
|
25
|
+
Let :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}`
|
|
26
|
+
denote the output value of the prior (e.g. :math:`k` th )
|
|
27
|
+
layer of our neural network. The :math:`i` th row of :math:`X^{(k)}` is a
|
|
28
|
+
:math:`\text{Dim-Out}(k)`-dimensional vector that represents the features
|
|
29
|
+
computed by the :math:`k` th layer for edge :math:`e_i` of the input mesh
|
|
30
|
+
:math:`\mathcal{m}`. Let :math:`A \in \{0, ..., |E|-1\}^{2 \times 4*|E|}`
|
|
31
|
+
denote the *edge adjacency* matrix of our input mesh :math:`\mathcal{m}`.
|
|
32
|
+
The :math:`j` th column of :math:`A` returns a pair of indices
|
|
33
|
+
:math:`k,l \in \{0,...,|E|-1\}`, which means that edge
|
|
34
|
+
:math:`e_k` is adjacent to edge :math:`e_l`
|
|
35
|
+
in our input mesh :math:`\mathcal{m}`.
|
|
36
|
+
The definition of edge adjacency in a triangular
|
|
37
|
+
mesh is illustrated in Figure 1.
|
|
38
|
+
In a triangular
|
|
39
|
+
mesh, each edge :math:`e_i` is expected to be adjacent to exactly :math:`4`
|
|
40
|
+
neighboring edges, hence the number of columns of :math:`A`: :math:`4*|E|`.
|
|
41
|
+
We write *the neighborhood* of edge :math:`e_i` as
|
|
42
|
+
:math:`\mathcal{N}(i) = (a(i), b(i), c(i), d(i))` where
|
|
43
|
+
|
|
44
|
+
1. :math:`a(i)` denotes the index of the *first* counter-clockwise
|
|
45
|
+
edge of the face *above* :math:`e_i`.
|
|
46
|
+
|
|
47
|
+
2. :math:`b(i)` denotes the index of the *second* counter-clockwise
|
|
48
|
+
edge of the face *above* :math:`e_i`.
|
|
49
|
+
|
|
50
|
+
3. :math:`c(i)` denotes the index of the *first* counter-clockwise edge
|
|
51
|
+
of the face *below* :math:`e_i`.
|
|
52
|
+
|
|
53
|
+
4. :math:`d(i)` denotes the index of the *second*
|
|
54
|
+
counter-clockwise edge of the face *below* :math:`e_i`.
|
|
55
|
+
|
|
56
|
+
.. figure:: ../_figures/meshcnn_edge_adjacency.svg
|
|
57
|
+
:align: center
|
|
58
|
+
:width: 80%
|
|
59
|
+
|
|
60
|
+
**Figure 1:** The neighbors of edge :math:`\mathbf{e_1}`
|
|
61
|
+
are :math:`\mathbf{e_2}, \mathbf{e_3}, \mathbf{e_4}` and
|
|
62
|
+
:math:`\mathbf{e_5}`, respectively.
|
|
63
|
+
We write this as
|
|
64
|
+
:math:`\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)`
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
Because of this ordering constraint, :obj:`MeshCNNConv` **requires
|
|
68
|
+
that the columns of** :math:`A`
|
|
69
|
+
**be ordered in the following way**:
|
|
70
|
+
|
|
71
|
+
.. math::
|
|
72
|
+
&A[:,0] = (0, \text{The index of the "a" edge for edge } 0) \\
|
|
73
|
+
&A[:,1] = (0, \text{The index of the "b" edge for edge } 0) \\
|
|
74
|
+
&A[:,2] = (0, \text{The index of the "c" edge for edge } 0) \\
|
|
75
|
+
&A[:,3] = (0, \text{The index of the "d" edge for edge } 0) \\
|
|
76
|
+
\vdots \\
|
|
77
|
+
&A[:,4*|E|-4] =
|
|
78
|
+
\bigl(|E|-1,
|
|
79
|
+
a\bigl(|E|-1\bigr)\bigr) \\
|
|
80
|
+
&A[:,4*|E|-3] =
|
|
81
|
+
\bigl(|E|-1,
|
|
82
|
+
b\bigl(|E|-1\bigr)\bigr) \\
|
|
83
|
+
&A[:,4*|E|-2] =
|
|
84
|
+
\bigl(|E|-1,
|
|
85
|
+
c\bigl(|E|-1\bigr)\bigr) \\
|
|
86
|
+
&A[:,4*|E|-1] =
|
|
87
|
+
\bigl(|E|-1,
|
|
88
|
+
d\bigl(|E|-1\bigr)\bigr)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
Stated a bit more compactly, for every edge :math:`e_i` in the input mesh,
|
|
92
|
+
:math:`A`, should have the following entries
|
|
93
|
+
|
|
94
|
+
.. math::
|
|
95
|
+
A[:, 4*i] &= (i, a(i)) \\
|
|
96
|
+
A[:, 4*i + 1] &= (i, b(i)) \\
|
|
97
|
+
A[:, 4*i + 2] &= (i, c(i)) \\
|
|
98
|
+
A[:, 4*i + 3] &= (i, d(i))
|
|
99
|
+
|
|
100
|
+
To summarize so far, we have defined 3 things:
|
|
101
|
+
|
|
102
|
+
1. The activation of the prior (e.g. :math:`k` th) layer,
|
|
103
|
+
:math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}`
|
|
104
|
+
|
|
105
|
+
2. The edge adjacency matrix and the definition of edge adjacency.
|
|
106
|
+
:math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}`
|
|
107
|
+
|
|
108
|
+
3. The ways the columns of :math:`A` must be ordered.
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
We are now finally able to define the :obj:`MeshCNNConv` class/layer.
|
|
113
|
+
In the following definition
|
|
114
|
+
we assume :obj:`MeshCNNConv` is at the :math:`k+1` th layer of our
|
|
115
|
+
neural network.
|
|
116
|
+
|
|
117
|
+
The :obj:`MeshCNNConv` layer is a function,
|
|
118
|
+
|
|
119
|
+
.. math::
|
|
120
|
+
\text{MeshCNNConv}^{(k+1)}(X^{(k)}, A) = X^{(k+1)},
|
|
121
|
+
|
|
122
|
+
that, given the prior layer's output
|
|
123
|
+
:math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}`
|
|
124
|
+
and the edge adjacency matrix :math:`A`
|
|
125
|
+
of the input mesh (graph) :math:`\mathcal{m}` ,
|
|
126
|
+
returns a new edge feature tensor
|
|
127
|
+
:math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k+1)}`,
|
|
128
|
+
where the :math:`i` th row of :math:`X^{(k+1)}`, denoted by
|
|
129
|
+
:math:`x^{(k+1)}_i`,
|
|
130
|
+
represents the :math:`\text{Dim-Out}(k+1)`-dimensional feature vector
|
|
131
|
+
of edge :math:`e_i`, **and is defined as follows**:
|
|
132
|
+
|
|
133
|
+
.. math::
|
|
134
|
+
x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\
|
|
135
|
+
&+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\
|
|
136
|
+
&+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\
|
|
137
|
+
&+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\
|
|
138
|
+
&+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr).
|
|
139
|
+
|
|
140
|
+
:math:`W_0^{(k+1)},W_1^{(k+1)},W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)}
|
|
141
|
+
\in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}`
|
|
142
|
+
are trainable linear functions (i.e. "the weights" of this layer).
|
|
143
|
+
:math:`x_i` is the :math:`\text{Dim-Out}(k)`-dimensional feature of
|
|
144
|
+
edge :math:`e_i` vector computed by the prior (e.g. :math:`k`) th layer.
|
|
145
|
+
:math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, and
|
|
146
|
+
:math:`x^{(k)}_{d(i)}` are the :math:`\text{Dim-Out}(k)`-feature vectors,
|
|
147
|
+
computed in the :math:`k` th layer, that are associated with the :math:`4`
|
|
148
|
+
neighboring edges of :math:`e_i`.
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
in_channels (int): Corresponds to :math:`\text{Dim-Out}(k)`
|
|
153
|
+
in the above overview. This
|
|
154
|
+
represents the output dimension of the prior layer. For the given
|
|
155
|
+
input mesh :math:`\mathcal{m} = (V, F)`, the prior layer is
|
|
156
|
+
expected to output a
|
|
157
|
+
:math:`X \in \mathbb{R}^{|E| \times \textit{in_channels}}`
|
|
158
|
+
feature matrix.
|
|
159
|
+
Assuming the instance of this class
|
|
160
|
+
is situated at layer :math:`k+1`, we write that
|
|
161
|
+
:math:`X^{(k)} \in \mathbb{R}^{|E| \times \textit{in_channels}}`.
|
|
162
|
+
out_channels (int): Corresponds to :math:`\text{Dim-Out}(k+1)` in the
|
|
163
|
+
above overview. This represents the output dimension of this layer.
|
|
164
|
+
Assuming the instance of this class
|
|
165
|
+
is situated at layer :math:`k+1`, we write that
|
|
166
|
+
:math:`X^{(k+1)}
|
|
167
|
+
\in \mathbb{R}^{|E| \times \textit{out_channels}}`.
|
|
168
|
+
kernels (torch.nn.ModuleList, optional): A list of length of 5,
|
|
169
|
+
where each
|
|
170
|
+
element is a :class:`torch.nn.module` (i.e a neural network),
|
|
171
|
+
that each MUST take as input a vector
|
|
172
|
+
of dimension :`obj:in_channels` and return a vector of dimension
|
|
173
|
+
:obj:`out_channels`. In particular,
|
|
174
|
+
`obj:kernels[0]` is :math:`W^{(k+1)}_0` in the above overview
|
|
175
|
+
(see :obj:`MeshCNNConv`), `obj:kernels[1]` is :math:`W^{(k+1)}_1`,
|
|
176
|
+
`obj:kernels[2]` is :math:`W^{(k+1)}_2`,
|
|
177
|
+
`obj:kernels[3]` is :math:`W^{(k+1)}_3`
|
|
178
|
+
`obj:kernels[4]` is :math:`W^{(k+1)}_4`.
|
|
179
|
+
Note that this input is optional, in which case
|
|
180
|
+
each of the 5 elements in the kernels will be a linear
|
|
181
|
+
neural network :class:`torch.nn.modules.Linear`
|
|
182
|
+
correctly configured to take as input
|
|
183
|
+
:attr:`in_channels`-dimensional vectors and return
|
|
184
|
+
a vector of dimensions :attr:`out_channels`.
|
|
185
|
+
|
|
186
|
+
Discussion:
|
|
187
|
+
The key difference that separates :obj:`MeshCNNConv` from a traditional
|
|
188
|
+
message passing graph neural network is that :obj:`MeshCNNConv`
|
|
189
|
+
requires the set of neighbors for a node
|
|
190
|
+
:math:`\mathcal{N}(u) = (v_1, v_2, ...)`
|
|
191
|
+
to *be an ordered set* (i.e. a tuple). In
|
|
192
|
+
fact, :obj:`MeshCNNConv` goes further, requiring
|
|
193
|
+
that :math:`\mathcal{N}(u)` always return a set of size :math:`4`.
|
|
194
|
+
This is different to most message passing graph neural networks,
|
|
195
|
+
which assume that :math:`\mathcal{N}(u) = \{v_1, v_2, ...\}` returns an
|
|
196
|
+
ordered set. This lends :obj:`MeshCNNConv` more expressive power,
|
|
197
|
+
at the cost of no longer being permutation invariant to
|
|
198
|
+
:math:`\mathbb{S}_4`. Put more plainly, in tradition message passing
|
|
199
|
+
GNNs, the network is *unable* to distinguish one neighboring node
|
|
200
|
+
from another.
|
|
201
|
+
In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a
|
|
202
|
+
"role", either the "a", "b", "c", or "d" neighbor. We encode this fact
|
|
203
|
+
by requiring that :math:`\mathcal{N}` return the 4-tuple,
|
|
204
|
+
where the first component is the "a" neighbor, and so on.
|
|
205
|
+
|
|
206
|
+
To summarize this comparison, it may re-define
|
|
207
|
+
:obj:`MeshCNNConv` in terms of :math:`\text{UPDATE}` and
|
|
208
|
+
:math:`\text{AGGREGATE}`
|
|
209
|
+
functions, which is a general way to define a traditional GNN layer.
|
|
210
|
+
If we let :math:`x_i^{(k+1)}`
|
|
211
|
+
denote the output of a GNN layer for node :math:`i` at
|
|
212
|
+
layer :math:`k+1`, and let
|
|
213
|
+
:math:`\mathcal{N}(i)` denote the set of nodes adjacent
|
|
214
|
+
to node :math:`i`,
|
|
215
|
+
then we can describe the :math:`k+1` th layer as traditional GNN
|
|
216
|
+
as
|
|
217
|
+
|
|
218
|
+
.. math::
|
|
219
|
+
x_i^{(k+1)} = \text{UPDATE}^{(k+1)}\bigl(x^{(k)}_i,
|
|
220
|
+
\text{AGGREGATE}^{(k+1)}\bigl(\mathcal{N}(i)\bigr)\bigr).
|
|
221
|
+
|
|
222
|
+
Here, :math:`\text{UPDATE}^{(k+1)}` is a function of :math:`2`
|
|
223
|
+
:math:`\text{Dim-Out}(k)`-dimensional vectors, and returns a
|
|
224
|
+
:math:`\text{Dim-Out}(k+1)`-dimensional vector.
|
|
225
|
+
:math:`\text{AGGREGATE}^{(k+1)}` function
|
|
226
|
+
is a function of a *unordered set*
|
|
227
|
+
of nodes that are neighbors of node :math:`i`, as defined by
|
|
228
|
+
:math:`\mathcal{N}(i)`. Usually the size of this set varies across
|
|
229
|
+
different nodes :math:`i`, and one of the most basic examples
|
|
230
|
+
of such a function is the "sum aggregation", defined as
|
|
231
|
+
:math:`\text{AGGREGATE}^{(k+1)}(\mathcal{N}(i)) =
|
|
232
|
+
\sum_{j \in \mathcal{N}(i)} x^{(k)}_j`.
|
|
233
|
+
See
|
|
234
|
+
:class:`SumAggregation <torch_geometric.nn.aggr.basic.SumAggregation>`
|
|
235
|
+
for more.
|
|
236
|
+
|
|
237
|
+
In contrast, while :obj:`MeshCNNConv` 's :math:`\text{UPDATE}`
|
|
238
|
+
function follows
|
|
239
|
+
a tradition GNN, its :math:`\text{AGGREGATE}` is a function of a tuple
|
|
240
|
+
(i.e. an ordered set) of neighbors
|
|
241
|
+
rather than a unordered set of neighbors.
|
|
242
|
+
In particular, while the :math:`\text{UPDATE}`
|
|
243
|
+
function of :obj:`MeshCNNConv` for :math:`e_i` is
|
|
244
|
+
|
|
245
|
+
.. math::
|
|
246
|
+
x_i^{(k+1)} = \text{UPDATE}^{(k+1)}(x_i^{(k)}, s_i^{(k+1)})
|
|
247
|
+
= W_0^{(k+1)}x_i^{(k)} + s_i^{(k+1)},
|
|
248
|
+
|
|
249
|
+
in contrast, :obj:`MeshCNNConv` 's :math:`\text{AGGREGATE}` function is
|
|
250
|
+
|
|
251
|
+
.. math::
|
|
252
|
+
s_i^{(k+1)} = \text{AGGREGATE}^{(k+1)}(A, B, C, D)
|
|
253
|
+
&= W_1^{(k+1)}\bigl|A - C \bigr| \\
|
|
254
|
+
&= W_2^{(k+1)}\bigl(A + C \bigr) \\
|
|
255
|
+
&= W_3^{(k+1)}\bigl|B - D \bigr| \\
|
|
256
|
+
&= W_4^{(k+1)}\bigl(B + D \bigr),
|
|
257
|
+
|
|
258
|
+
where :math:`A=x_{a(i)}^{(k)}, B=x_{b(i)}^{(k)}, C=x_{c(i)}^{(k)},`
|
|
259
|
+
and :math:`D=x_{d(i)}^{(k)}`.
|
|
260
|
+
|
|
261
|
+
..
|
|
262
|
+
|
|
263
|
+
The :math:`i` th row of
|
|
264
|
+
:math:`V \in \mathbb{R}^{|V| \times 3}`
|
|
265
|
+
holds the cartesian :math:`xyz`
|
|
266
|
+
coordinates for node :math:`v_i` in the mesh, and the :math:`j` th
|
|
267
|
+
column in :math:`F \in \{1,...,|V|\}^{3 \times |V|}`
|
|
268
|
+
holds the :math:`3` indices
|
|
269
|
+
:math:`(k,l,m)` that correspond to the :math:`3` nodes
|
|
270
|
+
:math:`(v_k, v_l, v_m)` that construct face :math:`j` of the mesh.
|
|
271
|
+
"""
|
|
272
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
273
|
+
kernels: Optional[ModuleList] = None):
|
|
274
|
+
super().__init__(aggr='add')
|
|
275
|
+
self.in_channels = in_channels
|
|
276
|
+
self.out_channels = out_channels
|
|
277
|
+
|
|
278
|
+
if kernels is None:
|
|
279
|
+
self.kernels = ModuleList(
|
|
280
|
+
[Linear(in_channels, out_channels) for _ in range(5)])
|
|
281
|
+
|
|
282
|
+
else:
|
|
283
|
+
# ensures kernels is properly formed, otherwise throws
|
|
284
|
+
# the appropriate error.
|
|
285
|
+
self._assert_kernels(kernels)
|
|
286
|
+
self.kernels = kernels
|
|
287
|
+
|
|
288
|
+
def forward(self, x: Tensor, edge_index: Tensor):
|
|
289
|
+
r"""Forward pass.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
x(torch.Tensor): :math:`X^{(k)} \in
|
|
293
|
+
\mathbb{R}^{|E| \times \textit{in_channels}}`.
|
|
294
|
+
The edge feature tensor returned by the prior layer
|
|
295
|
+
(e.g. :math:`k`). The tensor is of shape
|
|
296
|
+
:math:`|E| \times \text{Dim-Out}(k)`, or equivalently,
|
|
297
|
+
:obj:`(|E|, self.in_channels)`.
|
|
298
|
+
|
|
299
|
+
edge_index(torch.Tensor):
|
|
300
|
+
:math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}`.
|
|
301
|
+
The edge adjacency tensor of the networks input mesh
|
|
302
|
+
:math:`\mathcal{m} = (V, F)`. The edge adjacency tensor
|
|
303
|
+
**MUST** have the following form:
|
|
304
|
+
|
|
305
|
+
.. math::
|
|
306
|
+
&A[:,0] = (0,
|
|
307
|
+
\text{The index of the "a" edge for edge } 0) \\
|
|
308
|
+
&A[:,1] = (0,
|
|
309
|
+
\text{The index of the "b" edge for edge } 0) \\
|
|
310
|
+
&A[:,2] = (0,
|
|
311
|
+
\text{The index of the "c" edge for edge } 0) \\
|
|
312
|
+
&A[:,3] = (0,
|
|
313
|
+
\text{The index of the "d" edge for edge } 0) \\
|
|
314
|
+
\vdots \\
|
|
315
|
+
&A[:,4*|E|-4] =
|
|
316
|
+
\bigl(|E|-1,
|
|
317
|
+
a\bigl(|E|-1\bigr)\bigr) \\
|
|
318
|
+
&A[:,4*|E|-3] =
|
|
319
|
+
\bigl(|E|-1,
|
|
320
|
+
b\bigl(|E|-1\bigr)\bigr) \\
|
|
321
|
+
&A[:,4*|E|-2] =
|
|
322
|
+
\bigl(|E|-1,
|
|
323
|
+
c\bigl(|E|-1\bigr)\bigr) \\
|
|
324
|
+
&A[:,4*|E|-1] =
|
|
325
|
+
\bigl(|E|-1,
|
|
326
|
+
d\bigl(|E|-1\bigr)\bigr)
|
|
327
|
+
|
|
328
|
+
See :obj:`MeshCNNConv` for what
|
|
329
|
+
"index of the 'a'(b,c,d) edge for edge i" means, and also
|
|
330
|
+
for the general definition of edge adjacency in MeshCNN.
|
|
331
|
+
These definitions are also provided in the
|
|
332
|
+
`paper <https://arxiv.org/abs/1809.05910>`_ itself.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
torch.Tensor:
|
|
336
|
+
:math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \textit{out_channels}}`.
|
|
337
|
+
The edge feature tensor for this (e.g. the :math:`k+1` th) layer.
|
|
338
|
+
The :math:`i` th row of :math:`X^{(k+1)}` is computed according
|
|
339
|
+
to the formula
|
|
340
|
+
|
|
341
|
+
.. math::
|
|
342
|
+
x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\
|
|
343
|
+
&+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\
|
|
344
|
+
&+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\
|
|
345
|
+
&+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\
|
|
346
|
+
&+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr),
|
|
347
|
+
|
|
348
|
+
where :math:`W_0^{(k+1)},W_1^{(k+1)},
|
|
349
|
+
W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)}
|
|
350
|
+
\in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}`
|
|
351
|
+
are the trainable linear functions (i.e. the trainable
|
|
352
|
+
"weights") of this layer, and
|
|
353
|
+
:math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`,
|
|
354
|
+
:math:`x^{(k)}_{d(i)}` are the
|
|
355
|
+
:math:`\text{Dim-Out}(k)`-dimensional edge feature vectors
|
|
356
|
+
computed by the prior (:math:`k` th) layer,
|
|
357
|
+
that are associated with the :math:`4`
|
|
358
|
+
neighboring edges of :math:`e_i`.
|
|
359
|
+
|
|
360
|
+
"""
|
|
361
|
+
return self.propagate(edge_index, x=x)
|
|
362
|
+
|
|
363
|
+
def message(self, x_j: Tensor) -> Tensor:
|
|
364
|
+
r"""The messaging passing step of :obj:`MeshCNNConv`.
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
x_j: A :obj:`[4*|E|, num_node_features]` tensor.
|
|
369
|
+
Its ith row holds the value
|
|
370
|
+
stored by the source node in the previous layer of edge i.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
A :obj:`[|E|, num_node_features]` tensor,
|
|
374
|
+
whose ith row will be the value
|
|
375
|
+
that the target node of edge i will receive.
|
|
376
|
+
"""
|
|
377
|
+
# The following variables names are taken from the paper
|
|
378
|
+
# MeshCNN computes the features associated with edge
|
|
379
|
+
# e by (|a - c|, a + c, |b - c|, b + c), where a, b, c, d are the
|
|
380
|
+
# neighboring edges of e, a being the 1 edge of the upper face,
|
|
381
|
+
# b being the second edge of the upper face, c being the first edge
|
|
382
|
+
# of the lower face,
|
|
383
|
+
# and d being the second edge of the lower face of the input Mesh
|
|
384
|
+
|
|
385
|
+
# TODO: It is unclear if view is faster. If it is not,
|
|
386
|
+
# then we should prefer the strided method commented out below
|
|
387
|
+
|
|
388
|
+
E4, in_channels = x_j.size() # E4 = 4|E|, i.e. num edges in line graph
|
|
389
|
+
# Option 1
|
|
390
|
+
n_a = x_j[0::4] # shape: |E| x in_channels
|
|
391
|
+
n_b = x_j[1::4] # shape: |E| x in_channels
|
|
392
|
+
n_c = x_j[2::4] # shape: |E| x in_channels
|
|
393
|
+
n_d = x_j[3::4] # shape: |E| x in_channels
|
|
394
|
+
m = torch.empty(E4, self.out_channels)
|
|
395
|
+
m[0::4] = self.kernels[1].forward(torch.abs(n_a - n_c))
|
|
396
|
+
m[1::4] = self.kernels[2].forward(n_a + n_c)
|
|
397
|
+
m[2::4] = self.kernels[3].forward(torch.abs(n_b - n_d))
|
|
398
|
+
m[3::4] = self.kernels[4].forward(n_b + n_d)
|
|
399
|
+
return m
|
|
400
|
+
|
|
401
|
+
# Option 2
|
|
402
|
+
# E4, in_channels = x_j.size()
|
|
403
|
+
# E = E4 // 4
|
|
404
|
+
# x_j = x_j.view(E, 4, in_channels) # shape: (|E| x 4 x in_channels)
|
|
405
|
+
# n_a, n_b, n_c, n_d = x_j.unbind(
|
|
406
|
+
# dim=1) # shape: (4 x |E| x in_channels)
|
|
407
|
+
# m = torch.stack(
|
|
408
|
+
# [
|
|
409
|
+
# (n_a - n_c).abs(), # shape: |E| x in_channels
|
|
410
|
+
# n_a + n_c,
|
|
411
|
+
# (n_b - n_d).abs(),
|
|
412
|
+
# n_b + n_d,
|
|
413
|
+
# ],
|
|
414
|
+
# dim=1) # shape: (|E| x 4 x in_channels)
|
|
415
|
+
# m.view(E4, in_channels) # shape 4*|E| x in_channels
|
|
416
|
+
# return m
|
|
417
|
+
|
|
418
|
+
def update(self, inputs: Tensor, x: Tensor) -> Tensor:
|
|
419
|
+
r"""The UPDATE step, in reference to the UPDATE and AGGREGATE
|
|
420
|
+
formulation of message passing convolution.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
inputs(torch.Tensor): The :attr:`in_channels`-dimensional vector
|
|
424
|
+
returned by aggregate.
|
|
425
|
+
x(torch.Tensor): :math:`X^{(k)}`. The original inputs to this layer.
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
torch.Tensor: :math:`X^{(k+1)}`. The output of this layer, which
|
|
429
|
+
has shape :obj:`(|E|, out_channels)`.
|
|
430
|
+
"""
|
|
431
|
+
return self.kernels[0].forward(x) + inputs
|
|
432
|
+
|
|
433
|
+
def _assert_kernels(self, kernels: ModuleList):
|
|
434
|
+
r"""Ensures that :obj:`kernels` is a list of 5 :obj:`torch.nn.Module`
|
|
435
|
+
modules (i.e. networks). In addition, it also ensures that each network
|
|
436
|
+
takes in input of dimension :attr:`in_channels`, and returns output
|
|
437
|
+
of dimension :attr:`out_channels`.
|
|
438
|
+
This method throws an error otherwise.
|
|
439
|
+
|
|
440
|
+
.. warn::
|
|
441
|
+
This method throws an error if :obj:`kernels` is
|
|
442
|
+
not valid. (Otherwise this method returns nothing)
|
|
443
|
+
|
|
444
|
+
"""
|
|
445
|
+
assert isinstance(kernels, ModuleList), \
|
|
446
|
+
f"Parameter 'kernels' must be a \
|
|
447
|
+
torch.nn.module.ModuleList with 5 members, but we got \
|
|
448
|
+
{type(kernels)}."
|
|
449
|
+
|
|
450
|
+
assert len(kernels) == 5, "Parameter 'kernels' must be a \
|
|
451
|
+
torch.nn.module.ModuleList of with exactly 5 members"
|
|
452
|
+
|
|
453
|
+
for i, network in enumerate(kernels):
|
|
454
|
+
assert isinstance(network, Module), \
|
|
455
|
+
f"kernels[{i}] must be torch.nn.Module, got \
|
|
456
|
+
{type(network)}"
|
|
457
|
+
if not hasattr(network, "in_channels") and \
|
|
458
|
+
not hasattr(network, "in_features"):
|
|
459
|
+
warnings.warn(
|
|
460
|
+
f"kernel[{i}] does not have attribute 'in_channels' nor "
|
|
461
|
+
f"'out_features'. The network must take as input a "
|
|
462
|
+
f"{self.in_channels}-dimensional tensor.", stacklevel=2)
|
|
463
|
+
else:
|
|
464
|
+
input_dimension = getattr(network, "in_channels",
|
|
465
|
+
network.in_features)
|
|
466
|
+
assert input_dimension == self.in_channels, f"The input \
|
|
467
|
+
dimension of the neural network in kernel[{i}] must \
|
|
468
|
+
be \
|
|
469
|
+
equal to 'in_channels', but input_dimension = \
|
|
470
|
+
{input_dimension}, and \
|
|
471
|
+
self.in_channels={self.in_channels}."
|
|
472
|
+
|
|
473
|
+
if not hasattr(network, "out_channels") and \
|
|
474
|
+
not hasattr(network, "out_features"):
|
|
475
|
+
warnings.warn(
|
|
476
|
+
f"kernel[{i}] does not have attribute 'in_channels' nor "
|
|
477
|
+
f"'out_features'. The network must take as input a "
|
|
478
|
+
f"{self.in_channels}-dimensional tensor.", stacklevel=2)
|
|
479
|
+
else:
|
|
480
|
+
output_dimension = getattr(network, "out_channels",
|
|
481
|
+
network.out_features)
|
|
482
|
+
assert output_dimension == self.out_channels, f"The output \
|
|
483
|
+
dimension of the neural network in kernel[{i}] must \
|
|
484
|
+
be \
|
|
485
|
+
equal to 'out_channels', but out_dimension = \
|
|
486
|
+
{output_dimension}, and \
|
|
487
|
+
self.out_channels={self.out_channels}."
|
|
@@ -276,7 +276,7 @@ class MessagePassing(torch.nn.Module):
|
|
|
276
276
|
f"{index.min().item()}). Please ensure that all "
|
|
277
277
|
f"indices in 'edge_index' point to valid indices "
|
|
278
278
|
f"in the interval [0, {src.size(self.node_dim)}) in "
|
|
279
|
-
f"your node feature matrix and try again.")
|
|
279
|
+
f"your node feature matrix and try again.") from e
|
|
280
280
|
|
|
281
281
|
if (index.numel() > 0 and index.max() >= src.size(self.node_dim)):
|
|
282
282
|
raise IndexError(
|
|
@@ -285,7 +285,7 @@ class MessagePassing(torch.nn.Module):
|
|
|
285
285
|
f"{index.max().item()}). Please ensure that all "
|
|
286
286
|
f"indices in 'edge_index' point to valid indices "
|
|
287
287
|
f"in the interval [0, {src.size(self.node_dim)}) in "
|
|
288
|
-
f"your node feature matrix and try again.")
|
|
288
|
+
f"your node feature matrix and try again.") from e
|
|
289
289
|
|
|
290
290
|
raise e
|
|
291
291
|
|
|
@@ -1029,6 +1029,7 @@ class MessagePassing(torch.nn.Module):
|
|
|
1029
1029
|
:meth:`jittable` is deprecated and a no-op from :pyg:`PyG` 2.5
|
|
1030
1030
|
onwards.
|
|
1031
1031
|
"""
|
|
1032
|
-
warnings.warn(
|
|
1033
|
-
|
|
1032
|
+
warnings.warn(
|
|
1033
|
+
f"'{self.__class__.__name__}.jittable' is deprecated "
|
|
1034
|
+
f"and a no-op. Please remove its usage.", stacklevel=2)
|
|
1034
1035
|
return self
|
|
@@ -120,7 +120,8 @@ class RGCNConv(MessagePassing):
|
|
|
120
120
|
in_channels = (in_channels, in_channels)
|
|
121
121
|
self.in_channels_l = in_channels[0]
|
|
122
122
|
|
|
123
|
-
self._use_segment_matmul_heuristic_output:
|
|
123
|
+
self._use_segment_matmul_heuristic_output: torch.jit.Attribute(
|
|
124
|
+
None, Optional[float])
|
|
124
125
|
|
|
125
126
|
if num_bases is not None:
|
|
126
127
|
self.weight = Parameter(
|
|
@@ -90,7 +90,7 @@ class SGConv(MessagePassing):
|
|
|
90
90
|
edge_index, edge_weight, x.size(self.node_dim), False,
|
|
91
91
|
self.add_self_loops, self.flow, dtype=x.dtype)
|
|
92
92
|
|
|
93
|
-
for
|
|
93
|
+
for _ in range(self.K):
|
|
94
94
|
# propagate_type: (x: Tensor, edge_weight: OptTensor)
|
|
95
95
|
x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
|
|
96
96
|
if self.cached:
|
|
@@ -132,7 +132,8 @@ class SplineConv(MessagePassing):
|
|
|
132
132
|
if not x[0].is_cuda:
|
|
133
133
|
warnings.warn(
|
|
134
134
|
'We do not recommend using the non-optimized CPU version of '
|
|
135
|
-
'`SplineConv`. If possible, please move your data to GPU.'
|
|
135
|
+
'`SplineConv`. If possible, please move your data to GPU.',
|
|
136
|
+
stacklevel=2)
|
|
136
137
|
|
|
137
138
|
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
|
|
138
139
|
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
|
|
@@ -100,7 +100,7 @@ class SSGConv(MessagePassing):
|
|
|
100
100
|
self.add_self_loops, self.flow, dtype=x.dtype)
|
|
101
101
|
|
|
102
102
|
h = x * self.alpha
|
|
103
|
-
for
|
|
103
|
+
for _ in range(self.K):
|
|
104
104
|
# propagate_type: (x: Tensor, edge_weight: OptTensor)
|
|
105
105
|
x = self.propagate(edge_index, x=x, edge_weight=edge_weight)
|
|
106
106
|
h = h + (1 - self.alpha) / self.K * x
|
|
@@ -126,9 +126,11 @@ class TransformerConv(MessagePassing):
|
|
|
126
126
|
if isinstance(in_channels, int):
|
|
127
127
|
in_channels = (in_channels, in_channels)
|
|
128
128
|
|
|
129
|
-
self.lin_key = Linear(in_channels[0], heads * out_channels)
|
|
130
|
-
self.lin_query = Linear(in_channels[1], heads * out_channels
|
|
131
|
-
|
|
129
|
+
self.lin_key = Linear(in_channels[0], heads * out_channels, bias=bias)
|
|
130
|
+
self.lin_query = Linear(in_channels[1], heads * out_channels,
|
|
131
|
+
bias=bias)
|
|
132
|
+
self.lin_value = Linear(in_channels[0], heads * out_channels,
|
|
133
|
+
bias=bias)
|
|
132
134
|
if edge_dim is not None:
|
|
133
135
|
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
|
|
134
136
|
else:
|
|
@@ -57,10 +57,11 @@ class DataParallel(torch.nn.DataParallel):
|
|
|
57
57
|
follow_batch=None, exclude_keys=None):
|
|
58
58
|
super().__init__(module, device_ids, output_device)
|
|
59
59
|
|
|
60
|
-
warnings.warn(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
60
|
+
warnings.warn(
|
|
61
|
+
"'DataParallel' is usually much slower than "
|
|
62
|
+
"'DistributedDataParallel' even on a single machine. "
|
|
63
|
+
"Please consider switching to 'DistributedDataParallel' "
|
|
64
|
+
"for multi-GPU training.", stacklevel=2)
|
|
64
65
|
|
|
65
66
|
self.src_device = torch.device(f'cuda:{self.device_ids[0]}')
|
|
66
67
|
self.follow_batch = follow_batch or []
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import copy
|
|
2
1
|
import math
|
|
3
2
|
import sys
|
|
4
3
|
import time
|
|
@@ -114,25 +113,6 @@ class Linear(torch.nn.Module):
|
|
|
114
113
|
|
|
115
114
|
self.reset_parameters()
|
|
116
115
|
|
|
117
|
-
def __deepcopy__(self, memo):
|
|
118
|
-
# PyTorch<1.13 cannot handle deep copies of uninitialized parameters :(
|
|
119
|
-
# TODO Drop this code once PyTorch 1.12 is no longer supported.
|
|
120
|
-
out = Linear(
|
|
121
|
-
self.in_channels,
|
|
122
|
-
self.out_channels,
|
|
123
|
-
self.bias is not None,
|
|
124
|
-
self.weight_initializer,
|
|
125
|
-
self.bias_initializer,
|
|
126
|
-
).to(self.weight.device)
|
|
127
|
-
|
|
128
|
-
if self.in_channels > 0:
|
|
129
|
-
out.weight = copy.deepcopy(self.weight, memo)
|
|
130
|
-
|
|
131
|
-
if self.bias is not None:
|
|
132
|
-
out.bias = copy.deepcopy(self.bias, memo)
|
|
133
|
-
|
|
134
|
-
return out
|
|
135
|
-
|
|
136
116
|
def reset_parameters(self):
|
|
137
117
|
r"""Resets all learnable parameters of the module."""
|
|
138
118
|
reset_weight_(self.weight, self.in_channels, self.weight_initializer)
|