morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
from morphomatics.manifold import Manifold, Metric
|
|
19
|
+
from morphomatics.manifold.metric import _eval_adjJacobi_embed
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class HyperbolicSpace(Manifold):
|
|
23
|
+
"""n-dimensional hyperbolic space represented by the hyperboloid model. Elements are represented as (row) vectors of
|
|
24
|
+
length n + 1.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, point_shape=(3,), structure='Canonical'):
|
|
28
|
+
name = 'Points in ' +\
|
|
29
|
+
'x'.join(map(str, point_shape)) + '-dim. space ' + 'for which the Minkowski quadratic form is -1.'
|
|
30
|
+
dimension = np.prod(point_shape)-1
|
|
31
|
+
super().__init__(name, dimension, point_shape)
|
|
32
|
+
if structure:
|
|
33
|
+
getattr(self, f'init{structure}Structure')()
|
|
34
|
+
|
|
35
|
+
def tree_flatten(self):
|
|
36
|
+
children, aux = super().tree_flatten()
|
|
37
|
+
return children, aux+(self.point_shape,)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def tree_unflatten(cls, aux_data, children):
|
|
41
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
42
|
+
*aux_data, shape = aux_data
|
|
43
|
+
obj = cls(shape, structure=None)
|
|
44
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
45
|
+
return obj
|
|
46
|
+
|
|
47
|
+
def initCanonicalStructure(self):
|
|
48
|
+
"""
|
|
49
|
+
Instantiate Sphere with canonical structure.
|
|
50
|
+
"""
|
|
51
|
+
structure = HyperbolicSpace.CanonicalStructure(self)
|
|
52
|
+
self._metric = structure
|
|
53
|
+
self._connec = structure
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def minkowski_inner(x, y):
|
|
57
|
+
"""Minkowski semi-Riemannian metric on R^(n+1). The hyperboloid model of n-dimensional
|
|
58
|
+
hyperbolic space consists of all points x with minkowski_inner(x, x) = -1. The tangent space at x consists of
|
|
59
|
+
all vectors v in R^(n+1) with minkowski_inner(x, v) = 0."""
|
|
60
|
+
return -x[-1] * y[-1] + (x[:-1] * y[:-1]).sum()
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def project_to_manifold(x):
|
|
64
|
+
"""Projection onto the hyperboloid."""
|
|
65
|
+
return x.at[-1].set(jnp.sqrt(1 + jnp.sum(x[:-1] ** 2)))
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def regularize(p):
|
|
69
|
+
"""Regularize/project points that are slightly off the hyperboloid to avoid numerical problems. Only apply when
|
|
70
|
+
minkowski_inner(p, p) < 0."""
|
|
71
|
+
return p / jnp.sqrt(jnp.abs(HyperbolicSpace.minkowski_inner(p, p)) + np.finfo(np.float64).eps)
|
|
72
|
+
|
|
73
|
+
def rand(self, key: jax.Array):
|
|
74
|
+
p = jax.random.normal(key, self.point_shape)
|
|
75
|
+
return self.project_to_manifold(p)
|
|
76
|
+
|
|
77
|
+
def randvec(self, p, key: jax.Array):
|
|
78
|
+
H = jax.random.normal(key, self.point_shape)
|
|
79
|
+
return H + HyperbolicSpace.minkowski_inner(p, H) * p
|
|
80
|
+
|
|
81
|
+
def zerovec(self):
|
|
82
|
+
return jnp.zeros(self.point_shape)
|
|
83
|
+
|
|
84
|
+
def pole(self):
|
|
85
|
+
o = jnp.zeros(self.point_shape)
|
|
86
|
+
o = o.at[-1].set(1)
|
|
87
|
+
return o
|
|
88
|
+
|
|
89
|
+
def proj(self, p, H):
|
|
90
|
+
return H + HyperbolicSpace.minkowski_inner(p, H) * p
|
|
91
|
+
|
|
92
|
+
class CanonicalStructure(Metric):
|
|
93
|
+
"""
|
|
94
|
+
The Riemannian metric used is the induced metric from the Minkowski sub-Riemannian metric on the embedding space
|
|
95
|
+
R^n+1.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, M):
|
|
99
|
+
"""
|
|
100
|
+
Constructor.
|
|
101
|
+
"""
|
|
102
|
+
self._M = M
|
|
103
|
+
|
|
104
|
+
def __str__(self):
|
|
105
|
+
return "Canonical structure"
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def typicaldist(self):
|
|
109
|
+
return jnp.sqrt(self._M.dim)
|
|
110
|
+
|
|
111
|
+
def inner(self, p, X, Y):
|
|
112
|
+
return HyperbolicSpace.minkowski_inner(X, Y)
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def metric_matrix(self):
|
|
116
|
+
"""Matrix representation of the Minkowski metric"""
|
|
117
|
+
G = jnp.eye(self._M.dim + 1)
|
|
118
|
+
G = G.at[-1, -1].set(-1)
|
|
119
|
+
return G
|
|
120
|
+
|
|
121
|
+
def norm(self, p, X):
|
|
122
|
+
# inner can be smaller than 0 due to cancelling of digits when 2 similar numbers are subtracted
|
|
123
|
+
return jnp.sqrt(jax.nn.relu(self.inner(p, X, X)))
|
|
124
|
+
|
|
125
|
+
def dist(self, p, q):
|
|
126
|
+
return jnp.sqrt(self.squared_dist(p, q))
|
|
127
|
+
|
|
128
|
+
def squared_dist(self, p, q):
|
|
129
|
+
# in theory minkowski_inner(p, q) < -1 always holds, but for numerical reasons we clip
|
|
130
|
+
mink_inner_neg = jnp.clip(-HyperbolicSpace.minkowski_inner(p, q), a_min=1)
|
|
131
|
+
|
|
132
|
+
def trunc_dist_sq(x):
|
|
133
|
+
# 4th order approximation of arccosh**2 around 1
|
|
134
|
+
return 2*(x-1) - 1/3*(x-1)**2 + 4/45*(x-1)**3
|
|
135
|
+
|
|
136
|
+
d2 = jax.lax.cond(mink_inner_neg > 1+1e-6, lambda i: jnp.acosh(i)**2, trunc_dist_sq, mink_inner_neg)
|
|
137
|
+
|
|
138
|
+
return jax.nn.relu(d2)
|
|
139
|
+
|
|
140
|
+
def egrad2rgrad(self, p, H):
|
|
141
|
+
H = H.at[-1].set(-H[-1])
|
|
142
|
+
return self._M.proj(p, H)
|
|
143
|
+
|
|
144
|
+
def ehess2rhess(self, p, G, H, X):
|
|
145
|
+
"""Converts the Euclidean gradient P_G and Hessian H of a function at
|
|
146
|
+
a point p along a tangent vector X to the Riemannian Hessian
|
|
147
|
+
along X on the manifold.
|
|
148
|
+
"""
|
|
149
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
150
|
+
|
|
151
|
+
def exp(self, p, X):
|
|
152
|
+
# numerical safeguard
|
|
153
|
+
X = self._M.proj(p, X)
|
|
154
|
+
|
|
155
|
+
def full_exp(n2):
|
|
156
|
+
n = jnp.sqrt(n2)
|
|
157
|
+
return jnp.cosh(n) * p + jnp.sinh(n) * X/n
|
|
158
|
+
|
|
159
|
+
def trunc_exp(n2):
|
|
160
|
+
# 6th order approximation of cosh(n)*p + sinh(n)*X/n around 0
|
|
161
|
+
return p+X + 1/6 * n2 * (3*p+X) + 1/120 * n2**2 * (5*p+X)
|
|
162
|
+
|
|
163
|
+
sqnorm_X = self.inner(p, X, X)
|
|
164
|
+
sqnorm_X = jnp.clip(sqnorm_X, a_min=0.)
|
|
165
|
+
|
|
166
|
+
p = jax.lax.cond(sqnorm_X < 1e-6, trunc_exp, full_exp, sqnorm_X)
|
|
167
|
+
|
|
168
|
+
return HyperbolicSpace.project_to_manifold(p)
|
|
169
|
+
|
|
170
|
+
def retr(self, p, X):
|
|
171
|
+
return self.exp(p, X)
|
|
172
|
+
|
|
173
|
+
def log(self, p, q):
|
|
174
|
+
sqd = self.squared_dist(p, q)
|
|
175
|
+
|
|
176
|
+
def full_log(d2):
|
|
177
|
+
# For the formula see Eq. (13) in Pennec, "Hessian of the Riemannian squared distance", HAL INRIA, 2017.
|
|
178
|
+
d = jnp.sqrt(d2)
|
|
179
|
+
return d/jnp.sinh(d) * (q - jnp.cosh(d) * p)
|
|
180
|
+
|
|
181
|
+
def trunc_log(d2):
|
|
182
|
+
# see also Pennec, "Hessian of the Riemannian squared distance", HAL INRIA, 2017.
|
|
183
|
+
return (1 - d2/6) * (q - (1 + d2/2 + d2**2/24) * p)
|
|
184
|
+
|
|
185
|
+
v = jax.lax.cond(sqd < 1e-6, trunc_log, full_log, sqd)
|
|
186
|
+
|
|
187
|
+
return self._M.proj(p, v)
|
|
188
|
+
|
|
189
|
+
def geopoint(self, p, q, t):
|
|
190
|
+
"""Evaluate gam(t;p,q) = exp_p(t*log_p(q))."""
|
|
191
|
+
dist = self.dist(p, q)
|
|
192
|
+
|
|
193
|
+
def full_geopoint(d):
|
|
194
|
+
r = q / jnp.sinh(d) - 1/jnp.tanh(d) * p # log_p(q) / d
|
|
195
|
+
return jnp.cosh(t*d) * p + jnp.sinh(t*d) * r
|
|
196
|
+
|
|
197
|
+
return jax.lax.cond(dist > 1e-6, full_geopoint, lambda _: (1-t)*p + t*q, dist)
|
|
198
|
+
|
|
199
|
+
def flat(self, p, X):
|
|
200
|
+
"""Lower vector with the metric"""
|
|
201
|
+
dX = X.at[-1].set(-X[-1])
|
|
202
|
+
return dX
|
|
203
|
+
|
|
204
|
+
def sharp(self, p, dX):
|
|
205
|
+
"""Raise covector with the metric"""
|
|
206
|
+
X = dX.at[-1].set(-dX[-1])
|
|
207
|
+
return X
|
|
208
|
+
|
|
209
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
210
|
+
"""Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
|
|
211
|
+
covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
212
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
213
|
+
is used.
|
|
214
|
+
"""
|
|
215
|
+
return -1 * (HyperbolicSpace.minkowski_inner(Z, Y) * X - HyperbolicSpace.minkowski_inner(Z, X) * Y)
|
|
216
|
+
|
|
217
|
+
def transp(self, p, q, X):
|
|
218
|
+
sqd = self.squared_dist(p, q)
|
|
219
|
+
|
|
220
|
+
def do_transp(X):
|
|
221
|
+
sum_l = self.log(p, q) + self.log(q, p)
|
|
222
|
+
return X - self.inner(p, self.log(p, q), X)/sqd * sum_l
|
|
223
|
+
|
|
224
|
+
return jax.lax.cond(sqd < 1e-6, lambda X: X, do_transp, X)
|
|
225
|
+
|
|
226
|
+
def jacobiField(self, p, q, t, X):
|
|
227
|
+
return NotImplementedError('This function has not been implemented yet.')
|
|
228
|
+
|
|
229
|
+
def _adjJacobi(self, p, q, t, w):
|
|
230
|
+
# alternative version to adjJacobi relying on automatic differentiation
|
|
231
|
+
return self._M.proj(p, _eval_adjJacobi_embed(self, p, q, t, w))
|
|
232
|
+
|
|
233
|
+
def adjJacobi(self, p, q, t, w):
|
|
234
|
+
"""Evaluate an adjoint jacobi field.
|
|
235
|
+
|
|
236
|
+
The decomposition of the curvature operator and the fact that only two of its eigenvectors are necessary is
|
|
237
|
+
used in the algorithm.
|
|
238
|
+
|
|
239
|
+
:param p: element of the hyperboloid
|
|
240
|
+
:param q: element of the hyperboloid
|
|
241
|
+
:param t: scalar in [0,1]
|
|
242
|
+
:param w: tangent vector at gamma(t;p,q)
|
|
243
|
+
:returns V: tangent vector at p
|
|
244
|
+
"""
|
|
245
|
+
dist = self.dist(p, q)
|
|
246
|
+
|
|
247
|
+
def _eval(dW):
|
|
248
|
+
d, W = dW
|
|
249
|
+
# all computations can be done at p, so only w has to be parallel translated
|
|
250
|
+
W = self.transp(self.geopoint(p, q, t), p, W)
|
|
251
|
+
|
|
252
|
+
# first eigenvector is normalized tangent of the geodesic -> corresponding eigenvalue is 0
|
|
253
|
+
T = self.log(p, q) / d
|
|
254
|
+
|
|
255
|
+
# second eigenvector is Gram Schmidt orthonormalization of W against T -> corresponding eigenvalue is -1
|
|
256
|
+
b1 = self.inner(p, W, T)
|
|
257
|
+
U = W - b1 * T
|
|
258
|
+
# Check whether W is (numerically) parallel to T;
|
|
259
|
+
# then, the adoint Jacobi field is only a scaling of the parallel transported tangent.
|
|
260
|
+
U = jax.lax.cond(jnp.linalg.norm(U) > 1e-6,
|
|
261
|
+
lambda v: v / self.norm(p, v),
|
|
262
|
+
lambda v: jnp.zeros_like(v), U)
|
|
263
|
+
|
|
264
|
+
b2 = self.inner(p, W, U)
|
|
265
|
+
|
|
266
|
+
a1 = 1-t # corresponds to the eigenvalue 0
|
|
267
|
+
a2 = jnp.sinh((1-t)*d) / jnp.sinh(d) # corresponds to the eigenvalue -1
|
|
268
|
+
|
|
269
|
+
return a1 * b1 * T + a2 * b2 * U
|
|
270
|
+
|
|
271
|
+
return jax.lax.cond(dist > 1e-6, _eval, lambda args: 1/(1-t) * args[-1], (dist, w))
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from typing import Sequence
|
|
15
|
+
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
|
|
19
|
+
from morphomatics.manifold import ShapeSpace, Metric, Sphere
|
|
20
|
+
from morphomatics.manifold.discrete_ops import pole_ladder
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Kendall(ShapeSpace):
|
|
24
|
+
"""
|
|
25
|
+
Kendall's shape space: (SO_m)-equivalence classes of preshape points (projection of centered landmarks onto the sphere)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, shape: Sequence[int], structure='Canonical'):
|
|
29
|
+
if len(shape) == 0:
|
|
30
|
+
raise TypeError("Need shape parameters.")
|
|
31
|
+
|
|
32
|
+
# Pre-Shape space (sphere)
|
|
33
|
+
self._S = Sphere(shape)
|
|
34
|
+
dimension = int(self._S.dim - shape[-1] * (shape[-1] - 1) / 2)
|
|
35
|
+
|
|
36
|
+
self.ref = None
|
|
37
|
+
|
|
38
|
+
name = 'Kendall shape space of ' + 'x'.join(map(str, shape[:-1])) + ' Landmarks in R^' + str(shape[-1])
|
|
39
|
+
super().__init__(name, dimension, shape)
|
|
40
|
+
if structure:
|
|
41
|
+
getattr(self, f'init{structure}Structure')()
|
|
42
|
+
|
|
43
|
+
def tree_flatten(self):
|
|
44
|
+
children, aux = super().tree_flatten()
|
|
45
|
+
return children, aux + (self.point_shape,)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def tree_unflatten(cls, aux_data, children):
|
|
49
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
50
|
+
*aux_data, shape = aux_data
|
|
51
|
+
obj = cls(shape, structure=None)
|
|
52
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
53
|
+
return obj
|
|
54
|
+
|
|
55
|
+
def update_ref_geom(self, v):
|
|
56
|
+
self.ref = self.to_coords(v)
|
|
57
|
+
|
|
58
|
+
def to_coords(self, v):
|
|
59
|
+
'''
|
|
60
|
+
:arg v: array of landmark coordinates
|
|
61
|
+
:return: manifold coordinates
|
|
62
|
+
'''
|
|
63
|
+
return Kendall.project(v)
|
|
64
|
+
|
|
65
|
+
def from_coords(self, c):
|
|
66
|
+
'''
|
|
67
|
+
:arg c: manifold coords.
|
|
68
|
+
:returns: array of landmark coordinates
|
|
69
|
+
'''
|
|
70
|
+
return c
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def ref_coords(self):
|
|
74
|
+
""" :returns: Coordinates of reference shape """
|
|
75
|
+
return self.ref
|
|
76
|
+
|
|
77
|
+
def rand(self, key: jax.Array):
|
|
78
|
+
p = jax.random.normal(key, self.point_shape)
|
|
79
|
+
return Kendall.project(p)
|
|
80
|
+
|
|
81
|
+
def randvec(self, p, key: jax.Array):
|
|
82
|
+
v = jax.random.normal(key, self.point_shape)
|
|
83
|
+
return Kendall.horizontal(p, self._S.proj(p, v))
|
|
84
|
+
|
|
85
|
+
def zerovec(self):
|
|
86
|
+
return jnp.zeros(self.point_shape)
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def wellpos(x, y):
|
|
90
|
+
"""
|
|
91
|
+
Rotate y such that it aligns to x.
|
|
92
|
+
:param x: (centered) reference landmark configuration.
|
|
93
|
+
:param y: (centered) landmarks to be aligned.
|
|
94
|
+
:returns: y well-positioned to x.
|
|
95
|
+
"""
|
|
96
|
+
return y @ Kendall.opt_rot(x, y)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def opt_rot(x, y):
|
|
100
|
+
"""
|
|
101
|
+
Rotate y such that it aligns to x.
|
|
102
|
+
:param x: (centered) reference landmark configuration.
|
|
103
|
+
:param y: (centered) landmarks to be aligned.
|
|
104
|
+
:returns: rotation R such that y*R is well-positioned to x.
|
|
105
|
+
"""
|
|
106
|
+
m = x.shape[-1]
|
|
107
|
+
sigma = jnp.ones(m)
|
|
108
|
+
# full_matrices=False equals full_matrices=True for quadratic input but allows for auto diff
|
|
109
|
+
u, _, v = jnp.linalg.svd(x.reshape(-1, m).T @ y.reshape(-1, m), full_matrices=False)
|
|
110
|
+
sigma = sigma.at[-1].set(jnp.sign(jnp.linalg.det(u @ v)))
|
|
111
|
+
return jnp.einsum('ji,j,kj', v, sigma, u)
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def center(x):
|
|
115
|
+
"""
|
|
116
|
+
Remove mean from x.
|
|
117
|
+
"""
|
|
118
|
+
mean = x.reshape(-1, x.shape[-1]).mean(axis=0)
|
|
119
|
+
return x - mean
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def project(x):
|
|
123
|
+
"""
|
|
124
|
+
Project to pre-shape space.
|
|
125
|
+
: param x: Point to project.
|
|
126
|
+
:returns: Projected x.
|
|
127
|
+
"""
|
|
128
|
+
x = Kendall.center(x)
|
|
129
|
+
return x / jnp.linalg.norm(x)
|
|
130
|
+
|
|
131
|
+
def proj(self, p, X):
|
|
132
|
+
""" Project a vector X from the ambient Euclidean space onto the tangent space at p. """
|
|
133
|
+
# TODO: think about naming convention.
|
|
134
|
+
return Kendall.horizontal(p, self._S.proj(p, X))
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def vertical(p, X):
|
|
138
|
+
"""
|
|
139
|
+
Compute vertical component of X at base point p by solving the sylvester equation
|
|
140
|
+
App^T+pp^TA = Xp^T-pX^T for A. If p has full rank (det(pp^T) > 0), then there exists a unique solution
|
|
141
|
+
A, it is skew-symmetric and Ap is the vertical component of X
|
|
142
|
+
"""
|
|
143
|
+
d = p.shape[-1]
|
|
144
|
+
S = p.reshape(-1, d).T @ p.reshape(-1, d)
|
|
145
|
+
rhs = X.reshape(-1, d).T @ p.reshape(-1, d)
|
|
146
|
+
rhs = rhs.T - rhs
|
|
147
|
+
S = jnp.kron(jnp.eye(d), S) + jnp.kron(S, jnp.eye(d))
|
|
148
|
+
A, *_ = jnp.linalg.lstsq(S, rhs.reshape(-1))
|
|
149
|
+
return jnp.einsum('...i,ij', p, A.reshape(d, d))
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def horizontal(p, X):
|
|
153
|
+
"""
|
|
154
|
+
compute horizontal component of X.
|
|
155
|
+
"""
|
|
156
|
+
X = Kendall.center(X)
|
|
157
|
+
return X - Kendall.vertical(p, X)
|
|
158
|
+
|
|
159
|
+
def initCanonicalStructure(self):
|
|
160
|
+
"""
|
|
161
|
+
Instantiate the preshape sphere with canonical structure.
|
|
162
|
+
"""
|
|
163
|
+
structure = Kendall.CanonicalStructure(self)
|
|
164
|
+
self._metric = structure
|
|
165
|
+
self._connec = structure
|
|
166
|
+
|
|
167
|
+
class CanonicalStructure(Metric):
|
|
168
|
+
"""
|
|
169
|
+
The Riemannian metric used is the induced metric from the embedding space (R^nxn)^k, i.e., this manifold is a
|
|
170
|
+
Riemannian submanifold of (R^3x3)^k endowed with the usual trace inner product.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(self, M):
|
|
174
|
+
"""
|
|
175
|
+
Constructor.
|
|
176
|
+
"""
|
|
177
|
+
self._M = M
|
|
178
|
+
self._S = M._S
|
|
179
|
+
|
|
180
|
+
def __str__(self):
|
|
181
|
+
return "canonical structure"
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def typicaldist(self):
|
|
185
|
+
return np.pi/2
|
|
186
|
+
|
|
187
|
+
def inner(self, p, X, Y):
|
|
188
|
+
return self._S.metric.inner(p, X, Y)
|
|
189
|
+
|
|
190
|
+
def norm(self, p, X):
|
|
191
|
+
return self._S.metric.norm(p, X)
|
|
192
|
+
|
|
193
|
+
def flat(self, p, X):
|
|
194
|
+
"""Lower vector X at p with the metric"""
|
|
195
|
+
return self._S.metric.flat(p, X)
|
|
196
|
+
|
|
197
|
+
def sharp(self, p, dX):
|
|
198
|
+
"""Raise covector dX at p with the metric"""
|
|
199
|
+
return self._S.metric.sharp(p, dX)
|
|
200
|
+
|
|
201
|
+
def egrad2rgrad(self, p, X):
|
|
202
|
+
return self._M.proj(p, X)
|
|
203
|
+
|
|
204
|
+
def ehess2rhess(self, p, G, H, X):
|
|
205
|
+
""" Convert the Euclidean gradient P_G and Hessian H of a function at
|
|
206
|
+
a point p along a tangent vector X to the Riemannian Hessian
|
|
207
|
+
along X on the manifold.
|
|
208
|
+
"""
|
|
209
|
+
Y = self._S.metric.ehess2rhess(p, G, H, X)
|
|
210
|
+
return Kendall.horizontal(p, Y)
|
|
211
|
+
|
|
212
|
+
def exp(self, p, X):
|
|
213
|
+
return self._S.connec.exp(p, X)
|
|
214
|
+
|
|
215
|
+
retr = exp
|
|
216
|
+
|
|
217
|
+
def log(self, p, q):
|
|
218
|
+
q = Kendall.wellpos(p, q)
|
|
219
|
+
return self._S.connec.log(p, q)
|
|
220
|
+
|
|
221
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
222
|
+
"""Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting
|
|
223
|
+
the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
224
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
225
|
+
is used.
|
|
226
|
+
"""
|
|
227
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
228
|
+
|
|
229
|
+
def geopoint(self, p, q, t):
|
|
230
|
+
return self.exp(p, t * self.log(p, q))
|
|
231
|
+
|
|
232
|
+
def transp(self, p, q, X):
|
|
233
|
+
# pole ladder transports along horizontal geodesic, thus, map (p, X) to the well-positioned representative
|
|
234
|
+
R = Kendall.opt_rot(q, p)
|
|
235
|
+
return pole_ladder(self._M, p @ R, q, X @ R, 10)
|
|
236
|
+
|
|
237
|
+
def pairmean(self, p, q):
|
|
238
|
+
return self.geopoint(p, q, .5)
|
|
239
|
+
|
|
240
|
+
def dist(self, p, q):
|
|
241
|
+
q = Kendall.wellpos(p, q)
|
|
242
|
+
return self._S.metric.dist(p, q)
|
|
243
|
+
|
|
244
|
+
def squared_dist(self, p, q):
|
|
245
|
+
q = Kendall.wellpos(p, q)
|
|
246
|
+
return self._S.metric.squared_dist(p, q)
|
|
247
|
+
|
|
248
|
+
def jacobiField(self, p, q, t, X):
|
|
249
|
+
# return self.proj(*super().jacobiField(p, q, t, X))
|
|
250
|
+
|
|
251
|
+
# q = Kendall.wellpos(p, q)
|
|
252
|
+
# phi = self._S.metric.dist(p, q)
|
|
253
|
+
# v = self._S.connec.log(p, q)
|
|
254
|
+
# gamTS = self._S.connec.exp(p, t * v)
|
|
255
|
+
#
|
|
256
|
+
# v = v / self._S.metric.norm(p, v)
|
|
257
|
+
# Xtan = self._S.metric.inner(p, X, v) * v
|
|
258
|
+
# Xorth = X - Xtan
|
|
259
|
+
|
|
260
|
+
# # tangential component of J: (1-t) * transp(p, gamTS, Xtan)
|
|
261
|
+
# Jtan = Xtan_norm / phi * self._S.connec.log(gamTS, q)
|
|
262
|
+
# return gamTS, (np.sin((1 - t) * phi) / np.sin(phi)) * Kendall.horizontal(gamTS, Xorth) + Jtan
|
|
263
|
+
|
|
264
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
265
|
+
|
|
266
|
+
def adjJacobi(self, p, q, t, X):
|
|
267
|
+
# return self.proj(p, super().adjJacobi(p, q, t, X))
|
|
268
|
+
|
|
269
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
# postponed evaluation of annotations to circumvent cyclic dependencies (will be default behavior in Python 4.0)
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
|
|
18
|
+
class LieGroup(metaclass=abc.ABCMeta):
|
|
19
|
+
"""
|
|
20
|
+
Interface setting out a template for Lie group classes.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, M: Manifold):
|
|
24
|
+
""" Construct connection.
|
|
25
|
+
:param M: underlying manifold
|
|
26
|
+
"""
|
|
27
|
+
self._M = M
|
|
28
|
+
|
|
29
|
+
@abc.abstractmethod
|
|
30
|
+
def __str__(self):
|
|
31
|
+
"""Returns a string representation of the particular group."""
|
|
32
|
+
|
|
33
|
+
# @abc.abstractmethod
|
|
34
|
+
# def compose(self, g, f):
|
|
35
|
+
# """Group operation"""
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def identity(self):
|
|
40
|
+
"""Returns the identity element e of the Lie group."""
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def coords(self, X):
|
|
44
|
+
"""Coordinate map for the tangent space at the identity."""
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def coords_inverse(self, X):
|
|
48
|
+
"""Inverse coordinate map for the tangent space at the identity."""
|
|
49
|
+
|
|
50
|
+
@abc.abstractmethod
|
|
51
|
+
def bracket(self, X, Y):
|
|
52
|
+
"""Lie bracket in Lie algebra."""
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def lefttrans(self, g, f):
|
|
56
|
+
"""Left translation of g by f.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@abc.abstractmethod
|
|
60
|
+
def righttrans(self, g, f):
|
|
61
|
+
"""Right translation of g by f.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
@abc.abstractmethod
|
|
65
|
+
def inverse(self, g):
|
|
66
|
+
"""Inverse map of the Lie group.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def exp(self, X):
|
|
71
|
+
"""Computes the Lie-theoretic exponential map of a tangent vector X at e.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
def log(self, g):
|
|
76
|
+
"""Computes the Lie-theoretic logarithm of g. This is the inverse of `exp`.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@abc.abstractmethod
|
|
80
|
+
def dleft(self, f, X):
|
|
81
|
+
"""Derivative of the left translation by f at e applied to the tangent vector X.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
@abc.abstractmethod
|
|
85
|
+
def dright(self, f, X):
|
|
86
|
+
"""Derivative of the right translation by f at e applied to the tangent vector X.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
@abc.abstractmethod
|
|
90
|
+
def dleft_inv(self, f, X):
|
|
91
|
+
"""Derivative of the left translation by f^{-1} at f applied to the tangent vector X.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
@abc.abstractmethod
|
|
95
|
+
def dright_inv(self, f, X):
|
|
96
|
+
"""Derivative of the right translation by f^{-1} at f applied to the tangent vector X.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
@abc.abstractmethod
|
|
100
|
+
def adjrep(self, g, X):
|
|
101
|
+
"""Adjoint representation of g applied to the tangent vector X at the identity.
|
|
102
|
+
"""
|