xax 0.3.12__tar.gz → 0.3.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.12/xax.egg-info → xax-0.3.13}/PKG-INFO +1 -1
  2. {xax-0.3.12 → xax-0.3.13}/xax/__init__.py +17 -7
  3. {xax-0.3.12 → xax-0.3.13}/xax/nn/geom.py +42 -13
  4. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/train.py +1 -1
  5. {xax-0.3.12 → xax-0.3.13}/xax/utils/debugging.py +20 -4
  6. {xax-0.3.12 → xax-0.3.13}/xax/utils/pytree.py +3 -5
  7. {xax-0.3.12 → xax-0.3.13/xax.egg-info}/PKG-INFO +1 -1
  8. {xax-0.3.12 → xax-0.3.13}/LICENSE +0 -0
  9. {xax-0.3.12 → xax-0.3.13}/MANIFEST.in +0 -0
  10. {xax-0.3.12 → xax-0.3.13}/README.md +0 -0
  11. {xax-0.3.12 → xax-0.3.13}/pyproject.toml +0 -0
  12. {xax-0.3.12 → xax-0.3.13}/setup.cfg +0 -0
  13. {xax-0.3.12 → xax-0.3.13}/setup.py +0 -0
  14. {xax-0.3.12 → xax-0.3.13}/xax/cli/__init__.py +0 -0
  15. {xax-0.3.12 → xax-0.3.13}/xax/cli/edit_config.py +0 -0
  16. {xax-0.3.12 → xax-0.3.13}/xax/core/__init__.py +0 -0
  17. {xax-0.3.12 → xax-0.3.13}/xax/core/conf.py +0 -0
  18. {xax-0.3.12 → xax-0.3.13}/xax/core/state.py +0 -0
  19. {xax-0.3.12 → xax-0.3.13}/xax/nn/__init__.py +0 -0
  20. {xax-0.3.12 → xax-0.3.13}/xax/nn/attention.py +0 -0
  21. {xax-0.3.12 → xax-0.3.13}/xax/nn/distributions.py +0 -0
  22. {xax-0.3.12 → xax-0.3.13}/xax/nn/embeddings.py +0 -0
  23. {xax-0.3.12 → xax-0.3.13}/xax/nn/functions.py +0 -0
  24. {xax-0.3.12 → xax-0.3.13}/xax/nn/losses.py +0 -0
  25. {xax-0.3.12 → xax-0.3.13}/xax/nn/metrics.py +0 -0
  26. {xax-0.3.12 → xax-0.3.13}/xax/nn/parallel.py +0 -0
  27. {xax-0.3.12 → xax-0.3.13}/xax/nn/ssm.py +0 -0
  28. {xax-0.3.12 → xax-0.3.13}/xax/py.typed +0 -0
  29. {xax-0.3.12 → xax-0.3.13}/xax/requirements-dev.txt +0 -0
  30. {xax-0.3.12 → xax-0.3.13}/xax/requirements.txt +0 -0
  31. {xax-0.3.12 → xax-0.3.13}/xax/task/__init__.py +0 -0
  32. {xax-0.3.12 → xax-0.3.13}/xax/task/base.py +0 -0
  33. {xax-0.3.12 → xax-0.3.13}/xax/task/launchers/__init__.py +0 -0
  34. {xax-0.3.12 → xax-0.3.13}/xax/task/launchers/base.py +0 -0
  35. {xax-0.3.12 → xax-0.3.13}/xax/task/launchers/cli.py +0 -0
  36. {xax-0.3.12 → xax-0.3.13}/xax/task/launchers/single_process.py +0 -0
  37. {xax-0.3.12 → xax-0.3.13}/xax/task/logger.py +0 -0
  38. {xax-0.3.12 → xax-0.3.13}/xax/task/loggers/__init__.py +0 -0
  39. {xax-0.3.12 → xax-0.3.13}/xax/task/loggers/callback.py +0 -0
  40. {xax-0.3.12 → xax-0.3.13}/xax/task/loggers/json.py +0 -0
  41. {xax-0.3.12 → xax-0.3.13}/xax/task/loggers/state.py +0 -0
  42. {xax-0.3.12 → xax-0.3.13}/xax/task/loggers/stdout.py +0 -0
  43. {xax-0.3.12 → xax-0.3.13}/xax/task/loggers/tensorboard.py +0 -0
  44. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/__init__.py +0 -0
  45. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/artifacts.py +0 -0
  46. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/checkpointing.py +0 -0
  47. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/data_loader.py +0 -0
  50. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/gpu_stats.py +0 -0
  51. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/logger.py +0 -0
  52. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/process.py +0 -0
  53. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/runnable.py +0 -0
  54. {xax-0.3.12 → xax-0.3.13}/xax/task/mixins/step_wrapper.py +0 -0
  55. {xax-0.3.12 → xax-0.3.13}/xax/task/script.py +0 -0
  56. {xax-0.3.12 → xax-0.3.13}/xax/task/task.py +0 -0
  57. {xax-0.3.12 → xax-0.3.13}/xax/utils/__init__.py +0 -0
  58. {xax-0.3.12 → xax-0.3.13}/xax/utils/data/__init__.py +0 -0
  59. {xax-0.3.12 → xax-0.3.13}/xax/utils/data/collate.py +0 -0
  60. {xax-0.3.12 → xax-0.3.13}/xax/utils/experiments.py +0 -0
  61. {xax-0.3.12 → xax-0.3.13}/xax/utils/jax.py +0 -0
  62. {xax-0.3.12 → xax-0.3.13}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.3.12 → xax-0.3.13}/xax/utils/logging.py +0 -0
  64. {xax-0.3.12 → xax-0.3.13}/xax/utils/numpy.py +0 -0
  65. {xax-0.3.12 → xax-0.3.13}/xax/utils/profile.py +0 -0
  66. {xax-0.3.12 → xax-0.3.13}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.3.12 → xax-0.3.13}/xax/utils/text.py +0 -0
  68. {xax-0.3.12 → xax-0.3.13}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.3.12 → xax-0.3.13}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.3.12 → xax-0.3.13}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.3.12 → xax-0.3.13}/xax.egg-info/SOURCES.txt +0 -0
  72. {xax-0.3.12 → xax-0.3.13}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.12 → xax-0.3.13}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.12 → xax-0.3.13}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.12 → xax-0.3.13}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.12
3
+ Version: 0.3.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.12"
15
+ __version__ = "0.3.13"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -53,6 +53,7 @@ __all__ = [
53
53
  "quat_mul",
54
54
  "quat_to_euler",
55
55
  "quat_to_rotmat",
56
+ "quat_to_yaw",
56
57
  "rotate_vector_by_quat",
57
58
  "rotation6d_to_rotation_matrix",
58
59
  "rotation_matrix_to_quat",
@@ -100,9 +101,9 @@ __all__ = [
100
101
  "Task",
101
102
  "collate",
102
103
  "collate_non_null",
103
- "breakpoint_if_nan",
104
+ "breakpoint_if_nonfinite",
104
105
  "get_named_leaves",
105
- "log_if_nan",
106
+ "log_if_nonfinite",
106
107
  "BaseFileDownloader",
107
108
  "ContextTimer",
108
109
  "CumulativeTimer",
@@ -198,7 +199,10 @@ if "XLA_FLAGS" in os.environ:
198
199
  # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
199
200
  # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
200
201
  if shutil.which("nvidia-smi") is not None:
201
- xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
202
+ xla_flags += [
203
+ "--xla_gpu_enable_latency_hiding_scheduler=true",
204
+ "--xla_gpu_enable_triton_gemm=false",
205
+ ]
202
206
  os.environ["XLA_FLAGS"] = " ".join(xla_flags)
203
207
 
204
208
  # If this flag is set, eagerly imports the entire package (not recommended).
@@ -246,6 +250,7 @@ NAME_MAP: dict[str, str] = {
246
250
  "quat_mul": "nn.geom",
247
251
  "quat_to_euler": "nn.geom",
248
252
  "quat_to_rotmat": "nn.geom",
253
+ "quat_to_yaw": "nn.geom",
249
254
  "rotate_vector_by_quat": "nn.geom",
250
255
  "rotation6d_to_rotation_matrix": "nn.geom",
251
256
  "rotation_matrix_to_quat": "nn.geom",
@@ -293,9 +298,9 @@ NAME_MAP: dict[str, str] = {
293
298
  "Task": "task.task",
294
299
  "collate": "utils.data.collate",
295
300
  "collate_non_null": "utils.data.collate",
296
- "breakpoint_if_nan": "utils.debugging",
301
+ "breakpoint_if_nonfinite": "utils.debugging",
297
302
  "get_named_leaves": "utils.debugging",
298
- "log_if_nan": "utils.debugging",
303
+ "log_if_nonfinite": "utils.debugging",
299
304
  "BaseFileDownloader": "utils.experiments",
300
305
  "ContextTimer": "utils.experiments",
301
306
  "CumulativeTimer": "utils.experiments",
@@ -443,6 +448,7 @@ if IMPORT_ALL or TYPE_CHECKING:
443
448
  quat_mul,
444
449
  quat_to_euler,
445
450
  quat_to_rotmat,
451
+ quat_to_yaw,
446
452
  rotate_vector_by_quat,
447
453
  rotation6d_to_rotation_matrix,
448
454
  rotation_matrix_to_quat,
@@ -486,7 +492,11 @@ if IMPORT_ALL or TYPE_CHECKING:
486
492
  from xax.task.script import Script, ScriptConfig
487
493
  from xax.task.task import Config, Task
488
494
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
489
- from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
495
+ from xax.utils.debugging import (
496
+ breakpoint_if_nonfinite,
497
+ get_named_leaves,
498
+ log_if_nonfinite,
499
+ )
490
500
  from xax.utils.experiments import (
491
501
  BaseFileDownloader,
492
502
  ContextTimer,
@@ -1,6 +1,7 @@
1
1
  """Defines geometry functions."""
2
2
 
3
3
  import chex
4
+ import jax
4
5
  from jax import numpy as jnp
5
6
  from jaxtyping import Array
6
7
 
@@ -15,30 +16,53 @@ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
15
16
  Returns:
16
17
  The roll, pitch, yaw angles with shape (*, 3).
17
18
  """
18
- quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
19
- w, x, y, z = jnp.split(quat_4, 4, axis=-1)
19
+ # Normalize with clamping
20
+ norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
21
+ inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
22
+ quat_4 = quat_4 * inv_norm
23
+
24
+ w, x, y, z = jnp.unstack(quat_4, axis=-1)
20
25
 
21
26
  # Roll (x-axis rotation)
22
27
  sinr_cosp = 2.0 * (w * x + y * z)
23
28
  cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
24
- roll = jnp.arctan2(sinr_cosp, cosr_cosp)
29
+ roll = jax.lax.atan2(sinr_cosp, cosr_cosp)
25
30
 
26
31
  # Pitch (y-axis rotation)
27
32
  sinp = 2.0 * (w * y - z * x)
28
-
29
- # Handle edge cases where |sinp| >= 1
30
- pitch = jnp.where(
31
- jnp.abs(sinp) >= 1.0,
32
- jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
33
- jnp.arcsin(sinp),
34
- )
33
+ sinp = jnp.clip(sinp, -1.0, 1.0) # Clamp to valid domain
34
+ pitch = jax.lax.asin(sinp)
35
35
 
36
36
  # Yaw (z-axis rotation)
37
37
  siny_cosp = 2.0 * (w * z + x * y)
38
38
  cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
39
- yaw = jnp.arctan2(siny_cosp, cosy_cosp)
39
+ yaw = jax.lax.atan2(siny_cosp, cosy_cosp)
40
+
41
+ return jnp.stack([roll, pitch, yaw], axis=-1)
42
+
43
+
44
+ def quat_to_yaw(quat_4: Array, eps: float = 1e-6) -> Array:
45
+ """Converts a quaternion to a yaw angle.
46
+
47
+ Args:
48
+ quat_4: The quaternion to convert, shape (*, 4).
49
+ eps: A small epsilon value to avoid division by zero.
50
+
51
+ Returns:
52
+ The yaw angle, shape (*).
53
+ """
54
+ # Normalize using a max + safe norm to handle extremely small values robustly
55
+ norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
56
+ inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
57
+ quat_4 = quat_4 * inv_norm
58
+
59
+ w, x, y, z = jnp.unstack(quat_4, axis=-1)
60
+
61
+ # Compute components with clamping to avoid rounding errors near limits
62
+ siny_cosp = 2.0 * (w * z + x * y)
63
+ cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
40
64
 
41
- return jnp.concatenate([roll, pitch, yaw], axis=-1)
65
+ return jax.lax.atan2(siny_cosp, cosy_cosp)
42
66
 
43
67
 
44
68
  def euler_to_quat(euler_3: Array) -> Array:
@@ -89,7 +113,12 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
89
113
  return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
90
114
 
91
115
 
92
- def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
116
+ def rotate_vector_by_quat(
117
+ vector: Array,
118
+ quat: Array,
119
+ inverse: bool = False,
120
+ eps: float = 1e-6,
121
+ ) -> Array:
93
122
  """Rotates a vector by a quaternion.
94
123
 
95
124
  Args:
@@ -678,7 +678,7 @@ class TrainMixin(
678
678
 
679
679
  def log_state(self) -> None:
680
680
  logger.log(LOG_STATUS, self.task_path)
681
- logger.log(LOG_STATUS, self.task_name)
681
+ logger.log(LOG_STATUS, self.exp_dir)
682
682
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
683
683
  self.logger.log_file("state.txt", get_state_file_string(self))
684
684
  self.logger.log_file("training_code.py", get_training_code(self))
@@ -51,9 +51,25 @@ def get_named_leaves(
51
51
  return ret
52
52
 
53
53
 
54
- def breakpoint_if_nan(x: Array) -> None:
55
- jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
54
+ def breakpoint_if_nonfinite(x: Array) -> None:
55
+ is_finite = jnp.isfinite(x).all()
56
56
 
57
+ def true_fn(x: Array) -> None:
58
+ pass
57
59
 
58
- def log_if_nan(x: Array, loc: str) -> None:
59
- jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
60
+ def false_fn(x: Array) -> None:
61
+ jax.debug.breakpoint()
62
+
63
+ jax.lax.cond(is_finite, true_fn, false_fn, x)
64
+
65
+
66
+ def log_if_nonfinite(x: Array, loc: str) -> None:
67
+ is_finite = jnp.isfinite(x).all()
68
+
69
+ def true_fn(x: Array) -> None:
70
+ pass
71
+
72
+ def false_fn(x: Array) -> None:
73
+ jax.debug.print("=== NaNs: {loc} ===", loc=loc)
74
+
75
+ jax.lax.cond(is_finite, true_fn, false_fn, x)
@@ -274,6 +274,9 @@ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
274
274
 
275
275
  # Handles dataclasses.
276
276
  if is_dataclass(tree_a) and is_dataclass(tree_b):
277
+ if type(tree_a) is not type(tree_b):
278
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
279
+ return diffs
277
280
  for field in fields(tree_a):
278
281
  attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
279
282
  diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
@@ -330,10 +333,5 @@ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
330
333
  diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
331
334
  return diffs
332
335
 
333
- # Handle mismatched types
334
- elif type(tree_a) is not type(tree_b):
335
- diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
336
- return diffs
337
-
338
336
  else:
339
337
  raise ValueError(f"Unknown type: {type(tree_a)}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.12
3
+ Version: 0.3.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes