morphomatics 4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. morphomatics/__init__.py +13 -0
  2. morphomatics/geom/__init__.py +16 -0
  3. morphomatics/geom/bezier_spline.py +361 -0
  4. morphomatics/geom/misc.py +104 -0
  5. morphomatics/geom/surface.py +208 -0
  6. morphomatics/graph/__init__.py +13 -0
  7. morphomatics/graph/operators.py +124 -0
  8. morphomatics/manifold/__init__.py +46 -0
  9. morphomatics/manifold/bezierfold.py +500 -0
  10. morphomatics/manifold/connection.py +105 -0
  11. morphomatics/manifold/cubic_bezierfold.py +305 -0
  12. morphomatics/manifold/differential_coords.py +197 -0
  13. morphomatics/manifold/discrete_ops.py +56 -0
  14. morphomatics/manifold/euclidean.py +213 -0
  15. morphomatics/manifold/fundamental_coords.py +440 -0
  16. morphomatics/manifold/gl_p_coords.py +149 -0
  17. morphomatics/manifold/gl_p_n.py +201 -0
  18. morphomatics/manifold/grassmann.py +174 -0
  19. morphomatics/manifold/hyperbolic_space.py +271 -0
  20. morphomatics/manifold/kendall.py +269 -0
  21. morphomatics/manifold/lie_group.py +102 -0
  22. morphomatics/manifold/manifold.py +162 -0
  23. morphomatics/manifold/manopt_wrapper.py +185 -0
  24. morphomatics/manifold/metric.py +110 -0
  25. morphomatics/manifold/point_distribution_model.py +143 -0
  26. morphomatics/manifold/power_manifold.py +413 -0
  27. morphomatics/manifold/product_manifold.py +381 -0
  28. morphomatics/manifold/se_3.py +419 -0
  29. morphomatics/manifold/shape_space.py +57 -0
  30. morphomatics/manifold/so_3.py +494 -0
  31. morphomatics/manifold/spd.py +524 -0
  32. morphomatics/manifold/sphere.py +241 -0
  33. morphomatics/manifold/tangent_bundle.py +337 -0
  34. morphomatics/manifold/util.py +126 -0
  35. morphomatics/nn/__init__.py +15 -0
  36. morphomatics/nn/flow_layers.py +219 -0
  37. morphomatics/nn/tangent_layers.py +176 -0
  38. morphomatics/nn/train.py +202 -0
  39. morphomatics/nn/wFM_layers.py +152 -0
  40. morphomatics/opt/__init__.py +14 -0
  41. morphomatics/opt/riemannian_newton_raphson.py +65 -0
  42. morphomatics/opt/riemannian_steepest_descent.py +61 -0
  43. morphomatics/stats/__init__.py +18 -0
  44. morphomatics/stats/biinvariant_statistics.py +190 -0
  45. morphomatics/stats/exponential_barycenter.py +78 -0
  46. morphomatics/stats/geometric_median.py +89 -0
  47. morphomatics/stats/principal_geodesic_analysis.py +135 -0
  48. morphomatics/stats/riemannian_regression.py +317 -0
  49. morphomatics/stats/statistical_shape_model.py +99 -0
  50. morphomatics-4.0.dist-info/LICENSE +9 -0
  51. morphomatics-4.0.dist-info/METADATA +55 -0
  52. morphomatics-4.0.dist-info/RECORD +54 -0
  53. morphomatics-4.0.dist-info/WHEEL +5 -0
  54. morphomatics-4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,213 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import numpy as np
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+ from morphomatics.manifold import Manifold, Metric, LieGroup
19
+
20
+
21
+ class Euclidean(Manifold):
22
+ """The Euclidean space
23
+ """
24
+
25
+ def __init__(self, point_shape=(3,), structure='Canonical'):
26
+ name = 'Euclidean space of dimension ' + 'x'.join(map(str, point_shape))
27
+ dimension = np.prod(point_shape)
28
+ super().__init__(name, dimension, point_shape)
29
+ if structure:
30
+ getattr(self, f'init{structure}Structure')()
31
+
32
+ def tree_flatten(self):
33
+ children, aux = super().tree_flatten()
34
+ return children, aux+(self.point_shape,)
35
+
36
+ @classmethod
37
+ def tree_unflatten(cls, aux_data, children):
38
+ """Specifies an unflattening recipe for PyTree registration."""
39
+ *aux_data, shape = aux_data
40
+ obj = cls(shape, structure=None)
41
+ obj.tree_unflatten_instance(aux_data, children)
42
+ return obj
43
+
44
+ def initCanonicalStructure(self):
45
+ """
46
+ Instantiate Euclidean space with canonical structure.
47
+ """
48
+ structure = Euclidean.CanonicalStructure(self)
49
+ self._metric = structure
50
+ self._connec = structure
51
+ self._group = structure
52
+
53
+ def rand(self, key: jax.Array):
54
+ return jax.random.normal(key, self.point_shape)
55
+
56
+ def randvec(self, X, key: jax.Array):
57
+ return jax.random.normal(key, self.point_shape)
58
+
59
+ def zerovec(self):
60
+ return jnp.zeros(self.point_shape)
61
+
62
+ def proj(self, x, X):
63
+ return X
64
+
65
+ class CanonicalStructure(Metric, LieGroup):
66
+ """
67
+ The Riemannian metric used is the induced metric from the embedding space (R^nxn)^k, i.e., this manifold is a
68
+ Riemannian submanifold of (R^nxn)^k endowed with the usual trace inner product.
69
+ """
70
+
71
+ def __init__(self, M):
72
+ """
73
+ Constructor.
74
+ """
75
+ self._M = M
76
+
77
+ def __str__(self):
78
+ return "Canonical euclidean structure"
79
+
80
+ @property
81
+ def typicaldist(self):
82
+ return jnp.sqrt(self._M.dim)
83
+
84
+ def inner(self, x, X, Y):
85
+ return euclidean_inner(X, Y)
86
+
87
+ def flat(self, p, X):
88
+ return X
89
+
90
+ def sharp(self, p, dX):
91
+ return dX
92
+
93
+ def norm(self, x, X):
94
+ return jnp.linalg.norm(X)
95
+
96
+ def egrad2rgrad(self, x, X):
97
+ return X
98
+
99
+ def ehess2rhess(self, x, G, H, X):
100
+ """Converts the Euclidean gradient P_G and Hessian H of a function at
101
+ a point p along a tangent vector X to the Riemannian Hessian
102
+ along X on the manifold.
103
+ """
104
+ return H
105
+
106
+ def retr(self, x, X):
107
+ return self.exp(x, X)
108
+
109
+ def exp(self, *argv):
110
+ """Computes the Lie-theoretic and Riemannian logarithmic map
111
+ (depending on signature, i.e. whether footpoint is given as well)
112
+
113
+ Note that tangent vectors are always represented in the Lie Algebra.Thus, the Riemannian and group
114
+ operation coincide.
115
+ """
116
+ return jax.lax.cond(len(argv) == 1,
117
+ lambda A: A[-1],
118
+ lambda A: A[-1] + A[0],
119
+ (argv[0], argv[-1]))
120
+
121
+ def log(self, *argv):
122
+ """Computes the Lie-theoretic and Riemannian exponential map
123
+ (depending on signature, i.e. whether footpoint is given as well)
124
+
125
+ Note that tangent vectors are always represented in the Lie Algebra.Thus, the Riemannian and group
126
+ operation coincide.
127
+ """
128
+ return jax.lax.cond(len(argv) == 1,
129
+ lambda A: A[-1],
130
+ lambda A: A[-1]- A[0],
131
+ argv)
132
+
133
+ def curvature_tensor(self, x, X, Y, Z):
134
+ return jnp.zeros(self._M.point_shape)
135
+
136
+ def geopoint(self, x, y, t):
137
+ return x + t * (y - x)
138
+
139
+ @property
140
+ def identity(self):
141
+ return jnp.zeros(self._M.point_shape)
142
+
143
+ def transp(self, x, y, X):
144
+ return X
145
+
146
+ def pairmean(self, x, y):
147
+ return self.geopoint(x, y, .5)
148
+
149
+ def dist(self, x, y):
150
+ return jnp.linalg.norm(y - x)
151
+
152
+ def squared_dist(self, x, y):
153
+ return jnp.sum((y-x)**2)
154
+
155
+ def jacobiField(self, x, y, t, X):
156
+ return [self.geopoint(x, y, t), (1-t) * X]
157
+
158
+ def adjJacobi(self, x, y, t, X):
159
+ return 1/(1-t) * X
160
+
161
+ def dleft(self, f, X):
162
+ """Derivative of the left translation by f at e applied to the tangent vector X.
163
+ """
164
+ return X
165
+
166
+ def dright(self, f, X):
167
+ """Derivative of the right translation by f at e applied to the tangent vector X.
168
+ """
169
+ return self.dleft(f,X)
170
+
171
+ def dleft_inv(self, f, X):
172
+ """Derivative of the left translation by f^{-1} at f applied to the tangent vector X.
173
+ """
174
+ return self.dleft(-f, X)
175
+
176
+ def dright_inv(self, f, X):
177
+ """Derivative of the right translation by f^{-1} at f applied to the tangent vector X.
178
+ """
179
+ return self.dleft_inv(f,X)
180
+
181
+ def lefttrans(self, g, f):
182
+ """Left translation of g by f.
183
+ """
184
+ return self.dleft(g,f)
185
+
186
+ def righttrans(self, g, f):
187
+ """Right translation of g by f.
188
+ """
189
+ return self.dleft(g,f)
190
+
191
+ def inverse(self, g):
192
+ """Inverse map of the Lie group.
193
+ """
194
+ return -g
195
+
196
+ def coords(self, X):
197
+ """Coordinate map for the tangent space at the identity."""
198
+ return X
199
+
200
+ def coords_inverse(self, c):
201
+ """Inverse of coords"""
202
+ return self.coords(c)
203
+
204
+ def bracket(self, X, Y):
205
+ return self.identity
206
+
207
+ def adjrep(self, g, X):
208
+ """Adjoint representation of g applied to the tangent vector X at the identity.
209
+ """
210
+ raise NotImplementedError('This function has not been implemented yet.')
211
+
212
+ def euclidean_inner(X, Y):
213
+ return (X * Y).sum()
@@ -0,0 +1,440 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import os
14
+
15
+ import numpy as np
16
+ import jax
17
+ import jax.numpy as jnp
18
+
19
+ from scipy import sparse
20
+
21
+ try:
22
+ from sksparse.cholmod import cholesky as direct_solve
23
+ except:
24
+ from scipy.sparse.linalg import factorized as direct_solve
25
+
26
+ from ..geom import Surface
27
+ from . import SO3, SPD
28
+ from . import PowerManifold, ProductManifold
29
+ from . import ShapeSpace
30
+
31
+
32
+ class FundamentalCoords(ShapeSpace):
33
+ """
34
+ Shape space based on fundamental coordinates.
35
+
36
+ See:
37
+ Felix Ambellan, Stefan Zachow, and Christoph von Tycowicz.
38
+ A Surface-Theoretic Approach for Statistical Shape Modeling.
39
+ Proc. Medical Image Computing and Computer Assisted Intervention (MICCAI), LNCS, 2019.
40
+ """
41
+
42
+ def __init__(self, reference: Surface, structure='product', metric_weights=(1.0, 1.0)):
43
+ """
44
+ :arg reference: Reference surface (shapes will be encoded as deformations thereof)
45
+ :arg metric_weights: weights (rotation, stretch) for commensuration between rotational and stretch parts
46
+ """
47
+ assert reference is not None
48
+ self.ref = reference
49
+ self.init_face = int(os.getenv('FCM_INIT_FACE', 0)) # initial face for spanning tree path
50
+ self.init_vert = int(os.getenv('FCM_INIT_VERT', 0)) # id of fixed vertex
51
+
52
+ self.integration_tol = float(os.getenv('FCM_INTEGRATION_TOL', 1e-05)) # integration tolerance local/global solver
53
+ self.integration_iter = int(os.getenv('FCM_INTEGRATION_ITER', 3)) # max iteration local/global solver
54
+
55
+ omega_C = float(os.getenv('FCM_WEIGHT_ROTATION', metric_weights[0]))
56
+ omega_U = float(os.getenv('FCM_WEIGHT_STRETCH', metric_weights[1]))
57
+ metric_weights = (omega_C, omega_U)
58
+
59
+ self.spanning_tree_path = self.setup_spanning_tree_path()
60
+
61
+ # setup manifolds ...
62
+ # ... relative rotations (transition rotations)
63
+ self.SO = PowerManifold(SO3(), int(0.5 * self.ref.inner_edges.getnnz()))
64
+ # ... stretch w.r.t. tangent space
65
+ self.SPD = PowerManifold(SPD(2), self.ref.f.shape[0])
66
+ # product of both
67
+ self._M = ProductManifold([self.SO, self.SPD], jnp.asarray(metric_weights))
68
+
69
+ self.update_ref_geom(self.ref.v)
70
+
71
+ name = f'Fundamental Coordinates Shape Space ({structure})'
72
+ super().__init__(name, self.M.dim, self.M.point_shape, self.M.connec, self.M.metric, None)
73
+
74
+ def tree_flatten(self):
75
+ return (self.M,), (self.ref.v.tolist(), self.ref.f.tolist())
76
+
77
+ @classmethod
78
+ def tree_unflatten(cls, aux_data, children):
79
+ """Specifies an unflattening recipe for PyTree registration."""
80
+ M = children[0]
81
+ obj = cls(Surface(*aux_data))
82
+ obj._M = M
83
+ obj.SO, obj.SPD = M.manifolds
84
+ return obj
85
+
86
+ @property
87
+ def M(self):
88
+ return self._M
89
+
90
+ @property
91
+ def n_triangles(self):
92
+ """Number of triangles of the reference surface
93
+ """
94
+ return len(self.ref.f)
95
+
96
+ def update_ref_geom(self, v):
97
+ self.ref.v=v
98
+
99
+ # center of gravity
100
+ self.CoG = self.ref.v.mean(axis=0)
101
+
102
+ # setup Poisson system
103
+ S = self.ref.div @ self.ref.grad
104
+ # add soft-constraint fixing translational DoF
105
+ S += sparse.coo_matrix(([1.0], ([0], [0])), S.shape) # make pos-def
106
+ self.poisson = direct_solve(S.tocsc())
107
+
108
+ self.ref_frame_field = self.setup_frame_field()
109
+
110
+ edgeAreaFactor = np.divide(self.ref.edge_areas, np.sum(self.ref.edge_areas))
111
+ faceAreaFactor = np.divide(self.ref.face_areas, np.sum(self.ref.face_areas))
112
+
113
+ # update weights
114
+ self.SO.metric_weights = edgeAreaFactor
115
+ self.SPD.metric_weights = faceAreaFactor
116
+
117
+ self._identity = self.to_coords(self.ref.v)
118
+
119
+ def to_coords(self, v):
120
+ """
121
+ :arg v: #v-by-3 array of vertex coordinates
122
+ :return: fundamental coords.
123
+ """
124
+ # compute gradients
125
+ D = self.ref.grad @ v
126
+
127
+ # decompose...
128
+ U, S, Vt = np.linalg.svd(D.reshape(-1, 3, 3))
129
+
130
+ # D holds transpose of def. grads.
131
+ # -> compute left polar decomposition for right stretch tensor
132
+
133
+ # ...rotation
134
+ R = np.einsum('...ij,...jk', U, Vt)
135
+ W = np.ones_like(S)
136
+ W[:, -1] = np.linalg.det(R)
137
+ R = np.einsum('...ij,...j,...jk', U, W, Vt)
138
+
139
+ # ...stretch
140
+ S[:, -1] = 1 # no stretch (=1) in normal direction
141
+ # for degenerate triangles
142
+ # TODO: check which direction is normal in degenerate case
143
+ S[S < 1e-6] = 1e-6
144
+ U = np.einsum('...ij,...j,...kj', U, S, U)
145
+
146
+ # frame field on actual shape pushed over from reference shape
147
+ frame = np.einsum('...ji,...jk', R, self.ref_frame_field)
148
+
149
+ # setup ...transition rotations for every inner edge
150
+ e = sparse.triu(self.ref.inner_edges).tocoo()
151
+ C = np.zeros((e.getnnz(), 3, 3))
152
+ C[e.data[:]] = np.einsum('...ji,...jk', frame[e.row[:]], frame[e.col[:]])
153
+
154
+ # transform ...stretch from gobal (standard) coordinates to tangential Ulocal
155
+ # frame.T * U * frame
156
+ Ulocal = np.einsum('...ji,...jk,...kl', self.ref_frame_field, U, self.ref_frame_field)
157
+ Ulocal = Ulocal[:,0:-1, 0:-1]
158
+
159
+ return self.M.entangle((C, Ulocal))
160
+
161
+ def from_coords(self, c):
162
+ """
163
+ :arg c: fundamental coords.
164
+ :returns: #v-by-3 array of vertex coordinates
165
+ """
166
+ ################################################################################################################
167
+ # initialization with spanning tree path #######################################################################
168
+ C, Ulocal = self.M.disentangle(np.asarray(c))
169
+
170
+ eIds = self.spanning_tree_path[:,0]
171
+ fsourceId = self.spanning_tree_path[:, 1]
172
+ ftargetId = self.spanning_tree_path[:, 2]
173
+
174
+ # organize transition rotations along the path
175
+ CoI = C[eIds[:]]
176
+ CC = np.zeros_like(CoI)
177
+ BB = (fsourceId < ftargetId)
178
+ CC[BB] = CoI[BB]
179
+ CC[~BB] = np.einsum("...ij->...ji", CoI[~BB])
180
+
181
+ R= np.repeat(np.eye(3)[np.newaxis, :, :], len(self.ref.f), axis=0)
182
+
183
+ # walk along path and initialize rotations
184
+ CC = np.einsum('...jk,...kl,...ml', self.ref_frame_field[fsourceId], CC, self.ref_frame_field[ftargetId])
185
+ for l in range(eIds.shape[0]):
186
+ R[ftargetId[l]] = R[fsourceId[l]] @ CC[l]
187
+
188
+ # transform (tangential) Ulocal to gobal (standard) coordinates
189
+ U = np.zeros_like(R)
190
+ U[:, 0:-1, 0:-1] = Ulocal
191
+ # frame * U * frame.T
192
+ U = np.einsum('...ij,...jk,...lk', self.ref_frame_field, U, self.ref_frame_field)
193
+
194
+ idx_1, idx_2, idx_3, n_1, n_2, n_3 = self.ref.neighbors
195
+
196
+ e = sparse.triu(self.ref.inner_edges).tocoo(); f = sparse.tril(self.ref.inner_edges).tocoo()
197
+
198
+ e.data += 1; f.data += 1
199
+
200
+ CC = np.zeros((C.shape[0] + 1, 3, 3)); CCt = np.zeros((C.shape[0] + 1, 3, 3))
201
+ CC[e.data] = C[e.data - 1]; CCt[f.data] = np.einsum("...ij->...ji", C[f.data - 1])
202
+
203
+ e = e.tocsr(); f = f.tocsr()
204
+
205
+ Dijk = R.copy()
206
+ n_iter = 0
207
+ v = np.asarray(self.ref.v.copy())
208
+ vk = np.asarray(self.ref.v.copy())
209
+ sqrt_tol = np.sqrt(self.integration_tol)
210
+ while n_iter < self.integration_iter:
211
+ ################################################################################################################
212
+ # global step ##################################################################################################
213
+
214
+ # setup gradient matrix and solve Poisson system
215
+ D = np.einsum('...ij,...kj', U, R) # <-- from left polar decomp.
216
+ rhs = self.ref.div @ D.reshape(-1, 3)
217
+ vk = v
218
+ v = self.poisson(rhs)
219
+ v += self.CoG - v.mean(axis=0)
220
+ errCoord = np.amax(np.abs((v - vk)))
221
+ errCoordTol = sqrt_tol * (1.0 + np.amax(np.abs((vk))))
222
+
223
+ ################################################################################################################
224
+ # local step ###################################################################################################
225
+ if (n_iter + 1 == self.integration_iter) or (errCoord < errCoordTol):
226
+ break
227
+
228
+ # compute gradients again
229
+ D = (self.ref.grad @ v).reshape(-1, 3, 3)
230
+
231
+ Dijk[idx_1] = np.einsum('...ji,...jk,...kl,...lm,...nm', D[n_1[:, 0]], U[n_1[:, 0]], self.ref_frame_field[n_1[:, 0]], CCt[e[idx_1, n_1[:, 0]]] + CC[f[idx_1, n_1[:, 0]]], self.ref_frame_field[idx_1])
232
+ if n_2.shape[0] > 0 :
233
+ Dijk[idx_2] = Dijk[idx_2] + np.einsum('...ji,...jk,...kl,...lm,...nm', D[n_2[:, 1]], U[n_2[:, 1]], self.ref_frame_field[n_2[:, 1]], CCt[e[idx_2, n_2[:, 1]]] + CC[f[idx_2, n_2[:, 1]]], self.ref_frame_field[idx_2])
234
+ if n_3.shape[0] > 0 :
235
+ Dijk[idx_3] = Dijk[idx_3] + np.einsum('...ji,...jk,...kl,...lm,...nm', D[n_3[:, 2]], U[n_3[:, 2]], self.ref_frame_field[n_3[:, 2]], CC[f[idx_3, n_3[:, 2]]] + CCt[e[idx_3, n_3[:, 2]]], self.ref_frame_field[idx_3])
236
+
237
+ Uijk, Sijk, Vtijk = np.linalg.svd(Dijk)
238
+ R = np.einsum('...ij,...jk', Uijk, Vtijk)
239
+ Wijk = np.ones_like(Sijk)
240
+ Wijk[:, -1] = np.linalg.det(R)
241
+ R = np.einsum('...ij,...j,...jk', Uijk, Wijk, Vtijk)
242
+
243
+ n_iter += 1
244
+
245
+ # orient w.r.t. fixed frame and move to fixed node
246
+ v[:] = (self.ref_frame_field[self.init_face] @ FundamentalCoords.frame_of_face(v, self.ref.f, [self.init_face]).T @ v[:].T).T
247
+ v += self.ref.v[self.init_vert] - v[self.init_vert]
248
+ # print("v:\n", v)
249
+ return v
250
+
251
+ @property
252
+ def ref_coords(self):
253
+ return self._identity
254
+
255
+ def rand(self, key: jax.Array):
256
+ return self.M.rand(key)
257
+
258
+ def zerovec(self):
259
+ """Returns the zero vector in any tangent space."""
260
+ return self.M.zerovec()
261
+
262
+ def proj(self, X, A):
263
+ """orthogonal (with respect to the euclidean inner product) projection of ambient
264
+ vector (vectorized (2,k,3,3) array) onto the tangentspace at X"""
265
+ return self.M.proj(X, A)
266
+
267
+ def projToGeodesic(self, X, Y, P, max_iter = 10):
268
+ '''
269
+ Project P onto geodesic from X to Y.
270
+
271
+ See:
272
+ Felix Ambellan, Stefan Zachow, Christoph von Tycowicz.
273
+ Geodesic B-Score for Improved Assessment of Knee Osteoarthritis.
274
+ Proc. Information Processing in Medical Imaging (IPMI), LNCS, 2021.
275
+
276
+ :arg X, Y: manifold coords defining geodesic X->Y.
277
+ :arg P: manifold coords to be projected to X->Y.
278
+ :returns: manifold coords of projection of P to X->Y
279
+ '''
280
+
281
+ # all tangent vectors in common space i.e. algebra
282
+ v = self.connec.log(X, Y)
283
+ v = v / self.metric.norm(X, v)
284
+
285
+ # initial guess
286
+ Pi = X.copy()
287
+
288
+ # solver loop
289
+ for _ in range(max_iter):
290
+ w = self.connec.log(Pi, P)
291
+ d = self.metric.inner(Pi, v, w)
292
+
293
+ # print(f'|<v, w>|={d}')
294
+ if abs(d) < 1e-6: break
295
+
296
+ Pi = self.connec.exp(Pi, d * v)
297
+
298
+ return Pi
299
+
300
+ def setup_spanning_tree_path(self):
301
+ """
302
+ Setup a path across spanning tree of the refrence surface beginning at self.init_face.
303
+ :return: n x 3 - array holding column wise an edge id and the respective neighbouring faces.
304
+ """
305
+ depth =[-1]*(len(self.ref.f))
306
+
307
+ depth[self.init_face] = 0
308
+ idcs = []
309
+ idcs.append(self.init_face)
310
+
311
+ spanningTreePath = []
312
+ while idcs:
313
+ idx = idcs.pop(0)
314
+ d = depth[idx] + 1
315
+ neighs = self.ref.inner_edges.getrow(idx).tocoo()
316
+
317
+ for neigh, edge in zip(neighs.col, neighs.data):
318
+ if depth[neigh] >= 0:
319
+ continue
320
+ depth[neigh] = d
321
+ idcs.append(neigh)
322
+
323
+ spanningTreePath.append([edge, idx, neigh])
324
+ return np.asarray(spanningTreePath)
325
+
326
+ def setup_frame_field(self):
327
+ """
328
+ Compute frames for every face of the surface with some added pi(e).
329
+ :return: n x 3 x 3 - array holding one frame for every face, column wise organized with c1, c2 tangential and c3 normal..
330
+ """
331
+ v1 = self.ref.v[self.ref.f[:, 2]] - self.ref.v[self.ref.f[:, 1]]
332
+ v2 = self.ref.v[self.ref.f[:, 0]] - self.ref.v[self.ref.f[:, 2]]
333
+
334
+ # orthonormal basis for face plane
335
+ proj = np.divide(np.einsum('ij,ij->i', v2, v1), np.einsum('ij,ij->i', v1, v1))
336
+ proj = sparse.diags(proj)
337
+
338
+ v2 = v2 - proj @ v1
339
+
340
+ # normalize and calculation of normal
341
+ v1 = v1 / np.linalg.norm(v1, axis=1, keepdims=True)
342
+ v2 = v2 / np.linalg.norm(v2, axis=1, keepdims=True)
343
+ v3 = np.cross(v1, v2, axisa=1, axisb=1, axisc=1)
344
+
345
+ # shape as n x 3 x 3 with basis vectors as cols
346
+ frame = np.reshape(np.concatenate((v1, v2, v3), axis=1), [-1, 3, 3])
347
+ frame = np.einsum('ijk->ikj', frame)
348
+
349
+ return frame
350
+
351
+ @staticmethod
352
+ def frame_of_face(v, f, fId : int):
353
+ """
354
+ :arg fId: id of face to caluclate frame for
355
+ :return: frame (colunm wise) with c1, c2 tangential and c3 normal.
356
+ """
357
+ v1 = v[f[fId, 2]] - v[f[fId, 1]]
358
+ v2 = v[f[fId, 0]] - v[f[fId, 2]]
359
+
360
+ # orthonormal basis for face plane
361
+ v2 = v2 - (np.dot(v2, v1.T) / np.dot(v1, v1.T)) * v1
362
+
363
+ # normalize and calculation of normal
364
+ v1 = v1 / np.linalg.norm(v1)
365
+ v2 = v2 / np.linalg.norm(v2)
366
+ v3 = np.cross(v1, v2)
367
+
368
+ return np.column_stack((v1.T, v2.T, v3.T))
369
+
370
+ def flatCoords(self, X):
371
+ """
372
+ Project shape X isometrically to flat configuration.
373
+ :param X: element of the space of fundamental coordinates
374
+ :returns: Flattened configuration.
375
+ """
376
+ _, Ulocal = self.M.disentangle(np.asarray(X))
377
+
378
+ inner_edge = sparse.triu(self.ref.inner_edges)
379
+
380
+ C = np.zeros((self.SO.k,3,3))
381
+
382
+ for l in range(inner_edge.data.shape[0]):
383
+ # transition rotations are directed from triangle with lower to triangle with higher id
384
+ i = inner_edge.row[l]
385
+ j = inner_edge.col[l]
386
+
387
+ ###### calc quaternion representing rotation from nj to ni ######
388
+
389
+ ni = self.ref_frame_field[i][:, 2]
390
+ nj = self.ref_frame_field[j][:, 2]
391
+
392
+ lni_lnj = np.sqrt(np.dot(ni,ni)*np.dot(nj,nj))
393
+ qw = lni_lnj + np.dot(ni,nj)
394
+
395
+ # check for anti-parallelism of ni and nj
396
+ if (qw < 1.0e-7 * lni_lnj):
397
+ qw=0.0
398
+ if(np.abs(ni[0]) > np.abs(ni[2])):
399
+ qxyz = np.array([ -ni[1], ni[0], 0.0 ])
400
+ else:
401
+ qxyz = np.array([ 0.0, -ni[2], ni[1]])
402
+ else:
403
+ qxyz = np.cross(nj, ni)
404
+
405
+ # normalize quaternion
406
+ lq = np.sqrt(qw*qw + np.dot(qxyz, qxyz))
407
+ qw = qw / lq
408
+ qxyz = qxyz / lq
409
+
410
+ ########## get rotation matrix from (unit) quarternion ##########
411
+
412
+ Rninj = np.eye(3)
413
+
414
+ qwqw = qw * qw
415
+ qxqx = qxyz[0] * qxyz[0]
416
+ qyqy = qxyz[1] * qxyz[1]
417
+ qzqz = qxyz[2] * qxyz[2]
418
+ qxqy = qxyz[0] * qxyz[1]
419
+ qzqw = qxyz[2] * qw
420
+ qxqz = qxyz[0] * qxyz[2]
421
+ qyqw = qxyz[1] * qw
422
+ qyqz = qxyz[1] * qxyz[2]
423
+ qxqw = qxyz[0] * qw
424
+
425
+ Rninj[0, 0] = qxqx - qyqy - qzqz + qwqw
426
+ Rninj[1, 1] = -qxqx + qyqy - qzqz + qwqw
427
+ Rninj[2, 2] = -qxqx - qyqy + qzqz + qwqw
428
+ Rninj[1, 0] = 2.0 * (qxqy + qzqw)
429
+ Rninj[0, 1] = 2.0 * (qxqy - qzqw)
430
+ Rninj[2, 0] = 2.0 * (qxqz - qyqw)
431
+ Rninj[0, 2] = 2.0 * (qxqz + qyqw)
432
+ Rninj[2, 1] = 2.0 * (qyqz + qxqw)
433
+ Rninj[1, 2] = 2.0 * (qyqz - qxqw)
434
+
435
+ #################################################################
436
+
437
+ # update transition rotations
438
+ C[inner_edge.data[l]] = self.ref_frame_field[i].T @ Rninj @ self.ref_frame_field[j]
439
+
440
+ return np.concatenate([np.ravel(C), np.ravel(Ulocal)]).reshape(-1)