xax 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -2
- xax/task/loggers/stdout.py +5 -6
- xax/utils/debugging.py +6 -0
- {xax-0.1.12.dist-info → xax-0.1.13.dist-info}/METADATA +1 -1
- {xax-0.1.12.dist-info → xax-0.1.13.dist-info}/RECORD +8 -8
- {xax-0.1.12.dist-info → xax-0.1.13.dist-info}/WHEEL +0 -0
- {xax-0.1.12.dist-info → xax-0.1.13.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.12.dist-info → xax-0.1.13.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.13"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -75,6 +75,7 @@ __all__ = [
|
|
75
75
|
"Task",
|
76
76
|
"collate",
|
77
77
|
"collate_non_null",
|
78
|
+
"breakpoint_if_nan",
|
78
79
|
"get_named_leaves",
|
79
80
|
"BaseFileDownloader",
|
80
81
|
"ContextTimer",
|
@@ -234,6 +235,7 @@ NAME_MAP: dict[str, str] = {
|
|
234
235
|
"Task": "task.task",
|
235
236
|
"collate": "utils.data.collate",
|
236
237
|
"collate_non_null": "utils.data.collate",
|
238
|
+
"breakpoint_if_nan": "utils.debugging",
|
237
239
|
"get_named_leaves": "utils.debugging",
|
238
240
|
"BaseFileDownloader": "utils.experiments",
|
239
241
|
"ContextTimer": "utils.experiments",
|
@@ -386,7 +388,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
386
388
|
from xax.task.script import Script, ScriptConfig
|
387
389
|
from xax.task.task import Config, Task
|
388
390
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
389
|
-
from xax.utils.debugging import get_named_leaves
|
391
|
+
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
|
390
392
|
from xax.utils.experiments import (
|
391
393
|
BaseFileDownloader,
|
392
394
|
ContextTimer,
|
xax/task/loggers/stdout.py
CHANGED
@@ -14,7 +14,7 @@ from xax.utils.text import Color, colored, format_timedelta
|
|
14
14
|
|
15
15
|
def format_number(value: int | float, precision: int) -> str:
|
16
16
|
if isinstance(value, int):
|
17
|
-
return
|
17
|
+
return f"{value:,}" # Add commas to the number
|
18
18
|
return f"{value:.{precision}g}"
|
19
19
|
|
20
20
|
|
@@ -80,11 +80,10 @@ class StdoutLogger(LoggerImpl):
|
|
80
80
|
self.write_fp.write("\033[2J\033[H")
|
81
81
|
|
82
82
|
def write_state_window(self, line: LogLine) -> None:
|
83
|
-
|
84
|
-
|
85
|
-
"
|
86
|
-
"
|
87
|
-
"Elapsed Time": f"{elapsed_time}",
|
83
|
+
state_info: dict[str, str] = {
|
84
|
+
"Steps": format_number(line.state.num_steps, 0),
|
85
|
+
"Samples": format_number(line.state.num_samples, 0),
|
86
|
+
"Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
|
88
87
|
}
|
89
88
|
|
90
89
|
colored_prefix = colored("Phase: ", "grey", bold=True)
|
xax/utils/debugging.py
CHANGED
@@ -4,6 +4,8 @@ from collections import deque
|
|
4
4
|
from collections.abc import Iterable, Mapping
|
5
5
|
from typing import Any, Callable, Deque
|
6
6
|
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
7
9
|
from jaxtyping import Array
|
8
10
|
|
9
11
|
|
@@ -47,3 +49,7 @@ def get_named_leaves(
|
|
47
49
|
q.append((depth + 1, gname, cnode))
|
48
50
|
|
49
51
|
return ret
|
52
|
+
|
53
|
+
|
54
|
+
def breakpoint_if_nan(x: Array) -> None:
|
55
|
+
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=zBQiNKXLrE-9UPy1q-XBFwPNAKTAxr6wpAwYwaTVggs,13922
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -28,7 +28,7 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
28
28
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
29
29
|
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
|
-
xax/task/loggers/stdout.py,sha256=
|
31
|
+
xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,6788
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
@@ -43,7 +43,7 @@ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1
|
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
44
|
xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
|
-
xax/utils/debugging.py,sha256=
|
46
|
+
xax/utils/debugging.py,sha256=0DU41DvYp3SZ9tMrM7sSFpfhC7dieMYR7eRlGNAFrdM,1783
|
47
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.1.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.13.dist-info/METADATA,sha256=P3N5hJMZtXPs199OHPN_cBp57S9zlOGB2B1TRPaczuI,1878
|
63
|
+
xax-0.1.13.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|