nanomanifold 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nanomanifold might be problematic. Click here for more details.

@@ -0,0 +1,149 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+ from ..canonicalize import canonicalize
8
+ from . import matrix
9
+
10
+
11
+ def to_euler(q: Float[Any, "... 4"], convention: str = "ZYX") -> Float[Any, "... 3"]:
12
+ q = canonicalize(q)
13
+
14
+ R = matrix.to_matrix(q)
15
+
16
+ return _matrix_to_euler(R, convention)
17
+
18
+
19
+ def from_euler(euler: Float[Any, "... 3"], convention: str = "ZYX") -> Float[Any, "... 4"]:
20
+ R = _euler_to_matrix(euler, convention)
21
+
22
+ return matrix.from_matrix(R)
23
+
24
+
25
+ def _euler_to_matrix(euler, convention):
26
+ """Convert Euler angles to rotation matrix."""
27
+ xp = get_namespace(euler)
28
+
29
+ eye = xp.eye(3, dtype=euler.dtype)
30
+ output_shape = euler.shape[:-1] + (3, 3)
31
+ R = xp.broadcast_to(eye, output_shape)
32
+
33
+ is_extrinsic = convention.islower()
34
+ conv = convention.lower()
35
+
36
+ for i, axis in enumerate(conv):
37
+ angle = euler[..., i]
38
+ R_axis = _rotation_matrix(angle, axis)
39
+
40
+ if is_extrinsic:
41
+ R = xp.matmul(R_axis, R)
42
+ else:
43
+ R = xp.matmul(R, R_axis)
44
+
45
+ return R
46
+
47
+
48
+ def _rotation_matrix(angle, axis):
49
+ """Create rotation matrix for given angle and axis."""
50
+ xp = get_namespace(angle)
51
+
52
+ cos_a = xp.cos(angle)
53
+ sin_a = xp.sin(angle)
54
+ zero = xp.zeros_like(cos_a)
55
+ ones = xp.ones_like(cos_a)
56
+
57
+ if axis == "x":
58
+ mat = xp.stack(
59
+ [xp.stack([ones, zero, zero], axis=-1), xp.stack([zero, cos_a, -sin_a], axis=-1), xp.stack([zero, sin_a, cos_a], axis=-1)],
60
+ axis=-2,
61
+ )
62
+ elif axis == "y":
63
+ mat = xp.stack(
64
+ [xp.stack([cos_a, zero, sin_a], axis=-1), xp.stack([zero, ones, zero], axis=-1), xp.stack([-sin_a, zero, cos_a], axis=-1)],
65
+ axis=-2,
66
+ )
67
+ elif axis == "z":
68
+ mat = xp.stack(
69
+ [xp.stack([cos_a, -sin_a, zero], axis=-1), xp.stack([sin_a, cos_a, zero], axis=-1), xp.stack([zero, zero, ones], axis=-1)],
70
+ axis=-2,
71
+ )
72
+ else:
73
+ raise ValueError(f"Invalid axis: {axis}")
74
+
75
+ return mat
76
+
77
+
78
+ def _matrix_to_euler(matrix, convention):
79
+ """Convert rotation matrix to Euler angles."""
80
+ xp = get_namespace(matrix)
81
+
82
+ is_extrinsic = convention.islower()
83
+
84
+ if is_extrinsic:
85
+ convention = convention.upper()
86
+ convention = convention[::-1]
87
+
88
+ eulers = _matrix_to_euler_angles(matrix, convention)
89
+
90
+ if is_extrinsic:
91
+ return xp.stack([eulers[..., 2], eulers[..., 1], eulers[..., 0]], axis=-1)
92
+
93
+ return eulers
94
+
95
+
96
+ def _matrix_to_euler_angles(matrix, convention):
97
+ """Extract Euler angles from rotation matrix using systematic approach."""
98
+ xp = get_namespace(matrix)
99
+
100
+ if len(convention) != 3:
101
+ raise ValueError("Convention must have 3 letters.")
102
+ if convention[1] in (convention[0], convention[2]):
103
+ raise ValueError(f"Invalid convention {convention}.")
104
+ for letter in convention:
105
+ if letter not in ("X", "Y", "Z"):
106
+ raise ValueError(f"Invalid letter {letter} in convention string.")
107
+
108
+ i0 = _index_from_letter(convention[0])
109
+ i2 = _index_from_letter(convention[2])
110
+ tait_bryan = i0 != i2
111
+
112
+ if tait_bryan:
113
+ sign = -1.0 if i0 - i2 in [-1, 2] else 1.0
114
+ x = matrix[..., i0, i2] * sign
115
+
116
+ one = xp.ones_like(x)
117
+ eps = xp.finfo(x.dtype).eps * one
118
+ central_angle = xp.arcsin(xp.clip(x, -one + eps, one - eps))
119
+ else:
120
+ central_angle = xp.arccos(xp.clip(matrix[..., i0, i0], -1, 1))
121
+
122
+ first_angle = _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan, xp)
123
+ third_angle = _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan, xp)
124
+
125
+ return xp.stack([first_angle, central_angle, third_angle], axis=-1)
126
+
127
+
128
+ def _angle_from_tan(axis, other_axis, data, horizontal, tait_bryan, xp):
129
+ """Compute angle from tangent using systematic indexing."""
130
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
131
+ if horizontal:
132
+ i2, i1 = i1, i2
133
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
134
+ if horizontal == even:
135
+ return xp.atan2(data[..., i1], data[..., i2])
136
+ if tait_bryan:
137
+ return xp.atan2(-data[..., i2], data[..., i1])
138
+ return xp.atan2(data[..., i2], -data[..., i1])
139
+
140
+
141
+ def _index_from_letter(letter):
142
+ """Convert axis letter to index."""
143
+ if letter == "X":
144
+ return 0
145
+ if letter == "Y":
146
+ return 1
147
+ if letter == "Z":
148
+ return 2
149
+ raise ValueError("letter must be either X, Y or Z.")
@@ -0,0 +1,79 @@
1
+ """Matrix conversions for SO(3) rotations."""
2
+
3
+ from typing import Any
4
+
5
+ from jaxtyping import Float
6
+
7
+ from nanomanifold.common import get_namespace
8
+
9
+ from ..canonicalize import canonicalize
10
+
11
+
12
+ def to_matrix(q: Float[Any, "... 4"]) -> Float[Any, "... 3 3"]:
13
+ xp = get_namespace(q)
14
+ q = canonicalize(q)
15
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
16
+
17
+ R = xp.stack(
18
+ [
19
+ xp.stack([1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y)], axis=-1),
20
+ xp.stack([2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x)], axis=-1),
21
+ xp.stack([2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)], axis=-1),
22
+ ],
23
+ axis=-2,
24
+ )
25
+
26
+ return R
27
+
28
+
29
+ def from_matrix(R: Float[Any, "... 3 3"]) -> Float[Any, "... 4"]:
30
+ xp = get_namespace(R)
31
+
32
+ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
33
+
34
+ zero = trace * 0
35
+ one = zero + 1
36
+ eps = one * 1e-10
37
+ quarter = one * 0.25
38
+ two = one * 2
39
+
40
+ s1 = xp.sqrt(xp.maximum(zero, trace + one)) * two # s = 4 * w
41
+ s1_safe = xp.where(s1 < eps, eps, s1) # Avoid division by zero
42
+ w1 = quarter * s1
43
+ x1 = (R[..., 2, 1] - R[..., 1, 2]) / s1_safe
44
+ y1 = (R[..., 0, 2] - R[..., 2, 0]) / s1_safe
45
+ z1 = (R[..., 1, 0] - R[..., 0, 1]) / s1_safe
46
+
47
+ s2 = xp.sqrt(xp.maximum(zero, one + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])) * two # s = 4 * x
48
+ s2_safe = xp.where(s2 < eps, eps, s2) # Avoid division by zero
49
+ w2 = (R[..., 2, 1] - R[..., 1, 2]) / s2_safe
50
+ x2 = quarter * s2
51
+ y2 = (R[..., 0, 1] + R[..., 1, 0]) / s2_safe
52
+ z2 = (R[..., 0, 2] + R[..., 2, 0]) / s2_safe
53
+
54
+ s3 = xp.sqrt(xp.maximum(zero, one + R[..., 1, 1] - R[..., 0, 0] - R[..., 2, 2])) * two # s = 4 * y
55
+ s3_safe = xp.where(s3 < eps, eps, s3) # Avoid division by zero
56
+ w3 = (R[..., 0, 2] - R[..., 2, 0]) / s3_safe
57
+ x3 = (R[..., 0, 1] + R[..., 1, 0]) / s3_safe
58
+ y3 = quarter * s3
59
+ z3 = (R[..., 1, 2] + R[..., 2, 1]) / s3_safe
60
+
61
+ s4 = xp.sqrt(xp.maximum(zero, one + R[..., 2, 2] - R[..., 0, 0] - R[..., 1, 1])) * two # s = 4 * z
62
+ s4_safe = xp.where(s4 < eps, eps, s4) # Avoid division by zero
63
+ w4 = (R[..., 1, 0] - R[..., 0, 1]) / s4_safe
64
+ x4 = (R[..., 0, 2] + R[..., 2, 0]) / s4_safe
65
+ y4 = (R[..., 1, 2] + R[..., 2, 1]) / s4_safe
66
+ z4 = quarter * s4
67
+
68
+ cond1 = trace > 0
69
+ cond2 = (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
70
+ cond3 = R[..., 1, 1] > R[..., 2, 2]
71
+
72
+ w = xp.where(cond1, w1, xp.where(cond2, w2, xp.where(cond3, w3, w4)))
73
+ x = xp.where(cond1, x1, xp.where(cond2, x2, xp.where(cond3, x3, x4)))
74
+ y = xp.where(cond1, y1, xp.where(cond2, y2, xp.where(cond3, y3, y4)))
75
+ z = xp.where(cond1, z1, xp.where(cond2, z2, xp.where(cond3, z3, z4)))
76
+
77
+ q = xp.stack([w, x, y, z], axis=-1)
78
+
79
+ return canonicalize(q)
@@ -0,0 +1,52 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+
8
+ def distance(q1: Float[Any, "... 4"], q2: Float[Any, "... 4"]) -> Float[Any, "..."]:
9
+ """Compute the angular distance between two quaternions representing SO(3) rotations.
10
+
11
+ The angular distance is the smallest angle needed to rotate from one orientation
12
+ to another, measured in radians. This is equivalent to the geodesic distance
13
+ on the SO(3) manifold.
14
+
15
+ Args:
16
+ q1: First quaternion in [w, x, y, z] format
17
+ q2: Second quaternion in [w, x, y, z] format
18
+
19
+ Returns:
20
+ Angular distance in radians, in range [0, π]
21
+ """
22
+ xp = get_namespace(q1)
23
+
24
+ norm1 = xp.sqrt(xp.sum(q1**2, axis=-1, keepdims=True))
25
+ norm2 = xp.sqrt(xp.sum(q2**2, axis=-1, keepdims=True))
26
+ q1_unit = q1 / norm1
27
+ q2_unit = q2 / norm2
28
+
29
+ # Flip sign of q2 when dot(q1, q2) < 0 so the relative rotation
30
+ # always measures the shorter geodesic on the double cover.
31
+ dot_keepdims = xp.sum(q1_unit * q2_unit, axis=-1, keepdims=True)
32
+ q2_unit = xp.where(dot_keepdims < 0, -q2_unit, q2_unit)
33
+
34
+ w1 = q1_unit[..., :1]
35
+ v1 = q1_unit[..., 1:]
36
+ w2 = q2_unit[..., :1]
37
+ v2 = q2_unit[..., 1:]
38
+
39
+ cross = xp.stack(
40
+ [
41
+ v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1],
42
+ v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2],
43
+ v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0],
44
+ ],
45
+ axis=-1,
46
+ )
47
+
48
+ vec = w1 * v2 - w2 * v1 - cross
49
+ vec_norm = xp.sqrt(xp.sum(vec**2, axis=-1))
50
+ w = w1 * w2 + xp.sum(v1 * v2, axis=-1, keepdims=True)
51
+
52
+ return 2.0 * xp.atan2(vec_norm, w[..., 0])
@@ -0,0 +1,25 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from .conversions.axis_angle import from_axis_angle
6
+
7
+
8
+ def exp(tangent_vector: Float[Any, "... 3"]) -> Float[Any, "... 4"]:
9
+ """Compute the exponential map from so(3) tangent space to SO(3) manifold.
10
+
11
+ The exponential map takes a tangent vector in the Lie algebra so(3)
12
+ and returns the corresponding rotation quaternion. This is the inverse
13
+ operation of log().
14
+
15
+ The exponential map is mathematically equivalent to converting an axis-angle
16
+ representation to its corresponding quaternion.
17
+
18
+ Args:
19
+ tangent_vector: Tangent vector in so(3) (axis-angle representation)
20
+ The magnitude is the rotation angle, direction is the rotation axis
21
+
22
+ Returns:
23
+ Quaternion in [w, x, y, z] format representing the rotation
24
+ """
25
+ return from_axis_angle(tangent_vector)
@@ -0,0 +1,36 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from ..common import get_namespace
6
+
7
+
8
+ def hat(w: Float[Any, "... 3"]) -> Float[Any, "... 3 3"]:
9
+ """Map vector to skew-symmetric matrix (hat operator).
10
+
11
+ Args:
12
+ w: (..., 3) array representing tangent vectors in so(3)
13
+
14
+ Returns:
15
+ (..., 3, 3) skew-symmetric matrices
16
+ """
17
+ xp = get_namespace(w)
18
+
19
+ w1 = w[..., 0]
20
+ w2 = w[..., 1]
21
+ w3 = w[..., 2]
22
+
23
+ zero = w1 * 0
24
+
25
+ # Build skew-symmetric matrix:
26
+ # [[ 0, -w3, w2],
27
+ # [ w3, 0, -w1],
28
+ # [-w2, w1, 0]]
29
+
30
+ row1 = xp.stack([zero, -w3, w2], axis=-1)
31
+ row2 = xp.stack([w3, zero, -w1], axis=-1)
32
+ row3 = xp.stack([-w2, w1, zero], axis=-1)
33
+
34
+ result = xp.stack([row1, row2, row3], axis=-2)
35
+
36
+ return result
@@ -0,0 +1,16 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+ from .canonicalize import canonicalize
8
+
9
+
10
+ def inverse(q: Float[Any, "... 4"]) -> Float[Any, "... 4"]:
11
+ xp = get_namespace(q)
12
+ q = canonicalize(q)
13
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
14
+ q_inv = xp.stack([w, -x, -y, -z], axis=-1)
15
+
16
+ return canonicalize(q_inv)
@@ -0,0 +1,25 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from .conversions.axis_angle import to_axis_angle
6
+
7
+
8
+ def log(q: Float[Any, "... 4"]) -> Float[Any, "... 3"]:
9
+ """Compute the logarithmic map of a quaternion on the SO(3) manifold.
10
+
11
+ The logarithmic map takes a rotation and returns the corresponding
12
+ tangent vector (axis-angle representation) in the Lie algebra so(3).
13
+ This is the inverse operation of exp().
14
+
15
+ The logarithmic map is mathematically equivalent to converting a quaternion
16
+ to its axis-angle representation.
17
+
18
+ Args:
19
+ q: Quaternion in [w, x, y, z] format representing a rotation
20
+
21
+ Returns:
22
+ Tangent vector in so(3) (axis-angle representation)
23
+ The magnitude is the rotation angle, direction is the rotation axis
24
+ """
25
+ return to_axis_angle(q)
@@ -0,0 +1,40 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+ from .canonicalize import canonicalize
8
+
9
+
10
+ def multiply(q1: Float[Any, "... 4"], q2: Float[Any, "... 4"]) -> Float[Any, "... 4"]:
11
+ """Multiply two quaternions representing SO(3) rotations.
12
+
13
+ The multiplication order matches rotation matrix multiplication:
14
+ multiply(q1, q2) represents the same composition as to_matrix(q1) @ to_matrix(q2)
15
+
16
+ This means q2 is applied first, then q1.
17
+
18
+ Args:
19
+ q1: First quaternion in [w, x, y, z] format
20
+ q2: Second quaternion in [w, x, y, z] format
21
+
22
+ Returns:
23
+ Product quaternion representing the composed rotation
24
+ """
25
+ xp = get_namespace(q1)
26
+
27
+ q1 = canonicalize(q1)
28
+ q2 = canonicalize(q2)
29
+
30
+ w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
31
+ w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
32
+
33
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
34
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
35
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
36
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
37
+
38
+ result = xp.stack([w, x, y, z], axis=-1)
39
+
40
+ return canonicalize(result)
@@ -0,0 +1,50 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+ from .canonicalize import canonicalize
8
+
9
+
10
+ def rotate_points(q: Float[Any, "... 4"], points: Float[Any, "... N 3"]) -> Float[Any, "... N 3"]:
11
+ """Rotate points using quaternion rotation.
12
+
13
+ Args:
14
+ q: Quaternion in [w, x, y, z] format of shape (..., 4)
15
+ points: Points to rotate of shape (..., N, 3)
16
+
17
+ Returns:
18
+ Rotated points of shape (..., N, 3)
19
+ """
20
+ xp = get_namespace(q)
21
+ q = canonicalize(q)
22
+
23
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
24
+
25
+ w = w[..., None]
26
+ x = x[..., None]
27
+ y = y[..., None]
28
+ z = z[..., None]
29
+
30
+ px, py, pz = points[..., 0], points[..., 1], points[..., 2]
31
+
32
+ # Apply quaternion rotation using the formula:
33
+ # p' = q * p * q^(-1)
34
+ # This can be computed efficiently as:
35
+ # p' = p + 2 * cross(q_vec, cross(q_vec, p) + w * p)
36
+ # where q_vec = [x, y, z]
37
+
38
+ cross1_x = y * pz - z * py + w * px
39
+ cross1_y = z * px - x * pz + w * py
40
+ cross1_z = x * py - y * px + w * pz
41
+
42
+ cross2_x = y * cross1_z - z * cross1_y
43
+ cross2_y = z * cross1_x - x * cross1_z
44
+ cross2_z = x * cross1_y - y * cross1_x
45
+
46
+ result_x = px + 2 * cross2_x
47
+ result_y = py + 2 * cross2_y
48
+ result_z = pz + 2 * cross2_z
49
+
50
+ return xp.stack([result_x, result_y, result_z], axis=-1)
@@ -0,0 +1,69 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+ from .canonicalize import canonicalize
8
+
9
+
10
+ def slerp(q1: Float[Any, "... 4"], q2: Float[Any, "... 4"], t: Float[Any, "... N"]) -> Float[Any, "... N 4"]:
11
+ """Spherical linear interpolation between two quaternions representing SO(3) rotations.
12
+
13
+ Performs geodesic interpolation on the SO(3) manifold, taking the shortest
14
+ path between two rotations. The interpolation parameter array t must have
15
+ values in [0, 1].
16
+
17
+ Args:
18
+ q1: Start quaternion in [w, x, y, z] format
19
+ q2: End quaternion in [w, x, y, z] format
20
+ t: Array of interpolation parameters in [0, 1]. Last dimension N represents
21
+ the number of interpolation points. For a single point, use shape [..., 1].
22
+ t values: 0 returns q1, 1 returns q2
23
+
24
+ Returns:
25
+ Interpolated quaternions with shape [..., N, 4] where N is the number of
26
+ interpolation points from the last dimension of t
27
+
28
+ Raises:
29
+ ValueError: If any t values are outside the range [0, 1]
30
+ """
31
+ xp = get_namespace(q1)
32
+
33
+ q1 = canonicalize(q1)
34
+ q2 = canonicalize(q2)
35
+
36
+ q1_expanded = xp.expand_dims(q1, axis=-2)
37
+ q2_expanded = xp.expand_dims(q2, axis=-2)
38
+
39
+ dot_product = xp.sum(q1_expanded * q2_expanded, axis=-1, keepdims=True)
40
+
41
+ q2_corrected = xp.where(dot_product < 0, -q2_expanded, q2_expanded)
42
+ dot_product = xp.where(dot_product < 0, -dot_product, dot_product)
43
+
44
+ dot_product = xp.clip(dot_product, 0.0, 1.0)
45
+
46
+ threshold = 0.9995
47
+
48
+ t_expanded = xp.expand_dims(t, axis=-1)
49
+
50
+ use_linear = dot_product > threshold
51
+
52
+ linear_result = (1.0 - t_expanded) * q1_expanded + t_expanded * q2_corrected
53
+ linear_norm = xp.sqrt(xp.sum(linear_result**2, axis=-1, keepdims=True))
54
+ linear_result = linear_result / linear_norm
55
+
56
+ omega = xp.acos(dot_product)
57
+ sin_omega = xp.sin(omega)
58
+
59
+ eps = 1e-8
60
+ sin_omega_safe = xp.where(xp.abs(sin_omega) < eps, eps, sin_omega)
61
+
62
+ weight1 = xp.sin((1.0 - t_expanded) * omega) / sin_omega_safe
63
+ weight2 = xp.sin(t_expanded * omega) / sin_omega_safe
64
+
65
+ spherical_result = weight1 * q1_expanded + weight2 * q2_corrected
66
+
67
+ result = xp.where(use_linear, linear_result, spherical_result)
68
+
69
+ return canonicalize(result)
@@ -0,0 +1,28 @@
1
+ from typing import Any
2
+
3
+ from jaxtyping import Float
4
+
5
+ from ..common import get_namespace
6
+
7
+
8
+ def vee(W: Float[Any, "... 3 3"]) -> Float[Any, "... 3"]:
9
+ """Map skew-symmetric matrix to vector (vee operator).
10
+
11
+ Args:
12
+ W: (..., 3, 3) skew-symmetric matrices
13
+
14
+ Returns:
15
+ (..., 3) tangent vectors in so(3)
16
+ """
17
+ xp = get_namespace(W)
18
+
19
+ # Extract components from skew-symmetric matrix:
20
+ # [[ 0, -w3, w2],
21
+ # [ w3, 0, -w1],
22
+ # [-w2, w1, 0]]
23
+
24
+ w1 = W[..., 2, 1]
25
+ w2 = W[..., 0, 2]
26
+ w3 = W[..., 1, 0]
27
+
28
+ return xp.stack([w1, w2, w3], axis=-1)
@@ -0,0 +1,89 @@
1
+ from typing import Any, Sequence
2
+
3
+ from jaxtyping import Float
4
+
5
+ from nanomanifold.common import get_namespace
6
+
7
+ from .canonicalize import canonicalize
8
+
9
+
10
+ def weighted_mean(quaternions: Sequence[Float[Any, "... 4"]], weights: Float[Any, "... N"]) -> Float[Any, "... 4"]:
11
+ """Compute the weighted mean of SO(3) rotations represented as quaternions.
12
+
13
+ This function implements the Riemannian mean on SO(3) by computing the weighted
14
+ average in quaternion space using the outer product method. The result is the
15
+ eigenvector corresponding to the largest eigenvalue of the weighted covariance matrix.
16
+
17
+ Args:
18
+ quaternions: Sequence of quaternions in [w, x, y, z] format. Each quaternion
19
+ should have shape [..., 4] where the last dimension contains the
20
+ quaternion components.
21
+ weights: Array of weights with shape [..., N] where N is the number of quaternions.
22
+ The weights are normalized internally.
23
+
24
+ Returns:
25
+ Weighted mean quaternion with shape [..., 4] in [w, x, y, z] format.
26
+ The result is canonicalized to ensure w >= 0.
27
+
28
+ Note:
29
+ This implementation follows the algorithm from:
30
+ "Averaging Quaternions" by F. Landis Markley et al.
31
+ """
32
+ xp = get_namespace(quaternions[0])
33
+ original_dtype = quaternions[0].dtype
34
+
35
+ quats = xp.stack([xp.asarray(q, dtype=original_dtype) for q in quaternions], axis=-2)
36
+ weights_array = xp.asarray(weights, dtype=original_dtype)
37
+
38
+ norms = xp.linalg.norm(quats, axis=-1, keepdims=True)
39
+ eps = xp.finfo(original_dtype).eps * 10
40
+ safe_norms = xp.where(norms < eps, eps, norms)
41
+ quats_normalized = quats / safe_norms
42
+
43
+ sign_mask = quats_normalized[..., 0:1] < 0
44
+ quats_canonical = xp.where(sign_mask, -quats_normalized, quats_normalized)
45
+
46
+ weights_normalized = weights_array / xp.sum(weights_array, axis=-1, keepdims=True)
47
+
48
+ weighted_quats = weights_normalized[..., :, None] * quats_canonical
49
+
50
+ M = xp.einsum("...nj,...nk->...jk", weighted_quats, quats_canonical)
51
+
52
+ if original_dtype == xp.float16:
53
+ M_compute = xp.asarray(M, dtype=xp.float32)
54
+ eigenvalues, eigenvectors = xp.linalg.eigh(M_compute)
55
+ eigenvectors = xp.asarray(eigenvectors, dtype=original_dtype)
56
+ else:
57
+ eigenvalues, eigenvectors = xp.linalg.eigh(M)
58
+
59
+ avg_quat = eigenvectors[..., :, -1]
60
+
61
+ avg_quat = avg_quat / xp.linalg.norm(avg_quat, axis=-1, keepdims=True)
62
+
63
+ return canonicalize(avg_quat)
64
+
65
+
66
+ def mean(quaternions: Sequence[Float[Any, "... 4"]]) -> Float[Any, "... 4"]:
67
+ """Compute the mean of SO(3) rotations represented as quaternions.
68
+
69
+ This is equivalent to weighted_mean with uniform weights.
70
+
71
+ Args:
72
+ quaternions: Sequence of quaternions in [w, x, y, z] format. Each quaternion
73
+ should have shape [..., 4] where the last dimension contains the
74
+ quaternion components.
75
+
76
+ Returns:
77
+ Mean quaternion with shape [..., 4] in [w, x, y, z] format.
78
+ The result is canonicalized to ensure w >= 0.
79
+ """
80
+ if len(quaternions) == 0:
81
+ raise ValueError("Cannot compute mean of empty quaternion sequence")
82
+
83
+ xp = get_namespace(quaternions[0])
84
+
85
+ batch_shape = quaternions[0].shape[:-1]
86
+ num_quats = len(quaternions)
87
+ weights = xp.ones(batch_shape + (num_quats,), dtype=quaternions[0].dtype)
88
+
89
+ return weighted_mean(quaternions, weights)
@@ -0,0 +1,3 @@
1
+ from . import SE3, SO3
2
+
3
+ __all__ = ["SO3", "SE3"]