morphomatics 4.1__tar.gz → 4.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {morphomatics-4.1 → morphomatics-4.1.1}/PKG-INFO +1 -1
- morphomatics-4.1.1/morphomatics/manifold/diffeomorphism.py +156 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics.egg-info/PKG-INFO +1 -1
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics.egg-info/SOURCES.txt +1 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/setup.py +1 -1
- {morphomatics-4.1 → morphomatics-4.1.1}/LICENSE +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/README.md +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/correspondence/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/correspondence/convert.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/correspondence/laplacian.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/correspondence/refine.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/correspondence/util.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/geom/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/geom/bezier_spline.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/geom/misc.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/geom/surface.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/graph/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/graph/operators.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/bezierfold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/connection.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/cubic_bezierfold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/differential_coords.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/discrete_ops.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/euclidean.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/fundamental_coords.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/gl_p_coords.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/gl_p_n.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/grassmann.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/hyperbolic_space.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/kendall.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/lie_group.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/manifold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/manopt_wrapper.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/metric.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/point_distribution_model.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/power_manifold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/product_manifold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/se_3.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/shape_space.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/simplex.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/size_and_shape.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/so_3.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/spd.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/sphere.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/tangent_bundle.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/manifold/util.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/nn/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/nn/euclidean_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/nn/flow_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/nn/tangent_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/nn/train.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/nn/wFM_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/opt/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/opt/riemannian_newton_raphson.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/opt/riemannian_steepest_descent.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/biinvariant_dissimilarity_measures.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/biinvariant_regression.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/exponential_barycenter.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/geometric_median.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/principal_geodesic_analysis.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/riemannian_regression.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/statistical_shape_model.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics.egg-info/dependency_links.txt +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics.egg-info/requires.txt +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/morphomatics.egg-info/top_level.txt +0 -0
- {morphomatics-4.1 → morphomatics-4.1.1}/setup.cfg +0 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
from jax.experimental.ode import odeint
|
|
6
|
+
|
|
7
|
+
from morphomatics.manifold import Manifold, LieGroup
|
|
8
|
+
from .util import LazyKernel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Diffeomorphism(Manifold):
|
|
12
|
+
"""The diffeomorphism group, i.e. manifold of smooth invertible automorphisms of ambient space.
|
|
13
|
+
|
|
14
|
+
Diffeomorphisms model plausible deformations of the ambiant space
|
|
15
|
+
and, hence, any objects embedded therein. The parameter is a
|
|
16
|
+
discrete vector field (referred to as the momentum). The diffeomorhism is
|
|
17
|
+
given by integrating the smooth vector field obtained from the momentum by (kernel) smoothing.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, control_pts: jax.Array, scale: float = 0.1, structure="Group"):
|
|
21
|
+
"""Initialize the Diffeomorphism manifold.
|
|
22
|
+
Args:
|
|
23
|
+
control_pts (jax.Array): Control points defining the diffeomorphism.
|
|
24
|
+
scale (float): Scale of the kernel.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
if control_pts.ndim != 2:
|
|
28
|
+
raise ValueError("Control points must be a 2D array (shape: [n_points, n_features]).")
|
|
29
|
+
self._control_pts = control_pts
|
|
30
|
+
self._scale = scale
|
|
31
|
+
dim = np.prod(control_pts.shape)
|
|
32
|
+
|
|
33
|
+
name = 'Diffeomorphism group'
|
|
34
|
+
dimension = None
|
|
35
|
+
super().__init__(name, dim, point_shape=control_pts.shape)
|
|
36
|
+
|
|
37
|
+
if structure:
|
|
38
|
+
getattr(self, f'init{structure}Structure')()
|
|
39
|
+
|
|
40
|
+
def tree_flatten(self):
|
|
41
|
+
children, aux = super().tree_flatten()
|
|
42
|
+
return children + (self._control_pts, self._scale), aux
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def tree_unflatten(cls, aux_data, children):
|
|
46
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
47
|
+
*children, pts, scale = children
|
|
48
|
+
obj = cls(pts, scale, structure=None)
|
|
49
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
50
|
+
return obj
|
|
51
|
+
|
|
52
|
+
def initGroupStructure(self):
|
|
53
|
+
"""Initialize the group structure of the diffeomorphism manifold."""
|
|
54
|
+
self._group = Diffeomorphism.GroupStructure(self)
|
|
55
|
+
|
|
56
|
+
def proj(self, p, X):
|
|
57
|
+
return X
|
|
58
|
+
|
|
59
|
+
def rand(self, key: jax.Array):
|
|
60
|
+
return jax.random.normal(key, self.point_shape)
|
|
61
|
+
|
|
62
|
+
def randvec(self, X, key: jax.Array):
|
|
63
|
+
return jax.random.normal(key, self.point_shape)
|
|
64
|
+
|
|
65
|
+
def zerovec(self, X):
|
|
66
|
+
return jnp.zeros_like(X)
|
|
67
|
+
|
|
68
|
+
class GroupStructure(LieGroup):
|
|
69
|
+
"""Group structure for the diffeomorphism manifold."""
|
|
70
|
+
|
|
71
|
+
def __init__(self, M: 'Diffeomorphism'):
|
|
72
|
+
super().__init__(M)
|
|
73
|
+
|
|
74
|
+
def __str__(self):
|
|
75
|
+
return "Group structure on diffeomorphism manifold"
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def identity(self):
|
|
79
|
+
return jnp.zeros_like(self._M._control_pts)
|
|
80
|
+
|
|
81
|
+
def coords(self, X):
|
|
82
|
+
return X
|
|
83
|
+
|
|
84
|
+
def coords_inv(self, X):
|
|
85
|
+
return X
|
|
86
|
+
|
|
87
|
+
def bracket(self, X, Y):
|
|
88
|
+
raise NotImplementedError("Bracket operation is not implemented for diffeomorphisms.")
|
|
89
|
+
|
|
90
|
+
def lefttrans(self, g, f):
|
|
91
|
+
raise NotImplementedError("Left translation is not implemented for diffeomorphisms.")
|
|
92
|
+
|
|
93
|
+
def righttrans(self, g, f):
|
|
94
|
+
raise NotImplementedError("Right translation is not implemented for diffeomorphisms.")
|
|
95
|
+
|
|
96
|
+
def inverse(self, g):
|
|
97
|
+
raise NotImplementedError("Inverse is not implemented for diffeomorphisms.")
|
|
98
|
+
|
|
99
|
+
def exp(self, *argv):
|
|
100
|
+
if len(argv) == 2:
|
|
101
|
+
raise NotImplementedError("CCS connection exponential map is not implemented for diffeomorphisms.")
|
|
102
|
+
# we represent diffeomorphisms via momenta, so just return input momentum
|
|
103
|
+
return argv[0]
|
|
104
|
+
|
|
105
|
+
def log(self, *argv):
|
|
106
|
+
if len(argv) == 2:
|
|
107
|
+
raise NotImplementedError("CCS connection logarithmic map is not implemented for diffeomorphisms.")
|
|
108
|
+
# we represent diffeomorphisms via momenta, so just return input diffeomorphism
|
|
109
|
+
return argv[0]
|
|
110
|
+
|
|
111
|
+
def retr(self, p, X):
|
|
112
|
+
raise NotImplementedError("Retraction is not implemented for diffeomorphisms.")
|
|
113
|
+
|
|
114
|
+
def curvature_tensor(self, f, X, Y, Z):
|
|
115
|
+
raise NotImplementedError("Curvature tensor is not implemented for diffeomorphisms.")
|
|
116
|
+
|
|
117
|
+
def transp(self, f, g, X):
|
|
118
|
+
raise NotImplementedError("Parallel transport is not implemented for diffeomorphisms.")
|
|
119
|
+
|
|
120
|
+
def adjrep(self, g, X):
|
|
121
|
+
raise NotImplementedError("Adjoint representation is not implemented for diffeomorphisms.")
|
|
122
|
+
|
|
123
|
+
def jacobiField(self, p, q, t, X):
|
|
124
|
+
raise NotImplementedError("Jacobi field is not implemented for diffeomorphisms.")
|
|
125
|
+
|
|
126
|
+
def action(self, g, x, t = jnp.linspace(0, 1, 2)):
|
|
127
|
+
"""Apply the diffeomorphism to a point set x in the ambient space.
|
|
128
|
+
Args:
|
|
129
|
+
g (jax.Array): Diffeomorphism (represented by a momentum).
|
|
130
|
+
x (jax.Array): Points in the ambient space to be transformed.
|
|
131
|
+
t (jax.Array): Time points for the integration (strictly increasing).
|
|
132
|
+
NOTE: First entry in t will be ignored and assumed to be 0!
|
|
133
|
+
Returns:
|
|
134
|
+
jax.Array: Transformed points at times t."""
|
|
135
|
+
|
|
136
|
+
# kernel matrix
|
|
137
|
+
h = 1 / (2 * self._M._scale**2) # inverse kernel bandwidth
|
|
138
|
+
gaussian = lambda a, b: jnp.exp(-jnp.sum((a - b) ** 2) * h)
|
|
139
|
+
|
|
140
|
+
def F(y, t, *args):
|
|
141
|
+
"""ODE function."""
|
|
142
|
+
v, c, p = y
|
|
143
|
+
|
|
144
|
+
p_dot = LazyKernel(p, c, gaussian) @ v
|
|
145
|
+
|
|
146
|
+
# Hamiltonian function
|
|
147
|
+
H = lambda v, c: .5 * jnp.sum(v * (LazyKernel(c, c, gaussian) @ v))
|
|
148
|
+
Gv, Gc = jax.grad(H, argnums=(0,1))(v, c)
|
|
149
|
+
|
|
150
|
+
return -Gc, Gv, p_dot
|
|
151
|
+
|
|
152
|
+
# integrate F
|
|
153
|
+
_, _, x_morphed = odeint(F, (g, self._M._control_pts, x), t)
|
|
154
|
+
|
|
155
|
+
return x_morphed
|
|
156
|
+
|
|
@@ -22,6 +22,7 @@ morphomatics/manifold/__init__.py
|
|
|
22
22
|
morphomatics/manifold/bezierfold.py
|
|
23
23
|
morphomatics/manifold/connection.py
|
|
24
24
|
morphomatics/manifold/cubic_bezierfold.py
|
|
25
|
+
morphomatics/manifold/diffeomorphism.py
|
|
25
26
|
morphomatics/manifold/differential_coords.py
|
|
26
27
|
morphomatics/manifold/discrete_ops.py
|
|
27
28
|
morphomatics/manifold/euclidean.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{morphomatics-4.1 → morphomatics-4.1.1}/morphomatics/stats/biinvariant_dissimilarity_measures.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|