morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
from functools import partial
|
|
15
|
+
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
|
|
23
|
+
from morphomatics.geom.bezier_spline import BezierSpline, full_set, indep_set
|
|
24
|
+
from morphomatics.manifold import Manifold, Metric, PowerManifold
|
|
25
|
+
from morphomatics.opt import RiemannianSteepestDescent, RiemannianNewtonRaphson
|
|
26
|
+
from morphomatics.stats import ExponentialBarycenter
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Bezierfold(Manifold):
|
|
30
|
+
"""Manifold of Bézier splines (of fixed degrees)
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, M: Manifold, n_segments: int, degree: int, isscycle: bool=False,
|
|
35
|
+
n_steps: int=10, n_samples: int=None, structure='FunctionalBased'):
|
|
36
|
+
"""Manifold of Bézier splines of constant segment degree
|
|
37
|
+
|
|
38
|
+
:arg M: base manifold in which the curves lie
|
|
39
|
+
:arg n_segments: number of spline segments
|
|
40
|
+
:arg degree: degree of segment (same for each one)
|
|
41
|
+
:arg iscycle: boolean indicating whether the splines are closed
|
|
42
|
+
:arg n_steps: number of steps (i.e. segments) for approximation of geodesics in Bezierfold
|
|
43
|
+
:arg n_samples: number of samples for quadrature of curve distance in L²(I, M)
|
|
44
|
+
:arg structure: type of geometric structure
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
self._M = M
|
|
48
|
+
self._degrees = np.full(n_segments, degree)
|
|
49
|
+
self._nsteps = n_steps
|
|
50
|
+
|
|
51
|
+
if isscycle:
|
|
52
|
+
name = 'Manifold of closed Bézier splines of degree {d} through '.format(d=degree) + str(M)
|
|
53
|
+
K = np.sum(self._degrees - 1) - 1
|
|
54
|
+
else:
|
|
55
|
+
name = 'Manifold of non-closed Bézier splines of degrees {d} through '.format(d=degree) + str(M)
|
|
56
|
+
K = np.sum(self._degrees - 1) + 1
|
|
57
|
+
|
|
58
|
+
self._nsamples = n_samples if n_samples else K+1
|
|
59
|
+
assert self._nsamples > K
|
|
60
|
+
|
|
61
|
+
dimension = (K + 1) * M.dim
|
|
62
|
+
point_shape = (K+1, *M.point_shape)
|
|
63
|
+
self._K = K
|
|
64
|
+
super().__init__(name, dimension, point_shape)
|
|
65
|
+
|
|
66
|
+
self._iscycle = isscycle
|
|
67
|
+
|
|
68
|
+
if structure:
|
|
69
|
+
getattr(self, f'init{structure}Structure')()
|
|
70
|
+
|
|
71
|
+
def tree_flatten(self):
|
|
72
|
+
children, aux = super().tree_flatten()
|
|
73
|
+
aux += (self.nsegments, self.degrees[0], self.iscycle, self.nsteps, self.nsamples)
|
|
74
|
+
return children + (self.M,), aux
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def tree_unflatten(cls, aux_data, children):
|
|
78
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
79
|
+
*children, M = children
|
|
80
|
+
*aux_data, n_seg, d, c, n_st, n_sam = aux_data
|
|
81
|
+
obj = cls(M, n_seg, d, c, n_st, n_sam, structure=None)
|
|
82
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
83
|
+
return obj
|
|
84
|
+
|
|
85
|
+
def initFunctionalBasedStructure(self):
|
|
86
|
+
"""
|
|
87
|
+
Instantiate functional-based structure with discrete methods.
|
|
88
|
+
"""
|
|
89
|
+
structure = Bezierfold.FunctionalBasedStructure(self)
|
|
90
|
+
self._metric = structure
|
|
91
|
+
self._connec = structure
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def M(self) -> Manifold:
|
|
95
|
+
"""Return the underlying manifold
|
|
96
|
+
"""
|
|
97
|
+
return self._M
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def degrees(self) -> np.array:
|
|
101
|
+
"""Return vector of segment degrees
|
|
102
|
+
"""
|
|
103
|
+
return self._degrees
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def nsegments(self) -> int:
|
|
107
|
+
"""Returns the number of spline segments."""
|
|
108
|
+
return len(self._degrees)
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def K(self) -> int:
|
|
112
|
+
"""Return the generalized degree of a Bezier spline, i.e., the number of independent control points - 1
|
|
113
|
+
"""
|
|
114
|
+
return self._K
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def iscycle(self) -> bool:
|
|
118
|
+
"""Return whether the Bezierfold consists of non-closed or closed splines
|
|
119
|
+
"""
|
|
120
|
+
return self._iscycle
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def nsamples(self):
|
|
124
|
+
"""Returns the number of samples for quadrature of curve distance in L²(I, M)."""
|
|
125
|
+
return self._nsamples
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def nsteps(self):
|
|
129
|
+
"""Returns the number of steps (i.e. segments) for approximation of geodesics in Bezierfold"""
|
|
130
|
+
return self._nsteps
|
|
131
|
+
|
|
132
|
+
def correct_type(self, B: BezierSpline) -> bool:
|
|
133
|
+
"""Check whether B has the right segment degrees"""
|
|
134
|
+
if jnp.all(jnp.atleast_1d(B.degrees) == self.degrees):
|
|
135
|
+
return True
|
|
136
|
+
else:
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
def rand(self, key: jax.Array) -> BezierSpline:
|
|
140
|
+
"""Return random Bézier spline"""
|
|
141
|
+
subkeys = jax.random.split(key, self.K + 1)
|
|
142
|
+
return BezierSpline(self.M, full_set(self.M, jax.vmap(self.M.rand)(subkeys),
|
|
143
|
+
self.degrees, self.iscycle))
|
|
144
|
+
|
|
145
|
+
def randvec(self, B: BezierSpline, key: jax.Array) -> jnp.array:
|
|
146
|
+
"""Return random vector for every independent control point"""
|
|
147
|
+
pts = indep_set(B, self.iscycle)
|
|
148
|
+
subkeys = jax.random.split(key, len(pts))
|
|
149
|
+
return jax.vmap(self.M.randvec)(pts, subkeys)
|
|
150
|
+
|
|
151
|
+
def zerovec(self) -> jnp.array:
|
|
152
|
+
"""Return zero vector for every independent control point"""
|
|
153
|
+
return jnp.tile(self.M.zerovec(), (self.K + 1,) + (1,)*len(self.M.point_shape))
|
|
154
|
+
|
|
155
|
+
def to_coords(self, B: BezierSpline) -> jnp.array:
|
|
156
|
+
"""
|
|
157
|
+
:param B: Bézier spline
|
|
158
|
+
:return: Array of independent control points.
|
|
159
|
+
"""
|
|
160
|
+
return indep_set(B.control_points, self.iscycle)
|
|
161
|
+
|
|
162
|
+
def from_coords(self, pts: jnp.array) -> BezierSpline:
|
|
163
|
+
"""
|
|
164
|
+
:param pts: independent control points
|
|
165
|
+
:return: Bézier spline
|
|
166
|
+
"""
|
|
167
|
+
pts = full_set(self.M, pts, self.degrees, self.iscycle)
|
|
168
|
+
return BezierSpline(self.M, pts, self.iscycle)
|
|
169
|
+
|
|
170
|
+
def proj(self, X, H):
|
|
171
|
+
return H
|
|
172
|
+
|
|
173
|
+
############################## Functional-based structure ##############################
|
|
174
|
+
class FunctionalBasedStructure(Metric):
|
|
175
|
+
"""
|
|
176
|
+
Functional-based metric structure
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, Bf: Bezierfold):
|
|
180
|
+
"""
|
|
181
|
+
Constructor.
|
|
182
|
+
"""
|
|
183
|
+
self._Bf = Bf
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def __str__(self):
|
|
187
|
+
return "Bézierfold-functional-based structure"
|
|
188
|
+
|
|
189
|
+
def inner(self, p: jnp.array, X: jnp.array, Y: jnp.array):
|
|
190
|
+
"""Functional-based metric, i.e. L²(I, TBM).
|
|
191
|
+
|
|
192
|
+
:arg p: Bézier spline in M
|
|
193
|
+
:arg X: tangent vector at p
|
|
194
|
+
:arg Y: tangent vector at p
|
|
195
|
+
:return: inner product of X and Y at p
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
M, deg, cyclic = self._Bf.M, self._Bf.degrees, self._Bf.iscycle
|
|
199
|
+
|
|
200
|
+
def full(q, V):
|
|
201
|
+
f = lambda pts: jnp.array(full_set(M, pts, deg, cyclic))
|
|
202
|
+
# fwd-diff. of full_set
|
|
203
|
+
q_full, V_full = jax.jvp(f, (q,), (V,))
|
|
204
|
+
# proj. to tangent space
|
|
205
|
+
vproj = jax.vmap(jax.vmap(M.proj))
|
|
206
|
+
return q_full, vproj(q_full, V_full)
|
|
207
|
+
|
|
208
|
+
# map p, X, Y to all control points
|
|
209
|
+
p_full, X_full = full(p, X)
|
|
210
|
+
_, Y_full = full(p, Y)
|
|
211
|
+
|
|
212
|
+
# sample spline and generalized Jacobi fields for X, Y
|
|
213
|
+
t = jnp.linspace(0., self._Bf.nsegments, self._Bf.nsamples)
|
|
214
|
+
spln = BezierSpline(M, p_full, cyclic)
|
|
215
|
+
vDpB = jax.vmap(spln.DpB, (0, None))
|
|
216
|
+
B, Jx = vDpB(t, X_full)
|
|
217
|
+
_, Jy = vDpB(t, Y_full)
|
|
218
|
+
|
|
219
|
+
# eval inner products
|
|
220
|
+
return jax.vmap(self._Bf.M.metric.inner)(B, Jx, Jy).sum()
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def typicaldist(self) -> float:
|
|
224
|
+
# approximations via control points
|
|
225
|
+
return self._Bf.K * self._Bf.M.metric.typicaldist
|
|
226
|
+
|
|
227
|
+
def dist(self, a: jnp.array, b: jnp.array) -> float:
|
|
228
|
+
"""Approximate the distance between two Bézier splines
|
|
229
|
+
|
|
230
|
+
:param a: independent control points of a Bézier spline
|
|
231
|
+
:param b: independent control points of a Bézier spline
|
|
232
|
+
:return: length of n-geodesic between A and B (approximation of the distance)
|
|
233
|
+
"""
|
|
234
|
+
return jnp.sqrt(self.squared_dist(a, b))
|
|
235
|
+
|
|
236
|
+
def squared_dist_extrinsic(self, p, q):
|
|
237
|
+
t = jnp.linspace(0., self._Bf.nsegments, self._Bf.nsamples)
|
|
238
|
+
d2 = jax.vmap(self._Bf.M.metric.squared_dist)
|
|
239
|
+
return d2(sample(self._Bf, p, t), sample(self._Bf, q, t)).sum()
|
|
240
|
+
|
|
241
|
+
def squared_dist(self, p, q):
|
|
242
|
+
n = self._Bf.nsteps
|
|
243
|
+
gamma = self.discgeodesic(self._Bf, p, q, n=n)
|
|
244
|
+
return jax.vmap(self.squared_dist_extrinsic)(gamma[:-1], gamma[1:]).sum() * n
|
|
245
|
+
|
|
246
|
+
@staticmethod
|
|
247
|
+
@partial(jax.jit, static_argnames=['Bf'])
|
|
248
|
+
def discexp(Bf, a: jnp.array, b: jnp.array):
|
|
249
|
+
"""
|
|
250
|
+
Compute c such that [a,b,c] is a discrete 2-geodesic.
|
|
251
|
+
:param Bf: Bezierfold a ang b live in
|
|
252
|
+
:param a: Bézier spline in manifold M (i.e. independent control points thereof)
|
|
253
|
+
:param b: Bézier spline in manifold M (i.e. independent control points thereof)
|
|
254
|
+
:return: c
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
t = jnp.linspace(0., Bf.nsegments, Bf.nsamples)
|
|
258
|
+
|
|
259
|
+
# initial guess for c
|
|
260
|
+
c = jax.vmap(Bf.M.connec.geopoint, (0, 0, None))(a, b, 2.)
|
|
261
|
+
|
|
262
|
+
# gradient of sum-of-squared-distances between samples along alpha and beta w.r.t. ctrl. pts. of alpha
|
|
263
|
+
def G(alpha, beta):
|
|
264
|
+
egrad = jax.grad(lambda x: jax.vmap(Bf.M.metric.squared_dist)(sample(Bf, x, t), sample(Bf, beta, t)).sum())
|
|
265
|
+
return jax.vmap(Bf.M.metric.egrad2rgrad)(alpha, egrad(alpha))
|
|
266
|
+
|
|
267
|
+
# gradient for b w.r.t. a
|
|
268
|
+
G_a = G(b, a)
|
|
269
|
+
|
|
270
|
+
# discrete Euler-Lagrange cnd. of path energy for [a,b,c]
|
|
271
|
+
def F(x):
|
|
272
|
+
return G(b, x) + G_a
|
|
273
|
+
|
|
274
|
+
# solve F(x) = 0
|
|
275
|
+
N = PowerManifold(Bf.M, Bf.K+1)
|
|
276
|
+
return RiemannianNewtonRaphson.solve(N, F, c, stepsize=.1, maxiter=min(Bf.dim, 1000))
|
|
277
|
+
|
|
278
|
+
def exp(self, p: jnp.array, X: jnp.array) -> jnp.array:
|
|
279
|
+
n = self._Bf.nsteps
|
|
280
|
+
|
|
281
|
+
def body(carry, _):
|
|
282
|
+
a, b = carry
|
|
283
|
+
# compute c s.t. [a,b,c] is discrete 2-geodesic
|
|
284
|
+
c = self.discexp(self._Bf, a, b)
|
|
285
|
+
return (b, c), None
|
|
286
|
+
|
|
287
|
+
q = jax.vmap(self._Bf.M.connec.exp)(p, X/n)
|
|
288
|
+
(_, q), _ = jax.lax.scan(body, (p, q), jnp.empty(n))
|
|
289
|
+
|
|
290
|
+
return q
|
|
291
|
+
|
|
292
|
+
def log(self, p: jnp.array, q: jnp.array) -> jnp.array:
|
|
293
|
+
n = self._Bf.nsteps
|
|
294
|
+
gamma = self.discgeodesic(self._Bf, p, q, n=n)
|
|
295
|
+
return jax.vmap(self._Bf.M.connec.log)(p, gamma[1]) * n
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
@partial(jax.jit, static_argnames=['Bf', 'n'])
|
|
299
|
+
def discgeodesic(Bf: Bezierfold, p: jnp.array, q: jnp.array, n: int = 5, maxiter: int = 100, minchange: float = 1e-6) -> jnp.array:
|
|
300
|
+
"""Discrete shortest path through space of Bézier splines.
|
|
301
|
+
|
|
302
|
+
:param Bf: Bezierfold p and q live in
|
|
303
|
+
:param p: Bézier spline in manifold M (i.e. independent control points thereof)
|
|
304
|
+
:param q: Bézier spline in manifold M (i.e. independent control points thereof)
|
|
305
|
+
:param n: create discrete n-geodesic
|
|
306
|
+
:param maxiter: max. number of iterations
|
|
307
|
+
:param minchange: min. change in coordinates to declare convergence
|
|
308
|
+
:return: control points of the Bézier splines along the shortest path
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
# Initialize inner splines of path
|
|
312
|
+
|
|
313
|
+
# logs between corresponding control points of A and B (save repeated computations)
|
|
314
|
+
X = jax.vmap(Bf.M.connec.log)(p, q)
|
|
315
|
+
# exps
|
|
316
|
+
t_exp = lambda t: jax.vmap(Bf.M.connec.exp)(p, t * X)
|
|
317
|
+
H = jax.vmap(t_exp)(jnp.linspace(0., 1., n + 1)[1:-1])
|
|
318
|
+
# add start-/endpt.
|
|
319
|
+
H = jnp.concatenate((jnp.expand_dims(p, axis=0), H, jnp.expand_dims(q, axis=0)))
|
|
320
|
+
|
|
321
|
+
# Discrete path shortening flow
|
|
322
|
+
def body(args):
|
|
323
|
+
x, _, i = args
|
|
324
|
+
x, d = curve_shortening_step(Bf, x)
|
|
325
|
+
# jax.debug.print("{}: {}", i, d)
|
|
326
|
+
return x, d, i + 1
|
|
327
|
+
|
|
328
|
+
# check convergence
|
|
329
|
+
def cond(args):
|
|
330
|
+
_, d, i = args
|
|
331
|
+
c = jnp.array([d > minchange, i < maxiter])
|
|
332
|
+
return jnp.all(c)
|
|
333
|
+
|
|
334
|
+
H, *_ = jax.lax.while_loop(cond, body, (H, jnp.array(1.), jnp.array(0)))
|
|
335
|
+
|
|
336
|
+
return H
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
#@partial(jax.jit, static_argnames=['Bf'])
|
|
340
|
+
def mean(Bf, B, maxiter: int = 500, minchange: float = 1e-5):
|
|
341
|
+
"""Discrete mean of a set of Bézier splines
|
|
342
|
+
|
|
343
|
+
:param Bf: Bezierfold
|
|
344
|
+
:param B: array of splines (i.e. independent control points thereof)
|
|
345
|
+
:param maxiter: max. number of iterations
|
|
346
|
+
:param minchange: min. change in coordinates to declare convergence
|
|
347
|
+
:return: (independent control points of) mean curve
|
|
348
|
+
"""
|
|
349
|
+
# times at which to sample splines
|
|
350
|
+
t = jnp.linspace(0, Bf.nsegments, Bf.nsamples)
|
|
351
|
+
|
|
352
|
+
# setup 'regression' problem for mean (where there are len(B) targets for each time pt.)
|
|
353
|
+
|
|
354
|
+
# search space: k-fold product of M
|
|
355
|
+
N = PowerManifold(Bf.M, Bf.K+1)
|
|
356
|
+
|
|
357
|
+
# sum-of-squared-distances
|
|
358
|
+
def ssd(pts, Y, param):
|
|
359
|
+
x = sample(Bf, pts, param)
|
|
360
|
+
d = jax.vmap(jax.vmap(Bf.M.metric.squared_dist), (None, 0))(x, Y)
|
|
361
|
+
return jnp.sum(d) / np.prod(Y.shape[:2])
|
|
362
|
+
|
|
363
|
+
# compute mean spline
|
|
364
|
+
|
|
365
|
+
# initialize i-th control point of the mean as the mean of the i-th control points of the data
|
|
366
|
+
mean = lambda b: ExponentialBarycenter.compute(Bf.M, b)
|
|
367
|
+
init = jax.vmap(mean, 1)(B)
|
|
368
|
+
|
|
369
|
+
# init legs, i.e. n-geodesics between mean and input curves B
|
|
370
|
+
discgeodesic = Bezierfold.FunctionalBasedStructure.discgeodesic
|
|
371
|
+
F_init = jax.vmap(discgeodesic, (None, None, 0, None))(Bf, init, B, Bf.nsteps)
|
|
372
|
+
|
|
373
|
+
def body(args):
|
|
374
|
+
x, F, change, i = args
|
|
375
|
+
|
|
376
|
+
# update x via regression
|
|
377
|
+
Y = jax.vmap(sample, (None, 0, None))(Bf, F[:, 1], t)
|
|
378
|
+
opt = RiemannianSteepestDescent.fixedpoint(N, lambda a: ssd(a, Y, t), x)
|
|
379
|
+
change = jnp.abs(opt - x).max()
|
|
380
|
+
#change = jnp.linalg.norm((opt - x).ravel(), np.inf)
|
|
381
|
+
|
|
382
|
+
# update legs of 'polygonal spider'
|
|
383
|
+
F = F.at[:, 0].set(opt)
|
|
384
|
+
F, d = jax.vmap(curve_shortening_step, (None, 0))(Bf, F)
|
|
385
|
+
change = jnp.array([change, jnp.abs(d).max()]).max()
|
|
386
|
+
#change = jnp.array([change, jnp.linalg.norm(d.ravel(), np.inf)]).max()
|
|
387
|
+
|
|
388
|
+
jax.debug.print("{}: {}", i, change)
|
|
389
|
+
return opt, F, change, i + 1
|
|
390
|
+
|
|
391
|
+
def cond(args):
|
|
392
|
+
_, _, change, i = args
|
|
393
|
+
c = jnp.array([change > minchange, i < maxiter])
|
|
394
|
+
return jnp.all(c)
|
|
395
|
+
|
|
396
|
+
mu, F_mu, *_ = jax.lax.while_loop(cond, body, (init, F_init, 1., 0))
|
|
397
|
+
|
|
398
|
+
return mu, F_mu
|
|
399
|
+
|
|
400
|
+
def gram(self, B_mean: jnp.array, F: jnp.array):
|
|
401
|
+
"""Approximates the Gram matrix for a curve data set.
|
|
402
|
+
|
|
403
|
+
:param B_mean: mean of curves in B (as returned by #mean)
|
|
404
|
+
:param F: discrete spider, i.e, discrete paths from mean to data (as returned by #mean)
|
|
405
|
+
:return G: Gram matrix
|
|
406
|
+
"""
|
|
407
|
+
n = len(F)
|
|
408
|
+
G = jnp.zeros((n, n))
|
|
409
|
+
for i, si in enumerate(F):
|
|
410
|
+
for j, sj in enumerate(F[i:], start=i):
|
|
411
|
+
G = G.at[i, j].set(n / 2 * (
|
|
412
|
+
self.squared_dist_extrinsic(B_mean, si[1])
|
|
413
|
+
+ self.squared_dist_extrinsic(B_mean, sj[1])
|
|
414
|
+
- self.squared_dist_extrinsic(si[1], sj[1]))
|
|
415
|
+
)
|
|
416
|
+
G = G.at[j, i].set(G[i, j])
|
|
417
|
+
|
|
418
|
+
return G
|
|
419
|
+
|
|
420
|
+
def egrad2rgrad(self, p: jnp.array, X: jnp.array) -> jnp.array:
|
|
421
|
+
"""
|
|
422
|
+
:param p: Bézier spline in manifold M (i.e. independent control points thereof)
|
|
423
|
+
:param X: tangent vector (i.e. tangent vectors at the independent control points)
|
|
424
|
+
"""
|
|
425
|
+
return jax.vmap(self._Bf.M.metric.egrad2rgrad)(p, X)
|
|
426
|
+
|
|
427
|
+
### not imlemented ###
|
|
428
|
+
|
|
429
|
+
def ehess2rhess(self, p, G, H, X):
|
|
430
|
+
"""Converts the Euclidean gradient P_G and Hessian H of a function at
|
|
431
|
+
a point p along a tangent vector X to the Riemannian Hessian
|
|
432
|
+
along X on the manifold.
|
|
433
|
+
"""
|
|
434
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
435
|
+
|
|
436
|
+
def retr(self, R, X):
|
|
437
|
+
return self.exp(R, X)
|
|
438
|
+
|
|
439
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
440
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
441
|
+
|
|
442
|
+
def transp(self, R, Q, X):
|
|
443
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
444
|
+
|
|
445
|
+
def jacobiField(self, R, Q, t, X):
|
|
446
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
447
|
+
|
|
448
|
+
def adjJacobi(self, R, Q, t, X):
|
|
449
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
450
|
+
|
|
451
|
+
def flat(self, p, X):
|
|
452
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
453
|
+
|
|
454
|
+
def sharp(self, p, dX):
|
|
455
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def sample(Bf: Bezierfold, pts: jnp.array, t: jnp.array) -> jnp.array:
|
|
459
|
+
# vectorized methods for sampling of splines (from independent ctrl. pts.)
|
|
460
|
+
return jax.vmap(lambda p, s: Bf.from_coords(p).eval(s), (None, 0))(pts, t)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def curve_shortening_step(Bf: Bezierfold, x: jnp.array) -> Tuple[jnp.array, float]:
|
|
464
|
+
"""Single step of discrete curve shortening flow: Replace inner node with
|
|
465
|
+
average of its neighbours (s.t. it's the midpoint of the connecting 2-geodesic).
|
|
466
|
+
|
|
467
|
+
:param Bf: Bezierfold
|
|
468
|
+
:param x: Discrete path in Bf (i.e. independent control points of nodes)
|
|
469
|
+
:return: updated nodes, inf-norm of update
|
|
470
|
+
"""
|
|
471
|
+
# local import to avoid cyclic dependencies
|
|
472
|
+
from morphomatics.stats.riemannian_regression import RiemannianRegression
|
|
473
|
+
|
|
474
|
+
deg = Bf.degrees[0]
|
|
475
|
+
nseg = Bf.nsegments
|
|
476
|
+
|
|
477
|
+
t = jnp.linspace(0., nseg, Bf.nsamples)
|
|
478
|
+
tt = jnp.concatenate([t, t])
|
|
479
|
+
|
|
480
|
+
def body(carry, cur_post):
|
|
481
|
+
pre, d = carry
|
|
482
|
+
cur, post = cur_post
|
|
483
|
+
# sample pre & post
|
|
484
|
+
pre = sample(Bf, pre, t)
|
|
485
|
+
post = sample(Bf, post, t)
|
|
486
|
+
# update (fit cur to pre & post)
|
|
487
|
+
Y = jnp.concatenate([pre, post])
|
|
488
|
+
opt = RiemannianRegression.fit(Bf.M, Y, tt, cur, deg, nseg, maxiter=1, iscycle=Bf.iscycle)
|
|
489
|
+
# update inf-norm
|
|
490
|
+
d = jnp.array([d, jnp.abs(opt - cur).max()]).max()
|
|
491
|
+
#d = jnp.array([d, jnp.linalg.norm(jnp.ravel(opt - cur), ord=jnp.inf)]).max()
|
|
492
|
+
return (opt, d), opt
|
|
493
|
+
|
|
494
|
+
# stack each node with its successor
|
|
495
|
+
stacked = jnp.stack([x[1:-1], x[2:]], axis=1)
|
|
496
|
+
|
|
497
|
+
# update nodes one-by-one
|
|
498
|
+
(_, d), inner_nodes = jax.lax.scan(body, (x[0], 0.), stacked)
|
|
499
|
+
|
|
500
|
+
return jnp.concatenate((x[None, 0], inner_nodes, x[None, -1])), d
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
# postponed evaluation of annotations to circumvent cyclic dependencies (will be default behavior in Python 4.0)
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
|
|
20
|
+
class Connection(metaclass=abc.ABCMeta):
|
|
21
|
+
"""
|
|
22
|
+
Interface setting out a template for a connection on the tangent bundle of a manifold.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, M: Manifold):
|
|
26
|
+
""" Construct connection.
|
|
27
|
+
:param M: underlying manifold
|
|
28
|
+
"""
|
|
29
|
+
self._M = M
|
|
30
|
+
|
|
31
|
+
@abc.abstractmethod
|
|
32
|
+
def __str__(self):
|
|
33
|
+
"""Returns a string representation of the particular connection."""
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def exp(self, p, X):
|
|
37
|
+
"""Exponential map of the connection at p applied to the tangent vector X.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
def retr(self, p, X):
|
|
42
|
+
"""Computes a retraction mapping a vector X in the tangent space at
|
|
43
|
+
p to the manifold.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def log(self, p, q):
|
|
48
|
+
"""Logarithmic map of the connection at p applied to q.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def geopoint(self, p, q, t):
|
|
52
|
+
"""Evaluates the geodesic between p and q at time t.
|
|
53
|
+
"""
|
|
54
|
+
return self.exp(p, t * self.log(p, q))
|
|
55
|
+
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def transp(self, p, q, X):
|
|
58
|
+
"""Computes a vector transport which transports a vector X in the
|
|
59
|
+
tangent space at p to the tangent space at q.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
@abc.abstractmethod
|
|
63
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
64
|
+
"""Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
|
|
65
|
+
covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
66
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
67
|
+
is used.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def jacobiField(self, p, q, t, X):
|
|
72
|
+
"""
|
|
73
|
+
Evaluates a Jacobi field (with boundary conditions gam'(0) = X, gam'(1) = 0) along the geodesic gam from p to q.
|
|
74
|
+
:param p: element of the Riemannian manifold
|
|
75
|
+
:param q: element of the Riemannian manifold
|
|
76
|
+
:param t: scalar in [0,1]
|
|
77
|
+
:param X: tangent vector at p
|
|
78
|
+
:return: [b, J] with J and b being the Jacobi field at t and the corresponding basepoint
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def dxgeo(self, p, q, t, X):
|
|
83
|
+
"""Evaluates the differential of the geodesic gam from p to q w.r.t. the starting point p at X,
|
|
84
|
+
i.e, d_p gamma(t; ., q) applied to X; the result is en element of the tangent space at gam(t).
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
return self.jacobiField(p, q, t, X)[1]
|
|
88
|
+
|
|
89
|
+
def dygeo(self, p, q, t, X):
|
|
90
|
+
"""Evaluates the differential of the geodesic gam from p to q w.r.t. the end point q at X,
|
|
91
|
+
i.e, d_q gamma(t; p, .) applied to X; the result is en element of the tangent space at gam(t).
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
return self.jacobiField(q, p, 1 - t, X)[1]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _eval_jacobi_embed(C: Connection, p, q, t, X):
|
|
98
|
+
""" Implementation of eval_jacobi for isometrically embedded manifolds using (forward-mode) automatic
|
|
99
|
+
differentiation of geopoint(..).
|
|
100
|
+
|
|
101
|
+
ATTENTION: the result must be projected to the tangent space!
|
|
102
|
+
"""
|
|
103
|
+
f = lambda O: C.geopoint(O, q, t)
|
|
104
|
+
|
|
105
|
+
return jax.jvp(f, (p,), (X,))
|