xax 0.3.12__tar.gz → 0.3.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.12/xax.egg-info → xax-0.3.14}/PKG-INFO +1 -1
- {xax-0.3.12 → xax-0.3.14}/xax/__init__.py +17 -7
- {xax-0.3.12 → xax-0.3.14}/xax/nn/geom.py +42 -13
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/artifacts.py +1 -1
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/train.py +1 -1
- {xax-0.3.12 → xax-0.3.14}/xax/utils/debugging.py +20 -4
- {xax-0.3.12 → xax-0.3.14}/xax/utils/pytree.py +3 -5
- {xax-0.3.12 → xax-0.3.14/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.12 → xax-0.3.14}/LICENSE +0 -0
- {xax-0.3.12 → xax-0.3.14}/MANIFEST.in +0 -0
- {xax-0.3.12 → xax-0.3.14}/README.md +0 -0
- {xax-0.3.12 → xax-0.3.14}/pyproject.toml +0 -0
- {xax-0.3.12 → xax-0.3.14}/setup.cfg +0 -0
- {xax-0.3.12 → xax-0.3.14}/setup.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/cli/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/cli/edit_config.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/core/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/core/conf.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/core/state.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/attention.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/distributions.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/embeddings.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/functions.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/losses.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/metrics.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/parallel.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/nn/ssm.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/py.typed +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/requirements-dev.txt +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/requirements.txt +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/base.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/launchers/base.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/logger.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/loggers/json.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/loggers/state.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/process.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/script.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/task/task.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/data/collate.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/experiments.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/jax.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/logging.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/numpy.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/profile.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/text.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.12 → xax-0.3.14}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.14"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -53,6 +53,7 @@ __all__ = [
|
|
53
53
|
"quat_mul",
|
54
54
|
"quat_to_euler",
|
55
55
|
"quat_to_rotmat",
|
56
|
+
"quat_to_yaw",
|
56
57
|
"rotate_vector_by_quat",
|
57
58
|
"rotation6d_to_rotation_matrix",
|
58
59
|
"rotation_matrix_to_quat",
|
@@ -100,9 +101,9 @@ __all__ = [
|
|
100
101
|
"Task",
|
101
102
|
"collate",
|
102
103
|
"collate_non_null",
|
103
|
-
"
|
104
|
+
"breakpoint_if_nonfinite",
|
104
105
|
"get_named_leaves",
|
105
|
-
"
|
106
|
+
"log_if_nonfinite",
|
106
107
|
"BaseFileDownloader",
|
107
108
|
"ContextTimer",
|
108
109
|
"CumulativeTimer",
|
@@ -198,7 +199,10 @@ if "XLA_FLAGS" in os.environ:
|
|
198
199
|
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
199
200
|
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
200
201
|
if shutil.which("nvidia-smi") is not None:
|
201
|
-
xla_flags += [
|
202
|
+
xla_flags += [
|
203
|
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
204
|
+
"--xla_gpu_enable_triton_gemm=false",
|
205
|
+
]
|
202
206
|
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
203
207
|
|
204
208
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
@@ -246,6 +250,7 @@ NAME_MAP: dict[str, str] = {
|
|
246
250
|
"quat_mul": "nn.geom",
|
247
251
|
"quat_to_euler": "nn.geom",
|
248
252
|
"quat_to_rotmat": "nn.geom",
|
253
|
+
"quat_to_yaw": "nn.geom",
|
249
254
|
"rotate_vector_by_quat": "nn.geom",
|
250
255
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
251
256
|
"rotation_matrix_to_quat": "nn.geom",
|
@@ -293,9 +298,9 @@ NAME_MAP: dict[str, str] = {
|
|
293
298
|
"Task": "task.task",
|
294
299
|
"collate": "utils.data.collate",
|
295
300
|
"collate_non_null": "utils.data.collate",
|
296
|
-
"
|
301
|
+
"breakpoint_if_nonfinite": "utils.debugging",
|
297
302
|
"get_named_leaves": "utils.debugging",
|
298
|
-
"
|
303
|
+
"log_if_nonfinite": "utils.debugging",
|
299
304
|
"BaseFileDownloader": "utils.experiments",
|
300
305
|
"ContextTimer": "utils.experiments",
|
301
306
|
"CumulativeTimer": "utils.experiments",
|
@@ -443,6 +448,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
443
448
|
quat_mul,
|
444
449
|
quat_to_euler,
|
445
450
|
quat_to_rotmat,
|
451
|
+
quat_to_yaw,
|
446
452
|
rotate_vector_by_quat,
|
447
453
|
rotation6d_to_rotation_matrix,
|
448
454
|
rotation_matrix_to_quat,
|
@@ -486,7 +492,11 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
486
492
|
from xax.task.script import Script, ScriptConfig
|
487
493
|
from xax.task.task import Config, Task
|
488
494
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
489
|
-
from xax.utils.debugging import
|
495
|
+
from xax.utils.debugging import (
|
496
|
+
breakpoint_if_nonfinite,
|
497
|
+
get_named_leaves,
|
498
|
+
log_if_nonfinite,
|
499
|
+
)
|
490
500
|
from xax.utils.experiments import (
|
491
501
|
BaseFileDownloader,
|
492
502
|
ContextTimer,
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
3
|
import chex
|
4
|
+
import jax
|
4
5
|
from jax import numpy as jnp
|
5
6
|
from jaxtyping import Array
|
6
7
|
|
@@ -15,30 +16,53 @@ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
|
|
15
16
|
Returns:
|
16
17
|
The roll, pitch, yaw angles with shape (*, 3).
|
17
18
|
"""
|
18
|
-
|
19
|
-
|
19
|
+
# Normalize with clamping
|
20
|
+
norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
|
21
|
+
inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
|
22
|
+
quat_4 = quat_4 * inv_norm
|
23
|
+
|
24
|
+
w, x, y, z = jnp.unstack(quat_4, axis=-1)
|
20
25
|
|
21
26
|
# Roll (x-axis rotation)
|
22
27
|
sinr_cosp = 2.0 * (w * x + y * z)
|
23
28
|
cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
|
24
|
-
roll =
|
29
|
+
roll = jax.lax.atan2(sinr_cosp, cosr_cosp)
|
25
30
|
|
26
31
|
# Pitch (y-axis rotation)
|
27
32
|
sinp = 2.0 * (w * y - z * x)
|
28
|
-
|
29
|
-
|
30
|
-
pitch = jnp.where(
|
31
|
-
jnp.abs(sinp) >= 1.0,
|
32
|
-
jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
|
33
|
-
jnp.arcsin(sinp),
|
34
|
-
)
|
33
|
+
sinp = jnp.clip(sinp, -1.0, 1.0) # Clamp to valid domain
|
34
|
+
pitch = jax.lax.asin(sinp)
|
35
35
|
|
36
36
|
# Yaw (z-axis rotation)
|
37
37
|
siny_cosp = 2.0 * (w * z + x * y)
|
38
38
|
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
39
|
-
yaw =
|
39
|
+
yaw = jax.lax.atan2(siny_cosp, cosy_cosp)
|
40
|
+
|
41
|
+
return jnp.stack([roll, pitch, yaw], axis=-1)
|
42
|
+
|
43
|
+
|
44
|
+
def quat_to_yaw(quat_4: Array, eps: float = 1e-6) -> Array:
|
45
|
+
"""Converts a quaternion to a yaw angle.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
quat_4: The quaternion to convert, shape (*, 4).
|
49
|
+
eps: A small epsilon value to avoid division by zero.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The yaw angle, shape (*).
|
53
|
+
"""
|
54
|
+
# Normalize using a max + safe norm to handle extremely small values robustly
|
55
|
+
norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
|
56
|
+
inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
|
57
|
+
quat_4 = quat_4 * inv_norm
|
58
|
+
|
59
|
+
w, x, y, z = jnp.unstack(quat_4, axis=-1)
|
60
|
+
|
61
|
+
# Compute components with clamping to avoid rounding errors near limits
|
62
|
+
siny_cosp = 2.0 * (w * z + x * y)
|
63
|
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
40
64
|
|
41
|
-
return
|
65
|
+
return jax.lax.atan2(siny_cosp, cosy_cosp)
|
42
66
|
|
43
67
|
|
44
68
|
def euler_to_quat(euler_3: Array) -> Array:
|
@@ -89,7 +113,12 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
|
|
89
113
|
return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
|
90
114
|
|
91
115
|
|
92
|
-
def rotate_vector_by_quat(
|
116
|
+
def rotate_vector_by_quat(
|
117
|
+
vector: Array,
|
118
|
+
quat: Array,
|
119
|
+
inverse: bool = False,
|
120
|
+
eps: float = 1e-6,
|
121
|
+
) -> Array:
|
93
122
|
"""Rotates a vector by a quaternion.
|
94
123
|
|
95
124
|
Args:
|
@@ -82,7 +82,7 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
82
82
|
return self._exp_dir
|
83
83
|
|
84
84
|
def get_exp_dir(run_id: int) -> Path:
|
85
|
-
return self.run_dir / f"run_{run_id}"
|
85
|
+
return self.run_dir / f"run_{run_id:03d}"
|
86
86
|
|
87
87
|
run_id = 0
|
88
88
|
while (exp_dir := get_exp_dir(run_id)).is_dir():
|
@@ -678,7 +678,7 @@ class TrainMixin(
|
|
678
678
|
|
679
679
|
def log_state(self) -> None:
|
680
680
|
logger.log(LOG_STATUS, self.task_path)
|
681
|
-
logger.log(LOG_STATUS, self.
|
681
|
+
logger.log(LOG_STATUS, self.exp_dir)
|
682
682
|
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
683
683
|
self.logger.log_file("state.txt", get_state_file_string(self))
|
684
684
|
self.logger.log_file("training_code.py", get_training_code(self))
|
@@ -51,9 +51,25 @@ def get_named_leaves(
|
|
51
51
|
return ret
|
52
52
|
|
53
53
|
|
54
|
-
def
|
55
|
-
|
54
|
+
def breakpoint_if_nonfinite(x: Array) -> None:
|
55
|
+
is_finite = jnp.isfinite(x).all()
|
56
56
|
|
57
|
+
def true_fn(x: Array) -> None:
|
58
|
+
pass
|
57
59
|
|
58
|
-
def
|
59
|
-
|
60
|
+
def false_fn(x: Array) -> None:
|
61
|
+
jax.debug.breakpoint()
|
62
|
+
|
63
|
+
jax.lax.cond(is_finite, true_fn, false_fn, x)
|
64
|
+
|
65
|
+
|
66
|
+
def log_if_nonfinite(x: Array, loc: str) -> None:
|
67
|
+
is_finite = jnp.isfinite(x).all()
|
68
|
+
|
69
|
+
def true_fn(x: Array) -> None:
|
70
|
+
pass
|
71
|
+
|
72
|
+
def false_fn(x: Array) -> None:
|
73
|
+
jax.debug.print("=== NaNs: {loc} ===", loc=loc)
|
74
|
+
|
75
|
+
jax.lax.cond(is_finite, true_fn, false_fn, x)
|
@@ -274,6 +274,9 @@ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
|
|
274
274
|
|
275
275
|
# Handles dataclasses.
|
276
276
|
if is_dataclass(tree_a) and is_dataclass(tree_b):
|
277
|
+
if type(tree_a) is not type(tree_b):
|
278
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
279
|
+
return diffs
|
277
280
|
for field in fields(tree_a):
|
278
281
|
attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
|
279
282
|
diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
|
@@ -330,10 +333,5 @@ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
|
|
330
333
|
diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
|
331
334
|
return diffs
|
332
335
|
|
333
|
-
# Handle mismatched types
|
334
|
-
elif type(tree_a) is not type(tree_b):
|
335
|
-
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
336
|
-
return diffs
|
337
|
-
|
338
336
|
else:
|
339
337
|
raise ValueError(f"Unknown type: {type(tree_a)}")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|