morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import abc
|
|
14
|
+
import jax
|
|
15
|
+
from jax.tree_util import register_pytree_node
|
|
16
|
+
|
|
17
|
+
import operator as op
|
|
18
|
+
|
|
19
|
+
from morphomatics.manifold import Metric, Connection, LieGroup
|
|
20
|
+
|
|
21
|
+
class ManifoldMeta(abc.ABCMeta):
|
|
22
|
+
"""Metaclass for abstract base class for which subclasses are to be registered as jax PyTree"""
|
|
23
|
+
def __new__(mcls, name, bases, namespace, **kwargs):
|
|
24
|
+
cls = super().__new__(mcls, name, bases, namespace, **kwargs)
|
|
25
|
+
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
|
|
26
|
+
return cls
|
|
27
|
+
|
|
28
|
+
class Manifold(metaclass=ManifoldMeta):
|
|
29
|
+
"""
|
|
30
|
+
Abstract base class setting out a template for manifold classes.
|
|
31
|
+
Morphomatics's Lie group and Riemannian manifold classes inherit from Manifold.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, name, dimension: int, point_shape,
|
|
35
|
+
connec: Connection = None, metric: Metric = None, group: LieGroup = None):
|
|
36
|
+
self._name = name
|
|
37
|
+
self._dimension = dimension
|
|
38
|
+
self._point_shape = point_shape
|
|
39
|
+
# (possibly) define a connection on the tangent bundle
|
|
40
|
+
self._connec = connec
|
|
41
|
+
# (possibly) define a metric on the tangent bundle
|
|
42
|
+
self._metric = metric
|
|
43
|
+
# (possibly) define a group operation turning the manifold into a Lie group
|
|
44
|
+
self._group = group
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
@abc.abstractmethod
|
|
48
|
+
def tree_unflatten(cls, aux_data, children):
|
|
49
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
50
|
+
|
|
51
|
+
def tree_unflatten_instance(self, aux_data, children):
|
|
52
|
+
"""Specifies an unflattening recipe given an instance (possibly of a subclass)."""
|
|
53
|
+
C, M, G = aux_data
|
|
54
|
+
Cdict, Mdict, Gdict = children
|
|
55
|
+
|
|
56
|
+
def setup(cls, dict):
|
|
57
|
+
if cls is None.__class__:
|
|
58
|
+
return None
|
|
59
|
+
obj = cls(self)
|
|
60
|
+
obj.__dict__.update(dict)
|
|
61
|
+
return obj
|
|
62
|
+
|
|
63
|
+
self._connec = setup(C, Cdict)
|
|
64
|
+
self._metric = setup(M, Mdict)
|
|
65
|
+
self._group = setup(G, Gdict)
|
|
66
|
+
|
|
67
|
+
def tree_flatten(self):
|
|
68
|
+
"""Specifies a flattening recipe for PyTree registration."""
|
|
69
|
+
aux_data = (self._connec.__class__, self._metric.__class__, self._group.__class__)
|
|
70
|
+
wo_ = lambda o: {} if o is None else {k: v for k, v in o.__dict__.items() if k[0] != '_'}
|
|
71
|
+
children = (wo_(self._connec), wo_(self._metric), wo_(self._group))
|
|
72
|
+
return (children, aux_data)
|
|
73
|
+
|
|
74
|
+
def __str__(self):
|
|
75
|
+
return self._name
|
|
76
|
+
|
|
77
|
+
def __repr__(self):
|
|
78
|
+
"""Returns a string representation of the particular manifold."""
|
|
79
|
+
conf = 'metric='+str(self._metric) if self._metric else ''
|
|
80
|
+
conf += ' connection='+str(self._connec) if self._connec else ''
|
|
81
|
+
conf += ' group='+str(self._group) if self._group else ''
|
|
82
|
+
if not conf:
|
|
83
|
+
return self._name
|
|
84
|
+
return f'{self._name} ({conf.strip()})'
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def dim(self):
|
|
88
|
+
"""The dimension of the manifold"""
|
|
89
|
+
return self._dimension
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def point_shape(self):
|
|
93
|
+
"""Dimensions of elements of the manifold.
|
|
94
|
+
|
|
95
|
+
Tuple of dimension, e.g., if an element is given by a 3-by-3 matrix, then its point shape is [3, 3].
|
|
96
|
+
"""
|
|
97
|
+
return self._point_shape
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def metric(self) -> Metric:
|
|
101
|
+
return self._metric
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def connec(self) -> Connection:
|
|
105
|
+
return self._connec
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def group(self) -> LieGroup:
|
|
109
|
+
return self._group
|
|
110
|
+
|
|
111
|
+
@abc.abstractmethod
|
|
112
|
+
def rand(self, key: jax.Array):
|
|
113
|
+
"""Returns a random point of the manifold."""
|
|
114
|
+
|
|
115
|
+
@abc.abstractmethod
|
|
116
|
+
def randvec(self, p, key: jax.Array):
|
|
117
|
+
"""Returns a random vector in the tangent space at p."""
|
|
118
|
+
|
|
119
|
+
@abc.abstractmethod
|
|
120
|
+
def zerovec(self):
|
|
121
|
+
"""Returns the zero vector in any tangent space."""
|
|
122
|
+
|
|
123
|
+
@abc.abstractmethod
|
|
124
|
+
def proj(self, p, X):
|
|
125
|
+
"""Projects a vector X in the ambient space on the tangent space at
|
|
126
|
+
p.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def projToGeodesic(self, X, Y, P, max_iter=10):
|
|
130
|
+
'''
|
|
131
|
+
Project P onto geodesic from X to Y.
|
|
132
|
+
|
|
133
|
+
See:
|
|
134
|
+
Felix Ambellan, Stefan Zachow, Christoph von Tycowicz.
|
|
135
|
+
Geodesic B-Score for Improved Assessment of Knee Osteoarthritis.
|
|
136
|
+
Proc. Information Processing in Medical Imaging (IPMI), LNCS, 2021.
|
|
137
|
+
|
|
138
|
+
:arg X, Y: manifold coords defining geodesic X->Y.
|
|
139
|
+
:arg P: manifold coords to be projected to X->Y.
|
|
140
|
+
:returns: manifold coords of projection of P to X->Y
|
|
141
|
+
'''
|
|
142
|
+
assert X.shape == Y.shape
|
|
143
|
+
assert Y.shape == P.shape
|
|
144
|
+
assert self.connec
|
|
145
|
+
assert self.metric
|
|
146
|
+
|
|
147
|
+
# initial guess
|
|
148
|
+
Pi = X.copy()
|
|
149
|
+
|
|
150
|
+
# solver loop
|
|
151
|
+
for _ in range(max_iter):
|
|
152
|
+
v = self.connec.log(Pi, Y)
|
|
153
|
+
v = v / self.metric.norm(Pi, v)
|
|
154
|
+
w = self.connec.log(Pi, P)
|
|
155
|
+
d = self.metric.inner(Pi, v, w)
|
|
156
|
+
|
|
157
|
+
# print(f'|<v, w>|={d}')
|
|
158
|
+
if abs(d) < 1e-6: break
|
|
159
|
+
|
|
160
|
+
Pi = self.connec.exp(Pi, d * v)
|
|
161
|
+
|
|
162
|
+
return Pi
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import functools
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import jax.tree_util as tree
|
|
20
|
+
|
|
21
|
+
from .manifold import Manifold
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import pymanopt
|
|
25
|
+
|
|
26
|
+
from pymanopt.manifolds.manifold import Manifold as ManoptManifold
|
|
27
|
+
from pymanopt.autodiff.backends._backend import Backend
|
|
28
|
+
from pymanopt.autodiff import backend_decorator_factory
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def conjugate_result(function):
|
|
32
|
+
@functools.wraps(function)
|
|
33
|
+
def wrapper(*args, **kwargs):
|
|
34
|
+
return tree.tree_map(jnp.conj, function(*args, **kwargs))
|
|
35
|
+
|
|
36
|
+
return wrapper
|
|
37
|
+
|
|
38
|
+
class JaxBackend(Backend):
|
|
39
|
+
def __init__(self):
|
|
40
|
+
super().__init__("Jax")
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def is_available():
|
|
44
|
+
return jax is not None
|
|
45
|
+
|
|
46
|
+
@Backend._assert_backend_available
|
|
47
|
+
def prepare_function(self, function):
|
|
48
|
+
return function
|
|
49
|
+
|
|
50
|
+
@Backend._assert_backend_available
|
|
51
|
+
def generate_gradient_operator(self, function, num_arguments):
|
|
52
|
+
gradient = conjugate_result(
|
|
53
|
+
jax.grad(function) if num_arguments == 1 else
|
|
54
|
+
jax.grad(function, argnums=list(range(num_arguments)))
|
|
55
|
+
)
|
|
56
|
+
return gradient
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _hessian_vector_product(function, argnum):
|
|
60
|
+
raise NotImplementedError()
|
|
61
|
+
|
|
62
|
+
@Backend._assert_backend_available
|
|
63
|
+
def generate_hessian_operator(self, function, num_arguments):
|
|
64
|
+
raise NotImplementedError()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if "jax" not in pymanopt.autodiff.backends.__all__:
|
|
68
|
+
factory = backend_decorator_factory(JaxBackend)
|
|
69
|
+
pymanopt.autodiff.backends.jax = factory
|
|
70
|
+
pymanopt.function.jax = factory
|
|
71
|
+
pymanopt.autodiff.backends.__all__.append("jax")
|
|
72
|
+
pymanopt.function.__all__.append("jax")
|
|
73
|
+
|
|
74
|
+
except ImportError:
|
|
75
|
+
_has_manopt = False
|
|
76
|
+
ManoptManifold = object
|
|
77
|
+
else:
|
|
78
|
+
_has_manopt = True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ManoptWrap(ManoptManifold):
|
|
82
|
+
"""
|
|
83
|
+
Wraper for pymanopt to make manifolds from morphomatics compatible.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, M: Manifold):
|
|
87
|
+
self._M = M
|
|
88
|
+
super().__init__(str(M), M.dim) # as of v0.2.6rc1
|
|
89
|
+
|
|
90
|
+
# Manifold properties that subclasses can define
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def typicaldist(self):
|
|
94
|
+
"""Returns the "scale" of the manifold. This is used by the
|
|
95
|
+
trust-regions solver to determine default initial and maximal
|
|
96
|
+
trust-region radii.
|
|
97
|
+
"""
|
|
98
|
+
return self._M.metric.typicaldist
|
|
99
|
+
|
|
100
|
+
def inner_product(self, X, G, H):
|
|
101
|
+
"""Returns the inner product (i.e., the Riemannian metric) between two
|
|
102
|
+
tangent vectors `P_G` and `H` in the tangent space at `X`.
|
|
103
|
+
"""
|
|
104
|
+
return self._M.metric.inner(X, G, H)
|
|
105
|
+
|
|
106
|
+
def dist(self, X, Y):
|
|
107
|
+
"""
|
|
108
|
+
Geodesic distance on the manifold
|
|
109
|
+
"""
|
|
110
|
+
return self._M.metric.dist(X, Y)
|
|
111
|
+
|
|
112
|
+
def projection(self, X, G):
|
|
113
|
+
"""Projects a vector `P_G` in the ambient space on the tangent space at
|
|
114
|
+
`X`.
|
|
115
|
+
"""
|
|
116
|
+
return self._M.metric.proj(X, G)
|
|
117
|
+
|
|
118
|
+
def euclidean_to_riemannian_gradient(self, X, G):
|
|
119
|
+
"""Maps the Euclidean gradient P_G in the ambient space on the tangent
|
|
120
|
+
space of the manifold at X.
|
|
121
|
+
"""
|
|
122
|
+
return self._M.metric.egrad2rgrad(X, G)
|
|
123
|
+
|
|
124
|
+
def norm(self, X, G):
|
|
125
|
+
"""Computes the norm of a tangent vector `P_G` in the tangent space at
|
|
126
|
+
`X`.
|
|
127
|
+
"""
|
|
128
|
+
return self._M.metric.norm(X, G)
|
|
129
|
+
|
|
130
|
+
def exp(self, X, U):
|
|
131
|
+
"""
|
|
132
|
+
The exponential (in the sense of Lie group theory) of a tangent
|
|
133
|
+
vector U at X.
|
|
134
|
+
"""
|
|
135
|
+
return self._M.connec.exp(X, U)
|
|
136
|
+
|
|
137
|
+
def retraction(self, X, G):
|
|
138
|
+
"""
|
|
139
|
+
A retraction mapping from the tangent space at X to the manifold.
|
|
140
|
+
See Absil for definition of retraction.
|
|
141
|
+
"""
|
|
142
|
+
return self.exp(X, G)
|
|
143
|
+
|
|
144
|
+
def log(self, X, Y):
|
|
145
|
+
"""
|
|
146
|
+
The logarithm (in the sense of Lie group theory) of Y. This is the
|
|
147
|
+
inverse of exp.
|
|
148
|
+
"""
|
|
149
|
+
return self._M.connec.log(X, Y)
|
|
150
|
+
|
|
151
|
+
def transport(self, x1, x2, d):
|
|
152
|
+
"""
|
|
153
|
+
Transports d, which is a tangent vector at x1, into the tangent
|
|
154
|
+
space at x2.
|
|
155
|
+
"""
|
|
156
|
+
return self._M.connec.transp(x1, x2, d)
|
|
157
|
+
|
|
158
|
+
def random_point(self):
|
|
159
|
+
"""Returns a random point on the manifold."""
|
|
160
|
+
key = jax.random.PRNGKey(np.random.randint(1 << 32))
|
|
161
|
+
return self._M.rand(key)
|
|
162
|
+
|
|
163
|
+
def random_tangent_vector(self, X):
|
|
164
|
+
"""Returns a random vector in the tangent space at `X`. This does not
|
|
165
|
+
follow a specific distribution.
|
|
166
|
+
"""
|
|
167
|
+
key = jax.random.PRNGKey(np.random.randint(1 << 32))
|
|
168
|
+
return self._M.randvec(X, key)
|
|
169
|
+
|
|
170
|
+
def zero_vector(self, X):
|
|
171
|
+
"""Returns the zero vector in the tangent space at X."""
|
|
172
|
+
return self._M.zerovec(X)
|
|
173
|
+
|
|
174
|
+
def euclidean_to_riemannian_hessian(self, p, grad, Hess, X):
|
|
175
|
+
"""
|
|
176
|
+
Convert Euclidean into Riemannian Hessian.
|
|
177
|
+
"""
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
def pair_mean(self, X, Y):
|
|
181
|
+
"""
|
|
182
|
+
Computes the intrinsic mean of X and Y, that is, a point that lies
|
|
183
|
+
mid-way between X and Y on the geodesic arc joining them.
|
|
184
|
+
"""
|
|
185
|
+
return self.exp(X, 0.5 * self.log(X, Y))
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import abc
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
from morphomatics.manifold.connection import Connection
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Metric(Connection):
|
|
22
|
+
"""
|
|
23
|
+
Interface setting out a template for a metric on the tangent bundle of a manifold. It is modelled as a subclass of
|
|
24
|
+
its Levi-Civita connection.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
@abc.abstractmethod
|
|
29
|
+
def typicaldist(self):
|
|
30
|
+
"""Returns the "scale" of the manifold. This is used by the
|
|
31
|
+
trust-regions solver to determine default initial and maximal
|
|
32
|
+
trust-region radii.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def dist(self, p, q):
|
|
37
|
+
"""Returns the geodesic distance between two points p and q on the
|
|
38
|
+
manifold."""
|
|
39
|
+
|
|
40
|
+
def squared_dist(self, p, q):
|
|
41
|
+
return self.dist(p, q)**2
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def inner(self, p, X, Y):
|
|
45
|
+
"""Returns the inner product (i.e., the Riemannian metric) between two
|
|
46
|
+
tangent vectors X and Y from the tangent space at p.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def egrad2rgrad(self, p, X):
|
|
51
|
+
"""Maps the Euclidean gradient X in the ambient space on the tangent
|
|
52
|
+
space of the manifold at p.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
@abc.abstractmethod
|
|
56
|
+
def ehess2rhess(self, p, G, H, X):
|
|
57
|
+
"""Converts the Euclidean gradient P_G and Hessian H of a function at
|
|
58
|
+
a point p along a tangent vector X to the Riemannian Hessian
|
|
59
|
+
along X on the manifold.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def norm(self, p, X):
|
|
63
|
+
"""Computes the norm of a tangent vector X in the tangent space at
|
|
64
|
+
p.
|
|
65
|
+
"""
|
|
66
|
+
return jnp.sqrt(self.inner(p, X, X))
|
|
67
|
+
|
|
68
|
+
@abc.abstractmethod
|
|
69
|
+
def flat(self, p, X):
|
|
70
|
+
"""Lower vector X at p with the metric"""
|
|
71
|
+
|
|
72
|
+
@abc.abstractmethod
|
|
73
|
+
def sharp(self, p, dX):
|
|
74
|
+
"""Raise covector dX at p with the metric"""
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def adjJacobi(self, p, q, t, X):
|
|
78
|
+
"""Evaluates an adjoint Jacobi field for the geodesic gam from p to q at p.
|
|
79
|
+
:param p: element of the Riemannian manifold
|
|
80
|
+
:param q: element of the Riemannian manifold
|
|
81
|
+
:param t: scalar in [0,1]
|
|
82
|
+
:param X: tangent vector at gam(t)
|
|
83
|
+
:return: tangent vector at p
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def adjDxgeo(self, p, q, t, X):
|
|
87
|
+
"""Evaluates the adjoint of the differential of the geodesic gamma from p to q w.r.t. the starting point p at X,
|
|
88
|
+
i.e, the adjoint of d_p gamma(t; ., q) applied to X, which is an element of the tangent space at p.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
return self.adjJacobi(p, q, t, X)
|
|
92
|
+
|
|
93
|
+
def adjDygeo(self, p, q, t, X):
|
|
94
|
+
"""Evaluates the adjoint of the differential of the geodesic gamma from p to q w.r.t. the endpoint q at X,
|
|
95
|
+
i.e, the adjoint of d_q gamma(t; p, .) applied to X, which is en element of the tangent space at q.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
return self.adjJacobi(q, p, 1 - t, X)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _eval_adjJacobi_embed(g: Metric, p, q, t, X):
|
|
102
|
+
""" Implementation of eval_adjJacobi for isometrically embedded manifolds using (forward-mode) automatic
|
|
103
|
+
differentiation of geopoint(..).
|
|
104
|
+
|
|
105
|
+
ATTENTION: the result must be projected to the tangent space!
|
|
106
|
+
"""
|
|
107
|
+
f = lambda O: g.geopoint(O, q, t)
|
|
108
|
+
gam, Jt = jax.vjp(f, p)
|
|
109
|
+
co_X = g.flat(gam, X)
|
|
110
|
+
return g.sharp(p, Jt(co_X)[0])
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from ..geom import Surface
|
|
18
|
+
from . import ShapeSpace, Metric
|
|
19
|
+
from .util import align
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PointDistributionModel(ShapeSpace, Metric):
|
|
23
|
+
""" Linear manifold space model. """
|
|
24
|
+
|
|
25
|
+
def __init__(self, reference: Surface):
|
|
26
|
+
"""
|
|
27
|
+
:arg reference: Reference surface (shapes will be encoded as deformations thereof)
|
|
28
|
+
"""
|
|
29
|
+
assert reference is not None
|
|
30
|
+
self.ref = reference
|
|
31
|
+
|
|
32
|
+
name = 'Point Distribution Model'
|
|
33
|
+
dimension = reference.v.size
|
|
34
|
+
point_shape = reference.v.shape
|
|
35
|
+
super().__init__(name, dimension, point_shape, self, self, None)
|
|
36
|
+
|
|
37
|
+
def tree_flatten(self):
|
|
38
|
+
return tuple(), (self.ref.v.tolist(), self.ref.f.tolist())
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def tree_unflatten(cls, aux_data, children):
|
|
42
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
43
|
+
return cls(Surface(*aux_data))
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def typicaldist(self):
|
|
47
|
+
return self.ref.v.std()
|
|
48
|
+
|
|
49
|
+
def update_ref_geom(self, v):
|
|
50
|
+
self.ref.v=v
|
|
51
|
+
|
|
52
|
+
def to_coords(self, v):
|
|
53
|
+
# align
|
|
54
|
+
return align(v, self.ref.v)
|
|
55
|
+
|
|
56
|
+
def from_coords(self, c):
|
|
57
|
+
return np.asarray(c)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def ref_coords(self):
|
|
61
|
+
return self.ref.v
|
|
62
|
+
|
|
63
|
+
def dist(self, p, q):
|
|
64
|
+
return jnp.sqrt(self.squared_dist(p, q))
|
|
65
|
+
|
|
66
|
+
def squared_dist(self, p, q):
|
|
67
|
+
return jnp.sum((p-q)**2)
|
|
68
|
+
|
|
69
|
+
def inner(self, p, X, Y):
|
|
70
|
+
"""
|
|
71
|
+
:arg X: (list of) tangent vector(s) at p
|
|
72
|
+
:arg Y: (list of) tangent vector(s) at p
|
|
73
|
+
:returns: inner product at p between P_G and H, i.e. <X,Y>_p
|
|
74
|
+
"""
|
|
75
|
+
return X.reshape(-1) @ Y.reshape(-1)
|
|
76
|
+
|
|
77
|
+
def flat(self, p, X):
|
|
78
|
+
return X
|
|
79
|
+
|
|
80
|
+
def sharp(self, p, dX):
|
|
81
|
+
return dX
|
|
82
|
+
|
|
83
|
+
def proj(self, p, X):
|
|
84
|
+
return X
|
|
85
|
+
|
|
86
|
+
def egrad2rgrad(self, p, X):
|
|
87
|
+
return X
|
|
88
|
+
|
|
89
|
+
def ehess2rhess(self, p, egrad, ehess, X):
|
|
90
|
+
return ehess
|
|
91
|
+
|
|
92
|
+
def exp(self, p, X):
|
|
93
|
+
return p + X
|
|
94
|
+
|
|
95
|
+
retr = exp
|
|
96
|
+
|
|
97
|
+
def log(self, p, q):
|
|
98
|
+
return q - p
|
|
99
|
+
|
|
100
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
101
|
+
return self.zerovec()
|
|
102
|
+
|
|
103
|
+
def rand(self, key: jax.Array):
|
|
104
|
+
v = jax.random.normal(key, self.ref.v.shape)
|
|
105
|
+
return self.to_coords(v)
|
|
106
|
+
|
|
107
|
+
def zerovec(self):
|
|
108
|
+
jnp.zeros(self.ref.v.shape)
|
|
109
|
+
|
|
110
|
+
def transp(self, p, q, X):
|
|
111
|
+
return X
|
|
112
|
+
|
|
113
|
+
def jacobiField(self, p, q, t, X):
|
|
114
|
+
"""Evaluates a Jacobi field (with boundary conditions gam(0) = X, gam(1) = 0) along the geodesic gam from p to q.
|
|
115
|
+
:param p: point
|
|
116
|
+
:param q: point
|
|
117
|
+
:param t: scalar in [0,1]
|
|
118
|
+
:param X: tangent vector at p
|
|
119
|
+
:return: tangent vector at gam(t)
|
|
120
|
+
"""
|
|
121
|
+
return (1-t) * X
|
|
122
|
+
|
|
123
|
+
def adjJacobi(self, p, q, t, X):
|
|
124
|
+
"""
|
|
125
|
+
:param p: point
|
|
126
|
+
:param q: point
|
|
127
|
+
:param t: scalar in [0, 1]
|
|
128
|
+
:param X: vector at gam(p,q,t)
|
|
129
|
+
:return: vector at p
|
|
130
|
+
"""
|
|
131
|
+
return X / (1.0 - t)
|
|
132
|
+
|
|
133
|
+
def projToGeodesic(self, p, q, m):
|
|
134
|
+
'''
|
|
135
|
+
:arg X, Y: manifold coords defining geodesic X->Y.
|
|
136
|
+
:arg P: manifold coords to be projected to X->Y.
|
|
137
|
+
:returns: manifold coords of projection of P to X->Y
|
|
138
|
+
'''
|
|
139
|
+
return super().projToGeodesic(p, q, m, max_iter=1)
|
|
140
|
+
|
|
141
|
+
def coords(self, X):
|
|
142
|
+
"""Coordinate map of the tangent space at the identity"""
|
|
143
|
+
return X.reshape(-1)
|