xax 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/nn/geom.py CHANGED
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
272
272
  z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
273
273
 
274
274
  return jnp.concatenate([w, x, y, z], axis=-1)
275
+
276
+
277
+ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
278
+ """Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
279
+
280
+ Args:
281
+ rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
282
+ eps: A small epsilon value to avoid division by zero when normalising.
283
+
284
+ Returns:
285
+ A quaternion with shape ``(*, 4)``.
286
+ """
287
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
288
+
289
+ m00 = rotation_matrix[..., 0, 0]
290
+ m01 = rotation_matrix[..., 0, 1]
291
+ m02 = rotation_matrix[..., 0, 2]
292
+ m10 = rotation_matrix[..., 1, 0]
293
+ m11 = rotation_matrix[..., 1, 1]
294
+ m12 = rotation_matrix[..., 1, 2]
295
+ m20 = rotation_matrix[..., 2, 0]
296
+ m21 = rotation_matrix[..., 2, 1]
297
+ m22 = rotation_matrix[..., 2, 2]
298
+
299
+ trace = m00 + m11 + m22
300
+
301
+ # Case 0: trace is positive
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
303
+ w0 = 0.25 * s0
304
+ x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
+ y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
+ z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
+
308
+ # Case 1: m00 is the largest diagonal term
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
310
+ w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
+ x1 = 0.25 * s1
312
+ y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
+ z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
+
315
+ # Case 2: m11 is the largest diagonal term
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
317
+ w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
+ x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
+ y2 = 0.25 * s2
320
+ z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
+
322
+ # Case 3: m22 is the largest diagonal term
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
324
+ w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
+ x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
+ y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
327
+ z3 = 0.25 * s3
328
+
329
+ cond0 = trace > 0.0
330
+ cond1 = (m00 > m11) & (m00 > m22)
331
+ cond2 = m11 > m22
332
+
333
+ w = jnp.where(
334
+ cond0,
335
+ w0,
336
+ jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
337
+ )
338
+ x = jnp.where(
339
+ cond0,
340
+ x0,
341
+ jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
342
+ )
343
+ y = jnp.where(
344
+ cond0,
345
+ y0,
346
+ jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
347
+ )
348
+ z = jnp.where(
349
+ cond0,
350
+ z0,
351
+ jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
352
+ )
353
+
354
+ quat = jnp.stack([w, x, y, z], axis=-1)
355
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
356
+ return quat
@@ -10,6 +10,7 @@ import jax
10
10
 
11
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
12
12
  from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.launchers.cli import CliLauncher
13
14
 
14
15
 
15
16
  @jax.tree_util.register_dataclass
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
45
46
  use_cli: bool | list[str] = True,
46
47
  ) -> None:
47
48
  if launcher is None:
48
- from xax.task.launchers.cli import CliLauncher
49
-
50
49
  launcher = CliLauncher()
51
50
  launcher.launch(cls, *cfgs, use_cli=use_cli)
52
51
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=7XPE3H-5F7BX-ArnD8SnuXu8Wx-ih2-2xcQsPUQDQMM,15713
1
+ xax/__init__.py,sha256=LJFB4xQplzC08tkbkZMxaCd-7jIB7aJZzBMcs9AuqiM,16240
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -8,10 +8,10 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
11
+ xax/nn/attention.py,sha256=aIEtrM7vAQtaXTPKmsqGcYqt03CyiUQMccXj8Cjw3vc,29514
12
12
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
13
13
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
- xax/nn/geom.py,sha256=7_cUeoJIfCS_YbQ6O7zElnpYsAn4XrVm4m7PaT9UcYw,8005
14
+ xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
15
15
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
16
16
  xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
17
17
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
@@ -40,7 +40,7 @@ xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-go
40
40
  xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
41
41
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
42
42
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
43
- xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
+ xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
44
44
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
45
45
  xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
46
46
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
59
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
60
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
61
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.2.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.2.dist-info/METADATA,sha256=-TMz7GJvEpVb59UiAHjNZi2EbbgdafhDYyhySzEgRFQ,1246
64
- xax-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.2.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.2.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.2.dist-info/RECORD,,
62
+ xax-0.3.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
+ xax-0.3.4.dist-info/METADATA,sha256=j_UQdK4iPYbhzMH0osmHm5XJnYnFY1A_Z5MwSJwXr-4,1246
64
+ xax-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ xax-0.3.4.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
+ xax-0.3.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
+ xax-0.3.4.dist-info/RECORD,,
File without changes