FastQuat 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastquat/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .quaternion import Quaternion
2
+
3
+ __all__ = ['Quaternion']
fastquat/quaternion.py ADDED
@@ -0,0 +1,586 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import Array
8
+ from jax.tree_util import register_pytree_node_class
9
+
10
+
11
+ @register_pytree_node_class
12
+ class Quaternion:
13
+ """Class for manipulating quaternion tensors with JAX.
14
+
15
+ A quaternion is represented by [w, x, y, z] where w is the scalar part
16
+ and (x, y, z) is the vector part.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ w: float | jnp.ndarray,
22
+ x: float | jnp.ndarray,
23
+ y: float | jnp.ndarray,
24
+ z: float | jnp.ndarray,
25
+ ) -> None:
26
+ """Initialize a tensor of quaternions.
27
+
28
+ Args:
29
+ w, x, y, z: components of the quaternions.
30
+ """
31
+ w, x, y, z = jnp.broadcast_arrays(w, x, y, z)
32
+ self.wxyz = jnp.stack([w, x, y, z], axis=-1)
33
+
34
+ def tree_flatten(self) -> tuple[tuple[Any, ...], Any]:
35
+ """Flatten the Quaternion PyTree."""
36
+ return (self.wxyz,), None
37
+
38
+ @classmethod
39
+ def tree_unflatten(cls, aux_data, children) -> Quaternion:
40
+ """Unflatten The Quaternion PyTree"""
41
+ # Create an instance directly without going through from_array to avoid tracer issues
42
+ instance = cls.__new__(cls)
43
+ instance.wxyz = children[0]
44
+ return instance
45
+
46
+ @classmethod
47
+ def from_array(cls, array: Array) -> Quaternion:
48
+ """Create a Quaternion array from a numeric array of shape (..., 4).
49
+
50
+ Args:
51
+ array: array of shape (..., 4) where the last dimension is [w, x, y, z]
52
+ """
53
+ # Handle JAX tracers and arrays properly
54
+ if not isinstance(array, jnp.ndarray):
55
+ array = jnp.asarray(array)
56
+
57
+ if array.shape[-1:] != (4,):
58
+ raise ValueError(f'Array must have shape (..., 4), got {array.shape}')
59
+
60
+ instance = cls.__new__(cls)
61
+ instance.wxyz = array
62
+ return instance
63
+
64
+ @classmethod
65
+ def from_scalar_vector(cls, scalar: Array, vector: Array) -> Quaternion:
66
+ """Create a quaternion from scalar and vector parts.
67
+
68
+ Args:
69
+ scalar: Array of shape (...,) for the scalar part.
70
+ vector: Array of shape (..., 3) for the vector part.
71
+
72
+ Returns:
73
+ Quaternion
74
+ """
75
+ if vector.shape[-1:] != (3,):
76
+ raise ValueError(f'Vector must have shape (..., 3), got {vector.shape}')
77
+ scalar = jnp.expand_dims(scalar, axis=-1)
78
+ return cls.from_array(jnp.concatenate([scalar, vector], axis=-1))
79
+
80
+ @classmethod
81
+ def from_rotation_matrix(cls, rot: Array) -> Quaternion:
82
+ """Create the quaternion associated to a rotation matrix.
83
+
84
+ Args:
85
+ rot: Array of shape (..., 3, 3) representing the rotation matrix
86
+
87
+ Returns:
88
+ The normalized Quaternion tensor representing the rotation matrix.
89
+ """
90
+ if rot.shape[-2:] != (3, 3):
91
+ raise ValueError(f'Rotation matrix must have shape (..., 3, 3), got {rot.shape}')
92
+
93
+ # Implémentation de la conversion matrice -> quaternion
94
+ trace = jnp.trace(rot, axis1=-2, axis2=-1)
95
+
96
+ # Cas où trace > 0
97
+ s = jnp.sqrt(trace + 1.0) * 2 # s = 4 * w
98
+ w = 0.25 * s
99
+ x = (rot[..., 2, 1] - rot[..., 1, 2]) / s
100
+ y = (rot[..., 0, 2] - rot[..., 2, 0]) / s
101
+ z = (rot[..., 1, 0] - rot[..., 0, 1]) / s
102
+
103
+ return cls.from_array(jnp.stack([w, x, y, z], axis=-1))
104
+
105
+ @classmethod
106
+ def zeros(cls, shape: tuple[int, ...] = (), dtype: jnp.dtype = jnp.float32) -> Quaternion:
107
+ """Create quaternions with all components set to 0.
108
+
109
+ Args:
110
+ shape: Shape of the tensor (without the last dimension).
111
+ dtype: Data type of the quaternion components.
112
+
113
+ Returns:
114
+ Quaternion with all components equal to 0.
115
+ """
116
+ data = jnp.zeros(shape + (4,), dtype=dtype)
117
+ return cls.from_array(data)
118
+
119
+ @classmethod
120
+ def ones(cls, shape: tuple[int, ...] = (), dtype: jnp.dtype = jnp.float32) -> Quaternion:
121
+ """Create quaternions with scalar component set to 1 and vector components set to 0.
122
+
123
+ Args:
124
+ shape: Shape of the tensor (without the last dimension).
125
+ dtype: Data type of the quaternion components.
126
+
127
+ Returns:
128
+ Quaternions with w=1 and x=y=z=0.
129
+ """
130
+ data = jnp.zeros(shape + (4,), dtype=dtype)
131
+ data = data.at[..., 0].set(1.0)
132
+ return cls.from_array(data)
133
+
134
+ @classmethod
135
+ def full(
136
+ cls, shape: tuple[int, ...], fill_value: float, dtype: jnp.dtype = jnp.float32
137
+ ) -> Quaternion:
138
+ """Create quaternions with scalar component set to a value and vector components set to 0.
139
+
140
+ Args:
141
+ shape: Shape of the tensor (without the last dimension).
142
+ fill_value: Value to fill the scalar component with.
143
+ dtype: Data type of the quaternion components.
144
+
145
+ Returns:
146
+ Quaternions with w=fill_value and x=y=z=0.
147
+ """
148
+ data = jnp.zeros(shape + (4,), dtype=dtype)
149
+ data = data.at[..., 0].set(fill_value)
150
+ return cls.from_array(data)
151
+
152
+ @classmethod
153
+ def random(cls, key: jax.random.PRNGKey, shape: tuple[int, ...] = ()) -> Quaternion:
154
+ """Generate normalized random quaternions.
155
+
156
+ Args:
157
+ key: Key PRNG.
158
+ shape: Shape of the tensor (without the last dimension).
159
+
160
+ Returns:
161
+ Normalized Quaternion.
162
+ """
163
+ data = jax.random.normal(key, shape + (4,))
164
+ return Quaternion.from_array(data).normalize()
165
+
166
+ @property
167
+ def w(self) -> Array:
168
+ return self.wxyz[..., 0]
169
+
170
+ @property
171
+ def x(self) -> Array:
172
+ return self.wxyz[..., 1]
173
+
174
+ @property
175
+ def y(self) -> Array:
176
+ return self.wxyz[..., 2]
177
+
178
+ @property
179
+ def z(self) -> Array:
180
+ return self.wxyz[..., 3]
181
+
182
+ @property
183
+ def vector(self) -> Array:
184
+ """Vector part (..., 3)"""
185
+ return self.wxyz[..., 1:]
186
+
187
+ def norm(self) -> Array:
188
+ """Quaternion norm."""
189
+ return jnp.sqrt(jnp.sum(self.wxyz**2, axis=-1))
190
+
191
+ def normalize(self) -> Quaternion:
192
+ """Normalize the quaternion.
193
+
194
+ Returns the normalized quaternion. If the quaternion has zero norm,
195
+ returns the zero quaternion [0, 0, 0, 0].
196
+ """
197
+ norm = self.norm()
198
+ # Avoid division by zero
199
+ safe_norm = jnp.where(norm == 0, 1.0, norm)
200
+ return Quaternion.from_array(self.wxyz / jnp.expand_dims(safe_norm, axis=-1))
201
+
202
+ def _inverse(self) -> Quaternion:
203
+ """Quaternion inverse (private method - use 1/q instead)."""
204
+ conj = self.conj()
205
+ norm_sq = self.norm() ** 2
206
+ return Quaternion.from_array(conj.wxyz / jnp.expand_dims(norm_sq, axis=-1))
207
+
208
+ def to_components(self) -> tuple[Array, Array, Array, Array]:
209
+ return self.w, self.x, self.y, self.z
210
+
211
+ def to_rotation_matrix(self) -> Array:
212
+ """Convert quaternion to rotation matrix.
213
+
214
+ Returns:
215
+ Array of shape (..., 3, 3)
216
+ """
217
+ # Normalize the quaternion
218
+ q = self.normalize()
219
+ w, x, y, z = q.to_components()
220
+
221
+ # Calculate matrix elements
222
+ xx, yy, zz = x * x, y * y, z * z
223
+ xy, xz, yz = x * y, x * z, y * z
224
+ wx, wy, wz = w * x, w * y, w * z
225
+
226
+ rot = jnp.stack(
227
+ [
228
+ jnp.stack([1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], axis=-1),
229
+ jnp.stack([2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], axis=-1),
230
+ jnp.stack([2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], axis=-1),
231
+ ],
232
+ axis=-2,
233
+ )
234
+
235
+ return rot
236
+
237
+ def rotate_vector(self, v: Array) -> Array:
238
+ """Apply quaternion rotation to a vector.
239
+
240
+ Args:
241
+ v: Array of shape (..., 3) representing vectors
242
+
243
+ Returns:
244
+ Array of shape (..., 3) representing rotated vectors
245
+ """
246
+ # Convert vector to pure quaternion
247
+ v_quat = Quaternion(0, v[..., 0], v[..., 1], v[..., 2])
248
+ # Apply rotation: q * v * q^-1
249
+ result = self * v_quat * self._inverse()
250
+
251
+ return result.vector
252
+
253
+ def __repr__(self) -> str:
254
+ if self.shape == ():
255
+ w, x, y, z = self.wxyz
256
+ return f'{w} + {x}i + {y}j + {z}k'
257
+ return f'Quaternion(shape={self.shape}, dtype={self.dtype})'
258
+
259
+ #######################
260
+ # JAX array interface #
261
+ #######################
262
+
263
+ def __len__(self):
264
+ """Length of the first axis."""
265
+ if self.ndim == 0:
266
+ raise TypeError('len() of unsized object')
267
+ return self.shape[0]
268
+
269
+ def __iter__(self):
270
+ """Iterate over the first axis."""
271
+ if self.ndim == 0:
272
+ raise TypeError('iteration over a 0-d quaternion')
273
+ for i in range(self.shape[0]):
274
+ yield Quaternion.from_array(self.wxyz[i])
275
+
276
+ def __add__(self, other: Any) -> Quaternion:
277
+ """Quaternion addition."""
278
+ if isinstance(other, Quaternion):
279
+ return Quaternion.from_array(self.wxyz + other.wxyz)
280
+
281
+ if isinstance(other, int | float | jnp.ndarray):
282
+ return Quaternion.from_scalar_vector(self.w + other, self.vector)
283
+
284
+ raise NotImplementedError
285
+
286
+ def __radd__(self, other: Any) -> Quaternion:
287
+ """Quaternion addition."""
288
+ return self.__add__(other)
289
+
290
+ def __sub__(self, other: Any) -> Quaternion:
291
+ """Quaternion subtraction."""
292
+ if isinstance(other, Quaternion):
293
+ return Quaternion.from_array(self.wxyz - other.wxyz)
294
+
295
+ if isinstance(other, int | float | jnp.ndarray):
296
+ return Quaternion.from_scalar_vector(self.w - other, self.vector)
297
+
298
+ raise NotImplementedError
299
+
300
+ def __rsub__(self, other: Any) -> Quaternion:
301
+ """Quaternion subtraction."""
302
+ if isinstance(other, int | float | jnp.ndarray):
303
+ return Quaternion.from_scalar_vector(other - self.w, -self.vector)
304
+
305
+ raise NotImplementedError
306
+
307
+ def __mul__(self, other: Any) -> Quaternion:
308
+ """Quaternion multiplication."""
309
+ if isinstance(other, Quaternion):
310
+ w1, x1, y1, z1 = self.to_components()
311
+ w2, x2, y2, z2 = other.to_components()
312
+
313
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
314
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
315
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
316
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
317
+
318
+ return Quaternion(w, x, y, z)
319
+
320
+ if isinstance(other, int | float):
321
+ return Quaternion.from_array(self.wxyz * other)
322
+
323
+ if isinstance(other, jnp.ndarray):
324
+ return Quaternion.from_array(self.wxyz * jnp.expand_dims(other, axis=-1))
325
+
326
+ return NotImplemented
327
+
328
+ def __rmul__(self, other: Any) -> Quaternion:
329
+ """Quaternion multiplication."""
330
+ if isinstance(other, int | float):
331
+ return Quaternion.from_array(other * self.wxyz)
332
+
333
+ if isinstance(other, jnp.ndarray):
334
+ return Quaternion.from_array(jnp.expand_dims(other, axis=-1) * self.wxyz)
335
+
336
+ return NotImplemented
337
+
338
+ def __truediv__(self, other: Any) -> Quaternion:
339
+ """Quaternion division."""
340
+ if isinstance(other, Quaternion):
341
+ return self * other._inverse()
342
+
343
+ if isinstance(other, int | float):
344
+ return Quaternion.from_array(self.wxyz / other)
345
+
346
+ if isinstance(other, jnp.ndarray):
347
+ return Quaternion.from_array(self.wxyz / jnp.expand_dims(other, axis=-1))
348
+
349
+ return NotImplemented
350
+
351
+ def __rtruediv__(self, other: Any) -> Quaternion:
352
+ """Quaternion division."""
353
+ if isinstance(other, int | float) and other == 1:
354
+ return self._inverse()
355
+
356
+ if isinstance(other, int | float | jnp.ndarray):
357
+ return other * self._inverse()
358
+
359
+ return NotImplemented
360
+
361
+ def __neg__(self) -> Quaternion:
362
+ """Quaternion negation."""
363
+ return Quaternion.from_array(-self.wxyz)
364
+
365
+ def log(self) -> Quaternion:
366
+ """Compute quaternion logarithm.
367
+
368
+ For a quaternion q = |q| * (cos(θ) + sin(θ)v), the logarithm is:
369
+ log(q) = log(|q|) + θ * v
370
+
371
+ Returns:
372
+ The logarithm of the quaternion
373
+ """
374
+ # Get norm and handle zero quaternion
375
+ q_norm = self.norm()
376
+
377
+ # Normalize to get unit quaternion (handles zero quaternions safely)
378
+ unit_q = self.normalize()
379
+
380
+ # For unit quaternion q = cos(θ) + sin(θ)v, compute θ and v
381
+ # θ = arccos(w) and v = vector/|vector|
382
+ theta = jnp.arccos(jnp.clip(unit_q.w, -1.0, 1.0))
383
+ vector_norm = jnp.linalg.norm(unit_q.vector, axis=-1)
384
+
385
+ # Handle case where vector is zero (real quaternion)
386
+ safe_vector_norm = jnp.where(vector_norm == 0, 1.0, vector_norm)
387
+ unit_vector = unit_q.vector / jnp.expand_dims(safe_vector_norm, -1)
388
+
389
+ # log(q) = log(|q|) + θ * v
390
+ log_norm = jnp.log(jnp.maximum(q_norm, 1e-10)) # Avoid log(0)
391
+ theta_expanded = jnp.expand_dims(theta, -1)
392
+ log_q_vector = theta_expanded * unit_vector
393
+
394
+ # Handle zero vector case (real quaternion)
395
+ log_q_vector = jnp.where(
396
+ jnp.expand_dims(vector_norm == 0, -1), jnp.zeros_like(log_q_vector), log_q_vector
397
+ )
398
+
399
+ return Quaternion.from_scalar_vector(log_norm, log_q_vector)
400
+
401
+ def exp(self) -> Quaternion:
402
+ """Compute quaternion exponential.
403
+
404
+ For a quaternion q = s + v, the exponential is:
405
+ exp(q) = exp(s) * (cos(|v|) + sin(|v|) * v/|v|)
406
+
407
+ Returns:
408
+ The exponential of the quaternion
409
+ """
410
+ scalar_part = self.w
411
+ vector_part = self.vector
412
+ vector_norm = jnp.linalg.norm(vector_part, axis=-1)
413
+
414
+ # exp(s + v) = exp(s) * (cos(|v|) + sin(|v|) * v/|v|)
415
+ exp_scalar = jnp.exp(scalar_part)
416
+ cos_vnorm = jnp.cos(vector_norm)
417
+ sin_vnorm = jnp.sin(vector_norm)
418
+
419
+ # Handle case where |v| = 0 (real quaternion)
420
+ safe_vector_norm = jnp.where(vector_norm == 0, 1.0, vector_norm)
421
+ unit_v = vector_part / jnp.expand_dims(safe_vector_norm, -1)
422
+
423
+ result_w = exp_scalar * cos_vnorm
424
+ result_vector = exp_scalar * jnp.expand_dims(sin_vnorm, -1) * unit_v
425
+
426
+ # Handle zero vector case
427
+ result_vector = jnp.where(
428
+ jnp.expand_dims(vector_norm == 0, -1), jnp.zeros_like(result_vector), result_vector
429
+ )
430
+
431
+ return Quaternion.from_scalar_vector(result_w, result_vector)
432
+
433
+ def __pow__(self, exponent: float | int | Array) -> Quaternion:
434
+ """Quaternion exponentiation q^n.
435
+
436
+ For integer exponents, uses optimized special cases.
437
+ For non-integer exponents, uses the general formula: q^n = exp(n * log(q))
438
+
439
+ Args:
440
+ exponent: The exponent (scalar or array)
441
+
442
+ Returns:
443
+ The quaternion raised to the given power
444
+ """
445
+ # Handle special cases for integer exponents only
446
+ if isinstance(exponent, int):
447
+ if exponent == -2:
448
+ q_inv = self._inverse()
449
+ return q_inv * q_inv
450
+ elif exponent == -1:
451
+ return self._inverse()
452
+ elif exponent == 0:
453
+ return Quaternion.ones(self.shape, self.dtype)
454
+ elif exponent == 1:
455
+ return self
456
+ elif exponent == 2:
457
+ return self * self
458
+
459
+ # General case: q^n = exp(n * log(q))
460
+ log_q = self.log()
461
+ n_log_q = exponent * log_q
462
+ return n_log_q.exp()
463
+
464
+ @property
465
+ def nbytes(self) -> int:
466
+ """Number of bytes in the tensor."""
467
+ return self.wxyz.nbytes
468
+
469
+ @property
470
+ def itemsize(self) -> int:
471
+ """Size of one quaternion element in bytes."""
472
+ return self.wxyz.itemsize * 4
473
+
474
+ @property
475
+ def shape(self) -> tuple[int, ...]:
476
+ """Shape of the tensor."""
477
+ return self.wxyz.shape[:-1]
478
+
479
+ @property
480
+ def ndim(self):
481
+ """Number of dimensions of the quaternion tensor (without the quaternion dimension)."""
482
+ return self.wxyz.ndim - 1
483
+
484
+ @property
485
+ def size(self):
486
+ """Total number of quaternions."""
487
+ return self.wxyz.size >> 2
488
+
489
+ @property
490
+ def dtype(self) -> jnp.dtype:
491
+ """Data type."""
492
+ return self.wxyz.dtype
493
+
494
+ def reshape(self, *shape) -> Quaternion:
495
+ """Redimensionne le tableau de quaternions"""
496
+ if len(shape) == 0:
497
+ raise ValueError('Must specify at least one dimension')
498
+ if isinstance(shape[0], tuple):
499
+ if len(shape) > 1:
500
+ raise ValueError('Cannot specify more than one shape')
501
+ shape = shape[0]
502
+ new_shape = shape + (4,)
503
+ return self.from_array(self.wxyz.reshape(new_shape))
504
+
505
+ def flatten(self) -> Quaternion:
506
+ """Aplatis le tableau de quaternions"""
507
+ return self.from_array(self.wxyz.reshape(-1, 4))
508
+
509
+ def ravel(self) -> Quaternion:
510
+ """Aplatis le tableau de quaternions"""
511
+ return self.flatten()
512
+
513
+ def squeeze(self, axis=None) -> Quaternion:
514
+ """Supprime les dimensions de taille 1"""
515
+ return Quaternion.from_array(jnp.squeeze(self.wxyz, axis=axis))
516
+
517
+ def conjugate(self) -> Quaternion:
518
+ """Quaternion conjugate."""
519
+ sign = jnp.array([1, -1, -1, -1])
520
+ return Quaternion.from_array(self.wxyz * sign)
521
+
522
+ def conj(self) -> Quaternion:
523
+ """Quaternion conjugate."""
524
+ return self.conjugate()
525
+
526
+ def block_until_ready(self) -> None:
527
+ """Block until all pending computations are done."""
528
+ self.wxyz.block_until_ready()
529
+
530
+ @property
531
+ def device(self) -> jax.Device[Any]:
532
+ return self.wxyz.device
533
+
534
+ def devices(self) -> set[jax.Device[Any]]:
535
+ return self.wxyz.devices()
536
+
537
+ def slerp(self, other: Quaternion, t: float | Array) -> Quaternion:
538
+ """Spherical linear interpolation between two quaternions.
539
+
540
+ Args:
541
+ other: Target quaternion to interpolate towards
542
+ t: Interpolation parameter in [0, 1]. t=0 returns self, t=1 returns other
543
+
544
+ Returns:
545
+ Interpolated quaternion
546
+ """
547
+ # Ensure both quaternions are normalized
548
+ q1 = self.normalize()
549
+ q2 = other.normalize()
550
+
551
+ # Compute dot product
552
+ dot = jnp.sum(q1.wxyz * q2.wxyz, axis=-1)
553
+
554
+ # If dot product is negative, slerp won't take the shorter path.
555
+ # Note that this is necessary to handle the double cover of SO(3)
556
+ # by unit quaternions: q and -q represent the same rotation.
557
+ q2_corrected = jnp.where(jnp.expand_dims(dot < 0, -1), -q2.wxyz, q2.wxyz)
558
+ dot = jnp.abs(dot)
559
+
560
+ # If quaternions are very close, use linear interpolation to avoid numerical issues
561
+ threshold = 0.9995
562
+ use_linear = dot > threshold
563
+
564
+ # Linear interpolation case
565
+ result_linear = q1.wxyz + jnp.expand_dims(t * (1 - t), -1) * (q2_corrected - q1.wxyz)
566
+ result_linear = Quaternion.from_array(result_linear).normalize()
567
+
568
+ # Spherical interpolation case
569
+ theta = jnp.arccos(jnp.clip(dot, 0.0, 1.0))
570
+ sin_theta = jnp.sin(theta)
571
+
572
+ # Avoid division by zero
573
+ safe_sin_theta = jnp.where(sin_theta == 0, 1.0, sin_theta)
574
+
575
+ factor1 = jnp.sin((1 - t) * theta) / safe_sin_theta
576
+ factor2 = jnp.sin(t * theta) / safe_sin_theta
577
+
578
+ result_slerp = (
579
+ jnp.expand_dims(factor1, -1) * q1.wxyz + jnp.expand_dims(factor2, -1) * q2_corrected
580
+ )
581
+ result_slerp = Quaternion.from_array(result_slerp)
582
+
583
+ # Choose between linear and spherical interpolation
584
+ result = jnp.where(jnp.expand_dims(use_linear, -1), result_linear.wxyz, result_slerp.wxyz)
585
+
586
+ return Quaternion.from_array(result)
@@ -0,0 +1,127 @@
1
+ Metadata-Version: 2.4
2
+ Name: FastQuat
3
+ Version: 0.2
4
+ Summary: High-performance quaternions with JAX support
5
+ Author-email: Pierre Chanial <chanial@apc.in2p3.fr>
6
+ License-Expression: MIT
7
+ Requires-Python: >=3.10
8
+ Requires-Dist: jax[cuda12]>=0.4.0
9
+ Requires-Dist: jaxlib>=0.4.0
10
+ Description-Content-Type: text/markdown
11
+
12
+ # FastQuat - High-Performance Quaternions with JAX
13
+
14
+ [![PyPI version](https://img.shields.io/pypi/v/fastquat)](https://pypi.org/project/fastquat/)
15
+ [![Python versions](https://img.shields.io/pypi/pyversions/fastquat)](https://pypi.org/project/fastquat/)
16
+ [![Tests](https://github.com/CMBSciPol/fastquat/actions/workflows/tests.yml/badge.svg)](https://github.com/CMBSciPol/fastquat/actions)
17
+
18
+ FastQuat provides optimized quaternion operations with full JAX compatibility, featuring:
19
+
20
+ - 🚀 **Hardware-accelerated** computations (CPU/GPU/TPU)
21
+ - 🔄 **Automatic differentiation** support
22
+ - 🧩 **Seamless integration** with JAX transformations (`jit`, `grad`, `vmap`)
23
+ - 📦 **Efficient storage** using interleaved memory layout
24
+
25
+ ## Installation
26
+
27
+ ```bash
28
+ pip install fastquat
29
+ ```
30
+
31
+ ## Quick Start
32
+
33
+ ```python
34
+ import jax.numpy as jnp
35
+ from fastquat import Quaternion
36
+
37
+ # Create quaternions
38
+ q1 = Quaternion.ones() # Identity quaternion
39
+ q2 = Quaternion(0.7071, 0.7071, 0.0, 0.0) # 90° rotation around x-axis
40
+
41
+ # Quaternion operations
42
+ q3 = q1 * q2 # Multiplication
43
+ q_inv = q1.inverse() # Inverse
44
+ q_norm = q1.normalize() # Normalization
45
+
46
+ # Rotate vectors
47
+ vector = jnp.array([1.0, 0.0, 0.0])
48
+ rotated = q2.rotate_vector(vector)
49
+
50
+ # Spherical interpolation (SLERP)
51
+ interpolated = q1.slerp(q2, t=0.5) # Halfway between q1 and q2
52
+ ```
53
+
54
+ ## Features
55
+
56
+ ### Core Operations
57
+ - **Quaternion arithmetic**: Addition, multiplication, conjugation, inverse
58
+ - **Normalization**: Efficient unit quaternion computation
59
+ - **Conversion**: To/from rotation matrices, Euler angles
60
+ - **Vector rotation**: Direct vector transformation
61
+
62
+ ### Advanced Interpolation
63
+ - **SLERP (Spherical Linear Interpolation)**: Smooth rotation interpolation
64
+ - Automatically handles shortest path selection
65
+ - Numerically stable for close quaternions
66
+ - Supports batched operations and array-valued parameters
67
+
68
+ ### JAX Integration
69
+ - **JIT compilation**: Compile quaternion operations for maximum performance
70
+ - **Automatic differentiation**: Compute gradients through quaternion operations
71
+ - **Vectorization**: Process batches of quaternions efficiently
72
+ - **Device support**: Run on CPU, GPU, or TPU
73
+
74
+ ## Performance
75
+
76
+ FastQuat is optimized for high-performance computing:
77
+ - Memory-efficient interleaved storage
78
+ - SIMD-optimized operations on supported hardware
79
+ - Zero-copy integration with JAX arrays
80
+ - Minimal Python overhead through JIT compilation
81
+
82
+ ## Examples
83
+
84
+ ### Basic Usage
85
+ ```python
86
+ import jax
87
+ import jax.numpy as jnp
88
+ from fastquat import Quaternion
89
+
90
+ # Create random quaternions
91
+ key = jax.random.PRNGKey(42)
92
+ q_batch = Quaternion.random(key, shape=(1000,))
93
+
94
+ # JIT-compiled batch operations
95
+ @jax.jit
96
+ def batch_rotate(quaternions, vectors):
97
+ return quaternions.rotate_vector(vectors)
98
+
99
+ vectors = jax.random.normal(key, (1000, 3))
100
+ rotated_batch = batch_rotate(q_batch, vectors)
101
+ ```
102
+
103
+ ### SLERP Animation
104
+ ```python
105
+ # Smooth rotation interpolation
106
+ q_start = Quaternion.ones()
107
+ q_end = Quaternion.from_rotation_matrix(rotation_matrix)
108
+
109
+ # Generate smooth interpolation
110
+ t_values = jnp.linspace(0, 1, 100)
111
+ interpolated_rotations = q_start.slerp(q_end, t_values)
112
+
113
+ # Apply to object vertices for smooth animation
114
+ animated_vertices = interpolated_rotations.rotate_vector(object_vertices)
115
+ ```
116
+
117
+ ## Documentation
118
+
119
+ Full documentation is available at [fastquat.readthedocs.io](https://fastquat.readthedocs.io)
120
+
121
+ ## Contributing
122
+
123
+ Contributions are welcome! Please see our [development guide](https://fastquat.readthedocs.io/en/latest/development.html) for details.
124
+
125
+ ## License
126
+
127
+ MIT License - see [LICENSE](LICENSE) file for details.
@@ -0,0 +1,5 @@
1
+ fastquat/__init__.py,sha256=IrR5dmDWrjoxCiSv0q59P263OWZcZg1eegAExPpw9zo,61
2
+ fastquat/quaternion.py,sha256=jtQvQlIpBbA7zxBQBT3o7UQaHrp-4qPz--5zlefPBrg,19740
3
+ fastquat-0.2.dist-info/METADATA,sha256=oJm1wK5sDCwcpiK_918YMzFXenVLUWm0hasV8FR7e0E,3877
4
+ fastquat-0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ fastquat-0.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any