xax 0.2.16__tar.gz → 0.2.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.16/xax.egg-info → xax-0.2.17}/PKG-INFO +1 -1
  2. {xax-0.2.16 → xax-0.2.17}/xax/__init__.py +1 -1
  3. {xax-0.2.16 → xax-0.2.17}/xax/nn/geom.py +1 -14
  4. {xax-0.2.16 → xax-0.2.17/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.2.16 → xax-0.2.17}/LICENSE +0 -0
  6. {xax-0.2.16 → xax-0.2.17}/MANIFEST.in +0 -0
  7. {xax-0.2.16 → xax-0.2.17}/README.md +0 -0
  8. {xax-0.2.16 → xax-0.2.17}/pyproject.toml +0 -0
  9. {xax-0.2.16 → xax-0.2.17}/setup.cfg +0 -0
  10. {xax-0.2.16 → xax-0.2.17}/setup.py +0 -0
  11. {xax-0.2.16 → xax-0.2.17}/xax/core/__init__.py +0 -0
  12. {xax-0.2.16 → xax-0.2.17}/xax/core/conf.py +0 -0
  13. {xax-0.2.16 → xax-0.2.17}/xax/core/state.py +0 -0
  14. {xax-0.2.16 → xax-0.2.17}/xax/nn/__init__.py +0 -0
  15. {xax-0.2.16 → xax-0.2.17}/xax/nn/embeddings.py +0 -0
  16. {xax-0.2.16 → xax-0.2.17}/xax/nn/equinox.py +0 -0
  17. {xax-0.2.16 → xax-0.2.17}/xax/nn/export.py +0 -0
  18. {xax-0.2.16 → xax-0.2.17}/xax/nn/functions.py +0 -0
  19. {xax-0.2.16 → xax-0.2.17}/xax/nn/losses.py +0 -0
  20. {xax-0.2.16 → xax-0.2.17}/xax/nn/metrics.py +0 -0
  21. {xax-0.2.16 → xax-0.2.17}/xax/nn/parallel.py +0 -0
  22. {xax-0.2.16 → xax-0.2.17}/xax/nn/ssm.py +0 -0
  23. {xax-0.2.16 → xax-0.2.17}/xax/py.typed +0 -0
  24. {xax-0.2.16 → xax-0.2.17}/xax/requirements-dev.txt +0 -0
  25. {xax-0.2.16 → xax-0.2.17}/xax/requirements.txt +0 -0
  26. {xax-0.2.16 → xax-0.2.17}/xax/task/__init__.py +0 -0
  27. {xax-0.2.16 → xax-0.2.17}/xax/task/base.py +0 -0
  28. {xax-0.2.16 → xax-0.2.17}/xax/task/launchers/__init__.py +0 -0
  29. {xax-0.2.16 → xax-0.2.17}/xax/task/launchers/base.py +0 -0
  30. {xax-0.2.16 → xax-0.2.17}/xax/task/launchers/cli.py +0 -0
  31. {xax-0.2.16 → xax-0.2.17}/xax/task/launchers/single_process.py +0 -0
  32. {xax-0.2.16 → xax-0.2.17}/xax/task/logger.py +0 -0
  33. {xax-0.2.16 → xax-0.2.17}/xax/task/loggers/__init__.py +0 -0
  34. {xax-0.2.16 → xax-0.2.17}/xax/task/loggers/callback.py +0 -0
  35. {xax-0.2.16 → xax-0.2.17}/xax/task/loggers/json.py +0 -0
  36. {xax-0.2.16 → xax-0.2.17}/xax/task/loggers/state.py +0 -0
  37. {xax-0.2.16 → xax-0.2.17}/xax/task/loggers/stdout.py +0 -0
  38. {xax-0.2.16 → xax-0.2.17}/xax/task/loggers/tensorboard.py +0 -0
  39. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/__init__.py +0 -0
  40. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/artifacts.py +0 -0
  41. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/checkpointing.py +0 -0
  42. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/compile.py +0 -0
  43. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/process.py +0 -0
  48. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.2.16 → xax-0.2.17}/xax/task/mixins/train.py +0 -0
  51. {xax-0.2.16 → xax-0.2.17}/xax/task/script.py +0 -0
  52. {xax-0.2.16 → xax-0.2.17}/xax/task/task.py +0 -0
  53. {xax-0.2.16 → xax-0.2.17}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.16 → xax-0.2.17}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.16 → xax-0.2.17}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.16 → xax-0.2.17}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.16 → xax-0.2.17}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.16 → xax-0.2.17}/xax/utils/jax.py +0 -0
  59. {xax-0.2.16 → xax-0.2.17}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.16 → xax-0.2.17}/xax/utils/logging.py +0 -0
  61. {xax-0.2.16 → xax-0.2.17}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.16 → xax-0.2.17}/xax/utils/profile.py +0 -0
  63. {xax-0.2.16 → xax-0.2.17}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.16 → xax-0.2.17}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.16 → xax-0.2.17}/xax/utils/text.py +0 -0
  66. {xax-0.2.16 → xax-0.2.17}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.16 → xax-0.2.17}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.16 → xax-0.2.17}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.16 → xax-0.2.17}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.16 → xax-0.2.17}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.16 → xax-0.2.17}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.16 → xax-0.2.17}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.16
3
+ Version: 0.2.17
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.16"
15
+ __version__ = "0.2.17"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -86,20 +86,7 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
86
86
  Returns:
87
87
  A 3D vector representing the gravity in the local frame, shape (*, 3).
88
88
  """
89
- # Normalize quaternion
90
- quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
91
- w, x, y, z = jnp.split(quat, 4, axis=-1)
92
-
93
- # Gravity vector in world frame is [0, 0, -1] (pointing down)
94
- # Rotate gravity vector using quaternion rotation
95
-
96
- # Calculate quaternion rotation: q * [0,0,-1] * q^-1
97
- gx = 2 * (x * z - w * y)
98
- gy = 2 * (y * z + w * x)
99
- gz = w * w - x * x - y * y + z * z
100
-
101
- # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
102
- return jnp.concatenate([gx, gy, -gz], axis=-1)
89
+ return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
103
90
 
104
91
 
105
92
  def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.16
3
+ Version: 0.2.17
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes