morphomatics 4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics-4.0/LICENSE +9 -0
- morphomatics-4.0/PKG-INFO +55 -0
- morphomatics-4.0/README.md +28 -0
- morphomatics-4.0/morphomatics/__init__.py +13 -0
- morphomatics-4.0/morphomatics/geom/__init__.py +16 -0
- morphomatics-4.0/morphomatics/geom/bezier_spline.py +361 -0
- morphomatics-4.0/morphomatics/geom/misc.py +104 -0
- morphomatics-4.0/morphomatics/geom/surface.py +208 -0
- morphomatics-4.0/morphomatics/graph/__init__.py +13 -0
- morphomatics-4.0/morphomatics/graph/operators.py +124 -0
- morphomatics-4.0/morphomatics/manifold/__init__.py +46 -0
- morphomatics-4.0/morphomatics/manifold/bezierfold.py +500 -0
- morphomatics-4.0/morphomatics/manifold/connection.py +105 -0
- morphomatics-4.0/morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics-4.0/morphomatics/manifold/differential_coords.py +197 -0
- morphomatics-4.0/morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics-4.0/morphomatics/manifold/euclidean.py +213 -0
- morphomatics-4.0/morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics-4.0/morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics-4.0/morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics-4.0/morphomatics/manifold/grassmann.py +174 -0
- morphomatics-4.0/morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics-4.0/morphomatics/manifold/kendall.py +269 -0
- morphomatics-4.0/morphomatics/manifold/lie_group.py +102 -0
- morphomatics-4.0/morphomatics/manifold/manifold.py +162 -0
- morphomatics-4.0/morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics-4.0/morphomatics/manifold/metric.py +110 -0
- morphomatics-4.0/morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics-4.0/morphomatics/manifold/power_manifold.py +413 -0
- morphomatics-4.0/morphomatics/manifold/product_manifold.py +381 -0
- morphomatics-4.0/morphomatics/manifold/se_3.py +419 -0
- morphomatics-4.0/morphomatics/manifold/shape_space.py +57 -0
- morphomatics-4.0/morphomatics/manifold/so_3.py +494 -0
- morphomatics-4.0/morphomatics/manifold/spd.py +524 -0
- morphomatics-4.0/morphomatics/manifold/sphere.py +241 -0
- morphomatics-4.0/morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics-4.0/morphomatics/manifold/util.py +126 -0
- morphomatics-4.0/morphomatics/nn/__init__.py +15 -0
- morphomatics-4.0/morphomatics/nn/flow_layers.py +219 -0
- morphomatics-4.0/morphomatics/nn/tangent_layers.py +176 -0
- morphomatics-4.0/morphomatics/nn/train.py +202 -0
- morphomatics-4.0/morphomatics/nn/wFM_layers.py +152 -0
- morphomatics-4.0/morphomatics/opt/__init__.py +14 -0
- morphomatics-4.0/morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics-4.0/morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics-4.0/morphomatics/stats/__init__.py +18 -0
- morphomatics-4.0/morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics-4.0/morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics-4.0/morphomatics/stats/geometric_median.py +89 -0
- morphomatics-4.0/morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics-4.0/morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics-4.0/morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0/morphomatics.egg-info/PKG-INFO +55 -0
- morphomatics-4.0/morphomatics.egg-info/SOURCES.txt +57 -0
- morphomatics-4.0/morphomatics.egg-info/dependency_links.txt +1 -0
- morphomatics-4.0/morphomatics.egg-info/requires.txt +8 -0
- morphomatics-4.0/morphomatics.egg-info/top_level.txt +1 -0
- morphomatics-4.0/setup.cfg +4 -0
- morphomatics-4.0/setup.py +57 -0
morphomatics-4.0/LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (C) 2024 Zuse Institute Berlin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
6
|
+
|
|
7
|
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
8
|
+
|
|
9
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: morphomatics
|
|
3
|
+
Version: 4.0
|
|
4
|
+
Summary: Geometric morphometrics in non-Euclidean shape spaces
|
|
5
|
+
Home-page: https://morphomatics.github.io/
|
|
6
|
+
Author: Christoph von Tycowicz et al.
|
|
7
|
+
Author-email: vontycowicz@zib.de
|
|
8
|
+
License: MIT License
|
|
9
|
+
Keywords: Shape Analysis,Morphometrics,Geometric Statistics
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Topic :: Software Development :: Build Tools
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE
|
|
20
|
+
Requires-Dist: jax>=0.4.25
|
|
21
|
+
Requires-Dist: jaxlib>=0.4.25
|
|
22
|
+
Requires-Dist: jraph
|
|
23
|
+
Requires-Dist: flax
|
|
24
|
+
Requires-Dist: optax
|
|
25
|
+
Provides-Extra: all
|
|
26
|
+
Requires-Dist: pymanopt>=2.0.1; extra == "all"
|
|
27
|
+
|
|
28
|
+
<div align="center">
|
|
29
|
+
<img src="https://github.com/morphomatics/morphomatics.github.io/blob/master/images/logo_cyan.png?raw=true" width="250" alt="Morphomatics"/>
|
|
30
|
+
</div>
|
|
31
|
+
|
|
32
|
+
# Morphomatics: Geometric morphometrics in non-Euclidean shape spaces
|
|
33
|
+
|
|
34
|
+
Morphomatics is an open-source Python library for (statistical) shape analysis developed within the [geometric data analysis and processing](https://www.zib.de/visual/geometric-data-analysis-and-processing) research group at Zuse Institute Berlin.
|
|
35
|
+
It contains prototype implementations of intrinsic manifold-based methods that are highly consistent and avoid the influence of unwanted effects such as bias due to arbitrary choices of coordinates.
|
|
36
|
+
|
|
37
|
+
Detailed information and tutorials can be found at https://morphomatics.github.io/
|
|
38
|
+
|
|
39
|
+
## Installation
|
|
40
|
+
|
|
41
|
+
Morphomatics can be installed directly from github using the following command:
|
|
42
|
+
```
|
|
43
|
+
pip install git+https://github.com/morphomatics/morphomatics.git#egg=morphomatics
|
|
44
|
+
```
|
|
45
|
+
For instructions on how to set up `jaxlib`, please refer to the [JAX install guide](https://github.com/google/jax#installation).
|
|
46
|
+
|
|
47
|
+
## Dependencies
|
|
48
|
+
* jax/jaxlib
|
|
49
|
+
* jraph
|
|
50
|
+
* flax
|
|
51
|
+
* optax
|
|
52
|
+
|
|
53
|
+
Optional
|
|
54
|
+
* pymanopt
|
|
55
|
+
* sksparse
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
<div align="center">
|
|
2
|
+
<img src="https://github.com/morphomatics/morphomatics.github.io/blob/master/images/logo_cyan.png?raw=true" width="250" alt="Morphomatics"/>
|
|
3
|
+
</div>
|
|
4
|
+
|
|
5
|
+
# Morphomatics: Geometric morphometrics in non-Euclidean shape spaces
|
|
6
|
+
|
|
7
|
+
Morphomatics is an open-source Python library for (statistical) shape analysis developed within the [geometric data analysis and processing](https://www.zib.de/visual/geometric-data-analysis-and-processing) research group at Zuse Institute Berlin.
|
|
8
|
+
It contains prototype implementations of intrinsic manifold-based methods that are highly consistent and avoid the influence of unwanted effects such as bias due to arbitrary choices of coordinates.
|
|
9
|
+
|
|
10
|
+
Detailed information and tutorials can be found at https://morphomatics.github.io/
|
|
11
|
+
|
|
12
|
+
## Installation
|
|
13
|
+
|
|
14
|
+
Morphomatics can be installed directly from github using the following command:
|
|
15
|
+
```
|
|
16
|
+
pip install git+https://github.com/morphomatics/morphomatics.git#egg=morphomatics
|
|
17
|
+
```
|
|
18
|
+
For instructions on how to set up `jaxlib`, please refer to the [JAX install guide](https://github.com/google/jax#installation).
|
|
19
|
+
|
|
20
|
+
## Dependencies
|
|
21
|
+
* jax/jaxlib
|
|
22
|
+
* jraph
|
|
23
|
+
* flax
|
|
24
|
+
* optax
|
|
25
|
+
|
|
26
|
+
Optional
|
|
27
|
+
* pymanopt
|
|
28
|
+
* sksparse
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
__version__ = '4.0.dev0'
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
from .misc import memoize, gradient_matrix_ambient
|
|
14
|
+
|
|
15
|
+
from .surface import Surface
|
|
16
|
+
from .bezier_spline import BezierSpline
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
# postponed evaluation of annotations to circumvent cyclic dependencies (will be default behavior in Python 4.0)
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
# from morphomatics.manifold import Manifold
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import jax.lax as lax
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
from typing import Tuple, List
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BezierSpline:
|
|
27
|
+
"""Manifold-valued spline that consists of Bézier curves"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, M: Manifold, control_points: jnp.array, iscycle: bool = False):
|
|
30
|
+
"""
|
|
31
|
+
:arg M: manifold in which the curve lies
|
|
32
|
+
:arg control_points: array of control points of the Bézier spline, the L >= 1 segments must be sorted along the
|
|
33
|
+
first axis and all segments must have the same degree k; i.e., the input must be an [L, k, M.point_shape] array
|
|
34
|
+
:arg iscycle: boolean indicating whether B is a closed curve
|
|
35
|
+
"""
|
|
36
|
+
assert M is not None
|
|
37
|
+
|
|
38
|
+
self._M = M
|
|
39
|
+
|
|
40
|
+
self.control_points = jnp.asarray(control_points)
|
|
41
|
+
|
|
42
|
+
self.iscycle = iscycle
|
|
43
|
+
|
|
44
|
+
def __str__(self) -> str:
|
|
45
|
+
return 'Bézier spline through ' + str(self._M)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def nsegments(self) -> int:
|
|
49
|
+
"""Returns the number of segments."""
|
|
50
|
+
return len(self.control_points)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def degrees(self) -> jnp.array:
|
|
54
|
+
"""Returns the degrees of the spline segments."""
|
|
55
|
+
L = len(self.control_points)
|
|
56
|
+
n_seg = np.zeros(L, dtype=int)
|
|
57
|
+
for i in range(L):
|
|
58
|
+
n_seg[i] = np.shape(self.control_points[i])[0] - 1
|
|
59
|
+
return n_seg
|
|
60
|
+
|
|
61
|
+
def length(self) -> float:
|
|
62
|
+
# TODO
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
def energy(self) -> float:
|
|
66
|
+
# TODO
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
def tangent(self, t: float) -> jnp.array:
|
|
70
|
+
"""
|
|
71
|
+
Compute the tangent vector at the point of the spline corresponding to t.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def bezier_tangent(bet:BezierSpline, s):
|
|
75
|
+
"""
|
|
76
|
+
Compute the tangent vector at the point of a (single) Bèzier curve corresponding to t in [0, 1].
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def single_layer(A, r, X=None):
|
|
80
|
+
"""
|
|
81
|
+
Single layer of the computation consisting of a single step of the de Casteljau algorithm
|
|
82
|
+
plus additinal vectors transport/computation.
|
|
83
|
+
"""
|
|
84
|
+
if X is None:
|
|
85
|
+
# averaging of a single layer in de Casteljau algorithm
|
|
86
|
+
size = np.array(np.shape(A))
|
|
87
|
+
# give back one point less
|
|
88
|
+
size[0] = size[0] - 1
|
|
89
|
+
B = np.zeros(size)
|
|
90
|
+
for i in range(size[0]):
|
|
91
|
+
B[i] = self._M.exp(A[i], self._M.log(A[i], A[i + 1]) * r)
|
|
92
|
+
return B
|
|
93
|
+
|
|
94
|
+
else:
|
|
95
|
+
# averaging of a single layer in de Casteljau algorithm
|
|
96
|
+
size = np.array(np.shape(A))
|
|
97
|
+
# give back one point less
|
|
98
|
+
size[0] = size[0] - 1
|
|
99
|
+
B = np.zeros(size)
|
|
100
|
+
for i in range(size[0]):
|
|
101
|
+
B[i] = self._M.exp(A[i], self._M.log(A[i], A[i + 1]) * r)
|
|
102
|
+
|
|
103
|
+
# calculate updates of tangent vectors
|
|
104
|
+
X_shape = X.shape
|
|
105
|
+
X_shape[0] -= 1
|
|
106
|
+
Y = np.zeros(X_shape)
|
|
107
|
+
for i in range(len(Y)):
|
|
108
|
+
# new point is on geodesic between old control points -> log to endpoint shortened tangent vector
|
|
109
|
+
v = self._M.connec.log(B[ii], A[ii+1])
|
|
110
|
+
# rescale
|
|
111
|
+
v = v / self._M.metric.norm(B[ii], v) * self._M.metric.dist(bet.control_points[ii],
|
|
112
|
+
bet.control_points[ii + 1])
|
|
113
|
+
Y[i] += v
|
|
114
|
+
# add transported old vectors
|
|
115
|
+
# X[i] 'forward' X[i+1] 'backward'
|
|
116
|
+
Y[i] += self._M.connec.DxGeo(A[i], A[i+1], r, X[i])
|
|
117
|
+
Y[i] += self._M.connec.DyGeo(A[i], A[i + 1], r, X[i+1])
|
|
118
|
+
|
|
119
|
+
return B, Y
|
|
120
|
+
|
|
121
|
+
k = bet.degrees[0]
|
|
122
|
+
if s == 0:
|
|
123
|
+
return bet.eval(0), k * self._M.connec.log(bet.control_points[0][0], bet.control_points[0][1])
|
|
124
|
+
elif s == 1:
|
|
125
|
+
return bet.eval(1), -k * self._M.connec.log(bet.control_points[0][-1], bet.control_points[0][-2])
|
|
126
|
+
else:
|
|
127
|
+
P_old = bet.control_points[0]
|
|
128
|
+
P = single_layer(P_old, s)
|
|
129
|
+
|
|
130
|
+
X = np.zeros(k, self._M.zerovec().shape)
|
|
131
|
+
for ii in range(len(P)):
|
|
132
|
+
# new point is on geodesic between old control points -> log to endpoint shortened tangent vector
|
|
133
|
+
v = self._M.connec.log(P[ii], P_old[ii+1])
|
|
134
|
+
# rescale
|
|
135
|
+
X[ii] = v / self._M.connec.norm(P[ii], v) * self._M.metric.dist(P_old[ii], P_old[ii + 1])
|
|
136
|
+
|
|
137
|
+
# there are k+1 control points
|
|
138
|
+
for l in range(k):
|
|
139
|
+
P, X = single_layer(P, s, X)
|
|
140
|
+
|
|
141
|
+
return P, X
|
|
142
|
+
|
|
143
|
+
# get segment and local parameter
|
|
144
|
+
ind, t = segmentize(t)
|
|
145
|
+
|
|
146
|
+
return bezier_tangent(BezierSpline(self._M, [self.control_points[ind]]), t)
|
|
147
|
+
|
|
148
|
+
def isC1(self, eps: float = 1e-5) -> bool:
|
|
149
|
+
"""
|
|
150
|
+
Check whether the spline is (approximately) continuously differentible. For this, all control points that connect
|
|
151
|
+
two segments must be in the middle of their neighbours.
|
|
152
|
+
"""
|
|
153
|
+
cp = self.control_points
|
|
154
|
+
|
|
155
|
+
# trivial case: only one segment -> infinitly often differentible
|
|
156
|
+
if len(cp) == 1:
|
|
157
|
+
return True
|
|
158
|
+
|
|
159
|
+
for i, seg in enumerate(cp[1:]):
|
|
160
|
+
p = self._M.connec.geopoint(cp[i-1][-2], seg[1], 1/2)
|
|
161
|
+
# if midpoint and connecting control point are further apart than epsilon return False
|
|
162
|
+
if self._M.metric.dist(p, seg[0]) > eps:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
def geoshaped(self, eps: float = 1e-7) -> bool:
|
|
168
|
+
"""
|
|
169
|
+
Return whether the spline is a reparametrized geodesic. For this we test if all tangent vectors from the first
|
|
170
|
+
control point to the other control points are parallel (within a tolerance of epsilon).
|
|
171
|
+
"""
|
|
172
|
+
cp = self.control_points.copy()
|
|
173
|
+
|
|
174
|
+
# trivial case
|
|
175
|
+
if len(cp) == 1 and len(cp[0]) == 2:
|
|
176
|
+
return True
|
|
177
|
+
|
|
178
|
+
c = cp[0][0]
|
|
179
|
+
v0 = self._M.connec.log(c, cp[0][1])
|
|
180
|
+
cp[0] = cp[0][2:]
|
|
181
|
+
# check whether the logs at c to all other control points are parallel to v0
|
|
182
|
+
for seg in cp:
|
|
183
|
+
for cc in seg:
|
|
184
|
+
# ignore almost equal points---the test is unstable for them and their influence in non-geodecity is
|
|
185
|
+
# negligable
|
|
186
|
+
if self._M.metric.dist(c, cc) > 1e-7:
|
|
187
|
+
v = self._M.connec.log(c, cc)
|
|
188
|
+
par = self._M.metric.inner(c, v0, v) / (self._M.metric.norm(c, v0) * self._M.metric.norm(c, v))
|
|
189
|
+
|
|
190
|
+
if -1 + eps < par < 1 - eps:
|
|
191
|
+
# v and v0 are not parallel
|
|
192
|
+
return False
|
|
193
|
+
# all vectors were (almost) parallel
|
|
194
|
+
return True
|
|
195
|
+
|
|
196
|
+
def eval(self, t: float) -> jnp.array:
|
|
197
|
+
"""Evaluates the Bézier spline at time t."""
|
|
198
|
+
|
|
199
|
+
# choose correct control points
|
|
200
|
+
ind, t = segmentize(t)
|
|
201
|
+
P = self.control_points[ind]
|
|
202
|
+
|
|
203
|
+
return decasteljau(self._M, P, t)[0]
|
|
204
|
+
|
|
205
|
+
def DpB(self, t: float, X: jnp.array) -> jnp.array:
|
|
206
|
+
"""Compute derivative of Bézier curve B(t) w.r.t. its control points applied to vector X, i.e.
|
|
207
|
+
the generalizd Jacobi field J(t).
|
|
208
|
+
:param t: time in [0, nSegments]
|
|
209
|
+
:param X: tangent vectors for each control point
|
|
210
|
+
:return: B(t), J(t)
|
|
211
|
+
"""
|
|
212
|
+
# choose correct control points
|
|
213
|
+
ind, t = segmentize(t)
|
|
214
|
+
P = self.control_points[ind]
|
|
215
|
+
|
|
216
|
+
# (forward-mode) automatic differentiation of decasteljau(..)
|
|
217
|
+
f = lambda a: decasteljau(self._M, a, t)[0]
|
|
218
|
+
Bt, Jt = jax.jvp(f, (P,), (X[ind],))
|
|
219
|
+
return Bt, self._M.proj(Bt, Jt)
|
|
220
|
+
|
|
221
|
+
def adjDpB(self, t: float, X: jnp.array) -> jnp.array:
|
|
222
|
+
"""Compute the value of the adjoint derivative of a Bézier curve B with respect to its control points applied
|
|
223
|
+
to the vector X.
|
|
224
|
+
:param t: scalar in [0, nSegments]
|
|
225
|
+
:param X: tangent vector at B(t)
|
|
226
|
+
:return: vectors at the control points
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
M = self._M
|
|
230
|
+
siz = list(X.shape)
|
|
231
|
+
# insert 1 in front
|
|
232
|
+
siz.insert(0, 1)
|
|
233
|
+
|
|
234
|
+
# t indicates which element of P to choose
|
|
235
|
+
ind, t = segmentize(t)
|
|
236
|
+
P = self.control_points[ind]
|
|
237
|
+
|
|
238
|
+
# number of control points of corresponding segment
|
|
239
|
+
k = len(P)
|
|
240
|
+
|
|
241
|
+
b, B = decasteljau(M, P, t)
|
|
242
|
+
# want to go backwards from B(t) to control points
|
|
243
|
+
B.reverse()
|
|
244
|
+
|
|
245
|
+
# initialize list for intermediate vectors
|
|
246
|
+
D = []
|
|
247
|
+
s = siz.copy()
|
|
248
|
+
for i in range(1, len(B) + 1):
|
|
249
|
+
s[0] = i + 1
|
|
250
|
+
D.append(jnp.zeros(s))
|
|
251
|
+
|
|
252
|
+
# transport X backwards along the "tree of geodesics" defined by the generalized de Casteljau algorithm.
|
|
253
|
+
# We iterate over the depth of the tree and add vectors from the same tangent space.
|
|
254
|
+
for i in range(k-1):
|
|
255
|
+
if i == 0:
|
|
256
|
+
D_old = jnp.zeros(siz)
|
|
257
|
+
D_old = D_old.at[0].set(X)
|
|
258
|
+
else:
|
|
259
|
+
D_old = D[i - 1]
|
|
260
|
+
|
|
261
|
+
siz = np.array(D_old.shape)
|
|
262
|
+
siz[0] *= 2
|
|
263
|
+
D_tilde = jnp.zeros(siz)
|
|
264
|
+
for jj in range(siz[0] // 2):
|
|
265
|
+
# transport to starting point of the geodesic
|
|
266
|
+
D_tilde = D_tilde.at[2 * jj].set(M.connec.adjDxgeo(B[i][jj], B[i][jj + 1], t, D_old[jj]))
|
|
267
|
+
# and to the endpoint
|
|
268
|
+
D_tilde = D_tilde.at[2 * jj + 1].set(M.connec.adjDygeo(B[i][jj], B[i][jj + 1], t, D_old[jj]))
|
|
269
|
+
|
|
270
|
+
D[i] = D[i].at[0].set(D_tilde[0])
|
|
271
|
+
D[i] = D[i].at[-1].set(D_tilde[-1])
|
|
272
|
+
|
|
273
|
+
# add up vectors
|
|
274
|
+
for jj in range(1, D[i].shape[0] - 1):
|
|
275
|
+
D[i] = D[i].at[jj].set(D_tilde[2 * jj - 1] + D_tilde[2 * jj])
|
|
276
|
+
|
|
277
|
+
# return D[-1]
|
|
278
|
+
|
|
279
|
+
grad = jnp.zeros_like(self.control_points)
|
|
280
|
+
|
|
281
|
+
# update the entries corresponding to the ind-th segment
|
|
282
|
+
grad = grad.at[ind].set(D[-1])
|
|
283
|
+
|
|
284
|
+
return grad
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def segmentize(t: float) -> Tuple[int, float]:
|
|
288
|
+
"""Choose the correct segment and value for the parameter t
|
|
289
|
+
:param t: scalar in [0, nsegments]
|
|
290
|
+
:return: index of corresponding control points in self.control_points and the adjusted value of t in [0,1]
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
def startpoint(t):
|
|
294
|
+
return int(0), t
|
|
295
|
+
|
|
296
|
+
def connecting_point(t):
|
|
297
|
+
return jnp.asarray(t, dtype=int) - 1, 1.
|
|
298
|
+
|
|
299
|
+
def inner_point(t):
|
|
300
|
+
return jnp.floor(t).astype(int), t - jnp.floor(t)
|
|
301
|
+
|
|
302
|
+
return lax.cond(t == 0, startpoint, lambda s: lax.cond(t == jnp.round(t), connecting_point, inner_point, s), t)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def decasteljau(M: Manifold, P: jnp.array, t: float) -> Tuple[jnp.array, List[jnp.array]]:
|
|
306
|
+
"""Generalized de Casteljau algorithm
|
|
307
|
+
:param M: manifold
|
|
308
|
+
:param P: control points of curve beta
|
|
309
|
+
:param t: scalar in [0,1]
|
|
310
|
+
:return beta(t), (B): result of the de Casteljau algorithm with control points P, (intermediate points Bf in the algorithm)
|
|
311
|
+
"""
|
|
312
|
+
# number of control points
|
|
313
|
+
k = len(P)
|
|
314
|
+
|
|
315
|
+
# init linearized tree of control points
|
|
316
|
+
B = jnp.concatenate([jnp.asarray(P)[i:] for i in range(k)])
|
|
317
|
+
# for lower-level control points: indices of parent ones w.r.t Bf
|
|
318
|
+
offset = [(2*k*n - n*n + n)//2 for n in range(k-1)]
|
|
319
|
+
idx = np.concatenate([np.arange(k-1-i)+o for i, o in enumerate(offset)])
|
|
320
|
+
# compute lower-level points
|
|
321
|
+
f = lambda B, io: (B.at[io[1]].set(M.connec.geopoint(B[io[0]], B[io[0]+1], t)), None)
|
|
322
|
+
B = lax.scan(f, B, np.c_[idx, k+np.arange(len(idx))])[0]
|
|
323
|
+
|
|
324
|
+
return B[-1], [B[o:o+k-i] for i, o in enumerate(offset)]
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def full_set(M: Manifold, P, degrees, iscycle):
|
|
328
|
+
"""Compute all control points of a C^1 Bézier spline from the independent ones."""
|
|
329
|
+
control_points = []
|
|
330
|
+
start = 0
|
|
331
|
+
for l, deg in enumerate(degrees):
|
|
332
|
+
if l == 0:
|
|
333
|
+
if not iscycle:
|
|
334
|
+
# all control points of the first segment are independent
|
|
335
|
+
control_points.append(P[:deg + 1])
|
|
336
|
+
start = start + deg + 1
|
|
337
|
+
else:
|
|
338
|
+
# add first two control points
|
|
339
|
+
C = jnp.vstack([jnp.expand_dims(P[-1], axis=0), jnp.expand_dims(M.connec.geopoint(P[-2], P[-1], 2),
|
|
340
|
+
axis=0), P[:deg - 1]])
|
|
341
|
+
control_points.append(C)
|
|
342
|
+
start = start + deg - 1
|
|
343
|
+
else:
|
|
344
|
+
C = jnp.vstack([jnp.expand_dims(control_points[-1][-1], axis=0),
|
|
345
|
+
jnp.expand_dims(M.connec.geopoint(control_points[-1][-2], control_points[-1][-1], 2), axis=0),
|
|
346
|
+
P[start:start + deg - 1]])
|
|
347
|
+
control_points.append(C)
|
|
348
|
+
start = start + deg - 1
|
|
349
|
+
|
|
350
|
+
return control_points
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def indep_set(obj, iscycle):
|
|
354
|
+
"""Return array with independent control points or gradients from full set."""
|
|
355
|
+
ind_pts = []
|
|
356
|
+
for l in range(len(obj)):
|
|
357
|
+
if l == 0 and not iscycle:
|
|
358
|
+
ind_pts.append(obj[0])
|
|
359
|
+
else:
|
|
360
|
+
ind_pts.append(obj[l, 2:])
|
|
361
|
+
return jnp.vstack(ind_pts)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import functools
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from scipy import sparse
|
|
17
|
+
|
|
18
|
+
def memoize(cache_name):
|
|
19
|
+
"""Helper decorator memoizes the given zero-argument function.
|
|
20
|
+
Really helpful for memoizing properties so they don't have to be recomputed
|
|
21
|
+
dozens of times.
|
|
22
|
+
"""
|
|
23
|
+
def memo_decorator(fn):
|
|
24
|
+
@functools.wraps(fn)
|
|
25
|
+
def memofn(self, *args, **kwargs):
|
|
26
|
+
cache = getattr(self, cache_name, None)
|
|
27
|
+
if id(fn) not in cache:
|
|
28
|
+
cache[id(fn)] = fn(self)
|
|
29
|
+
return cache[id(fn)]
|
|
30
|
+
|
|
31
|
+
return memofn
|
|
32
|
+
return memo_decorator
|
|
33
|
+
|
|
34
|
+
def gradient_matrix_ambient(verts, cells):
|
|
35
|
+
"""
|
|
36
|
+
Compute gradient (represented in ambient space) matrix for Lagrange basis
|
|
37
|
+
on k-manifold simplicial geom with vertices \a verts and k-simplices \a cells
|
|
38
|
+
:return: sparse (d*m)-by-n gradient matrix, where d is dim. of vertices,
|
|
39
|
+
and m (n) is the number of triangles (vertices).
|
|
40
|
+
"""
|
|
41
|
+
n = len(verts)
|
|
42
|
+
m = len(cells)
|
|
43
|
+
d = verts.shape[1]
|
|
44
|
+
k = cells.shape[1]-1
|
|
45
|
+
|
|
46
|
+
E = [verts[cells[:,i]] - verts[cells[:,k]] for i in range(k)]
|
|
47
|
+
M = np.matmul(np.stack(E, axis=1), np.stack(E, axis=2))
|
|
48
|
+
# TODO: use solve() instead of inv()
|
|
49
|
+
Minv = np.linalg.inv(M)
|
|
50
|
+
EMinv = np.matmul(np.stack(E, axis=2), Minv)
|
|
51
|
+
partials = np.zeros((k,k+1))
|
|
52
|
+
partials[:k,:k] = np.eye(k)
|
|
53
|
+
partials[:,k] = -1
|
|
54
|
+
# TODO: use np.einsum s.t. we don't need np.tile
|
|
55
|
+
D = np.matmul(EMinv, np.tile(partials, (m,1,1))).ravel()
|
|
56
|
+
|
|
57
|
+
I = np.repeat(np.arange(d*m), k+1)
|
|
58
|
+
J = np.repeat(cells, d, axis=0).ravel()
|
|
59
|
+
return sparse.csr_matrix((D, (I, J)), shape=(d*m, n))
|
|
60
|
+
|
|
61
|
+
def gradient_matrix_local(verts, cells):
|
|
62
|
+
"""
|
|
63
|
+
Compute gradient matrix for Lagrange basis on d-manifold simplicial geom
|
|
64
|
+
with vertices \a verts and d-simplices \a cells.
|
|
65
|
+
Gradients will be represented in (d-dim.) local chart of each simplex.
|
|
66
|
+
:return: sparse (d*m)-by-n gradient matrix, where m (n) is the number of triangles (vertices),
|
|
67
|
+
and volumes of d-simplices
|
|
68
|
+
"""
|
|
69
|
+
n = len(verts)
|
|
70
|
+
m = len(cells)
|
|
71
|
+
d = cells.shape[1] - 1
|
|
72
|
+
|
|
73
|
+
E = [verts[cells[:, i]] - verts[cells[:, d]] for i in range(d)]
|
|
74
|
+
# metric
|
|
75
|
+
M = np.matmul(np.stack(E, axis=1), np.stack(E, axis=2))
|
|
76
|
+
# (lower) cholesky factor of M
|
|
77
|
+
L = np.linalg.cholesky(M)
|
|
78
|
+
|
|
79
|
+
# partial derivatives for reference simplex
|
|
80
|
+
partials = np.zeros((d, d + 1))
|
|
81
|
+
partials[:d, :d] = np.eye(d)
|
|
82
|
+
partials[:, d] = -1
|
|
83
|
+
|
|
84
|
+
# gradient = inv(M)*partials
|
|
85
|
+
# change of variables: x -> L^T*x (s.t. M-inner product becomes standard one)
|
|
86
|
+
# togehter: L^T * inv(M) = inv(L)
|
|
87
|
+
|
|
88
|
+
# unroll forward substitution (no array-wise solve in numpy)
|
|
89
|
+
D = np.tile(partials, (m, 1))
|
|
90
|
+
for i in range(d):
|
|
91
|
+
for j in range(i):
|
|
92
|
+
D[i::2] -= D[j::d] * L.ravel()[i * d + j::d ** 2, None]
|
|
93
|
+
D[i::d] /= L.ravel()[i * d + i::d ** 2, None]
|
|
94
|
+
|
|
95
|
+
# set up gradient matrix
|
|
96
|
+
I = np.repeat(np.arange(d * m), d + 1)
|
|
97
|
+
J = np.repeat(cells, d, axis=0).ravel()
|
|
98
|
+
grad = sparse.csr_matrix((D.ravel(), (I, J)), shape=(d * m, n))
|
|
99
|
+
|
|
100
|
+
# volumes of d-dimplices (computing sqrt. of det(M) re-using L)
|
|
101
|
+
factorial = lambda d: np.prod(range(1, d + 1))
|
|
102
|
+
vol = np.diagonal(L, axis1=1, axis2=2).prod(axis=1) / factorial(d)
|
|
103
|
+
|
|
104
|
+
return grad, vol
|