morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,419 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+
16
+ from morphomatics.manifold import Manifold, Connection, LieGroup, Metric, SO3, GLpn, Euclidean
17
+ from morphomatics.manifold.so_3 import logm as SO3_logm, expm as SO3_expm
18
+ from morphomatics.manifold.util import multiskew
19
+
20
+ # SE(3) is affine submanifold of GL+(4) -> (re)use methods
21
+ GLp4 = GLpn(n=4, structure='AffineGroup')
22
+ # also reusing methods from rotations and R^3 for product structure
23
+ SO = SO3()
24
+ R3 = Euclidean()
25
+
26
+ class SE3(Manifold):
27
+ """Returns the product manifold SE(3), i.e., rigid body motions.
28
+
29
+ manifold = SE3()
30
+
31
+ Elements of SE(3) are represented as matrices of size 4x4, where the upper-left 3x3 block is the rotational part,
32
+ the upper-right 3x1 part is the translational part, and the lower row is [0 0 0 1]. Tangent vectors, consequently,
33
+ follow the same ‘layout‘.
34
+
35
+ To improve efficiency, tangent vectors are always represented in the Lie Algebra.
36
+ """
37
+
38
+ def __init__(self, structure='AffineGroup'):
39
+ name = 'Rigid motions'
40
+ dimension = 6
41
+ point_shape = (4, 4)
42
+ super().__init__(name, dimension, point_shape)
43
+
44
+ if structure:
45
+ getattr(self, f'init{structure}Structure')()
46
+
47
+ @classmethod
48
+ def tree_unflatten(cls, aux_data, children):
49
+ """Specifies an unflattening recipe for PyTree registration."""
50
+ obj = cls(structure=None)
51
+ obj.tree_unflatten_instance(aux_data, children)
52
+ return obj
53
+
54
+ def initAffineGroupStructure(self):
55
+ """
56
+ Instantiate SE(3) with standard Lie group structure and canonical Cartan-Shouten connection.
57
+ """
58
+ self._connec = SE3.CartanShoutenConnection(self)
59
+ self._group = SE3.GroupStructure(self)
60
+
61
+ def initCanonicalRiemannianStructure(self):
62
+ """
63
+ Instantiate SE(3) with standard canonical left invariant Riemannian metric and the corresponding bi-invariant
64
+ connection.
65
+ """
66
+ self._metric = SE3.CanonicalRiemannianStructure(self)
67
+ self._connec = SE3.CanonicalRiemannianStructure(self)
68
+ self._group = SE3.GroupStructure(self)
69
+
70
+ def rand(self, key: jax.Array):
71
+ k1, k2 = jax.random.split(key, 2)
72
+ return jnp.zeros(self.point_shape) \
73
+ .at[:3, :3].set(SO.rand(k1)) \
74
+ .at[:3, 3].set(jax.random.normal(k2, (3,))) \
75
+ .at[3, 3].set(1)
76
+
77
+ def randvec(self, P, key: jax.Array):
78
+ k1, k2 = jax.random.split(key, 2)
79
+ return jnp.zeros(self.point_shape) \
80
+ .at[:3, :3].set(SO.randvec(P[:3, :3], k1)) \
81
+ .at[:3, 3].set(jax.random.normal(k2, (3,)))
82
+
83
+ def zerovec(self):
84
+ return jnp.zeros(self.point_shape)
85
+
86
+ def proj(self, P, X):
87
+ X = X.at[:3, :3].set(SO.proj(P[:3, :3], X[:3, :3]))
88
+ return X.at[3, :].set(0)
89
+
90
+ def get_so3(self, P):
91
+ """get SO3-part of P in SE3 or se3"""
92
+ return P[:3, :3]
93
+
94
+ def get_r3(self, P):
95
+ """get R3-part of P in SE3 or se3"""
96
+ return P[:3, -1].squeeze()
97
+
98
+
99
+ def homogeneous_coords(self, M, X):
100
+ """create SE3-element from M in SO3 and X in R^3"""
101
+ P = jnp.zeros(self.point_shape)
102
+ P = P.at[:3, :3].set(M)
103
+ P = P.at[:3, -1].set(X)
104
+ P = P.at[3, 3].set(1.)
105
+ return P
106
+
107
+ class GroupStructure(LieGroup):
108
+ def __init__(self, M):
109
+ """
110
+ Constructor.
111
+ """
112
+ self._M = M
113
+
114
+ def __str__(self):
115
+ return "Semi-direct (product) group structure"
116
+
117
+ @property
118
+ def identity(self):
119
+ """Identity element of SE(3)"""
120
+ return jnp.eye(4)
121
+
122
+ def lefttrans(self, P, S):
123
+ """Left-translation of P by S"""
124
+ return GLp4.group.lefttrans(P, S)
125
+
126
+ def righttrans(self, P, S):
127
+ """Right translation of P by S.
128
+ """
129
+ return GLp4.group.righttrans(P, S)
130
+
131
+ def dleft(self, P, X):
132
+ """Derivative of the left translation by P at the identity applied to the tangent vector X.
133
+ """
134
+ return GLp4.group.dleft(P, X)
135
+
136
+ def dright(self, P, X):
137
+ """Derivative of the right translation by P at the identity applied to the tangent vector X.
138
+ """
139
+ return GLp4.group.dright(P, X)
140
+
141
+ def dleft_inv(self, P, X):
142
+ """Derivative of the left translation by P^{-1} at f applied to the tangent vector X.
143
+ """
144
+ return GLp4.group.dleft_inv(P, X)
145
+
146
+ def dright_inv(self, P, X):
147
+ """Derivative of the right translation by P^{-1} at f applied to the tangent vector X.
148
+ """
149
+ return GLp4.group.dright_inv(P, X)
150
+
151
+ def inverse(self, P):
152
+ """Inverse map of the Lie group.
153
+ """
154
+ Rt = jnp.einsum('ij->ji', P[:3, :3])
155
+ return P.at[:3, :3].set(Rt) \
156
+ .at[:3, 3].set(jnp.einsum('ij,j', -Rt, P[:3, 3]))
157
+ # return GLp4.group.inverse(P)
158
+
159
+ def coords(self, X):
160
+ """Coordinate map for the tangent space at the identity."""
161
+ x123 = jnp.stack((X[0, 1], X[0, 2], X[1, 2])) * 2 ** .5
162
+ x456 = X[:3, 3].transpose()
163
+ x = jnp.concatenate((x123, x456), axis=0)
164
+ return x.reshape((-1, 1), order='F')
165
+
166
+ def coords_inverse(self, X):
167
+ x123 = X[:3]
168
+ x456 = X[3:]
169
+ Y = self._M.group.identity
170
+ Y = Y.at[:3, :3].set(self._SO3.group.coords_inverse(x123))
171
+ Y = Y.at[:3, -1].set(x456)
172
+ return Y
173
+
174
+ def bracket(self, X, Y):
175
+ """Lie bracket in Lie algebra."""
176
+ return GLp4.group.bracket(X, Y)
177
+
178
+ def adjrep(self, P, X):
179
+ """Adjoint representation of P applied to the tangent vector X at the identity.
180
+ """
181
+ return GLp4.group.adjrep(P, X)
182
+
183
+ def exp(self, X):
184
+ """Computes the Lie-theoretic exponential map at X
185
+ """
186
+ return expm(X)
187
+
188
+ def log(self, S):
189
+ """Computes the Lie-theoretic logarithm map at S
190
+ """
191
+ return logm(S)
192
+
193
+ class CartanShoutenConnection(Connection):
194
+ """
195
+ Canonical Cartan-Shouten connection on SE(3) connection.
196
+ """
197
+
198
+ def __init__(self, M):
199
+ """
200
+ Constructor.
201
+ """
202
+ self._M = M
203
+
204
+ def __str__(self):
205
+ return "Canonical Cartan-Shouten connection"
206
+
207
+ def retr(self, S, X):
208
+ return self.exp(S, X)
209
+
210
+ def exp(self, S, X):
211
+ """Computes connection exponential map
212
+ """
213
+ return jnp.einsum('...ij,...jk', expm(X), S)
214
+
215
+ def log(self, P, S):
216
+ """Computes the connection logarithm map
217
+ """
218
+ return logm(jnp.einsum('...ij,...jk', S, GLp4.group.inverse(P)))
219
+
220
+ def curvature_tensor(self, p, X, Y, Z):
221
+ """Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting
222
+ the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
223
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
224
+ is used.
225
+ """
226
+ raise GLp4.connec.curvature_tensor(p, X, Y, Z)
227
+
228
+ def transp(self, P, S, X):
229
+ """Parallel transport for SE(3)^k.
230
+ :param P: element of SE(3)^k
231
+ :param S: element of SE(3)^k
232
+ :param X: tangent vector at P
233
+ :return: parallel transport of X at S
234
+ """
235
+ return GLp4.connec.transp(P, S, X)
236
+
237
+ def pairmean(self, P, S):
238
+ return self.exp(P, 0.5 * self.log(P, S))
239
+
240
+ def jacobiField(self, R, Q, t, X):
241
+ # ### using (forward-mode) automatic differentiation of geopoint(..)
242
+ f = lambda O: self.geopoint(O, Q, t)
243
+ geo, J = jax.jvp(f, (R,), (self._M.group.dright(R, X),))
244
+ J = self._M.group.dright_inv(geo, J)
245
+ J = J.at[:-1, :-1].set(multiskew(J[:-1, :-1]))
246
+ return geo, J
247
+
248
+ class CanonicalRiemannianStructure(Metric):
249
+ """
250
+ Standard (product) Riemannian structure with the canonical right invariant metric that is the product of the
251
+ canonical (bi-invariant) metrics on SO(3) and R^3. The resulting geodesics are products of their geodesics. For
252
+ a reference, see, e.g.,
253
+
254
+ Zefran et al., "Choice of Riemannian Metrics for Rigid Body Kinematics."
255
+
256
+ """
257
+
258
+ def __init__(self, M):
259
+ """
260
+ Constructor.
261
+ """
262
+ self._M = M
263
+
264
+ def __str__(self):
265
+ return "Canonical left invariant metric"
266
+
267
+ @property
268
+ def typicaldist(self):
269
+ return jnp.pi * jnp.sqrt(3)
270
+
271
+ def inner(self, R, X, Y):
272
+ """Product of canonical bi-invariant metrics of SO(3) and R3"""
273
+ return jnp.sum(jnp.einsum('ij,ij', X, Y))
274
+
275
+ def flat(self, R, X):
276
+ """Lower vector X at R with the metric"""
277
+ # return X
278
+ raise NotImplementedError('This function has not been implemented yet.')
279
+
280
+ def sharp(self, R, dX):
281
+ """Raise covector dX at R with the metric"""
282
+ # return dX
283
+ raise NotImplementedError('This function has not been implemented yet.')
284
+
285
+ def egrad2rgrad(self, S, X):
286
+ Y = X
287
+ # translational part is already "Riemannian"
288
+ return Y.at[:3, :3].set(SO.metric.egrad2rgrad(S[:3, :3], X[:3, :3]))
289
+
290
+ def ehess2rhess(self, p, G, H, X):
291
+ """Converts the Euclidean gradient P_G and Hessian H of a function at
292
+ a point p along a tangent vector X to the Riemannian Hessian
293
+ along X on the manifold.
294
+ """
295
+ raise NotImplementedError('This function has not been implemented yet.')
296
+
297
+ def retr(self, S, X):
298
+ return self.exp(S, X)
299
+
300
+ def exp(self, S, X):
301
+ """Computes the Riemannian exponential map
302
+ """
303
+ P = jnp.zeros_like(S)
304
+ P = P.at[3, 3].set(1)
305
+ P = P.at[:3, :3].set(SO.connec.exp(S[:3, :3], X[:3, :3]))
306
+ return P.at[:3, 3].set(S[:3, 3] + X[:3, 3])
307
+
308
+ def log(self, S, P):
309
+ """Computes the Riemannian exponential map
310
+
311
+ Note that tangent vectors are always represented in the Lie Algebra.Thus, the Riemannian and group
312
+ operation coincide.
313
+ """
314
+ X = jnp.zeros_like(S)
315
+ X = X.at[:3, :3].set(SO.connec.log(S[:3, :3], P[:3, :3]))
316
+ return X.at[:3, 3].set(P[:3, 3] - S[:3, 3])
317
+
318
+ def curvature_tensor(self, S, X, Y, Z):
319
+ """Evaluates the curvature tensor R of the connection at S on the vectors X, Y, Z. With nabla_X Y denoting
320
+ the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
321
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
322
+ is used.
323
+ """
324
+ V = jnp.zeros_like(X)
325
+ # translational part it flat
326
+ return V.at[:3, :3].set(SO.connec.curvature_tensor(S[:3, :3], X[:3, :3], Y[:3, :3], Z[:3, :3]))
327
+
328
+ def transp(self, S, P, X):
329
+ """Parallel transport for the canonical Riemannian structure.
330
+ :param R: element of SE(3)
331
+ :param Q: element of SE(3)
332
+ :param X: tangent vector at R
333
+ :return: parallel transport of X at Q
334
+ """
335
+ Y = X
336
+ # translational part has the identity as parallel transport
337
+ return Y.at[:3, :3].set(SO.connec.transp(S[:3, :3], P[:3, :3], X[:3, :3]))
338
+
339
+ def pairmean(self, S, P):
340
+ return self.exp(S, 0.5 * self.log(S, P))
341
+
342
+ def dist(self, S, P):
343
+ """product distance function"""
344
+ return jnp.sqrt(self.squared_dist(S, P))
345
+
346
+ def squared_dist(self, S, P):
347
+ """product squared distance function"""
348
+ return (SO.metric.squared_dist(S[:3, :3], P[:3, :3])
349
+ + R3.metric.squared_dist(S[:3, 3], P[:3, 3]))
350
+
351
+ def projToGeodesic(self, S, P, Q, max_iter=10):
352
+ '''
353
+ :arg S, P: elements of SE(3) defining geodesic S->P.
354
+ :arg Q: element of SE(3) to be projected to S->P.
355
+ :returns: projection of Q to S->P
356
+ '''
357
+ Pi = jnp.zeros_like(S)
358
+ Pi = Pi.at[:3, :3].set(SO.metric.projToGeodesic(S[:3, :3], P[:3, :3], Q[:3, :3], max_iter))
359
+ return Pi.at[:3, 3].set(R3.metric.projToGeodesic(S[:3, 3], P[:3, 3], Q[:3, 3], max_iter))
360
+
361
+ def jacobiField(self, S, P, t, X):
362
+ J = jnp.zeros_like(X)
363
+ J = J.at[:3, :3].set(SO.connec.jacobiField(S[:3, :3], P[:3, :3], t, X[:3, :3]))
364
+ return J.at[:3, 3].set(R3.connec.jacobiField(S[:3, 3], P[:3, 3], t, X[:3, 3]))
365
+
366
+ def adjJacobi(self, S, P, t, X):
367
+ J = jnp.zeros_like(X)
368
+ J = J.at[:3, :3].set(SO.metric.adjJacobi(S[:3, :3], P[:3, :3], t, X[:3, :3]))
369
+ return J.at[:3, 3].set(R3.metric.adjJacobi(S[:3, 3], P[:3, 3], t, X[:3, 3]))
370
+
371
+
372
+
373
+ def logm(P):
374
+ """
375
+ Blanco, J. L. (2010). A tutorial on SE(3) transformation parameterizations and on-manifold optimization.
376
+ University of Malaga, Tech. Rep, 3, 6.
377
+ """
378
+ w = SO3_logm(P[:3, :3])
379
+
380
+ theta2 = .5 * jnp.sum(w ** 2)
381
+ theta = jnp.sqrt(theta2 + jnp.finfo(jnp.float64).eps)
382
+
383
+ Vinv = (jnp.eye(3) - .5 * w
384
+ + jax.lax.cond(theta < 1e-6,
385
+ lambda _, _theta2: 1 / 12 + _theta2 / 720 + _theta2 ** 2 / 30240,
386
+ lambda _theta, _theta2: (1 - jnp.cos(.5 * _theta) / jnp.sinc(
387
+ .5 * _theta / jnp.pi)) / _theta2,
388
+ theta,
389
+ theta2)
390
+ * (w @ w))
391
+
392
+ return P.at[:3, :3].set(w) \
393
+ .at[:3, 3].set(jnp.einsum('ij,j', Vinv, P[:3, 3])) \
394
+ .at[3, 3].set(0)
395
+
396
+
397
+ def expm(X):
398
+ """
399
+ Blanco, J. L. (2010). "A tutorial on SE(3) transformation parameterizations and on-manifold optimization."
400
+ University of Malaga, Tech. Rep, 3, 6.
401
+ """
402
+ R = SO3_expm(X[:3, :3])
403
+
404
+ theta2 = .5 * jnp.sum(X[:3, :3] ** 2)
405
+ theta = jnp.sqrt(theta2 + jnp.finfo(jnp.float64).eps)
406
+
407
+ V = (jnp.eye(3)
408
+ + jax.lax.cond(theta < 1e-6,
409
+ lambda _, _theta2: .5 - _theta2 / 24 + _theta2 ** 2 / 720,
410
+ lambda _theta, _theta2: (1.0 - jnp.cos(_theta)) / _theta2, theta, theta2)
411
+ * X[:3, :3]
412
+ + jax.lax.cond(theta < 1e-6,
413
+ lambda _, _theta2: 1 / 6 - _theta2 / 120 + _theta2 ** 2 / 5040,
414
+ lambda _theta, _theta2: (_theta - jnp.sin(_theta)) / (_theta2 * _theta), theta, theta2)
415
+ * (X[:3, :3] @ X[:3, :3]))
416
+
417
+ return X.at[:3, :3].set(R) \
418
+ .at[:3, 3].set(jnp.einsum('ij,j', V, X[:3, 3])) \
419
+ .at[3, 3].set(1)
@@ -0,0 +1,57 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import abc
14
+
15
+ import jax.random
16
+
17
+ from morphomatics.manifold import Manifold
18
+
19
+
20
+ class ShapeSpace(Manifold):
21
+ """ Abstract base class for shape spaces. """
22
+
23
+ @abc.abstractmethod
24
+ def update_ref_geom(self, v):
25
+ '''
26
+ :arg v: #n-by-3 array of vertex coordinates
27
+ '''
28
+
29
+ @abc.abstractmethod
30
+ def to_coords(self, v):
31
+ '''
32
+ :arg v: #n-by-3 array of vertex coordinates
33
+ :return: manifold coordinates
34
+ '''
35
+
36
+ @abc.abstractmethod
37
+ def from_coords(self, c):
38
+ '''
39
+ :arg c: manifold coords.
40
+ :returns: #n-by-3 array of vertex coordinates
41
+ '''
42
+
43
+ @property
44
+ @abc.abstractmethod
45
+ def ref_coords(self):
46
+ """ :returns: Coordinates of reference shape """
47
+
48
+ @property
49
+ def M(self) -> Manifold:
50
+ """
51
+ :returns: Manifold of shape coordinates
52
+ (might be other than #self and more efficient for JIT due to fewer dependencies).
53
+ """
54
+ return self
55
+
56
+ def randvec(self, X, key: jax.Array):
57
+ return self.connec.log(X, self.rand(key))