morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,413 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+
16
+ from morphomatics.manifold import Manifold, Metric, LieGroup
17
+
18
+
19
+ class PowerManifold(Manifold):
20
+ """ Product manifold M^k consisting of k copies of a single (atom) manifold M """
21
+
22
+ def __init__(self, M: Manifold, k: int, metric_weights: jnp.array = None, structure: str = 'Product'):
23
+ assert metric_weights is None or len(metric_weights) == k
24
+
25
+ point_shape = tuple([k, *M.point_shape])
26
+ name = f'Product of {k} copies of ' + M.__str__() + '.'
27
+ dimension = M.dim * k
28
+ super().__init__(name, dimension, point_shape)
29
+ self._atom_manifold = M
30
+ self._k = k
31
+ self.metric_weights = metric_weights
32
+ if structure:
33
+ getattr(self, f'init{structure}Structure')()
34
+
35
+ def tree_flatten(self):
36
+ children, aux = super().tree_flatten()
37
+ return children + (self.atom_manifold, self.metric_weights), aux + (self.k,)
38
+
39
+ @classmethod
40
+ def tree_unflatten(cls, aux_data, children):
41
+ """Specifies an unflattening recipe for PyTree registration."""
42
+ *children, M, w = children
43
+ *aux_data, k = aux_data
44
+ obj = cls(M, k, w, structure=None)
45
+ obj.tree_unflatten_instance(aux_data, children)
46
+ return obj
47
+
48
+ @property
49
+ def atom_manifold(self) -> Manifold:
50
+ """Return the atom manifold M"""
51
+ return self._atom_manifold
52
+
53
+ @property
54
+ def k(self) -> int:
55
+ """Return the power k"""
56
+ return self._k
57
+
58
+ def ith_component(self, x: jnp.array, i: int) -> jnp.array:
59
+ """Projection to the i-th element for both points and tangent vectors of M^k"""
60
+ return x[i]
61
+
62
+ def initProductStructure(self):
63
+ """
64
+ Instantiate the power manifold with product structure.
65
+ """
66
+ structure = PowerManifold.ProductStructure(self)
67
+ self._metric = structure if self.atom_manifold.metric is not None else None
68
+ self._connec = structure if self.atom_manifold.connec is not None else None
69
+ self._group = structure if self.atom_manifold.group is not None else None
70
+
71
+ def rand(self, key: jax.Array) -> jnp.array:
72
+ """ Random element of the power manifold
73
+ :param key: a PRNG key
74
+ """
75
+ subkeys = jax.random.split(key, self.k)
76
+ return jax.vmap(self.atom_manifold.rand)(subkeys)
77
+
78
+ def randvec(self, p: jnp.array, key: jax.Array) -> jnp.array:
79
+ """Random vector in the tangent space of the point pu
80
+
81
+ :param p: element of M^k
82
+ :param key: a PRNG key
83
+ :return: random tangent vector at p
84
+ """
85
+ subkeys = jax.random.split(key, self.k)
86
+ return jax.vmap(self.atom_manifold.randvec)(p, subkeys)
87
+
88
+ def zerovec(self) -> jnp.array:
89
+ """Zero vector in any tangen space
90
+ """
91
+ return jnp.zeros(self.point_shape)
92
+
93
+ def proj(self, p, z):
94
+ """Project ambient vector onto the power manifold
95
+
96
+ :param p: element of M^k
97
+ :param z: ambient vector
98
+ :return: projection of z to the tangent space at p
99
+ """
100
+ return jax.vmap(self.atom_manifold.proj)(p, z)
101
+
102
+ class ProductStructure(Metric, LieGroup):
103
+ """ Product structure, i.e., product metric, product connection, and, if applicable, product Lie group structure
104
+ on M^k
105
+ """
106
+ def __init__(self, M):
107
+ self._M: PowerManifold = M
108
+
109
+ def __str__(self) -> str:
110
+ return "Product structure"
111
+
112
+ @property
113
+ def atom_mfd(self):
114
+ return self._M.atom_manifold
115
+
116
+ @property
117
+ def weights(self) -> jnp.array:
118
+ return self._M.metric_weights
119
+
120
+ #### metric interface ####
121
+
122
+ @property
123
+ def typicaldist(self) -> jnp.array:
124
+ """Typical distance in the product manifold"""
125
+ if self.weights is None:
126
+ d = jax.vmap(lambda _: self.atom_mfd.metric.typicaldist ** 2)(jnp.arange(self._M.k))
127
+ else:
128
+ d = jax.vmap(lambda lam: lam * self.atom_mfd.metric.typicaldist ** 2)(self.weights)
129
+
130
+ return jnp.sqrt(jnp.sum(d))
131
+
132
+ def inner(self, p: jnp.array, v: jnp.array, w: jnp.array) -> jnp.array:
133
+ """Product metric
134
+
135
+ :param p: element of M^k
136
+ :param v: tangent vector at p
137
+ :param w: tangent vector at p
138
+ :return: inner product of v and w
139
+ """
140
+ i = jax.vmap(self.atom_mfd.metric.inner)(p, v, w)
141
+ if self.weights is not None:
142
+ i = i * self.weights
143
+
144
+ return jnp.sum(i)
145
+
146
+ def dist(self, p: jnp.array, q: jnp.array) -> jnp.array:
147
+ """Distance function of the product metric
148
+
149
+ :param p: element of M^k
150
+ :param q: element of M^k
151
+ :return: distance between p and q
152
+ """
153
+
154
+ return jnp.sqrt(self.squared_dist(p, q))
155
+
156
+ def squared_dist(self, p: jnp.array, q: jnp.array) -> jnp.array:
157
+ """Squared distance function of the product metric
158
+
159
+ :param p: element of M^k
160
+ :param q: element of M^k
161
+ :return: squared distance between p and q
162
+ """
163
+ d2 = jax.vmap(self.atom_mfd.metric.squared_dist)(p, q)
164
+ if self.weights is not None:
165
+ d2 = d2 * self.weights
166
+
167
+ return jnp.sum(d2)
168
+
169
+ def flat(self, p: jnp.array, v: jnp.array) -> jnp.array:
170
+ """Lower vector v at p with the metric
171
+
172
+ :param p: element of M^k
173
+ :param v: tangent vector at p
174
+ :return: covector at p
175
+ """
176
+ dv = jax.vmap(self.atom_mfd.metric.flat)(p, v)
177
+ if self.weights is not None:
178
+ dv = dv * self.weights.reshape((-1,) + (1,) * len(self.atom_mfd.point_shape))
179
+
180
+ return dv
181
+
182
+ def sharp(self, p: jnp.array, dv: jnp.array) -> jnp.array:
183
+ """Raise covector dv at p with the metric
184
+
185
+ :param p: element of M^k
186
+ :param dv: covector at p
187
+ :return: tangent vector at p
188
+ """
189
+ v = jax.vmap(self.atom_mfd.metric.flat)(p, dv)
190
+ if self.weights is not None:
191
+ dv = dv / self.weights.reshape((-1,) + (1,) * len(self.atom_mfd.point_shape))
192
+
193
+ return dv
194
+
195
+ def adjJacobi(self, p: jnp.array, q: jnp.array, t: float, v: jnp.array) -> jnp.array:
196
+ """
197
+ Evaluates an adjoint Jacobi field along the geodesic gam from p to q. X is a vector at gam(t)
198
+
199
+ :param p: element of M^k
200
+ :param q: element of M^k
201
+ :param t: scalar in [0,1]
202
+ :param v: tangent vector at gam(t)
203
+ :return: tangent vector at p
204
+ """
205
+ if self.weights is None:
206
+ return jax.vmap(self.atom_mfd.metric.adjJacobi, (0, 0, None, 0))(p, q, t, v)
207
+ else:
208
+ raise NotImplementedError('This function has not been implemented yet for non-trivial metric weights.')
209
+
210
+ def egrad2rgrad(self, p: jnp.array, z: jnp.array) -> jnp.array:
211
+ """Transform the Euclidean gradient of a function into the corresponding Riemannian gradient, i.e.,
212
+ directions pointing away from the manifold are removed
213
+
214
+ :param p: element of M^k
215
+ :param z: Euclidean gradient at p
216
+ :return: Riemannian gradient at p
217
+ """
218
+ g = jax.vmap(self.atom_mfd.metric.egrad2rgrad)(p, z)
219
+ if self.weights is not None:
220
+ g = g / self.weights.reshape((-1,) + (1,) * len(self.atom_mfd.point_shape))
221
+
222
+ return g
223
+
224
+ def ehess2rhess(self, pu: jnp.array, G: jnp.array, H: jnp.array, vw: jnp.array) -> jnp.array:
225
+ """Converts the Euclidean gradient G and Hessian H of a function at
226
+ a point pv along a tangent vector uw to the Riemannian Hessian
227
+ along X on the manifold.
228
+ """
229
+ if self.weights is None:
230
+ return jax.vmap(self.atom_mfd.metric.ehess2rhess)(pu, G, H, vw)
231
+ else:
232
+ raise NotImplementedError('This function has not been implemented yet for non-trivial metric weights.')
233
+
234
+ #### connection interface ####
235
+
236
+ # Note that the Levi-Civita connection does not change under a constant re-scaling of the metric.
237
+ # (This can, e.g., be deduced from the Koszul formula.) Therefore, all notions that only depend on the metric
238
+ # implicitly through the connection are not influenced by metric weights.
239
+
240
+ def exp(self, p: jnp.array, v: jnp.array) -> jnp.array:
241
+ """Riemannian exponential
242
+
243
+ :param p: element of M^k
244
+ :param v: tangent vector at p
245
+ :return: point at time 1 of the geodesic that starts at p with initial velocity v
246
+ """
247
+ return jax.vmap(self.atom_mfd.connec.exp)(p, v)
248
+
249
+ retr = exp
250
+
251
+ def log(self, p: jnp.array, q: jnp.array) -> jnp.array:
252
+ """Riemannian logarithm
253
+
254
+ :param p: element of M^k
255
+ :param q: element of M^k
256
+ :return: vector at p with exp(p, v) = q
257
+ """
258
+ return jax.vmap(self.atom_mfd.connec.log)(p, q)
259
+
260
+ def geopoint(self, p: jnp.array, q: jnp.array, t: float) -> jnp.array:
261
+ """Geodesic map
262
+
263
+ :param p: element of M^k
264
+ :param q: element of M^k
265
+ :param t: scalar between 0 and 1
266
+ :return: element of M^k on that is reached in the geodesic between p and q at time t
267
+ """
268
+ return jax.vmap(self.atom_mfd.connec.geopoint, (0, 0, None))(p, q, t)
269
+
270
+ def transp(self, p: jnp.array, q: jnp.array, v: jnp.array) -> jnp.array:
271
+ """Parallel transport map
272
+
273
+ :param p: element of M^k
274
+ :param q: element of M^k
275
+ :param v: tangent vector at p
276
+ :return: tangent vector at q that is the parallel transport of v along the geodesic from p to q
277
+ """
278
+ return jax.vmap(self.atom_mfd.connec.transp)(p, q, v)
279
+
280
+ def pairmean(self, p: jnp.array, q: jnp.array) -> jnp.array:
281
+ """Pair-wise mean
282
+
283
+ :param p: element of M^k
284
+ :param q: element of M^k
285
+ :return: mean of p and q
286
+ """
287
+ return jax.vmap(self.atom_mfd.connec.pairmean)(p, q)
288
+
289
+ def curvature_tensor(self, p: jnp.array, v: jnp.array, w: jnp.array, x: jnp.array) -> jnp.array:
290
+ """Curvature tensor
291
+
292
+ :param p: element of M^k
293
+ :param v: tangent vector at p
294
+ :param w: tangent vector at p
295
+ :param x: tangent vector at p
296
+ :return: tangent vector at p that is the value R(v,w)x of the curvature tensor
297
+ """
298
+ return jax.vmap(self.atom_mfd.connec.curvature_tensor)(p, v, w, x)
299
+
300
+ def jacobiField(self, p: jnp.array, q: jnp.array, t: float, X: jnp.array) -> jnp.array:
301
+ """
302
+ Evaluates a Jacobi field (with boundary conditions gam'(0) = X, gam'(1) = 0) along the geodesic gam from p to
303
+ q.
304
+
305
+ :param p: element of M^k
306
+ :param q: element of M^k
307
+ :param t: scalar in [0,1]
308
+ :param X: tangent vector at p
309
+ :return: [gam(t), J], where J is the value of the Jacobi field (which is an element-wise Jacobi field) at gam(t)
310
+ """
311
+ return jax.vmap(self.atom_mfd.connec.jacobiField, (0, 0, None, 0))(p, q, t, X)
312
+
313
+ #### group interface ####
314
+
315
+ def identity(self) -> jnp.array:
316
+ """Identity element"""
317
+
318
+ return jax.vmap(self.atom_mfd.group.identity)(jnp.arange(self._M.k))
319
+
320
+ def coords(self, v: jnp.array) -> jnp.array:
321
+ """Coordinate map for the tangent space at the identity
322
+
323
+ :param v: tangent vector at the identity
324
+ :return: vector of coordinates
325
+ """
326
+ c = jax.vmap(self.atom_mfd.group.coords)(v)
327
+ return c.reshape(-1)
328
+
329
+ def coords_inverse(self, X):
330
+ return jax.vmap(self.atom_mfd.group.coords_inverse)(X.reshape(self._M.k, -1))
331
+
332
+ def bracket(self, v: jnp.array, w: jnp.array) -> jnp.array:
333
+ """Lie bracket in Lie algebra
334
+
335
+ :param v: tangent vector at the identity
336
+ :param w: tangent vector at the identity
337
+ :return: tangent vector at the identity that is the Lie bracket of v and w
338
+ """
339
+ return jax.vmap(self.atom_mfd.group.bracket)(v, w)
340
+
341
+ def lefttrans(self, g: jnp.array, f: jnp.array) -> jnp.array:
342
+ """Left translation of g by f
343
+
344
+ :param g: element of the Lie group M^k
345
+ :param f: element of the Lie group M^k
346
+ :return: left-translated element
347
+ """
348
+ return jax.vmap(self.atom_mfd.lefttrans)(g, f)
349
+
350
+ def righttrans(self, g: jnp.array, f: jnp.array) -> jnp.array:
351
+ """Right translation of g by f
352
+
353
+ :param g: element of the Lie group M^k
354
+ :param f: element of the Lie group M^k
355
+ :return: right-translated element
356
+ """
357
+ return jax.vmap(self.atom_mfd.righttrans)(g, f)
358
+
359
+ def inverse(self, g: jnp.array) -> jnp.array:
360
+ """Inverse map of the Lie group
361
+
362
+ :param g: element of the Lie group M^k
363
+ :return: element of M^k that is inverse to g
364
+ """
365
+ return jax.vmap(self.atom_mfd.inverse)(g)
366
+
367
+ # the group exponential and logarithm are given by the connection group and logarithm (by using them with a
368
+ # different number of arguments).
369
+
370
+ def dleft(self, f: jnp.array, v: jnp.array) -> jnp.array:
371
+ """Derivative of the left translation by f at the identity applied to the tangent vector v
372
+
373
+ :param f: element of the Lie group M^k
374
+ :param v: tangent vector at the identity
375
+ :return: left-translated tangent vector represented at the identity
376
+ """
377
+ return jax.vmap(self.atom_mfd.dleft)(f, v)
378
+
379
+ def dright(self, f: jnp.array, v: jnp.array) -> jnp.array:
380
+ """Derivative of the right translation by f at e applied to the tangent vector v
381
+
382
+ :param f: element of the Lie group M^k
383
+ :param v: tangent vector at the identity
384
+ :return: right-translated tangent vector represented at the identity
385
+ """
386
+ return jax.vmap(self.atom_mfd.dright)(f, v)
387
+
388
+ def dleft_inv(self, f: jnp.array, v: jnp.array) -> jnp.array:
389
+ """Derivative of the left translation by f^{-1} at f applied to the tangent vector v
390
+
391
+ :param f: element of the Lie group M^k
392
+ :param v: tangent vector at the identity
393
+ :return: translated vector represented at the identity
394
+ """
395
+ return jax.vmap(self.atom_mfd.dleft_inv)(f, v)
396
+
397
+ def dright_inv(self, f: jnp.array, v: jnp.array) -> jnp.array:
398
+ """Derivative of the right translation by f^{-1} at f applied to the tangent vector v
399
+
400
+ :param f: element of the Lie group M^k
401
+ :param v: tangent vector at the identity
402
+ :return: translated vector represented at the identity
403
+ """
404
+ return jax.vmap(self.atom_mfd.dright_inv)(f, v)
405
+
406
+ def adjrep(self, g: jnp.array, v: jnp.array) -> jnp.array:
407
+ """Adjoint representation of g applied to the tangent vector v at the identity
408
+
409
+ :param g: element of the Lie group M^k
410
+ :param v: tangent vector at the identity
411
+ :return: tangent vector at the identity
412
+ """
413
+ return jax.vmap(self.atom_mfd.adjrep)(g, v)