xax 0.1.13__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.13/xax.egg-info → xax-0.1.14}/PKG-INFO +1 -1
  2. {xax-0.1.13 → xax-0.1.14}/xax/__init__.py +4 -2
  3. {xax-0.1.13 → xax-0.1.14}/xax/task/base.py +0 -3
  4. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/checkpointing.py +25 -8
  5. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/compile.py +8 -0
  6. {xax-0.1.13 → xax-0.1.14}/xax/utils/debugging.py +4 -0
  7. {xax-0.1.13 → xax-0.1.14/xax.egg-info}/PKG-INFO +1 -1
  8. {xax-0.1.13 → xax-0.1.14}/LICENSE +0 -0
  9. {xax-0.1.13 → xax-0.1.14}/MANIFEST.in +0 -0
  10. {xax-0.1.13 → xax-0.1.14}/README.md +0 -0
  11. {xax-0.1.13 → xax-0.1.14}/pyproject.toml +0 -0
  12. {xax-0.1.13 → xax-0.1.14}/setup.cfg +0 -0
  13. {xax-0.1.13 → xax-0.1.14}/setup.py +0 -0
  14. {xax-0.1.13 → xax-0.1.14}/xax/core/__init__.py +0 -0
  15. {xax-0.1.13 → xax-0.1.14}/xax/core/conf.py +0 -0
  16. {xax-0.1.13 → xax-0.1.14}/xax/core/state.py +0 -0
  17. {xax-0.1.13 → xax-0.1.14}/xax/nn/__init__.py +0 -0
  18. {xax-0.1.13 → xax-0.1.14}/xax/nn/embeddings.py +0 -0
  19. {xax-0.1.13 → xax-0.1.14}/xax/nn/equinox.py +0 -0
  20. {xax-0.1.13 → xax-0.1.14}/xax/nn/export.py +0 -0
  21. {xax-0.1.13 → xax-0.1.14}/xax/nn/functions.py +0 -0
  22. {xax-0.1.13 → xax-0.1.14}/xax/nn/geom.py +0 -0
  23. {xax-0.1.13 → xax-0.1.14}/xax/nn/losses.py +0 -0
  24. {xax-0.1.13 → xax-0.1.14}/xax/nn/norm.py +0 -0
  25. {xax-0.1.13 → xax-0.1.14}/xax/nn/parallel.py +0 -0
  26. {xax-0.1.13 → xax-0.1.14}/xax/nn/ssm.py +0 -0
  27. {xax-0.1.13 → xax-0.1.14}/xax/py.typed +0 -0
  28. {xax-0.1.13 → xax-0.1.14}/xax/requirements-dev.txt +0 -0
  29. {xax-0.1.13 → xax-0.1.14}/xax/requirements.txt +0 -0
  30. {xax-0.1.13 → xax-0.1.14}/xax/task/__init__.py +0 -0
  31. {xax-0.1.13 → xax-0.1.14}/xax/task/launchers/__init__.py +0 -0
  32. {xax-0.1.13 → xax-0.1.14}/xax/task/launchers/base.py +0 -0
  33. {xax-0.1.13 → xax-0.1.14}/xax/task/launchers/cli.py +0 -0
  34. {xax-0.1.13 → xax-0.1.14}/xax/task/launchers/single_process.py +0 -0
  35. {xax-0.1.13 → xax-0.1.14}/xax/task/logger.py +0 -0
  36. {xax-0.1.13 → xax-0.1.14}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.1.13 → xax-0.1.14}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.1.13 → xax-0.1.14}/xax/task/loggers/json.py +0 -0
  39. {xax-0.1.13 → xax-0.1.14}/xax/task/loggers/state.py +0 -0
  40. {xax-0.1.13 → xax-0.1.14}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.1.13 → xax-0.1.14}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/artifacts.py +0 -0
  44. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/process.py +0 -0
  49. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.1.13 → xax-0.1.14}/xax/task/mixins/train.py +0 -0
  52. {xax-0.1.13 → xax-0.1.14}/xax/task/script.py +0 -0
  53. {xax-0.1.13 → xax-0.1.14}/xax/task/task.py +0 -0
  54. {xax-0.1.13 → xax-0.1.14}/xax/utils/__init__.py +0 -0
  55. {xax-0.1.13 → xax-0.1.14}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.1.13 → xax-0.1.14}/xax/utils/data/collate.py +0 -0
  57. {xax-0.1.13 → xax-0.1.14}/xax/utils/experiments.py +0 -0
  58. {xax-0.1.13 → xax-0.1.14}/xax/utils/jax.py +0 -0
  59. {xax-0.1.13 → xax-0.1.14}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.1.13 → xax-0.1.14}/xax/utils/logging.py +0 -0
  61. {xax-0.1.13 → xax-0.1.14}/xax/utils/numpy.py +0 -0
  62. {xax-0.1.13 → xax-0.1.14}/xax/utils/profile.py +0 -0
  63. {xax-0.1.13 → xax-0.1.14}/xax/utils/pytree.py +0 -0
  64. {xax-0.1.13 → xax-0.1.14}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.1.13 → xax-0.1.14}/xax/utils/text.py +0 -0
  66. {xax-0.1.13 → xax-0.1.14}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.1.13 → xax-0.1.14}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.1.13 → xax-0.1.14}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.1.13 → xax-0.1.14}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.1.13 → xax-0.1.14}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.1.13 → xax-0.1.14}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.1.13 → xax-0.1.14}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.13
3
+ Version: 0.1.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.13"
15
+ __version__ = "0.1.14"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -77,6 +77,7 @@ __all__ = [
77
77
  "collate_non_null",
78
78
  "breakpoint_if_nan",
79
79
  "get_named_leaves",
80
+ "log_if_nan",
80
81
  "BaseFileDownloader",
81
82
  "ContextTimer",
82
83
  "CumulativeTimer",
@@ -237,6 +238,7 @@ NAME_MAP: dict[str, str] = {
237
238
  "collate_non_null": "utils.data.collate",
238
239
  "breakpoint_if_nan": "utils.debugging",
239
240
  "get_named_leaves": "utils.debugging",
241
+ "log_if_nan": "utils.debugging",
240
242
  "BaseFileDownloader": "utils.experiments",
241
243
  "ContextTimer": "utils.experiments",
242
244
  "CumulativeTimer": "utils.experiments",
@@ -388,7 +390,7 @@ if IMPORT_ALL or TYPE_CHECKING:
388
390
  from xax.task.script import Script, ScriptConfig
389
391
  from xax.task.task import Config, Task
390
392
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
391
- from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
393
+ from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
392
394
  from xax.utils.experiments import (
393
395
  BaseFileDownloader,
394
396
  ContextTimer,
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
83
83
  return state
84
84
 
85
- def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
86
- pass
87
-
88
85
  @functools.cached_property
89
86
  def task_class_name(self) -> str:
90
87
  return self.__class__.__name__
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
98
98
  ) -> tuple[PyTree, State, DictConfig]: ...
99
99
 
100
100
  @overload
101
- def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
101
+ def load_checkpoint(
102
+ self,
103
+ path: Path,
104
+ part: Literal["model"],
105
+ ) -> PyTree: ...
102
106
 
103
107
  @overload
104
- def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
108
+ def load_checkpoint(
109
+ self,
110
+ path: Path,
111
+ part: Literal["opt"],
112
+ ) -> optax.GradientTransformation: ...
105
113
 
106
114
  @overload
107
- def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
115
+ def load_checkpoint(
116
+ self,
117
+ path: Path,
118
+ part: Literal["opt_state"],
119
+ ) -> optax.OptState: ...
108
120
 
109
121
  @overload
110
- def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
122
+ def load_checkpoint(
123
+ self,
124
+ path: Path,
125
+ part: Literal["state"],
126
+ ) -> State: ...
111
127
 
112
128
  @overload
113
- def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
129
+ def load_checkpoint(
130
+ self,
131
+ path: Path,
132
+ part: Literal["config"],
133
+ ) -> DictConfig: ...
114
134
 
115
135
  def load_checkpoint(
116
136
  self,
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
125
145
  | State
126
146
  | DictConfig
127
147
  ):
128
- # Calls the base callback.
129
- self.on_before_checkpoint_load(path)
130
-
131
148
  with tarfile.open(path, "r:gz") as tar:
132
149
 
133
150
  def get_model() -> PyTree:
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
32
32
  @dataclass
33
33
  class CompileOptions:
34
34
  # JAX compilation options
35
+ debug_nans: bool = field(
36
+ value=False,
37
+ help="If True, breaks on NaNs",
38
+ )
35
39
  disable_jit: bool = field(
36
40
  value=False,
37
41
  help="If True, disables JIT compilation",
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
89
93
  cc = self.config.compile
90
94
 
91
95
  # Set basic compilation flags
96
+ if cc.debug_nans:
97
+ logger.info("Enabling NaNs debugging")
98
+ jax.config.update("jax_debug_nans", True)
99
+
92
100
  if cc.disable_jit:
93
101
  logger.info("Disabling JIT compilation")
94
102
  jax.config.update("jax_disable_jit", True)
@@ -53,3 +53,7 @@ def get_named_leaves(
53
53
 
54
54
  def breakpoint_if_nan(x: Array) -> None:
55
55
  jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
56
+
57
+
58
+ def log_if_nan(x: Array, loc: str) -> None:
59
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.13
3
+ Version: 0.1.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes