xax 0.3.2__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {xax-0.3.2/xax.egg-info → xax-0.3.3}/PKG-INFO +1 -1
  2. {xax-0.3.2 → xax-0.3.3}/xax/__init__.py +6 -3
  3. {xax-0.3.2 → xax-0.3.3}/xax/nn/geom.py +82 -0
  4. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/runnable.py +1 -2
  5. {xax-0.3.2 → xax-0.3.3/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.3.2 → xax-0.3.3}/LICENSE +0 -0
  7. {xax-0.3.2 → xax-0.3.3}/MANIFEST.in +0 -0
  8. {xax-0.3.2 → xax-0.3.3}/README.md +0 -0
  9. {xax-0.3.2 → xax-0.3.3}/pyproject.toml +0 -0
  10. {xax-0.3.2 → xax-0.3.3}/setup.cfg +0 -0
  11. {xax-0.3.2 → xax-0.3.3}/setup.py +0 -0
  12. {xax-0.3.2 → xax-0.3.3}/xax/cli/__init__.py +0 -0
  13. {xax-0.3.2 → xax-0.3.3}/xax/cli/edit_config.py +0 -0
  14. {xax-0.3.2 → xax-0.3.3}/xax/core/__init__.py +0 -0
  15. {xax-0.3.2 → xax-0.3.3}/xax/core/conf.py +0 -0
  16. {xax-0.3.2 → xax-0.3.3}/xax/core/state.py +0 -0
  17. {xax-0.3.2 → xax-0.3.3}/xax/nn/__init__.py +0 -0
  18. {xax-0.3.2 → xax-0.3.3}/xax/nn/attention.py +0 -0
  19. {xax-0.3.2 → xax-0.3.3}/xax/nn/embeddings.py +0 -0
  20. {xax-0.3.2 → xax-0.3.3}/xax/nn/functions.py +0 -0
  21. {xax-0.3.2 → xax-0.3.3}/xax/nn/losses.py +0 -0
  22. {xax-0.3.2 → xax-0.3.3}/xax/nn/metrics.py +0 -0
  23. {xax-0.3.2 → xax-0.3.3}/xax/nn/parallel.py +0 -0
  24. {xax-0.3.2 → xax-0.3.3}/xax/nn/ssm.py +0 -0
  25. {xax-0.3.2 → xax-0.3.3}/xax/py.typed +0 -0
  26. {xax-0.3.2 → xax-0.3.3}/xax/requirements-dev.txt +0 -0
  27. {xax-0.3.2 → xax-0.3.3}/xax/requirements.txt +0 -0
  28. {xax-0.3.2 → xax-0.3.3}/xax/task/__init__.py +0 -0
  29. {xax-0.3.2 → xax-0.3.3}/xax/task/base.py +0 -0
  30. {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/base.py +0 -0
  32. {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.3.2 → xax-0.3.3}/xax/task/logger.py +0 -0
  35. {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/json.py +0 -0
  38. {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/state.py +0 -0
  39. {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/process.py +0 -0
  50. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/train.py +0 -0
  52. {xax-0.3.2 → xax-0.3.3}/xax/task/script.py +0 -0
  53. {xax-0.3.2 → xax-0.3.3}/xax/task/task.py +0 -0
  54. {xax-0.3.2 → xax-0.3.3}/xax/utils/__init__.py +0 -0
  55. {xax-0.3.2 → xax-0.3.3}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.3.2 → xax-0.3.3}/xax/utils/data/collate.py +0 -0
  57. {xax-0.3.2 → xax-0.3.3}/xax/utils/debugging.py +0 -0
  58. {xax-0.3.2 → xax-0.3.3}/xax/utils/experiments.py +0 -0
  59. {xax-0.3.2 → xax-0.3.3}/xax/utils/jax.py +0 -0
  60. {xax-0.3.2 → xax-0.3.3}/xax/utils/jaxpr.py +0 -0
  61. {xax-0.3.2 → xax-0.3.3}/xax/utils/logging.py +0 -0
  62. {xax-0.3.2 → xax-0.3.3}/xax/utils/numpy.py +0 -0
  63. {xax-0.3.2 → xax-0.3.3}/xax/utils/profile.py +0 -0
  64. {xax-0.3.2 → xax-0.3.3}/xax/utils/pytree.py +0 -0
  65. {xax-0.3.2 → xax-0.3.3}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.3.2 → xax-0.3.3}/xax/utils/text.py +0 -0
  67. {xax-0.3.2 → xax-0.3.3}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.3.2 → xax-0.3.3}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.3.2 → xax-0.3.3}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.3.2 → xax-0.3.3}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.3.2 → xax-0.3.3}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.3.2 → xax-0.3.3}/xax.egg-info/entry_points.txt +0 -0
  73. {xax-0.3.2 → xax-0.3.3}/xax.egg-info/requires.txt +0 -0
  74. {xax-0.3.2 → xax-0.3.3}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.2"
15
+ __version__ = "0.3.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -42,11 +42,12 @@ __all__ = [
42
42
  "euler_to_quat",
43
43
  "get_projected_gravity_vector_from_quat",
44
44
  "normalize",
45
+ "quat_mul",
45
46
  "quat_to_euler",
46
47
  "quat_to_rotmat",
47
- "quat_mul",
48
48
  "rotate_vector_by_quat",
49
49
  "rotation6d_to_rotation_matrix",
50
+ "rotation_matrix_to_quat",
50
51
  "rotation_matrix_to_rotation6d",
51
52
  "cross_entropy",
52
53
  "cast_norm_type",
@@ -224,11 +225,12 @@ NAME_MAP: dict[str, str] = {
224
225
  "euler_to_quat": "nn.geom",
225
226
  "get_projected_gravity_vector_from_quat": "nn.geom",
226
227
  "normalize": "nn.geom",
228
+ "quat_mul": "nn.geom",
227
229
  "quat_to_euler": "nn.geom",
228
230
  "quat_to_rotmat": "nn.geom",
229
- "quat_mul": "nn.geom",
230
231
  "rotate_vector_by_quat": "nn.geom",
231
232
  "rotation6d_to_rotation_matrix": "nn.geom",
233
+ "rotation_matrix_to_quat": "nn.geom",
232
234
  "rotation_matrix_to_rotation6d": "nn.geom",
233
235
  "cross_entropy": "nn.losses",
234
236
  "cast_norm_type": "nn.metrics",
@@ -405,6 +407,7 @@ if IMPORT_ALL or TYPE_CHECKING:
405
407
  quat_to_rotmat,
406
408
  rotate_vector_by_quat,
407
409
  rotation6d_to_rotation_matrix,
410
+ rotation_matrix_to_quat,
408
411
  rotation_matrix_to_rotation6d,
409
412
  )
410
413
  from xax.nn.losses import cross_entropy
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
272
272
  z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
273
273
 
274
274
  return jnp.concatenate([w, x, y, z], axis=-1)
275
+
276
+
277
+ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
278
+ """Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
279
+
280
+ Args:
281
+ rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
282
+ eps: A small epsilon value to avoid division by zero when normalising.
283
+
284
+ Returns:
285
+ A quaternion with shape ``(*, 4)``.
286
+ """
287
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
288
+
289
+ m00 = rotation_matrix[..., 0, 0]
290
+ m01 = rotation_matrix[..., 0, 1]
291
+ m02 = rotation_matrix[..., 0, 2]
292
+ m10 = rotation_matrix[..., 1, 0]
293
+ m11 = rotation_matrix[..., 1, 1]
294
+ m12 = rotation_matrix[..., 1, 2]
295
+ m20 = rotation_matrix[..., 2, 0]
296
+ m21 = rotation_matrix[..., 2, 1]
297
+ m22 = rotation_matrix[..., 2, 2]
298
+
299
+ trace = m00 + m11 + m22
300
+
301
+ # Case 0: trace is positive
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
303
+ w0 = 0.25 * s0
304
+ x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
+ y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
+ z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
+
308
+ # Case 1: m00 is the largest diagonal term
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
310
+ w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
+ x1 = 0.25 * s1
312
+ y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
+ z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
+
315
+ # Case 2: m11 is the largest diagonal term
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
317
+ w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
+ x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
+ y2 = 0.25 * s2
320
+ z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
+
322
+ # Case 3: m22 is the largest diagonal term
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
324
+ w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
+ x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
+ y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
327
+ z3 = 0.25 * s3
328
+
329
+ cond0 = trace > 0.0
330
+ cond1 = (m00 > m11) & (m00 > m22)
331
+ cond2 = m11 > m22
332
+
333
+ w = jnp.where(
334
+ cond0,
335
+ w0,
336
+ jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
337
+ )
338
+ x = jnp.where(
339
+ cond0,
340
+ x0,
341
+ jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
342
+ )
343
+ y = jnp.where(
344
+ cond0,
345
+ y0,
346
+ jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
347
+ )
348
+ z = jnp.where(
349
+ cond0,
350
+ z0,
351
+ jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
352
+ )
353
+
354
+ quat = jnp.stack([w, x, y, z], axis=-1)
355
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
356
+ return quat
@@ -10,6 +10,7 @@ import jax
10
10
 
11
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
12
12
  from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.launchers.cli import CliLauncher
13
14
 
14
15
 
15
16
  @jax.tree_util.register_dataclass
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
45
46
  use_cli: bool | list[str] = True,
46
47
  ) -> None:
47
48
  if launcher is None:
48
- from xax.task.launchers.cli import CliLauncher
49
-
50
49
  launcher = CliLauncher()
51
50
  launcher.launch(cls, *cfgs, use_cli=use_cli)
52
51
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes