morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from morphomatics.geom.surface import Surface
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def align(src, ref):
|
|
21
|
+
""" (Constrained) Procrustes alignment of src to ref using Kabsch algorithm
|
|
22
|
+
:arg src: n-by-3 array of vertex coordinates of source object
|
|
23
|
+
:arg ref: n-by-3 array of vertex coordinates of reference
|
|
24
|
+
:returns: aligned coords.
|
|
25
|
+
"""
|
|
26
|
+
n = len(ref)
|
|
27
|
+
# cross-covariance matrix
|
|
28
|
+
c_s = src.mean(axis=0)
|
|
29
|
+
c_r = ref.mean(axis=0)
|
|
30
|
+
xCov = (ref.T @ src) / n - jnp.outer(c_r, c_s)
|
|
31
|
+
# optimal rotation
|
|
32
|
+
U, S, Vt = jnp.linalg.svd(xCov)
|
|
33
|
+
R = Vt.T @ U.T
|
|
34
|
+
if jnp.linalg.det(R) < 0:
|
|
35
|
+
R = Vt.T @ np.diag([1, 1, -1]) @ U.T
|
|
36
|
+
# return aligned coords.
|
|
37
|
+
return src @ R + (c_r - c_s @ R)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def generalized_procrustes(surf):
|
|
41
|
+
""" Generalized Procrustes analysis.
|
|
42
|
+
:arg surf: list of surfaces to be aligned. The meshes must be in correspondence.
|
|
43
|
+
"""
|
|
44
|
+
ref = Surface(jnp.copy(surf[0].v), jnp.copy(surf[0].f))
|
|
45
|
+
old_ref = Surface(jnp.copy(ref.v), jnp.copy(ref.f))
|
|
46
|
+
|
|
47
|
+
n_steps = 0
|
|
48
|
+
# do until convergence
|
|
49
|
+
while (jnp.linalg.norm(ref.v - old_ref.v) > 1e-11 and 1000 > n_steps) or n_steps == 0:
|
|
50
|
+
n_steps = n_steps + 1
|
|
51
|
+
old_ref = Surface(jnp.copy(ref.v), jnp.copy(ref.f))
|
|
52
|
+
# align meshes to reference
|
|
53
|
+
for i, s in enumerate(surf):
|
|
54
|
+
s.v = align(s.v, ref.v)
|
|
55
|
+
|
|
56
|
+
# compute new reference
|
|
57
|
+
v_ref = jnp.mean(jnp.array([s.v for s in surf]), axis=0)
|
|
58
|
+
ref.v = v_ref
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def preshape(v):
|
|
62
|
+
""" Center point cloud at origin and normalize its size
|
|
63
|
+
:arg v: n-by-3 array of vertex coordinates
|
|
64
|
+
:returns: n-by-3 array of adjusted vertex coordinates
|
|
65
|
+
"""
|
|
66
|
+
# center
|
|
67
|
+
v = v - 1 / v.shape[0] * jnp.tile(jnp.sum(v, axis=0), (v.shape[0], 1))
|
|
68
|
+
# normalize
|
|
69
|
+
v /= jnp.linalg.norm(v)
|
|
70
|
+
return v
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def multiprod(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray:
|
|
74
|
+
# vectorized matrix - matrix multiplication
|
|
75
|
+
return jnp.einsum('...ij,...jk', A, B)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def multitransp(A):
|
|
79
|
+
# vectorized matrix transpose
|
|
80
|
+
return jnp.einsum('...ij->...ji', A)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def multiskew(A):
|
|
84
|
+
return 0.5 * (A - jnp.einsum('...ij->...ji', A))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def multisym(A):
|
|
88
|
+
return 0.5 * (A + jnp.einsum('...ij->...ji', A))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def vectime3d(x, A):
|
|
92
|
+
"""
|
|
93
|
+
:param x: vector of length k
|
|
94
|
+
:param A: array of size k x n x m
|
|
95
|
+
:return: k x n x m array such that the j-th n x m slice of A is multiplied with the j-th element of x
|
|
96
|
+
|
|
97
|
+
In case of k=1, x * A is returned.
|
|
98
|
+
"""
|
|
99
|
+
if jnp.isscalar(x) and A.ndim == 2:
|
|
100
|
+
return x * A
|
|
101
|
+
|
|
102
|
+
x = jnp.atleast_2d(x)
|
|
103
|
+
assert x.ndim <= 2 and jnp.size(A.shape) == 3
|
|
104
|
+
assert x.shape[0] == 1 or x.shape[1] == 1
|
|
105
|
+
assert x.shape[0] == A.shape[0] or x.shape[1] == A.shape[0]
|
|
106
|
+
|
|
107
|
+
if x.shape[1] == 1:
|
|
108
|
+
x = x.T
|
|
109
|
+
|
|
110
|
+
A = jnp.einsum('kij->ijk', A)
|
|
111
|
+
return jnp.einsum('ijk->kij', x * A)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def gram_schmidt(A):
|
|
115
|
+
"""Orthogonalize a set of vectors stored as the columns of matrix A."""
|
|
116
|
+
# Get the number of vectors.
|
|
117
|
+
n = A.shape[1]
|
|
118
|
+
for j in range(n):
|
|
119
|
+
# To orthogonalize the vector in column j with respect to the
|
|
120
|
+
# previous vectors, subtract from it its projection onto
|
|
121
|
+
# each of the previous vectors.
|
|
122
|
+
for k in range(j):
|
|
123
|
+
A = A.at[:, j].set(A[:, j] - jnp.dot(A[:, k], A[:, j]) * A[:, k])
|
|
124
|
+
A = A.at[:, j].set(A[:, j] / jnp.linalg.norm(A[:, j]))
|
|
125
|
+
|
|
126
|
+
return A
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
from .tangent_layers import TangentMLP, TangentInvariant
|
|
14
|
+
from .wFM_layers import MfdFC, MfdInvariant
|
|
15
|
+
from .flow_layers import FlowLayer, MfdGcnBlock
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
from typing import Iterable, Callable
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
from jax.scipy.sparse.linalg import bicgstab
|
|
18
|
+
|
|
19
|
+
import jraph
|
|
20
|
+
from flax import linen as nn
|
|
21
|
+
|
|
22
|
+
from morphomatics.manifold import Manifold, PowerManifold
|
|
23
|
+
from morphomatics.opt import RiemannianNewtonRaphson
|
|
24
|
+
from morphomatics.graph.operators import mfdg_laplace
|
|
25
|
+
from morphomatics.nn.tangent_layers import TangentMLP
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FlowLayer(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Graph flow layer for graphs with manifold-valued features. The flow equation is integrated explicitly by default,
|
|
31
|
+
but an implicit scheme is also available.
|
|
32
|
+
|
|
33
|
+
See
|
|
34
|
+
|
|
35
|
+
M. Hanik, G, Steidl, C. v. Tycowicz. "Manifold GCN: Diffusion-based Convolutional Neural Network for
|
|
36
|
+
Manifold-valued Graphs" (https://arxiv.org/abs/2401.14381)
|
|
37
|
+
|
|
38
|
+
for a detailed description.
|
|
39
|
+
|
|
40
|
+
Inputs:
|
|
41
|
+
:param M: manifold in which the features lie
|
|
42
|
+
:param n_steps: number of explicit steps to approximate the flow with explicit Euler
|
|
43
|
+
:param implicit: boolean indicating whether to use implicit or explicit Euler integration
|
|
44
|
+
:param max_step_length: maximum step size for Euler integration
|
|
45
|
+
|
|
46
|
+
Note: Too long Euler steps can lead to numerical instabilities with some manifolds, e.g., the hyperbolic space.
|
|
47
|
+
In this case, a maximal step length should be ued.
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
M: Manifold
|
|
52
|
+
n_steps: int = 1
|
|
53
|
+
implicit: bool = False
|
|
54
|
+
max_step_length: float = jnp.inf
|
|
55
|
+
t_init: Callable = lambda *args: nn.initializers.truncated_normal(stddev=1.)(*args) + 1.
|
|
56
|
+
delta_init: Callable = lambda *args: nn.initializers.truncated_normal(stddev=1.)(*args) + 1.
|
|
57
|
+
|
|
58
|
+
def _single_euler_step(self, G: jraph.GraphsTuple, time: jnp.ndarray, delta: jnp.ndarray) -> jraph.GraphsTuple:
|
|
59
|
+
"""Single step of the explicit Euler method for diffusion
|
|
60
|
+
|
|
61
|
+
:param G: graph with manifold valued vectors as features; length of vector must equal the flow layer width
|
|
62
|
+
:param time: vector of time parameters (same length as number of features)
|
|
63
|
+
:param delta: vector of "minimal step sizes"
|
|
64
|
+
:return: updated graph
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def _multi_laplace(channel):
|
|
68
|
+
return mfdg_laplace(self.M, G._replace(nodes=channel))
|
|
69
|
+
|
|
70
|
+
def _activation(feature, vector, d):
|
|
71
|
+
nrm = jnp.sqrt(self.M.metric.inner(feature, vector, vector) + jnp.finfo(jnp.float64).eps)
|
|
72
|
+
alp = jax.nn.sigmoid(nrm - d)
|
|
73
|
+
|
|
74
|
+
# make sure that the step size is not larger than max_step_length
|
|
75
|
+
return jax.lax.cond(nrm * alp <= self.max_step_length,
|
|
76
|
+
lambda w: alp * w,
|
|
77
|
+
lambda w: w * self.max_step_length / nrm, vector)
|
|
78
|
+
|
|
79
|
+
v = jax.vmap(_multi_laplace, in_axes=1, out_axes=1)(G.nodes)
|
|
80
|
+
|
|
81
|
+
# ReLU-type activation
|
|
82
|
+
delta = jnp.stack([delta, ] * v.shape[0])
|
|
83
|
+
v = jax.vmap(jax.vmap(_activation))(G.nodes, v, delta)
|
|
84
|
+
|
|
85
|
+
v = -v * time.reshape((1, -1) + (1,) * (v.ndim - 2))
|
|
86
|
+
x = jax.vmap(jax.vmap(self.M.connec.exp))(G.nodes, v)
|
|
87
|
+
return G._replace(nodes=x)
|
|
88
|
+
|
|
89
|
+
def _implicit_euler_step(self, G: jraph.GraphsTuple, time: jnp.ndarray, delta=None) -> jraph.GraphsTuple:
|
|
90
|
+
"""Single step of the implicit Euler method for diffusion
|
|
91
|
+
|
|
92
|
+
:param G: graph with manifold valued vectors as features; length of vector must equal the flow layer width
|
|
93
|
+
:param time: vector of time parameters (same length as number of features)
|
|
94
|
+
:param delta: only needed for out of syntax reasons
|
|
95
|
+
:return: updated graph
|
|
96
|
+
"""
|
|
97
|
+
# n_nodes x n_channels x point_shape
|
|
98
|
+
n, c, *shape = G.nodes.shape
|
|
99
|
+
|
|
100
|
+
# power manifold
|
|
101
|
+
P = PowerManifold(self.M, n * c)
|
|
102
|
+
|
|
103
|
+
# current state
|
|
104
|
+
x_cur = G.nodes.reshape(-1, *shape)
|
|
105
|
+
|
|
106
|
+
# root of F characterizes solution to implicit Euler step
|
|
107
|
+
def F(x: jnp.array):
|
|
108
|
+
L = lambda a: mfdg_laplace(self.M, G._replace(nodes=a))
|
|
109
|
+
Lx = jax.vmap(L, in_axes=1, out_axes=1)(x.reshape(n, c, *shape))
|
|
110
|
+
tLx = Lx * time.reshape((1, -1) + (1,) * len(shape))
|
|
111
|
+
diff = P.connec.log(x, x_cur)
|
|
112
|
+
return diff - tLx.reshape(-1, *shape)
|
|
113
|
+
|
|
114
|
+
# x_next = RiemannianNewtonRaphson.solve(P, F, x_cur, maxiter=1)
|
|
115
|
+
###############################
|
|
116
|
+
# unroll single interation
|
|
117
|
+
###############################
|
|
118
|
+
# solve for update direction: v = -J⁻¹F(x)
|
|
119
|
+
J = lambda v: jax.jvp(F, (x_cur,), (v,))[1]
|
|
120
|
+
v, _ = bicgstab(J, -F(x_cur))
|
|
121
|
+
# step
|
|
122
|
+
x_next = P.connec.exp(x_cur, v)
|
|
123
|
+
|
|
124
|
+
return G._replace(nodes=x_next.reshape(n, c, *shape))
|
|
125
|
+
|
|
126
|
+
@nn.compact
|
|
127
|
+
def __call__(self, G: jraph.GraphsTuple) -> jraph.GraphsTuple:
|
|
128
|
+
"""
|
|
129
|
+
:param G: graphs tuple with features of shape: num_nodes * num_channels * point_shape
|
|
130
|
+
:return: graphs tuple with features of shape: num_nodes * num_channels * point_shape
|
|
131
|
+
|
|
132
|
+
Apply discretized diffusion flow (with final activation) to each channel.
|
|
133
|
+
"""
|
|
134
|
+
step_method = self._implicit_euler_step if self.implicit else self._single_euler_step
|
|
135
|
+
|
|
136
|
+
width = G.nodes.shape[1] # number of channels
|
|
137
|
+
####################
|
|
138
|
+
t = self.param("t_sqrt", self.t_init, (width,), G.nodes.dtype)
|
|
139
|
+
delta = self.param("delta_sqrt", self.delta_init, (width,), G.nodes.dtype)
|
|
140
|
+
|
|
141
|
+
# map to non-negative weights
|
|
142
|
+
t = t ** 2
|
|
143
|
+
delta = delta ** 2
|
|
144
|
+
|
|
145
|
+
def step(graph, _):
|
|
146
|
+
graph = step_method(graph, t / self.n_steps, delta)
|
|
147
|
+
return graph, None
|
|
148
|
+
|
|
149
|
+
G, _ = jax.lax.scan(step, G, None, self.n_steps, unroll=self.n_steps)
|
|
150
|
+
|
|
151
|
+
return G
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MfdGcnBlock(nn.Module):
|
|
155
|
+
"""
|
|
156
|
+
Manifold convolution network block as proposed in
|
|
157
|
+
M. Hanik, G, Steidl, C. v. Tycowicz. "Manifold GCN: Diffusion-based Convolutional Neural Network for
|
|
158
|
+
Manifold-valued Graphs" (https://arxiv.org/abs/2401.14381)
|
|
159
|
+
|
|
160
|
+
An explicit skip connection can be added that attaches the input to the output as an additional channels.
|
|
161
|
+
Note: If skip connections between each flow-tMLP unit are wanted, use several blocks with only one layer.
|
|
162
|
+
|
|
163
|
+
:param M: manifold constituting the signal domain
|
|
164
|
+
:param channel_sizes: sequence of channel sizes
|
|
165
|
+
:param n_steps: number of Euler steps that are performed in the flow layer
|
|
166
|
+
:param implicit: boolean indicating whether to use implicit or explicit Euler integration
|
|
167
|
+
:param max_step_length: maximum step size for Euler integration (see the flow layer)
|
|
168
|
+
:param explicit_skip: boolean indicating whether additionally to perform an explicit skip connection
|
|
169
|
+
:param inputs_are_copies: when true, only the first input channel is passed through by the skip connection
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
M: Manifold
|
|
173
|
+
channel_sizes: Iterable[int]
|
|
174
|
+
n_steps: int = 1
|
|
175
|
+
implicit: bool = False
|
|
176
|
+
max_step_length: float = jnp.inf
|
|
177
|
+
explicit_skip: bool = False
|
|
178
|
+
inputs_are_copies: bool = False
|
|
179
|
+
|
|
180
|
+
def setup(self):
|
|
181
|
+
layers = []
|
|
182
|
+
channel_sizes = tuple(self.channel_sizes)
|
|
183
|
+
for i, channel_size in enumerate(channel_sizes):
|
|
184
|
+
layers.append(
|
|
185
|
+
(
|
|
186
|
+
FlowLayer(self.M, self.n_steps, self.implicit, self.max_step_length),
|
|
187
|
+
TangentMLP(self.M, (channel_size,))
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
self.layers = tuple(layers)
|
|
191
|
+
|
|
192
|
+
def __call__(self, G: jraph.GraphsTuple) -> jraph.GraphsTuple:
|
|
193
|
+
"""
|
|
194
|
+
:param G: graphs tuple with features of shape: num_nodes * num_channels * point_shape
|
|
195
|
+
:return: graphs tuple with features of shape: num_nodes * out_channels * point_shape
|
|
196
|
+
|
|
197
|
+
We use jraph pooling; hence, the batches are combined in the same graph and thus "hidden" in num_nodes.
|
|
198
|
+
The number of output channels is the number of output channels of the last (tangentMLP) layer plus, if
|
|
199
|
+
activated, the number of channels that are passed through by the skip connection (either num_channels or 1).
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
# save input for the skip connection
|
|
203
|
+
if self.explicit_skip:
|
|
204
|
+
if self.inputs_are_copies:
|
|
205
|
+
z = G.nodes[:, 1, None]
|
|
206
|
+
else:
|
|
207
|
+
z = G.nodes
|
|
208
|
+
|
|
209
|
+
for layer_unit in self.layers:
|
|
210
|
+
# flow layer
|
|
211
|
+
G = layer_unit[0](G)
|
|
212
|
+
# tangent MLP
|
|
213
|
+
G = G._replace(nodes=layer_unit[1](G.nodes[None])[0])
|
|
214
|
+
|
|
215
|
+
# skip connection
|
|
216
|
+
if self.explicit_skip:
|
|
217
|
+
G = G._replace(nodes=jax.lax.concatenate([z, G.nodes], 1))
|
|
218
|
+
|
|
219
|
+
return G
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
from typing import Sequence
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
|
|
19
|
+
import flax.linen as nn
|
|
20
|
+
|
|
21
|
+
from morphomatics.manifold import Manifold
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TangentMLP(nn.Module):
|
|
25
|
+
"""
|
|
26
|
+
Generalized multi-layer perceptron for manifold-valued features as proposed in
|
|
27
|
+
M. Hanik, G, Steidl, C. v. Tycowicz. "Manifold GCN: Diffusion-based Convolutional Neural Network for
|
|
28
|
+
Manifold-valued Graphs" (https://arxiv.org/abs/2401.14381).
|
|
29
|
+
|
|
30
|
+
:param M: Manifold input signal takes values in
|
|
31
|
+
:param out_sizes: number of output feature channels (sequence thereof)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
M: Manifold
|
|
35
|
+
out_sizes: Sequence[int]
|
|
36
|
+
|
|
37
|
+
@nn.compact
|
|
38
|
+
def __call__(self, x):
|
|
39
|
+
"""
|
|
40
|
+
Apply tangent MLP layer.
|
|
41
|
+
:param x: input sequence with shape: batch * sequence_length * in_channel * M.point_shape
|
|
42
|
+
(M being the underlying manifold)
|
|
43
|
+
:return: output with shape: batch * sequence_length * out_channel * M.point_shape
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
n_batch, n_seq, n_in, *pt_shape = x.shape
|
|
47
|
+
|
|
48
|
+
# flatten first two axes -> shape: (batch * sequence_length) * in_channel * M.point_shape
|
|
49
|
+
x = x.reshape(-1, n_in, *pt_shape)
|
|
50
|
+
|
|
51
|
+
# compute logs -> tangent vectors
|
|
52
|
+
pernode = jax.vmap(self.M.connec.log, in_axes=(None, 0))
|
|
53
|
+
v = jax.vmap(pernode)(x[:, 0], x[:, 1:])
|
|
54
|
+
|
|
55
|
+
# shape into vectors
|
|
56
|
+
v = v.reshape(n_batch, n_seq, n_in-1, -1)
|
|
57
|
+
|
|
58
|
+
# apply vector neuron MLP
|
|
59
|
+
v = VectorNeuronMLP(self.out_sizes)(v)
|
|
60
|
+
|
|
61
|
+
# shape back into tangent vectors
|
|
62
|
+
v = v.reshape(n_batch * n_seq, -1, *pt_shape)
|
|
63
|
+
|
|
64
|
+
# map back to manifold
|
|
65
|
+
pernode = jax.vmap(self.M.connec.exp, in_axes=(None, 0))
|
|
66
|
+
y = jax.vmap(pernode)(x[:, 0], v)
|
|
67
|
+
|
|
68
|
+
return y.reshape(n_batch, n_seq, -1, *pt_shape)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TangentInvariant(nn.Module):
|
|
72
|
+
"""
|
|
73
|
+
Invariant layer for manifold-valued features extending the TangentMLP layer.
|
|
74
|
+
Specifically, computes inner products of the input (linearized via Log) with
|
|
75
|
+
a set of tangent vectors obtained from them (using TangentMLP).
|
|
76
|
+
Finally, the products are passed through a fully connected layer to match desired output size.
|
|
77
|
+
|
|
78
|
+
:param M: Manifold input signal takes values in
|
|
79
|
+
:param out_channel: number of output feature channels
|
|
80
|
+
:param vec_sizes: sequence of widths for TangentMLP
|
|
81
|
+
:param use_bias: whether to use bias in the final fully connected layer
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
M: Manifold
|
|
85
|
+
out_channel: int
|
|
86
|
+
vec_sizes: Sequence[int] = (3,)
|
|
87
|
+
use_bias: bool = True
|
|
88
|
+
|
|
89
|
+
@nn.compact
|
|
90
|
+
def __call__(self, x):
|
|
91
|
+
"""
|
|
92
|
+
Apply tangent MLP layer.
|
|
93
|
+
:param x: input sequence with shape: batch * sequence_length * in_channel * M.point_shape
|
|
94
|
+
(M being the underlying manifold)
|
|
95
|
+
:return: output with shape: batch * sequence_length * out_channel
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
n_batch, n_seq, n_in, *pt_shape = x.shape
|
|
99
|
+
|
|
100
|
+
# flatten first two axes -> shape: (batch * sequence_length) * in_channel * M.point_shape
|
|
101
|
+
x = x.reshape(-1, n_in, *pt_shape)
|
|
102
|
+
|
|
103
|
+
# compute logs -> tangent vectors
|
|
104
|
+
log_batched = jax.vmap(self.M.connec.log, in_axes=(None, 0))
|
|
105
|
+
v = jax.vmap(log_batched)(x[:, 0], x[:, 1:])
|
|
106
|
+
|
|
107
|
+
# shape into vectors
|
|
108
|
+
w = v.reshape(n_batch, n_seq, n_in-1, -1)
|
|
109
|
+
|
|
110
|
+
# apply vector neuron MLP
|
|
111
|
+
w = VectorNeuronMLP(self.vec_sizes)(w)
|
|
112
|
+
|
|
113
|
+
# shape back into tangent vectors
|
|
114
|
+
w = w.reshape(n_batch * n_seq, -1, *pt_shape)
|
|
115
|
+
|
|
116
|
+
# lower indices of tangent vectors (either v or w)
|
|
117
|
+
flat_batched = jax.vmap(self.M.metric.flat, in_axes=(None, 0))
|
|
118
|
+
if w.shape[1] > n_in:
|
|
119
|
+
v = jax.vmap(flat_batched)(x[:, 0], v)
|
|
120
|
+
else:
|
|
121
|
+
w = jax.vmap(flat_batched)(x[:, 0], w)
|
|
122
|
+
# compute inner products
|
|
123
|
+
y = jnp.einsum('...ij,...kj', v.reshape(*v.shape[:2], -1), w.reshape(*w.shape[:2], -1))
|
|
124
|
+
|
|
125
|
+
f = nn.Dense(self.out_channel, use_bias=self.use_bias)
|
|
126
|
+
return f(y.reshape(n_batch, n_seq, -1))
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class VectorNeuronMLP(nn.Module):
|
|
130
|
+
"""
|
|
131
|
+
Vector Neuron MLP layer as described in
|
|
132
|
+
Deng, C., Litany, O., Duan, Y., Poulenard, A., Tagliasacchi, A., & Guibas, L. J. (2021).
|
|
133
|
+
Vector neurons: A general framework for SO(3)-equivariant networks.
|
|
134
|
+
In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 12200-12209).
|
|
135
|
+
|
|
136
|
+
:param output_sizes: sequence (length: m+1) of layer widths
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
output_sizes: Sequence[int]
|
|
140
|
+
negative_slope: float = 0.2
|
|
141
|
+
|
|
142
|
+
@nn.compact
|
|
143
|
+
def __call__(self, x: jnp.array):
|
|
144
|
+
"""
|
|
145
|
+
Apply layer.
|
|
146
|
+
:param x: input sequence with shape: batch * sequence_length * in_channel * vector_dim
|
|
147
|
+
:return: output with shape: batch * sequence_length * out_channel * vector_shape
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
n_in = x.shape[2]
|
|
151
|
+
|
|
152
|
+
# apply layers
|
|
153
|
+
for i, n_out in enumerate(self.output_sizes):
|
|
154
|
+
# initialize weights
|
|
155
|
+
u_init = nn.initializers.truncated_normal(stddev=np.sqrt(2/n_in))
|
|
156
|
+
w_init = lambda *args: nn.initializers.truncated_normal(stddev=1. / np.sqrt(n_in))(*args) + 1. / n_in
|
|
157
|
+
|
|
158
|
+
U = self.param(f"U_{i}", u_init, (n_out, n_in), x.dtype)
|
|
159
|
+
W = self.param(f"W_{i}", w_init, (n_out, n_in), x.dtype)
|
|
160
|
+
|
|
161
|
+
n_in = n_out
|
|
162
|
+
|
|
163
|
+
# compute direction k (and its squared norm)
|
|
164
|
+
k = jnp.einsum('ij,...jk', U, x)
|
|
165
|
+
sqnrm_k = jnp.sum(k**2, axis=-1, keepdims=True) + np.finfo(np.float64).eps
|
|
166
|
+
|
|
167
|
+
# compute feature q
|
|
168
|
+
q = jnp.einsum('ij,...jk', W, x)
|
|
169
|
+
|
|
170
|
+
# Rotation-equivariant ReLU
|
|
171
|
+
dot_qk = jnp.sum(q * k, axis=-1, keepdims=True)
|
|
172
|
+
x = q + k * jax.nn.relu(-dot_qk) / sqnrm_k
|
|
173
|
+
# Leaky ReLU (weighted average of q and ReLU(q))
|
|
174
|
+
x = self.negative_slope * q + (1 - self.negative_slope) * x
|
|
175
|
+
|
|
176
|
+
return x
|