xax 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.2"
15
+ __version__ = "0.3.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -42,11 +42,12 @@ __all__ = [
42
42
  "euler_to_quat",
43
43
  "get_projected_gravity_vector_from_quat",
44
44
  "normalize",
45
+ "quat_mul",
45
46
  "quat_to_euler",
46
47
  "quat_to_rotmat",
47
- "quat_mul",
48
48
  "rotate_vector_by_quat",
49
49
  "rotation6d_to_rotation_matrix",
50
+ "rotation_matrix_to_quat",
50
51
  "rotation_matrix_to_rotation6d",
51
52
  "cross_entropy",
52
53
  "cast_norm_type",
@@ -224,11 +225,12 @@ NAME_MAP: dict[str, str] = {
224
225
  "euler_to_quat": "nn.geom",
225
226
  "get_projected_gravity_vector_from_quat": "nn.geom",
226
227
  "normalize": "nn.geom",
228
+ "quat_mul": "nn.geom",
227
229
  "quat_to_euler": "nn.geom",
228
230
  "quat_to_rotmat": "nn.geom",
229
- "quat_mul": "nn.geom",
230
231
  "rotate_vector_by_quat": "nn.geom",
231
232
  "rotation6d_to_rotation_matrix": "nn.geom",
233
+ "rotation_matrix_to_quat": "nn.geom",
232
234
  "rotation_matrix_to_rotation6d": "nn.geom",
233
235
  "cross_entropy": "nn.losses",
234
236
  "cast_norm_type": "nn.metrics",
@@ -405,6 +407,7 @@ if IMPORT_ALL or TYPE_CHECKING:
405
407
  quat_to_rotmat,
406
408
  rotate_vector_by_quat,
407
409
  rotation6d_to_rotation_matrix,
410
+ rotation_matrix_to_quat,
408
411
  rotation_matrix_to_rotation6d,
409
412
  )
410
413
  from xax.nn.losses import cross_entropy
xax/nn/geom.py CHANGED
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
272
272
  z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
273
273
 
274
274
  return jnp.concatenate([w, x, y, z], axis=-1)
275
+
276
+
277
+ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
278
+ """Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
279
+
280
+ Args:
281
+ rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
282
+ eps: A small epsilon value to avoid division by zero when normalising.
283
+
284
+ Returns:
285
+ A quaternion with shape ``(*, 4)``.
286
+ """
287
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
288
+
289
+ m00 = rotation_matrix[..., 0, 0]
290
+ m01 = rotation_matrix[..., 0, 1]
291
+ m02 = rotation_matrix[..., 0, 2]
292
+ m10 = rotation_matrix[..., 1, 0]
293
+ m11 = rotation_matrix[..., 1, 1]
294
+ m12 = rotation_matrix[..., 1, 2]
295
+ m20 = rotation_matrix[..., 2, 0]
296
+ m21 = rotation_matrix[..., 2, 1]
297
+ m22 = rotation_matrix[..., 2, 2]
298
+
299
+ trace = m00 + m11 + m22
300
+
301
+ # Case 0: trace is positive
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
303
+ w0 = 0.25 * s0
304
+ x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
+ y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
+ z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
+
308
+ # Case 1: m00 is the largest diagonal term
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
310
+ w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
+ x1 = 0.25 * s1
312
+ y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
+ z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
+
315
+ # Case 2: m11 is the largest diagonal term
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
317
+ w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
+ x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
+ y2 = 0.25 * s2
320
+ z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
+
322
+ # Case 3: m22 is the largest diagonal term
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
324
+ w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
+ x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
+ y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
327
+ z3 = 0.25 * s3
328
+
329
+ cond0 = trace > 0.0
330
+ cond1 = (m00 > m11) & (m00 > m22)
331
+ cond2 = m11 > m22
332
+
333
+ w = jnp.where(
334
+ cond0,
335
+ w0,
336
+ jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
337
+ )
338
+ x = jnp.where(
339
+ cond0,
340
+ x0,
341
+ jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
342
+ )
343
+ y = jnp.where(
344
+ cond0,
345
+ y0,
346
+ jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
347
+ )
348
+ z = jnp.where(
349
+ cond0,
350
+ z0,
351
+ jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
352
+ )
353
+
354
+ quat = jnp.stack([w, x, y, z], axis=-1)
355
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
356
+ return quat
@@ -10,6 +10,7 @@ import jax
10
10
 
11
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
12
12
  from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.launchers.cli import CliLauncher
13
14
 
14
15
 
15
16
  @jax.tree_util.register_dataclass
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
45
46
  use_cli: bool | list[str] = True,
46
47
  ) -> None:
47
48
  if launcher is None:
48
- from xax.task.launchers.cli import CliLauncher
49
-
50
49
  launcher = CliLauncher()
51
50
  launcher.launch(cls, *cfgs, use_cli=use_cli)
52
51
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=7XPE3H-5F7BX-ArnD8SnuXu8Wx-ih2-2xcQsPUQDQMM,15713
1
+ xax/__init__.py,sha256=ffVd9_qSVuEAIPn6eK_6N8qEfDALRJigArHZQGy6y1o,15819
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -11,7 +11,7 @@ xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
12
12
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
13
13
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
- xax/nn/geom.py,sha256=7_cUeoJIfCS_YbQ6O7zElnpYsAn4XrVm4m7PaT9UcYw,8005
14
+ xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
15
15
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
16
16
  xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
17
17
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
@@ -40,7 +40,7 @@ xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-go
40
40
  xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
41
41
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
42
42
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
43
- xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
+ xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
44
44
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
45
45
  xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
46
46
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
59
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
60
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
61
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.2.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.2.dist-info/METADATA,sha256=-TMz7GJvEpVb59UiAHjNZi2EbbgdafhDYyhySzEgRFQ,1246
64
- xax-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.2.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.2.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.2.dist-info/RECORD,,
62
+ xax-0.3.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
+ xax-0.3.3.dist-info/METADATA,sha256=mjIzoFZDSR3V1-2LHbvup6wDVa4vLiqbqiNWLsKCXY8,1246
64
+ xax-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ xax-0.3.3.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
+ xax-0.3.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
+ xax-0.3.3.dist-info/RECORD,,
File without changes