xax 0.3.12__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.12"
15
+ __version__ = "0.3.14"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -53,6 +53,7 @@ __all__ = [
53
53
  "quat_mul",
54
54
  "quat_to_euler",
55
55
  "quat_to_rotmat",
56
+ "quat_to_yaw",
56
57
  "rotate_vector_by_quat",
57
58
  "rotation6d_to_rotation_matrix",
58
59
  "rotation_matrix_to_quat",
@@ -100,9 +101,9 @@ __all__ = [
100
101
  "Task",
101
102
  "collate",
102
103
  "collate_non_null",
103
- "breakpoint_if_nan",
104
+ "breakpoint_if_nonfinite",
104
105
  "get_named_leaves",
105
- "log_if_nan",
106
+ "log_if_nonfinite",
106
107
  "BaseFileDownloader",
107
108
  "ContextTimer",
108
109
  "CumulativeTimer",
@@ -198,7 +199,10 @@ if "XLA_FLAGS" in os.environ:
198
199
  # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
199
200
  # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
200
201
  if shutil.which("nvidia-smi") is not None:
201
- xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
202
+ xla_flags += [
203
+ "--xla_gpu_enable_latency_hiding_scheduler=true",
204
+ "--xla_gpu_enable_triton_gemm=false",
205
+ ]
202
206
  os.environ["XLA_FLAGS"] = " ".join(xla_flags)
203
207
 
204
208
  # If this flag is set, eagerly imports the entire package (not recommended).
@@ -246,6 +250,7 @@ NAME_MAP: dict[str, str] = {
246
250
  "quat_mul": "nn.geom",
247
251
  "quat_to_euler": "nn.geom",
248
252
  "quat_to_rotmat": "nn.geom",
253
+ "quat_to_yaw": "nn.geom",
249
254
  "rotate_vector_by_quat": "nn.geom",
250
255
  "rotation6d_to_rotation_matrix": "nn.geom",
251
256
  "rotation_matrix_to_quat": "nn.geom",
@@ -293,9 +298,9 @@ NAME_MAP: dict[str, str] = {
293
298
  "Task": "task.task",
294
299
  "collate": "utils.data.collate",
295
300
  "collate_non_null": "utils.data.collate",
296
- "breakpoint_if_nan": "utils.debugging",
301
+ "breakpoint_if_nonfinite": "utils.debugging",
297
302
  "get_named_leaves": "utils.debugging",
298
- "log_if_nan": "utils.debugging",
303
+ "log_if_nonfinite": "utils.debugging",
299
304
  "BaseFileDownloader": "utils.experiments",
300
305
  "ContextTimer": "utils.experiments",
301
306
  "CumulativeTimer": "utils.experiments",
@@ -443,6 +448,7 @@ if IMPORT_ALL or TYPE_CHECKING:
443
448
  quat_mul,
444
449
  quat_to_euler,
445
450
  quat_to_rotmat,
451
+ quat_to_yaw,
446
452
  rotate_vector_by_quat,
447
453
  rotation6d_to_rotation_matrix,
448
454
  rotation_matrix_to_quat,
@@ -486,7 +492,11 @@ if IMPORT_ALL or TYPE_CHECKING:
486
492
  from xax.task.script import Script, ScriptConfig
487
493
  from xax.task.task import Config, Task
488
494
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
489
- from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
495
+ from xax.utils.debugging import (
496
+ breakpoint_if_nonfinite,
497
+ get_named_leaves,
498
+ log_if_nonfinite,
499
+ )
490
500
  from xax.utils.experiments import (
491
501
  BaseFileDownloader,
492
502
  ContextTimer,
xax/nn/geom.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Defines geometry functions."""
2
2
 
3
3
  import chex
4
+ import jax
4
5
  from jax import numpy as jnp
5
6
  from jaxtyping import Array
6
7
 
@@ -15,30 +16,53 @@ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
15
16
  Returns:
16
17
  The roll, pitch, yaw angles with shape (*, 3).
17
18
  """
18
- quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
19
- w, x, y, z = jnp.split(quat_4, 4, axis=-1)
19
+ # Normalize with clamping
20
+ norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
21
+ inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
22
+ quat_4 = quat_4 * inv_norm
23
+
24
+ w, x, y, z = jnp.unstack(quat_4, axis=-1)
20
25
 
21
26
  # Roll (x-axis rotation)
22
27
  sinr_cosp = 2.0 * (w * x + y * z)
23
28
  cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
24
- roll = jnp.arctan2(sinr_cosp, cosr_cosp)
29
+ roll = jax.lax.atan2(sinr_cosp, cosr_cosp)
25
30
 
26
31
  # Pitch (y-axis rotation)
27
32
  sinp = 2.0 * (w * y - z * x)
28
-
29
- # Handle edge cases where |sinp| >= 1
30
- pitch = jnp.where(
31
- jnp.abs(sinp) >= 1.0,
32
- jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
33
- jnp.arcsin(sinp),
34
- )
33
+ sinp = jnp.clip(sinp, -1.0, 1.0) # Clamp to valid domain
34
+ pitch = jax.lax.asin(sinp)
35
35
 
36
36
  # Yaw (z-axis rotation)
37
37
  siny_cosp = 2.0 * (w * z + x * y)
38
38
  cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
39
- yaw = jnp.arctan2(siny_cosp, cosy_cosp)
39
+ yaw = jax.lax.atan2(siny_cosp, cosy_cosp)
40
+
41
+ return jnp.stack([roll, pitch, yaw], axis=-1)
42
+
43
+
44
+ def quat_to_yaw(quat_4: Array, eps: float = 1e-6) -> Array:
45
+ """Converts a quaternion to a yaw angle.
46
+
47
+ Args:
48
+ quat_4: The quaternion to convert, shape (*, 4).
49
+ eps: A small epsilon value to avoid division by zero.
50
+
51
+ Returns:
52
+ The yaw angle, shape (*).
53
+ """
54
+ # Normalize using a max + safe norm to handle extremely small values robustly
55
+ norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
56
+ inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
57
+ quat_4 = quat_4 * inv_norm
58
+
59
+ w, x, y, z = jnp.unstack(quat_4, axis=-1)
60
+
61
+ # Compute components with clamping to avoid rounding errors near limits
62
+ siny_cosp = 2.0 * (w * z + x * y)
63
+ cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
40
64
 
41
- return jnp.concatenate([roll, pitch, yaw], axis=-1)
65
+ return jax.lax.atan2(siny_cosp, cosy_cosp)
42
66
 
43
67
 
44
68
  def euler_to_quat(euler_3: Array) -> Array:
@@ -89,7 +113,12 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
89
113
  return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
90
114
 
91
115
 
92
- def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
116
+ def rotate_vector_by_quat(
117
+ vector: Array,
118
+ quat: Array,
119
+ inverse: bool = False,
120
+ eps: float = 1e-6,
121
+ ) -> Array:
93
122
  """Rotates a vector by a quaternion.
94
123
 
95
124
  Args:
@@ -82,7 +82,7 @@ class ArtifactsMixin(BaseTask[Config]):
82
82
  return self._exp_dir
83
83
 
84
84
  def get_exp_dir(run_id: int) -> Path:
85
- return self.run_dir / f"run_{run_id}"
85
+ return self.run_dir / f"run_{run_id:03d}"
86
86
 
87
87
  run_id = 0
88
88
  while (exp_dir := get_exp_dir(run_id)).is_dir():
xax/task/mixins/train.py CHANGED
@@ -678,7 +678,7 @@ class TrainMixin(
678
678
 
679
679
  def log_state(self) -> None:
680
680
  logger.log(LOG_STATUS, self.task_path)
681
- logger.log(LOG_STATUS, self.task_name)
681
+ logger.log(LOG_STATUS, self.exp_dir)
682
682
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
683
683
  self.logger.log_file("state.txt", get_state_file_string(self))
684
684
  self.logger.log_file("training_code.py", get_training_code(self))
xax/utils/debugging.py CHANGED
@@ -51,9 +51,25 @@ def get_named_leaves(
51
51
  return ret
52
52
 
53
53
 
54
- def breakpoint_if_nan(x: Array) -> None:
55
- jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
54
+ def breakpoint_if_nonfinite(x: Array) -> None:
55
+ is_finite = jnp.isfinite(x).all()
56
56
 
57
+ def true_fn(x: Array) -> None:
58
+ pass
57
59
 
58
- def log_if_nan(x: Array, loc: str) -> None:
59
- jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
60
+ def false_fn(x: Array) -> None:
61
+ jax.debug.breakpoint()
62
+
63
+ jax.lax.cond(is_finite, true_fn, false_fn, x)
64
+
65
+
66
+ def log_if_nonfinite(x: Array, loc: str) -> None:
67
+ is_finite = jnp.isfinite(x).all()
68
+
69
+ def true_fn(x: Array) -> None:
70
+ pass
71
+
72
+ def false_fn(x: Array) -> None:
73
+ jax.debug.print("=== NaNs: {loc} ===", loc=loc)
74
+
75
+ jax.lax.cond(is_finite, true_fn, false_fn, x)
xax/utils/pytree.py CHANGED
@@ -274,6 +274,9 @@ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
274
274
 
275
275
  # Handles dataclasses.
276
276
  if is_dataclass(tree_a) and is_dataclass(tree_b):
277
+ if type(tree_a) is not type(tree_b):
278
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
279
+ return diffs
277
280
  for field in fields(tree_a):
278
281
  attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
279
282
  diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
@@ -330,10 +333,5 @@ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
330
333
  diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
331
334
  return diffs
332
335
 
333
- # Handle mismatched types
334
- elif type(tree_a) is not type(tree_b):
335
- diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
336
- return diffs
337
-
338
336
  else:
339
337
  raise ValueError(f"Unknown type: {type(tree_a)}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.12
3
+ Version: 0.3.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=HXD6tR7Bz1b5ImFyRyR1kAok-dx5g8eBDpO_lCIP8rk,16782
1
+ xax/__init__.py,sha256=l3acv85D5Sq8IEv1tajSuCVY_eTGt8iJGnu_JuONB48,16944
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -12,7 +12,7 @@ xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
12
12
  xax/nn/distributions.py,sha256=6YOjyiPOC7XLDaMYpFNBlLCu3eLgDAeqIg9FoKfYLL4,6497
13
13
  xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
14
14
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
15
- xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
15
+ xax/nn/geom.py,sha256=ataKbQFXTebK9fM10CFyxsHOPGXhn26P4jakoc9Wqek,11424
16
16
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
17
17
  xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
18
18
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
@@ -33,7 +33,7 @@ xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,140
33
33
  xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
34
34
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
35
35
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
36
- xax/task/mixins/artifacts.py,sha256=R-y3p7__zJHlHDqwDVAZysg2ZmebCJbqAx_xGT2Xpd0,3857
36
+ xax/task/mixins/artifacts.py,sha256=UN26TW22ARduO6Bjs0yRu4-V6-Md9MPbXLKDnS28m44,3861
37
37
  xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
38
38
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
39
39
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
@@ -43,16 +43,16 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
43
43
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
44
44
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
45
45
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
46
- xax/task/mixins/train.py,sha256=hwAR_G1kgvhXgrE5ZRNL4Jn-Teflx65_1bdk6aULXEg,32814
46
+ xax/task/mixins/train.py,sha256=qb0zpsyeCk_U8Sk8THxtXkUVwj5r0lOlMLNRTctvcWU,32812
47
47
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
- xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
48
+ xax/utils/debugging.py,sha256=85JYIdnzLnvXsuli-4YHei_3tE3DnX3rmDSARKW2u1M,2192
49
49
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
50
50
  xax/utils/jax.py,sha256=6cP95-rcjkRt1fefkZWJQhJhH0uUYWJB3w4NP1-aDp0,10136
51
51
  xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
52
52
  xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
53
53
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
54
54
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
55
- xax/utils/pytree.py,sha256=e8T5DY0ZhPcbvS3EuOsac0Oprra46lN05WEIhVN-3V0,12670
55
+ xax/utils/pytree.py,sha256=qC7OfCydX3N5yDIgcWwiXFIdpQZg3uxgBP2H85eNmzQ,12649
56
56
  xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
57
57
  xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
58
58
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -60,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
60
60
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
61
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
62
62
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
63
- xax-0.3.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
- xax-0.3.12.dist-info/METADATA,sha256=RACxHJ_iF4r0BTTTgyTI1ExYF_-aXRWrsq3NlQC7l9A,1247
65
- xax-0.3.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
- xax-0.3.12.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
- xax-0.3.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
- xax-0.3.12.dist-info/RECORD,,
63
+ xax-0.3.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
+ xax-0.3.14.dist-info/METADATA,sha256=eb-f4GhyPyCizmbj87lEg7CE7ufKtAf7uEGLI9mBms4,1247
65
+ xax-0.3.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
+ xax-0.3.14.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
+ xax-0.3.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
+ xax-0.3.14.dist-info/RECORD,,
File without changes