morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,241 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import numpy as np
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+ from morphomatics.manifold import Manifold, Metric
19
+ from morphomatics.manifold.euclidean import euclidean_inner
20
+ from morphomatics.manifold.metric import _eval_adjJacobi_embed
21
+
22
+
23
+ class Sphere(Manifold):
24
+ """The sphere of [... x k x m]-tensors embedded in R(n+1)
25
+ Elements are represented as normalized (row) vectors of length n + 1.
26
+ """
27
+
28
+ def __init__(self, point_shape=(3,), structure='Canonical'):
29
+ name = 'Points with unit Frobenius norm in ' +\
30
+ 'x'.join(map(str, point_shape)) + '-dim. space.'
31
+ dimension = np.prod(point_shape)-1
32
+ super().__init__(name, dimension, point_shape)
33
+ if structure:
34
+ getattr(self, f'init{structure}Structure')()
35
+
36
+ def tree_flatten(self):
37
+ children, aux = super().tree_flatten()
38
+ return children, aux+(self.point_shape,)
39
+
40
+ @classmethod
41
+ def tree_unflatten(cls, aux_data, children):
42
+ """Specifies an unflattening recipe for PyTree registration."""
43
+ *aux_data, shape = aux_data
44
+ obj = cls(shape, structure=None)
45
+ obj.tree_unflatten_instance(aux_data, children)
46
+ return obj
47
+
48
+ def initCanonicalStructure(self):
49
+ """
50
+ Instantiate Sphere with canonical structure.
51
+ """
52
+ structure = Sphere.CanonicalStructure(self)
53
+ self._metric = structure
54
+ self._connec = structure
55
+
56
+ def rand(self, key: jax.Array):
57
+ p = jax.random.normal(key, self.point_shape)
58
+ return p / jnp.linalg.norm(p)
59
+
60
+ def randvec(self, X, key: jax.Array):
61
+ H = jax.random.normal(key, self.point_shape)
62
+ return H - jnp.dot(X.reshape(-1), H.reshape(-1)) * X
63
+
64
+ def zerovec(self):
65
+ return jnp.zeros(self.point_shape)
66
+
67
+ @staticmethod
68
+ def antipode(p):
69
+ return -p
70
+
71
+ @staticmethod
72
+ def normalize(X):
73
+ """Return Frobenius-normalized version of X in ambient space."""
74
+ return X / jnp.sqrt((X**2).sum() + np.finfo(np.float64).eps)
75
+
76
+ def proj(self, p, X):
77
+ return X - euclidean_inner(p, X) * p
78
+
79
+ class CanonicalStructure(Metric):
80
+ """
81
+ The Riemannian metric used is the induced metric from the embedding space (R^nxn)^k, i.e., this manifold is a
82
+ Riemannian submanifold of (R^3x3)^k endowed with the usual trace inner product.
83
+ """
84
+
85
+ def __init__(self, M):
86
+ """
87
+ Constructor.
88
+ """
89
+ self._M = M
90
+
91
+ def __str__(self):
92
+ return "Canonical structure"
93
+
94
+ @property
95
+ def typicaldist(self):
96
+ return np.pi
97
+
98
+ def inner(self, p, X, Y):
99
+ return euclidean_inner(X, Y)
100
+
101
+ def norm(self, p, X):
102
+ return jnp.sqrt(self.inner(p, X, X))
103
+
104
+ def flat(self, p, X):
105
+ """Lower vector X at p with the metric"""
106
+ return X
107
+
108
+ def sharp(self, p, dX):
109
+ """Raise covector dX at p with the metric"""
110
+ return dX
111
+
112
+ def egrad2rgrad(self, p, X):
113
+ return self._M.proj(p, X)
114
+
115
+ def ehess2rhess(self, p, G, H, X):
116
+ """Converts the Euclidean gradient P_G and Hessian H of a function at
117
+ a point p along a tangent vector X to the Riemannian Hessian
118
+ along X on the manifold.
119
+ """
120
+ raise NotImplementedError('This function has not been implemented yet.')
121
+
122
+ def retr(self, p, X):
123
+ return self.exp(p, X)
124
+
125
+ def exp(self, p, X):
126
+ # numerical safeguard
127
+ p = Sphere.normalize(p)
128
+ X = self._M.proj(p, X)
129
+
130
+ def full_exp(sqn):
131
+ n = jnp.sqrt(sqn + jnp.finfo(jnp.float64).eps)
132
+ return jnp.cos(n) * p + jnp.sinc(n/jnp.pi) * X
133
+
134
+ def trunc_exp(sqn):
135
+ #return (1-sqn/2+sqn**2/24-sqn**3/720) * p + (1-sqn/6+sqn**2/120-sqn**3/5040) * X
136
+ # 4th-order approximation
137
+ return (1-sqn/2+sqn**2/24) * p + (1-sqn/6+sqn**2/120) * X
138
+
139
+ sq_norm = (X ** 2).sum()
140
+ q = jax.lax.cond(sq_norm < 1e-6, trunc_exp, full_exp, sq_norm)
141
+ return Sphere.normalize(q)
142
+
143
+ def log(self, p, q):
144
+
145
+ def full_log(a2):
146
+ a = jnp.sqrt(a2 + jnp.finfo(jnp.float64).eps)
147
+ return 1/jnp.sinc(a/jnp.pi) * q - a/jnp.tan(a) * p
148
+
149
+ def trunc_log(a2):
150
+ return (1 + a2/6 + 7*a2**2/360 + 31*a2**3/15120) * q - (1 - a2/3 - a2**2/45 - a2**3/945) * p
151
+ #return (1 + a**2/6 + 7*a**4/360) * q - (1 - a**2/3 - a**4/45) * p
152
+
153
+ sqd = self.squared_dist(p, q)
154
+ return jax.lax.cond(sqd < 1e-6, trunc_log, full_log, sqd)
155
+
156
+ def curvature_tensor(self, p, X, Y, Z):
157
+ """Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
158
+ covariant derivative of Y in direction X and [] being the Lie bracket, the convention
159
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
160
+ is used.
161
+ """
162
+ return (Y*Z).sum() * X - (X*Z).sum() * Y
163
+
164
+ def geopoint(self, p, q, t):
165
+ return self.exp(p, t * self.log(p, q))
166
+
167
+ def transp(self, p, q, X):
168
+ d2 = self.squared_dist(p, q)
169
+ def do_transp(V):
170
+ log_p_q = self.log(p, q)
171
+ return V - self.inner(p, log_p_q, V)/d2 * (log_p_q + self.log(q, p))
172
+ return jax.lax.cond(d2 < 1e-6, lambda X: X, do_transp, X)
173
+
174
+ def pairmean(self, p, q):
175
+ return self.geopoint(p, q, .5)
176
+
177
+ def dist(self, p, q):
178
+ inner = (p * q).sum()
179
+ return jax.lax.cond(jnp.abs(inner) >= 1, lambda i: (i < 0)*jnp.pi, jnp.arccos, jnp.clip(inner, -1, 1))
180
+
181
+ def squared_dist(self, p, q):
182
+ inner = (p * q).sum()
183
+ return jax.lax.cond(inner > 1-1e-6, lambda _: jnp.sum((q-p)**2), lambda i: jnp.arccos(i)**2, jnp.clip(inner, None, 1-1e-6))
184
+
185
+ def jacobiField(self, p, q, t, X):
186
+ phi = self.dist(p, q)
187
+ v = self.log(p, q)
188
+ gamTS = self.exp(p, t*v)
189
+
190
+ v = v / phi
191
+ Xtan_norm = self.inner(p, X, v)
192
+ Xtan = Xtan_norm * v
193
+ Xorth = X - Xtan
194
+
195
+ # tangential component of J: (1-t) * transp(p, gamTS, Xtan)
196
+ Jtan = Xtan_norm / phi * self.log(gamTS, q)
197
+ return gamTS, (jnp.sin((1 - t) * phi) / jnp.sin(phi)) * Xorth + Jtan
198
+
199
+ def _adjJacobi(self, p, q, t, w):
200
+ # alternative version to adjJacobi relying on automatic differentiation
201
+ return self._M.proj(p, _eval_adjJacobi_embed(self, p, q, t, w))
202
+
203
+ def adjJacobi(self, p, q, t, w):
204
+ """Evaluate an adjoint Jacobi field.
205
+
206
+ The decomposition of the curvature operator and the fact that only two of its eigenvectors are necessary is
207
+ used in the algorithm.
208
+
209
+ :param p: element of the hyperboloid
210
+ :param q: element of the hyperboloid
211
+ :param t: scalar in [0,1]
212
+ :param w: tangent vector at gamma(t;p,q)
213
+ :returns: tangent vector at p
214
+ """
215
+ dist = self.dist(p, q)
216
+
217
+ def _eval(dW):
218
+ d, W = dW
219
+ # all computations can be done at p, so only w has to be parallel translated
220
+ W = self.transp(self.geopoint(p, q, t), p, W)
221
+
222
+ # first eigenvector is normalized tangent of the geodesic -> corresponding eigenvalue is 0
223
+ T = self.log(p, q) / jnp.clip(d, 1e-6) # clipping only for NAN debugging
224
+
225
+ # second eigenvector is Gram Schmidt orthonormalization of W against T -> corresponding eigenvalue is 1
226
+ b1 = self.inner(p, W, T)
227
+ U = W - b1 * T
228
+ # Check whether W is (numerically) parallel to T;
229
+ # then, the adoint Jacobi field is only a scaling of the parallel transported tangent.
230
+ U = jax.lax.cond(jnp.linalg.norm(U) > 1e-3,
231
+ lambda v: v / jnp.clip(self.norm(p, v), 1e-3), # clipping only for NAN debugging
232
+ lambda v: jnp.zeros_like(v), U)
233
+
234
+ b2 = self.inner(p, W, U)
235
+
236
+ a1 = 1 - t # corresponds to the eigenvalue 0
237
+ a2 = jnp.sin((1 - t) * d) / jnp.sin(d) # corresponds to the eigenvalue 1
238
+
239
+ return a1 * b1 * T + a2 * b2 * U
240
+
241
+ return jax.lax.cond(dist > 1e-6, _eval, lambda args: 1/jnp.clip(1-t, 1e-3) * args[-1], (dist, w))
@@ -0,0 +1,337 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+
16
+ from typing import Callable, List
17
+
18
+
19
+ from morphomatics.manifold import Manifold, Metric
20
+
21
+
22
+ class TangentBundle(Manifold):
23
+ """Tangent Bundle TM of a smooth manifold M
24
+
25
+ Elements of TM are modelled as arrays of the form [2, M.point_shape]
26
+ """
27
+
28
+ def __init__(self, M: Manifold, structure: str = 'Sasaki'):
29
+ point_shape = tuple([2, *M.point_shape])
30
+ name = 'Tangent bundle of ' + M.__str__() + '.'
31
+ dimension = 2 * M.dim
32
+ super().__init__(name, dimension, point_shape)
33
+ self._base_manifold = M
34
+ if structure:
35
+ getattr(self, f'init{structure}Structure')()
36
+
37
+ def tree_flatten(self):
38
+ children, aux = super().tree_flatten()
39
+ return children+(self.base_manifold,), aux
40
+
41
+ @classmethod
42
+ def tree_unflatten(cls, aux_data, children):
43
+ """Specifies an unflattening recipe for PyTree registration."""
44
+ *children, M = children
45
+ obj = cls(M, structure=None)
46
+ obj.tree_unflatten_instance(aux_data, children)
47
+ return obj
48
+
49
+ @property
50
+ def base_manifold(self) -> Manifold:
51
+ """Return the base manifold of whose tangent bundle is TM """
52
+ return self._base_manifold
53
+
54
+ def bundle_projection(self, pu: jnp.array) -> jnp.array:
55
+ """Canonical projection of the tangent bundle onto its base manifold"""
56
+ return pu[0]
57
+
58
+ def initSasakiStructure(self):
59
+ """
60
+ Instantiate the tangent bundle with Sasaki structure.
61
+ """
62
+ structure = TangentBundle.SasakiStructure(self)
63
+ self._metric = structure
64
+ self._connec = structure
65
+
66
+ def rand(self, key: jax.Array) -> jnp.array:
67
+ """Random element of TM"""
68
+ k1, k2 = jax.random.split(key, 2)
69
+ p = self._base_manifold.rand(k1)
70
+ u = self._base_manifold.randvec(k2)
71
+ return jnp.stack((p, u))
72
+
73
+ def randvec(self, pu: jnp.array, key: jax.Array) -> jnp.array:
74
+ """Random vector in the tangent space of the point pu
75
+
76
+ :param pu: element of TM
77
+ :return: tangent vector at pu
78
+ """
79
+ k1, k2 = jax.random.split(key, 2)
80
+ p = self.bundle_projection(pu)
81
+ return jnp.stack((self._base_manifold.randvec(p, k1), self._base_manifold.randvec(p, k2)))
82
+
83
+ def zerovec(self) -> jnp.array:
84
+ """Zero vector in any tangen space
85
+ """
86
+ return jnp.zeros(self.point_shape)
87
+
88
+ def proj(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
89
+ raise NotImplementedError('This function has not been implemented yet.')
90
+
91
+ class SasakiStructure(Metric):
92
+ """
93
+ This class implements the Sasaki metric: The natural metric on the tangent bundle TM of a Riemannian manifold M.
94
+
95
+ The Sasaki metric is characterized by the following three properties:
96
+ * the canonical projection of TM becomes a Riemannian submersion,
97
+ * parallel vector fields along curves are orthogonal to their fibres, and
98
+ * its restriction to any tangent space is Euclidean.
99
+
100
+ Geodesic computations are realized via a discrete formulation of the geodesic equation on TM that involve
101
+ geodesics, parallel translation, and the curvature tensor on the base manifold M; for details see
102
+ Muralidharan, P., & Fletcher, P. T. (2012, June).
103
+ Sasaki metrics for analysis of longitudinal data on manifolds.
104
+ In 2012 IEEE conference on computer vision and pattern recognition (pp. 1027-1034). IEEE.
105
+ https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4270017/
106
+ """
107
+
108
+ def __init__(self, TM: Manifold, Ns: int = 10):
109
+ """
110
+ Constructor.
111
+
112
+ :param TM: TangentBundle object
113
+ :param Ns: scalar that determines the number of discretization steps used in the approximation of the
114
+ exponential and logarithm maps
115
+ """
116
+ self._TM = TM
117
+ self.Ns = Ns
118
+
119
+ def __str__(self) -> str:
120
+ return "Sasaki structure"
121
+
122
+ @property
123
+ def typicaldist(self):
124
+ raise NotImplementedError('This function has not been implemented yet.')
125
+
126
+ def inner(self, pu: jnp.array, vw: jnp.array, xy: jnp.array) -> float:
127
+ """Inner product between two tangent vectors at point in TM.
128
+
129
+ :param pu: element of _TM
130
+ :param vw: tangent vector at pv
131
+ :param xy: tangent vector at pv
132
+ :return: inner product (scalar) of vw and xy
133
+ """
134
+ p = self._TM.bundle_projection(pu)
135
+ # compute Sasaki inner product via metric of the underlying manifold
136
+ base_metric = self._TM.base_manifold.metric.inner
137
+ return base_metric(p, vw[0], xy[0]) + base_metric(p, vw[1], xy[1])
138
+
139
+ def flat(self, pu: jnp.array, xy: jnp.array) -> jnp.array:
140
+ """Lower vector xy at pu with the metric"""
141
+ raise NotImplementedError('This function has not been implemented yet.')
142
+
143
+ def sharp(self, pu: jnp.array, d_xy: jnp.array) -> jnp.array:
144
+ """Raise covector d_xy at pu with the metric"""
145
+ raise NotImplementedError('This function has not been implemented yet.')
146
+
147
+ def egrad2rgrad(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
148
+ raise NotImplementedError('This function has not been implemented yet.')
149
+
150
+ def ehess2rhess(self, pu, G, H, vw):
151
+ """Converts the Euclidean gradient G and Hessian H of a function at
152
+ a point pv along a tangent vector uw to the Riemannian Hessian
153
+ along X on the manifold.
154
+ """
155
+ raise NotImplementedError('This function has not been implemented yet.')
156
+
157
+ def geodesic_discrete(self, pu: jnp.array, qr: jnp.array) -> List[jnp.array]:
158
+ """
159
+ Compute Sasaki geodesic employing a variational time discretization.
160
+
161
+ :param pu: element of TM
162
+ :param qr: element of TM
163
+ :return: array-like, shape=[n_steps + 1, 2, M.shape]
164
+ Discrete geodesic x(s)=(p(s), u(s)) in Sasaki metric connecting
165
+ pu = x(0) and qr = x(1).
166
+ """
167
+
168
+ Ns = self.Ns
169
+ base_connection = self._TM.base_manifold.connec
170
+ par_trans = base_connection.transp
171
+ p0, u0 = pu[0], pu[1]
172
+ pL, uL = qr[0], qr[1]
173
+
174
+ def grad(c):
175
+ """ Gradient of path energy for discrete geodesic c """
176
+ def grad_i(puP, pu, puN):
177
+ v, w = base_connection.log(pu[0], puN[0]), par_trans(puN[0], pu[0], puN[1]) - pu[1]
178
+ gp = .5 * (v + base_connection.log(pu[0], puP[0])) \
179
+ + base_connection.curvature_tensor(pu[0], pu[1], w, v)
180
+ gu = w + par_trans(puP[0], pu[0], puP[1]) - pu[1]
181
+ return jnp.array([gp, gu])
182
+ return -Ns * jax.vmap(grad_i)(c[:-2], c[1:-1], c[2:])
183
+
184
+ # Initial guess for gradient_descent
185
+ vw = base_connection.log(p0, pL)
186
+ s = jnp.linspace(0., 1., Ns + 1)
187
+ def init(t):
188
+ p_ini = base_connection.exp(p0, t * vw)
189
+ u_ini = (1 - t) * par_trans(p0, p_ini, u0) + t * par_trans(pL, p_ini, uL)
190
+ return jnp.array([p_ini, u_ini])
191
+ pu_ini = jax.vmap(init)(s[1:-1])
192
+ pu_ini = jnp.vstack([pu[None], pu_ini, qr[None]])
193
+
194
+ # Minimization by gradient descent
195
+ x = _gradient_descent(self._TM, pu_ini, grad)
196
+ # x = _gradient_descent(pu_ini, grad, self.exp)
197
+
198
+ return x
199
+
200
+ def exp(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
201
+ """Compute the exponential of the Levi-Civita connection of the Sasaki metric.
202
+
203
+ Exponential map at pv of uw computed by
204
+ shooting a Sasaki geodesic using an Euler integration in TM.
205
+
206
+ :param pu: element of TM
207
+ :param vw: tangent vector at pv
208
+ :return: point at time 1 of the geodesic that starts in pu with initial velocity vw
209
+ """
210
+
211
+ base_connection = self._TM.base_manifold.connec
212
+ par_trans = self._TM.base_manifold.connec.transp
213
+ Ns = self.Ns
214
+ eps = 1 / Ns
215
+
216
+ def body(carry, _):
217
+ p, u, v, w = carry
218
+ p_ = base_connection.exp(p, eps * v)
219
+ u_ = par_trans(p, p_, u + eps * w)
220
+ v_ = par_trans(p, p_, v - eps * base_connection.curvature_tensor(p, u, w, v))
221
+ w_ = par_trans(p, p_, w)
222
+ return (p_, u_, v_, w_), None
223
+
224
+ (p, u, *_), _ = jax.lax.scan(body, (*pu, *vw), jnp.empty(Ns))
225
+
226
+ return jnp.stack([p, u])
227
+
228
+ def log(self, pu: jnp.array, qr: jnp.array) -> jnp.array:
229
+ """Compute the logarithm of the Levi-Civita connection of the Sasaki metric.
230
+
231
+ Logarithmic map at base_point p of pu computed by
232
+ iteratively relaxing a discretized geodesic between pu and qw.
233
+
234
+ For a derivation of the algorithm see https://opus4.kobv.de/opus4-zib/frontdoor/index/index/docId/8717.
235
+
236
+ :param pu: element of TM
237
+ :param qr: element of TM
238
+ :return: tangent vector at pu (inverse of exp)
239
+ """
240
+
241
+ base_connection = self._TM.base_manifold.connec
242
+ par_trans = base_connection.transp
243
+ Ns = self.Ns
244
+
245
+ def do_log(bs_pt, pt):
246
+ pu = self.geodesic_discrete(bs_pt, pt)
247
+ p1, u1 = pu[1][0], pu[1][1]
248
+ p0, u0 = bs_pt[0], bs_pt[1]
249
+ w = par_trans(p1, p0, u1) - u0
250
+ v = base_connection.log(p0, p1)
251
+ return Ns * jnp.array([v, w])
252
+
253
+ return do_log(pu, qr)
254
+
255
+ def curvature_tensor(self, pu, X, Y, Z):
256
+ """Evaluates the curvature tensor R of the connection at pu on the vectors X, Y, Z. With nabla_X Y denoting
257
+ the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
258
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
259
+ is used.
260
+ """
261
+ # TODO: Implement from "Curvature of the Induced Riemannian Metric on the Tangent Bundle of a Riemannian
262
+ # Manifold" (1971) by Kowalski
263
+ raise NotImplementedError('This function has not been implemented yet.')
264
+
265
+ def geopoint(self, pu: jnp.array, qr: jnp.array, t: float) -> jnp.array:
266
+ """Evaluate geodesic in TM
267
+
268
+ :param pu: element of TM
269
+ :param qr: element of TM
270
+ :param t: scalar between 0 and 1
271
+ :return: element of the geodesic between pv and pu evaluated at time t
272
+ """
273
+ return self.exp(pu, t * self.log(pu, qr))
274
+
275
+ def retr(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
276
+ return self.exp(pu, vw)
277
+
278
+ def transp(self, pu, qr, vw):
279
+ raise NotImplementedError('This function has not been implemented yet.')
280
+
281
+ def pairmean(self, pu: jnp.array, qr: jnp.array) -> jnp.array:
282
+ """Fréchet mean of 2 point in TM
283
+
284
+ :param pu: element of TM
285
+ :param qr: element of TM
286
+ :return: Fréchet mean of pv and qw
287
+ """
288
+ return self.geopoint(pu, qr, .5)
289
+
290
+ def dist(self, pu: jnp.array, qr: jnp.array) -> float:
291
+ """Distance function that is induced on TM by the Sasaki metric
292
+
293
+ :param pu: element of TM
294
+ :param qr: element of TM
295
+ :return: distance between pu and qr in TM
296
+ """
297
+ vw = self.log(pu, qr)
298
+ return jnp.sqrt(self.inner(pu, vw, vw))
299
+
300
+ def jacobiField(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
301
+ raise NotImplementedError('This function has not been implemented yet.')
302
+
303
+ def adjJacobi(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
304
+ raise NotImplementedError('This function has not been implemented yet.')
305
+
306
+
307
+ def _gradient_descent(TM: Manifold, x_ini: jnp.array, grad: Callable, step_size: float = 0.1, max_iter: int = 100,
308
+ tol: float = 1e-6) -> jnp.array:
309
+ """Apply a gradient descent to compute the discrete geodesic in TM with the Sasaki structure.
310
+
311
+ :param TM: Tangent bundle
312
+ :param x_ini: initial discrete curve as list of points in TM
313
+ :param grad: gradient function
314
+ :param step_size: step size
315
+ :param max_iter: maximum number of iterations
316
+ :param tol: tolerance for convergence
317
+ :return: discrete geodesic as list of points in _TM
318
+ """
319
+
320
+ def body(args):
321
+ x, grad_norm, tol, i = args
322
+ grad_x = grad(x)
323
+
324
+ x.at[1:-1].set(jax.vmap(TM.connec.exp)(x[1:-1], -step_size * grad_x))
325
+ grad_norm = jnp.linalg.norm(grad_x)
326
+
327
+ return x, grad_norm, tol, i + 1
328
+
329
+ def cond(args):
330
+ _, g_norm, tol, i = args
331
+
332
+ c = jnp.array([g_norm > tol, i < max_iter])
333
+ return jnp.all(c)
334
+
335
+ x, g_norm, _, i = jax.lax.while_loop(cond, body, (x_ini, jnp.array(1.), tol, jnp.array(0)))
336
+
337
+ return x