xax 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {xax-0.3.0/xax.egg-info → xax-0.3.1}/PKG-INFO +1 -1
  2. {xax-0.3.0 → xax-0.3.1}/xax/__init__.py +4 -1
  3. {xax-0.3.0 → xax-0.3.1}/xax/nn/geom.py +21 -0
  4. {xax-0.3.0 → xax-0.3.1/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.3.0 → xax-0.3.1}/LICENSE +0 -0
  6. {xax-0.3.0 → xax-0.3.1}/MANIFEST.in +0 -0
  7. {xax-0.3.0 → xax-0.3.1}/README.md +0 -0
  8. {xax-0.3.0 → xax-0.3.1}/pyproject.toml +0 -0
  9. {xax-0.3.0 → xax-0.3.1}/setup.cfg +0 -0
  10. {xax-0.3.0 → xax-0.3.1}/setup.py +0 -0
  11. {xax-0.3.0 → xax-0.3.1}/xax/cli/__init__.py +0 -0
  12. {xax-0.3.0 → xax-0.3.1}/xax/cli/edit_config.py +0 -0
  13. {xax-0.3.0 → xax-0.3.1}/xax/core/__init__.py +0 -0
  14. {xax-0.3.0 → xax-0.3.1}/xax/core/conf.py +0 -0
  15. {xax-0.3.0 → xax-0.3.1}/xax/core/state.py +0 -0
  16. {xax-0.3.0 → xax-0.3.1}/xax/nn/__init__.py +0 -0
  17. {xax-0.3.0 → xax-0.3.1}/xax/nn/attention.py +0 -0
  18. {xax-0.3.0 → xax-0.3.1}/xax/nn/embeddings.py +0 -0
  19. {xax-0.3.0 → xax-0.3.1}/xax/nn/functions.py +0 -0
  20. {xax-0.3.0 → xax-0.3.1}/xax/nn/losses.py +0 -0
  21. {xax-0.3.0 → xax-0.3.1}/xax/nn/metrics.py +0 -0
  22. {xax-0.3.0 → xax-0.3.1}/xax/nn/parallel.py +0 -0
  23. {xax-0.3.0 → xax-0.3.1}/xax/nn/ssm.py +0 -0
  24. {xax-0.3.0 → xax-0.3.1}/xax/py.typed +0 -0
  25. {xax-0.3.0 → xax-0.3.1}/xax/requirements-dev.txt +0 -0
  26. {xax-0.3.0 → xax-0.3.1}/xax/requirements.txt +0 -0
  27. {xax-0.3.0 → xax-0.3.1}/xax/task/__init__.py +0 -0
  28. {xax-0.3.0 → xax-0.3.1}/xax/task/base.py +0 -0
  29. {xax-0.3.0 → xax-0.3.1}/xax/task/launchers/__init__.py +0 -0
  30. {xax-0.3.0 → xax-0.3.1}/xax/task/launchers/base.py +0 -0
  31. {xax-0.3.0 → xax-0.3.1}/xax/task/launchers/cli.py +0 -0
  32. {xax-0.3.0 → xax-0.3.1}/xax/task/launchers/single_process.py +0 -0
  33. {xax-0.3.0 → xax-0.3.1}/xax/task/logger.py +0 -0
  34. {xax-0.3.0 → xax-0.3.1}/xax/task/loggers/__init__.py +0 -0
  35. {xax-0.3.0 → xax-0.3.1}/xax/task/loggers/callback.py +0 -0
  36. {xax-0.3.0 → xax-0.3.1}/xax/task/loggers/json.py +0 -0
  37. {xax-0.3.0 → xax-0.3.1}/xax/task/loggers/state.py +0 -0
  38. {xax-0.3.0 → xax-0.3.1}/xax/task/loggers/stdout.py +0 -0
  39. {xax-0.3.0 → xax-0.3.1}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/process.py +0 -0
  49. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.3.0 → xax-0.3.1}/xax/task/mixins/train.py +0 -0
  52. {xax-0.3.0 → xax-0.3.1}/xax/task/script.py +0 -0
  53. {xax-0.3.0 → xax-0.3.1}/xax/task/task.py +0 -0
  54. {xax-0.3.0 → xax-0.3.1}/xax/utils/__init__.py +0 -0
  55. {xax-0.3.0 → xax-0.3.1}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.3.0 → xax-0.3.1}/xax/utils/data/collate.py +0 -0
  57. {xax-0.3.0 → xax-0.3.1}/xax/utils/debugging.py +0 -0
  58. {xax-0.3.0 → xax-0.3.1}/xax/utils/experiments.py +0 -0
  59. {xax-0.3.0 → xax-0.3.1}/xax/utils/jax.py +0 -0
  60. {xax-0.3.0 → xax-0.3.1}/xax/utils/jaxpr.py +0 -0
  61. {xax-0.3.0 → xax-0.3.1}/xax/utils/logging.py +0 -0
  62. {xax-0.3.0 → xax-0.3.1}/xax/utils/numpy.py +0 -0
  63. {xax-0.3.0 → xax-0.3.1}/xax/utils/profile.py +0 -0
  64. {xax-0.3.0 → xax-0.3.1}/xax/utils/pytree.py +0 -0
  65. {xax-0.3.0 → xax-0.3.1}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.3.0 → xax-0.3.1}/xax/utils/text.py +0 -0
  67. {xax-0.3.0 → xax-0.3.1}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.3.0 → xax-0.3.1}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.3.0 → xax-0.3.1}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.3.0 → xax-0.3.1}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.3.0 → xax-0.3.1}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.3.0 → xax-0.3.1}/xax.egg-info/entry_points.txt +0 -0
  73. {xax-0.3.0 → xax-0.3.1}/xax.egg-info/requires.txt +0 -0
  74. {xax-0.3.0 → xax-0.3.1}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.0"
15
+ __version__ = "0.3.1"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -44,6 +44,7 @@ __all__ = [
44
44
  "normalize",
45
45
  "quat_to_euler",
46
46
  "quat_to_rotmat",
47
+ "quat_mul",
47
48
  "rotate_vector_by_quat",
48
49
  "rotation6d_to_rotation_matrix",
49
50
  "rotation_matrix_to_rotation6d",
@@ -225,6 +226,7 @@ NAME_MAP: dict[str, str] = {
225
226
  "normalize": "nn.geom",
226
227
  "quat_to_euler": "nn.geom",
227
228
  "quat_to_rotmat": "nn.geom",
229
+ "quat_mul": "nn.geom",
228
230
  "rotate_vector_by_quat": "nn.geom",
229
231
  "rotation6d_to_rotation_matrix": "nn.geom",
230
232
  "rotation_matrix_to_rotation6d": "nn.geom",
@@ -398,6 +400,7 @@ if IMPORT_ALL or TYPE_CHECKING:
398
400
  euler_to_quat,
399
401
  get_projected_gravity_vector_from_quat,
400
402
  normalize,
403
+ quat_mul,
401
404
  quat_to_euler,
402
405
  quat_to_rotmat,
403
406
  rotate_vector_by_quat,
@@ -251,3 +251,24 @@ def rotation_matrix_to_rotation6d(rotation_matrix: jnp.ndarray) -> jnp.ndarray:
251
251
  # Simply concatenate a1 and a2 from SO(3)
252
252
  r6d = jnp.concatenate([rotation_matrix[..., 0], rotation_matrix[..., 1]], axis=-1)
253
253
  return r6d.reshape(shape[:-2] + (6,))
254
+
255
+
256
+ def quat_mul(q2: Array, q1: Array) -> Array:
257
+ """Multiply two quaternions (supports batching).
258
+
259
+ Args:
260
+ q2: Second quaternion (w, x, y, z), shape (..., 4)
261
+ q1: First quaternion (w, x, y, z), shape (..., 4)
262
+
263
+ Returns:
264
+ Product quaternion, shape (..., 4)
265
+ """
266
+ w1, x1, y1, z1 = jnp.split(q1, 4, axis=-1)
267
+ w2, x2, y2, z2 = jnp.split(q2, 4, axis=-1)
268
+
269
+ w = w2 * w1 - x2 * x1 - y2 * y1 - z2 * z1
270
+ x = w2 * x1 + x2 * w1 + y2 * z1 - z2 * y1
271
+ y = w2 * y1 - x2 * z1 + y2 * w1 + z2 * x1
272
+ z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
273
+
274
+ return jnp.concatenate([w, x, y, z], axis=-1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes