morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,500 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ from __future__ import annotations
14
+ from functools import partial
15
+
16
+ from typing import Tuple
17
+
18
+ import numpy as np
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ from morphomatics.geom.bezier_spline import BezierSpline, full_set, indep_set
24
+ from morphomatics.manifold import Manifold, Metric, PowerManifold
25
+ from morphomatics.opt import RiemannianSteepestDescent, RiemannianNewtonRaphson
26
+ from morphomatics.stats import ExponentialBarycenter
27
+
28
+
29
+ class Bezierfold(Manifold):
30
+ """Manifold of Bézier splines (of fixed degrees)
31
+
32
+ """
33
+
34
+ def __init__(self, M: Manifold, n_segments: int, degree: int, isscycle: bool=False,
35
+ n_steps: int=10, n_samples: int=None, structure='FunctionalBased'):
36
+ """Manifold of Bézier splines of constant segment degree
37
+
38
+ :arg M: base manifold in which the curves lie
39
+ :arg n_segments: number of spline segments
40
+ :arg degree: degree of segment (same for each one)
41
+ :arg iscycle: boolean indicating whether the splines are closed
42
+ :arg n_steps: number of steps (i.e. segments) for approximation of geodesics in Bezierfold
43
+ :arg n_samples: number of samples for quadrature of curve distance in L²(I, M)
44
+ :arg structure: type of geometric structure
45
+ """
46
+
47
+ self._M = M
48
+ self._degrees = np.full(n_segments, degree)
49
+ self._nsteps = n_steps
50
+
51
+ if isscycle:
52
+ name = 'Manifold of closed Bézier splines of degree {d} through '.format(d=degree) + str(M)
53
+ K = np.sum(self._degrees - 1) - 1
54
+ else:
55
+ name = 'Manifold of non-closed Bézier splines of degrees {d} through '.format(d=degree) + str(M)
56
+ K = np.sum(self._degrees - 1) + 1
57
+
58
+ self._nsamples = n_samples if n_samples else K+1
59
+ assert self._nsamples > K
60
+
61
+ dimension = (K + 1) * M.dim
62
+ point_shape = (K+1, *M.point_shape)
63
+ self._K = K
64
+ super().__init__(name, dimension, point_shape)
65
+
66
+ self._iscycle = isscycle
67
+
68
+ if structure:
69
+ getattr(self, f'init{structure}Structure')()
70
+
71
+ def tree_flatten(self):
72
+ children, aux = super().tree_flatten()
73
+ aux += (self.nsegments, self.degrees[0], self.iscycle, self.nsteps, self.nsamples)
74
+ return children + (self.M,), aux
75
+
76
+ @classmethod
77
+ def tree_unflatten(cls, aux_data, children):
78
+ """Specifies an unflattening recipe for PyTree registration."""
79
+ *children, M = children
80
+ *aux_data, n_seg, d, c, n_st, n_sam = aux_data
81
+ obj = cls(M, n_seg, d, c, n_st, n_sam, structure=None)
82
+ obj.tree_unflatten_instance(aux_data, children)
83
+ return obj
84
+
85
+ def initFunctionalBasedStructure(self):
86
+ """
87
+ Instantiate functional-based structure with discrete methods.
88
+ """
89
+ structure = Bezierfold.FunctionalBasedStructure(self)
90
+ self._metric = structure
91
+ self._connec = structure
92
+
93
+ @property
94
+ def M(self) -> Manifold:
95
+ """Return the underlying manifold
96
+ """
97
+ return self._M
98
+
99
+ @property
100
+ def degrees(self) -> np.array:
101
+ """Return vector of segment degrees
102
+ """
103
+ return self._degrees
104
+
105
+ @property
106
+ def nsegments(self) -> int:
107
+ """Returns the number of spline segments."""
108
+ return len(self._degrees)
109
+
110
+ @property
111
+ def K(self) -> int:
112
+ """Return the generalized degree of a Bezier spline, i.e., the number of independent control points - 1
113
+ """
114
+ return self._K
115
+
116
+ @property
117
+ def iscycle(self) -> bool:
118
+ """Return whether the Bezierfold consists of non-closed or closed splines
119
+ """
120
+ return self._iscycle
121
+
122
+ @property
123
+ def nsamples(self):
124
+ """Returns the number of samples for quadrature of curve distance in L²(I, M)."""
125
+ return self._nsamples
126
+
127
+ @property
128
+ def nsteps(self):
129
+ """Returns the number of steps (i.e. segments) for approximation of geodesics in Bezierfold"""
130
+ return self._nsteps
131
+
132
+ def correct_type(self, B: BezierSpline) -> bool:
133
+ """Check whether B has the right segment degrees"""
134
+ if jnp.all(jnp.atleast_1d(B.degrees) == self.degrees):
135
+ return True
136
+ else:
137
+ return False
138
+
139
+ def rand(self, key: jax.Array) -> BezierSpline:
140
+ """Return random Bézier spline"""
141
+ subkeys = jax.random.split(key, self.K + 1)
142
+ return BezierSpline(self.M, full_set(self.M, jax.vmap(self.M.rand)(subkeys),
143
+ self.degrees, self.iscycle))
144
+
145
+ def randvec(self, B: BezierSpline, key: jax.Array) -> jnp.array:
146
+ """Return random vector for every independent control point"""
147
+ pts = indep_set(B, self.iscycle)
148
+ subkeys = jax.random.split(key, len(pts))
149
+ return jax.vmap(self.M.randvec)(pts, subkeys)
150
+
151
+ def zerovec(self) -> jnp.array:
152
+ """Return zero vector for every independent control point"""
153
+ return jnp.tile(self.M.zerovec(), (self.K + 1,) + (1,)*len(self.M.point_shape))
154
+
155
+ def to_coords(self, B: BezierSpline) -> jnp.array:
156
+ """
157
+ :param B: Bézier spline
158
+ :return: Array of independent control points.
159
+ """
160
+ return indep_set(B.control_points, self.iscycle)
161
+
162
+ def from_coords(self, pts: jnp.array) -> BezierSpline:
163
+ """
164
+ :param pts: independent control points
165
+ :return: Bézier spline
166
+ """
167
+ pts = full_set(self.M, pts, self.degrees, self.iscycle)
168
+ return BezierSpline(self.M, pts, self.iscycle)
169
+
170
+ def proj(self, X, H):
171
+ return H
172
+
173
+ ############################## Functional-based structure ##############################
174
+ class FunctionalBasedStructure(Metric):
175
+ """
176
+ Functional-based metric structure
177
+ """
178
+
179
+ def __init__(self, Bf: Bezierfold):
180
+ """
181
+ Constructor.
182
+ """
183
+ self._Bf = Bf
184
+
185
+ @property
186
+ def __str__(self):
187
+ return "Bézierfold-functional-based structure"
188
+
189
+ def inner(self, p: jnp.array, X: jnp.array, Y: jnp.array):
190
+ """Functional-based metric, i.e. L²(I, TBM).
191
+
192
+ :arg p: Bézier spline in M
193
+ :arg X: tangent vector at p
194
+ :arg Y: tangent vector at p
195
+ :return: inner product of X and Y at p
196
+ """
197
+
198
+ M, deg, cyclic = self._Bf.M, self._Bf.degrees, self._Bf.iscycle
199
+
200
+ def full(q, V):
201
+ f = lambda pts: jnp.array(full_set(M, pts, deg, cyclic))
202
+ # fwd-diff. of full_set
203
+ q_full, V_full = jax.jvp(f, (q,), (V,))
204
+ # proj. to tangent space
205
+ vproj = jax.vmap(jax.vmap(M.proj))
206
+ return q_full, vproj(q_full, V_full)
207
+
208
+ # map p, X, Y to all control points
209
+ p_full, X_full = full(p, X)
210
+ _, Y_full = full(p, Y)
211
+
212
+ # sample spline and generalized Jacobi fields for X, Y
213
+ t = jnp.linspace(0., self._Bf.nsegments, self._Bf.nsamples)
214
+ spln = BezierSpline(M, p_full, cyclic)
215
+ vDpB = jax.vmap(spln.DpB, (0, None))
216
+ B, Jx = vDpB(t, X_full)
217
+ _, Jy = vDpB(t, Y_full)
218
+
219
+ # eval inner products
220
+ return jax.vmap(self._Bf.M.metric.inner)(B, Jx, Jy).sum()
221
+
222
+ @property
223
+ def typicaldist(self) -> float:
224
+ # approximations via control points
225
+ return self._Bf.K * self._Bf.M.metric.typicaldist
226
+
227
+ def dist(self, a: jnp.array, b: jnp.array) -> float:
228
+ """Approximate the distance between two Bézier splines
229
+
230
+ :param a: independent control points of a Bézier spline
231
+ :param b: independent control points of a Bézier spline
232
+ :return: length of n-geodesic between A and B (approximation of the distance)
233
+ """
234
+ return jnp.sqrt(self.squared_dist(a, b))
235
+
236
+ def squared_dist_extrinsic(self, p, q):
237
+ t = jnp.linspace(0., self._Bf.nsegments, self._Bf.nsamples)
238
+ d2 = jax.vmap(self._Bf.M.metric.squared_dist)
239
+ return d2(sample(self._Bf, p, t), sample(self._Bf, q, t)).sum()
240
+
241
+ def squared_dist(self, p, q):
242
+ n = self._Bf.nsteps
243
+ gamma = self.discgeodesic(self._Bf, p, q, n=n)
244
+ return jax.vmap(self.squared_dist_extrinsic)(gamma[:-1], gamma[1:]).sum() * n
245
+
246
+ @staticmethod
247
+ @partial(jax.jit, static_argnames=['Bf'])
248
+ def discexp(Bf, a: jnp.array, b: jnp.array):
249
+ """
250
+ Compute c such that [a,b,c] is a discrete 2-geodesic.
251
+ :param Bf: Bezierfold a ang b live in
252
+ :param a: Bézier spline in manifold M (i.e. independent control points thereof)
253
+ :param b: Bézier spline in manifold M (i.e. independent control points thereof)
254
+ :return: c
255
+ """
256
+
257
+ t = jnp.linspace(0., Bf.nsegments, Bf.nsamples)
258
+
259
+ # initial guess for c
260
+ c = jax.vmap(Bf.M.connec.geopoint, (0, 0, None))(a, b, 2.)
261
+
262
+ # gradient of sum-of-squared-distances between samples along alpha and beta w.r.t. ctrl. pts. of alpha
263
+ def G(alpha, beta):
264
+ egrad = jax.grad(lambda x: jax.vmap(Bf.M.metric.squared_dist)(sample(Bf, x, t), sample(Bf, beta, t)).sum())
265
+ return jax.vmap(Bf.M.metric.egrad2rgrad)(alpha, egrad(alpha))
266
+
267
+ # gradient for b w.r.t. a
268
+ G_a = G(b, a)
269
+
270
+ # discrete Euler-Lagrange cnd. of path energy for [a,b,c]
271
+ def F(x):
272
+ return G(b, x) + G_a
273
+
274
+ # solve F(x) = 0
275
+ N = PowerManifold(Bf.M, Bf.K+1)
276
+ return RiemannianNewtonRaphson.solve(N, F, c, stepsize=.1, maxiter=min(Bf.dim, 1000))
277
+
278
+ def exp(self, p: jnp.array, X: jnp.array) -> jnp.array:
279
+ n = self._Bf.nsteps
280
+
281
+ def body(carry, _):
282
+ a, b = carry
283
+ # compute c s.t. [a,b,c] is discrete 2-geodesic
284
+ c = self.discexp(self._Bf, a, b)
285
+ return (b, c), None
286
+
287
+ q = jax.vmap(self._Bf.M.connec.exp)(p, X/n)
288
+ (_, q), _ = jax.lax.scan(body, (p, q), jnp.empty(n))
289
+
290
+ return q
291
+
292
+ def log(self, p: jnp.array, q: jnp.array) -> jnp.array:
293
+ n = self._Bf.nsteps
294
+ gamma = self.discgeodesic(self._Bf, p, q, n=n)
295
+ return jax.vmap(self._Bf.M.connec.log)(p, gamma[1]) * n
296
+
297
+ @staticmethod
298
+ @partial(jax.jit, static_argnames=['Bf', 'n'])
299
+ def discgeodesic(Bf: Bezierfold, p: jnp.array, q: jnp.array, n: int = 5, maxiter: int = 100, minchange: float = 1e-6) -> jnp.array:
300
+ """Discrete shortest path through space of Bézier splines.
301
+
302
+ :param Bf: Bezierfold p and q live in
303
+ :param p: Bézier spline in manifold M (i.e. independent control points thereof)
304
+ :param q: Bézier spline in manifold M (i.e. independent control points thereof)
305
+ :param n: create discrete n-geodesic
306
+ :param maxiter: max. number of iterations
307
+ :param minchange: min. change in coordinates to declare convergence
308
+ :return: control points of the Bézier splines along the shortest path
309
+ """
310
+
311
+ # Initialize inner splines of path
312
+
313
+ # logs between corresponding control points of A and B (save repeated computations)
314
+ X = jax.vmap(Bf.M.connec.log)(p, q)
315
+ # exps
316
+ t_exp = lambda t: jax.vmap(Bf.M.connec.exp)(p, t * X)
317
+ H = jax.vmap(t_exp)(jnp.linspace(0., 1., n + 1)[1:-1])
318
+ # add start-/endpt.
319
+ H = jnp.concatenate((jnp.expand_dims(p, axis=0), H, jnp.expand_dims(q, axis=0)))
320
+
321
+ # Discrete path shortening flow
322
+ def body(args):
323
+ x, _, i = args
324
+ x, d = curve_shortening_step(Bf, x)
325
+ # jax.debug.print("{}: {}", i, d)
326
+ return x, d, i + 1
327
+
328
+ # check convergence
329
+ def cond(args):
330
+ _, d, i = args
331
+ c = jnp.array([d > minchange, i < maxiter])
332
+ return jnp.all(c)
333
+
334
+ H, *_ = jax.lax.while_loop(cond, body, (H, jnp.array(1.), jnp.array(0)))
335
+
336
+ return H
337
+
338
+ @staticmethod
339
+ #@partial(jax.jit, static_argnames=['Bf'])
340
+ def mean(Bf, B, maxiter: int = 500, minchange: float = 1e-5):
341
+ """Discrete mean of a set of Bézier splines
342
+
343
+ :param Bf: Bezierfold
344
+ :param B: array of splines (i.e. independent control points thereof)
345
+ :param maxiter: max. number of iterations
346
+ :param minchange: min. change in coordinates to declare convergence
347
+ :return: (independent control points of) mean curve
348
+ """
349
+ # times at which to sample splines
350
+ t = jnp.linspace(0, Bf.nsegments, Bf.nsamples)
351
+
352
+ # setup 'regression' problem for mean (where there are len(B) targets for each time pt.)
353
+
354
+ # search space: k-fold product of M
355
+ N = PowerManifold(Bf.M, Bf.K+1)
356
+
357
+ # sum-of-squared-distances
358
+ def ssd(pts, Y, param):
359
+ x = sample(Bf, pts, param)
360
+ d = jax.vmap(jax.vmap(Bf.M.metric.squared_dist), (None, 0))(x, Y)
361
+ return jnp.sum(d) / np.prod(Y.shape[:2])
362
+
363
+ # compute mean spline
364
+
365
+ # initialize i-th control point of the mean as the mean of the i-th control points of the data
366
+ mean = lambda b: ExponentialBarycenter.compute(Bf.M, b)
367
+ init = jax.vmap(mean, 1)(B)
368
+
369
+ # init legs, i.e. n-geodesics between mean and input curves B
370
+ discgeodesic = Bezierfold.FunctionalBasedStructure.discgeodesic
371
+ F_init = jax.vmap(discgeodesic, (None, None, 0, None))(Bf, init, B, Bf.nsteps)
372
+
373
+ def body(args):
374
+ x, F, change, i = args
375
+
376
+ # update x via regression
377
+ Y = jax.vmap(sample, (None, 0, None))(Bf, F[:, 1], t)
378
+ opt = RiemannianSteepestDescent.fixedpoint(N, lambda a: ssd(a, Y, t), x)
379
+ change = jnp.abs(opt - x).max()
380
+ #change = jnp.linalg.norm((opt - x).ravel(), np.inf)
381
+
382
+ # update legs of 'polygonal spider'
383
+ F = F.at[:, 0].set(opt)
384
+ F, d = jax.vmap(curve_shortening_step, (None, 0))(Bf, F)
385
+ change = jnp.array([change, jnp.abs(d).max()]).max()
386
+ #change = jnp.array([change, jnp.linalg.norm(d.ravel(), np.inf)]).max()
387
+
388
+ jax.debug.print("{}: {}", i, change)
389
+ return opt, F, change, i + 1
390
+
391
+ def cond(args):
392
+ _, _, change, i = args
393
+ c = jnp.array([change > minchange, i < maxiter])
394
+ return jnp.all(c)
395
+
396
+ mu, F_mu, *_ = jax.lax.while_loop(cond, body, (init, F_init, 1., 0))
397
+
398
+ return mu, F_mu
399
+
400
+ def gram(self, B_mean: jnp.array, F: jnp.array):
401
+ """Approximates the Gram matrix for a curve data set.
402
+
403
+ :param B_mean: mean of curves in B (as returned by #mean)
404
+ :param F: discrete spider, i.e, discrete paths from mean to data (as returned by #mean)
405
+ :return G: Gram matrix
406
+ """
407
+ n = len(F)
408
+ G = jnp.zeros((n, n))
409
+ for i, si in enumerate(F):
410
+ for j, sj in enumerate(F[i:], start=i):
411
+ G = G.at[i, j].set(n / 2 * (
412
+ self.squared_dist_extrinsic(B_mean, si[1])
413
+ + self.squared_dist_extrinsic(B_mean, sj[1])
414
+ - self.squared_dist_extrinsic(si[1], sj[1]))
415
+ )
416
+ G = G.at[j, i].set(G[i, j])
417
+
418
+ return G
419
+
420
+ def egrad2rgrad(self, p: jnp.array, X: jnp.array) -> jnp.array:
421
+ """
422
+ :param p: Bézier spline in manifold M (i.e. independent control points thereof)
423
+ :param X: tangent vector (i.e. tangent vectors at the independent control points)
424
+ """
425
+ return jax.vmap(self._Bf.M.metric.egrad2rgrad)(p, X)
426
+
427
+ ### not imlemented ###
428
+
429
+ def ehess2rhess(self, p, G, H, X):
430
+ """Converts the Euclidean gradient P_G and Hessian H of a function at
431
+ a point p along a tangent vector X to the Riemannian Hessian
432
+ along X on the manifold.
433
+ """
434
+ raise NotImplementedError('This function has not been implemented yet.')
435
+
436
+ def retr(self, R, X):
437
+ return self.exp(R, X)
438
+
439
+ def curvature_tensor(self, p, X, Y, Z):
440
+ raise NotImplementedError('This function has not been implemented yet.')
441
+
442
+ def transp(self, R, Q, X):
443
+ raise NotImplementedError('This function has not been implemented yet.')
444
+
445
+ def jacobiField(self, R, Q, t, X):
446
+ raise NotImplementedError('This function has not been implemented yet.')
447
+
448
+ def adjJacobi(self, R, Q, t, X):
449
+ raise NotImplementedError('This function has not been implemented yet.')
450
+
451
+ def flat(self, p, X):
452
+ raise NotImplementedError('This function has not been implemented yet.')
453
+
454
+ def sharp(self, p, dX):
455
+ raise NotImplementedError('This function has not been implemented yet.')
456
+
457
+
458
+ def sample(Bf: Bezierfold, pts: jnp.array, t: jnp.array) -> jnp.array:
459
+ # vectorized methods for sampling of splines (from independent ctrl. pts.)
460
+ return jax.vmap(lambda p, s: Bf.from_coords(p).eval(s), (None, 0))(pts, t)
461
+
462
+
463
+ def curve_shortening_step(Bf: Bezierfold, x: jnp.array) -> Tuple[jnp.array, float]:
464
+ """Single step of discrete curve shortening flow: Replace inner node with
465
+ average of its neighbours (s.t. it's the midpoint of the connecting 2-geodesic).
466
+
467
+ :param Bf: Bezierfold
468
+ :param x: Discrete path in Bf (i.e. independent control points of nodes)
469
+ :return: updated nodes, inf-norm of update
470
+ """
471
+ # local import to avoid cyclic dependencies
472
+ from morphomatics.stats.riemannian_regression import RiemannianRegression
473
+
474
+ deg = Bf.degrees[0]
475
+ nseg = Bf.nsegments
476
+
477
+ t = jnp.linspace(0., nseg, Bf.nsamples)
478
+ tt = jnp.concatenate([t, t])
479
+
480
+ def body(carry, cur_post):
481
+ pre, d = carry
482
+ cur, post = cur_post
483
+ # sample pre & post
484
+ pre = sample(Bf, pre, t)
485
+ post = sample(Bf, post, t)
486
+ # update (fit cur to pre & post)
487
+ Y = jnp.concatenate([pre, post])
488
+ opt = RiemannianRegression.fit(Bf.M, Y, tt, cur, deg, nseg, maxiter=1, iscycle=Bf.iscycle)
489
+ # update inf-norm
490
+ d = jnp.array([d, jnp.abs(opt - cur).max()]).max()
491
+ #d = jnp.array([d, jnp.linalg.norm(jnp.ravel(opt - cur), ord=jnp.inf)]).max()
492
+ return (opt, d), opt
493
+
494
+ # stack each node with its successor
495
+ stacked = jnp.stack([x[1:-1], x[2:]], axis=1)
496
+
497
+ # update nodes one-by-one
498
+ (_, d), inner_nodes = jax.lax.scan(body, (x[0], 0.), stacked)
499
+
500
+ return jnp.concatenate((x[None, 0], inner_nodes, x[None, -1])), d
@@ -0,0 +1,105 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ # postponed evaluation of annotations to circumvent cyclic dependencies (will be default behavior in Python 4.0)
14
+ from __future__ import annotations
15
+
16
+ import abc
17
+
18
+ import jax
19
+
20
+ class Connection(metaclass=abc.ABCMeta):
21
+ """
22
+ Interface setting out a template for a connection on the tangent bundle of a manifold.
23
+ """
24
+
25
+ def __init__(self, M: Manifold):
26
+ """ Construct connection.
27
+ :param M: underlying manifold
28
+ """
29
+ self._M = M
30
+
31
+ @abc.abstractmethod
32
+ def __str__(self):
33
+ """Returns a string representation of the particular connection."""
34
+
35
+ @abc.abstractmethod
36
+ def exp(self, p, X):
37
+ """Exponential map of the connection at p applied to the tangent vector X.
38
+ """
39
+
40
+ @abc.abstractmethod
41
+ def retr(self, p, X):
42
+ """Computes a retraction mapping a vector X in the tangent space at
43
+ p to the manifold.
44
+ """
45
+
46
+ @abc.abstractmethod
47
+ def log(self, p, q):
48
+ """Logarithmic map of the connection at p applied to q.
49
+ """
50
+
51
+ def geopoint(self, p, q, t):
52
+ """Evaluates the geodesic between p and q at time t.
53
+ """
54
+ return self.exp(p, t * self.log(p, q))
55
+
56
+ @abc.abstractmethod
57
+ def transp(self, p, q, X):
58
+ """Computes a vector transport which transports a vector X in the
59
+ tangent space at p to the tangent space at q.
60
+ """
61
+
62
+ @abc.abstractmethod
63
+ def curvature_tensor(self, p, X, Y, Z):
64
+ """Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
65
+ covariant derivative of Y in direction X and [] being the Lie bracket, the convention
66
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
67
+ is used.
68
+ """
69
+
70
+ @abc.abstractmethod
71
+ def jacobiField(self, p, q, t, X):
72
+ """
73
+ Evaluates a Jacobi field (with boundary conditions gam'(0) = X, gam'(1) = 0) along the geodesic gam from p to q.
74
+ :param p: element of the Riemannian manifold
75
+ :param q: element of the Riemannian manifold
76
+ :param t: scalar in [0,1]
77
+ :param X: tangent vector at p
78
+ :return: [b, J] with J and b being the Jacobi field at t and the corresponding basepoint
79
+ """
80
+
81
+
82
+ def dxgeo(self, p, q, t, X):
83
+ """Evaluates the differential of the geodesic gam from p to q w.r.t. the starting point p at X,
84
+ i.e, d_p gamma(t; ., q) applied to X; the result is en element of the tangent space at gam(t).
85
+ """
86
+
87
+ return self.jacobiField(p, q, t, X)[1]
88
+
89
+ def dygeo(self, p, q, t, X):
90
+ """Evaluates the differential of the geodesic gam from p to q w.r.t. the end point q at X,
91
+ i.e, d_q gamma(t; p, .) applied to X; the result is en element of the tangent space at gam(t).
92
+ """
93
+
94
+ return self.jacobiField(q, p, 1 - t, X)[1]
95
+
96
+
97
+ def _eval_jacobi_embed(C: Connection, p, q, t, X):
98
+ """ Implementation of eval_jacobi for isometrically embedded manifolds using (forward-mode) automatic
99
+ differentiation of geopoint(..).
100
+
101
+ ATTENTION: the result must be projected to the tangent space!
102
+ """
103
+ f = lambda O: C.geopoint(O, q, t)
104
+
105
+ return jax.jvp(f, (p,), (X,))