morphomatics 4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. morphomatics-4.0/LICENSE +9 -0
  2. morphomatics-4.0/PKG-INFO +55 -0
  3. morphomatics-4.0/README.md +28 -0
  4. morphomatics-4.0/morphomatics/__init__.py +13 -0
  5. morphomatics-4.0/morphomatics/geom/__init__.py +16 -0
  6. morphomatics-4.0/morphomatics/geom/bezier_spline.py +361 -0
  7. morphomatics-4.0/morphomatics/geom/misc.py +104 -0
  8. morphomatics-4.0/morphomatics/geom/surface.py +208 -0
  9. morphomatics-4.0/morphomatics/graph/__init__.py +13 -0
  10. morphomatics-4.0/morphomatics/graph/operators.py +124 -0
  11. morphomatics-4.0/morphomatics/manifold/__init__.py +46 -0
  12. morphomatics-4.0/morphomatics/manifold/bezierfold.py +500 -0
  13. morphomatics-4.0/morphomatics/manifold/connection.py +105 -0
  14. morphomatics-4.0/morphomatics/manifold/cubic_bezierfold.py +305 -0
  15. morphomatics-4.0/morphomatics/manifold/differential_coords.py +197 -0
  16. morphomatics-4.0/morphomatics/manifold/discrete_ops.py +56 -0
  17. morphomatics-4.0/morphomatics/manifold/euclidean.py +213 -0
  18. morphomatics-4.0/morphomatics/manifold/fundamental_coords.py +440 -0
  19. morphomatics-4.0/morphomatics/manifold/gl_p_coords.py +149 -0
  20. morphomatics-4.0/morphomatics/manifold/gl_p_n.py +201 -0
  21. morphomatics-4.0/morphomatics/manifold/grassmann.py +174 -0
  22. morphomatics-4.0/morphomatics/manifold/hyperbolic_space.py +271 -0
  23. morphomatics-4.0/morphomatics/manifold/kendall.py +269 -0
  24. morphomatics-4.0/morphomatics/manifold/lie_group.py +102 -0
  25. morphomatics-4.0/morphomatics/manifold/manifold.py +162 -0
  26. morphomatics-4.0/morphomatics/manifold/manopt_wrapper.py +185 -0
  27. morphomatics-4.0/morphomatics/manifold/metric.py +110 -0
  28. morphomatics-4.0/morphomatics/manifold/point_distribution_model.py +143 -0
  29. morphomatics-4.0/morphomatics/manifold/power_manifold.py +413 -0
  30. morphomatics-4.0/morphomatics/manifold/product_manifold.py +381 -0
  31. morphomatics-4.0/morphomatics/manifold/se_3.py +419 -0
  32. morphomatics-4.0/morphomatics/manifold/shape_space.py +57 -0
  33. morphomatics-4.0/morphomatics/manifold/so_3.py +494 -0
  34. morphomatics-4.0/morphomatics/manifold/spd.py +524 -0
  35. morphomatics-4.0/morphomatics/manifold/sphere.py +241 -0
  36. morphomatics-4.0/morphomatics/manifold/tangent_bundle.py +337 -0
  37. morphomatics-4.0/morphomatics/manifold/util.py +126 -0
  38. morphomatics-4.0/morphomatics/nn/__init__.py +15 -0
  39. morphomatics-4.0/morphomatics/nn/flow_layers.py +219 -0
  40. morphomatics-4.0/morphomatics/nn/tangent_layers.py +176 -0
  41. morphomatics-4.0/morphomatics/nn/train.py +202 -0
  42. morphomatics-4.0/morphomatics/nn/wFM_layers.py +152 -0
  43. morphomatics-4.0/morphomatics/opt/__init__.py +14 -0
  44. morphomatics-4.0/morphomatics/opt/riemannian_newton_raphson.py +65 -0
  45. morphomatics-4.0/morphomatics/opt/riemannian_steepest_descent.py +61 -0
  46. morphomatics-4.0/morphomatics/stats/__init__.py +18 -0
  47. morphomatics-4.0/morphomatics/stats/biinvariant_statistics.py +190 -0
  48. morphomatics-4.0/morphomatics/stats/exponential_barycenter.py +78 -0
  49. morphomatics-4.0/morphomatics/stats/geometric_median.py +89 -0
  50. morphomatics-4.0/morphomatics/stats/principal_geodesic_analysis.py +135 -0
  51. morphomatics-4.0/morphomatics/stats/riemannian_regression.py +317 -0
  52. morphomatics-4.0/morphomatics/stats/statistical_shape_model.py +99 -0
  53. morphomatics-4.0/morphomatics.egg-info/PKG-INFO +55 -0
  54. morphomatics-4.0/morphomatics.egg-info/SOURCES.txt +57 -0
  55. morphomatics-4.0/morphomatics.egg-info/dependency_links.txt +1 -0
  56. morphomatics-4.0/morphomatics.egg-info/requires.txt +8 -0
  57. morphomatics-4.0/morphomatics.egg-info/top_level.txt +1 -0
  58. morphomatics-4.0/setup.cfg +4 -0
  59. morphomatics-4.0/setup.py +57 -0
@@ -0,0 +1,9 @@
1
+ MIT License
2
+
3
+ Copyright (C) 2024 Zuse Institute Berlin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.1
2
+ Name: morphomatics
3
+ Version: 4.0
4
+ Summary: Geometric morphometrics in non-Euclidean shape spaces
5
+ Home-page: https://morphomatics.github.io/
6
+ Author: Christoph von Tycowicz et al.
7
+ Author-email: vontycowicz@zib.de
8
+ License: MIT License
9
+ Keywords: Shape Analysis,Morphometrics,Geometric Statistics
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Topic :: Software Development :: Build Tools
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: jax>=0.4.25
21
+ Requires-Dist: jaxlib>=0.4.25
22
+ Requires-Dist: jraph
23
+ Requires-Dist: flax
24
+ Requires-Dist: optax
25
+ Provides-Extra: all
26
+ Requires-Dist: pymanopt>=2.0.1; extra == "all"
27
+
28
+ <div align="center">
29
+ <img src="https://github.com/morphomatics/morphomatics.github.io/blob/master/images/logo_cyan.png?raw=true" width="250" alt="Morphomatics"/>
30
+ </div>
31
+
32
+ # Morphomatics: Geometric morphometrics in non-Euclidean shape spaces
33
+
34
+ Morphomatics is an open-source Python library for (statistical) shape analysis developed within the [geometric data analysis and processing](https://www.zib.de/visual/geometric-data-analysis-and-processing) research group at Zuse Institute Berlin.
35
+ It contains prototype implementations of intrinsic manifold-based methods that are highly consistent and avoid the influence of unwanted effects such as bias due to arbitrary choices of coordinates.
36
+
37
+ Detailed information and tutorials can be found at https://morphomatics.github.io/
38
+
39
+ ## Installation
40
+
41
+ Morphomatics can be installed directly from github using the following command:
42
+ ```
43
+ pip install git+https://github.com/morphomatics/morphomatics.git#egg=morphomatics
44
+ ```
45
+ For instructions on how to set up `jaxlib`, please refer to the [JAX install guide](https://github.com/google/jax#installation).
46
+
47
+ ## Dependencies
48
+ * jax/jaxlib
49
+ * jraph
50
+ * flax
51
+ * optax
52
+
53
+ Optional
54
+ * pymanopt
55
+ * sksparse
@@ -0,0 +1,28 @@
1
+ <div align="center">
2
+ <img src="https://github.com/morphomatics/morphomatics.github.io/blob/master/images/logo_cyan.png?raw=true" width="250" alt="Morphomatics"/>
3
+ </div>
4
+
5
+ # Morphomatics: Geometric morphometrics in non-Euclidean shape spaces
6
+
7
+ Morphomatics is an open-source Python library for (statistical) shape analysis developed within the [geometric data analysis and processing](https://www.zib.de/visual/geometric-data-analysis-and-processing) research group at Zuse Institute Berlin.
8
+ It contains prototype implementations of intrinsic manifold-based methods that are highly consistent and avoid the influence of unwanted effects such as bias due to arbitrary choices of coordinates.
9
+
10
+ Detailed information and tutorials can be found at https://morphomatics.github.io/
11
+
12
+ ## Installation
13
+
14
+ Morphomatics can be installed directly from github using the following command:
15
+ ```
16
+ pip install git+https://github.com/morphomatics/morphomatics.git#egg=morphomatics
17
+ ```
18
+ For instructions on how to set up `jaxlib`, please refer to the [JAX install guide](https://github.com/google/jax#installation).
19
+
20
+ ## Dependencies
21
+ * jax/jaxlib
22
+ * jraph
23
+ * flax
24
+ * optax
25
+
26
+ Optional
27
+ * pymanopt
28
+ * sksparse
@@ -0,0 +1,13 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ __version__ = '4.0.dev0'
@@ -0,0 +1,16 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ from .misc import memoize, gradient_matrix_ambient
14
+
15
+ from .surface import Surface
16
+ from .bezier_spline import BezierSpline
@@ -0,0 +1,361 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ # postponed evaluation of annotations to circumvent cyclic dependencies (will be default behavior in Python 4.0)
14
+ from __future__ import annotations
15
+ # from morphomatics.manifold import Manifold
16
+
17
+ import numpy as np
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import jax.lax as lax
21
+
22
+
23
+ from typing import Tuple, List
24
+
25
+
26
+ class BezierSpline:
27
+ """Manifold-valued spline that consists of Bézier curves"""
28
+
29
+ def __init__(self, M: Manifold, control_points: jnp.array, iscycle: bool = False):
30
+ """
31
+ :arg M: manifold in which the curve lies
32
+ :arg control_points: array of control points of the Bézier spline, the L >= 1 segments must be sorted along the
33
+ first axis and all segments must have the same degree k; i.e., the input must be an [L, k, M.point_shape] array
34
+ :arg iscycle: boolean indicating whether B is a closed curve
35
+ """
36
+ assert M is not None
37
+
38
+ self._M = M
39
+
40
+ self.control_points = jnp.asarray(control_points)
41
+
42
+ self.iscycle = iscycle
43
+
44
+ def __str__(self) -> str:
45
+ return 'Bézier spline through ' + str(self._M)
46
+
47
+ @property
48
+ def nsegments(self) -> int:
49
+ """Returns the number of segments."""
50
+ return len(self.control_points)
51
+
52
+ @property
53
+ def degrees(self) -> jnp.array:
54
+ """Returns the degrees of the spline segments."""
55
+ L = len(self.control_points)
56
+ n_seg = np.zeros(L, dtype=int)
57
+ for i in range(L):
58
+ n_seg[i] = np.shape(self.control_points[i])[0] - 1
59
+ return n_seg
60
+
61
+ def length(self) -> float:
62
+ # TODO
63
+ return
64
+
65
+ def energy(self) -> float:
66
+ # TODO
67
+ return
68
+
69
+ def tangent(self, t: float) -> jnp.array:
70
+ """
71
+ Compute the tangent vector at the point of the spline corresponding to t.
72
+ """
73
+
74
+ def bezier_tangent(bet:BezierSpline, s):
75
+ """
76
+ Compute the tangent vector at the point of a (single) Bèzier curve corresponding to t in [0, 1].
77
+ """
78
+
79
+ def single_layer(A, r, X=None):
80
+ """
81
+ Single layer of the computation consisting of a single step of the de Casteljau algorithm
82
+ plus additinal vectors transport/computation.
83
+ """
84
+ if X is None:
85
+ # averaging of a single layer in de Casteljau algorithm
86
+ size = np.array(np.shape(A))
87
+ # give back one point less
88
+ size[0] = size[0] - 1
89
+ B = np.zeros(size)
90
+ for i in range(size[0]):
91
+ B[i] = self._M.exp(A[i], self._M.log(A[i], A[i + 1]) * r)
92
+ return B
93
+
94
+ else:
95
+ # averaging of a single layer in de Casteljau algorithm
96
+ size = np.array(np.shape(A))
97
+ # give back one point less
98
+ size[0] = size[0] - 1
99
+ B = np.zeros(size)
100
+ for i in range(size[0]):
101
+ B[i] = self._M.exp(A[i], self._M.log(A[i], A[i + 1]) * r)
102
+
103
+ # calculate updates of tangent vectors
104
+ X_shape = X.shape
105
+ X_shape[0] -= 1
106
+ Y = np.zeros(X_shape)
107
+ for i in range(len(Y)):
108
+ # new point is on geodesic between old control points -> log to endpoint shortened tangent vector
109
+ v = self._M.connec.log(B[ii], A[ii+1])
110
+ # rescale
111
+ v = v / self._M.metric.norm(B[ii], v) * self._M.metric.dist(bet.control_points[ii],
112
+ bet.control_points[ii + 1])
113
+ Y[i] += v
114
+ # add transported old vectors
115
+ # X[i] 'forward' X[i+1] 'backward'
116
+ Y[i] += self._M.connec.DxGeo(A[i], A[i+1], r, X[i])
117
+ Y[i] += self._M.connec.DyGeo(A[i], A[i + 1], r, X[i+1])
118
+
119
+ return B, Y
120
+
121
+ k = bet.degrees[0]
122
+ if s == 0:
123
+ return bet.eval(0), k * self._M.connec.log(bet.control_points[0][0], bet.control_points[0][1])
124
+ elif s == 1:
125
+ return bet.eval(1), -k * self._M.connec.log(bet.control_points[0][-1], bet.control_points[0][-2])
126
+ else:
127
+ P_old = bet.control_points[0]
128
+ P = single_layer(P_old, s)
129
+
130
+ X = np.zeros(k, self._M.zerovec().shape)
131
+ for ii in range(len(P)):
132
+ # new point is on geodesic between old control points -> log to endpoint shortened tangent vector
133
+ v = self._M.connec.log(P[ii], P_old[ii+1])
134
+ # rescale
135
+ X[ii] = v / self._M.connec.norm(P[ii], v) * self._M.metric.dist(P_old[ii], P_old[ii + 1])
136
+
137
+ # there are k+1 control points
138
+ for l in range(k):
139
+ P, X = single_layer(P, s, X)
140
+
141
+ return P, X
142
+
143
+ # get segment and local parameter
144
+ ind, t = segmentize(t)
145
+
146
+ return bezier_tangent(BezierSpline(self._M, [self.control_points[ind]]), t)
147
+
148
+ def isC1(self, eps: float = 1e-5) -> bool:
149
+ """
150
+ Check whether the spline is (approximately) continuously differentible. For this, all control points that connect
151
+ two segments must be in the middle of their neighbours.
152
+ """
153
+ cp = self.control_points
154
+
155
+ # trivial case: only one segment -> infinitly often differentible
156
+ if len(cp) == 1:
157
+ return True
158
+
159
+ for i, seg in enumerate(cp[1:]):
160
+ p = self._M.connec.geopoint(cp[i-1][-2], seg[1], 1/2)
161
+ # if midpoint and connecting control point are further apart than epsilon return False
162
+ if self._M.metric.dist(p, seg[0]) > eps:
163
+ return False
164
+
165
+ return True
166
+
167
+ def geoshaped(self, eps: float = 1e-7) -> bool:
168
+ """
169
+ Return whether the spline is a reparametrized geodesic. For this we test if all tangent vectors from the first
170
+ control point to the other control points are parallel (within a tolerance of epsilon).
171
+ """
172
+ cp = self.control_points.copy()
173
+
174
+ # trivial case
175
+ if len(cp) == 1 and len(cp[0]) == 2:
176
+ return True
177
+
178
+ c = cp[0][0]
179
+ v0 = self._M.connec.log(c, cp[0][1])
180
+ cp[0] = cp[0][2:]
181
+ # check whether the logs at c to all other control points are parallel to v0
182
+ for seg in cp:
183
+ for cc in seg:
184
+ # ignore almost equal points---the test is unstable for them and their influence in non-geodecity is
185
+ # negligable
186
+ if self._M.metric.dist(c, cc) > 1e-7:
187
+ v = self._M.connec.log(c, cc)
188
+ par = self._M.metric.inner(c, v0, v) / (self._M.metric.norm(c, v0) * self._M.metric.norm(c, v))
189
+
190
+ if -1 + eps < par < 1 - eps:
191
+ # v and v0 are not parallel
192
+ return False
193
+ # all vectors were (almost) parallel
194
+ return True
195
+
196
+ def eval(self, t: float) -> jnp.array:
197
+ """Evaluates the Bézier spline at time t."""
198
+
199
+ # choose correct control points
200
+ ind, t = segmentize(t)
201
+ P = self.control_points[ind]
202
+
203
+ return decasteljau(self._M, P, t)[0]
204
+
205
+ def DpB(self, t: float, X: jnp.array) -> jnp.array:
206
+ """Compute derivative of Bézier curve B(t) w.r.t. its control points applied to vector X, i.e.
207
+ the generalizd Jacobi field J(t).
208
+ :param t: time in [0, nSegments]
209
+ :param X: tangent vectors for each control point
210
+ :return: B(t), J(t)
211
+ """
212
+ # choose correct control points
213
+ ind, t = segmentize(t)
214
+ P = self.control_points[ind]
215
+
216
+ # (forward-mode) automatic differentiation of decasteljau(..)
217
+ f = lambda a: decasteljau(self._M, a, t)[0]
218
+ Bt, Jt = jax.jvp(f, (P,), (X[ind],))
219
+ return Bt, self._M.proj(Bt, Jt)
220
+
221
+ def adjDpB(self, t: float, X: jnp.array) -> jnp.array:
222
+ """Compute the value of the adjoint derivative of a Bézier curve B with respect to its control points applied
223
+ to the vector X.
224
+ :param t: scalar in [0, nSegments]
225
+ :param X: tangent vector at B(t)
226
+ :return: vectors at the control points
227
+ """
228
+
229
+ M = self._M
230
+ siz = list(X.shape)
231
+ # insert 1 in front
232
+ siz.insert(0, 1)
233
+
234
+ # t indicates which element of P to choose
235
+ ind, t = segmentize(t)
236
+ P = self.control_points[ind]
237
+
238
+ # number of control points of corresponding segment
239
+ k = len(P)
240
+
241
+ b, B = decasteljau(M, P, t)
242
+ # want to go backwards from B(t) to control points
243
+ B.reverse()
244
+
245
+ # initialize list for intermediate vectors
246
+ D = []
247
+ s = siz.copy()
248
+ for i in range(1, len(B) + 1):
249
+ s[0] = i + 1
250
+ D.append(jnp.zeros(s))
251
+
252
+ # transport X backwards along the "tree of geodesics" defined by the generalized de Casteljau algorithm.
253
+ # We iterate over the depth of the tree and add vectors from the same tangent space.
254
+ for i in range(k-1):
255
+ if i == 0:
256
+ D_old = jnp.zeros(siz)
257
+ D_old = D_old.at[0].set(X)
258
+ else:
259
+ D_old = D[i - 1]
260
+
261
+ siz = np.array(D_old.shape)
262
+ siz[0] *= 2
263
+ D_tilde = jnp.zeros(siz)
264
+ for jj in range(siz[0] // 2):
265
+ # transport to starting point of the geodesic
266
+ D_tilde = D_tilde.at[2 * jj].set(M.connec.adjDxgeo(B[i][jj], B[i][jj + 1], t, D_old[jj]))
267
+ # and to the endpoint
268
+ D_tilde = D_tilde.at[2 * jj + 1].set(M.connec.adjDygeo(B[i][jj], B[i][jj + 1], t, D_old[jj]))
269
+
270
+ D[i] = D[i].at[0].set(D_tilde[0])
271
+ D[i] = D[i].at[-1].set(D_tilde[-1])
272
+
273
+ # add up vectors
274
+ for jj in range(1, D[i].shape[0] - 1):
275
+ D[i] = D[i].at[jj].set(D_tilde[2 * jj - 1] + D_tilde[2 * jj])
276
+
277
+ # return D[-1]
278
+
279
+ grad = jnp.zeros_like(self.control_points)
280
+
281
+ # update the entries corresponding to the ind-th segment
282
+ grad = grad.at[ind].set(D[-1])
283
+
284
+ return grad
285
+
286
+
287
+ def segmentize(t: float) -> Tuple[int, float]:
288
+ """Choose the correct segment and value for the parameter t
289
+ :param t: scalar in [0, nsegments]
290
+ :return: index of corresponding control points in self.control_points and the adjusted value of t in [0,1]
291
+ """
292
+
293
+ def startpoint(t):
294
+ return int(0), t
295
+
296
+ def connecting_point(t):
297
+ return jnp.asarray(t, dtype=int) - 1, 1.
298
+
299
+ def inner_point(t):
300
+ return jnp.floor(t).astype(int), t - jnp.floor(t)
301
+
302
+ return lax.cond(t == 0, startpoint, lambda s: lax.cond(t == jnp.round(t), connecting_point, inner_point, s), t)
303
+
304
+
305
+ def decasteljau(M: Manifold, P: jnp.array, t: float) -> Tuple[jnp.array, List[jnp.array]]:
306
+ """Generalized de Casteljau algorithm
307
+ :param M: manifold
308
+ :param P: control points of curve beta
309
+ :param t: scalar in [0,1]
310
+ :return beta(t), (B): result of the de Casteljau algorithm with control points P, (intermediate points Bf in the algorithm)
311
+ """
312
+ # number of control points
313
+ k = len(P)
314
+
315
+ # init linearized tree of control points
316
+ B = jnp.concatenate([jnp.asarray(P)[i:] for i in range(k)])
317
+ # for lower-level control points: indices of parent ones w.r.t Bf
318
+ offset = [(2*k*n - n*n + n)//2 for n in range(k-1)]
319
+ idx = np.concatenate([np.arange(k-1-i)+o for i, o in enumerate(offset)])
320
+ # compute lower-level points
321
+ f = lambda B, io: (B.at[io[1]].set(M.connec.geopoint(B[io[0]], B[io[0]+1], t)), None)
322
+ B = lax.scan(f, B, np.c_[idx, k+np.arange(len(idx))])[0]
323
+
324
+ return B[-1], [B[o:o+k-i] for i, o in enumerate(offset)]
325
+
326
+
327
+ def full_set(M: Manifold, P, degrees, iscycle):
328
+ """Compute all control points of a C^1 Bézier spline from the independent ones."""
329
+ control_points = []
330
+ start = 0
331
+ for l, deg in enumerate(degrees):
332
+ if l == 0:
333
+ if not iscycle:
334
+ # all control points of the first segment are independent
335
+ control_points.append(P[:deg + 1])
336
+ start = start + deg + 1
337
+ else:
338
+ # add first two control points
339
+ C = jnp.vstack([jnp.expand_dims(P[-1], axis=0), jnp.expand_dims(M.connec.geopoint(P[-2], P[-1], 2),
340
+ axis=0), P[:deg - 1]])
341
+ control_points.append(C)
342
+ start = start + deg - 1
343
+ else:
344
+ C = jnp.vstack([jnp.expand_dims(control_points[-1][-1], axis=0),
345
+ jnp.expand_dims(M.connec.geopoint(control_points[-1][-2], control_points[-1][-1], 2), axis=0),
346
+ P[start:start + deg - 1]])
347
+ control_points.append(C)
348
+ start = start + deg - 1
349
+
350
+ return control_points
351
+
352
+
353
+ def indep_set(obj, iscycle):
354
+ """Return array with independent control points or gradients from full set."""
355
+ ind_pts = []
356
+ for l in range(len(obj)):
357
+ if l == 0 and not iscycle:
358
+ ind_pts.append(obj[0])
359
+ else:
360
+ ind_pts.append(obj[l, 2:])
361
+ return jnp.vstack(ind_pts)
@@ -0,0 +1,104 @@
1
+ ################################################################################
2
+ # #
3
+ # This file is part of the Morphomatics library #
4
+ # see https://github.com/morphomatics/morphomatics #
5
+ # #
6
+ # Copyright (C) 2024 Zuse Institute Berlin #
7
+ # #
8
+ # Morphomatics is distributed under the terms of the MIT License. #
9
+ # see $MORPHOMATICS/LICENSE #
10
+ # #
11
+ ################################################################################
12
+
13
+ import functools
14
+
15
+ import numpy as np
16
+ from scipy import sparse
17
+
18
+ def memoize(cache_name):
19
+ """Helper decorator memoizes the given zero-argument function.
20
+ Really helpful for memoizing properties so they don't have to be recomputed
21
+ dozens of times.
22
+ """
23
+ def memo_decorator(fn):
24
+ @functools.wraps(fn)
25
+ def memofn(self, *args, **kwargs):
26
+ cache = getattr(self, cache_name, None)
27
+ if id(fn) not in cache:
28
+ cache[id(fn)] = fn(self)
29
+ return cache[id(fn)]
30
+
31
+ return memofn
32
+ return memo_decorator
33
+
34
+ def gradient_matrix_ambient(verts, cells):
35
+ """
36
+ Compute gradient (represented in ambient space) matrix for Lagrange basis
37
+ on k-manifold simplicial geom with vertices \a verts and k-simplices \a cells
38
+ :return: sparse (d*m)-by-n gradient matrix, where d is dim. of vertices,
39
+ and m (n) is the number of triangles (vertices).
40
+ """
41
+ n = len(verts)
42
+ m = len(cells)
43
+ d = verts.shape[1]
44
+ k = cells.shape[1]-1
45
+
46
+ E = [verts[cells[:,i]] - verts[cells[:,k]] for i in range(k)]
47
+ M = np.matmul(np.stack(E, axis=1), np.stack(E, axis=2))
48
+ # TODO: use solve() instead of inv()
49
+ Minv = np.linalg.inv(M)
50
+ EMinv = np.matmul(np.stack(E, axis=2), Minv)
51
+ partials = np.zeros((k,k+1))
52
+ partials[:k,:k] = np.eye(k)
53
+ partials[:,k] = -1
54
+ # TODO: use np.einsum s.t. we don't need np.tile
55
+ D = np.matmul(EMinv, np.tile(partials, (m,1,1))).ravel()
56
+
57
+ I = np.repeat(np.arange(d*m), k+1)
58
+ J = np.repeat(cells, d, axis=0).ravel()
59
+ return sparse.csr_matrix((D, (I, J)), shape=(d*m, n))
60
+
61
+ def gradient_matrix_local(verts, cells):
62
+ """
63
+ Compute gradient matrix for Lagrange basis on d-manifold simplicial geom
64
+ with vertices \a verts and d-simplices \a cells.
65
+ Gradients will be represented in (d-dim.) local chart of each simplex.
66
+ :return: sparse (d*m)-by-n gradient matrix, where m (n) is the number of triangles (vertices),
67
+ and volumes of d-simplices
68
+ """
69
+ n = len(verts)
70
+ m = len(cells)
71
+ d = cells.shape[1] - 1
72
+
73
+ E = [verts[cells[:, i]] - verts[cells[:, d]] for i in range(d)]
74
+ # metric
75
+ M = np.matmul(np.stack(E, axis=1), np.stack(E, axis=2))
76
+ # (lower) cholesky factor of M
77
+ L = np.linalg.cholesky(M)
78
+
79
+ # partial derivatives for reference simplex
80
+ partials = np.zeros((d, d + 1))
81
+ partials[:d, :d] = np.eye(d)
82
+ partials[:, d] = -1
83
+
84
+ # gradient = inv(M)*partials
85
+ # change of variables: x -> L^T*x (s.t. M-inner product becomes standard one)
86
+ # togehter: L^T * inv(M) = inv(L)
87
+
88
+ # unroll forward substitution (no array-wise solve in numpy)
89
+ D = np.tile(partials, (m, 1))
90
+ for i in range(d):
91
+ for j in range(i):
92
+ D[i::2] -= D[j::d] * L.ravel()[i * d + j::d ** 2, None]
93
+ D[i::d] /= L.ravel()[i * d + i::d ** 2, None]
94
+
95
+ # set up gradient matrix
96
+ I = np.repeat(np.arange(d * m), d + 1)
97
+ J = np.repeat(cells, d, axis=0).ravel()
98
+ grad = sparse.csr_matrix((D.ravel(), (I, J)), shape=(d * m, n))
99
+
100
+ # volumes of d-dimplices (computing sqrt. of det(M) re-using L)
101
+ factorial = lambda d: np.prod(range(1, d + 1))
102
+ vol = np.diagonal(L, axis1=1, axis2=2).prod(axis=1) / factorial(d)
103
+
104
+ return grad, vol