morphomatics 4.1__tar.gz → 4.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {morphomatics-4.1 → morphomatics-4.1.2}/PKG-INFO +1 -1
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/__init__.py +2 -2
- morphomatics-4.1.2/morphomatics/manifold/diffeomorphism.py +167 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/biinvariant_dissimilarity_measures.py +80 -77
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics.egg-info/PKG-INFO +1 -1
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics.egg-info/SOURCES.txt +1 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/setup.py +1 -1
- {morphomatics-4.1 → morphomatics-4.1.2}/LICENSE +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/README.md +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/correspondence/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/correspondence/convert.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/correspondence/laplacian.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/correspondence/refine.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/correspondence/util.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/geom/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/geom/bezier_spline.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/geom/misc.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/geom/surface.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/graph/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/graph/operators.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/bezierfold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/connection.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/cubic_bezierfold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/differential_coords.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/discrete_ops.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/euclidean.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/fundamental_coords.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/gl_p_coords.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/gl_p_n.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/grassmann.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/hyperbolic_space.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/kendall.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/lie_group.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/manifold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/manopt_wrapper.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/metric.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/point_distribution_model.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/power_manifold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/product_manifold.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/se_3.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/shape_space.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/simplex.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/size_and_shape.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/so_3.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/spd.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/sphere.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/tangent_bundle.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/manifold/util.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/nn/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/nn/euclidean_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/nn/flow_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/nn/tangent_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/nn/train.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/nn/wFM_layers.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/opt/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/opt/riemannian_newton_raphson.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/opt/riemannian_steepest_descent.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/__init__.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/biinvariant_regression.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/exponential_barycenter.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/geometric_median.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/principal_geodesic_analysis.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/riemannian_regression.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/statistical_shape_model.py +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics.egg-info/dependency_links.txt +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics.egg-info/requires.txt +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/morphomatics.egg-info/top_level.txt +0 -0
- {morphomatics-4.1 → morphomatics-4.1.2}/setup.cfg +0 -0
|
@@ -3,11 +3,11 @@
|
|
|
3
3
|
# This file is part of the Morphomatics library #
|
|
4
4
|
# see https://github.com/morphomatics/morphomatics #
|
|
5
5
|
# #
|
|
6
|
-
# Copyright (C)
|
|
6
|
+
# Copyright (C) 2025 Zuse Institute Berlin #
|
|
7
7
|
# #
|
|
8
8
|
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
9
|
# see $MORPHOMATICS/LICENSE #
|
|
10
10
|
# #
|
|
11
11
|
################################################################################
|
|
12
12
|
|
|
13
|
-
__version__ = '4.1'
|
|
13
|
+
__version__ = '4.1.2'
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2025 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import jax
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
|
|
17
|
+
from jax.experimental.ode import odeint
|
|
18
|
+
|
|
19
|
+
from morphomatics.manifold import Manifold, LieGroup
|
|
20
|
+
from .util import LazyKernel
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Diffeomorphism(Manifold):
|
|
24
|
+
"""The diffeomorphism group, i.e. manifold of smooth invertible automorphisms of ambient space.
|
|
25
|
+
|
|
26
|
+
Diffeomorphisms model plausible deformations of the ambiant space
|
|
27
|
+
and, hence, any objects embedded therein. The parameter is a
|
|
28
|
+
discrete vector field (referred to as the momentum). The diffeomorhism is
|
|
29
|
+
given by integrating the smooth vector field obtained from the momentum by (kernel) smoothing.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, control_pts: jax.Array, scale: float = 0.1, structure="Group"):
|
|
33
|
+
"""Initialize the Diffeomorphism manifold.
|
|
34
|
+
Args:
|
|
35
|
+
control_pts (jax.Array): Control points defining the diffeomorphism.
|
|
36
|
+
scale (float): Scale of the kernel.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
if control_pts.ndim != 2:
|
|
40
|
+
raise ValueError("Control points must be a 2D array (shape: [n_points, n_features]).")
|
|
41
|
+
self._control_pts = control_pts
|
|
42
|
+
self._scale = scale
|
|
43
|
+
dim = np.prod(control_pts.shape)
|
|
44
|
+
|
|
45
|
+
name = 'Diffeomorphism group'
|
|
46
|
+
dimension = None
|
|
47
|
+
super().__init__(name, dim, point_shape=control_pts.shape)
|
|
48
|
+
|
|
49
|
+
if structure:
|
|
50
|
+
getattr(self, f'init{structure}Structure')()
|
|
51
|
+
|
|
52
|
+
def tree_flatten(self):
|
|
53
|
+
children, aux = super().tree_flatten()
|
|
54
|
+
return children + (self._control_pts, self._scale), aux
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def tree_unflatten(cls, aux_data, children):
|
|
58
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
59
|
+
*children, pts, scale = children
|
|
60
|
+
obj = cls(pts, scale, structure=None)
|
|
61
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
62
|
+
return obj
|
|
63
|
+
|
|
64
|
+
def initGroupStructure(self):
|
|
65
|
+
"""Initialize the group structure of the diffeomorphism manifold."""
|
|
66
|
+
self._group = Diffeomorphism.GroupStructure(self)
|
|
67
|
+
|
|
68
|
+
def proj(self, p, X):
|
|
69
|
+
return X
|
|
70
|
+
|
|
71
|
+
def rand(self, key: jax.Array):
|
|
72
|
+
return jax.random.normal(key, self.point_shape)
|
|
73
|
+
|
|
74
|
+
def randvec(self, X, key: jax.Array):
|
|
75
|
+
return jax.random.normal(key, self.point_shape)
|
|
76
|
+
|
|
77
|
+
def zerovec(self, X):
|
|
78
|
+
return jnp.zeros_like(X)
|
|
79
|
+
|
|
80
|
+
class GroupStructure(LieGroup):
|
|
81
|
+
"""Group structure for the diffeomorphism manifold."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, M: 'Diffeomorphism'):
|
|
84
|
+
super().__init__(M)
|
|
85
|
+
|
|
86
|
+
def __str__(self):
|
|
87
|
+
return "Group structure on diffeomorphism manifold"
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def identity(self):
|
|
91
|
+
return jnp.zeros_like(self._M._control_pts)
|
|
92
|
+
|
|
93
|
+
def coords(self, X):
|
|
94
|
+
return X
|
|
95
|
+
|
|
96
|
+
def coords_inv(self, X):
|
|
97
|
+
return X
|
|
98
|
+
|
|
99
|
+
def bracket(self, X, Y):
|
|
100
|
+
raise NotImplementedError("Bracket operation is not implemented for diffeomorphisms.")
|
|
101
|
+
|
|
102
|
+
def lefttrans(self, g, f):
|
|
103
|
+
raise NotImplementedError("Left translation is not implemented for diffeomorphisms.")
|
|
104
|
+
|
|
105
|
+
def righttrans(self, g, f):
|
|
106
|
+
raise NotImplementedError("Right translation is not implemented for diffeomorphisms.")
|
|
107
|
+
|
|
108
|
+
def inverse(self, g):
|
|
109
|
+
raise NotImplementedError("Inverse is not implemented for diffeomorphisms.")
|
|
110
|
+
|
|
111
|
+
def exp(self, *argv):
|
|
112
|
+
if len(argv) == 2:
|
|
113
|
+
raise NotImplementedError("CCS connection exponential map is not implemented for diffeomorphisms.")
|
|
114
|
+
# we represent diffeomorphisms via momenta, so just return input momentum
|
|
115
|
+
return argv[0]
|
|
116
|
+
|
|
117
|
+
def log(self, *argv):
|
|
118
|
+
if len(argv) == 2:
|
|
119
|
+
raise NotImplementedError("CCS connection logarithmic map is not implemented for diffeomorphisms.")
|
|
120
|
+
# we represent diffeomorphisms via momenta, so just return input diffeomorphism
|
|
121
|
+
return argv[0]
|
|
122
|
+
|
|
123
|
+
def retr(self, p, X):
|
|
124
|
+
raise NotImplementedError("Retraction is not implemented for diffeomorphisms.")
|
|
125
|
+
|
|
126
|
+
def curvature_tensor(self, f, X, Y, Z):
|
|
127
|
+
raise NotImplementedError("Curvature tensor is not implemented for diffeomorphisms.")
|
|
128
|
+
|
|
129
|
+
def transp(self, f, g, X):
|
|
130
|
+
raise NotImplementedError("Parallel transport is not implemented for diffeomorphisms.")
|
|
131
|
+
|
|
132
|
+
def adjrep(self, g, X):
|
|
133
|
+
raise NotImplementedError("Adjoint representation is not implemented for diffeomorphisms.")
|
|
134
|
+
|
|
135
|
+
def jacobiField(self, p, q, t, X):
|
|
136
|
+
raise NotImplementedError("Jacobi field is not implemented for diffeomorphisms.")
|
|
137
|
+
|
|
138
|
+
def action(self, g, x, t = jnp.linspace(0, 1, 2)):
|
|
139
|
+
"""Apply the diffeomorphism to a point set x in the ambient space.
|
|
140
|
+
Args:
|
|
141
|
+
g (jax.Array): Diffeomorphism (represented by a momentum).
|
|
142
|
+
x (jax.Array): Points in the ambient space to be transformed.
|
|
143
|
+
t (jax.Array): Time points for the integration (strictly increasing).
|
|
144
|
+
NOTE: First entry in t will be ignored and assumed to be 0!
|
|
145
|
+
Returns:
|
|
146
|
+
jax.Array: Transformed points at times t."""
|
|
147
|
+
|
|
148
|
+
# kernel matrix
|
|
149
|
+
h = 1 / (2 * self._M._scale**2) # inverse kernel bandwidth
|
|
150
|
+
gaussian = lambda a, b: jnp.exp(-jnp.sum((a - b) ** 2) * h)
|
|
151
|
+
|
|
152
|
+
def F(y, t, *args):
|
|
153
|
+
"""ODE function."""
|
|
154
|
+
v, c, p = y
|
|
155
|
+
|
|
156
|
+
p_dot = LazyKernel(p, c, gaussian) @ v
|
|
157
|
+
|
|
158
|
+
# Hamiltonian function
|
|
159
|
+
H = lambda v, c: .5 * jnp.sum(v * (LazyKernel(c, c, gaussian) @ v))
|
|
160
|
+
Gv, Gc = jax.grad(H, argnums=(0,1))(v, c)
|
|
161
|
+
|
|
162
|
+
return -Gc, Gv, p_dot
|
|
163
|
+
|
|
164
|
+
# integrate F
|
|
165
|
+
_, _, x_morphed = odeint(F, (g, self._M._control_pts, x), t)
|
|
166
|
+
|
|
167
|
+
return x_morphed
|
{morphomatics-4.1 → morphomatics-4.1.2}/morphomatics/stats/biinvariant_dissimilarity_measures.py
RENAMED
|
@@ -12,12 +12,13 @@
|
|
|
12
12
|
|
|
13
13
|
# postponed evaluation of annotations to circumvent cyclic dependencies (will be default behavior in Python 4.0)
|
|
14
14
|
from __future__ import annotations
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
|
|
17
|
+
import jax
|
|
15
18
|
|
|
16
19
|
import jax.numpy as jnp
|
|
17
20
|
import jax.numpy.linalg as jla
|
|
18
|
-
from jax import random
|
|
19
21
|
|
|
20
|
-
# from morphomatics.manifold import LieGroup
|
|
21
22
|
from morphomatics.stats import ExponentialBarycenter as Mean
|
|
22
23
|
|
|
23
24
|
|
|
@@ -27,7 +28,7 @@ class BiinvariantDissimilarityMeasures(object):
|
|
|
27
28
|
Hanik, Hege, and von Tycowicz (2020): Bi-invariant Two-Sample Tests in Lie Groups for Shape Analysis
|
|
28
29
|
"""
|
|
29
30
|
|
|
30
|
-
def __init__(self, G: Manifold, variant='left'):
|
|
31
|
+
def __init__(self, G: Manifold, variant: str ='left'):
|
|
31
32
|
"""
|
|
32
33
|
:param G: Lie group
|
|
33
34
|
:param variant: indicate whether all tangent vectors are left (variants='left') or right (variants='right')
|
|
@@ -36,9 +37,17 @@ class BiinvariantDissimilarityMeasures(object):
|
|
|
36
37
|
|
|
37
38
|
self.G = G
|
|
38
39
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
if variant == 'left':
|
|
41
|
+
self.translation = G.group.lefttrans
|
|
42
|
+
else:
|
|
43
|
+
self.translation = G.group.righttrans
|
|
44
|
+
|
|
45
|
+
def two_sample_test(self,
|
|
46
|
+
data_A: jnp.array,
|
|
47
|
+
data_B: jnp.array,
|
|
48
|
+
measure: str,
|
|
49
|
+
n_permutations: int,
|
|
50
|
+
key: jax.random.key) -> Tuple[jnp.array, jnp.array, jnp.array]:
|
|
42
51
|
"""Bi-invariant two-sample permutation test for data in G.
|
|
43
52
|
Null hypothesis: 'Means of distributions underlying the 2 data sets are equal' if Hotelling T2 statistic is used
|
|
44
53
|
'Means and covariance underlying the 2 data sets are equal' if Bhattacharyya distance is used
|
|
@@ -46,6 +55,7 @@ class BiinvariantDissimilarityMeasures(object):
|
|
|
46
55
|
:param data_B: data array of second set; data is sorted along first axis
|
|
47
56
|
:param measure: indicate which measure to use; 'hotelling' and 'bhattacharyya' are possible
|
|
48
57
|
:param n_permutations: number of permutations performed for the test
|
|
58
|
+
:param key: random key
|
|
49
59
|
:return: p-value, original distance d_orig between data, vector d-perm of distances between permuted data sets
|
|
50
60
|
"""
|
|
51
61
|
if measure == 'hotelling':
|
|
@@ -59,82 +69,73 @@ class BiinvariantDissimilarityMeasures(object):
|
|
|
59
69
|
d_orig = distMeasure(data_A, data_B)
|
|
60
70
|
|
|
61
71
|
D = jnp.concatenate((data_A, data_B), axis=0)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
# def shuffle(A):
|
|
67
|
-
# A_perm = np.random.permutation(A)
|
|
68
|
-
# return distMeasure(A_perm[:n], A_perm[n:])
|
|
69
|
-
#
|
|
70
|
-
# with Parallel(n_jobs=-1, prefer='threads', verbose=0) as parallel:
|
|
71
|
-
# d_perm = parallel(delayed(shuffle)(D) for _ in range(n_permutations))
|
|
72
|
-
|
|
73
|
-
key = random.PRNGKey(0)
|
|
74
|
-
# permute and recompute
|
|
75
|
-
for i in range(n_permutations):
|
|
76
|
-
key, subkey = random.split(key)
|
|
77
|
-
# permute along first axis
|
|
78
|
-
D_perm = random.permutation(key, D)
|
|
72
|
+
|
|
73
|
+
def permute_and_recompute(key_):
|
|
74
|
+
# mix data
|
|
75
|
+
D_perm = jax.random.permutation(key_, D)
|
|
79
76
|
# distance between shuffled groups
|
|
80
|
-
|
|
81
|
-
d_perm.append(d_perm_i)
|
|
77
|
+
return distMeasure(D_perm[:n], D_perm[n:])
|
|
82
78
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
79
|
+
# vectorize
|
|
80
|
+
random_keys = jax.random.split(key, n_permutations)
|
|
81
|
+
d_perm = jax.vmap(permute_and_recompute)(random_keys)
|
|
86
82
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
counter = counter + 1
|
|
83
|
+
# p-value, i.e., approximate probability of observing d_orig under the null hypothesis
|
|
84
|
+
p_value = jnp.count_nonzero(d_perm > d_orig) / (n_permutations + 1)
|
|
90
85
|
|
|
91
|
-
return
|
|
86
|
+
return p_value, d_orig, d_perm
|
|
92
87
|
|
|
93
|
-
def groupmean(self, data):
|
|
88
|
+
def groupmean(self, data: jnp.array) -> jnp.array:
|
|
94
89
|
"""
|
|
95
|
-
:param data: array of elements in G
|
|
90
|
+
:param data: array of (sufficiently close) elements in G
|
|
96
91
|
:return: group mean of data points
|
|
97
92
|
"""
|
|
98
|
-
return Mean.compute(self.G, data)
|
|
93
|
+
return Mean.compute(self.G, data, max_iter=100)
|
|
99
94
|
|
|
100
|
-
def mahalanobisdist(self, A, g):
|
|
95
|
+
def mahalanobisdist(self, A: jnp.array, g: jnp.array) -> jnp.array:
|
|
101
96
|
""" Bi-invariant Mahalanobis distance in G
|
|
102
97
|
:param A: array of data points in G
|
|
103
|
-
:param g: element
|
|
98
|
+
:param g: element in G
|
|
104
99
|
:return: Mahalanobis distance of g to the distribution of the data points in A
|
|
105
100
|
"""
|
|
106
101
|
|
|
107
|
-
|
|
108
|
-
|
|
102
|
+
S, mean = self.centralized_sample_covariance(A)
|
|
103
|
+
c = self.G.group.coords(self.diff_at_e(mean, g))
|
|
109
104
|
|
|
110
|
-
|
|
105
|
+
x = jla.solve(S, c)
|
|
106
|
+
return jnp.sqrt(jnp.inner(c.squeeze(), x.squeeze()))
|
|
111
107
|
|
|
112
|
-
def hotellingT2(self, A, B):
|
|
108
|
+
def hotellingT2(self, A: jnp.array, B: jnp.array) -> jnp.array:
|
|
113
109
|
""" Bi-invariant Hotelling T^2 statistic in G
|
|
114
110
|
:param A: array of data points in G
|
|
115
111
|
:param B: array of data points in G
|
|
116
112
|
:return: Hotelling T^2 statistic between the distribution of the samples in A and B
|
|
117
113
|
"""
|
|
118
114
|
m, n = len(A), len(B)
|
|
119
|
-
|
|
120
|
-
|
|
115
|
+
S_pool, _, _, mean_A, mean_B = self.pooled_sample_covariance(A, B)
|
|
116
|
+
|
|
117
|
+
c = self.G.group.coords(self.diff_at_e(mean_A, mean_B))
|
|
118
|
+
x = jla.solve(S_pool, c)
|
|
121
119
|
|
|
122
|
-
return m*n/(m+n) *
|
|
120
|
+
return m*n/(m+n) * jnp.inner(c.squeeze(), x.squeeze())
|
|
123
121
|
|
|
124
|
-
def bhattacharyya(self, A, B):
|
|
122
|
+
def bhattacharyya(self, A: jnp.array, B: jnp.array) -> jnp.array:
|
|
125
123
|
""" Bi-invariant Bhattacharyya distance in G
|
|
126
124
|
:param A: array of data points in G
|
|
127
125
|
:param B: array of data points in G
|
|
128
126
|
:return: Bhattacharyya distance between the distribution of the samples in A and B
|
|
129
127
|
"""
|
|
130
|
-
|
|
131
|
-
x = self.G.group.coords(self.diff_at_e(mean_A, mean_B))
|
|
132
|
-
D_B = 1/8 * x.transpose() @ jla.inv(C_avg) @ x \
|
|
133
|
-
+ 1/2 * jnp.log(jla.det(C_avg) / jnp.sqrt(jla.det(C_A) * jla.det(C_B)))
|
|
128
|
+
S_avg, S_A, S_B, mean_A, mean_B = self.averaged_sample_covariance(A, B)
|
|
134
129
|
|
|
135
|
-
|
|
130
|
+
c = self.G.group.coords(self.diff_at_e(mean_A, mean_B))
|
|
131
|
+
x = jla.solve(S_avg, c)
|
|
136
132
|
|
|
137
|
-
|
|
133
|
+
D_B = 1 / 8 * jnp.inner(c.squeeze(), x.squeeze()) \
|
|
134
|
+
+ 1 / 2 * jnp.log(jla.det(S_avg) / jnp.sqrt(jla.det(S_A) * jla.det(S_B)))
|
|
135
|
+
|
|
136
|
+
return D_B
|
|
137
|
+
|
|
138
|
+
def centralized_sample_covariance(self, A: jnp.array) -> jnp.array:
|
|
138
139
|
""" Centralized sample covariance of G–valued data
|
|
139
140
|
:param A: array of data points in G
|
|
140
141
|
:return: covariance matrix defined on (coordinate representations of) tangent vectors at the identity
|
|
@@ -142,18 +143,22 @@ class BiinvariantDissimilarityMeasures(object):
|
|
|
142
143
|
m = len(A)
|
|
143
144
|
# mean of data
|
|
144
145
|
mean = self.groupmean(A)
|
|
145
|
-
# inverse only once
|
|
146
|
+
# compute the inverse only once
|
|
146
147
|
mean_inv = self.G.group.inverse(mean)
|
|
147
|
-
# set up covariance matrix
|
|
148
|
-
C = jnp.zeros((self.G.dim, self.G.dim))
|
|
149
148
|
|
|
150
|
-
|
|
151
|
-
x = self.
|
|
152
|
-
|
|
149
|
+
def outer_prod(a):
|
|
150
|
+
x = self.translation(a, mean_inv)
|
|
151
|
+
x = self.G.group.coords(self.G.group.log(x))
|
|
152
|
+
return jnp.outer(x, x)
|
|
153
|
+
|
|
154
|
+
# covariance matrix
|
|
155
|
+
S = jax.vmap(outer_prod)(A)
|
|
156
|
+
S = S.sum(axis=0) / m
|
|
153
157
|
|
|
154
|
-
return
|
|
158
|
+
return S, mean
|
|
155
159
|
|
|
156
|
-
def pooled_sample_covariance(self, A, B)
|
|
160
|
+
def pooled_sample_covariance(self, A: jnp.array, B: jnp.array) \
|
|
161
|
+
-> Tuple[jnp.array, jnp.array, jnp.array, jnp.array, jnp.array]:
|
|
157
162
|
"""Pooled sample covariance of two data sets in G.
|
|
158
163
|
:param A: array of data points
|
|
159
164
|
:param B: array of data points
|
|
@@ -161,33 +166,31 @@ class BiinvariantDissimilarityMeasures(object):
|
|
|
161
166
|
"""
|
|
162
167
|
m, n = len(A), len(B)
|
|
163
168
|
|
|
164
|
-
|
|
165
|
-
|
|
169
|
+
S_A, mean_A = self.centralized_sample_covariance(A)
|
|
170
|
+
S_B, mean_B = self.centralized_sample_covariance(B)
|
|
166
171
|
|
|
167
|
-
|
|
168
|
-
return
|
|
172
|
+
S_pool = 1 / (m + n - 2) * (m * S_A + n * S_B)
|
|
173
|
+
return S_pool, S_A, S_B, mean_A, mean_B
|
|
169
174
|
|
|
170
|
-
def averaged_sample_covariance(self, A, B)
|
|
175
|
+
def averaged_sample_covariance(self, A: jnp.array, B: jnp.array) \
|
|
176
|
+
-> Tuple[jnp.array, jnp.array, jnp.array, jnp.array, jnp.array]:
|
|
171
177
|
"""Averaged sample covariance of two data sets in G.
|
|
172
178
|
:param A: array of data points
|
|
173
179
|
:param B: array of data points
|
|
174
180
|
:return: covariance operator acting on vectors in the tangent space at the identity
|
|
175
181
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
182
|
+
S_A, mean_A = self.centralized_sample_covariance(A)
|
|
183
|
+
S_B, mean_B = self.centralized_sample_covariance(B)
|
|
178
184
|
|
|
179
|
-
|
|
180
|
-
return
|
|
185
|
+
S_avg = 1 / 2 * (S_A + S_B)
|
|
186
|
+
return S_avg, S_A, S_B, mean_A, mean_B
|
|
181
187
|
|
|
182
|
-
def diff_at_e(self,
|
|
188
|
+
def diff_at_e(self, f: jnp.array, g: jnp.array) -> jnp.array:
|
|
183
189
|
""" "Difference vector" between two elements in G after translating to a neighborhood of the
|
|
184
190
|
identity e.
|
|
185
|
-
:param
|
|
186
|
-
:param
|
|
187
|
-
:return: group logarithm after translating with
|
|
191
|
+
:param f: element of G
|
|
192
|
+
:param g: element of G
|
|
193
|
+
:return: group logarithm after translating with f^(-1).
|
|
188
194
|
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
else:
|
|
192
|
-
x = self.G.group.log(self.G.group.righttrans(b, self.G.group.inverse(a)))
|
|
193
|
-
return x
|
|
195
|
+
f_inv = self.G.group.inverse(f)
|
|
196
|
+
return self.G.group.log(self.translation(g, f_inv))
|
|
@@ -22,6 +22,7 @@ morphomatics/manifold/__init__.py
|
|
|
22
22
|
morphomatics/manifold/bezierfold.py
|
|
23
23
|
morphomatics/manifold/connection.py
|
|
24
24
|
morphomatics/manifold/cubic_bezierfold.py
|
|
25
|
+
morphomatics/manifold/diffeomorphism.py
|
|
25
26
|
morphomatics/manifold/differential_coords.py
|
|
26
27
|
morphomatics/manifold/discrete_ops.py
|
|
27
28
|
morphomatics/manifold/euclidean.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|