morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
from morphomatics.geom.bezier_spline import BezierSpline, full_set, indep_set
|
|
19
|
+
from morphomatics.manifold import Metric
|
|
20
|
+
from morphomatics.manifold import Manifold, TangentBundle, PowerManifold
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CubicBezierfold(Manifold):
|
|
24
|
+
"""Manifold of _cubic_ Bézier splines.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, M: Manifold, n_segments: int, isscycle: bool = False, structure='GeneralizedSasaki'):
|
|
29
|
+
"""Manifold of cubic Bézier splines.
|
|
30
|
+
|
|
31
|
+
:arg M: base manifold in which the curves lie
|
|
32
|
+
:arg n_segments: number of segments
|
|
33
|
+
:arg iscycle: boolean indicating whether the splines are closed
|
|
34
|
+
:arg structure: type of geometric structure
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
self._M = M
|
|
38
|
+
|
|
39
|
+
self._degrees = np.full(n_segments, 3)
|
|
40
|
+
|
|
41
|
+
if isscycle:
|
|
42
|
+
name = 'Manifold of closed, cubic Bézier splines through ' + str(M)
|
|
43
|
+
K = 2*n_segments - 1
|
|
44
|
+
else:
|
|
45
|
+
name = 'Manifold of cubic Bézier splines through ' + str(M)
|
|
46
|
+
K = 2*n_segments + 1
|
|
47
|
+
|
|
48
|
+
dimension = (K + 1) * M.dim
|
|
49
|
+
point_shape = ((K + 1)//2, 2) + M.point_shape
|
|
50
|
+
self._K = K
|
|
51
|
+
super().__init__(name, dimension, point_shape)
|
|
52
|
+
|
|
53
|
+
self._iscycle = isscycle
|
|
54
|
+
|
|
55
|
+
if structure:
|
|
56
|
+
getattr(self, f'init{structure}Structure')()
|
|
57
|
+
|
|
58
|
+
def tree_flatten(self):
|
|
59
|
+
children, aux = super().tree_flatten()
|
|
60
|
+
return children+(self.M,), aux+(self.nsegments, self.iscycle)
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def tree_unflatten(cls, aux_data, children):
|
|
64
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
65
|
+
*children, M = children
|
|
66
|
+
*aux_data, nsegments, iscycle = aux_data
|
|
67
|
+
obj = cls(M, nsegments, iscycle, structure=None)
|
|
68
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
69
|
+
return obj
|
|
70
|
+
|
|
71
|
+
def initGeneralizedSasakiStructure(self):
|
|
72
|
+
"""
|
|
73
|
+
Instantiate generalized Sasaki structure with discrete methods.
|
|
74
|
+
"""
|
|
75
|
+
structure = CubicBezierfold.GeneralizedSasakiStructure(self)
|
|
76
|
+
self._metric = structure
|
|
77
|
+
self._connec = structure
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def M(self) -> Manifold:
|
|
81
|
+
"""Return the underlying manifold
|
|
82
|
+
"""
|
|
83
|
+
return self._M
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def nsegments(self) -> int:
|
|
87
|
+
"""Returns the number of segments."""
|
|
88
|
+
return len(self._degrees)
|
|
89
|
+
|
|
90
|
+
# TODO: likely not needed here anymore
|
|
91
|
+
@property
|
|
92
|
+
def K(self) -> int:
|
|
93
|
+
"""Return the generalized degree of a Bezier spline, i.e., the number of independent control points - 1
|
|
94
|
+
"""
|
|
95
|
+
return self._K
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def iscycle(self) -> bool:
|
|
99
|
+
"""Return whether the Bezierfold consists of non-closed or closed splines
|
|
100
|
+
"""
|
|
101
|
+
return self._iscycle
|
|
102
|
+
|
|
103
|
+
def correct_type(self, B: BezierSpline) -> bool:
|
|
104
|
+
"""Check whether B has the right segment degrees"""
|
|
105
|
+
if jnp.all(jnp.atleast_1d(B.degrees) == jnp.repeat(3, self.nsegments)):
|
|
106
|
+
return True
|
|
107
|
+
else:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def rand(self, key: jax.Array) -> BezierSpline: #TODO: velocity repr.
|
|
111
|
+
"""Return random Bézier spline"""
|
|
112
|
+
subkeys = jax.random.split(key, self.K + 1)
|
|
113
|
+
return BezierSpline(self.M, full_set(self.M, jax.vmap(self.M.rand)(subkeys),
|
|
114
|
+
self.degrees, self.iscycle))
|
|
115
|
+
|
|
116
|
+
def randvec(self, B: BezierSpline, key: jax.Array) -> jnp.array: #TODO: velocity repr.
|
|
117
|
+
"""Return random vector for every independent control point"""
|
|
118
|
+
pts = indep_set(B, self.iscycle)
|
|
119
|
+
subkeys = jax.random.split(key, len(pts))
|
|
120
|
+
return jax.vmap(self.M.randvec)(pts, subkeys)
|
|
121
|
+
|
|
122
|
+
def zerovec(self) -> jnp.array: #TODO: velocity repr.
|
|
123
|
+
"""Return zero vector for every independent control point"""
|
|
124
|
+
return jnp.array([self.M.zerovec() for k in self.K + 1])
|
|
125
|
+
|
|
126
|
+
def to_coords(self, B: BezierSpline) -> jnp.array:
|
|
127
|
+
"""Get initial and final velocities (elements of the tangent bundle) of the segments of a C^1 Bézier spline with cubic
|
|
128
|
+
segments. velocities at connections are identified and returned only once.
|
|
129
|
+
|
|
130
|
+
:param B: Bézier spline with cubic segments through a Riemannian manifold M
|
|
131
|
+
:return: array of elements (ordered along the first dimension) of the tangent bundle TM
|
|
132
|
+
|
|
133
|
+
ATTENTION: made for splines with cubic segments only!
|
|
134
|
+
"""
|
|
135
|
+
assert jnp.all(B.degrees == 3)
|
|
136
|
+
|
|
137
|
+
def f(p):
|
|
138
|
+
return jnp.array([p[-1], -B._M.connec.log(p[-1], p[-2])])
|
|
139
|
+
|
|
140
|
+
b = jax.vmap(f)(B.control_points)
|
|
141
|
+
|
|
142
|
+
if not self.iscycle:
|
|
143
|
+
a = jnp.array([
|
|
144
|
+
B.control_points[0, 0], B._M.connec.log(B.control_points[0, 0], B.control_points[0, 1])
|
|
145
|
+
])
|
|
146
|
+
|
|
147
|
+
return jnp.concatenate((a[None, ...], b))
|
|
148
|
+
else:
|
|
149
|
+
return b
|
|
150
|
+
|
|
151
|
+
def from_coords(self, Q: jnp.array) -> BezierSpline:
|
|
152
|
+
"""Compute the cubic-only Bézier spline corresponding to the given velocities
|
|
153
|
+
|
|
154
|
+
:param Q: array of velocities
|
|
155
|
+
:return: Bézier spline with cubic segments that corresponds to P
|
|
156
|
+
|
|
157
|
+
ATTENTION: made for velocities of splines with cubic segments only!
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def f(pu, qw):
|
|
161
|
+
p, u = pu[0], pu[1]
|
|
162
|
+
q, w = qw[0], qw[1]
|
|
163
|
+
return jnp.array([p, self._M.connec.exp(p, u), self._M.connec.exp(q, -w), q])
|
|
164
|
+
|
|
165
|
+
if self.iscycle:
|
|
166
|
+
# last velocity vector is also first velocity vector
|
|
167
|
+
Q = jnp.concatenate([Q[None, -1], Q])
|
|
168
|
+
|
|
169
|
+
P = jax.vmap(f)(Q[:-1], Q[1:])
|
|
170
|
+
|
|
171
|
+
return BezierSpline(self._M, P, iscycle=self.iscycle)
|
|
172
|
+
|
|
173
|
+
def proj(self, pu, vw):
|
|
174
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
175
|
+
|
|
176
|
+
############################## Sasaki structure ##############################
|
|
177
|
+
class GeneralizedSasakiStructure(Metric):
|
|
178
|
+
"""
|
|
179
|
+
This class implements the generalization of the Sasaki metric to Bézier splines with cubic segments
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(self, Bf: Manifold, Ns: int = 3):
|
|
183
|
+
"""
|
|
184
|
+
Constructor.
|
|
185
|
+
|
|
186
|
+
:param Bf: Bézierfold object
|
|
187
|
+
:param Ns: scalar that determines the number of discretization steps used in the approximation of the
|
|
188
|
+
exponential and logarithm maps in the tangent bundle
|
|
189
|
+
"""
|
|
190
|
+
self._tangent_bundle_power = PowerManifold(TangentBundle(Bf.M), Bf.nsegments + 1)
|
|
191
|
+
self._Bf = Bf
|
|
192
|
+
self.Ns = Ns
|
|
193
|
+
|
|
194
|
+
def __str__(self):
|
|
195
|
+
return "Generalized Sasaki structure"
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def typicaldist(self) -> float:
|
|
199
|
+
return self._tangent_bundle_power.metric.typicaldist
|
|
200
|
+
|
|
201
|
+
def inner(self, p_B: jnp.array, v: jnp.array, w: jnp.array) -> float:
|
|
202
|
+
"""Generalized Sasaki metric
|
|
203
|
+
|
|
204
|
+
:param p_B: velocities of a Bézier spline
|
|
205
|
+
:param v: tangent vector in the tangent space of the velocities of B
|
|
206
|
+
:param w: tangent vector in the tangent space of the velocities of B
|
|
207
|
+
:return: inner product between X and Y
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
return self._tangent_bundle_power.metric.inner(p_B, v, w)
|
|
211
|
+
|
|
212
|
+
def flat(self, p_B, v):
|
|
213
|
+
"""Lower vector X at p with the metric"""
|
|
214
|
+
return self._tangent_bundle_power.metric.flat(p_B, v)
|
|
215
|
+
|
|
216
|
+
def sharp(self, p_B, dv):
|
|
217
|
+
"""Raise covector dX at p with the metric"""
|
|
218
|
+
return self._tangent_bundle_power.metric.sharp(p_B, dv)
|
|
219
|
+
|
|
220
|
+
def egrad2rgrad(self, p, X):
|
|
221
|
+
return self._Bf.proj
|
|
222
|
+
|
|
223
|
+
def ehess2rhess(self, pu, G, H, vw):
|
|
224
|
+
"""Converts the Euclidean gradient G and Hessian H of a function at
|
|
225
|
+
a point pv along a tangent vector uw to the Riemannian Hessian
|
|
226
|
+
along X on the manifold.
|
|
227
|
+
"""
|
|
228
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
229
|
+
|
|
230
|
+
def exp(self, p_B: jnp.array, v: jnp.array) -> jnp.array:
|
|
231
|
+
"""Exponential map
|
|
232
|
+
|
|
233
|
+
:param p_B: velocities of a Bézier spline
|
|
234
|
+
:param v: tangent vector in the tangent space of the velocities of Bf
|
|
235
|
+
:return: velocities of the Bézier spline at time 1 on the geodesic with initial velocity v
|
|
236
|
+
"""
|
|
237
|
+
return self._tangent_bundle_power.connec.exp(p_B, v)
|
|
238
|
+
|
|
239
|
+
def log(self, p_A: jnp.array, p_B: jnp.array) -> jnp.array:
|
|
240
|
+
"""Riemannian logarithm map
|
|
241
|
+
|
|
242
|
+
:param p_A: velocities of a Bézier spline A
|
|
243
|
+
:param p_B: velocities of a Bézier spline B
|
|
244
|
+
:return: tangent vector in the tangent space of the velocities of A pointing to the velocities of B
|
|
245
|
+
"""
|
|
246
|
+
return self._tangent_bundle_power.connec.log(p_A, p_B)
|
|
247
|
+
|
|
248
|
+
def curvature_tensor(self, p_B: jnp.array, v: jnp.array, w: jnp.array, x: jnp.array) -> jnp.array:
|
|
249
|
+
"""Riemmannian curvature tensor at a point of the Bézierfold
|
|
250
|
+
|
|
251
|
+
:param p_B: velocities of a Bézier spline
|
|
252
|
+
:param v: tangent vector in the tangent space of the velocities of B
|
|
253
|
+
:param w: tangent vector in the tangent space of the velocities of B
|
|
254
|
+
:param x: tangent vector in the tangent space of the velocities of B
|
|
255
|
+
:return: tangent vector in the tangent space of the velocities of B
|
|
256
|
+
"""
|
|
257
|
+
return self._tangent_bundle_power.connec.curvature_tensor(p_B, v, w, x)
|
|
258
|
+
|
|
259
|
+
def geopoint(self, p_A: jnp.array, p_B: jnp.array, t: float) -> jnp.array:
|
|
260
|
+
"""Evaluate the geodesic through the Bézierfold between A and Bf at time t
|
|
261
|
+
|
|
262
|
+
:param p_A: velocities of a Bézier spline A
|
|
263
|
+
:param p_B: velocities of a Bézier spline B
|
|
264
|
+
:param t: scalar between 0 and 1
|
|
265
|
+
:return: Bézier spline at time t on the geodesic from A to B
|
|
266
|
+
"""
|
|
267
|
+
return self.exp(p_A, t * self.log(p_A, p_B))
|
|
268
|
+
|
|
269
|
+
def retr(self, p_B: jnp.array, v: jnp.array) -> jnp.array:
|
|
270
|
+
return self.exp(p_B, v)
|
|
271
|
+
|
|
272
|
+
def transp(self, p_A: jnp.array, p_B: jnp.array, v: jnp.array) -> jnp.array:
|
|
273
|
+
"""Parallel transport along a geoodesic
|
|
274
|
+
|
|
275
|
+
:param p_A: velocities of a Bézier spline A
|
|
276
|
+
:param p_B: velocities of a Bézier spline B
|
|
277
|
+
:param v: tangent vector in the tangent space of the velocities of A
|
|
278
|
+
:return: tangent vector in the tangent space of the velocities of B: parallel transport of v along the
|
|
279
|
+
geodesic from A to B
|
|
280
|
+
"""
|
|
281
|
+
return self._tangent_bundle_power.connec.transp(p_A, p_B, v)
|
|
282
|
+
|
|
283
|
+
def pairmean(self, p_A: jnp.array, p_B: jnp.array) -> jnp.array:
|
|
284
|
+
"""Fréchet mean of 2 splines
|
|
285
|
+
|
|
286
|
+
:param A: velocities of a Bézier spline A
|
|
287
|
+
:param B: velocities of a Bézier spline B
|
|
288
|
+
:return: velocities of the mean of A and B
|
|
289
|
+
"""
|
|
290
|
+
return self.geopoint(p_A, p_B, .5)
|
|
291
|
+
|
|
292
|
+
def dist(self, p_A: jnp.array, p_B: jnp.array) -> float:
|
|
293
|
+
"""Distance function that is induced on the Bézierfold by the generalized Sasaki metric
|
|
294
|
+
|
|
295
|
+
:param p_A: velocities of a Bézier spline A
|
|
296
|
+
:param p_B: velocities of a Bézier spline B
|
|
297
|
+
:return: distance between A and B
|
|
298
|
+
"""
|
|
299
|
+
return self._tangent_bundle_power.metric.dist(p_A, p_B)
|
|
300
|
+
|
|
301
|
+
def jacobiField(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
|
|
302
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
303
|
+
|
|
304
|
+
def adjJacobi(self, p_A: jnp.array, p_B: jnp.array, t: float, X: jnp.array) -> jnp.array:
|
|
305
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import jax
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
|
|
17
|
+
from scipy import sparse
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from sksparse.cholmod import cholesky as direct_solve
|
|
21
|
+
except:
|
|
22
|
+
from scipy.sparse.linalg import factorized as direct_solve
|
|
23
|
+
|
|
24
|
+
from ..geom import Surface
|
|
25
|
+
from . import SO3, SPD
|
|
26
|
+
from . import ProductManifold, PowerManifold
|
|
27
|
+
from . import ShapeSpace, Metric
|
|
28
|
+
from .util import align
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DifferentialCoords(ShapeSpace):
|
|
32
|
+
"""
|
|
33
|
+
Shape space based on differential coordinates.
|
|
34
|
+
|
|
35
|
+
See:
|
|
36
|
+
Christoph von Tycowicz, Felix Ambellan, Anirban Mukhopadhyay, and Stefan Zachow.
|
|
37
|
+
An Efficient Riemannian Statistical Shape Model using Differential Coordinates.
|
|
38
|
+
Medical Image Analysis, Volume 43, January 2018.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, reference: Surface, commensuration_weights=(1.0, 1.0)):
|
|
42
|
+
"""
|
|
43
|
+
:arg reference: Reference surface (shapes will be encoded as deformations thereof)
|
|
44
|
+
:arg commensuration_weights: weights (rotation, stretch) for commensuration between rotational and stretch parts
|
|
45
|
+
"""
|
|
46
|
+
assert reference is not None
|
|
47
|
+
self.ref = reference
|
|
48
|
+
|
|
49
|
+
# rotation and stretch manifolds
|
|
50
|
+
self.SPD = PowerManifold(SPD(3), len(self.ref.f))
|
|
51
|
+
self.SO = PowerManifold(SO3(), len(self.ref.f))
|
|
52
|
+
self._M = ProductManifold([self.SO, self.SPD], jnp.asarray(commensuration_weights))
|
|
53
|
+
|
|
54
|
+
self.update_ref_geom(self.ref.v)
|
|
55
|
+
|
|
56
|
+
name = f'Differential Coordinates Shape Space'
|
|
57
|
+
super().__init__(name, self.M.dim, self.M.point_shape, self.M.connec, self.M.metric, None)
|
|
58
|
+
|
|
59
|
+
def tree_flatten(self):
|
|
60
|
+
return (self.M,), (self.ref.v.tolist(), self.ref.f.tolist())
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def tree_unflatten(cls, aux_data, children):
|
|
64
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
65
|
+
M = children[0]
|
|
66
|
+
obj = cls(Surface(*aux_data))
|
|
67
|
+
obj._M = M
|
|
68
|
+
obj.SO, obj.SPD = M.manifolds
|
|
69
|
+
return obj
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def M(self):
|
|
73
|
+
return self._M
|
|
74
|
+
|
|
75
|
+
def update_ref_geom(self, v):
|
|
76
|
+
self.ref.v=v
|
|
77
|
+
|
|
78
|
+
# center of gravity
|
|
79
|
+
self.CoG = self.ref.v.mean(axis=0)
|
|
80
|
+
|
|
81
|
+
# setup Poisson system
|
|
82
|
+
S = self.ref.div @ self.ref.grad
|
|
83
|
+
# add soft-constraint fixing translational DoF
|
|
84
|
+
S += sparse.coo_matrix(([1.0], ([0], [0])), S.shape) # make pos-def
|
|
85
|
+
self.poisson = direct_solve(S.tocsc())
|
|
86
|
+
|
|
87
|
+
# set metric weights
|
|
88
|
+
w = jnp.asarray(self.ref.face_areas)
|
|
89
|
+
self.SO.metric_weights = self.SPD.metric_weights = w
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def to_coords(self, v):
|
|
93
|
+
"""
|
|
94
|
+
:arg v: #v-by-3 array of vertex coordinates
|
|
95
|
+
:return: differentical coords.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
# align
|
|
99
|
+
v = align(v, self.ref.v)
|
|
100
|
+
|
|
101
|
+
# compute gradients
|
|
102
|
+
D = self.ref.grad @ v
|
|
103
|
+
|
|
104
|
+
# D holds transpose of def. grads.
|
|
105
|
+
# -> compute left polar decomposition for right stretch tensor
|
|
106
|
+
|
|
107
|
+
# decompose...
|
|
108
|
+
U, S, Vt = np.linalg.svd(D.reshape(-1, 3, 3))
|
|
109
|
+
|
|
110
|
+
# ...rotation
|
|
111
|
+
R = np.einsum('...ij,...jk', U, Vt)
|
|
112
|
+
W = np.ones_like(S)
|
|
113
|
+
W[:, -1] = np.linalg.det(R)
|
|
114
|
+
R = np.einsum('...ij,...j,...jk', U, W, Vt)
|
|
115
|
+
|
|
116
|
+
# ...stretch
|
|
117
|
+
S[:, -1] = 1 # no stretch (=1) in normal direction
|
|
118
|
+
# for degenerate triangles
|
|
119
|
+
# TODO: check which direction is normal in degenerate case
|
|
120
|
+
S[S < 1e-6] = 1e-6
|
|
121
|
+
U = np.einsum('...ij,...j,...kj', U, S, U)
|
|
122
|
+
|
|
123
|
+
return self.M.entangle([R, U])
|
|
124
|
+
|
|
125
|
+
def from_coords(self, c):
|
|
126
|
+
"""
|
|
127
|
+
:arg c: differentical coords.
|
|
128
|
+
:returns: #v-by-3 array of vertex coordinates
|
|
129
|
+
"""
|
|
130
|
+
# compose
|
|
131
|
+
R, U = self.M.disentangle(c)
|
|
132
|
+
D = jnp.einsum('...ij,...jk', U, R) # <-- from left polar decomp.
|
|
133
|
+
|
|
134
|
+
# solve Poisson system
|
|
135
|
+
rhs = self.ref.div @ D.reshape(-1, 3)
|
|
136
|
+
v = self.poisson(rhs)
|
|
137
|
+
# move to CoG
|
|
138
|
+
v += self.CoG - v.mean(axis=0)
|
|
139
|
+
|
|
140
|
+
return v
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def ref_coords(self):
|
|
144
|
+
return jnp.tile(jnp.eye(3), (2*len(self.ref.f), 1)).reshape(self.point_shape)
|
|
145
|
+
|
|
146
|
+
def rand(self, key: jax.Array):
|
|
147
|
+
return self.M.rand(key)
|
|
148
|
+
|
|
149
|
+
def zerovec(self):
|
|
150
|
+
"""Returns the zero vector in any tangent space."""
|
|
151
|
+
return self.M.zerovec()
|
|
152
|
+
|
|
153
|
+
def projToGeodesic(self, X, Y, P, max_iter = 10):
|
|
154
|
+
'''
|
|
155
|
+
Project P onto geodesic from X to Y.
|
|
156
|
+
|
|
157
|
+
See:
|
|
158
|
+
Felix Ambellan, Stefan Zachow, Christoph von Tycowicz.
|
|
159
|
+
Geodesic B-Score for Improved Assessment of Knee Osteoarthritis.
|
|
160
|
+
Proc. Information Processing in Medical Imaging (IPMI), LNCS, 2021.
|
|
161
|
+
|
|
162
|
+
:arg X, Y: manifold coords defining geodesic X->Y.
|
|
163
|
+
:arg P: manifold coords to be projected to X->Y.
|
|
164
|
+
:returns: manifold coords of projection of P to X->Y
|
|
165
|
+
'''
|
|
166
|
+
|
|
167
|
+
# all tagent vectors in common space i.e. algebra
|
|
168
|
+
v = self.connec.log(X, Y)
|
|
169
|
+
v = v / self.metric.norm(X, v)
|
|
170
|
+
|
|
171
|
+
# initial guess
|
|
172
|
+
Pi = X
|
|
173
|
+
|
|
174
|
+
# solver loop
|
|
175
|
+
for _ in range(max_iter):
|
|
176
|
+
w = self.connec.log(Pi, P)
|
|
177
|
+
d = self.metric.inner(Pi, v, w)
|
|
178
|
+
|
|
179
|
+
# print(f'|<v, w>|={d}')
|
|
180
|
+
if abs(d) < 1e-6: break
|
|
181
|
+
|
|
182
|
+
Pi = self.connec.exp(Pi, d * v)
|
|
183
|
+
|
|
184
|
+
return Pi
|
|
185
|
+
|
|
186
|
+
def proj(self, X, A):
|
|
187
|
+
"""orthogonal (with respect to the euclidean inner product) projection of ambient
|
|
188
|
+
vector (i.e. (2,k,3,3) array) onto the tangentspace at X"""
|
|
189
|
+
# disentangle coords. into rotations and stretches
|
|
190
|
+
R, U = self.disentangle(X)
|
|
191
|
+
r, u = self.disentangle(A)
|
|
192
|
+
|
|
193
|
+
# project in each component
|
|
194
|
+
r = self.SO.proj(R, r)
|
|
195
|
+
u = self.SPD.proj(U, u)
|
|
196
|
+
|
|
197
|
+
return self.entangle(r, u)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from morphomatics.manifold import Manifold
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def pole_ladder(M: Manifold, p: jnp.array, q: jnp.array, v: jnp.array, n_step: int = 1) -> jnp.array:
|
|
21
|
+
"""Pole Ladder algorithm to approximate parallel transport along geodesics in affine manifolds
|
|
22
|
+
See
|
|
23
|
+
|
|
24
|
+
Numerical Accuracy of Ladder Schemes for Parallel Transport on Manifolds, Nicolas Guigui, Xavier Pennec
|
|
25
|
+
Foundations of Computational Mathematics (2022) 22:757–790,
|
|
26
|
+
|
|
27
|
+
for details. The method is exact in Symmetric Spaces.
|
|
28
|
+
|
|
29
|
+
:param M: Manifold
|
|
30
|
+
:param p: Point in M
|
|
31
|
+
:param q: Point in M
|
|
32
|
+
:param v: Vector in the tangent space at p
|
|
33
|
+
:param n_step: Number of steps
|
|
34
|
+
:return: Vector in the tangent space at q
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# scaling speeds up convergence
|
|
38
|
+
v = v / n_step**2
|
|
39
|
+
|
|
40
|
+
def body(carry, _):
|
|
41
|
+
_P, _p_pr, _i = carry
|
|
42
|
+
_m = _P[_i]
|
|
43
|
+
_q_pr = M.connec.exp(_m, -M.connec.log(_m, _p_pr))
|
|
44
|
+
|
|
45
|
+
return (_P, _q_pr, _i+1), None
|
|
46
|
+
|
|
47
|
+
U = M.connec.log(p, q)
|
|
48
|
+
t = np.array([i/(2*n_step) for i in range(1, 2*n_step, 2)])
|
|
49
|
+
tU = t.reshape((-1,) + (1,)*U.ndim) * U[None]
|
|
50
|
+
|
|
51
|
+
P = jax.vmap(M.connec.exp, (None, 0))(p, tU)
|
|
52
|
+
p_pr = M.connec.exp(p, v)
|
|
53
|
+
|
|
54
|
+
(_, q_pr, _), _ = jax.lax.scan(body, (P, p_pr, 0), None, length=n_step)
|
|
55
|
+
|
|
56
|
+
return (-1)**n_step * n_step**2 * M.connec.log(q, q_pr)
|