nanomanifold 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nanomanifold might be problematic. Click here for more details.
- nanomanifold/SE3/__init__.py +21 -0
- nanomanifold/SE3/canonicalize.py +25 -0
- nanomanifold/SE3/conversions/__init__.py +9 -0
- nanomanifold/SE3/conversions/matrix.py +57 -0
- nanomanifold/SE3/conversions/rt.py +30 -0
- nanomanifold/SE3/exp.py +73 -0
- nanomanifold/SE3/inverse.py +37 -0
- nanomanifold/SE3/log.py +75 -0
- nanomanifold/SE3/multiply.py +48 -0
- nanomanifold/SE3/transform_points.py +34 -0
- nanomanifold/SO3/__init__.py +35 -0
- nanomanifold/SO3/canonicalize.py +17 -0
- nanomanifold/SO3/conversions/__init__.py +14 -0
- nanomanifold/SO3/conversions/axis_angle.py +63 -0
- nanomanifold/SO3/conversions/euler.py +149 -0
- nanomanifold/SO3/conversions/matrix.py +79 -0
- nanomanifold/SO3/distance.py +52 -0
- nanomanifold/SO3/exp.py +25 -0
- nanomanifold/SO3/hat.py +36 -0
- nanomanifold/SO3/inverse.py +16 -0
- nanomanifold/SO3/log.py +25 -0
- nanomanifold/SO3/multiply.py +40 -0
- nanomanifold/SO3/rotate_points.py +50 -0
- nanomanifold/SO3/slerp.py +69 -0
- nanomanifold/SO3/vee.py +28 -0
- nanomanifold/SO3/weighted_mean.py +89 -0
- nanomanifold/__init__.py +3 -0
- nanomanifold/common.py +25 -0
- nanomanifold-0.1.1.dist-info/METADATA +112 -0
- nanomanifold-0.1.1.dist-info/RECORD +31 -0
- nanomanifold-0.1.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
from ..canonicalize import canonicalize
|
|
8
|
+
from . import matrix
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def to_euler(q: Float[Any, "... 4"], convention: str = "ZYX") -> Float[Any, "... 3"]:
|
|
12
|
+
q = canonicalize(q)
|
|
13
|
+
|
|
14
|
+
R = matrix.to_matrix(q)
|
|
15
|
+
|
|
16
|
+
return _matrix_to_euler(R, convention)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def from_euler(euler: Float[Any, "... 3"], convention: str = "ZYX") -> Float[Any, "... 4"]:
|
|
20
|
+
R = _euler_to_matrix(euler, convention)
|
|
21
|
+
|
|
22
|
+
return matrix.from_matrix(R)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _euler_to_matrix(euler, convention):
|
|
26
|
+
"""Convert Euler angles to rotation matrix."""
|
|
27
|
+
xp = get_namespace(euler)
|
|
28
|
+
|
|
29
|
+
eye = xp.eye(3, dtype=euler.dtype)
|
|
30
|
+
output_shape = euler.shape[:-1] + (3, 3)
|
|
31
|
+
R = xp.broadcast_to(eye, output_shape)
|
|
32
|
+
|
|
33
|
+
is_extrinsic = convention.islower()
|
|
34
|
+
conv = convention.lower()
|
|
35
|
+
|
|
36
|
+
for i, axis in enumerate(conv):
|
|
37
|
+
angle = euler[..., i]
|
|
38
|
+
R_axis = _rotation_matrix(angle, axis)
|
|
39
|
+
|
|
40
|
+
if is_extrinsic:
|
|
41
|
+
R = xp.matmul(R_axis, R)
|
|
42
|
+
else:
|
|
43
|
+
R = xp.matmul(R, R_axis)
|
|
44
|
+
|
|
45
|
+
return R
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _rotation_matrix(angle, axis):
|
|
49
|
+
"""Create rotation matrix for given angle and axis."""
|
|
50
|
+
xp = get_namespace(angle)
|
|
51
|
+
|
|
52
|
+
cos_a = xp.cos(angle)
|
|
53
|
+
sin_a = xp.sin(angle)
|
|
54
|
+
zero = xp.zeros_like(cos_a)
|
|
55
|
+
ones = xp.ones_like(cos_a)
|
|
56
|
+
|
|
57
|
+
if axis == "x":
|
|
58
|
+
mat = xp.stack(
|
|
59
|
+
[xp.stack([ones, zero, zero], axis=-1), xp.stack([zero, cos_a, -sin_a], axis=-1), xp.stack([zero, sin_a, cos_a], axis=-1)],
|
|
60
|
+
axis=-2,
|
|
61
|
+
)
|
|
62
|
+
elif axis == "y":
|
|
63
|
+
mat = xp.stack(
|
|
64
|
+
[xp.stack([cos_a, zero, sin_a], axis=-1), xp.stack([zero, ones, zero], axis=-1), xp.stack([-sin_a, zero, cos_a], axis=-1)],
|
|
65
|
+
axis=-2,
|
|
66
|
+
)
|
|
67
|
+
elif axis == "z":
|
|
68
|
+
mat = xp.stack(
|
|
69
|
+
[xp.stack([cos_a, -sin_a, zero], axis=-1), xp.stack([sin_a, cos_a, zero], axis=-1), xp.stack([zero, zero, ones], axis=-1)],
|
|
70
|
+
axis=-2,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Invalid axis: {axis}")
|
|
74
|
+
|
|
75
|
+
return mat
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _matrix_to_euler(matrix, convention):
|
|
79
|
+
"""Convert rotation matrix to Euler angles."""
|
|
80
|
+
xp = get_namespace(matrix)
|
|
81
|
+
|
|
82
|
+
is_extrinsic = convention.islower()
|
|
83
|
+
|
|
84
|
+
if is_extrinsic:
|
|
85
|
+
convention = convention.upper()
|
|
86
|
+
convention = convention[::-1]
|
|
87
|
+
|
|
88
|
+
eulers = _matrix_to_euler_angles(matrix, convention)
|
|
89
|
+
|
|
90
|
+
if is_extrinsic:
|
|
91
|
+
return xp.stack([eulers[..., 2], eulers[..., 1], eulers[..., 0]], axis=-1)
|
|
92
|
+
|
|
93
|
+
return eulers
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _matrix_to_euler_angles(matrix, convention):
|
|
97
|
+
"""Extract Euler angles from rotation matrix using systematic approach."""
|
|
98
|
+
xp = get_namespace(matrix)
|
|
99
|
+
|
|
100
|
+
if len(convention) != 3:
|
|
101
|
+
raise ValueError("Convention must have 3 letters.")
|
|
102
|
+
if convention[1] in (convention[0], convention[2]):
|
|
103
|
+
raise ValueError(f"Invalid convention {convention}.")
|
|
104
|
+
for letter in convention:
|
|
105
|
+
if letter not in ("X", "Y", "Z"):
|
|
106
|
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
|
107
|
+
|
|
108
|
+
i0 = _index_from_letter(convention[0])
|
|
109
|
+
i2 = _index_from_letter(convention[2])
|
|
110
|
+
tait_bryan = i0 != i2
|
|
111
|
+
|
|
112
|
+
if tait_bryan:
|
|
113
|
+
sign = -1.0 if i0 - i2 in [-1, 2] else 1.0
|
|
114
|
+
x = matrix[..., i0, i2] * sign
|
|
115
|
+
|
|
116
|
+
one = xp.ones_like(x)
|
|
117
|
+
eps = xp.finfo(x.dtype).eps * one
|
|
118
|
+
central_angle = xp.arcsin(xp.clip(x, -one + eps, one - eps))
|
|
119
|
+
else:
|
|
120
|
+
central_angle = xp.arccos(xp.clip(matrix[..., i0, i0], -1, 1))
|
|
121
|
+
|
|
122
|
+
first_angle = _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan, xp)
|
|
123
|
+
third_angle = _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan, xp)
|
|
124
|
+
|
|
125
|
+
return xp.stack([first_angle, central_angle, third_angle], axis=-1)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _angle_from_tan(axis, other_axis, data, horizontal, tait_bryan, xp):
|
|
129
|
+
"""Compute angle from tangent using systematic indexing."""
|
|
130
|
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
|
131
|
+
if horizontal:
|
|
132
|
+
i2, i1 = i1, i2
|
|
133
|
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
|
134
|
+
if horizontal == even:
|
|
135
|
+
return xp.atan2(data[..., i1], data[..., i2])
|
|
136
|
+
if tait_bryan:
|
|
137
|
+
return xp.atan2(-data[..., i2], data[..., i1])
|
|
138
|
+
return xp.atan2(data[..., i2], -data[..., i1])
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _index_from_letter(letter):
|
|
142
|
+
"""Convert axis letter to index."""
|
|
143
|
+
if letter == "X":
|
|
144
|
+
return 0
|
|
145
|
+
if letter == "Y":
|
|
146
|
+
return 1
|
|
147
|
+
if letter == "Z":
|
|
148
|
+
return 2
|
|
149
|
+
raise ValueError("letter must be either X, Y or Z.")
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Matrix conversions for SO(3) rotations."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
|
|
7
|
+
from nanomanifold.common import get_namespace
|
|
8
|
+
|
|
9
|
+
from ..canonicalize import canonicalize
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def to_matrix(q: Float[Any, "... 4"]) -> Float[Any, "... 3 3"]:
|
|
13
|
+
xp = get_namespace(q)
|
|
14
|
+
q = canonicalize(q)
|
|
15
|
+
w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
|
16
|
+
|
|
17
|
+
R = xp.stack(
|
|
18
|
+
[
|
|
19
|
+
xp.stack([1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y)], axis=-1),
|
|
20
|
+
xp.stack([2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x)], axis=-1),
|
|
21
|
+
xp.stack([2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)], axis=-1),
|
|
22
|
+
],
|
|
23
|
+
axis=-2,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return R
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def from_matrix(R: Float[Any, "... 3 3"]) -> Float[Any, "... 4"]:
|
|
30
|
+
xp = get_namespace(R)
|
|
31
|
+
|
|
32
|
+
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
|
33
|
+
|
|
34
|
+
zero = trace * 0
|
|
35
|
+
one = zero + 1
|
|
36
|
+
eps = one * 1e-10
|
|
37
|
+
quarter = one * 0.25
|
|
38
|
+
two = one * 2
|
|
39
|
+
|
|
40
|
+
s1 = xp.sqrt(xp.maximum(zero, trace + one)) * two # s = 4 * w
|
|
41
|
+
s1_safe = xp.where(s1 < eps, eps, s1) # Avoid division by zero
|
|
42
|
+
w1 = quarter * s1
|
|
43
|
+
x1 = (R[..., 2, 1] - R[..., 1, 2]) / s1_safe
|
|
44
|
+
y1 = (R[..., 0, 2] - R[..., 2, 0]) / s1_safe
|
|
45
|
+
z1 = (R[..., 1, 0] - R[..., 0, 1]) / s1_safe
|
|
46
|
+
|
|
47
|
+
s2 = xp.sqrt(xp.maximum(zero, one + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])) * two # s = 4 * x
|
|
48
|
+
s2_safe = xp.where(s2 < eps, eps, s2) # Avoid division by zero
|
|
49
|
+
w2 = (R[..., 2, 1] - R[..., 1, 2]) / s2_safe
|
|
50
|
+
x2 = quarter * s2
|
|
51
|
+
y2 = (R[..., 0, 1] + R[..., 1, 0]) / s2_safe
|
|
52
|
+
z2 = (R[..., 0, 2] + R[..., 2, 0]) / s2_safe
|
|
53
|
+
|
|
54
|
+
s3 = xp.sqrt(xp.maximum(zero, one + R[..., 1, 1] - R[..., 0, 0] - R[..., 2, 2])) * two # s = 4 * y
|
|
55
|
+
s3_safe = xp.where(s3 < eps, eps, s3) # Avoid division by zero
|
|
56
|
+
w3 = (R[..., 0, 2] - R[..., 2, 0]) / s3_safe
|
|
57
|
+
x3 = (R[..., 0, 1] + R[..., 1, 0]) / s3_safe
|
|
58
|
+
y3 = quarter * s3
|
|
59
|
+
z3 = (R[..., 1, 2] + R[..., 2, 1]) / s3_safe
|
|
60
|
+
|
|
61
|
+
s4 = xp.sqrt(xp.maximum(zero, one + R[..., 2, 2] - R[..., 0, 0] - R[..., 1, 1])) * two # s = 4 * z
|
|
62
|
+
s4_safe = xp.where(s4 < eps, eps, s4) # Avoid division by zero
|
|
63
|
+
w4 = (R[..., 1, 0] - R[..., 0, 1]) / s4_safe
|
|
64
|
+
x4 = (R[..., 0, 2] + R[..., 2, 0]) / s4_safe
|
|
65
|
+
y4 = (R[..., 1, 2] + R[..., 2, 1]) / s4_safe
|
|
66
|
+
z4 = quarter * s4
|
|
67
|
+
|
|
68
|
+
cond1 = trace > 0
|
|
69
|
+
cond2 = (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
|
|
70
|
+
cond3 = R[..., 1, 1] > R[..., 2, 2]
|
|
71
|
+
|
|
72
|
+
w = xp.where(cond1, w1, xp.where(cond2, w2, xp.where(cond3, w3, w4)))
|
|
73
|
+
x = xp.where(cond1, x1, xp.where(cond2, x2, xp.where(cond3, x3, x4)))
|
|
74
|
+
y = xp.where(cond1, y1, xp.where(cond2, y2, xp.where(cond3, y3, y4)))
|
|
75
|
+
z = xp.where(cond1, z1, xp.where(cond2, z2, xp.where(cond3, z3, z4)))
|
|
76
|
+
|
|
77
|
+
q = xp.stack([w, x, y, z], axis=-1)
|
|
78
|
+
|
|
79
|
+
return canonicalize(q)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def distance(q1: Float[Any, "... 4"], q2: Float[Any, "... 4"]) -> Float[Any, "..."]:
|
|
9
|
+
"""Compute the angular distance between two quaternions representing SO(3) rotations.
|
|
10
|
+
|
|
11
|
+
The angular distance is the smallest angle needed to rotate from one orientation
|
|
12
|
+
to another, measured in radians. This is equivalent to the geodesic distance
|
|
13
|
+
on the SO(3) manifold.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
q1: First quaternion in [w, x, y, z] format
|
|
17
|
+
q2: Second quaternion in [w, x, y, z] format
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Angular distance in radians, in range [0, π]
|
|
21
|
+
"""
|
|
22
|
+
xp = get_namespace(q1)
|
|
23
|
+
|
|
24
|
+
norm1 = xp.sqrt(xp.sum(q1**2, axis=-1, keepdims=True))
|
|
25
|
+
norm2 = xp.sqrt(xp.sum(q2**2, axis=-1, keepdims=True))
|
|
26
|
+
q1_unit = q1 / norm1
|
|
27
|
+
q2_unit = q2 / norm2
|
|
28
|
+
|
|
29
|
+
# Flip sign of q2 when dot(q1, q2) < 0 so the relative rotation
|
|
30
|
+
# always measures the shorter geodesic on the double cover.
|
|
31
|
+
dot_keepdims = xp.sum(q1_unit * q2_unit, axis=-1, keepdims=True)
|
|
32
|
+
q2_unit = xp.where(dot_keepdims < 0, -q2_unit, q2_unit)
|
|
33
|
+
|
|
34
|
+
w1 = q1_unit[..., :1]
|
|
35
|
+
v1 = q1_unit[..., 1:]
|
|
36
|
+
w2 = q2_unit[..., :1]
|
|
37
|
+
v2 = q2_unit[..., 1:]
|
|
38
|
+
|
|
39
|
+
cross = xp.stack(
|
|
40
|
+
[
|
|
41
|
+
v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1],
|
|
42
|
+
v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2],
|
|
43
|
+
v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0],
|
|
44
|
+
],
|
|
45
|
+
axis=-1,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
vec = w1 * v2 - w2 * v1 - cross
|
|
49
|
+
vec_norm = xp.sqrt(xp.sum(vec**2, axis=-1))
|
|
50
|
+
w = w1 * w2 + xp.sum(v1 * v2, axis=-1, keepdims=True)
|
|
51
|
+
|
|
52
|
+
return 2.0 * xp.atan2(vec_norm, w[..., 0])
|
nanomanifold/SO3/exp.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from .conversions.axis_angle import from_axis_angle
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def exp(tangent_vector: Float[Any, "... 3"]) -> Float[Any, "... 4"]:
|
|
9
|
+
"""Compute the exponential map from so(3) tangent space to SO(3) manifold.
|
|
10
|
+
|
|
11
|
+
The exponential map takes a tangent vector in the Lie algebra so(3)
|
|
12
|
+
and returns the corresponding rotation quaternion. This is the inverse
|
|
13
|
+
operation of log().
|
|
14
|
+
|
|
15
|
+
The exponential map is mathematically equivalent to converting an axis-angle
|
|
16
|
+
representation to its corresponding quaternion.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
tangent_vector: Tangent vector in so(3) (axis-angle representation)
|
|
20
|
+
The magnitude is the rotation angle, direction is the rotation axis
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Quaternion in [w, x, y, z] format representing the rotation
|
|
24
|
+
"""
|
|
25
|
+
return from_axis_angle(tangent_vector)
|
nanomanifold/SO3/hat.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from ..common import get_namespace
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def hat(w: Float[Any, "... 3"]) -> Float[Any, "... 3 3"]:
|
|
9
|
+
"""Map vector to skew-symmetric matrix (hat operator).
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
w: (..., 3) array representing tangent vectors in so(3)
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
(..., 3, 3) skew-symmetric matrices
|
|
16
|
+
"""
|
|
17
|
+
xp = get_namespace(w)
|
|
18
|
+
|
|
19
|
+
w1 = w[..., 0]
|
|
20
|
+
w2 = w[..., 1]
|
|
21
|
+
w3 = w[..., 2]
|
|
22
|
+
|
|
23
|
+
zero = w1 * 0
|
|
24
|
+
|
|
25
|
+
# Build skew-symmetric matrix:
|
|
26
|
+
# [[ 0, -w3, w2],
|
|
27
|
+
# [ w3, 0, -w1],
|
|
28
|
+
# [-w2, w1, 0]]
|
|
29
|
+
|
|
30
|
+
row1 = xp.stack([zero, -w3, w2], axis=-1)
|
|
31
|
+
row2 = xp.stack([w3, zero, -w1], axis=-1)
|
|
32
|
+
row3 = xp.stack([-w2, w1, zero], axis=-1)
|
|
33
|
+
|
|
34
|
+
result = xp.stack([row1, row2, row3], axis=-2)
|
|
35
|
+
|
|
36
|
+
return result
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
from .canonicalize import canonicalize
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def inverse(q: Float[Any, "... 4"]) -> Float[Any, "... 4"]:
|
|
11
|
+
xp = get_namespace(q)
|
|
12
|
+
q = canonicalize(q)
|
|
13
|
+
w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
|
14
|
+
q_inv = xp.stack([w, -x, -y, -z], axis=-1)
|
|
15
|
+
|
|
16
|
+
return canonicalize(q_inv)
|
nanomanifold/SO3/log.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from .conversions.axis_angle import to_axis_angle
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def log(q: Float[Any, "... 4"]) -> Float[Any, "... 3"]:
|
|
9
|
+
"""Compute the logarithmic map of a quaternion on the SO(3) manifold.
|
|
10
|
+
|
|
11
|
+
The logarithmic map takes a rotation and returns the corresponding
|
|
12
|
+
tangent vector (axis-angle representation) in the Lie algebra so(3).
|
|
13
|
+
This is the inverse operation of exp().
|
|
14
|
+
|
|
15
|
+
The logarithmic map is mathematically equivalent to converting a quaternion
|
|
16
|
+
to its axis-angle representation.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
q: Quaternion in [w, x, y, z] format representing a rotation
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Tangent vector in so(3) (axis-angle representation)
|
|
23
|
+
The magnitude is the rotation angle, direction is the rotation axis
|
|
24
|
+
"""
|
|
25
|
+
return to_axis_angle(q)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
from .canonicalize import canonicalize
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def multiply(q1: Float[Any, "... 4"], q2: Float[Any, "... 4"]) -> Float[Any, "... 4"]:
|
|
11
|
+
"""Multiply two quaternions representing SO(3) rotations.
|
|
12
|
+
|
|
13
|
+
The multiplication order matches rotation matrix multiplication:
|
|
14
|
+
multiply(q1, q2) represents the same composition as to_matrix(q1) @ to_matrix(q2)
|
|
15
|
+
|
|
16
|
+
This means q2 is applied first, then q1.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
q1: First quaternion in [w, x, y, z] format
|
|
20
|
+
q2: Second quaternion in [w, x, y, z] format
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Product quaternion representing the composed rotation
|
|
24
|
+
"""
|
|
25
|
+
xp = get_namespace(q1)
|
|
26
|
+
|
|
27
|
+
q1 = canonicalize(q1)
|
|
28
|
+
q2 = canonicalize(q2)
|
|
29
|
+
|
|
30
|
+
w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
|
|
31
|
+
w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
|
|
32
|
+
|
|
33
|
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
|
34
|
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
|
35
|
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
|
36
|
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
|
37
|
+
|
|
38
|
+
result = xp.stack([w, x, y, z], axis=-1)
|
|
39
|
+
|
|
40
|
+
return canonicalize(result)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
from .canonicalize import canonicalize
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def rotate_points(q: Float[Any, "... 4"], points: Float[Any, "... N 3"]) -> Float[Any, "... N 3"]:
|
|
11
|
+
"""Rotate points using quaternion rotation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
q: Quaternion in [w, x, y, z] format of shape (..., 4)
|
|
15
|
+
points: Points to rotate of shape (..., N, 3)
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Rotated points of shape (..., N, 3)
|
|
19
|
+
"""
|
|
20
|
+
xp = get_namespace(q)
|
|
21
|
+
q = canonicalize(q)
|
|
22
|
+
|
|
23
|
+
w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
|
24
|
+
|
|
25
|
+
w = w[..., None]
|
|
26
|
+
x = x[..., None]
|
|
27
|
+
y = y[..., None]
|
|
28
|
+
z = z[..., None]
|
|
29
|
+
|
|
30
|
+
px, py, pz = points[..., 0], points[..., 1], points[..., 2]
|
|
31
|
+
|
|
32
|
+
# Apply quaternion rotation using the formula:
|
|
33
|
+
# p' = q * p * q^(-1)
|
|
34
|
+
# This can be computed efficiently as:
|
|
35
|
+
# p' = p + 2 * cross(q_vec, cross(q_vec, p) + w * p)
|
|
36
|
+
# where q_vec = [x, y, z]
|
|
37
|
+
|
|
38
|
+
cross1_x = y * pz - z * py + w * px
|
|
39
|
+
cross1_y = z * px - x * pz + w * py
|
|
40
|
+
cross1_z = x * py - y * px + w * pz
|
|
41
|
+
|
|
42
|
+
cross2_x = y * cross1_z - z * cross1_y
|
|
43
|
+
cross2_y = z * cross1_x - x * cross1_z
|
|
44
|
+
cross2_z = x * cross1_y - y * cross1_x
|
|
45
|
+
|
|
46
|
+
result_x = px + 2 * cross2_x
|
|
47
|
+
result_y = py + 2 * cross2_y
|
|
48
|
+
result_z = pz + 2 * cross2_z
|
|
49
|
+
|
|
50
|
+
return xp.stack([result_x, result_y, result_z], axis=-1)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
from .canonicalize import canonicalize
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def slerp(q1: Float[Any, "... 4"], q2: Float[Any, "... 4"], t: Float[Any, "... N"]) -> Float[Any, "... N 4"]:
|
|
11
|
+
"""Spherical linear interpolation between two quaternions representing SO(3) rotations.
|
|
12
|
+
|
|
13
|
+
Performs geodesic interpolation on the SO(3) manifold, taking the shortest
|
|
14
|
+
path between two rotations. The interpolation parameter array t must have
|
|
15
|
+
values in [0, 1].
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
q1: Start quaternion in [w, x, y, z] format
|
|
19
|
+
q2: End quaternion in [w, x, y, z] format
|
|
20
|
+
t: Array of interpolation parameters in [0, 1]. Last dimension N represents
|
|
21
|
+
the number of interpolation points. For a single point, use shape [..., 1].
|
|
22
|
+
t values: 0 returns q1, 1 returns q2
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Interpolated quaternions with shape [..., N, 4] where N is the number of
|
|
26
|
+
interpolation points from the last dimension of t
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If any t values are outside the range [0, 1]
|
|
30
|
+
"""
|
|
31
|
+
xp = get_namespace(q1)
|
|
32
|
+
|
|
33
|
+
q1 = canonicalize(q1)
|
|
34
|
+
q2 = canonicalize(q2)
|
|
35
|
+
|
|
36
|
+
q1_expanded = xp.expand_dims(q1, axis=-2)
|
|
37
|
+
q2_expanded = xp.expand_dims(q2, axis=-2)
|
|
38
|
+
|
|
39
|
+
dot_product = xp.sum(q1_expanded * q2_expanded, axis=-1, keepdims=True)
|
|
40
|
+
|
|
41
|
+
q2_corrected = xp.where(dot_product < 0, -q2_expanded, q2_expanded)
|
|
42
|
+
dot_product = xp.where(dot_product < 0, -dot_product, dot_product)
|
|
43
|
+
|
|
44
|
+
dot_product = xp.clip(dot_product, 0.0, 1.0)
|
|
45
|
+
|
|
46
|
+
threshold = 0.9995
|
|
47
|
+
|
|
48
|
+
t_expanded = xp.expand_dims(t, axis=-1)
|
|
49
|
+
|
|
50
|
+
use_linear = dot_product > threshold
|
|
51
|
+
|
|
52
|
+
linear_result = (1.0 - t_expanded) * q1_expanded + t_expanded * q2_corrected
|
|
53
|
+
linear_norm = xp.sqrt(xp.sum(linear_result**2, axis=-1, keepdims=True))
|
|
54
|
+
linear_result = linear_result / linear_norm
|
|
55
|
+
|
|
56
|
+
omega = xp.acos(dot_product)
|
|
57
|
+
sin_omega = xp.sin(omega)
|
|
58
|
+
|
|
59
|
+
eps = 1e-8
|
|
60
|
+
sin_omega_safe = xp.where(xp.abs(sin_omega) < eps, eps, sin_omega)
|
|
61
|
+
|
|
62
|
+
weight1 = xp.sin((1.0 - t_expanded) * omega) / sin_omega_safe
|
|
63
|
+
weight2 = xp.sin(t_expanded * omega) / sin_omega_safe
|
|
64
|
+
|
|
65
|
+
spherical_result = weight1 * q1_expanded + weight2 * q2_corrected
|
|
66
|
+
|
|
67
|
+
result = xp.where(use_linear, linear_result, spherical_result)
|
|
68
|
+
|
|
69
|
+
return canonicalize(result)
|
nanomanifold/SO3/vee.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from ..common import get_namespace
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def vee(W: Float[Any, "... 3 3"]) -> Float[Any, "... 3"]:
|
|
9
|
+
"""Map skew-symmetric matrix to vector (vee operator).
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
W: (..., 3, 3) skew-symmetric matrices
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
(..., 3) tangent vectors in so(3)
|
|
16
|
+
"""
|
|
17
|
+
xp = get_namespace(W)
|
|
18
|
+
|
|
19
|
+
# Extract components from skew-symmetric matrix:
|
|
20
|
+
# [[ 0, -w3, w2],
|
|
21
|
+
# [ w3, 0, -w1],
|
|
22
|
+
# [-w2, w1, 0]]
|
|
23
|
+
|
|
24
|
+
w1 = W[..., 2, 1]
|
|
25
|
+
w2 = W[..., 0, 2]
|
|
26
|
+
w3 = W[..., 1, 0]
|
|
27
|
+
|
|
28
|
+
return xp.stack([w1, w2, w3], axis=-1)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from typing import Any, Sequence
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from nanomanifold.common import get_namespace
|
|
6
|
+
|
|
7
|
+
from .canonicalize import canonicalize
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def weighted_mean(quaternions: Sequence[Float[Any, "... 4"]], weights: Float[Any, "... N"]) -> Float[Any, "... 4"]:
|
|
11
|
+
"""Compute the weighted mean of SO(3) rotations represented as quaternions.
|
|
12
|
+
|
|
13
|
+
This function implements the Riemannian mean on SO(3) by computing the weighted
|
|
14
|
+
average in quaternion space using the outer product method. The result is the
|
|
15
|
+
eigenvector corresponding to the largest eigenvalue of the weighted covariance matrix.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
quaternions: Sequence of quaternions in [w, x, y, z] format. Each quaternion
|
|
19
|
+
should have shape [..., 4] where the last dimension contains the
|
|
20
|
+
quaternion components.
|
|
21
|
+
weights: Array of weights with shape [..., N] where N is the number of quaternions.
|
|
22
|
+
The weights are normalized internally.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Weighted mean quaternion with shape [..., 4] in [w, x, y, z] format.
|
|
26
|
+
The result is canonicalized to ensure w >= 0.
|
|
27
|
+
|
|
28
|
+
Note:
|
|
29
|
+
This implementation follows the algorithm from:
|
|
30
|
+
"Averaging Quaternions" by F. Landis Markley et al.
|
|
31
|
+
"""
|
|
32
|
+
xp = get_namespace(quaternions[0])
|
|
33
|
+
original_dtype = quaternions[0].dtype
|
|
34
|
+
|
|
35
|
+
quats = xp.stack([xp.asarray(q, dtype=original_dtype) for q in quaternions], axis=-2)
|
|
36
|
+
weights_array = xp.asarray(weights, dtype=original_dtype)
|
|
37
|
+
|
|
38
|
+
norms = xp.linalg.norm(quats, axis=-1, keepdims=True)
|
|
39
|
+
eps = xp.finfo(original_dtype).eps * 10
|
|
40
|
+
safe_norms = xp.where(norms < eps, eps, norms)
|
|
41
|
+
quats_normalized = quats / safe_norms
|
|
42
|
+
|
|
43
|
+
sign_mask = quats_normalized[..., 0:1] < 0
|
|
44
|
+
quats_canonical = xp.where(sign_mask, -quats_normalized, quats_normalized)
|
|
45
|
+
|
|
46
|
+
weights_normalized = weights_array / xp.sum(weights_array, axis=-1, keepdims=True)
|
|
47
|
+
|
|
48
|
+
weighted_quats = weights_normalized[..., :, None] * quats_canonical
|
|
49
|
+
|
|
50
|
+
M = xp.einsum("...nj,...nk->...jk", weighted_quats, quats_canonical)
|
|
51
|
+
|
|
52
|
+
if original_dtype == xp.float16:
|
|
53
|
+
M_compute = xp.asarray(M, dtype=xp.float32)
|
|
54
|
+
eigenvalues, eigenvectors = xp.linalg.eigh(M_compute)
|
|
55
|
+
eigenvectors = xp.asarray(eigenvectors, dtype=original_dtype)
|
|
56
|
+
else:
|
|
57
|
+
eigenvalues, eigenvectors = xp.linalg.eigh(M)
|
|
58
|
+
|
|
59
|
+
avg_quat = eigenvectors[..., :, -1]
|
|
60
|
+
|
|
61
|
+
avg_quat = avg_quat / xp.linalg.norm(avg_quat, axis=-1, keepdims=True)
|
|
62
|
+
|
|
63
|
+
return canonicalize(avg_quat)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def mean(quaternions: Sequence[Float[Any, "... 4"]]) -> Float[Any, "... 4"]:
|
|
67
|
+
"""Compute the mean of SO(3) rotations represented as quaternions.
|
|
68
|
+
|
|
69
|
+
This is equivalent to weighted_mean with uniform weights.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
quaternions: Sequence of quaternions in [w, x, y, z] format. Each quaternion
|
|
73
|
+
should have shape [..., 4] where the last dimension contains the
|
|
74
|
+
quaternion components.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Mean quaternion with shape [..., 4] in [w, x, y, z] format.
|
|
78
|
+
The result is canonicalized to ensure w >= 0.
|
|
79
|
+
"""
|
|
80
|
+
if len(quaternions) == 0:
|
|
81
|
+
raise ValueError("Cannot compute mean of empty quaternion sequence")
|
|
82
|
+
|
|
83
|
+
xp = get_namespace(quaternions[0])
|
|
84
|
+
|
|
85
|
+
batch_shape = quaternions[0].shape[:-1]
|
|
86
|
+
num_quats = len(quaternions)
|
|
87
|
+
weights = xp.ones(batch_shape + (num_quats,), dtype=quaternions[0].dtype)
|
|
88
|
+
|
|
89
|
+
return weighted_mean(quaternions, weights)
|
nanomanifold/__init__.py
ADDED