morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,305 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import numpy as np
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+ from morphomatics.geom.bezier_spline import BezierSpline, full_set, indep_set
19
+ from morphomatics.manifold import Metric
20
+ from morphomatics.manifold import Manifold, TangentBundle, PowerManifold
21
+
22
+
23
+ class CubicBezierfold(Manifold):
24
+ """Manifold of _cubic_ Bézier splines.
25
+
26
+ """
27
+
28
+ def __init__(self, M: Manifold, n_segments: int, isscycle: bool = False, structure='GeneralizedSasaki'):
29
+ """Manifold of cubic Bézier splines.
30
+
31
+ :arg M: base manifold in which the curves lie
32
+ :arg n_segments: number of segments
33
+ :arg iscycle: boolean indicating whether the splines are closed
34
+ :arg structure: type of geometric structure
35
+ """
36
+
37
+ self._M = M
38
+
39
+ self._degrees = np.full(n_segments, 3)
40
+
41
+ if isscycle:
42
+ name = 'Manifold of closed, cubic Bézier splines through ' + str(M)
43
+ K = 2*n_segments - 1
44
+ else:
45
+ name = 'Manifold of cubic Bézier splines through ' + str(M)
46
+ K = 2*n_segments + 1
47
+
48
+ dimension = (K + 1) * M.dim
49
+ point_shape = ((K + 1)//2, 2) + M.point_shape
50
+ self._K = K
51
+ super().__init__(name, dimension, point_shape)
52
+
53
+ self._iscycle = isscycle
54
+
55
+ if structure:
56
+ getattr(self, f'init{structure}Structure')()
57
+
58
+ def tree_flatten(self):
59
+ children, aux = super().tree_flatten()
60
+ return children+(self.M,), aux+(self.nsegments, self.iscycle)
61
+
62
+ @classmethod
63
+ def tree_unflatten(cls, aux_data, children):
64
+ """Specifies an unflattening recipe for PyTree registration."""
65
+ *children, M = children
66
+ *aux_data, nsegments, iscycle = aux_data
67
+ obj = cls(M, nsegments, iscycle, structure=None)
68
+ obj.tree_unflatten_instance(aux_data, children)
69
+ return obj
70
+
71
+ def initGeneralizedSasakiStructure(self):
72
+ """
73
+ Instantiate generalized Sasaki structure with discrete methods.
74
+ """
75
+ structure = CubicBezierfold.GeneralizedSasakiStructure(self)
76
+ self._metric = structure
77
+ self._connec = structure
78
+
79
+ @property
80
+ def M(self) -> Manifold:
81
+ """Return the underlying manifold
82
+ """
83
+ return self._M
84
+
85
+ @property
86
+ def nsegments(self) -> int:
87
+ """Returns the number of segments."""
88
+ return len(self._degrees)
89
+
90
+ # TODO: likely not needed here anymore
91
+ @property
92
+ def K(self) -> int:
93
+ """Return the generalized degree of a Bezier spline, i.e., the number of independent control points - 1
94
+ """
95
+ return self._K
96
+
97
+ @property
98
+ def iscycle(self) -> bool:
99
+ """Return whether the Bezierfold consists of non-closed or closed splines
100
+ """
101
+ return self._iscycle
102
+
103
+ def correct_type(self, B: BezierSpline) -> bool:
104
+ """Check whether B has the right segment degrees"""
105
+ if jnp.all(jnp.atleast_1d(B.degrees) == jnp.repeat(3, self.nsegments)):
106
+ return True
107
+ else:
108
+ return False
109
+
110
+ def rand(self, key: jax.Array) -> BezierSpline: #TODO: velocity repr.
111
+ """Return random Bézier spline"""
112
+ subkeys = jax.random.split(key, self.K + 1)
113
+ return BezierSpline(self.M, full_set(self.M, jax.vmap(self.M.rand)(subkeys),
114
+ self.degrees, self.iscycle))
115
+
116
+ def randvec(self, B: BezierSpline, key: jax.Array) -> jnp.array: #TODO: velocity repr.
117
+ """Return random vector for every independent control point"""
118
+ pts = indep_set(B, self.iscycle)
119
+ subkeys = jax.random.split(key, len(pts))
120
+ return jax.vmap(self.M.randvec)(pts, subkeys)
121
+
122
+ def zerovec(self) -> jnp.array: #TODO: velocity repr.
123
+ """Return zero vector for every independent control point"""
124
+ return jnp.array([self.M.zerovec() for k in self.K + 1])
125
+
126
+ def to_coords(self, B: BezierSpline) -> jnp.array:
127
+ """Get initial and final velocities (elements of the tangent bundle) of the segments of a C^1 Bézier spline with cubic
128
+ segments. velocities at connections are identified and returned only once.
129
+
130
+ :param B: Bézier spline with cubic segments through a Riemannian manifold M
131
+ :return: array of elements (ordered along the first dimension) of the tangent bundle TM
132
+
133
+ ATTENTION: made for splines with cubic segments only!
134
+ """
135
+ assert jnp.all(B.degrees == 3)
136
+
137
+ def f(p):
138
+ return jnp.array([p[-1], -B._M.connec.log(p[-1], p[-2])])
139
+
140
+ b = jax.vmap(f)(B.control_points)
141
+
142
+ if not self.iscycle:
143
+ a = jnp.array([
144
+ B.control_points[0, 0], B._M.connec.log(B.control_points[0, 0], B.control_points[0, 1])
145
+ ])
146
+
147
+ return jnp.concatenate((a[None, ...], b))
148
+ else:
149
+ return b
150
+
151
+ def from_coords(self, Q: jnp.array) -> BezierSpline:
152
+ """Compute the cubic-only Bézier spline corresponding to the given velocities
153
+
154
+ :param Q: array of velocities
155
+ :return: Bézier spline with cubic segments that corresponds to P
156
+
157
+ ATTENTION: made for velocities of splines with cubic segments only!
158
+ """
159
+
160
+ def f(pu, qw):
161
+ p, u = pu[0], pu[1]
162
+ q, w = qw[0], qw[1]
163
+ return jnp.array([p, self._M.connec.exp(p, u), self._M.connec.exp(q, -w), q])
164
+
165
+ if self.iscycle:
166
+ # last velocity vector is also first velocity vector
167
+ Q = jnp.concatenate([Q[None, -1], Q])
168
+
169
+ P = jax.vmap(f)(Q[:-1], Q[1:])
170
+
171
+ return BezierSpline(self._M, P, iscycle=self.iscycle)
172
+
173
+ def proj(self, pu, vw):
174
+ raise NotImplementedError('This function has not been implemented yet.')
175
+
176
+ ############################## Sasaki structure ##############################
177
+ class GeneralizedSasakiStructure(Metric):
178
+ """
179
+ This class implements the generalization of the Sasaki metric to Bézier splines with cubic segments
180
+ """
181
+
182
+ def __init__(self, Bf: Manifold, Ns: int = 3):
183
+ """
184
+ Constructor.
185
+
186
+ :param Bf: Bézierfold object
187
+ :param Ns: scalar that determines the number of discretization steps used in the approximation of the
188
+ exponential and logarithm maps in the tangent bundle
189
+ """
190
+ self._tangent_bundle_power = PowerManifold(TangentBundle(Bf.M), Bf.nsegments + 1)
191
+ self._Bf = Bf
192
+ self.Ns = Ns
193
+
194
+ def __str__(self):
195
+ return "Generalized Sasaki structure"
196
+
197
+ @property
198
+ def typicaldist(self) -> float:
199
+ return self._tangent_bundle_power.metric.typicaldist
200
+
201
+ def inner(self, p_B: jnp.array, v: jnp.array, w: jnp.array) -> float:
202
+ """Generalized Sasaki metric
203
+
204
+ :param p_B: velocities of a Bézier spline
205
+ :param v: tangent vector in the tangent space of the velocities of B
206
+ :param w: tangent vector in the tangent space of the velocities of B
207
+ :return: inner product between X and Y
208
+ """
209
+
210
+ return self._tangent_bundle_power.metric.inner(p_B, v, w)
211
+
212
+ def flat(self, p_B, v):
213
+ """Lower vector X at p with the metric"""
214
+ return self._tangent_bundle_power.metric.flat(p_B, v)
215
+
216
+ def sharp(self, p_B, dv):
217
+ """Raise covector dX at p with the metric"""
218
+ return self._tangent_bundle_power.metric.sharp(p_B, dv)
219
+
220
+ def egrad2rgrad(self, p, X):
221
+ return self._Bf.proj
222
+
223
+ def ehess2rhess(self, pu, G, H, vw):
224
+ """Converts the Euclidean gradient G and Hessian H of a function at
225
+ a point pv along a tangent vector uw to the Riemannian Hessian
226
+ along X on the manifold.
227
+ """
228
+ raise NotImplementedError('This function has not been implemented yet.')
229
+
230
+ def exp(self, p_B: jnp.array, v: jnp.array) -> jnp.array:
231
+ """Exponential map
232
+
233
+ :param p_B: velocities of a Bézier spline
234
+ :param v: tangent vector in the tangent space of the velocities of Bf
235
+ :return: velocities of the Bézier spline at time 1 on the geodesic with initial velocity v
236
+ """
237
+ return self._tangent_bundle_power.connec.exp(p_B, v)
238
+
239
+ def log(self, p_A: jnp.array, p_B: jnp.array) -> jnp.array:
240
+ """Riemannian logarithm map
241
+
242
+ :param p_A: velocities of a Bézier spline A
243
+ :param p_B: velocities of a Bézier spline B
244
+ :return: tangent vector in the tangent space of the velocities of A pointing to the velocities of B
245
+ """
246
+ return self._tangent_bundle_power.connec.log(p_A, p_B)
247
+
248
+ def curvature_tensor(self, p_B: jnp.array, v: jnp.array, w: jnp.array, x: jnp.array) -> jnp.array:
249
+ """Riemmannian curvature tensor at a point of the Bézierfold
250
+
251
+ :param p_B: velocities of a Bézier spline
252
+ :param v: tangent vector in the tangent space of the velocities of B
253
+ :param w: tangent vector in the tangent space of the velocities of B
254
+ :param x: tangent vector in the tangent space of the velocities of B
255
+ :return: tangent vector in the tangent space of the velocities of B
256
+ """
257
+ return self._tangent_bundle_power.connec.curvature_tensor(p_B, v, w, x)
258
+
259
+ def geopoint(self, p_A: jnp.array, p_B: jnp.array, t: float) -> jnp.array:
260
+ """Evaluate the geodesic through the Bézierfold between A and Bf at time t
261
+
262
+ :param p_A: velocities of a Bézier spline A
263
+ :param p_B: velocities of a Bézier spline B
264
+ :param t: scalar between 0 and 1
265
+ :return: Bézier spline at time t on the geodesic from A to B
266
+ """
267
+ return self.exp(p_A, t * self.log(p_A, p_B))
268
+
269
+ def retr(self, p_B: jnp.array, v: jnp.array) -> jnp.array:
270
+ return self.exp(p_B, v)
271
+
272
+ def transp(self, p_A: jnp.array, p_B: jnp.array, v: jnp.array) -> jnp.array:
273
+ """Parallel transport along a geoodesic
274
+
275
+ :param p_A: velocities of a Bézier spline A
276
+ :param p_B: velocities of a Bézier spline B
277
+ :param v: tangent vector in the tangent space of the velocities of A
278
+ :return: tangent vector in the tangent space of the velocities of B: parallel transport of v along the
279
+ geodesic from A to B
280
+ """
281
+ return self._tangent_bundle_power.connec.transp(p_A, p_B, v)
282
+
283
+ def pairmean(self, p_A: jnp.array, p_B: jnp.array) -> jnp.array:
284
+ """Fréchet mean of 2 splines
285
+
286
+ :param A: velocities of a Bézier spline A
287
+ :param B: velocities of a Bézier spline B
288
+ :return: velocities of the mean of A and B
289
+ """
290
+ return self.geopoint(p_A, p_B, .5)
291
+
292
+ def dist(self, p_A: jnp.array, p_B: jnp.array) -> float:
293
+ """Distance function that is induced on the Bézierfold by the generalized Sasaki metric
294
+
295
+ :param p_A: velocities of a Bézier spline A
296
+ :param p_B: velocities of a Bézier spline B
297
+ :return: distance between A and B
298
+ """
299
+ return self._tangent_bundle_power.metric.dist(p_A, p_B)
300
+
301
+ def jacobiField(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
302
+ raise NotImplementedError('This function has not been implemented yet.')
303
+
304
+ def adjJacobi(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
305
+ raise NotImplementedError('This function has not been implemented yet.')
@@ -0,0 +1,197 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import numpy as np
14
+ import jax
15
+ import jax.numpy as jnp
16
+
17
+ from scipy import sparse
18
+
19
+ try:
20
+ from sksparse.cholmod import cholesky as direct_solve
21
+ except:
22
+ from scipy.sparse.linalg import factorized as direct_solve
23
+
24
+ from ..geom import Surface
25
+ from . import SO3, SPD
26
+ from . import ProductManifold, PowerManifold
27
+ from . import ShapeSpace, Metric
28
+ from .util import align
29
+
30
+
31
+ class DifferentialCoords(ShapeSpace):
32
+ """
33
+ Shape space based on differential coordinates.
34
+
35
+ See:
36
+ Christoph von Tycowicz, Felix Ambellan, Anirban Mukhopadhyay, and Stefan Zachow.
37
+ An Efficient Riemannian Statistical Shape Model using Differential Coordinates.
38
+ Medical Image Analysis, Volume 43, January 2018.
39
+ """
40
+
41
+ def __init__(self, reference: Surface, commensuration_weights=(1.0, 1.0)):
42
+ """
43
+ :arg reference: Reference surface (shapes will be encoded as deformations thereof)
44
+ :arg commensuration_weights: weights (rotation, stretch) for commensuration between rotational and stretch parts
45
+ """
46
+ assert reference is not None
47
+ self.ref = reference
48
+
49
+ # rotation and stretch manifolds
50
+ self.SPD = PowerManifold(SPD(3), len(self.ref.f))
51
+ self.SO = PowerManifold(SO3(), len(self.ref.f))
52
+ self._M = ProductManifold([self.SO, self.SPD], jnp.asarray(commensuration_weights))
53
+
54
+ self.update_ref_geom(self.ref.v)
55
+
56
+ name = f'Differential Coordinates Shape Space'
57
+ super().__init__(name, self.M.dim, self.M.point_shape, self.M.connec, self.M.metric, None)
58
+
59
+ def tree_flatten(self):
60
+ return (self.M,), (self.ref.v.tolist(), self.ref.f.tolist())
61
+
62
+ @classmethod
63
+ def tree_unflatten(cls, aux_data, children):
64
+ """Specifies an unflattening recipe for PyTree registration."""
65
+ M = children[0]
66
+ obj = cls(Surface(*aux_data))
67
+ obj._M = M
68
+ obj.SO, obj.SPD = M.manifolds
69
+ return obj
70
+
71
+ @property
72
+ def M(self):
73
+ return self._M
74
+
75
+ def update_ref_geom(self, v):
76
+ self.ref.v=v
77
+
78
+ # center of gravity
79
+ self.CoG = self.ref.v.mean(axis=0)
80
+
81
+ # setup Poisson system
82
+ S = self.ref.div @ self.ref.grad
83
+ # add soft-constraint fixing translational DoF
84
+ S += sparse.coo_matrix(([1.0], ([0], [0])), S.shape) # make pos-def
85
+ self.poisson = direct_solve(S.tocsc())
86
+
87
+ # set metric weights
88
+ w = jnp.asarray(self.ref.face_areas)
89
+ self.SO.metric_weights = self.SPD.metric_weights = w
90
+
91
+
92
+ def to_coords(self, v):
93
+ """
94
+ :arg v: #v-by-3 array of vertex coordinates
95
+ :return: differentical coords.
96
+ """
97
+
98
+ # align
99
+ v = align(v, self.ref.v)
100
+
101
+ # compute gradients
102
+ D = self.ref.grad @ v
103
+
104
+ # D holds transpose of def. grads.
105
+ # -> compute left polar decomposition for right stretch tensor
106
+
107
+ # decompose...
108
+ U, S, Vt = np.linalg.svd(D.reshape(-1, 3, 3))
109
+
110
+ # ...rotation
111
+ R = np.einsum('...ij,...jk', U, Vt)
112
+ W = np.ones_like(S)
113
+ W[:, -1] = np.linalg.det(R)
114
+ R = np.einsum('...ij,...j,...jk', U, W, Vt)
115
+
116
+ # ...stretch
117
+ S[:, -1] = 1 # no stretch (=1) in normal direction
118
+ # for degenerate triangles
119
+ # TODO: check which direction is normal in degenerate case
120
+ S[S < 1e-6] = 1e-6
121
+ U = np.einsum('...ij,...j,...kj', U, S, U)
122
+
123
+ return self.M.entangle([R, U])
124
+
125
+ def from_coords(self, c):
126
+ """
127
+ :arg c: differentical coords.
128
+ :returns: #v-by-3 array of vertex coordinates
129
+ """
130
+ # compose
131
+ R, U = self.M.disentangle(c)
132
+ D = jnp.einsum('...ij,...jk', U, R) # <-- from left polar decomp.
133
+
134
+ # solve Poisson system
135
+ rhs = self.ref.div @ D.reshape(-1, 3)
136
+ v = self.poisson(rhs)
137
+ # move to CoG
138
+ v += self.CoG - v.mean(axis=0)
139
+
140
+ return v
141
+
142
+ @property
143
+ def ref_coords(self):
144
+ return jnp.tile(jnp.eye(3), (2*len(self.ref.f), 1)).reshape(self.point_shape)
145
+
146
+ def rand(self, key: jax.Array):
147
+ return self.M.rand(key)
148
+
149
+ def zerovec(self):
150
+ """Returns the zero vector in any tangent space."""
151
+ return self.M.zerovec()
152
+
153
+ def projToGeodesic(self, X, Y, P, max_iter = 10):
154
+ '''
155
+ Project P onto geodesic from X to Y.
156
+
157
+ See:
158
+ Felix Ambellan, Stefan Zachow, Christoph von Tycowicz.
159
+ Geodesic B-Score for Improved Assessment of Knee Osteoarthritis.
160
+ Proc. Information Processing in Medical Imaging (IPMI), LNCS, 2021.
161
+
162
+ :arg X, Y: manifold coords defining geodesic X->Y.
163
+ :arg P: manifold coords to be projected to X->Y.
164
+ :returns: manifold coords of projection of P to X->Y
165
+ '''
166
+
167
+ # all tagent vectors in common space i.e. algebra
168
+ v = self.connec.log(X, Y)
169
+ v = v / self.metric.norm(X, v)
170
+
171
+ # initial guess
172
+ Pi = X
173
+
174
+ # solver loop
175
+ for _ in range(max_iter):
176
+ w = self.connec.log(Pi, P)
177
+ d = self.metric.inner(Pi, v, w)
178
+
179
+ # print(f'|<v, w>|={d}')
180
+ if abs(d) < 1e-6: break
181
+
182
+ Pi = self.connec.exp(Pi, d * v)
183
+
184
+ return Pi
185
+
186
+ def proj(self, X, A):
187
+ """orthogonal (with respect to the euclidean inner product) projection of ambient
188
+ vector (i.e. (2,k,3,3) array) onto the tangentspace at X"""
189
+ # disentangle coords. into rotations and stretches
190
+ R, U = self.disentangle(X)
191
+ r, u = self.disentangle(A)
192
+
193
+ # project in each component
194
+ r = self.SO.proj(R, r)
195
+ u = self.SPD.proj(U, u)
196
+
197
+ return self.entangle(r, u)
@@ -0,0 +1,56 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import numpy as np
16
+
17
+ from morphomatics.manifold import Manifold
18
+
19
+
20
+ def pole_ladder(M: Manifold, p: jnp.array, q: jnp.array, v: jnp.array, n_step: int = 1) -> jnp.array:
21
+ """Pole Ladder algorithm to approximate parallel transport along geodesics in affine manifolds
22
+ See
23
+
24
+ Numerical Accuracy of Ladder Schemes for Parallel Transport on Manifolds, Nicolas Guigui, Xavier Pennec
25
+ Foundations of Computational Mathematics (2022) 22:757–790,
26
+
27
+ for details. The method is exact in Symmetric Spaces.
28
+
29
+ :param M: Manifold
30
+ :param p: Point in M
31
+ :param q: Point in M
32
+ :param v: Vector in the tangent space at p
33
+ :param n_step: Number of steps
34
+ :return: Vector in the tangent space at q
35
+ """
36
+
37
+ # scaling speeds up convergence
38
+ v = v / n_step**2
39
+
40
+ def body(carry, _):
41
+ _P, _p_pr, _i = carry
42
+ _m = _P[_i]
43
+ _q_pr = M.connec.exp(_m, -M.connec.log(_m, _p_pr))
44
+
45
+ return (_P, _q_pr, _i+1), None
46
+
47
+ U = M.connec.log(p, q)
48
+ t = np.array([i/(2*n_step) for i in range(1, 2*n_step, 2)])
49
+ tU = t.reshape((-1,) + (1,)*U.ndim) * U[None]
50
+
51
+ P = jax.vmap(M.connec.exp, (None, 0))(p, tU)
52
+ p_pr = M.connec.exp(p, v)
53
+
54
+ (_, q_pr, _), _ = jax.lax.scan(body, (P, p_pr, 0), None, length=n_step)
55
+
56
+ return (-1)**n_step * n_step**2 * M.connec.log(q, q_pr)