FastQuat 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastquat/__init__.py +3 -0
- fastquat/quaternion.py +586 -0
- fastquat-0.2.dist-info/METADATA +127 -0
- fastquat-0.2.dist-info/RECORD +5 -0
- fastquat-0.2.dist-info/WHEEL +4 -0
fastquat/__init__.py
ADDED
fastquat/quaternion.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from jax import Array
|
|
8
|
+
from jax.tree_util import register_pytree_node_class
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_pytree_node_class
|
|
12
|
+
class Quaternion:
|
|
13
|
+
"""Class for manipulating quaternion tensors with JAX.
|
|
14
|
+
|
|
15
|
+
A quaternion is represented by [w, x, y, z] where w is the scalar part
|
|
16
|
+
and (x, y, z) is the vector part.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
w: float | jnp.ndarray,
|
|
22
|
+
x: float | jnp.ndarray,
|
|
23
|
+
y: float | jnp.ndarray,
|
|
24
|
+
z: float | jnp.ndarray,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initialize a tensor of quaternions.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
w, x, y, z: components of the quaternions.
|
|
30
|
+
"""
|
|
31
|
+
w, x, y, z = jnp.broadcast_arrays(w, x, y, z)
|
|
32
|
+
self.wxyz = jnp.stack([w, x, y, z], axis=-1)
|
|
33
|
+
|
|
34
|
+
def tree_flatten(self) -> tuple[tuple[Any, ...], Any]:
|
|
35
|
+
"""Flatten the Quaternion PyTree."""
|
|
36
|
+
return (self.wxyz,), None
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def tree_unflatten(cls, aux_data, children) -> Quaternion:
|
|
40
|
+
"""Unflatten The Quaternion PyTree"""
|
|
41
|
+
# Create an instance directly without going through from_array to avoid tracer issues
|
|
42
|
+
instance = cls.__new__(cls)
|
|
43
|
+
instance.wxyz = children[0]
|
|
44
|
+
return instance
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_array(cls, array: Array) -> Quaternion:
|
|
48
|
+
"""Create a Quaternion array from a numeric array of shape (..., 4).
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
array: array of shape (..., 4) where the last dimension is [w, x, y, z]
|
|
52
|
+
"""
|
|
53
|
+
# Handle JAX tracers and arrays properly
|
|
54
|
+
if not isinstance(array, jnp.ndarray):
|
|
55
|
+
array = jnp.asarray(array)
|
|
56
|
+
|
|
57
|
+
if array.shape[-1:] != (4,):
|
|
58
|
+
raise ValueError(f'Array must have shape (..., 4), got {array.shape}')
|
|
59
|
+
|
|
60
|
+
instance = cls.__new__(cls)
|
|
61
|
+
instance.wxyz = array
|
|
62
|
+
return instance
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_scalar_vector(cls, scalar: Array, vector: Array) -> Quaternion:
|
|
66
|
+
"""Create a quaternion from scalar and vector parts.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
scalar: Array of shape (...,) for the scalar part.
|
|
70
|
+
vector: Array of shape (..., 3) for the vector part.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Quaternion
|
|
74
|
+
"""
|
|
75
|
+
if vector.shape[-1:] != (3,):
|
|
76
|
+
raise ValueError(f'Vector must have shape (..., 3), got {vector.shape}')
|
|
77
|
+
scalar = jnp.expand_dims(scalar, axis=-1)
|
|
78
|
+
return cls.from_array(jnp.concatenate([scalar, vector], axis=-1))
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_rotation_matrix(cls, rot: Array) -> Quaternion:
|
|
82
|
+
"""Create the quaternion associated to a rotation matrix.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
rot: Array of shape (..., 3, 3) representing the rotation matrix
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
The normalized Quaternion tensor representing the rotation matrix.
|
|
89
|
+
"""
|
|
90
|
+
if rot.shape[-2:] != (3, 3):
|
|
91
|
+
raise ValueError(f'Rotation matrix must have shape (..., 3, 3), got {rot.shape}')
|
|
92
|
+
|
|
93
|
+
# Implémentation de la conversion matrice -> quaternion
|
|
94
|
+
trace = jnp.trace(rot, axis1=-2, axis2=-1)
|
|
95
|
+
|
|
96
|
+
# Cas où trace > 0
|
|
97
|
+
s = jnp.sqrt(trace + 1.0) * 2 # s = 4 * w
|
|
98
|
+
w = 0.25 * s
|
|
99
|
+
x = (rot[..., 2, 1] - rot[..., 1, 2]) / s
|
|
100
|
+
y = (rot[..., 0, 2] - rot[..., 2, 0]) / s
|
|
101
|
+
z = (rot[..., 1, 0] - rot[..., 0, 1]) / s
|
|
102
|
+
|
|
103
|
+
return cls.from_array(jnp.stack([w, x, y, z], axis=-1))
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def zeros(cls, shape: tuple[int, ...] = (), dtype: jnp.dtype = jnp.float32) -> Quaternion:
|
|
107
|
+
"""Create quaternions with all components set to 0.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
shape: Shape of the tensor (without the last dimension).
|
|
111
|
+
dtype: Data type of the quaternion components.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Quaternion with all components equal to 0.
|
|
115
|
+
"""
|
|
116
|
+
data = jnp.zeros(shape + (4,), dtype=dtype)
|
|
117
|
+
return cls.from_array(data)
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def ones(cls, shape: tuple[int, ...] = (), dtype: jnp.dtype = jnp.float32) -> Quaternion:
|
|
121
|
+
"""Create quaternions with scalar component set to 1 and vector components set to 0.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
shape: Shape of the tensor (without the last dimension).
|
|
125
|
+
dtype: Data type of the quaternion components.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Quaternions with w=1 and x=y=z=0.
|
|
129
|
+
"""
|
|
130
|
+
data = jnp.zeros(shape + (4,), dtype=dtype)
|
|
131
|
+
data = data.at[..., 0].set(1.0)
|
|
132
|
+
return cls.from_array(data)
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def full(
|
|
136
|
+
cls, shape: tuple[int, ...], fill_value: float, dtype: jnp.dtype = jnp.float32
|
|
137
|
+
) -> Quaternion:
|
|
138
|
+
"""Create quaternions with scalar component set to a value and vector components set to 0.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
shape: Shape of the tensor (without the last dimension).
|
|
142
|
+
fill_value: Value to fill the scalar component with.
|
|
143
|
+
dtype: Data type of the quaternion components.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Quaternions with w=fill_value and x=y=z=0.
|
|
147
|
+
"""
|
|
148
|
+
data = jnp.zeros(shape + (4,), dtype=dtype)
|
|
149
|
+
data = data.at[..., 0].set(fill_value)
|
|
150
|
+
return cls.from_array(data)
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def random(cls, key: jax.random.PRNGKey, shape: tuple[int, ...] = ()) -> Quaternion:
|
|
154
|
+
"""Generate normalized random quaternions.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
key: Key PRNG.
|
|
158
|
+
shape: Shape of the tensor (without the last dimension).
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Normalized Quaternion.
|
|
162
|
+
"""
|
|
163
|
+
data = jax.random.normal(key, shape + (4,))
|
|
164
|
+
return Quaternion.from_array(data).normalize()
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def w(self) -> Array:
|
|
168
|
+
return self.wxyz[..., 0]
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def x(self) -> Array:
|
|
172
|
+
return self.wxyz[..., 1]
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def y(self) -> Array:
|
|
176
|
+
return self.wxyz[..., 2]
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def z(self) -> Array:
|
|
180
|
+
return self.wxyz[..., 3]
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def vector(self) -> Array:
|
|
184
|
+
"""Vector part (..., 3)"""
|
|
185
|
+
return self.wxyz[..., 1:]
|
|
186
|
+
|
|
187
|
+
def norm(self) -> Array:
|
|
188
|
+
"""Quaternion norm."""
|
|
189
|
+
return jnp.sqrt(jnp.sum(self.wxyz**2, axis=-1))
|
|
190
|
+
|
|
191
|
+
def normalize(self) -> Quaternion:
|
|
192
|
+
"""Normalize the quaternion.
|
|
193
|
+
|
|
194
|
+
Returns the normalized quaternion. If the quaternion has zero norm,
|
|
195
|
+
returns the zero quaternion [0, 0, 0, 0].
|
|
196
|
+
"""
|
|
197
|
+
norm = self.norm()
|
|
198
|
+
# Avoid division by zero
|
|
199
|
+
safe_norm = jnp.where(norm == 0, 1.0, norm)
|
|
200
|
+
return Quaternion.from_array(self.wxyz / jnp.expand_dims(safe_norm, axis=-1))
|
|
201
|
+
|
|
202
|
+
def _inverse(self) -> Quaternion:
|
|
203
|
+
"""Quaternion inverse (private method - use 1/q instead)."""
|
|
204
|
+
conj = self.conj()
|
|
205
|
+
norm_sq = self.norm() ** 2
|
|
206
|
+
return Quaternion.from_array(conj.wxyz / jnp.expand_dims(norm_sq, axis=-1))
|
|
207
|
+
|
|
208
|
+
def to_components(self) -> tuple[Array, Array, Array, Array]:
|
|
209
|
+
return self.w, self.x, self.y, self.z
|
|
210
|
+
|
|
211
|
+
def to_rotation_matrix(self) -> Array:
|
|
212
|
+
"""Convert quaternion to rotation matrix.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Array of shape (..., 3, 3)
|
|
216
|
+
"""
|
|
217
|
+
# Normalize the quaternion
|
|
218
|
+
q = self.normalize()
|
|
219
|
+
w, x, y, z = q.to_components()
|
|
220
|
+
|
|
221
|
+
# Calculate matrix elements
|
|
222
|
+
xx, yy, zz = x * x, y * y, z * z
|
|
223
|
+
xy, xz, yz = x * y, x * z, y * z
|
|
224
|
+
wx, wy, wz = w * x, w * y, w * z
|
|
225
|
+
|
|
226
|
+
rot = jnp.stack(
|
|
227
|
+
[
|
|
228
|
+
jnp.stack([1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], axis=-1),
|
|
229
|
+
jnp.stack([2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], axis=-1),
|
|
230
|
+
jnp.stack([2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], axis=-1),
|
|
231
|
+
],
|
|
232
|
+
axis=-2,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return rot
|
|
236
|
+
|
|
237
|
+
def rotate_vector(self, v: Array) -> Array:
|
|
238
|
+
"""Apply quaternion rotation to a vector.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
v: Array of shape (..., 3) representing vectors
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Array of shape (..., 3) representing rotated vectors
|
|
245
|
+
"""
|
|
246
|
+
# Convert vector to pure quaternion
|
|
247
|
+
v_quat = Quaternion(0, v[..., 0], v[..., 1], v[..., 2])
|
|
248
|
+
# Apply rotation: q * v * q^-1
|
|
249
|
+
result = self * v_quat * self._inverse()
|
|
250
|
+
|
|
251
|
+
return result.vector
|
|
252
|
+
|
|
253
|
+
def __repr__(self) -> str:
|
|
254
|
+
if self.shape == ():
|
|
255
|
+
w, x, y, z = self.wxyz
|
|
256
|
+
return f'{w} + {x}i + {y}j + {z}k'
|
|
257
|
+
return f'Quaternion(shape={self.shape}, dtype={self.dtype})'
|
|
258
|
+
|
|
259
|
+
#######################
|
|
260
|
+
# JAX array interface #
|
|
261
|
+
#######################
|
|
262
|
+
|
|
263
|
+
def __len__(self):
|
|
264
|
+
"""Length of the first axis."""
|
|
265
|
+
if self.ndim == 0:
|
|
266
|
+
raise TypeError('len() of unsized object')
|
|
267
|
+
return self.shape[0]
|
|
268
|
+
|
|
269
|
+
def __iter__(self):
|
|
270
|
+
"""Iterate over the first axis."""
|
|
271
|
+
if self.ndim == 0:
|
|
272
|
+
raise TypeError('iteration over a 0-d quaternion')
|
|
273
|
+
for i in range(self.shape[0]):
|
|
274
|
+
yield Quaternion.from_array(self.wxyz[i])
|
|
275
|
+
|
|
276
|
+
def __add__(self, other: Any) -> Quaternion:
|
|
277
|
+
"""Quaternion addition."""
|
|
278
|
+
if isinstance(other, Quaternion):
|
|
279
|
+
return Quaternion.from_array(self.wxyz + other.wxyz)
|
|
280
|
+
|
|
281
|
+
if isinstance(other, int | float | jnp.ndarray):
|
|
282
|
+
return Quaternion.from_scalar_vector(self.w + other, self.vector)
|
|
283
|
+
|
|
284
|
+
raise NotImplementedError
|
|
285
|
+
|
|
286
|
+
def __radd__(self, other: Any) -> Quaternion:
|
|
287
|
+
"""Quaternion addition."""
|
|
288
|
+
return self.__add__(other)
|
|
289
|
+
|
|
290
|
+
def __sub__(self, other: Any) -> Quaternion:
|
|
291
|
+
"""Quaternion subtraction."""
|
|
292
|
+
if isinstance(other, Quaternion):
|
|
293
|
+
return Quaternion.from_array(self.wxyz - other.wxyz)
|
|
294
|
+
|
|
295
|
+
if isinstance(other, int | float | jnp.ndarray):
|
|
296
|
+
return Quaternion.from_scalar_vector(self.w - other, self.vector)
|
|
297
|
+
|
|
298
|
+
raise NotImplementedError
|
|
299
|
+
|
|
300
|
+
def __rsub__(self, other: Any) -> Quaternion:
|
|
301
|
+
"""Quaternion subtraction."""
|
|
302
|
+
if isinstance(other, int | float | jnp.ndarray):
|
|
303
|
+
return Quaternion.from_scalar_vector(other - self.w, -self.vector)
|
|
304
|
+
|
|
305
|
+
raise NotImplementedError
|
|
306
|
+
|
|
307
|
+
def __mul__(self, other: Any) -> Quaternion:
|
|
308
|
+
"""Quaternion multiplication."""
|
|
309
|
+
if isinstance(other, Quaternion):
|
|
310
|
+
w1, x1, y1, z1 = self.to_components()
|
|
311
|
+
w2, x2, y2, z2 = other.to_components()
|
|
312
|
+
|
|
313
|
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
|
314
|
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
|
315
|
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
|
316
|
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
|
317
|
+
|
|
318
|
+
return Quaternion(w, x, y, z)
|
|
319
|
+
|
|
320
|
+
if isinstance(other, int | float):
|
|
321
|
+
return Quaternion.from_array(self.wxyz * other)
|
|
322
|
+
|
|
323
|
+
if isinstance(other, jnp.ndarray):
|
|
324
|
+
return Quaternion.from_array(self.wxyz * jnp.expand_dims(other, axis=-1))
|
|
325
|
+
|
|
326
|
+
return NotImplemented
|
|
327
|
+
|
|
328
|
+
def __rmul__(self, other: Any) -> Quaternion:
|
|
329
|
+
"""Quaternion multiplication."""
|
|
330
|
+
if isinstance(other, int | float):
|
|
331
|
+
return Quaternion.from_array(other * self.wxyz)
|
|
332
|
+
|
|
333
|
+
if isinstance(other, jnp.ndarray):
|
|
334
|
+
return Quaternion.from_array(jnp.expand_dims(other, axis=-1) * self.wxyz)
|
|
335
|
+
|
|
336
|
+
return NotImplemented
|
|
337
|
+
|
|
338
|
+
def __truediv__(self, other: Any) -> Quaternion:
|
|
339
|
+
"""Quaternion division."""
|
|
340
|
+
if isinstance(other, Quaternion):
|
|
341
|
+
return self * other._inverse()
|
|
342
|
+
|
|
343
|
+
if isinstance(other, int | float):
|
|
344
|
+
return Quaternion.from_array(self.wxyz / other)
|
|
345
|
+
|
|
346
|
+
if isinstance(other, jnp.ndarray):
|
|
347
|
+
return Quaternion.from_array(self.wxyz / jnp.expand_dims(other, axis=-1))
|
|
348
|
+
|
|
349
|
+
return NotImplemented
|
|
350
|
+
|
|
351
|
+
def __rtruediv__(self, other: Any) -> Quaternion:
|
|
352
|
+
"""Quaternion division."""
|
|
353
|
+
if isinstance(other, int | float) and other == 1:
|
|
354
|
+
return self._inverse()
|
|
355
|
+
|
|
356
|
+
if isinstance(other, int | float | jnp.ndarray):
|
|
357
|
+
return other * self._inverse()
|
|
358
|
+
|
|
359
|
+
return NotImplemented
|
|
360
|
+
|
|
361
|
+
def __neg__(self) -> Quaternion:
|
|
362
|
+
"""Quaternion negation."""
|
|
363
|
+
return Quaternion.from_array(-self.wxyz)
|
|
364
|
+
|
|
365
|
+
def log(self) -> Quaternion:
|
|
366
|
+
"""Compute quaternion logarithm.
|
|
367
|
+
|
|
368
|
+
For a quaternion q = |q| * (cos(θ) + sin(θ)v), the logarithm is:
|
|
369
|
+
log(q) = log(|q|) + θ * v
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
The logarithm of the quaternion
|
|
373
|
+
"""
|
|
374
|
+
# Get norm and handle zero quaternion
|
|
375
|
+
q_norm = self.norm()
|
|
376
|
+
|
|
377
|
+
# Normalize to get unit quaternion (handles zero quaternions safely)
|
|
378
|
+
unit_q = self.normalize()
|
|
379
|
+
|
|
380
|
+
# For unit quaternion q = cos(θ) + sin(θ)v, compute θ and v
|
|
381
|
+
# θ = arccos(w) and v = vector/|vector|
|
|
382
|
+
theta = jnp.arccos(jnp.clip(unit_q.w, -1.0, 1.0))
|
|
383
|
+
vector_norm = jnp.linalg.norm(unit_q.vector, axis=-1)
|
|
384
|
+
|
|
385
|
+
# Handle case where vector is zero (real quaternion)
|
|
386
|
+
safe_vector_norm = jnp.where(vector_norm == 0, 1.0, vector_norm)
|
|
387
|
+
unit_vector = unit_q.vector / jnp.expand_dims(safe_vector_norm, -1)
|
|
388
|
+
|
|
389
|
+
# log(q) = log(|q|) + θ * v
|
|
390
|
+
log_norm = jnp.log(jnp.maximum(q_norm, 1e-10)) # Avoid log(0)
|
|
391
|
+
theta_expanded = jnp.expand_dims(theta, -1)
|
|
392
|
+
log_q_vector = theta_expanded * unit_vector
|
|
393
|
+
|
|
394
|
+
# Handle zero vector case (real quaternion)
|
|
395
|
+
log_q_vector = jnp.where(
|
|
396
|
+
jnp.expand_dims(vector_norm == 0, -1), jnp.zeros_like(log_q_vector), log_q_vector
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
return Quaternion.from_scalar_vector(log_norm, log_q_vector)
|
|
400
|
+
|
|
401
|
+
def exp(self) -> Quaternion:
|
|
402
|
+
"""Compute quaternion exponential.
|
|
403
|
+
|
|
404
|
+
For a quaternion q = s + v, the exponential is:
|
|
405
|
+
exp(q) = exp(s) * (cos(|v|) + sin(|v|) * v/|v|)
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
The exponential of the quaternion
|
|
409
|
+
"""
|
|
410
|
+
scalar_part = self.w
|
|
411
|
+
vector_part = self.vector
|
|
412
|
+
vector_norm = jnp.linalg.norm(vector_part, axis=-1)
|
|
413
|
+
|
|
414
|
+
# exp(s + v) = exp(s) * (cos(|v|) + sin(|v|) * v/|v|)
|
|
415
|
+
exp_scalar = jnp.exp(scalar_part)
|
|
416
|
+
cos_vnorm = jnp.cos(vector_norm)
|
|
417
|
+
sin_vnorm = jnp.sin(vector_norm)
|
|
418
|
+
|
|
419
|
+
# Handle case where |v| = 0 (real quaternion)
|
|
420
|
+
safe_vector_norm = jnp.where(vector_norm == 0, 1.0, vector_norm)
|
|
421
|
+
unit_v = vector_part / jnp.expand_dims(safe_vector_norm, -1)
|
|
422
|
+
|
|
423
|
+
result_w = exp_scalar * cos_vnorm
|
|
424
|
+
result_vector = exp_scalar * jnp.expand_dims(sin_vnorm, -1) * unit_v
|
|
425
|
+
|
|
426
|
+
# Handle zero vector case
|
|
427
|
+
result_vector = jnp.where(
|
|
428
|
+
jnp.expand_dims(vector_norm == 0, -1), jnp.zeros_like(result_vector), result_vector
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
return Quaternion.from_scalar_vector(result_w, result_vector)
|
|
432
|
+
|
|
433
|
+
def __pow__(self, exponent: float | int | Array) -> Quaternion:
|
|
434
|
+
"""Quaternion exponentiation q^n.
|
|
435
|
+
|
|
436
|
+
For integer exponents, uses optimized special cases.
|
|
437
|
+
For non-integer exponents, uses the general formula: q^n = exp(n * log(q))
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
exponent: The exponent (scalar or array)
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
The quaternion raised to the given power
|
|
444
|
+
"""
|
|
445
|
+
# Handle special cases for integer exponents only
|
|
446
|
+
if isinstance(exponent, int):
|
|
447
|
+
if exponent == -2:
|
|
448
|
+
q_inv = self._inverse()
|
|
449
|
+
return q_inv * q_inv
|
|
450
|
+
elif exponent == -1:
|
|
451
|
+
return self._inverse()
|
|
452
|
+
elif exponent == 0:
|
|
453
|
+
return Quaternion.ones(self.shape, self.dtype)
|
|
454
|
+
elif exponent == 1:
|
|
455
|
+
return self
|
|
456
|
+
elif exponent == 2:
|
|
457
|
+
return self * self
|
|
458
|
+
|
|
459
|
+
# General case: q^n = exp(n * log(q))
|
|
460
|
+
log_q = self.log()
|
|
461
|
+
n_log_q = exponent * log_q
|
|
462
|
+
return n_log_q.exp()
|
|
463
|
+
|
|
464
|
+
@property
|
|
465
|
+
def nbytes(self) -> int:
|
|
466
|
+
"""Number of bytes in the tensor."""
|
|
467
|
+
return self.wxyz.nbytes
|
|
468
|
+
|
|
469
|
+
@property
|
|
470
|
+
def itemsize(self) -> int:
|
|
471
|
+
"""Size of one quaternion element in bytes."""
|
|
472
|
+
return self.wxyz.itemsize * 4
|
|
473
|
+
|
|
474
|
+
@property
|
|
475
|
+
def shape(self) -> tuple[int, ...]:
|
|
476
|
+
"""Shape of the tensor."""
|
|
477
|
+
return self.wxyz.shape[:-1]
|
|
478
|
+
|
|
479
|
+
@property
|
|
480
|
+
def ndim(self):
|
|
481
|
+
"""Number of dimensions of the quaternion tensor (without the quaternion dimension)."""
|
|
482
|
+
return self.wxyz.ndim - 1
|
|
483
|
+
|
|
484
|
+
@property
|
|
485
|
+
def size(self):
|
|
486
|
+
"""Total number of quaternions."""
|
|
487
|
+
return self.wxyz.size >> 2
|
|
488
|
+
|
|
489
|
+
@property
|
|
490
|
+
def dtype(self) -> jnp.dtype:
|
|
491
|
+
"""Data type."""
|
|
492
|
+
return self.wxyz.dtype
|
|
493
|
+
|
|
494
|
+
def reshape(self, *shape) -> Quaternion:
|
|
495
|
+
"""Redimensionne le tableau de quaternions"""
|
|
496
|
+
if len(shape) == 0:
|
|
497
|
+
raise ValueError('Must specify at least one dimension')
|
|
498
|
+
if isinstance(shape[0], tuple):
|
|
499
|
+
if len(shape) > 1:
|
|
500
|
+
raise ValueError('Cannot specify more than one shape')
|
|
501
|
+
shape = shape[0]
|
|
502
|
+
new_shape = shape + (4,)
|
|
503
|
+
return self.from_array(self.wxyz.reshape(new_shape))
|
|
504
|
+
|
|
505
|
+
def flatten(self) -> Quaternion:
|
|
506
|
+
"""Aplatis le tableau de quaternions"""
|
|
507
|
+
return self.from_array(self.wxyz.reshape(-1, 4))
|
|
508
|
+
|
|
509
|
+
def ravel(self) -> Quaternion:
|
|
510
|
+
"""Aplatis le tableau de quaternions"""
|
|
511
|
+
return self.flatten()
|
|
512
|
+
|
|
513
|
+
def squeeze(self, axis=None) -> Quaternion:
|
|
514
|
+
"""Supprime les dimensions de taille 1"""
|
|
515
|
+
return Quaternion.from_array(jnp.squeeze(self.wxyz, axis=axis))
|
|
516
|
+
|
|
517
|
+
def conjugate(self) -> Quaternion:
|
|
518
|
+
"""Quaternion conjugate."""
|
|
519
|
+
sign = jnp.array([1, -1, -1, -1])
|
|
520
|
+
return Quaternion.from_array(self.wxyz * sign)
|
|
521
|
+
|
|
522
|
+
def conj(self) -> Quaternion:
|
|
523
|
+
"""Quaternion conjugate."""
|
|
524
|
+
return self.conjugate()
|
|
525
|
+
|
|
526
|
+
def block_until_ready(self) -> None:
|
|
527
|
+
"""Block until all pending computations are done."""
|
|
528
|
+
self.wxyz.block_until_ready()
|
|
529
|
+
|
|
530
|
+
@property
|
|
531
|
+
def device(self) -> jax.Device[Any]:
|
|
532
|
+
return self.wxyz.device
|
|
533
|
+
|
|
534
|
+
def devices(self) -> set[jax.Device[Any]]:
|
|
535
|
+
return self.wxyz.devices()
|
|
536
|
+
|
|
537
|
+
def slerp(self, other: Quaternion, t: float | Array) -> Quaternion:
|
|
538
|
+
"""Spherical linear interpolation between two quaternions.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
other: Target quaternion to interpolate towards
|
|
542
|
+
t: Interpolation parameter in [0, 1]. t=0 returns self, t=1 returns other
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
Interpolated quaternion
|
|
546
|
+
"""
|
|
547
|
+
# Ensure both quaternions are normalized
|
|
548
|
+
q1 = self.normalize()
|
|
549
|
+
q2 = other.normalize()
|
|
550
|
+
|
|
551
|
+
# Compute dot product
|
|
552
|
+
dot = jnp.sum(q1.wxyz * q2.wxyz, axis=-1)
|
|
553
|
+
|
|
554
|
+
# If dot product is negative, slerp won't take the shorter path.
|
|
555
|
+
# Note that this is necessary to handle the double cover of SO(3)
|
|
556
|
+
# by unit quaternions: q and -q represent the same rotation.
|
|
557
|
+
q2_corrected = jnp.where(jnp.expand_dims(dot < 0, -1), -q2.wxyz, q2.wxyz)
|
|
558
|
+
dot = jnp.abs(dot)
|
|
559
|
+
|
|
560
|
+
# If quaternions are very close, use linear interpolation to avoid numerical issues
|
|
561
|
+
threshold = 0.9995
|
|
562
|
+
use_linear = dot > threshold
|
|
563
|
+
|
|
564
|
+
# Linear interpolation case
|
|
565
|
+
result_linear = q1.wxyz + jnp.expand_dims(t * (1 - t), -1) * (q2_corrected - q1.wxyz)
|
|
566
|
+
result_linear = Quaternion.from_array(result_linear).normalize()
|
|
567
|
+
|
|
568
|
+
# Spherical interpolation case
|
|
569
|
+
theta = jnp.arccos(jnp.clip(dot, 0.0, 1.0))
|
|
570
|
+
sin_theta = jnp.sin(theta)
|
|
571
|
+
|
|
572
|
+
# Avoid division by zero
|
|
573
|
+
safe_sin_theta = jnp.where(sin_theta == 0, 1.0, sin_theta)
|
|
574
|
+
|
|
575
|
+
factor1 = jnp.sin((1 - t) * theta) / safe_sin_theta
|
|
576
|
+
factor2 = jnp.sin(t * theta) / safe_sin_theta
|
|
577
|
+
|
|
578
|
+
result_slerp = (
|
|
579
|
+
jnp.expand_dims(factor1, -1) * q1.wxyz + jnp.expand_dims(factor2, -1) * q2_corrected
|
|
580
|
+
)
|
|
581
|
+
result_slerp = Quaternion.from_array(result_slerp)
|
|
582
|
+
|
|
583
|
+
# Choose between linear and spherical interpolation
|
|
584
|
+
result = jnp.where(jnp.expand_dims(use_linear, -1), result_linear.wxyz, result_slerp.wxyz)
|
|
585
|
+
|
|
586
|
+
return Quaternion.from_array(result)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: FastQuat
|
|
3
|
+
Version: 0.2
|
|
4
|
+
Summary: High-performance quaternions with JAX support
|
|
5
|
+
Author-email: Pierre Chanial <chanial@apc.in2p3.fr>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Requires-Dist: jax[cuda12]>=0.4.0
|
|
9
|
+
Requires-Dist: jaxlib>=0.4.0
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# FastQuat - High-Performance Quaternions with JAX
|
|
13
|
+
|
|
14
|
+
[](https://pypi.org/project/fastquat/)
|
|
15
|
+
[](https://pypi.org/project/fastquat/)
|
|
16
|
+
[](https://github.com/CMBSciPol/fastquat/actions)
|
|
17
|
+
|
|
18
|
+
FastQuat provides optimized quaternion operations with full JAX compatibility, featuring:
|
|
19
|
+
|
|
20
|
+
- 🚀 **Hardware-accelerated** computations (CPU/GPU/TPU)
|
|
21
|
+
- 🔄 **Automatic differentiation** support
|
|
22
|
+
- 🧩 **Seamless integration** with JAX transformations (`jit`, `grad`, `vmap`)
|
|
23
|
+
- 📦 **Efficient storage** using interleaved memory layout
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install fastquat
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Quick Start
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
import jax.numpy as jnp
|
|
35
|
+
from fastquat import Quaternion
|
|
36
|
+
|
|
37
|
+
# Create quaternions
|
|
38
|
+
q1 = Quaternion.ones() # Identity quaternion
|
|
39
|
+
q2 = Quaternion(0.7071, 0.7071, 0.0, 0.0) # 90° rotation around x-axis
|
|
40
|
+
|
|
41
|
+
# Quaternion operations
|
|
42
|
+
q3 = q1 * q2 # Multiplication
|
|
43
|
+
q_inv = q1.inverse() # Inverse
|
|
44
|
+
q_norm = q1.normalize() # Normalization
|
|
45
|
+
|
|
46
|
+
# Rotate vectors
|
|
47
|
+
vector = jnp.array([1.0, 0.0, 0.0])
|
|
48
|
+
rotated = q2.rotate_vector(vector)
|
|
49
|
+
|
|
50
|
+
# Spherical interpolation (SLERP)
|
|
51
|
+
interpolated = q1.slerp(q2, t=0.5) # Halfway between q1 and q2
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Features
|
|
55
|
+
|
|
56
|
+
### Core Operations
|
|
57
|
+
- **Quaternion arithmetic**: Addition, multiplication, conjugation, inverse
|
|
58
|
+
- **Normalization**: Efficient unit quaternion computation
|
|
59
|
+
- **Conversion**: To/from rotation matrices, Euler angles
|
|
60
|
+
- **Vector rotation**: Direct vector transformation
|
|
61
|
+
|
|
62
|
+
### Advanced Interpolation
|
|
63
|
+
- **SLERP (Spherical Linear Interpolation)**: Smooth rotation interpolation
|
|
64
|
+
- Automatically handles shortest path selection
|
|
65
|
+
- Numerically stable for close quaternions
|
|
66
|
+
- Supports batched operations and array-valued parameters
|
|
67
|
+
|
|
68
|
+
### JAX Integration
|
|
69
|
+
- **JIT compilation**: Compile quaternion operations for maximum performance
|
|
70
|
+
- **Automatic differentiation**: Compute gradients through quaternion operations
|
|
71
|
+
- **Vectorization**: Process batches of quaternions efficiently
|
|
72
|
+
- **Device support**: Run on CPU, GPU, or TPU
|
|
73
|
+
|
|
74
|
+
## Performance
|
|
75
|
+
|
|
76
|
+
FastQuat is optimized for high-performance computing:
|
|
77
|
+
- Memory-efficient interleaved storage
|
|
78
|
+
- SIMD-optimized operations on supported hardware
|
|
79
|
+
- Zero-copy integration with JAX arrays
|
|
80
|
+
- Minimal Python overhead through JIT compilation
|
|
81
|
+
|
|
82
|
+
## Examples
|
|
83
|
+
|
|
84
|
+
### Basic Usage
|
|
85
|
+
```python
|
|
86
|
+
import jax
|
|
87
|
+
import jax.numpy as jnp
|
|
88
|
+
from fastquat import Quaternion
|
|
89
|
+
|
|
90
|
+
# Create random quaternions
|
|
91
|
+
key = jax.random.PRNGKey(42)
|
|
92
|
+
q_batch = Quaternion.random(key, shape=(1000,))
|
|
93
|
+
|
|
94
|
+
# JIT-compiled batch operations
|
|
95
|
+
@jax.jit
|
|
96
|
+
def batch_rotate(quaternions, vectors):
|
|
97
|
+
return quaternions.rotate_vector(vectors)
|
|
98
|
+
|
|
99
|
+
vectors = jax.random.normal(key, (1000, 3))
|
|
100
|
+
rotated_batch = batch_rotate(q_batch, vectors)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
### SLERP Animation
|
|
104
|
+
```python
|
|
105
|
+
# Smooth rotation interpolation
|
|
106
|
+
q_start = Quaternion.ones()
|
|
107
|
+
q_end = Quaternion.from_rotation_matrix(rotation_matrix)
|
|
108
|
+
|
|
109
|
+
# Generate smooth interpolation
|
|
110
|
+
t_values = jnp.linspace(0, 1, 100)
|
|
111
|
+
interpolated_rotations = q_start.slerp(q_end, t_values)
|
|
112
|
+
|
|
113
|
+
# Apply to object vertices for smooth animation
|
|
114
|
+
animated_vertices = interpolated_rotations.rotate_vector(object_vertices)
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
## Documentation
|
|
118
|
+
|
|
119
|
+
Full documentation is available at [fastquat.readthedocs.io](https://fastquat.readthedocs.io)
|
|
120
|
+
|
|
121
|
+
## Contributing
|
|
122
|
+
|
|
123
|
+
Contributions are welcome! Please see our [development guide](https://fastquat.readthedocs.io/en/latest/development.html) for details.
|
|
124
|
+
|
|
125
|
+
## License
|
|
126
|
+
|
|
127
|
+
MIT License - see [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
fastquat/__init__.py,sha256=IrR5dmDWrjoxCiSv0q59P263OWZcZg1eegAExPpw9zo,61
|
|
2
|
+
fastquat/quaternion.py,sha256=jtQvQlIpBbA7zxBQBT3o7UQaHrp-4qPz--5zlefPBrg,19740
|
|
3
|
+
fastquat-0.2.dist-info/METADATA,sha256=oJm1wK5sDCwcpiK_918YMzFXenVLUWm0hasV8FR7e0E,3877
|
|
4
|
+
fastquat-0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
fastquat-0.2.dist-info/RECORD,,
|