xax 0.1.12__tar.gz → 0.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.12/xax.egg-info → xax-0.1.13}/PKG-INFO +1 -1
  2. {xax-0.1.12 → xax-0.1.13}/xax/__init__.py +4 -2
  3. {xax-0.1.12 → xax-0.1.13}/xax/task/loggers/stdout.py +5 -6
  4. {xax-0.1.12 → xax-0.1.13}/xax/utils/debugging.py +6 -0
  5. {xax-0.1.12 → xax-0.1.13/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.1.12 → xax-0.1.13}/LICENSE +0 -0
  7. {xax-0.1.12 → xax-0.1.13}/MANIFEST.in +0 -0
  8. {xax-0.1.12 → xax-0.1.13}/README.md +0 -0
  9. {xax-0.1.12 → xax-0.1.13}/pyproject.toml +0 -0
  10. {xax-0.1.12 → xax-0.1.13}/setup.cfg +0 -0
  11. {xax-0.1.12 → xax-0.1.13}/setup.py +0 -0
  12. {xax-0.1.12 → xax-0.1.13}/xax/core/__init__.py +0 -0
  13. {xax-0.1.12 → xax-0.1.13}/xax/core/conf.py +0 -0
  14. {xax-0.1.12 → xax-0.1.13}/xax/core/state.py +0 -0
  15. {xax-0.1.12 → xax-0.1.13}/xax/nn/__init__.py +0 -0
  16. {xax-0.1.12 → xax-0.1.13}/xax/nn/embeddings.py +0 -0
  17. {xax-0.1.12 → xax-0.1.13}/xax/nn/equinox.py +0 -0
  18. {xax-0.1.12 → xax-0.1.13}/xax/nn/export.py +0 -0
  19. {xax-0.1.12 → xax-0.1.13}/xax/nn/functions.py +0 -0
  20. {xax-0.1.12 → xax-0.1.13}/xax/nn/geom.py +0 -0
  21. {xax-0.1.12 → xax-0.1.13}/xax/nn/losses.py +0 -0
  22. {xax-0.1.12 → xax-0.1.13}/xax/nn/norm.py +0 -0
  23. {xax-0.1.12 → xax-0.1.13}/xax/nn/parallel.py +0 -0
  24. {xax-0.1.12 → xax-0.1.13}/xax/nn/ssm.py +0 -0
  25. {xax-0.1.12 → xax-0.1.13}/xax/py.typed +0 -0
  26. {xax-0.1.12 → xax-0.1.13}/xax/requirements-dev.txt +0 -0
  27. {xax-0.1.12 → xax-0.1.13}/xax/requirements.txt +0 -0
  28. {xax-0.1.12 → xax-0.1.13}/xax/task/__init__.py +0 -0
  29. {xax-0.1.12 → xax-0.1.13}/xax/task/base.py +0 -0
  30. {xax-0.1.12 → xax-0.1.13}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.1.12 → xax-0.1.13}/xax/task/launchers/base.py +0 -0
  32. {xax-0.1.12 → xax-0.1.13}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.1.12 → xax-0.1.13}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.1.12 → xax-0.1.13}/xax/task/logger.py +0 -0
  35. {xax-0.1.12 → xax-0.1.13}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.1.12 → xax-0.1.13}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.1.12 → xax-0.1.13}/xax/task/loggers/json.py +0 -0
  38. {xax-0.1.12 → xax-0.1.13}/xax/task/loggers/state.py +0 -0
  39. {xax-0.1.12 → xax-0.1.13}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/process.py +0 -0
  49. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.1.12 → xax-0.1.13}/xax/task/mixins/train.py +0 -0
  52. {xax-0.1.12 → xax-0.1.13}/xax/task/script.py +0 -0
  53. {xax-0.1.12 → xax-0.1.13}/xax/task/task.py +0 -0
  54. {xax-0.1.12 → xax-0.1.13}/xax/utils/__init__.py +0 -0
  55. {xax-0.1.12 → xax-0.1.13}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.1.12 → xax-0.1.13}/xax/utils/data/collate.py +0 -0
  57. {xax-0.1.12 → xax-0.1.13}/xax/utils/experiments.py +0 -0
  58. {xax-0.1.12 → xax-0.1.13}/xax/utils/jax.py +0 -0
  59. {xax-0.1.12 → xax-0.1.13}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.1.12 → xax-0.1.13}/xax/utils/logging.py +0 -0
  61. {xax-0.1.12 → xax-0.1.13}/xax/utils/numpy.py +0 -0
  62. {xax-0.1.12 → xax-0.1.13}/xax/utils/profile.py +0 -0
  63. {xax-0.1.12 → xax-0.1.13}/xax/utils/pytree.py +0 -0
  64. {xax-0.1.12 → xax-0.1.13}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.1.12 → xax-0.1.13}/xax/utils/text.py +0 -0
  66. {xax-0.1.12 → xax-0.1.13}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.1.12 → xax-0.1.13}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.1.12 → xax-0.1.13}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.1.12 → xax-0.1.13}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.1.12 → xax-0.1.13}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.1.12 → xax-0.1.13}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.1.12 → xax-0.1.13}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.12
3
+ Version: 0.1.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.12"
15
+ __version__ = "0.1.13"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -75,6 +75,7 @@ __all__ = [
75
75
  "Task",
76
76
  "collate",
77
77
  "collate_non_null",
78
+ "breakpoint_if_nan",
78
79
  "get_named_leaves",
79
80
  "BaseFileDownloader",
80
81
  "ContextTimer",
@@ -234,6 +235,7 @@ NAME_MAP: dict[str, str] = {
234
235
  "Task": "task.task",
235
236
  "collate": "utils.data.collate",
236
237
  "collate_non_null": "utils.data.collate",
238
+ "breakpoint_if_nan": "utils.debugging",
237
239
  "get_named_leaves": "utils.debugging",
238
240
  "BaseFileDownloader": "utils.experiments",
239
241
  "ContextTimer": "utils.experiments",
@@ -386,7 +388,7 @@ if IMPORT_ALL or TYPE_CHECKING:
386
388
  from xax.task.script import Script, ScriptConfig
387
389
  from xax.task.task import Config, Task
388
390
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
389
- from xax.utils.debugging import get_named_leaves
391
+ from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
390
392
  from xax.utils.experiments import (
391
393
  BaseFileDownloader,
392
394
  ContextTimer,
@@ -14,7 +14,7 @@ from xax.utils.text import Color, colored, format_timedelta
14
14
 
15
15
  def format_number(value: int | float, precision: int) -> str:
16
16
  if isinstance(value, int):
17
- return str(value)
17
+ return f"{value:,}" # Add commas to the number
18
18
  return f"{value:.{precision}g}"
19
19
 
20
20
 
@@ -80,11 +80,10 @@ class StdoutLogger(LoggerImpl):
80
80
  self.write_fp.write("\033[2J\033[H")
81
81
 
82
82
  def write_state_window(self, line: LogLine) -> None:
83
- elapsed_time = format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True)
84
- state_info = {
85
- "Steps": f"{line.state.num_steps}",
86
- "Samples": f"{line.state.num_samples}",
87
- "Elapsed Time": f"{elapsed_time}",
83
+ state_info: dict[str, str] = {
84
+ "Steps": format_number(line.state.num_steps, 0),
85
+ "Samples": format_number(line.state.num_samples, 0),
86
+ "Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
88
87
  }
89
88
 
90
89
  colored_prefix = colored("Phase: ", "grey", bold=True)
@@ -4,6 +4,8 @@ from collections import deque
4
4
  from collections.abc import Iterable, Mapping
5
5
  from typing import Any, Callable, Deque
6
6
 
7
+ import jax
8
+ import jax.numpy as jnp
7
9
  from jaxtyping import Array
8
10
 
9
11
 
@@ -47,3 +49,7 @@ def get_named_leaves(
47
49
  q.append((depth + 1, gname, cnode))
48
50
 
49
51
  return ret
52
+
53
+
54
+ def breakpoint_if_nan(x: Array) -> None:
55
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.12
3
+ Version: 0.1.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes