xax 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +7 -2
- xax/nn/geom.py +26 -5
- xax/task/base.py +0 -3
- xax/task/mixins/checkpointing.py +25 -8
- xax/task/mixins/compile.py +8 -0
- xax/task/mixins/train.py +3 -5
- xax/utils/debugging.py +4 -0
- xax/utils/experiments.py +14 -0
- {xax-0.1.13.dist-info → xax-0.1.15.dist-info}/METADATA +1 -1
- {xax-0.1.13.dist-info → xax-0.1.15.dist-info}/RECORD +13 -13
- {xax-0.1.13.dist-info → xax-0.1.15.dist-info}/WHEEL +0 -0
- {xax-0.1.13.dist-info → xax-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.13.dist-info → xax-0.1.15.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.15"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -40,6 +40,7 @@ __all__ = [
|
|
40
40
|
"load_eqx_mlp",
|
41
41
|
"make_eqx_mlp",
|
42
42
|
"save_eqx",
|
43
|
+
"cubic_bezier_interpolation",
|
43
44
|
"euler_to_quat",
|
44
45
|
"get_projected_gravity_vector_from_quat",
|
45
46
|
"quat_to_euler",
|
@@ -77,6 +78,7 @@ __all__ = [
|
|
77
78
|
"collate_non_null",
|
78
79
|
"breakpoint_if_nan",
|
79
80
|
"get_named_leaves",
|
81
|
+
"log_if_nan",
|
80
82
|
"BaseFileDownloader",
|
81
83
|
"ContextTimer",
|
82
84
|
"CumulativeTimer",
|
@@ -200,6 +202,7 @@ NAME_MAP: dict[str, str] = {
|
|
200
202
|
"load_eqx_mlp": "nn.equinox",
|
201
203
|
"make_eqx_mlp": "nn.equinox",
|
202
204
|
"save_eqx": "nn.equinox",
|
205
|
+
"cubic_bezier_interpolation": "nn.geom",
|
203
206
|
"euler_to_quat": "nn.geom",
|
204
207
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
205
208
|
"quat_to_euler": "nn.geom",
|
@@ -237,6 +240,7 @@ NAME_MAP: dict[str, str] = {
|
|
237
240
|
"collate_non_null": "utils.data.collate",
|
238
241
|
"breakpoint_if_nan": "utils.debugging",
|
239
242
|
"get_named_leaves": "utils.debugging",
|
243
|
+
"log_if_nan": "utils.debugging",
|
240
244
|
"BaseFileDownloader": "utils.experiments",
|
241
245
|
"ContextTimer": "utils.experiments",
|
242
246
|
"CumulativeTimer": "utils.experiments",
|
@@ -361,6 +365,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
361
365
|
save_eqx,
|
362
366
|
)
|
363
367
|
from xax.nn.geom import (
|
368
|
+
cubic_bezier_interpolation,
|
364
369
|
euler_to_quat,
|
365
370
|
get_projected_gravity_vector_from_quat,
|
366
371
|
quat_to_euler,
|
@@ -388,7 +393,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
388
393
|
from xax.task.script import Script, ScriptConfig
|
389
394
|
from xax.task.task import Config, Task
|
390
395
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
391
|
-
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
|
396
|
+
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
|
392
397
|
from xax.utils.experiments import (
|
393
398
|
BaseFileDownloader,
|
394
399
|
ContextTimer,
|
xax/nn/geom.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
|
-
import jax
|
4
3
|
from jax import numpy as jnp
|
4
|
+
from jaxtyping import Array
|
5
5
|
|
6
6
|
|
7
|
-
def quat_to_euler(quat_4:
|
7
|
+
def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
|
8
8
|
"""Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
|
9
9
|
|
10
10
|
Args:
|
@@ -40,7 +40,7 @@ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
|
|
40
40
|
return jnp.concatenate([roll, pitch, yaw], axis=-1)
|
41
41
|
|
42
42
|
|
43
|
-
def euler_to_quat(euler_3:
|
43
|
+
def euler_to_quat(euler_3: Array) -> Array:
|
44
44
|
"""Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
|
45
45
|
|
46
46
|
Args:
|
@@ -75,7 +75,7 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
|
|
75
75
|
return quat
|
76
76
|
|
77
77
|
|
78
|
-
def get_projected_gravity_vector_from_quat(quat:
|
78
|
+
def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Array:
|
79
79
|
"""Calculates the gravity vector projected onto the local frame given a quaternion orientation.
|
80
80
|
|
81
81
|
Args:
|
@@ -101,7 +101,7 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
|
|
101
101
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
102
102
|
|
103
103
|
|
104
|
-
def rotate_vector_by_quat(vector:
|
104
|
+
def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
|
105
105
|
"""Rotates a vector by a quaternion.
|
106
106
|
|
107
107
|
Args:
|
@@ -156,3 +156,24 @@ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6)
|
|
156
156
|
)
|
157
157
|
|
158
158
|
return jnp.concatenate([xx, yy, zz], axis=-1)
|
159
|
+
|
160
|
+
|
161
|
+
def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
|
162
|
+
"""Cubic bezier interpolation.
|
163
|
+
|
164
|
+
This is a cubic bezier curve that starts at y_start and ends at y_end,
|
165
|
+
and is controlled by the parameter x. The curve is defined by the following formula:
|
166
|
+
|
167
|
+
y(x) = y_start + (y_end - y_start) * (x**3 + 3 * (x**2 * (1 - x)))
|
168
|
+
|
169
|
+
Args:
|
170
|
+
y_start: The start value, shape (*).
|
171
|
+
y_end: The end value, shape (*).
|
172
|
+
x: The interpolation parameter, shape (*).
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The interpolated value, shape (*).
|
176
|
+
"""
|
177
|
+
y_diff = y_end - y_start
|
178
|
+
bezier = x**3 + 3 * (x**2 * (1 - x))
|
179
|
+
return y_start + y_diff * bezier
|
xax/task/base.py
CHANGED
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
|
|
82
82
|
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
83
83
|
return state
|
84
84
|
|
85
|
-
def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
|
86
|
-
pass
|
87
|
-
|
88
85
|
@functools.cached_property
|
89
86
|
def task_class_name(self) -> str:
|
90
87
|
return self.__class__.__name__
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
98
98
|
) -> tuple[PyTree, State, DictConfig]: ...
|
99
99
|
|
100
100
|
@overload
|
101
|
-
def load_checkpoint(
|
101
|
+
def load_checkpoint(
|
102
|
+
self,
|
103
|
+
path: Path,
|
104
|
+
part: Literal["model"],
|
105
|
+
) -> PyTree: ...
|
102
106
|
|
103
107
|
@overload
|
104
|
-
def load_checkpoint(
|
108
|
+
def load_checkpoint(
|
109
|
+
self,
|
110
|
+
path: Path,
|
111
|
+
part: Literal["opt"],
|
112
|
+
) -> optax.GradientTransformation: ...
|
105
113
|
|
106
114
|
@overload
|
107
|
-
def load_checkpoint(
|
115
|
+
def load_checkpoint(
|
116
|
+
self,
|
117
|
+
path: Path,
|
118
|
+
part: Literal["opt_state"],
|
119
|
+
) -> optax.OptState: ...
|
108
120
|
|
109
121
|
@overload
|
110
|
-
def load_checkpoint(
|
122
|
+
def load_checkpoint(
|
123
|
+
self,
|
124
|
+
path: Path,
|
125
|
+
part: Literal["state"],
|
126
|
+
) -> State: ...
|
111
127
|
|
112
128
|
@overload
|
113
|
-
def load_checkpoint(
|
129
|
+
def load_checkpoint(
|
130
|
+
self,
|
131
|
+
path: Path,
|
132
|
+
part: Literal["config"],
|
133
|
+
) -> DictConfig: ...
|
114
134
|
|
115
135
|
def load_checkpoint(
|
116
136
|
self,
|
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
125
145
|
| State
|
126
146
|
| DictConfig
|
127
147
|
):
|
128
|
-
# Calls the base callback.
|
129
|
-
self.on_before_checkpoint_load(path)
|
130
|
-
|
131
148
|
with tarfile.open(path, "r:gz") as tar:
|
132
149
|
|
133
150
|
def get_model() -> PyTree:
|
xax/task/mixins/compile.py
CHANGED
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
|
|
32
32
|
@dataclass
|
33
33
|
class CompileOptions:
|
34
34
|
# JAX compilation options
|
35
|
+
debug_nans: bool = field(
|
36
|
+
value=False,
|
37
|
+
help="If True, breaks on NaNs",
|
38
|
+
)
|
35
39
|
disable_jit: bool = field(
|
36
40
|
value=False,
|
37
41
|
help="If True, disables JIT compilation",
|
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
|
|
89
93
|
cc = self.config.compile
|
90
94
|
|
91
95
|
# Set basic compilation flags
|
96
|
+
if cc.debug_nans:
|
97
|
+
logger.info("Enabling NaNs debugging")
|
98
|
+
jax.config.update("jax_debug_nans", True)
|
99
|
+
|
92
100
|
if cc.disable_jit:
|
93
101
|
logger.info("Disabling JIT compilation")
|
94
102
|
jax.config.update("jax_disable_jit", True)
|
xax/task/mixins/train.py
CHANGED
@@ -50,8 +50,7 @@ from xax.utils.experiments import (
|
|
50
50
|
TrainingFinishedError,
|
51
51
|
diff_configs,
|
52
52
|
get_diff_string,
|
53
|
-
|
54
|
-
get_packages_with_versions,
|
53
|
+
get_state_file_string,
|
55
54
|
get_training_code,
|
56
55
|
)
|
57
56
|
from xax.utils.jax import jit as xax_jit
|
@@ -534,9 +533,8 @@ class TrainMixin(
|
|
534
533
|
logger.log(LOG_STATUS, self.task_path)
|
535
534
|
logger.log(LOG_STATUS, self.task_name)
|
536
535
|
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
537
|
-
self.logger.log_file("
|
538
|
-
self.logger.log_file("
|
539
|
-
self.logger.log_file("training_code.txt", get_training_code(self))
|
536
|
+
self.logger.log_file("state.txt", get_state_file_string(self))
|
537
|
+
self.logger.log_file("training_code.py", get_training_code(self))
|
540
538
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
541
539
|
|
542
540
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
xax/utils/debugging.py
CHANGED
@@ -53,3 +53,7 @@ def get_named_leaves(
|
|
53
53
|
|
54
54
|
def breakpoint_if_nan(x: Array) -> None:
|
55
55
|
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
|
56
|
+
|
57
|
+
|
58
|
+
def log_if_nan(x: Array, loc: str) -> None:
|
59
|
+
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
|
xax/utils/experiments.py
CHANGED
@@ -479,6 +479,20 @@ def get_packages_with_versions() -> str:
|
|
479
479
|
return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
|
480
480
|
|
481
481
|
|
482
|
+
def get_command_line_string() -> str:
|
483
|
+
return " ".join(sys.argv)
|
484
|
+
|
485
|
+
|
486
|
+
def get_state_file_string(obj: object) -> str:
|
487
|
+
return "\n\n".join(
|
488
|
+
[
|
489
|
+
f"=== Command Line ===\n\n{get_command_line_string()}",
|
490
|
+
f"=== Git State ===\n\n{get_git_state(obj)}",
|
491
|
+
f"=== Packages ===\n\n{get_packages_with_versions()}",
|
492
|
+
]
|
493
|
+
)
|
494
|
+
|
495
|
+
|
482
496
|
def get_training_code(obj: object) -> str:
|
483
497
|
"""Gets the text from the file containing the provided object.
|
484
498
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=bV2mTcuiVaVNvwgbDgg7dKDkMeuyA0mqF0muU5KZHeg,14104
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -10,13 +10,13 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=PN0Ndn575aVtsSfxi67RghHB7luRkqtpS7bPbT1LpLE,5201
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
19
|
+
xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
|
20
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
@@ -32,8 +32,8 @@ xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,67
|
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
36
|
-
xax/task/mixins/compile.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
|
36
|
+
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
38
38
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
39
39
|
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
@@ -41,10 +41,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=1hmUx1HIL8HKfwOnupS3Knsw1CiK2YCbIQnUTYyDEms,26157
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
|
-
xax/utils/debugging.py,sha256=
|
47
|
-
xax/utils/experiments.py,sha256=
|
46
|
+
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
+
xax/utils/experiments.py,sha256=X6MESZ3z_Z0DLH6NQucuPzibuOc6rZmlf5UZt4in458,29591
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.1.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.15.dist-info/METADATA,sha256=i5thFSTL1Zx03UpnCj7f71rxSgs0P3L6ZDd6vYEtM7U,1878
|
63
|
+
xax-0.1.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|