morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,126 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import numpy as np
16
+
17
+ from morphomatics.geom.surface import Surface
18
+
19
+
20
+ def align(src, ref):
21
+ """ (Constrained) Procrustes alignment of src to ref using Kabsch algorithm
22
+ :arg src: n-by-3 array of vertex coordinates of source object
23
+ :arg ref: n-by-3 array of vertex coordinates of reference
24
+ :returns: aligned coords.
25
+ """
26
+ n = len(ref)
27
+ # cross-covariance matrix
28
+ c_s = src.mean(axis=0)
29
+ c_r = ref.mean(axis=0)
30
+ xCov = (ref.T @ src) / n - jnp.outer(c_r, c_s)
31
+ # optimal rotation
32
+ U, S, Vt = jnp.linalg.svd(xCov)
33
+ R = Vt.T @ U.T
34
+ if jnp.linalg.det(R) < 0:
35
+ R = Vt.T @ np.diag([1, 1, -1]) @ U.T
36
+ # return aligned coords.
37
+ return src @ R + (c_r - c_s @ R)
38
+
39
+
40
+ def generalized_procrustes(surf):
41
+ """ Generalized Procrustes analysis.
42
+ :arg surf: list of surfaces to be aligned. The meshes must be in correspondence.
43
+ """
44
+ ref = Surface(jnp.copy(surf[0].v), jnp.copy(surf[0].f))
45
+ old_ref = Surface(jnp.copy(ref.v), jnp.copy(ref.f))
46
+
47
+ n_steps = 0
48
+ # do until convergence
49
+ while (jnp.linalg.norm(ref.v - old_ref.v) > 1e-11 and 1000 > n_steps) or n_steps == 0:
50
+ n_steps = n_steps + 1
51
+ old_ref = Surface(jnp.copy(ref.v), jnp.copy(ref.f))
52
+ # align meshes to reference
53
+ for i, s in enumerate(surf):
54
+ s.v = align(s.v, ref.v)
55
+
56
+ # compute new reference
57
+ v_ref = jnp.mean(jnp.array([s.v for s in surf]), axis=0)
58
+ ref.v = v_ref
59
+
60
+
61
+ def preshape(v):
62
+ """ Center point cloud at origin and normalize its size
63
+ :arg v: n-by-3 array of vertex coordinates
64
+ :returns: n-by-3 array of adjusted vertex coordinates
65
+ """
66
+ # center
67
+ v = v - 1 / v.shape[0] * jnp.tile(jnp.sum(v, axis=0), (v.shape[0], 1))
68
+ # normalize
69
+ v /= jnp.linalg.norm(v)
70
+ return v
71
+
72
+
73
+ def multiprod(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray:
74
+ # vectorized matrix - matrix multiplication
75
+ return jnp.einsum('...ij,...jk', A, B)
76
+
77
+
78
+ def multitransp(A):
79
+ # vectorized matrix transpose
80
+ return jnp.einsum('...ij->...ji', A)
81
+
82
+
83
+ def multiskew(A):
84
+ return 0.5 * (A - jnp.einsum('...ij->...ji', A))
85
+
86
+
87
+ def multisym(A):
88
+ return 0.5 * (A + jnp.einsum('...ij->...ji', A))
89
+
90
+
91
+ def vectime3d(x, A):
92
+ """
93
+ :param x: vector of length k
94
+ :param A: array of size k x n x m
95
+ :return: k x n x m array such that the j-th n x m slice of A is multiplied with the j-th element of x
96
+
97
+ In case of k=1, x * A is returned.
98
+ """
99
+ if jnp.isscalar(x) and A.ndim == 2:
100
+ return x * A
101
+
102
+ x = jnp.atleast_2d(x)
103
+ assert x.ndim <= 2 and jnp.size(A.shape) == 3
104
+ assert x.shape[0] == 1 or x.shape[1] == 1
105
+ assert x.shape[0] == A.shape[0] or x.shape[1] == A.shape[0]
106
+
107
+ if x.shape[1] == 1:
108
+ x = x.T
109
+
110
+ A = jnp.einsum('kij->ijk', A)
111
+ return jnp.einsum('ijk->kij', x * A)
112
+
113
+
114
+ def gram_schmidt(A):
115
+ """Orthogonalize a set of vectors stored as the columns of matrix A."""
116
+ # Get the number of vectors.
117
+ n = A.shape[1]
118
+ for j in range(n):
119
+ # To orthogonalize the vector in column j with respect to the
120
+ # previous vectors, subtract from it its projection onto
121
+ # each of the previous vectors.
122
+ for k in range(j):
123
+ A = A.at[:, j].set(A[:, j] - jnp.dot(A[:, k], A[:, j]) * A[:, k])
124
+ A = A.at[:, j].set(A[:, j] / jnp.linalg.norm(A[:, j]))
125
+
126
+ return A
@@ -0,0 +1,15 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ from .tangent_layers import TangentMLP, TangentInvariant
14
+ from .wFM_layers import MfdFC, MfdInvariant
15
+ from .flow_layers import FlowLayer, MfdGcnBlock
@@ -0,0 +1,219 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ from typing import Iterable, Callable
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ from jax.scipy.sparse.linalg import bicgstab
18
+
19
+ import jraph
20
+ from flax import linen as nn
21
+
22
+ from morphomatics.manifold import Manifold, PowerManifold
23
+ from morphomatics.opt import RiemannianNewtonRaphson
24
+ from morphomatics.graph.operators import mfdg_laplace
25
+ from morphomatics.nn.tangent_layers import TangentMLP
26
+
27
+
28
+ class FlowLayer(nn.Module):
29
+ """
30
+ Graph flow layer for graphs with manifold-valued features. The flow equation is integrated explicitly by default,
31
+ but an implicit scheme is also available.
32
+
33
+ See
34
+
35
+ M. Hanik, G, Steidl, C. v. Tycowicz. "Manifold GCN: Diffusion-based Convolutional Neural Network for
36
+ Manifold-valued Graphs" (https://arxiv.org/abs/2401.14381)
37
+
38
+ for a detailed description.
39
+
40
+ Inputs:
41
+ :param M: manifold in which the features lie
42
+ :param n_steps: number of explicit steps to approximate the flow with explicit Euler
43
+ :param implicit: boolean indicating whether to use implicit or explicit Euler integration
44
+ :param max_step_length: maximum step size for Euler integration
45
+
46
+ Note: Too long Euler steps can lead to numerical instabilities with some manifolds, e.g., the hyperbolic space.
47
+ In this case, a maximal step length should be ued.
48
+
49
+ """
50
+
51
+ M: Manifold
52
+ n_steps: int = 1
53
+ implicit: bool = False
54
+ max_step_length: float = jnp.inf
55
+ t_init: Callable = lambda *args: nn.initializers.truncated_normal(stddev=1.)(*args) + 1.
56
+ delta_init: Callable = lambda *args: nn.initializers.truncated_normal(stddev=1.)(*args) + 1.
57
+
58
+ def _single_euler_step(self, G: jraph.GraphsTuple, time: jnp.ndarray, delta: jnp.ndarray) -> jraph.GraphsTuple:
59
+ """Single step of the explicit Euler method for diffusion
60
+
61
+ :param G: graph with manifold valued vectors as features; length of vector must equal the flow layer width
62
+ :param time: vector of time parameters (same length as number of features)
63
+ :param delta: vector of "minimal step sizes"
64
+ :return: updated graph
65
+ """
66
+
67
+ def _multi_laplace(channel):
68
+ return mfdg_laplace(self.M, G._replace(nodes=channel))
69
+
70
+ def _activation(feature, vector, d):
71
+ nrm = jnp.sqrt(self.M.metric.inner(feature, vector, vector) + jnp.finfo(jnp.float64).eps)
72
+ alp = jax.nn.sigmoid(nrm - d)
73
+
74
+ # make sure that the step size is not larger than max_step_length
75
+ return jax.lax.cond(nrm * alp <= self.max_step_length,
76
+ lambda w: alp * w,
77
+ lambda w: w * self.max_step_length / nrm, vector)
78
+
79
+ v = jax.vmap(_multi_laplace, in_axes=1, out_axes=1)(G.nodes)
80
+
81
+ # ReLU-type activation
82
+ delta = jnp.stack([delta, ] * v.shape[0])
83
+ v = jax.vmap(jax.vmap(_activation))(G.nodes, v, delta)
84
+
85
+ v = -v * time.reshape((1, -1) + (1,) * (v.ndim - 2))
86
+ x = jax.vmap(jax.vmap(self.M.connec.exp))(G.nodes, v)
87
+ return G._replace(nodes=x)
88
+
89
+ def _implicit_euler_step(self, G: jraph.GraphsTuple, time: jnp.ndarray, delta=None) -> jraph.GraphsTuple:
90
+ """Single step of the implicit Euler method for diffusion
91
+
92
+ :param G: graph with manifold valued vectors as features; length of vector must equal the flow layer width
93
+ :param time: vector of time parameters (same length as number of features)
94
+ :param delta: only needed for out of syntax reasons
95
+ :return: updated graph
96
+ """
97
+ # n_nodes x n_channels x point_shape
98
+ n, c, *shape = G.nodes.shape
99
+
100
+ # power manifold
101
+ P = PowerManifold(self.M, n * c)
102
+
103
+ # current state
104
+ x_cur = G.nodes.reshape(-1, *shape)
105
+
106
+ # root of F characterizes solution to implicit Euler step
107
+ def F(x: jnp.array):
108
+ L = lambda a: mfdg_laplace(self.M, G._replace(nodes=a))
109
+ Lx = jax.vmap(L, in_axes=1, out_axes=1)(x.reshape(n, c, *shape))
110
+ tLx = Lx * time.reshape((1, -1) + (1,) * len(shape))
111
+ diff = P.connec.log(x, x_cur)
112
+ return diff - tLx.reshape(-1, *shape)
113
+
114
+ # x_next = RiemannianNewtonRaphson.solve(P, F, x_cur, maxiter=1)
115
+ ###############################
116
+ # unroll single interation
117
+ ###############################
118
+ # solve for update direction: v = -J⁻¹F(x)
119
+ J = lambda v: jax.jvp(F, (x_cur,), (v,))[1]
120
+ v, _ = bicgstab(J, -F(x_cur))
121
+ # step
122
+ x_next = P.connec.exp(x_cur, v)
123
+
124
+ return G._replace(nodes=x_next.reshape(n, c, *shape))
125
+
126
+ @nn.compact
127
+ def __call__(self, G: jraph.GraphsTuple) -> jraph.GraphsTuple:
128
+ """
129
+ :param G: graphs tuple with features of shape: num_nodes * num_channels * point_shape
130
+ :return: graphs tuple with features of shape: num_nodes * num_channels * point_shape
131
+
132
+ Apply discretized diffusion flow (with final activation) to each channel.
133
+ """
134
+ step_method = self._implicit_euler_step if self.implicit else self._single_euler_step
135
+
136
+ width = G.nodes.shape[1] # number of channels
137
+ ####################
138
+ t = self.param("t_sqrt", self.t_init, (width,), G.nodes.dtype)
139
+ delta = self.param("delta_sqrt", self.delta_init, (width,), G.nodes.dtype)
140
+
141
+ # map to non-negative weights
142
+ t = t ** 2
143
+ delta = delta ** 2
144
+
145
+ def step(graph, _):
146
+ graph = step_method(graph, t / self.n_steps, delta)
147
+ return graph, None
148
+
149
+ G, _ = jax.lax.scan(step, G, None, self.n_steps, unroll=self.n_steps)
150
+
151
+ return G
152
+
153
+
154
+ class MfdGcnBlock(nn.Module):
155
+ """
156
+ Manifold convolution network block as proposed in
157
+ M. Hanik, G, Steidl, C. v. Tycowicz. "Manifold GCN: Diffusion-based Convolutional Neural Network for
158
+ Manifold-valued Graphs" (https://arxiv.org/abs/2401.14381)
159
+
160
+ An explicit skip connection can be added that attaches the input to the output as an additional channels.
161
+ Note: If skip connections between each flow-tMLP unit are wanted, use several blocks with only one layer.
162
+
163
+ :param M: manifold constituting the signal domain
164
+ :param channel_sizes: sequence of channel sizes
165
+ :param n_steps: number of Euler steps that are performed in the flow layer
166
+ :param implicit: boolean indicating whether to use implicit or explicit Euler integration
167
+ :param max_step_length: maximum step size for Euler integration (see the flow layer)
168
+ :param explicit_skip: boolean indicating whether additionally to perform an explicit skip connection
169
+ :param inputs_are_copies: when true, only the first input channel is passed through by the skip connection
170
+ """
171
+
172
+ M: Manifold
173
+ channel_sizes: Iterable[int]
174
+ n_steps: int = 1
175
+ implicit: bool = False
176
+ max_step_length: float = jnp.inf
177
+ explicit_skip: bool = False
178
+ inputs_are_copies: bool = False
179
+
180
+ def setup(self):
181
+ layers = []
182
+ channel_sizes = tuple(self.channel_sizes)
183
+ for i, channel_size in enumerate(channel_sizes):
184
+ layers.append(
185
+ (
186
+ FlowLayer(self.M, self.n_steps, self.implicit, self.max_step_length),
187
+ TangentMLP(self.M, (channel_size,))
188
+ )
189
+ )
190
+ self.layers = tuple(layers)
191
+
192
+ def __call__(self, G: jraph.GraphsTuple) -> jraph.GraphsTuple:
193
+ """
194
+ :param G: graphs tuple with features of shape: num_nodes * num_channels * point_shape
195
+ :return: graphs tuple with features of shape: num_nodes * out_channels * point_shape
196
+
197
+ We use jraph pooling; hence, the batches are combined in the same graph and thus "hidden" in num_nodes.
198
+ The number of output channels is the number of output channels of the last (tangentMLP) layer plus, if
199
+ activated, the number of channels that are passed through by the skip connection (either num_channels or 1).
200
+ """
201
+
202
+ # save input for the skip connection
203
+ if self.explicit_skip:
204
+ if self.inputs_are_copies:
205
+ z = G.nodes[:, 1, None]
206
+ else:
207
+ z = G.nodes
208
+
209
+ for layer_unit in self.layers:
210
+ # flow layer
211
+ G = layer_unit[0](G)
212
+ # tangent MLP
213
+ G = G._replace(nodes=layer_unit[1](G.nodes[None])[0])
214
+
215
+ # skip connection
216
+ if self.explicit_skip:
217
+ G = G._replace(nodes=jax.lax.concatenate([z, G.nodes], 1))
218
+
219
+ return G
@@ -0,0 +1,176 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ from typing import Sequence
14
+
15
+ import numpy as np
16
+ import jax
17
+ import jax.numpy as jnp
18
+
19
+ import flax.linen as nn
20
+
21
+ from morphomatics.manifold import Manifold
22
+
23
+
24
+ class TangentMLP(nn.Module):
25
+ """
26
+ Generalized multi-layer perceptron for manifold-valued features as proposed in
27
+ M. Hanik, G, Steidl, C. v. Tycowicz. "Manifold GCN: Diffusion-based Convolutional Neural Network for
28
+ Manifold-valued Graphs" (https://arxiv.org/abs/2401.14381).
29
+
30
+ :param M: Manifold input signal takes values in
31
+ :param out_sizes: number of output feature channels (sequence thereof)
32
+ """
33
+
34
+ M: Manifold
35
+ out_sizes: Sequence[int]
36
+
37
+ @nn.compact
38
+ def __call__(self, x):
39
+ """
40
+ Apply tangent MLP layer.
41
+ :param x: input sequence with shape: batch * sequence_length * in_channel * M.point_shape
42
+ (M being the underlying manifold)
43
+ :return: output with shape: batch * sequence_length * out_channel * M.point_shape
44
+ """
45
+
46
+ n_batch, n_seq, n_in, *pt_shape = x.shape
47
+
48
+ # flatten first two axes -> shape: (batch * sequence_length) * in_channel * M.point_shape
49
+ x = x.reshape(-1, n_in, *pt_shape)
50
+
51
+ # compute logs -> tangent vectors
52
+ pernode = jax.vmap(self.M.connec.log, in_axes=(None, 0))
53
+ v = jax.vmap(pernode)(x[:, 0], x[:, 1:])
54
+
55
+ # shape into vectors
56
+ v = v.reshape(n_batch, n_seq, n_in-1, -1)
57
+
58
+ # apply vector neuron MLP
59
+ v = VectorNeuronMLP(self.out_sizes)(v)
60
+
61
+ # shape back into tangent vectors
62
+ v = v.reshape(n_batch * n_seq, -1, *pt_shape)
63
+
64
+ # map back to manifold
65
+ pernode = jax.vmap(self.M.connec.exp, in_axes=(None, 0))
66
+ y = jax.vmap(pernode)(x[:, 0], v)
67
+
68
+ return y.reshape(n_batch, n_seq, -1, *pt_shape)
69
+
70
+
71
+ class TangentInvariant(nn.Module):
72
+ """
73
+ Invariant layer for manifold-valued features extending the TangentMLP layer.
74
+ Specifically, computes inner products of the input (linearized via Log) with
75
+ a set of tangent vectors obtained from them (using TangentMLP).
76
+ Finally, the products are passed through a fully connected layer to match desired output size.
77
+
78
+ :param M: Manifold input signal takes values in
79
+ :param out_channel: number of output feature channels
80
+ :param vec_sizes: sequence of widths for TangentMLP
81
+ :param use_bias: whether to use bias in the final fully connected layer
82
+ """
83
+
84
+ M: Manifold
85
+ out_channel: int
86
+ vec_sizes: Sequence[int] = (3,)
87
+ use_bias: bool = True
88
+
89
+ @nn.compact
90
+ def __call__(self, x):
91
+ """
92
+ Apply tangent MLP layer.
93
+ :param x: input sequence with shape: batch * sequence_length * in_channel * M.point_shape
94
+ (M being the underlying manifold)
95
+ :return: output with shape: batch * sequence_length * out_channel
96
+ """
97
+
98
+ n_batch, n_seq, n_in, *pt_shape = x.shape
99
+
100
+ # flatten first two axes -> shape: (batch * sequence_length) * in_channel * M.point_shape
101
+ x = x.reshape(-1, n_in, *pt_shape)
102
+
103
+ # compute logs -> tangent vectors
104
+ log_batched = jax.vmap(self.M.connec.log, in_axes=(None, 0))
105
+ v = jax.vmap(log_batched)(x[:, 0], x[:, 1:])
106
+
107
+ # shape into vectors
108
+ w = v.reshape(n_batch, n_seq, n_in-1, -1)
109
+
110
+ # apply vector neuron MLP
111
+ w = VectorNeuronMLP(self.vec_sizes)(w)
112
+
113
+ # shape back into tangent vectors
114
+ w = w.reshape(n_batch * n_seq, -1, *pt_shape)
115
+
116
+ # lower indices of tangent vectors (either v or w)
117
+ flat_batched = jax.vmap(self.M.metric.flat, in_axes=(None, 0))
118
+ if w.shape[1] > n_in:
119
+ v = jax.vmap(flat_batched)(x[:, 0], v)
120
+ else:
121
+ w = jax.vmap(flat_batched)(x[:, 0], w)
122
+ # compute inner products
123
+ y = jnp.einsum('...ij,...kj', v.reshape(*v.shape[:2], -1), w.reshape(*w.shape[:2], -1))
124
+
125
+ f = nn.Dense(self.out_channel, use_bias=self.use_bias)
126
+ return f(y.reshape(n_batch, n_seq, -1))
127
+
128
+
129
+ class VectorNeuronMLP(nn.Module):
130
+ """
131
+ Vector Neuron MLP layer as described in
132
+ Deng, C., Litany, O., Duan, Y., Poulenard, A., Tagliasacchi, A., & Guibas, L. J. (2021).
133
+ Vector neurons: A general framework for SO(3)-equivariant networks.
134
+ In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 12200-12209).
135
+
136
+ :param output_sizes: sequence (length: m+1) of layer widths
137
+ """
138
+
139
+ output_sizes: Sequence[int]
140
+ negative_slope: float = 0.2
141
+
142
+ @nn.compact
143
+ def __call__(self, x: jnp.array):
144
+ """
145
+ Apply layer.
146
+ :param x: input sequence with shape: batch * sequence_length * in_channel * vector_dim
147
+ :return: output with shape: batch * sequence_length * out_channel * vector_shape
148
+ """
149
+
150
+ n_in = x.shape[2]
151
+
152
+ # apply layers
153
+ for i, n_out in enumerate(self.output_sizes):
154
+ # initialize weights
155
+ u_init = nn.initializers.truncated_normal(stddev=np.sqrt(2/n_in))
156
+ w_init = lambda *args: nn.initializers.truncated_normal(stddev=1. / np.sqrt(n_in))(*args) + 1. / n_in
157
+
158
+ U = self.param(f"U_{i}", u_init, (n_out, n_in), x.dtype)
159
+ W = self.param(f"W_{i}", w_init, (n_out, n_in), x.dtype)
160
+
161
+ n_in = n_out
162
+
163
+ # compute direction k (and its squared norm)
164
+ k = jnp.einsum('ij,...jk', U, x)
165
+ sqnrm_k = jnp.sum(k**2, axis=-1, keepdims=True) + np.finfo(np.float64).eps
166
+
167
+ # compute feature q
168
+ q = jnp.einsum('ij,...jk', W, x)
169
+
170
+ # Rotation-equivariant ReLU
171
+ dot_qk = jnp.sum(q * k, axis=-1, keepdims=True)
172
+ x = q + k * jax.nn.relu(-dot_qk) / sqnrm_k
173
+ # Leaky ReLU (weighted average of q and ReLU(q))
174
+ x = self.negative_slope * q + (1 - self.negative_slope) * x
175
+
176
+ return x