morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,149 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax.random
14
+ import numpy as np
15
+ import jax.numpy as jnp
16
+
17
+ from scipy import sparse
18
+
19
+ try:
20
+ from sksparse.cholmod import cholesky as direct_solve
21
+ except:
22
+ from scipy.sparse.linalg import factorized as direct_solve
23
+
24
+ from ..geom import Surface
25
+ from . import GLpn
26
+ from . import ShapeSpace
27
+ from .util import align
28
+
29
+
30
+ class GLpCoords(ShapeSpace):
31
+ """
32
+ Shape space based the group of matrices with positive determinant.
33
+
34
+ See:
35
+ Felix Ambellan, Stefan Zachow and Christoph von Tycowicz.
36
+ An as-invariant-as-possible GL+(3)-based statistical shape model.
37
+ Proc. 7th MICCAI workshop on Mathematical Foundations of Computational Anatomy, pp. 219--228, 2019.
38
+ """
39
+
40
+ def __init__(self, reference: Surface):
41
+ """
42
+ :arg reference: Reference surface (shapes will be encoded as deformations thereof)
43
+ """
44
+ assert reference is not None
45
+ self.ref = reference
46
+ k = len(self.ref.f)
47
+
48
+ self.update_ref_geom(self.ref.v)
49
+
50
+ self.GLp = GLpn(k, structure='AffineGroup')
51
+
52
+ name = 'Shape Space based on the orientation preserving component of the general linear group'
53
+ super().__init__(name, self.GLp.dim, self.GLp.point_shape, self.GLp.connec, None, self.GLp.group)
54
+
55
+ def tree_flatten(self):
56
+ return tuple(), (self.ref.v.tolist(), self.ref.f.tolist())
57
+
58
+ @classmethod
59
+ def tree_unflatten(cls, aux_data, children):
60
+ """Specifies an unflattening recipe for PyTree registration."""
61
+ return cls(Surface(*aux_data))
62
+
63
+ @property
64
+ def M(self):
65
+ return self.GLp
66
+
67
+ @property
68
+ def n_triangles(self):
69
+ return len(self.ref.f)
70
+
71
+ def update_ref_geom(self, v):
72
+ self.ref.v = v
73
+
74
+ # center of gravity
75
+ self.CoG = self.ref.v.mean(axis=0)
76
+
77
+ # setup Poisson system
78
+ S = self.ref.div @ self.ref.grad
79
+ # add soft-constraint fixing translational DoF
80
+ S += sparse.coo_matrix(([1.0], ([0], [0])), S.shape) # make pos-def
81
+ self.poisson = direct_solve(S.tocsc())
82
+
83
+ def to_coords(self, v):
84
+ """
85
+ :arg v: #v-by-3 array of vertex coordinates
86
+ :return: GLp coords.
87
+ """
88
+
89
+ # align
90
+ v = align(v, self.ref.v)
91
+
92
+ # compute gradients
93
+ D = self.ref.grad @ v
94
+
95
+ # D holds transpose of def. grads.
96
+ # decompose...
97
+ U, S, Vt = np.linalg.svd(D.reshape(-1, 3, 3))
98
+
99
+ # ...rotation
100
+ R = np.einsum('...ij,...jk', U, Vt)
101
+ W = np.ones_like(S)
102
+ W[:, -1] = np.linalg.det(R)
103
+ R = np.einsum('...ij,...j,...jk', U, W, Vt)
104
+
105
+ # ...stretch
106
+ S[:, -1] = 1 # no stretch (=1) in normal direction
107
+ # for degenerate triangles
108
+ # TODO: check which direction is normal in degenerate case
109
+ S[S < 1e-6] = 1e-6
110
+ U = np.einsum('...ij,...j,...kj', U, S, U)
111
+ return jnp.einsum('...ij,...jl', R, U)
112
+
113
+ def from_coords(self, D):
114
+ """
115
+ :arg D: GLp coords.
116
+ :returns: #v-by-3 array of vertex coordinates
117
+ """
118
+ # solve Poisson system
119
+ rhs = self.ref.div @ D.reshape(-1, 3)
120
+ v = self.poisson(rhs)
121
+ # move to CoG
122
+ v += self.CoG - v.mean(axis=0)
123
+
124
+ return v
125
+
126
+ @property
127
+ def ref_coords(self):
128
+ """ Identity coordinates (i.e., the reference shape).
129
+ """
130
+ return self.group.identity
131
+
132
+ def rand(self, key: jax.Array):
133
+ """Random set of coordinates () won't represent a 'nice' shape).
134
+ """
135
+ return self.GLp.rand(key)
136
+
137
+ def randvec(self, A, key: jax.Array):
138
+ return self.GLp.randvec(A, key)
139
+
140
+ def zerovec(self):
141
+ """Zero tangent vector in any tangent space.
142
+ """
143
+ return self.GLp.zerovec()
144
+
145
+ def proj(self, p, X):
146
+ return X
147
+
148
+ def geopoint(self, A, B, t):
149
+ return self.GLp.connec.geopoint(A, B, t)
@@ -0,0 +1,201 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+ from jax.scipy.linalg import expm, funm
16
+
17
+ from morphomatics.manifold import Manifold, LieGroup, Connection
18
+
19
+
20
+ class GLpn(Manifold):
21
+ """Returns the Lie group GL^+(n), i.e., the group of n-by-n matrices each with positive determinant.
22
+
23
+ manifold = GLpn(n)
24
+
25
+ Elements of GL^+(n) are represented as arrays of size nxn.
26
+
27
+ # NOTE: Tangent vectors are represented as left translations in the Lie algebra, i.e., a tangent vector X at g is
28
+ represented as d_gL_{g^(-1)}(X)
29
+ """
30
+
31
+ def __init__(self, n=3, structure='AffineGroup'):
32
+ self._n = n
33
+
34
+ name = 'Orientation preserving maps of R^n'
35
+
36
+ super().__init__(name, n**2, point_shape=(n, n))
37
+
38
+ if structure:
39
+ getattr(self, f'init{structure}Structure')()
40
+
41
+ def tree_flatten(self):
42
+ children, aux = super().tree_flatten()
43
+ return children, aux+(self._n,)
44
+
45
+ @classmethod
46
+ def tree_unflatten(cls, aux_data, children):
47
+ """Specifies an unflattening recipe for PyTree registration."""
48
+ *aux_data, n = aux_data
49
+ obj = cls(n, structure=None)
50
+ obj.tree_unflatten_instance(aux_data, children)
51
+ return obj
52
+
53
+ @property
54
+ def n(self):
55
+ return self._n
56
+
57
+ def rand(self, key: jax.Array):
58
+ """Returns a random point in the Lie group. This does not
59
+ follow a specific distribution."""
60
+ A = jax.random.normal(key, self.point_shape)
61
+ return expm(A)
62
+
63
+ def randvec(self, A, key: jax.Array):
64
+ """Returns a random vector in the tangent space at A.
65
+ """
66
+ return jax.random.normal(key, self.point_shape)
67
+
68
+ def zerovec(self):
69
+ """Returns the zero vector in the tangent space at g."""
70
+ return jnp.zeros((self.n, self.n))
71
+
72
+ def proj(self, p, X):
73
+ return X
74
+
75
+ def initAffineGroupStructure(self):
76
+ """
77
+ Standard group structure with canonical Cartan Shouten connection.
78
+ """
79
+ structure = GLpn.AffineGroupStructure(self)
80
+ self._connec = structure
81
+ self._group = structure
82
+
83
+ class AffineGroupStructure(Connection, LieGroup):
84
+ """
85
+ Standard group structure on GL+(n) where the composition of two elements is given by component-wise matrix
86
+ multiplication. The connection is the corresponding canonical Cartan Shouten (CCS) connection. No Riemannian
87
+ metric is used.
88
+ """
89
+
90
+ def __init__(self, M):
91
+ """
92
+ Constructor.
93
+ """
94
+ self._M = M
95
+
96
+ def __str__(self):
97
+ return 'standard group structure on GL+(n) with CCS connection'
98
+
99
+ # Group
100
+
101
+ @property
102
+ def identity(self):
103
+ """Returns the identity element e of the Lie group."""
104
+ return jnp.eye(self._M.n)
105
+
106
+ def bracket(self, X, Y):
107
+ """Lie bracket in Lie algebra."""
108
+ return jnp.einsum('ij,jl->il', X, Y) - jnp.einsum('ij,jl->il', Y, X)
109
+
110
+ def coords(self, X):
111
+ """Coordinate map for the tangent space at the identity."""
112
+ return jnp.reshape(X, (self._M.n ** 2, 1))
113
+
114
+ def coords_inverse(self, X):
115
+ raise NotImplementedError('This function has not been implemented yet.')
116
+
117
+ def lefttrans(self, g, f):
118
+ """Left translation of g by f.
119
+ """
120
+ return jnp.einsum('ij,jl->il', f, g)
121
+
122
+ def righttrans(self, g, f):
123
+ """Right translation of g by f.
124
+ """
125
+ return jnp.einsum('ij,jl->il', g, f)
126
+
127
+ def inverse(self, g):
128
+ """Inverse map of the Lie group.
129
+ """
130
+ return jnp.linalg.inv(g)
131
+
132
+ def exp(self, *argv):
133
+ """Computes the Lie-theoretic and connection exponential map
134
+ (depending on signature, i.e. whether footpoint is given as well)
135
+ """
136
+ return jax.lax.cond(len(argv) == 1,
137
+ lambda A: A[-1], # group exp
138
+ lambda A: jnp.einsum('ij,jk', A[-1], A[0]), # exp of CCS connection
139
+ (argv[0], expm(argv[-1])))
140
+
141
+ retr = exp
142
+
143
+ def log(self, *argv):
144
+ """Computes the Lie-theoretic and connection logarithm map
145
+ (depending on signature, i.e. whether footpoint is given as well)
146
+ """
147
+ # NOTE: as logm() is not available in jax we apply log via funm() (so far this is CPU only; not as stable as
148
+ # logm in numpy)
149
+ logm = lambda m: jnp.real(funm(m, jnp.log))
150
+ return logm(jax.lax.cond(len(argv) == 1,
151
+ lambda A: A[-1],
152
+ lambda A: jnp.einsum('ij,kj', A[-1], A[0]),
153
+ argv))
154
+
155
+ def curvature_tensor(self, f, X, Y, Z):
156
+ """Evaluates the curvature tensor R of the connection at f on the vectors X, Y, Z. With nabla_X Y denoting
157
+ the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
158
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
159
+ is used.
160
+ """
161
+ return - 1 / 4 * self.bracket(self.bracket(X, Y), Z)
162
+
163
+ def dleft(self, f, X):
164
+ """Derivative of the left translation by f applied to the tangent vector X at the identity.
165
+ """
166
+ return jnp.einsum('ij,jl->il', f, X)
167
+
168
+ def dright(self, f, X):
169
+ """Derivative of the right translation by f at g applied to the tangent vector X.
170
+ """
171
+ return jnp.einsum('ij,jl->il', X, f)
172
+
173
+ def dleft_inv(self, f, X):
174
+ """Derivative of the left translation by f^{-1} at f applied to the tangent vector X.
175
+ """
176
+ return jnp.einsum('ij,jl->il', self.inverse(f), X)
177
+
178
+ def dright_inv(self, f, X):
179
+ """Derivative of the right translation by f^{-1} at f applied to the tangent vector X.
180
+ """
181
+ return jnp.einsum('ij,jl->il', X, self.inverse(f))
182
+
183
+ def adjrep(self, g, X):
184
+ """Adjoint representation of g applied to the tangent vector X at the identity.
185
+ """
186
+ return jnp.einsum('ij,jl,lm->im', g, X, self.inverse(g))
187
+
188
+ def transp(self, f, g, X):
189
+ """
190
+ Parallel transport of the CCS connection along one-parameter subgroups; see Sec. 5.3.3 of
191
+ X. Pennec and M. Lorenzi,
192
+ "Beyond Riemannian geometry: The affine connection setting for transformation groups."
193
+
194
+ """
195
+ f_invg = self.lefttrans(g, self.inverse(f))
196
+ h = self.geopoint(self.identity, f_invg, .5)
197
+
198
+ return self.dleft_inv(f_invg, self.dleft(h, self.dright(h, X)))
199
+
200
+ def jacobiField(self, R, Q, t, X):
201
+ raise NotImplementedError('This function has not been implemented yet.')
@@ -0,0 +1,174 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ from morphomatics.manifold import Manifold, Metric
14
+ from morphomatics.manifold.discrete_ops import pole_ladder
15
+
16
+ import jax
17
+ import jax.numpy as jnp
18
+ from jax.numpy.linalg import svd
19
+
20
+
21
+ class Grassmann(Manifold):
22
+ """ The Grassmannian.
23
+ Manifold of m-dimensional subspaces of n dimensional real vector space
24
+ Elements are represented as n x m matrices.
25
+ """
26
+ def __init__(self, n=3, m=1, structure='Canonical'):
27
+ if n < m or m < 1:
28
+ raise ValueError(
29
+ "Need n >= p >= 1. Values supplied were n = {n} and m = {m}"
30
+ )
31
+
32
+ name = 'Grassmann manifold Gr({n},{m})'.format(n=n, m=m)
33
+ dimension = int(n * m - m ** 2)
34
+ self._n = n
35
+ self._m = m
36
+ super().__init__(name, dimension, point_shape=(n, m))
37
+ if structure:
38
+ getattr(self, f'init{structure}Structure')()
39
+
40
+ def tree_flatten(self):
41
+ children, aux = super().tree_flatten()
42
+ return children, aux+(self.point_shape,)
43
+
44
+ @classmethod
45
+ def tree_unflatten(cls, aux_data, children):
46
+ """Specifies an unflattening recipe for PyTree registration."""
47
+ *aux_data, shape = aux_data
48
+ obj = cls(*shape, structure=None)
49
+ obj.tree_unflatten_instance(aux_data, children)
50
+ return obj
51
+
52
+ def initCanonicalStructure(self):
53
+ """
54
+ Instantiate Grassmannian with canonical structure.
55
+ """
56
+ structure = Grassmann.CanonicalStructure(self)
57
+ self._metric = structure
58
+ self._connec = structure
59
+
60
+ # Generate random Grassmann point using qr of random normally distributed matrix.
61
+ def rand(self, key: jax.Array):
62
+ Q, _ = jnp.linalg.qr(jax.random.normal(key, self.point_shape))
63
+ return Q
64
+
65
+ def randvec(self, p, key: jax.Array):
66
+ U = jax.random.normal(key, p.shape)
67
+ U = U - jnp.einsum('ij,kj,kl', p, p, U)
68
+ U = U / jnp.linalg.norm(U)
69
+ return U
70
+
71
+ def zerovec(self):
72
+ return jnp.zeros(self.point_shape)
73
+
74
+ @staticmethod
75
+ def project(p):
76
+ """Project arbitrary matrix to the manifold."""
77
+ return jnp.linalg.qr(p)[0]
78
+
79
+ def proj(self, p, U):
80
+ """Project ambient tangent vector to tangent space at p."""
81
+ return U - jnp.einsum('ij,kj,kl', p, p, U)
82
+
83
+ class CanonicalStructure(Metric):
84
+ """
85
+ The Riemannian metric used is the induced metric from the embedding space (R^nxm)^k, i.e., this manifold is a
86
+ Riemannian submanifold of (R^nxm)^k endowed with the usual trace inner product.
87
+ """
88
+ def __init__(self, M):
89
+ """
90
+ Constructor.
91
+ """
92
+ self._M = M
93
+
94
+ def __str__(self):
95
+ return "Canonical structure of Grassmannian"
96
+
97
+ @property
98
+ def typicaldist(self):
99
+ return jnp.sqrt(jnp.prod(self._M.point_shape[-2:]))
100
+
101
+ # Geodesic distance for Grassmann
102
+ def dist(self, p, q):
103
+ s = svd(jnp.einsum('ji,jk', p, q), compute_uv=False)
104
+ s = jnp.arccos(jnp.min(s))
105
+ return jnp.linalg.norm(s)
106
+
107
+ def inner(self, p, G, H):
108
+ # Inner product (Riemannian metric) on the tangent space
109
+ # For the Grassmann this is the Frobenius inner product.
110
+ return jnp.einsum('ij,ij', G, H)
111
+
112
+ def flat(self, p, G):
113
+ raise NotImplementedError('This function has not been implemented yet.')
114
+
115
+ def sharp(self, p, dG):
116
+ raise NotImplementedError('This function has not been implemented yet.')
117
+
118
+ def egrad2rgrad(self, p, X):
119
+ return self._M.proj(p, X)
120
+
121
+ def ehess2rhess(self, p, G, H, X):
122
+ # Convert Euclidean into Riemannian Hessian.
123
+ xpG = jnp.einsum('ij,kj,kl', X, p, G)
124
+ return self._M.proj(p, H) - xpG
125
+
126
+ def retr(self, X, G):
127
+ # We do not need to worry about flipping signs of columns here,
128
+ # since only the column space is important, not the actual
129
+ # columns. Compare this with the Stiefel manifold.
130
+
131
+ # Compute the polar factorization of Y = X+P_G
132
+ u, _, vt = svd(X + G, full_matrices=False)
133
+ return u @ vt
134
+
135
+ def norm(self, p, G):
136
+ # Norm on the tangent space is simply the Euclidean norm.
137
+ return jnp.linalg.norm(G)
138
+
139
+ def transp(self, p1, p2, d):
140
+
141
+ return pole_ladder(self._M, p1, p2, d)
142
+
143
+ def exp(self, p, U):
144
+ u, s, vt = svd(U, full_matrices=False)
145
+
146
+ Y = jnp.einsum('ij,kj,k,kl', p, vt, jnp.cos(s), vt) + \
147
+ jnp.einsum('ij,j,jk', u, jnp.sin(s), vt)
148
+
149
+ # From numerical experiments, it seems necessary to
150
+ # re-orthonormalize. This is overall quite expensive.
151
+ Y, _ = jnp.linalg.qr(Y)
152
+ return Y
153
+
154
+ def log(self, p, q):
155
+ qtp = q.T @ p
156
+ At = q.T - qtp @ p.T
157
+ Bt = jnp.linalg.solve(qtp, At)
158
+ u, s, vt = svd(Bt.T, full_matrices=False)
159
+
160
+ return jnp.einsum('ij,j,jk', u, jnp.arctan(s), vt)
161
+
162
+ def curvature_tensor(self, p, X, Y, Z):
163
+ """Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
164
+ covariant derivative of Y in direction X and [] being the Lie bracket, the convention
165
+ R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
166
+ is used.
167
+ """
168
+ raise NotImplementedError('This function has not been implemented yet.')
169
+
170
+ def jacobiField(self, p, q, t, X):
171
+ raise NotImplementedError('This function has not been implemented yet.')
172
+
173
+ def adjJacobi(self, p, q, t, X):
174
+ raise NotImplementedError('This function has not been implemented yet.')