xax 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +6 -2
- xax/task/base.py +0 -3
- xax/task/loggers/stdout.py +5 -6
- xax/task/mixins/checkpointing.py +25 -8
- xax/task/mixins/compile.py +8 -0
- xax/utils/debugging.py +10 -0
- {xax-0.1.12.dist-info → xax-0.1.14.dist-info}/METADATA +1 -1
- {xax-0.1.12.dist-info → xax-0.1.14.dist-info}/RECORD +11 -11
- {xax-0.1.12.dist-info → xax-0.1.14.dist-info}/WHEEL +0 -0
- {xax-0.1.12.dist-info → xax-0.1.14.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.12.dist-info → xax-0.1.14.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.14"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -75,7 +75,9 @@ __all__ = [
|
|
75
75
|
"Task",
|
76
76
|
"collate",
|
77
77
|
"collate_non_null",
|
78
|
+
"breakpoint_if_nan",
|
78
79
|
"get_named_leaves",
|
80
|
+
"log_if_nan",
|
79
81
|
"BaseFileDownloader",
|
80
82
|
"ContextTimer",
|
81
83
|
"CumulativeTimer",
|
@@ -234,7 +236,9 @@ NAME_MAP: dict[str, str] = {
|
|
234
236
|
"Task": "task.task",
|
235
237
|
"collate": "utils.data.collate",
|
236
238
|
"collate_non_null": "utils.data.collate",
|
239
|
+
"breakpoint_if_nan": "utils.debugging",
|
237
240
|
"get_named_leaves": "utils.debugging",
|
241
|
+
"log_if_nan": "utils.debugging",
|
238
242
|
"BaseFileDownloader": "utils.experiments",
|
239
243
|
"ContextTimer": "utils.experiments",
|
240
244
|
"CumulativeTimer": "utils.experiments",
|
@@ -386,7 +390,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
386
390
|
from xax.task.script import Script, ScriptConfig
|
387
391
|
from xax.task.task import Config, Task
|
388
392
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
389
|
-
from xax.utils.debugging import get_named_leaves
|
393
|
+
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
|
390
394
|
from xax.utils.experiments import (
|
391
395
|
BaseFileDownloader,
|
392
396
|
ContextTimer,
|
xax/task/base.py
CHANGED
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
|
|
82
82
|
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
83
83
|
return state
|
84
84
|
|
85
|
-
def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
|
86
|
-
pass
|
87
|
-
|
88
85
|
@functools.cached_property
|
89
86
|
def task_class_name(self) -> str:
|
90
87
|
return self.__class__.__name__
|
xax/task/loggers/stdout.py
CHANGED
@@ -14,7 +14,7 @@ from xax.utils.text import Color, colored, format_timedelta
|
|
14
14
|
|
15
15
|
def format_number(value: int | float, precision: int) -> str:
|
16
16
|
if isinstance(value, int):
|
17
|
-
return
|
17
|
+
return f"{value:,}" # Add commas to the number
|
18
18
|
return f"{value:.{precision}g}"
|
19
19
|
|
20
20
|
|
@@ -80,11 +80,10 @@ class StdoutLogger(LoggerImpl):
|
|
80
80
|
self.write_fp.write("\033[2J\033[H")
|
81
81
|
|
82
82
|
def write_state_window(self, line: LogLine) -> None:
|
83
|
-
|
84
|
-
|
85
|
-
"
|
86
|
-
"
|
87
|
-
"Elapsed Time": f"{elapsed_time}",
|
83
|
+
state_info: dict[str, str] = {
|
84
|
+
"Steps": format_number(line.state.num_steps, 0),
|
85
|
+
"Samples": format_number(line.state.num_samples, 0),
|
86
|
+
"Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
|
88
87
|
}
|
89
88
|
|
90
89
|
colored_prefix = colored("Phase: ", "grey", bold=True)
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
98
98
|
) -> tuple[PyTree, State, DictConfig]: ...
|
99
99
|
|
100
100
|
@overload
|
101
|
-
def load_checkpoint(
|
101
|
+
def load_checkpoint(
|
102
|
+
self,
|
103
|
+
path: Path,
|
104
|
+
part: Literal["model"],
|
105
|
+
) -> PyTree: ...
|
102
106
|
|
103
107
|
@overload
|
104
|
-
def load_checkpoint(
|
108
|
+
def load_checkpoint(
|
109
|
+
self,
|
110
|
+
path: Path,
|
111
|
+
part: Literal["opt"],
|
112
|
+
) -> optax.GradientTransformation: ...
|
105
113
|
|
106
114
|
@overload
|
107
|
-
def load_checkpoint(
|
115
|
+
def load_checkpoint(
|
116
|
+
self,
|
117
|
+
path: Path,
|
118
|
+
part: Literal["opt_state"],
|
119
|
+
) -> optax.OptState: ...
|
108
120
|
|
109
121
|
@overload
|
110
|
-
def load_checkpoint(
|
122
|
+
def load_checkpoint(
|
123
|
+
self,
|
124
|
+
path: Path,
|
125
|
+
part: Literal["state"],
|
126
|
+
) -> State: ...
|
111
127
|
|
112
128
|
@overload
|
113
|
-
def load_checkpoint(
|
129
|
+
def load_checkpoint(
|
130
|
+
self,
|
131
|
+
path: Path,
|
132
|
+
part: Literal["config"],
|
133
|
+
) -> DictConfig: ...
|
114
134
|
|
115
135
|
def load_checkpoint(
|
116
136
|
self,
|
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
125
145
|
| State
|
126
146
|
| DictConfig
|
127
147
|
):
|
128
|
-
# Calls the base callback.
|
129
|
-
self.on_before_checkpoint_load(path)
|
130
|
-
|
131
148
|
with tarfile.open(path, "r:gz") as tar:
|
132
149
|
|
133
150
|
def get_model() -> PyTree:
|
xax/task/mixins/compile.py
CHANGED
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
|
|
32
32
|
@dataclass
|
33
33
|
class CompileOptions:
|
34
34
|
# JAX compilation options
|
35
|
+
debug_nans: bool = field(
|
36
|
+
value=False,
|
37
|
+
help="If True, breaks on NaNs",
|
38
|
+
)
|
35
39
|
disable_jit: bool = field(
|
36
40
|
value=False,
|
37
41
|
help="If True, disables JIT compilation",
|
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
|
|
89
93
|
cc = self.config.compile
|
90
94
|
|
91
95
|
# Set basic compilation flags
|
96
|
+
if cc.debug_nans:
|
97
|
+
logger.info("Enabling NaNs debugging")
|
98
|
+
jax.config.update("jax_debug_nans", True)
|
99
|
+
|
92
100
|
if cc.disable_jit:
|
93
101
|
logger.info("Disabling JIT compilation")
|
94
102
|
jax.config.update("jax_disable_jit", True)
|
xax/utils/debugging.py
CHANGED
@@ -4,6 +4,8 @@ from collections import deque
|
|
4
4
|
from collections.abc import Iterable, Mapping
|
5
5
|
from typing import Any, Callable, Deque
|
6
6
|
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
7
9
|
from jaxtyping import Array
|
8
10
|
|
9
11
|
|
@@ -47,3 +49,11 @@ def get_named_leaves(
|
|
47
49
|
q.append((depth + 1, gname, cnode))
|
48
50
|
|
49
51
|
return ret
|
52
|
+
|
53
|
+
|
54
|
+
def breakpoint_if_nan(x: Array) -> None:
|
55
|
+
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
|
56
|
+
|
57
|
+
|
58
|
+
def log_if_nan(x: Array, loc: str) -> None:
|
59
|
+
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=D7czvfKKQJlemPuatMPVYbAO4ST3U272QRIyTOru7JI,13989
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -16,7 +16,7 @@ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
19
|
+
xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
|
20
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
@@ -28,12 +28,12 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
28
28
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
29
29
|
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
|
-
xax/task/loggers/stdout.py,sha256=
|
31
|
+
xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,6788
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
36
|
-
xax/task/mixins/compile.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
|
36
|
+
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
38
38
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
39
39
|
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
@@ -43,7 +43,7 @@ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1
|
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
44
|
xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
|
-
xax/utils/debugging.py,sha256=
|
46
|
+
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.1.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.14.dist-info/METADATA,sha256=WbKtAXJUYKHvBrOJPEm_eXF9O9ekc0WdPmsQQCSGG5Q,1878
|
63
|
+
xax-0.1.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|