xax 0.2.15__tar.gz → 0.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.15/xax.egg-info → xax-0.2.16}/PKG-INFO +1 -1
  2. {xax-0.2.15 → xax-0.2.16}/xax/__init__.py +1 -1
  3. {xax-0.2.15 → xax-0.2.16}/xax/nn/geom.py +5 -1
  4. {xax-0.2.15 → xax-0.2.16/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.2.15 → xax-0.2.16}/LICENSE +0 -0
  6. {xax-0.2.15 → xax-0.2.16}/MANIFEST.in +0 -0
  7. {xax-0.2.15 → xax-0.2.16}/README.md +0 -0
  8. {xax-0.2.15 → xax-0.2.16}/pyproject.toml +0 -0
  9. {xax-0.2.15 → xax-0.2.16}/setup.cfg +0 -0
  10. {xax-0.2.15 → xax-0.2.16}/setup.py +0 -0
  11. {xax-0.2.15 → xax-0.2.16}/xax/core/__init__.py +0 -0
  12. {xax-0.2.15 → xax-0.2.16}/xax/core/conf.py +0 -0
  13. {xax-0.2.15 → xax-0.2.16}/xax/core/state.py +0 -0
  14. {xax-0.2.15 → xax-0.2.16}/xax/nn/__init__.py +0 -0
  15. {xax-0.2.15 → xax-0.2.16}/xax/nn/embeddings.py +0 -0
  16. {xax-0.2.15 → xax-0.2.16}/xax/nn/equinox.py +0 -0
  17. {xax-0.2.15 → xax-0.2.16}/xax/nn/export.py +0 -0
  18. {xax-0.2.15 → xax-0.2.16}/xax/nn/functions.py +0 -0
  19. {xax-0.2.15 → xax-0.2.16}/xax/nn/losses.py +0 -0
  20. {xax-0.2.15 → xax-0.2.16}/xax/nn/metrics.py +0 -0
  21. {xax-0.2.15 → xax-0.2.16}/xax/nn/parallel.py +0 -0
  22. {xax-0.2.15 → xax-0.2.16}/xax/nn/ssm.py +0 -0
  23. {xax-0.2.15 → xax-0.2.16}/xax/py.typed +0 -0
  24. {xax-0.2.15 → xax-0.2.16}/xax/requirements-dev.txt +0 -0
  25. {xax-0.2.15 → xax-0.2.16}/xax/requirements.txt +0 -0
  26. {xax-0.2.15 → xax-0.2.16}/xax/task/__init__.py +0 -0
  27. {xax-0.2.15 → xax-0.2.16}/xax/task/base.py +0 -0
  28. {xax-0.2.15 → xax-0.2.16}/xax/task/launchers/__init__.py +0 -0
  29. {xax-0.2.15 → xax-0.2.16}/xax/task/launchers/base.py +0 -0
  30. {xax-0.2.15 → xax-0.2.16}/xax/task/launchers/cli.py +0 -0
  31. {xax-0.2.15 → xax-0.2.16}/xax/task/launchers/single_process.py +0 -0
  32. {xax-0.2.15 → xax-0.2.16}/xax/task/logger.py +0 -0
  33. {xax-0.2.15 → xax-0.2.16}/xax/task/loggers/__init__.py +0 -0
  34. {xax-0.2.15 → xax-0.2.16}/xax/task/loggers/callback.py +0 -0
  35. {xax-0.2.15 → xax-0.2.16}/xax/task/loggers/json.py +0 -0
  36. {xax-0.2.15 → xax-0.2.16}/xax/task/loggers/state.py +0 -0
  37. {xax-0.2.15 → xax-0.2.16}/xax/task/loggers/stdout.py +0 -0
  38. {xax-0.2.15 → xax-0.2.16}/xax/task/loggers/tensorboard.py +0 -0
  39. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/__init__.py +0 -0
  40. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/artifacts.py +0 -0
  41. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/checkpointing.py +0 -0
  42. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/compile.py +0 -0
  43. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/process.py +0 -0
  48. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.2.15 → xax-0.2.16}/xax/task/mixins/train.py +0 -0
  51. {xax-0.2.15 → xax-0.2.16}/xax/task/script.py +0 -0
  52. {xax-0.2.15 → xax-0.2.16}/xax/task/task.py +0 -0
  53. {xax-0.2.15 → xax-0.2.16}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.15 → xax-0.2.16}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.15 → xax-0.2.16}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.15 → xax-0.2.16}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.15 → xax-0.2.16}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.15 → xax-0.2.16}/xax/utils/jax.py +0 -0
  59. {xax-0.2.15 → xax-0.2.16}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.15 → xax-0.2.16}/xax/utils/logging.py +0 -0
  61. {xax-0.2.15 → xax-0.2.16}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.15 → xax-0.2.16}/xax/utils/profile.py +0 -0
  63. {xax-0.2.15 → xax-0.2.16}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.15 → xax-0.2.16}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.15 → xax-0.2.16}/xax/utils/text.py +0 -0
  66. {xax-0.2.15 → xax-0.2.16}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.15 → xax-0.2.16}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.15 → xax-0.2.16}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.15 → xax-0.2.16}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.15 → xax-0.2.16}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.15 → xax-0.2.16}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.15 → xax-0.2.16}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.15
3
+ Version: 0.2.16
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.15"
15
+ __version__ = "0.2.16"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -102,12 +102,13 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
102
102
  return jnp.concatenate([gx, gy, -gz], axis=-1)
103
103
 
104
104
 
105
- def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
105
+ def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
106
106
  """Rotates a vector by a quaternion.
107
107
 
108
108
  Args:
109
109
  vector: The vector to rotate, shape (*, 3).
110
110
  quat: The quaternion to rotate by, shape (*, 4).
111
+ inverse: If True, rotate the vector by the conjugate of the quaternion.
111
112
  eps: A small epsilon value to avoid division by zero.
112
113
 
113
114
  Returns:
@@ -117,6 +118,9 @@ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Arra
117
118
  quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
118
119
  w, x, y, z = jnp.split(quat, 4, axis=-1)
119
120
 
121
+ if inverse:
122
+ x, y, z = -x, -y, -z
123
+
120
124
  # Extract vector components
121
125
  vx, vy, vz = jnp.split(vector, 3, axis=-1)
122
126
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.15
3
+ Version: 0.2.16
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes