xax 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.15"
15
+ __version__ = "0.2.17"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/nn/geom.py CHANGED
@@ -86,28 +86,16 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
86
86
  Returns:
87
87
  A 3D vector representing the gravity in the local frame, shape (*, 3).
88
88
  """
89
- # Normalize quaternion
90
- quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
91
- w, x, y, z = jnp.split(quat, 4, axis=-1)
92
-
93
- # Gravity vector in world frame is [0, 0, -1] (pointing down)
94
- # Rotate gravity vector using quaternion rotation
89
+ return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
95
90
 
96
- # Calculate quaternion rotation: q * [0,0,-1] * q^-1
97
- gx = 2 * (x * z - w * y)
98
- gy = 2 * (y * z + w * x)
99
- gz = w * w - x * x - y * y + z * z
100
91
 
101
- # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
102
- return jnp.concatenate([gx, gy, -gz], axis=-1)
103
-
104
-
105
- def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
92
+ def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
106
93
  """Rotates a vector by a quaternion.
107
94
 
108
95
  Args:
109
96
  vector: The vector to rotate, shape (*, 3).
110
97
  quat: The quaternion to rotate by, shape (*, 4).
98
+ inverse: If True, rotate the vector by the conjugate of the quaternion.
111
99
  eps: A small epsilon value to avoid division by zero.
112
100
 
113
101
  Returns:
@@ -117,6 +105,9 @@ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Arra
117
105
  quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
118
106
  w, x, y, z = jnp.split(quat, 4, axis=-1)
119
107
 
108
+ if inverse:
109
+ x, y, z = -x, -y, -z
110
+
120
111
  # Extract vector components
121
112
  vx, vy, vz = jnp.split(vector, 3, axis=-1)
122
113
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.15
3
+ Version: 0.2.17
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=JVxuGfbwBPHXiF4kSG0Pb73mzu3EIaRipjvt0Y-Z9W4,15733
1
+ xax/__init__.py,sha256=dBxZ_r1ck3C9ZH9VRM38-ApVkyUe7CnI_0SF9k07KcI,15733
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -10,7 +10,7 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
11
11
  xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
12
12
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
13
- xax/nn/geom.py,sha256=B8QE-L-xJWhf9KygTByPUAWe7Clpek4GlTABpsJFMBs,7702
13
+ xax/nn/geom.py,sha256=A7WPefMvgwUNReZC7_HX1GmvHPASyghbaXaKsuhwDrE,7382
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/metrics.py,sha256=OAkeScwhi-wTBIJ59KHUhYbZTq4V4V-LG-mKlxMJ7bY,3238
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.15.dist-info/METADATA,sha256=6LJoiKOyNmF1MJSwVSdbEJATzSv1P77Amn4ZJCbWaP0,1880
63
- xax-0.2.15.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
64
- xax-0.2.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.15.dist-info/RECORD,,
61
+ xax-0.2.17.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.17.dist-info/METADATA,sha256=GQhyzReeHSrZkYrpxeSXt19z2271zD49-S6fwN6cagU,1880
63
+ xax-0.2.17.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
64
+ xax-0.2.17.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.17.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5