morphomatics 4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphomatics/__init__.py +13 -0
- morphomatics/geom/__init__.py +16 -0
- morphomatics/geom/bezier_spline.py +361 -0
- morphomatics/geom/misc.py +104 -0
- morphomatics/geom/surface.py +208 -0
- morphomatics/graph/__init__.py +13 -0
- morphomatics/graph/operators.py +124 -0
- morphomatics/manifold/__init__.py +46 -0
- morphomatics/manifold/bezierfold.py +500 -0
- morphomatics/manifold/connection.py +105 -0
- morphomatics/manifold/cubic_bezierfold.py +305 -0
- morphomatics/manifold/differential_coords.py +197 -0
- morphomatics/manifold/discrete_ops.py +56 -0
- morphomatics/manifold/euclidean.py +213 -0
- morphomatics/manifold/fundamental_coords.py +440 -0
- morphomatics/manifold/gl_p_coords.py +149 -0
- morphomatics/manifold/gl_p_n.py +201 -0
- morphomatics/manifold/grassmann.py +174 -0
- morphomatics/manifold/hyperbolic_space.py +271 -0
- morphomatics/manifold/kendall.py +269 -0
- morphomatics/manifold/lie_group.py +102 -0
- morphomatics/manifold/manifold.py +162 -0
- morphomatics/manifold/manopt_wrapper.py +185 -0
- morphomatics/manifold/metric.py +110 -0
- morphomatics/manifold/point_distribution_model.py +143 -0
- morphomatics/manifold/power_manifold.py +413 -0
- morphomatics/manifold/product_manifold.py +381 -0
- morphomatics/manifold/se_3.py +419 -0
- morphomatics/manifold/shape_space.py +57 -0
- morphomatics/manifold/so_3.py +494 -0
- morphomatics/manifold/spd.py +524 -0
- morphomatics/manifold/sphere.py +241 -0
- morphomatics/manifold/tangent_bundle.py +337 -0
- morphomatics/manifold/util.py +126 -0
- morphomatics/nn/__init__.py +15 -0
- morphomatics/nn/flow_layers.py +219 -0
- morphomatics/nn/tangent_layers.py +176 -0
- morphomatics/nn/train.py +202 -0
- morphomatics/nn/wFM_layers.py +152 -0
- morphomatics/opt/__init__.py +14 -0
- morphomatics/opt/riemannian_newton_raphson.py +65 -0
- morphomatics/opt/riemannian_steepest_descent.py +61 -0
- morphomatics/stats/__init__.py +18 -0
- morphomatics/stats/biinvariant_statistics.py +190 -0
- morphomatics/stats/exponential_barycenter.py +78 -0
- morphomatics/stats/geometric_median.py +89 -0
- morphomatics/stats/principal_geodesic_analysis.py +135 -0
- morphomatics/stats/riemannian_regression.py +317 -0
- morphomatics/stats/statistical_shape_model.py +99 -0
- morphomatics-4.0.dist-info/LICENSE +9 -0
- morphomatics-4.0.dist-info/METADATA +55 -0
- morphomatics-4.0.dist-info/RECORD +54 -0
- morphomatics-4.0.dist-info/WHEEL +5 -0
- morphomatics-4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import jax.random
|
|
14
|
+
import numpy as np
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
|
|
17
|
+
from scipy import sparse
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from sksparse.cholmod import cholesky as direct_solve
|
|
21
|
+
except:
|
|
22
|
+
from scipy.sparse.linalg import factorized as direct_solve
|
|
23
|
+
|
|
24
|
+
from ..geom import Surface
|
|
25
|
+
from . import GLpn
|
|
26
|
+
from . import ShapeSpace
|
|
27
|
+
from .util import align
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GLpCoords(ShapeSpace):
|
|
31
|
+
"""
|
|
32
|
+
Shape space based the group of matrices with positive determinant.
|
|
33
|
+
|
|
34
|
+
See:
|
|
35
|
+
Felix Ambellan, Stefan Zachow and Christoph von Tycowicz.
|
|
36
|
+
An as-invariant-as-possible GL+(3)-based statistical shape model.
|
|
37
|
+
Proc. 7th MICCAI workshop on Mathematical Foundations of Computational Anatomy, pp. 219--228, 2019.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, reference: Surface):
|
|
41
|
+
"""
|
|
42
|
+
:arg reference: Reference surface (shapes will be encoded as deformations thereof)
|
|
43
|
+
"""
|
|
44
|
+
assert reference is not None
|
|
45
|
+
self.ref = reference
|
|
46
|
+
k = len(self.ref.f)
|
|
47
|
+
|
|
48
|
+
self.update_ref_geom(self.ref.v)
|
|
49
|
+
|
|
50
|
+
self.GLp = GLpn(k, structure='AffineGroup')
|
|
51
|
+
|
|
52
|
+
name = 'Shape Space based on the orientation preserving component of the general linear group'
|
|
53
|
+
super().__init__(name, self.GLp.dim, self.GLp.point_shape, self.GLp.connec, None, self.GLp.group)
|
|
54
|
+
|
|
55
|
+
def tree_flatten(self):
|
|
56
|
+
return tuple(), (self.ref.v.tolist(), self.ref.f.tolist())
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def tree_unflatten(cls, aux_data, children):
|
|
60
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
61
|
+
return cls(Surface(*aux_data))
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def M(self):
|
|
65
|
+
return self.GLp
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def n_triangles(self):
|
|
69
|
+
return len(self.ref.f)
|
|
70
|
+
|
|
71
|
+
def update_ref_geom(self, v):
|
|
72
|
+
self.ref.v = v
|
|
73
|
+
|
|
74
|
+
# center of gravity
|
|
75
|
+
self.CoG = self.ref.v.mean(axis=0)
|
|
76
|
+
|
|
77
|
+
# setup Poisson system
|
|
78
|
+
S = self.ref.div @ self.ref.grad
|
|
79
|
+
# add soft-constraint fixing translational DoF
|
|
80
|
+
S += sparse.coo_matrix(([1.0], ([0], [0])), S.shape) # make pos-def
|
|
81
|
+
self.poisson = direct_solve(S.tocsc())
|
|
82
|
+
|
|
83
|
+
def to_coords(self, v):
|
|
84
|
+
"""
|
|
85
|
+
:arg v: #v-by-3 array of vertex coordinates
|
|
86
|
+
:return: GLp coords.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
# align
|
|
90
|
+
v = align(v, self.ref.v)
|
|
91
|
+
|
|
92
|
+
# compute gradients
|
|
93
|
+
D = self.ref.grad @ v
|
|
94
|
+
|
|
95
|
+
# D holds transpose of def. grads.
|
|
96
|
+
# decompose...
|
|
97
|
+
U, S, Vt = np.linalg.svd(D.reshape(-1, 3, 3))
|
|
98
|
+
|
|
99
|
+
# ...rotation
|
|
100
|
+
R = np.einsum('...ij,...jk', U, Vt)
|
|
101
|
+
W = np.ones_like(S)
|
|
102
|
+
W[:, -1] = np.linalg.det(R)
|
|
103
|
+
R = np.einsum('...ij,...j,...jk', U, W, Vt)
|
|
104
|
+
|
|
105
|
+
# ...stretch
|
|
106
|
+
S[:, -1] = 1 # no stretch (=1) in normal direction
|
|
107
|
+
# for degenerate triangles
|
|
108
|
+
# TODO: check which direction is normal in degenerate case
|
|
109
|
+
S[S < 1e-6] = 1e-6
|
|
110
|
+
U = np.einsum('...ij,...j,...kj', U, S, U)
|
|
111
|
+
return jnp.einsum('...ij,...jl', R, U)
|
|
112
|
+
|
|
113
|
+
def from_coords(self, D):
|
|
114
|
+
"""
|
|
115
|
+
:arg D: GLp coords.
|
|
116
|
+
:returns: #v-by-3 array of vertex coordinates
|
|
117
|
+
"""
|
|
118
|
+
# solve Poisson system
|
|
119
|
+
rhs = self.ref.div @ D.reshape(-1, 3)
|
|
120
|
+
v = self.poisson(rhs)
|
|
121
|
+
# move to CoG
|
|
122
|
+
v += self.CoG - v.mean(axis=0)
|
|
123
|
+
|
|
124
|
+
return v
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def ref_coords(self):
|
|
128
|
+
""" Identity coordinates (i.e., the reference shape).
|
|
129
|
+
"""
|
|
130
|
+
return self.group.identity
|
|
131
|
+
|
|
132
|
+
def rand(self, key: jax.Array):
|
|
133
|
+
"""Random set of coordinates () won't represent a 'nice' shape).
|
|
134
|
+
"""
|
|
135
|
+
return self.GLp.rand(key)
|
|
136
|
+
|
|
137
|
+
def randvec(self, A, key: jax.Array):
|
|
138
|
+
return self.GLp.randvec(A, key)
|
|
139
|
+
|
|
140
|
+
def zerovec(self):
|
|
141
|
+
"""Zero tangent vector in any tangent space.
|
|
142
|
+
"""
|
|
143
|
+
return self.GLp.zerovec()
|
|
144
|
+
|
|
145
|
+
def proj(self, p, X):
|
|
146
|
+
return X
|
|
147
|
+
|
|
148
|
+
def geopoint(self, A, B, t):
|
|
149
|
+
return self.GLp.connec.geopoint(A, B, t)
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
from jax.scipy.linalg import expm, funm
|
|
16
|
+
|
|
17
|
+
from morphomatics.manifold import Manifold, LieGroup, Connection
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GLpn(Manifold):
|
|
21
|
+
"""Returns the Lie group GL^+(n), i.e., the group of n-by-n matrices each with positive determinant.
|
|
22
|
+
|
|
23
|
+
manifold = GLpn(n)
|
|
24
|
+
|
|
25
|
+
Elements of GL^+(n) are represented as arrays of size nxn.
|
|
26
|
+
|
|
27
|
+
# NOTE: Tangent vectors are represented as left translations in the Lie algebra, i.e., a tangent vector X at g is
|
|
28
|
+
represented as d_gL_{g^(-1)}(X)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, n=3, structure='AffineGroup'):
|
|
32
|
+
self._n = n
|
|
33
|
+
|
|
34
|
+
name = 'Orientation preserving maps of R^n'
|
|
35
|
+
|
|
36
|
+
super().__init__(name, n**2, point_shape=(n, n))
|
|
37
|
+
|
|
38
|
+
if structure:
|
|
39
|
+
getattr(self, f'init{structure}Structure')()
|
|
40
|
+
|
|
41
|
+
def tree_flatten(self):
|
|
42
|
+
children, aux = super().tree_flatten()
|
|
43
|
+
return children, aux+(self._n,)
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def tree_unflatten(cls, aux_data, children):
|
|
47
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
48
|
+
*aux_data, n = aux_data
|
|
49
|
+
obj = cls(n, structure=None)
|
|
50
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
51
|
+
return obj
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def n(self):
|
|
55
|
+
return self._n
|
|
56
|
+
|
|
57
|
+
def rand(self, key: jax.Array):
|
|
58
|
+
"""Returns a random point in the Lie group. This does not
|
|
59
|
+
follow a specific distribution."""
|
|
60
|
+
A = jax.random.normal(key, self.point_shape)
|
|
61
|
+
return expm(A)
|
|
62
|
+
|
|
63
|
+
def randvec(self, A, key: jax.Array):
|
|
64
|
+
"""Returns a random vector in the tangent space at A.
|
|
65
|
+
"""
|
|
66
|
+
return jax.random.normal(key, self.point_shape)
|
|
67
|
+
|
|
68
|
+
def zerovec(self):
|
|
69
|
+
"""Returns the zero vector in the tangent space at g."""
|
|
70
|
+
return jnp.zeros((self.n, self.n))
|
|
71
|
+
|
|
72
|
+
def proj(self, p, X):
|
|
73
|
+
return X
|
|
74
|
+
|
|
75
|
+
def initAffineGroupStructure(self):
|
|
76
|
+
"""
|
|
77
|
+
Standard group structure with canonical Cartan Shouten connection.
|
|
78
|
+
"""
|
|
79
|
+
structure = GLpn.AffineGroupStructure(self)
|
|
80
|
+
self._connec = structure
|
|
81
|
+
self._group = structure
|
|
82
|
+
|
|
83
|
+
class AffineGroupStructure(Connection, LieGroup):
|
|
84
|
+
"""
|
|
85
|
+
Standard group structure on GL+(n) where the composition of two elements is given by component-wise matrix
|
|
86
|
+
multiplication. The connection is the corresponding canonical Cartan Shouten (CCS) connection. No Riemannian
|
|
87
|
+
metric is used.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(self, M):
|
|
91
|
+
"""
|
|
92
|
+
Constructor.
|
|
93
|
+
"""
|
|
94
|
+
self._M = M
|
|
95
|
+
|
|
96
|
+
def __str__(self):
|
|
97
|
+
return 'standard group structure on GL+(n) with CCS connection'
|
|
98
|
+
|
|
99
|
+
# Group
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def identity(self):
|
|
103
|
+
"""Returns the identity element e of the Lie group."""
|
|
104
|
+
return jnp.eye(self._M.n)
|
|
105
|
+
|
|
106
|
+
def bracket(self, X, Y):
|
|
107
|
+
"""Lie bracket in Lie algebra."""
|
|
108
|
+
return jnp.einsum('ij,jl->il', X, Y) - jnp.einsum('ij,jl->il', Y, X)
|
|
109
|
+
|
|
110
|
+
def coords(self, X):
|
|
111
|
+
"""Coordinate map for the tangent space at the identity."""
|
|
112
|
+
return jnp.reshape(X, (self._M.n ** 2, 1))
|
|
113
|
+
|
|
114
|
+
def coords_inverse(self, X):
|
|
115
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
116
|
+
|
|
117
|
+
def lefttrans(self, g, f):
|
|
118
|
+
"""Left translation of g by f.
|
|
119
|
+
"""
|
|
120
|
+
return jnp.einsum('ij,jl->il', f, g)
|
|
121
|
+
|
|
122
|
+
def righttrans(self, g, f):
|
|
123
|
+
"""Right translation of g by f.
|
|
124
|
+
"""
|
|
125
|
+
return jnp.einsum('ij,jl->il', g, f)
|
|
126
|
+
|
|
127
|
+
def inverse(self, g):
|
|
128
|
+
"""Inverse map of the Lie group.
|
|
129
|
+
"""
|
|
130
|
+
return jnp.linalg.inv(g)
|
|
131
|
+
|
|
132
|
+
def exp(self, *argv):
|
|
133
|
+
"""Computes the Lie-theoretic and connection exponential map
|
|
134
|
+
(depending on signature, i.e. whether footpoint is given as well)
|
|
135
|
+
"""
|
|
136
|
+
return jax.lax.cond(len(argv) == 1,
|
|
137
|
+
lambda A: A[-1], # group exp
|
|
138
|
+
lambda A: jnp.einsum('ij,jk', A[-1], A[0]), # exp of CCS connection
|
|
139
|
+
(argv[0], expm(argv[-1])))
|
|
140
|
+
|
|
141
|
+
retr = exp
|
|
142
|
+
|
|
143
|
+
def log(self, *argv):
|
|
144
|
+
"""Computes the Lie-theoretic and connection logarithm map
|
|
145
|
+
(depending on signature, i.e. whether footpoint is given as well)
|
|
146
|
+
"""
|
|
147
|
+
# NOTE: as logm() is not available in jax we apply log via funm() (so far this is CPU only; not as stable as
|
|
148
|
+
# logm in numpy)
|
|
149
|
+
logm = lambda m: jnp.real(funm(m, jnp.log))
|
|
150
|
+
return logm(jax.lax.cond(len(argv) == 1,
|
|
151
|
+
lambda A: A[-1],
|
|
152
|
+
lambda A: jnp.einsum('ij,kj', A[-1], A[0]),
|
|
153
|
+
argv))
|
|
154
|
+
|
|
155
|
+
def curvature_tensor(self, f, X, Y, Z):
|
|
156
|
+
"""Evaluates the curvature tensor R of the connection at f on the vectors X, Y, Z. With nabla_X Y denoting
|
|
157
|
+
the covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
158
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
159
|
+
is used.
|
|
160
|
+
"""
|
|
161
|
+
return - 1 / 4 * self.bracket(self.bracket(X, Y), Z)
|
|
162
|
+
|
|
163
|
+
def dleft(self, f, X):
|
|
164
|
+
"""Derivative of the left translation by f applied to the tangent vector X at the identity.
|
|
165
|
+
"""
|
|
166
|
+
return jnp.einsum('ij,jl->il', f, X)
|
|
167
|
+
|
|
168
|
+
def dright(self, f, X):
|
|
169
|
+
"""Derivative of the right translation by f at g applied to the tangent vector X.
|
|
170
|
+
"""
|
|
171
|
+
return jnp.einsum('ij,jl->il', X, f)
|
|
172
|
+
|
|
173
|
+
def dleft_inv(self, f, X):
|
|
174
|
+
"""Derivative of the left translation by f^{-1} at f applied to the tangent vector X.
|
|
175
|
+
"""
|
|
176
|
+
return jnp.einsum('ij,jl->il', self.inverse(f), X)
|
|
177
|
+
|
|
178
|
+
def dright_inv(self, f, X):
|
|
179
|
+
"""Derivative of the right translation by f^{-1} at f applied to the tangent vector X.
|
|
180
|
+
"""
|
|
181
|
+
return jnp.einsum('ij,jl->il', X, self.inverse(f))
|
|
182
|
+
|
|
183
|
+
def adjrep(self, g, X):
|
|
184
|
+
"""Adjoint representation of g applied to the tangent vector X at the identity.
|
|
185
|
+
"""
|
|
186
|
+
return jnp.einsum('ij,jl,lm->im', g, X, self.inverse(g))
|
|
187
|
+
|
|
188
|
+
def transp(self, f, g, X):
|
|
189
|
+
"""
|
|
190
|
+
Parallel transport of the CCS connection along one-parameter subgroups; see Sec. 5.3.3 of
|
|
191
|
+
X. Pennec and M. Lorenzi,
|
|
192
|
+
"Beyond Riemannian geometry: The affine connection setting for transformation groups."
|
|
193
|
+
|
|
194
|
+
"""
|
|
195
|
+
f_invg = self.lefttrans(g, self.inverse(f))
|
|
196
|
+
h = self.geopoint(self.identity, f_invg, .5)
|
|
197
|
+
|
|
198
|
+
return self.dleft_inv(f_invg, self.dleft(h, self.dright(h, X)))
|
|
199
|
+
|
|
200
|
+
def jacobiField(self, R, Q, t, X):
|
|
201
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
################################################################################
|
|
2
|
+
# #
|
|
3
|
+
# This file is part of the Morphomatics library #
|
|
4
|
+
# see https://github.com/morphomatics/morphomatics #
|
|
5
|
+
# #
|
|
6
|
+
# Copyright (C) 2024 Zuse Institute Berlin #
|
|
7
|
+
# #
|
|
8
|
+
# Morphomatics is distributed under the terms of the MIT License. #
|
|
9
|
+
# see $MORPHOMATICS/LICENSE #
|
|
10
|
+
# #
|
|
11
|
+
################################################################################
|
|
12
|
+
|
|
13
|
+
from morphomatics.manifold import Manifold, Metric
|
|
14
|
+
from morphomatics.manifold.discrete_ops import pole_ladder
|
|
15
|
+
|
|
16
|
+
import jax
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
from jax.numpy.linalg import svd
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Grassmann(Manifold):
|
|
22
|
+
""" The Grassmannian.
|
|
23
|
+
Manifold of m-dimensional subspaces of n dimensional real vector space
|
|
24
|
+
Elements are represented as n x m matrices.
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self, n=3, m=1, structure='Canonical'):
|
|
27
|
+
if n < m or m < 1:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Need n >= p >= 1. Values supplied were n = {n} and m = {m}"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
name = 'Grassmann manifold Gr({n},{m})'.format(n=n, m=m)
|
|
33
|
+
dimension = int(n * m - m ** 2)
|
|
34
|
+
self._n = n
|
|
35
|
+
self._m = m
|
|
36
|
+
super().__init__(name, dimension, point_shape=(n, m))
|
|
37
|
+
if structure:
|
|
38
|
+
getattr(self, f'init{structure}Structure')()
|
|
39
|
+
|
|
40
|
+
def tree_flatten(self):
|
|
41
|
+
children, aux = super().tree_flatten()
|
|
42
|
+
return children, aux+(self.point_shape,)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def tree_unflatten(cls, aux_data, children):
|
|
46
|
+
"""Specifies an unflattening recipe for PyTree registration."""
|
|
47
|
+
*aux_data, shape = aux_data
|
|
48
|
+
obj = cls(*shape, structure=None)
|
|
49
|
+
obj.tree_unflatten_instance(aux_data, children)
|
|
50
|
+
return obj
|
|
51
|
+
|
|
52
|
+
def initCanonicalStructure(self):
|
|
53
|
+
"""
|
|
54
|
+
Instantiate Grassmannian with canonical structure.
|
|
55
|
+
"""
|
|
56
|
+
structure = Grassmann.CanonicalStructure(self)
|
|
57
|
+
self._metric = structure
|
|
58
|
+
self._connec = structure
|
|
59
|
+
|
|
60
|
+
# Generate random Grassmann point using qr of random normally distributed matrix.
|
|
61
|
+
def rand(self, key: jax.Array):
|
|
62
|
+
Q, _ = jnp.linalg.qr(jax.random.normal(key, self.point_shape))
|
|
63
|
+
return Q
|
|
64
|
+
|
|
65
|
+
def randvec(self, p, key: jax.Array):
|
|
66
|
+
U = jax.random.normal(key, p.shape)
|
|
67
|
+
U = U - jnp.einsum('ij,kj,kl', p, p, U)
|
|
68
|
+
U = U / jnp.linalg.norm(U)
|
|
69
|
+
return U
|
|
70
|
+
|
|
71
|
+
def zerovec(self):
|
|
72
|
+
return jnp.zeros(self.point_shape)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def project(p):
|
|
76
|
+
"""Project arbitrary matrix to the manifold."""
|
|
77
|
+
return jnp.linalg.qr(p)[0]
|
|
78
|
+
|
|
79
|
+
def proj(self, p, U):
|
|
80
|
+
"""Project ambient tangent vector to tangent space at p."""
|
|
81
|
+
return U - jnp.einsum('ij,kj,kl', p, p, U)
|
|
82
|
+
|
|
83
|
+
class CanonicalStructure(Metric):
|
|
84
|
+
"""
|
|
85
|
+
The Riemannian metric used is the induced metric from the embedding space (R^nxm)^k, i.e., this manifold is a
|
|
86
|
+
Riemannian submanifold of (R^nxm)^k endowed with the usual trace inner product.
|
|
87
|
+
"""
|
|
88
|
+
def __init__(self, M):
|
|
89
|
+
"""
|
|
90
|
+
Constructor.
|
|
91
|
+
"""
|
|
92
|
+
self._M = M
|
|
93
|
+
|
|
94
|
+
def __str__(self):
|
|
95
|
+
return "Canonical structure of Grassmannian"
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def typicaldist(self):
|
|
99
|
+
return jnp.sqrt(jnp.prod(self._M.point_shape[-2:]))
|
|
100
|
+
|
|
101
|
+
# Geodesic distance for Grassmann
|
|
102
|
+
def dist(self, p, q):
|
|
103
|
+
s = svd(jnp.einsum('ji,jk', p, q), compute_uv=False)
|
|
104
|
+
s = jnp.arccos(jnp.min(s))
|
|
105
|
+
return jnp.linalg.norm(s)
|
|
106
|
+
|
|
107
|
+
def inner(self, p, G, H):
|
|
108
|
+
# Inner product (Riemannian metric) on the tangent space
|
|
109
|
+
# For the Grassmann this is the Frobenius inner product.
|
|
110
|
+
return jnp.einsum('ij,ij', G, H)
|
|
111
|
+
|
|
112
|
+
def flat(self, p, G):
|
|
113
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
114
|
+
|
|
115
|
+
def sharp(self, p, dG):
|
|
116
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
117
|
+
|
|
118
|
+
def egrad2rgrad(self, p, X):
|
|
119
|
+
return self._M.proj(p, X)
|
|
120
|
+
|
|
121
|
+
def ehess2rhess(self, p, G, H, X):
|
|
122
|
+
# Convert Euclidean into Riemannian Hessian.
|
|
123
|
+
xpG = jnp.einsum('ij,kj,kl', X, p, G)
|
|
124
|
+
return self._M.proj(p, H) - xpG
|
|
125
|
+
|
|
126
|
+
def retr(self, X, G):
|
|
127
|
+
# We do not need to worry about flipping signs of columns here,
|
|
128
|
+
# since only the column space is important, not the actual
|
|
129
|
+
# columns. Compare this with the Stiefel manifold.
|
|
130
|
+
|
|
131
|
+
# Compute the polar factorization of Y = X+P_G
|
|
132
|
+
u, _, vt = svd(X + G, full_matrices=False)
|
|
133
|
+
return u @ vt
|
|
134
|
+
|
|
135
|
+
def norm(self, p, G):
|
|
136
|
+
# Norm on the tangent space is simply the Euclidean norm.
|
|
137
|
+
return jnp.linalg.norm(G)
|
|
138
|
+
|
|
139
|
+
def transp(self, p1, p2, d):
|
|
140
|
+
|
|
141
|
+
return pole_ladder(self._M, p1, p2, d)
|
|
142
|
+
|
|
143
|
+
def exp(self, p, U):
|
|
144
|
+
u, s, vt = svd(U, full_matrices=False)
|
|
145
|
+
|
|
146
|
+
Y = jnp.einsum('ij,kj,k,kl', p, vt, jnp.cos(s), vt) + \
|
|
147
|
+
jnp.einsum('ij,j,jk', u, jnp.sin(s), vt)
|
|
148
|
+
|
|
149
|
+
# From numerical experiments, it seems necessary to
|
|
150
|
+
# re-orthonormalize. This is overall quite expensive.
|
|
151
|
+
Y, _ = jnp.linalg.qr(Y)
|
|
152
|
+
return Y
|
|
153
|
+
|
|
154
|
+
def log(self, p, q):
|
|
155
|
+
qtp = q.T @ p
|
|
156
|
+
At = q.T - qtp @ p.T
|
|
157
|
+
Bt = jnp.linalg.solve(qtp, At)
|
|
158
|
+
u, s, vt = svd(Bt.T, full_matrices=False)
|
|
159
|
+
|
|
160
|
+
return jnp.einsum('ij,j,jk', u, jnp.arctan(s), vt)
|
|
161
|
+
|
|
162
|
+
def curvature_tensor(self, p, X, Y, Z):
|
|
163
|
+
"""Evaluates the curvature tensor R of the connection at p on the vectors X, Y, Z. With nabla_X Y denoting the
|
|
164
|
+
covariant derivative of Y in direction X and [] being the Lie bracket, the convention
|
|
165
|
+
R(X,Y)Z = (nabla_X nabla_Y) Z - (nabla_Y nabla_X) Z - nabla_[X,Y] Z
|
|
166
|
+
is used.
|
|
167
|
+
"""
|
|
168
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
169
|
+
|
|
170
|
+
def jacobiField(self, p, q, t, X):
|
|
171
|
+
raise NotImplementedError('This function has not been implemented yet.')
|
|
172
|
+
|
|
173
|
+
def adjJacobi(self, p, q, t, X):
|
|
174
|
+
raise NotImplementedError('This function has not been implemented yet.')
|