xax 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -2
- xax/task/base.py +0 -3
- xax/task/mixins/checkpointing.py +25 -8
- xax/task/mixins/compile.py +8 -0
- xax/utils/debugging.py +4 -0
- {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/METADATA +1 -1
- {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/RECORD +10 -10
- {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/WHEEL +0 -0
- {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.14"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -77,6 +77,7 @@ __all__ = [
|
|
77
77
|
"collate_non_null",
|
78
78
|
"breakpoint_if_nan",
|
79
79
|
"get_named_leaves",
|
80
|
+
"log_if_nan",
|
80
81
|
"BaseFileDownloader",
|
81
82
|
"ContextTimer",
|
82
83
|
"CumulativeTimer",
|
@@ -237,6 +238,7 @@ NAME_MAP: dict[str, str] = {
|
|
237
238
|
"collate_non_null": "utils.data.collate",
|
238
239
|
"breakpoint_if_nan": "utils.debugging",
|
239
240
|
"get_named_leaves": "utils.debugging",
|
241
|
+
"log_if_nan": "utils.debugging",
|
240
242
|
"BaseFileDownloader": "utils.experiments",
|
241
243
|
"ContextTimer": "utils.experiments",
|
242
244
|
"CumulativeTimer": "utils.experiments",
|
@@ -388,7 +390,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
388
390
|
from xax.task.script import Script, ScriptConfig
|
389
391
|
from xax.task.task import Config, Task
|
390
392
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
391
|
-
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
|
393
|
+
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
|
392
394
|
from xax.utils.experiments import (
|
393
395
|
BaseFileDownloader,
|
394
396
|
ContextTimer,
|
xax/task/base.py
CHANGED
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
|
|
82
82
|
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
83
83
|
return state
|
84
84
|
|
85
|
-
def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
|
86
|
-
pass
|
87
|
-
|
88
85
|
@functools.cached_property
|
89
86
|
def task_class_name(self) -> str:
|
90
87
|
return self.__class__.__name__
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
98
98
|
) -> tuple[PyTree, State, DictConfig]: ...
|
99
99
|
|
100
100
|
@overload
|
101
|
-
def load_checkpoint(
|
101
|
+
def load_checkpoint(
|
102
|
+
self,
|
103
|
+
path: Path,
|
104
|
+
part: Literal["model"],
|
105
|
+
) -> PyTree: ...
|
102
106
|
|
103
107
|
@overload
|
104
|
-
def load_checkpoint(
|
108
|
+
def load_checkpoint(
|
109
|
+
self,
|
110
|
+
path: Path,
|
111
|
+
part: Literal["opt"],
|
112
|
+
) -> optax.GradientTransformation: ...
|
105
113
|
|
106
114
|
@overload
|
107
|
-
def load_checkpoint(
|
115
|
+
def load_checkpoint(
|
116
|
+
self,
|
117
|
+
path: Path,
|
118
|
+
part: Literal["opt_state"],
|
119
|
+
) -> optax.OptState: ...
|
108
120
|
|
109
121
|
@overload
|
110
|
-
def load_checkpoint(
|
122
|
+
def load_checkpoint(
|
123
|
+
self,
|
124
|
+
path: Path,
|
125
|
+
part: Literal["state"],
|
126
|
+
) -> State: ...
|
111
127
|
|
112
128
|
@overload
|
113
|
-
def load_checkpoint(
|
129
|
+
def load_checkpoint(
|
130
|
+
self,
|
131
|
+
path: Path,
|
132
|
+
part: Literal["config"],
|
133
|
+
) -> DictConfig: ...
|
114
134
|
|
115
135
|
def load_checkpoint(
|
116
136
|
self,
|
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
125
145
|
| State
|
126
146
|
| DictConfig
|
127
147
|
):
|
128
|
-
# Calls the base callback.
|
129
|
-
self.on_before_checkpoint_load(path)
|
130
|
-
|
131
148
|
with tarfile.open(path, "r:gz") as tar:
|
132
149
|
|
133
150
|
def get_model() -> PyTree:
|
xax/task/mixins/compile.py
CHANGED
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
|
|
32
32
|
@dataclass
|
33
33
|
class CompileOptions:
|
34
34
|
# JAX compilation options
|
35
|
+
debug_nans: bool = field(
|
36
|
+
value=False,
|
37
|
+
help="If True, breaks on NaNs",
|
38
|
+
)
|
35
39
|
disable_jit: bool = field(
|
36
40
|
value=False,
|
37
41
|
help="If True, disables JIT compilation",
|
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
|
|
89
93
|
cc = self.config.compile
|
90
94
|
|
91
95
|
# Set basic compilation flags
|
96
|
+
if cc.debug_nans:
|
97
|
+
logger.info("Enabling NaNs debugging")
|
98
|
+
jax.config.update("jax_debug_nans", True)
|
99
|
+
|
92
100
|
if cc.disable_jit:
|
93
101
|
logger.info("Disabling JIT compilation")
|
94
102
|
jax.config.update("jax_disable_jit", True)
|
xax/utils/debugging.py
CHANGED
@@ -53,3 +53,7 @@ def get_named_leaves(
|
|
53
53
|
|
54
54
|
def breakpoint_if_nan(x: Array) -> None:
|
55
55
|
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
|
56
|
+
|
57
|
+
|
58
|
+
def log_if_nan(x: Array, loc: str) -> None:
|
59
|
+
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=D7czvfKKQJlemPuatMPVYbAO4ST3U272QRIyTOru7JI,13989
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -16,7 +16,7 @@ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
19
|
+
xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
|
20
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
@@ -32,8 +32,8 @@ xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,67
|
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
36
|
-
xax/task/mixins/compile.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
|
36
|
+
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
38
38
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
39
39
|
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
@@ -43,7 +43,7 @@ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1
|
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
44
|
xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
|
-
xax/utils/debugging.py,sha256=
|
46
|
+
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.1.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.14.dist-info/METADATA,sha256=WbKtAXJUYKHvBrOJPEm_eXF9O9ekc0WdPmsQQCSGG5Q,1878
|
63
|
+
xax-0.1.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|