morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
from morphomatics.manifold import Manifold, Metric
|
|
19
|
+
from morphomatics.manifold.euclidean import euclidean_inner
|
|
20
|
+
from morphomatics.manifold.metric import _eval_adjJacobi_embed
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Sphere(Manifold):
|
|
24
|
+
"""The sphere of [... x k x m]-tensors embedded in R(n+1)
|
|
25
|
+
Elements are represented as normalized (row) vectors of length n + 1.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, point_shape=(3,), structure='Canonical'):
|
|
29
|
+
name = 'Points with unit Frobenius norm in ' +\
|
|
30
|
+
'x'.join(map(str, point_shape)) + '-dim. space.'
|
|
31
|
+
dimension = np.prod(point_shape)-1
|
|
32
|
+
super().__init__(name, dimension, point_shape)
|
|
33
|
+
if structure:
|
|
34
|
+
getattr(self, f'init{structure}Structure')()
|
|
35
|
+
|
|
36
|
+
def tree_flatten(self):
|
|
37
|
+
children, aux = super().tree_flatten()
|
|
38
|
+
return children, aux+(self.point_shape,)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def tree_unflatten(cls, aux_data, children):
|
|
42
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
43
|
+
*aux_data, shape = aux_data
|
|
44
|
+
obj = cls(shape, structure=None)
|
|
45
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
46
|
+
return obj
|
|
47
|
+
|
|
48
|
+
def initCanonicalStructure(self):
|
|
49
|
+
"""
|
|
50
|
+
Instantiate Sphere with canonical structure.
|
|
51
|
+
"""
|
|
52
|
+
structure = Sphere.CanonicalStructure(self)
|
|
53
|
+
self._metric = structure
|
|
54
|
+
self._connec = structure
|
|
55
|
+
|
|
56
|
+
def rand(self, key: jax.Array):
|
|
57
|
+
p = jax.random.normal(key, self.point_shape)
|
|
58
|
+
return p / jnp.linalg.norm(p)
|
|
59
|
+
|
|
60
|
+
def randvec(self, X, key: jax.Array):
|
|
61
|
+
H = jax.random.normal(key, self.point_shape)
|
|
62
|
+
return H - jnp.dot(X.reshape(-1), H.reshape(-1)) * X
|
|
63
|
+
|
|
64
|
+
def zerovec(self):
|
|
65
|
+
return jnp.zeros(self.point_shape)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def antipode(p):
|
|
69
|
+
return -p
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def normalize(X):
|
|
73
|
+
"""Return Frobenius-normalized version of X in ambient space."""
|
|
74
|
+
return X / jnp.sqrt((X**2).sum() + np.finfo(np.float64).eps)
|
|
75
|
+
|
|
76
|
+
def proj(self, p, X):
|
|
77
|
+
return X - euclidean_inner(p, X) * p
|
|
78
|
+
|
|
79
|
+
class CanonicalStructure(Metric):
|
|
80
|
+
"""
|
|
81
|
+
The Riemannian metric used is the induced metric from the embedding space (R^nxn)^k, i.e., this manifold is a
|
|
82
|
+
Riemannian submanifold of (R^3x3)^k endowed with the usual trace inner product.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, M):
|
|
86
|
+
"""
|
|
87
|
+
Constructor.
|
|
88
|
+
"""
|
|
89
|
+
self._M = M
|
|
90
|
+
|
|
91
|
+
def __str__(self):
|
|
92
|
+
return "Canonical structure"
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def typicaldist(self):
|
|
96
|
+
return np.pi
|
|
97
|
+
|
|
98
|
+
def inner(self, p, X, Y):
|
|
99
|
+
return euclidean_inner(X, Y)
|
|
100
|
+
|
|
101
|
+
def norm(self, p, X):
|
|
102
|
+
return jnp.sqrt(self.inner(p, X, X))
|
|
103
|
+
|
|
104
|
+
def flat(self, p, X):
|
|
105
|
+
"""Lower vector X at p with the metric"""
|
|
106
|
+
return X
|
|
107
|
+
|
|
108
|
+
def sharp(self, p, dX):
|
|
109
|
+
"""Raise covector dX at p with the metric"""
|
|
110
|
+
return dX
|
|
111
|
+
|
|
112
|
+
def egrad2rgrad(self, p, X):
|
|
113
|
+
return self._M.proj(p, X)
|
|
114
|
+
|
|
115
|
+
def ehess2rhess(self, p, G, H, X):
|
|
116
|
+
"""Converts the Euclidean gradient P_G and Hessian H of a function at
|
|
117
|
+
a point p along a tangent vector X to the Riemannian Hessian
|
|
118
|
+
along X on the manifold.
|
|
119
|
+
"""
|
|
120
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
121
|
+
|
|
122
|
+
def retr(self, p, X):
|
|
123
|
+
return self.exp(p, X)
|
|
124
|
+
|
|
125
|
+
def exp(self, p, X):
|
|
126
|
+
# numerical safeguard
|
|
127
|
+
p = Sphere.normalize(p)
|
|
128
|
+
X = self._M.proj(p, X)
|
|
129
|
+
|
|
130
|
+
def full_exp(sqn):
|
|
131
|
+
n = jnp.sqrt(sqn + jnp.finfo(jnp.float64).eps)
|
|
132
|
+
return jnp.cos(n) * p + jnp.sinc(n/jnp.pi) * X
|
|
133
|
+
|
|
134
|
+
def trunc_exp(sqn):
|
|
135
|
+
#return (1-sqn/2+sqn**2/24-sqn**3/720) * p + (1-sqn/6+sqn**2/120-sqn**3/5040) * X
|
|
136
|
+
# 4th-order approximation
|
|
137
|
+
return (1-sqn/2+sqn**2/24) * p + (1-sqn/6+sqn**2/120) * X
|
|
138
|
+
|
|
139
|
+
sq_norm = (X ** 2).sum()
|
|
140
|
+
q = jax.lax.cond(sq_norm < 1e-6, trunc_exp, full_exp, sq_norm)
|
|
141
|
+
return Sphere.normalize(q)
|
|
142
|
+
|
|
143
|
+
def log(self, p, q):
|
|
144
|
+
|
|
145
|
+
def full_log(a2):
|
|
146
|
+
a = jnp.sqrt(a2 + jnp.finfo(jnp.float64).eps)
|
|
147
|
+
return 1/jnp.sinc(a/jnp.pi) * q - a/jnp.tan(a) * p
|
|
148
|
+
|
|
149
|
+
def trunc_log(a2):
|
|
150
|
+
return (1 + a2/6 + 7*a2**2/360 + 31*a2**3/15120) * q - (1 - a2/3 - a2**2/45 - a2**3/945) * p
|
|
151
|
+
#return (1 + a**2/6 + 7*a**4/360) * q - (1 - a**2/3 - a**4/45) * p
|
|
152
|
+
|
|
153
|
+
sqd = self.squared_dist(p, q)
|
|
154
|
+
return jax.lax.cond(sqd < 1e-6, trunc_log, full_log, sqd)
|
|
155
|
+
|
|
156
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
157
|
+
"""Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
|
|
158
|
+
covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
159
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
160
|
+
is used.
|
|
161
|
+
"""
|
|
162
|
+
return (Y*Z).sum() * X - (X*Z).sum() * Y
|
|
163
|
+
|
|
164
|
+
def geopoint(self, p, q, t):
|
|
165
|
+
return self.exp(p, t * self.log(p, q))
|
|
166
|
+
|
|
167
|
+
def transp(self, p, q, X):
|
|
168
|
+
d2 = self.squared_dist(p, q)
|
|
169
|
+
def do_transp(V):
|
|
170
|
+
log_p_q = self.log(p, q)
|
|
171
|
+
return V - self.inner(p, log_p_q, V)/d2 * (log_p_q + self.log(q, p))
|
|
172
|
+
return jax.lax.cond(d2 < 1e-6, lambda X: X, do_transp, X)
|
|
173
|
+
|
|
174
|
+
def pairmean(self, p, q):
|
|
175
|
+
return self.geopoint(p, q, .5)
|
|
176
|
+
|
|
177
|
+
def dist(self, p, q):
|
|
178
|
+
inner = (p * q).sum()
|
|
179
|
+
return jax.lax.cond(jnp.abs(inner) >= 1, lambda i: (i < 0)*jnp.pi, jnp.arccos, jnp.clip(inner, -1, 1))
|
|
180
|
+
|
|
181
|
+
def squared_dist(self, p, q):
|
|
182
|
+
inner = (p * q).sum()
|
|
183
|
+
return jax.lax.cond(inner > 1-1e-6, lambda _: jnp.sum((q-p)**2), lambda i: jnp.arccos(i)**2, jnp.clip(inner, None, 1-1e-6))
|
|
184
|
+
|
|
185
|
+
def jacobiField(self, p, q, t, X):
|
|
186
|
+
phi = self.dist(p, q)
|
|
187
|
+
v = self.log(p, q)
|
|
188
|
+
gamTS = self.exp(p, t*v)
|
|
189
|
+
|
|
190
|
+
v = v / phi
|
|
191
|
+
Xtan_norm = self.inner(p, X, v)
|
|
192
|
+
Xtan = Xtan_norm * v
|
|
193
|
+
Xorth = X - Xtan
|
|
194
|
+
|
|
195
|
+
# tangential component of J: (1-t) * transp(p, gamTS, Xtan)
|
|
196
|
+
Jtan = Xtan_norm / phi * self.log(gamTS, q)
|
|
197
|
+
return gamTS, (jnp.sin((1 - t) * phi) / jnp.sin(phi)) * Xorth + Jtan
|
|
198
|
+
|
|
199
|
+
def _adjJacobi(self, p, q, t, w):
|
|
200
|
+
# alternative version to adjJacobi relying on automatic differentiation
|
|
201
|
+
return self._M.proj(p, _eval_adjJacobi_embed(self, p, q, t, w))
|
|
202
|
+
|
|
203
|
+
def adjJacobi(self, p, q, t, w):
|
|
204
|
+
"""Evaluate an adjoint Jacobi field.
|
|
205
|
+
|
|
206
|
+
The decomposition of the curvature operator and the fact that only two of its eigenvectors are necessary is
|
|
207
|
+
used in the algorithm.
|
|
208
|
+
|
|
209
|
+
:param p: element of the hyperboloid
|
|
210
|
+
:param q: element of the hyperboloid
|
|
211
|
+
:param t: scalar in [0,1]
|
|
212
|
+
:param w: tangent vector at gamma(t;p,q)
|
|
213
|
+
:returns: tangent vector at p
|
|
214
|
+
"""
|
|
215
|
+
dist = self.dist(p, q)
|
|
216
|
+
|
|
217
|
+
def _eval(dW):
|
|
218
|
+
d, W = dW
|
|
219
|
+
# all computations can be done at p, so only w has to be parallel translated
|
|
220
|
+
W = self.transp(self.geopoint(p, q, t), p, W)
|
|
221
|
+
|
|
222
|
+
# first eigenvector is normalized tangent of the geodesic -> corresponding eigenvalue is 0
|
|
223
|
+
T = self.log(p, q) / jnp.clip(d, 1e-6) # clipping only for NAN debugging
|
|
224
|
+
|
|
225
|
+
# second eigenvector is Gram Schmidt orthonormalization of W against T -> corresponding eigenvalue is 1
|
|
226
|
+
b1 = self.inner(p, W, T)
|
|
227
|
+
U = W - b1 * T
|
|
228
|
+
# Check whether W is (numerically) parallel to T;
|
|
229
|
+
# then, the adoint Jacobi field is only a scaling of the parallel transported tangent.
|
|
230
|
+
U = jax.lax.cond(jnp.linalg.norm(U) > 1e-3,
|
|
231
|
+
lambda v: v / jnp.clip(self.norm(p, v), 1e-3), # clipping only for NAN debugging
|
|
232
|
+
lambda v: jnp.zeros_like(v), U)
|
|
233
|
+
|
|
234
|
+
b2 = self.inner(p, W, U)
|
|
235
|
+
|
|
236
|
+
a1 = 1 - t # corresponds to the eigenvalue 0
|
|
237
|
+
a2 = jnp.sin((1 - t) * d) / jnp.sin(d) # corresponds to the eigenvalue 1
|
|
238
|
+
|
|
239
|
+
return a1 * b1 * T + a2 * b2 * U
|
|
240
|
+
|
|
241
|
+
return jax.lax.cond(dist > 1e-6, _eval, lambda args: 1/jnp.clip(1-t, 1e-3) * args[-1], (dist, w))
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
|
|
16
|
+
from typing import Callable, List
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from morphomatics.manifold import Manifold, Metric
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TangentBundle(Manifold):
|
|
23
|
+
"""Tangent Bundle TM of a smooth manifold M
|
|
24
|
+
|
|
25
|
+
Elements of TM are modelled as arrays of the form [2, M.point_shape]
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, M: Manifold, structure: str = 'Sasaki'):
|
|
29
|
+
point_shape = tuple([2, *M.point_shape])
|
|
30
|
+
name = 'Tangent bundle of ' + M.__str__() + '.'
|
|
31
|
+
dimension = 2 * M.dim
|
|
32
|
+
super().__init__(name, dimension, point_shape)
|
|
33
|
+
self._base_manifold = M
|
|
34
|
+
if structure:
|
|
35
|
+
getattr(self, f'init{structure}Structure')()
|
|
36
|
+
|
|
37
|
+
def tree_flatten(self):
|
|
38
|
+
children, aux = super().tree_flatten()
|
|
39
|
+
return children+(self.base_manifold,), aux
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def tree_unflatten(cls, aux_data, children):
|
|
43
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
44
|
+
*children, M = children
|
|
45
|
+
obj = cls(M, structure=None)
|
|
46
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
47
|
+
return obj
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def base_manifold(self) -> Manifold:
|
|
51
|
+
"""Return the base manifold of whose tangent bundle is TM """
|
|
52
|
+
return self._base_manifold
|
|
53
|
+
|
|
54
|
+
def bundle_projection(self, pu: jnp.array) -> jnp.array:
|
|
55
|
+
"""Canonical projection of the tangent bundle onto its base manifold"""
|
|
56
|
+
return pu[0]
|
|
57
|
+
|
|
58
|
+
def initSasakiStructure(self):
|
|
59
|
+
"""
|
|
60
|
+
Instantiate the tangent bundle with Sasaki structure.
|
|
61
|
+
"""
|
|
62
|
+
structure = TangentBundle.SasakiStructure(self)
|
|
63
|
+
self._metric = structure
|
|
64
|
+
self._connec = structure
|
|
65
|
+
|
|
66
|
+
def rand(self, key: jax.Array) -> jnp.array:
|
|
67
|
+
"""Random element of TM"""
|
|
68
|
+
k1, k2 = jax.random.split(key, 2)
|
|
69
|
+
p = self._base_manifold.rand(k1)
|
|
70
|
+
u = self._base_manifold.randvec(k2)
|
|
71
|
+
return jnp.stack((p, u))
|
|
72
|
+
|
|
73
|
+
def randvec(self, pu: jnp.array, key: jax.Array) -> jnp.array:
|
|
74
|
+
"""Random vector in the tangent space of the point pu
|
|
75
|
+
|
|
76
|
+
:param pu: element of TM
|
|
77
|
+
:return: tangent vector at pu
|
|
78
|
+
"""
|
|
79
|
+
k1, k2 = jax.random.split(key, 2)
|
|
80
|
+
p = self.bundle_projection(pu)
|
|
81
|
+
return jnp.stack((self._base_manifold.randvec(p, k1), self._base_manifold.randvec(p, k2)))
|
|
82
|
+
|
|
83
|
+
def zerovec(self) -> jnp.array:
|
|
84
|
+
"""Zero vector in any tangen space
|
|
85
|
+
"""
|
|
86
|
+
return jnp.zeros(self.point_shape)
|
|
87
|
+
|
|
88
|
+
def proj(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
|
|
89
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
90
|
+
|
|
91
|
+
class SasakiStructure(Metric):
|
|
92
|
+
"""
|
|
93
|
+
This class implements the Sasaki metric: The natural metric on the tangent bundle TM of a Riemannian manifold M.
|
|
94
|
+
|
|
95
|
+
The Sasaki metric is characterized by the following three properties:
|
|
96
|
+
* the canonical projection of TM becomes a Riemannian submersion,
|
|
97
|
+
* parallel vector fields along curves are orthogonal to their fibres, and
|
|
98
|
+
* its restriction to any tangent space is Euclidean.
|
|
99
|
+
|
|
100
|
+
Geodesic computations are realized via a discrete formulation of the geodesic equation on TM that involve
|
|
101
|
+
geodesics, parallel translation, and the curvature tensor on the base manifold M; for details see
|
|
102
|
+
Muralidharan, P., & Fletcher, P. T. (2012, June).
|
|
103
|
+
Sasaki metrics for analysis of longitudinal data on manifolds.
|
|
104
|
+
In 2012 IEEE conference on computer vision and pattern recognition (pp. 1027-1034). IEEE.
|
|
105
|
+
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4270017/
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, TM: Manifold, Ns: int = 10):
|
|
109
|
+
"""
|
|
110
|
+
Constructor.
|
|
111
|
+
|
|
112
|
+
:param TM: TangentBundle object
|
|
113
|
+
:param Ns: scalar that determines the number of discretization steps used in the approximation of the
|
|
114
|
+
exponential and logarithm maps
|
|
115
|
+
"""
|
|
116
|
+
self._TM = TM
|
|
117
|
+
self.Ns = Ns
|
|
118
|
+
|
|
119
|
+
def __str__(self) -> str:
|
|
120
|
+
return "Sasaki structure"
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def typicaldist(self):
|
|
124
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
125
|
+
|
|
126
|
+
def inner(self, pu: jnp.array, vw: jnp.array, xy: jnp.array) -> float:
|
|
127
|
+
"""Inner product between two tangent vectors at point in TM.
|
|
128
|
+
|
|
129
|
+
:param pu: element of _TM
|
|
130
|
+
:param vw: tangent vector at pv
|
|
131
|
+
:param xy: tangent vector at pv
|
|
132
|
+
:return: inner product (scalar) of vw and xy
|
|
133
|
+
"""
|
|
134
|
+
p = self._TM.bundle_projection(pu)
|
|
135
|
+
# compute Sasaki inner product via metric of the underlying manifold
|
|
136
|
+
base_metric = self._TM.base_manifold.metric.inner
|
|
137
|
+
return base_metric(p, vw[0], xy[0]) + base_metric(p, vw[1], xy[1])
|
|
138
|
+
|
|
139
|
+
def flat(self, pu: jnp.array, xy: jnp.array) -> jnp.array:
|
|
140
|
+
"""Lower vector xy at pu with the metric"""
|
|
141
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
142
|
+
|
|
143
|
+
def sharp(self, pu: jnp.array, d_xy: jnp.array) -> jnp.array:
|
|
144
|
+
"""Raise covector d_xy at pu with the metric"""
|
|
145
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
146
|
+
|
|
147
|
+
def egrad2rgrad(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
|
|
148
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
149
|
+
|
|
150
|
+
def ehess2rhess(self, pu, G, H, vw):
|
|
151
|
+
"""Converts the Euclidean gradient G and Hessian H of a function at
|
|
152
|
+
a point pv along a tangent vector uw to the Riemannian Hessian
|
|
153
|
+
along X on the manifold.
|
|
154
|
+
"""
|
|
155
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
156
|
+
|
|
157
|
+
def geodesic_discrete(self, pu: jnp.array, qr: jnp.array) -> List[jnp.array]:
|
|
158
|
+
"""
|
|
159
|
+
Compute Sasaki geodesic employing a variational time discretization.
|
|
160
|
+
|
|
161
|
+
:param pu: element of TM
|
|
162
|
+
:param qr: element of TM
|
|
163
|
+
:return: array-like, shape=[n_steps + 1, 2, M.shape]
|
|
164
|
+
Discrete geodesic x(s)=(p(s), u(s)) in Sasaki metric connecting
|
|
165
|
+
pu = x(0) and qr = x(1).
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
Ns = self.Ns
|
|
169
|
+
base_connection = self._TM.base_manifold.connec
|
|
170
|
+
par_trans = base_connection.transp
|
|
171
|
+
p0, u0 = pu[0], pu[1]
|
|
172
|
+
pL, uL = qr[0], qr[1]
|
|
173
|
+
|
|
174
|
+
def grad(c):
|
|
175
|
+
""" Gradient of path energy for discrete geodesic c """
|
|
176
|
+
def grad_i(puP, pu, puN):
|
|
177
|
+
v, w = base_connection.log(pu[0], puN[0]), par_trans(puN[0], pu[0], puN[1]) - pu[1]
|
|
178
|
+
gp = .5 * (v + base_connection.log(pu[0], puP[0])) \
|
|
179
|
+
+ base_connection.curvature_tensor(pu[0], pu[1], w, v)
|
|
180
|
+
gu = w + par_trans(puP[0], pu[0], puP[1]) - pu[1]
|
|
181
|
+
return jnp.array([gp, gu])
|
|
182
|
+
return -Ns * jax.vmap(grad_i)(c[:-2], c[1:-1], c[2:])
|
|
183
|
+
|
|
184
|
+
# Initial guess for gradient_descent
|
|
185
|
+
vw = base_connection.log(p0, pL)
|
|
186
|
+
s = jnp.linspace(0., 1., Ns + 1)
|
|
187
|
+
def init(t):
|
|
188
|
+
p_ini = base_connection.exp(p0, t * vw)
|
|
189
|
+
u_ini = (1 - t) * par_trans(p0, p_ini, u0) + t * par_trans(pL, p_ini, uL)
|
|
190
|
+
return jnp.array([p_ini, u_ini])
|
|
191
|
+
pu_ini = jax.vmap(init)(s[1:-1])
|
|
192
|
+
pu_ini = jnp.vstack([pu[None], pu_ini, qr[None]])
|
|
193
|
+
|
|
194
|
+
# Minimization by gradient descent
|
|
195
|
+
x = _gradient_descent(self._TM, pu_ini, grad)
|
|
196
|
+
# x = _gradient_descent(pu_ini, grad, self.exp)
|
|
197
|
+
|
|
198
|
+
return x
|
|
199
|
+
|
|
200
|
+
def exp(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
|
|
201
|
+
"""Compute the exponential of the Levi-Civita connection of the Sasaki metric.
|
|
202
|
+
|
|
203
|
+
Exponential map at pv of uw computed by
|
|
204
|
+
shooting a Sasaki geodesic using an Euler integration in TM.
|
|
205
|
+
|
|
206
|
+
:param pu: element of TM
|
|
207
|
+
:param vw: tangent vector at pv
|
|
208
|
+
:return: point at time 1 of the geodesic that starts in pu with initial velocity vw
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
base_connection = self._TM.base_manifold.connec
|
|
212
|
+
par_trans = self._TM.base_manifold.connec.transp
|
|
213
|
+
Ns = self.Ns
|
|
214
|
+
eps = 1 / Ns
|
|
215
|
+
|
|
216
|
+
def body(carry, _):
|
|
217
|
+
p, u, v, w = carry
|
|
218
|
+
p_ = base_connection.exp(p, eps * v)
|
|
219
|
+
u_ = par_trans(p, p_, u + eps * w)
|
|
220
|
+
v_ = par_trans(p, p_, v - eps * base_connection.curvature_tensor(p, u, w, v))
|
|
221
|
+
w_ = par_trans(p, p_, w)
|
|
222
|
+
return (p_, u_, v_, w_), None
|
|
223
|
+
|
|
224
|
+
(p, u, *_), _ = jax.lax.scan(body, (*pu, *vw), jnp.empty(Ns))
|
|
225
|
+
|
|
226
|
+
return jnp.stack([p, u])
|
|
227
|
+
|
|
228
|
+
def log(self, pu: jnp.array, qr: jnp.array) -> jnp.array:
|
|
229
|
+
"""Compute the logarithm of the Levi-Civita connection of the Sasaki metric.
|
|
230
|
+
|
|
231
|
+
Logarithmic map at base_point p of pu computed by
|
|
232
|
+
iteratively relaxing a discretized geodesic between pu and qw.
|
|
233
|
+
|
|
234
|
+
For a derivation of the algorithm see https://opus4.kobv.de/opus4-zib/frontdoor/index/index/docId/8717.
|
|
235
|
+
|
|
236
|
+
:param pu: element of TM
|
|
237
|
+
:param qr: element of TM
|
|
238
|
+
:return: tangent vector at pu (inverse of exp)
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
base_connection = self._TM.base_manifold.connec
|
|
242
|
+
par_trans = base_connection.transp
|
|
243
|
+
Ns = self.Ns
|
|
244
|
+
|
|
245
|
+
def do_log(bs_pt, pt):
|
|
246
|
+
pu = self.geodesic_discrete(bs_pt, pt)
|
|
247
|
+
p1, u1 = pu[1][0], pu[1][1]
|
|
248
|
+
p0, u0 = bs_pt[0], bs_pt[1]
|
|
249
|
+
w = par_trans(p1, p0, u1) - u0
|
|
250
|
+
v = base_connection.log(p0, p1)
|
|
251
|
+
return Ns * jnp.array([v, w])
|
|
252
|
+
|
|
253
|
+
return do_log(pu, qr)
|
|
254
|
+
|
|
255
|
+
def curvature_tensor(self, pu, X, Y, Z):
|
|
256
|
+
"""Evaluates the curvature tensor R of the connection at pu on the vectors X, Y, Z. With nabla_X Y denoting
|
|
257
|
+
the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
258
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
259
|
+
is used.
|
|
260
|
+
"""
|
|
261
|
+
# TODO: Implement from "Curvature of the Induced Riemannian Metric on the Tangent Bundle of a Riemannian
|
|
262
|
+
# Manifold" (1971) by Kowalski
|
|
263
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
264
|
+
|
|
265
|
+
def geopoint(self, pu: jnp.array, qr: jnp.array, t: float) -> jnp.array:
|
|
266
|
+
"""Evaluate geodesic in TM
|
|
267
|
+
|
|
268
|
+
:param pu: element of TM
|
|
269
|
+
:param qr: element of TM
|
|
270
|
+
:param t: scalar between 0 and 1
|
|
271
|
+
:return: element of the geodesic between pv and pu evaluated at time t
|
|
272
|
+
"""
|
|
273
|
+
return self.exp(pu, t * self.log(pu, qr))
|
|
274
|
+
|
|
275
|
+
def retr(self, pu: jnp.array, vw: jnp.array) -> jnp.array:
|
|
276
|
+
return self.exp(pu, vw)
|
|
277
|
+
|
|
278
|
+
def transp(self, pu, qr, vw):
|
|
279
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
280
|
+
|
|
281
|
+
def pairmean(self, pu: jnp.array, qr: jnp.array) -> jnp.array:
|
|
282
|
+
"""Fréchet mean of 2 point in TM
|
|
283
|
+
|
|
284
|
+
:param pu: element of TM
|
|
285
|
+
:param qr: element of TM
|
|
286
|
+
:return: Fréchet mean of pv and qw
|
|
287
|
+
"""
|
|
288
|
+
return self.geopoint(pu, qr, .5)
|
|
289
|
+
|
|
290
|
+
def dist(self, pu: jnp.array, qr: jnp.array) -> float:
|
|
291
|
+
"""Distance function that is induced on TM by the Sasaki metric
|
|
292
|
+
|
|
293
|
+
:param pu: element of TM
|
|
294
|
+
:param qr: element of TM
|
|
295
|
+
:return: distance between pu and qr in TM
|
|
296
|
+
"""
|
|
297
|
+
vw = self.log(pu, qr)
|
|
298
|
+
return jnp.sqrt(self.inner(pu, vw, vw))
|
|
299
|
+
|
|
300
|
+
def jacobiField(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
|
|
301
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
302
|
+
|
|
303
|
+
def adjJacobi(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
|
|
304
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _gradient_descent(TM: Manifold, x_ini: jnp.array, grad: Callable, step_size: float = 0.1, max_iter: int = 100,
|
|
308
|
+
tol: float = 1e-6) -> jnp.array:
|
|
309
|
+
"""Apply a gradient descent to compute the discrete geodesic in TM with the Sasaki structure.
|
|
310
|
+
|
|
311
|
+
:param TM: Tangent bundle
|
|
312
|
+
:param x_ini: initial discrete curve as list of points in TM
|
|
313
|
+
:param grad: gradient function
|
|
314
|
+
:param step_size: step size
|
|
315
|
+
:param max_iter: maximum number of iterations
|
|
316
|
+
:param tol: tolerance for convergence
|
|
317
|
+
:return: discrete geodesic as list of points in _TM
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def body(args):
|
|
321
|
+
x, grad_norm, tol, i = args
|
|
322
|
+
grad_x = grad(x)
|
|
323
|
+
|
|
324
|
+
x.at[1:-1].set(jax.vmap(TM.connec.exp)(x[1:-1], -step_size * grad_x))
|
|
325
|
+
grad_norm = jnp.linalg.norm(grad_x)
|
|
326
|
+
|
|
327
|
+
return x, grad_norm, tol, i + 1
|
|
328
|
+
|
|
329
|
+
def cond(args):
|
|
330
|
+
_, g_norm, tol, i = args
|
|
331
|
+
|
|
332
|
+
c = jnp.array([g_norm > tol, i < max_iter])
|
|
333
|
+
return jnp.all(c)
|
|
334
|
+
|
|
335
|
+
x, g_norm, _, i = jax.lax.while_loop(cond, body, (x_ini, jnp.array(1.), tol, jnp.array(0)))
|
|
336
|
+
|
|
337
|
+
return x
|