xax 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.12"
15
+ __version__ = "0.1.13"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -75,6 +75,7 @@ __all__ = [
75
75
  "Task",
76
76
  "collate",
77
77
  "collate_non_null",
78
+ "breakpoint_if_nan",
78
79
  "get_named_leaves",
79
80
  "BaseFileDownloader",
80
81
  "ContextTimer",
@@ -234,6 +235,7 @@ NAME_MAP: dict[str, str] = {
234
235
  "Task": "task.task",
235
236
  "collate": "utils.data.collate",
236
237
  "collate_non_null": "utils.data.collate",
238
+ "breakpoint_if_nan": "utils.debugging",
237
239
  "get_named_leaves": "utils.debugging",
238
240
  "BaseFileDownloader": "utils.experiments",
239
241
  "ContextTimer": "utils.experiments",
@@ -386,7 +388,7 @@ if IMPORT_ALL or TYPE_CHECKING:
386
388
  from xax.task.script import Script, ScriptConfig
387
389
  from xax.task.task import Config, Task
388
390
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
389
- from xax.utils.debugging import get_named_leaves
391
+ from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
390
392
  from xax.utils.experiments import (
391
393
  BaseFileDownloader,
392
394
  ContextTimer,
@@ -14,7 +14,7 @@ from xax.utils.text import Color, colored, format_timedelta
14
14
 
15
15
  def format_number(value: int | float, precision: int) -> str:
16
16
  if isinstance(value, int):
17
- return str(value)
17
+ return f"{value:,}" # Add commas to the number
18
18
  return f"{value:.{precision}g}"
19
19
 
20
20
 
@@ -80,11 +80,10 @@ class StdoutLogger(LoggerImpl):
80
80
  self.write_fp.write("\033[2J\033[H")
81
81
 
82
82
  def write_state_window(self, line: LogLine) -> None:
83
- elapsed_time = format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True)
84
- state_info = {
85
- "Steps": f"{line.state.num_steps}",
86
- "Samples": f"{line.state.num_samples}",
87
- "Elapsed Time": f"{elapsed_time}",
83
+ state_info: dict[str, str] = {
84
+ "Steps": format_number(line.state.num_steps, 0),
85
+ "Samples": format_number(line.state.num_samples, 0),
86
+ "Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
88
87
  }
89
88
 
90
89
  colored_prefix = colored("Phase: ", "grey", bold=True)
xax/utils/debugging.py CHANGED
@@ -4,6 +4,8 @@ from collections import deque
4
4
  from collections.abc import Iterable, Mapping
5
5
  from typing import Any, Callable, Deque
6
6
 
7
+ import jax
8
+ import jax.numpy as jnp
7
9
  from jaxtyping import Array
8
10
 
9
11
 
@@ -47,3 +49,7 @@ def get_named_leaves(
47
49
  q.append((depth + 1, gname, cnode))
48
50
 
49
51
  return ret
52
+
53
+
54
+ def breakpoint_if_nan(x: Array) -> None:
55
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.12
3
+ Version: 0.1.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=7vdTYO7jAJdDxKZURlFxc3Y5kr5mVQcTQjeh_sYjD6I,13834
1
+ xax/__init__.py,sha256=zBQiNKXLrE-9UPy1q-XBFwPNAKTAxr6wpAwYwaTVggs,13922
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -28,7 +28,7 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
28
28
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
29
29
  xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
- xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,6757
31
+ xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,6788
32
32
  xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
@@ -43,7 +43,7 @@ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
44
  xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
- xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
46
+ xax/utils/debugging.py,sha256=0DU41DvYp3SZ9tMrM7sSFpfhC7dieMYR7eRlGNAFrdM,1783
47
47
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.12.dist-info/METADATA,sha256=hLRAX5__7QjBgjzhxbRftGvEsNrt8IAdgd22dMtHu_Y,1878
63
- xax-0.1.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.12.dist-info/RECORD,,
61
+ xax-0.1.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.13.dist-info/METADATA,sha256=P3N5hJMZtXPs199OHPN_cBp57S9zlOGB2B1TRPaczuI,1878
63
+ xax-0.1.13.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.13.dist-info/RECORD,,
File without changes