xax 0.1.12__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.12/xax.egg-info → xax-0.1.14}/PKG-INFO +1 -1
  2. {xax-0.1.12 → xax-0.1.14}/xax/__init__.py +6 -2
  3. {xax-0.1.12 → xax-0.1.14}/xax/task/base.py +0 -3
  4. {xax-0.1.12 → xax-0.1.14}/xax/task/loggers/stdout.py +5 -6
  5. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/checkpointing.py +25 -8
  6. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/compile.py +8 -0
  7. {xax-0.1.12 → xax-0.1.14}/xax/utils/debugging.py +10 -0
  8. {xax-0.1.12 → xax-0.1.14/xax.egg-info}/PKG-INFO +1 -1
  9. {xax-0.1.12 → xax-0.1.14}/LICENSE +0 -0
  10. {xax-0.1.12 → xax-0.1.14}/MANIFEST.in +0 -0
  11. {xax-0.1.12 → xax-0.1.14}/README.md +0 -0
  12. {xax-0.1.12 → xax-0.1.14}/pyproject.toml +0 -0
  13. {xax-0.1.12 → xax-0.1.14}/setup.cfg +0 -0
  14. {xax-0.1.12 → xax-0.1.14}/setup.py +0 -0
  15. {xax-0.1.12 → xax-0.1.14}/xax/core/__init__.py +0 -0
  16. {xax-0.1.12 → xax-0.1.14}/xax/core/conf.py +0 -0
  17. {xax-0.1.12 → xax-0.1.14}/xax/core/state.py +0 -0
  18. {xax-0.1.12 → xax-0.1.14}/xax/nn/__init__.py +0 -0
  19. {xax-0.1.12 → xax-0.1.14}/xax/nn/embeddings.py +0 -0
  20. {xax-0.1.12 → xax-0.1.14}/xax/nn/equinox.py +0 -0
  21. {xax-0.1.12 → xax-0.1.14}/xax/nn/export.py +0 -0
  22. {xax-0.1.12 → xax-0.1.14}/xax/nn/functions.py +0 -0
  23. {xax-0.1.12 → xax-0.1.14}/xax/nn/geom.py +0 -0
  24. {xax-0.1.12 → xax-0.1.14}/xax/nn/losses.py +0 -0
  25. {xax-0.1.12 → xax-0.1.14}/xax/nn/norm.py +0 -0
  26. {xax-0.1.12 → xax-0.1.14}/xax/nn/parallel.py +0 -0
  27. {xax-0.1.12 → xax-0.1.14}/xax/nn/ssm.py +0 -0
  28. {xax-0.1.12 → xax-0.1.14}/xax/py.typed +0 -0
  29. {xax-0.1.12 → xax-0.1.14}/xax/requirements-dev.txt +0 -0
  30. {xax-0.1.12 → xax-0.1.14}/xax/requirements.txt +0 -0
  31. {xax-0.1.12 → xax-0.1.14}/xax/task/__init__.py +0 -0
  32. {xax-0.1.12 → xax-0.1.14}/xax/task/launchers/__init__.py +0 -0
  33. {xax-0.1.12 → xax-0.1.14}/xax/task/launchers/base.py +0 -0
  34. {xax-0.1.12 → xax-0.1.14}/xax/task/launchers/cli.py +0 -0
  35. {xax-0.1.12 → xax-0.1.14}/xax/task/launchers/single_process.py +0 -0
  36. {xax-0.1.12 → xax-0.1.14}/xax/task/logger.py +0 -0
  37. {xax-0.1.12 → xax-0.1.14}/xax/task/loggers/__init__.py +0 -0
  38. {xax-0.1.12 → xax-0.1.14}/xax/task/loggers/callback.py +0 -0
  39. {xax-0.1.12 → xax-0.1.14}/xax/task/loggers/json.py +0 -0
  40. {xax-0.1.12 → xax-0.1.14}/xax/task/loggers/state.py +0 -0
  41. {xax-0.1.12 → xax-0.1.14}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/artifacts.py +0 -0
  44. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/process.py +0 -0
  49. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.1.12 → xax-0.1.14}/xax/task/mixins/train.py +0 -0
  52. {xax-0.1.12 → xax-0.1.14}/xax/task/script.py +0 -0
  53. {xax-0.1.12 → xax-0.1.14}/xax/task/task.py +0 -0
  54. {xax-0.1.12 → xax-0.1.14}/xax/utils/__init__.py +0 -0
  55. {xax-0.1.12 → xax-0.1.14}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.1.12 → xax-0.1.14}/xax/utils/data/collate.py +0 -0
  57. {xax-0.1.12 → xax-0.1.14}/xax/utils/experiments.py +0 -0
  58. {xax-0.1.12 → xax-0.1.14}/xax/utils/jax.py +0 -0
  59. {xax-0.1.12 → xax-0.1.14}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.1.12 → xax-0.1.14}/xax/utils/logging.py +0 -0
  61. {xax-0.1.12 → xax-0.1.14}/xax/utils/numpy.py +0 -0
  62. {xax-0.1.12 → xax-0.1.14}/xax/utils/profile.py +0 -0
  63. {xax-0.1.12 → xax-0.1.14}/xax/utils/pytree.py +0 -0
  64. {xax-0.1.12 → xax-0.1.14}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.1.12 → xax-0.1.14}/xax/utils/text.py +0 -0
  66. {xax-0.1.12 → xax-0.1.14}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.1.12 → xax-0.1.14}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.1.12 → xax-0.1.14}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.1.12 → xax-0.1.14}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.1.12 → xax-0.1.14}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.1.12 → xax-0.1.14}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.1.12 → xax-0.1.14}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.12"
15
+ __version__ = "0.1.14"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -75,7 +75,9 @@ __all__ = [
75
75
  "Task",
76
76
  "collate",
77
77
  "collate_non_null",
78
+ "breakpoint_if_nan",
78
79
  "get_named_leaves",
80
+ "log_if_nan",
79
81
  "BaseFileDownloader",
80
82
  "ContextTimer",
81
83
  "CumulativeTimer",
@@ -234,7 +236,9 @@ NAME_MAP: dict[str, str] = {
234
236
  "Task": "task.task",
235
237
  "collate": "utils.data.collate",
236
238
  "collate_non_null": "utils.data.collate",
239
+ "breakpoint_if_nan": "utils.debugging",
237
240
  "get_named_leaves": "utils.debugging",
241
+ "log_if_nan": "utils.debugging",
238
242
  "BaseFileDownloader": "utils.experiments",
239
243
  "ContextTimer": "utils.experiments",
240
244
  "CumulativeTimer": "utils.experiments",
@@ -386,7 +390,7 @@ if IMPORT_ALL or TYPE_CHECKING:
386
390
  from xax.task.script import Script, ScriptConfig
387
391
  from xax.task.task import Config, Task
388
392
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
389
- from xax.utils.debugging import get_named_leaves
393
+ from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
390
394
  from xax.utils.experiments import (
391
395
  BaseFileDownloader,
392
396
  ContextTimer,
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
83
83
  return state
84
84
 
85
- def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
86
- pass
87
-
88
85
  @functools.cached_property
89
86
  def task_class_name(self) -> str:
90
87
  return self.__class__.__name__
@@ -14,7 +14,7 @@ from xax.utils.text import Color, colored, format_timedelta
14
14
 
15
15
  def format_number(value: int | float, precision: int) -> str:
16
16
  if isinstance(value, int):
17
- return str(value)
17
+ return f"{value:,}" # Add commas to the number
18
18
  return f"{value:.{precision}g}"
19
19
 
20
20
 
@@ -80,11 +80,10 @@ class StdoutLogger(LoggerImpl):
80
80
  self.write_fp.write("\033[2J\033[H")
81
81
 
82
82
  def write_state_window(self, line: LogLine) -> None:
83
- elapsed_time = format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True)
84
- state_info = {
85
- "Steps": f"{line.state.num_steps}",
86
- "Samples": f"{line.state.num_samples}",
87
- "Elapsed Time": f"{elapsed_time}",
83
+ state_info: dict[str, str] = {
84
+ "Steps": format_number(line.state.num_steps, 0),
85
+ "Samples": format_number(line.state.num_samples, 0),
86
+ "Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
88
87
  }
89
88
 
90
89
  colored_prefix = colored("Phase: ", "grey", bold=True)
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
98
98
  ) -> tuple[PyTree, State, DictConfig]: ...
99
99
 
100
100
  @overload
101
- def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
101
+ def load_checkpoint(
102
+ self,
103
+ path: Path,
104
+ part: Literal["model"],
105
+ ) -> PyTree: ...
102
106
 
103
107
  @overload
104
- def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
108
+ def load_checkpoint(
109
+ self,
110
+ path: Path,
111
+ part: Literal["opt"],
112
+ ) -> optax.GradientTransformation: ...
105
113
 
106
114
  @overload
107
- def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
115
+ def load_checkpoint(
116
+ self,
117
+ path: Path,
118
+ part: Literal["opt_state"],
119
+ ) -> optax.OptState: ...
108
120
 
109
121
  @overload
110
- def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
122
+ def load_checkpoint(
123
+ self,
124
+ path: Path,
125
+ part: Literal["state"],
126
+ ) -> State: ...
111
127
 
112
128
  @overload
113
- def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
129
+ def load_checkpoint(
130
+ self,
131
+ path: Path,
132
+ part: Literal["config"],
133
+ ) -> DictConfig: ...
114
134
 
115
135
  def load_checkpoint(
116
136
  self,
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
125
145
  | State
126
146
  | DictConfig
127
147
  ):
128
- # Calls the base callback.
129
- self.on_before_checkpoint_load(path)
130
-
131
148
  with tarfile.open(path, "r:gz") as tar:
132
149
 
133
150
  def get_model() -> PyTree:
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
32
32
  @dataclass
33
33
  class CompileOptions:
34
34
  # JAX compilation options
35
+ debug_nans: bool = field(
36
+ value=False,
37
+ help="If True, breaks on NaNs",
38
+ )
35
39
  disable_jit: bool = field(
36
40
  value=False,
37
41
  help="If True, disables JIT compilation",
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
89
93
  cc = self.config.compile
90
94
 
91
95
  # Set basic compilation flags
96
+ if cc.debug_nans:
97
+ logger.info("Enabling NaNs debugging")
98
+ jax.config.update("jax_debug_nans", True)
99
+
92
100
  if cc.disable_jit:
93
101
  logger.info("Disabling JIT compilation")
94
102
  jax.config.update("jax_disable_jit", True)
@@ -4,6 +4,8 @@ from collections import deque
4
4
  from collections.abc import Iterable, Mapping
5
5
  from typing import Any, Callable, Deque
6
6
 
7
+ import jax
8
+ import jax.numpy as jnp
7
9
  from jaxtyping import Array
8
10
 
9
11
 
@@ -47,3 +49,11 @@ def get_named_leaves(
47
49
  q.append((depth + 1, gname, cnode))
48
50
 
49
51
  return ret
52
+
53
+
54
+ def breakpoint_if_nan(x: Array) -> None:
55
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
56
+
57
+
58
+ def log_if_nan(x: Array, loc: str) -> None:
59
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes