xax 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.13"
15
+ __version__ = "0.1.14"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -77,6 +77,7 @@ __all__ = [
77
77
  "collate_non_null",
78
78
  "breakpoint_if_nan",
79
79
  "get_named_leaves",
80
+ "log_if_nan",
80
81
  "BaseFileDownloader",
81
82
  "ContextTimer",
82
83
  "CumulativeTimer",
@@ -237,6 +238,7 @@ NAME_MAP: dict[str, str] = {
237
238
  "collate_non_null": "utils.data.collate",
238
239
  "breakpoint_if_nan": "utils.debugging",
239
240
  "get_named_leaves": "utils.debugging",
241
+ "log_if_nan": "utils.debugging",
240
242
  "BaseFileDownloader": "utils.experiments",
241
243
  "ContextTimer": "utils.experiments",
242
244
  "CumulativeTimer": "utils.experiments",
@@ -388,7 +390,7 @@ if IMPORT_ALL or TYPE_CHECKING:
388
390
  from xax.task.script import Script, ScriptConfig
389
391
  from xax.task.task import Config, Task
390
392
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
391
- from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
393
+ from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
392
394
  from xax.utils.experiments import (
393
395
  BaseFileDownloader,
394
396
  ContextTimer,
xax/task/base.py CHANGED
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
83
83
  return state
84
84
 
85
- def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
86
- pass
87
-
88
85
  @functools.cached_property
89
86
  def task_class_name(self) -> str:
90
87
  return self.__class__.__name__
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
98
98
  ) -> tuple[PyTree, State, DictConfig]: ...
99
99
 
100
100
  @overload
101
- def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
101
+ def load_checkpoint(
102
+ self,
103
+ path: Path,
104
+ part: Literal["model"],
105
+ ) -> PyTree: ...
102
106
 
103
107
  @overload
104
- def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
108
+ def load_checkpoint(
109
+ self,
110
+ path: Path,
111
+ part: Literal["opt"],
112
+ ) -> optax.GradientTransformation: ...
105
113
 
106
114
  @overload
107
- def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
115
+ def load_checkpoint(
116
+ self,
117
+ path: Path,
118
+ part: Literal["opt_state"],
119
+ ) -> optax.OptState: ...
108
120
 
109
121
  @overload
110
- def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
122
+ def load_checkpoint(
123
+ self,
124
+ path: Path,
125
+ part: Literal["state"],
126
+ ) -> State: ...
111
127
 
112
128
  @overload
113
- def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
129
+ def load_checkpoint(
130
+ self,
131
+ path: Path,
132
+ part: Literal["config"],
133
+ ) -> DictConfig: ...
114
134
 
115
135
  def load_checkpoint(
116
136
  self,
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
125
145
  | State
126
146
  | DictConfig
127
147
  ):
128
- # Calls the base callback.
129
- self.on_before_checkpoint_load(path)
130
-
131
148
  with tarfile.open(path, "r:gz") as tar:
132
149
 
133
150
  def get_model() -> PyTree:
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
32
32
  @dataclass
33
33
  class CompileOptions:
34
34
  # JAX compilation options
35
+ debug_nans: bool = field(
36
+ value=False,
37
+ help="If True, breaks on NaNs",
38
+ )
35
39
  disable_jit: bool = field(
36
40
  value=False,
37
41
  help="If True, disables JIT compilation",
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
89
93
  cc = self.config.compile
90
94
 
91
95
  # Set basic compilation flags
96
+ if cc.debug_nans:
97
+ logger.info("Enabling NaNs debugging")
98
+ jax.config.update("jax_debug_nans", True)
99
+
92
100
  if cc.disable_jit:
93
101
  logger.info("Disabling JIT compilation")
94
102
  jax.config.update("jax_disable_jit", True)
xax/utils/debugging.py CHANGED
@@ -53,3 +53,7 @@ def get_named_leaves(
53
53
 
54
54
  def breakpoint_if_nan(x: Array) -> None:
55
55
  jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
56
+
57
+
58
+ def log_if_nan(x: Array, loc: str) -> None:
59
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.13
3
+ Version: 0.1.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=zBQiNKXLrE-9UPy1q-XBFwPNAKTAxr6wpAwYwaTVggs,13922
1
+ xax/__init__.py,sha256=D7czvfKKQJlemPuatMPVYbAO4ST3U272QRIyTOru7JI,13989
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -16,7 +16,7 @@ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
19
+ xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
20
20
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
@@ -32,8 +32,8 @@ xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,67
32
32
  xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
36
- xax/task/mixins/compile.py,sha256=8jEdlGs-a14N_CwZA3Rxe461MT83dyIDr3Z56VkjviQ,3693
35
+ xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
36
+ xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
38
38
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
39
39
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
@@ -43,7 +43,7 @@ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
44
  xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
- xax/utils/debugging.py,sha256=0DU41DvYp3SZ9tMrM7sSFpfhC7dieMYR7eRlGNAFrdM,1783
46
+ xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.13.dist-info/METADATA,sha256=P3N5hJMZtXPs199OHPN_cBp57S9zlOGB2B1TRPaczuI,1878
63
- xax-0.1.13.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.13.dist-info/RECORD,,
61
+ xax-0.1.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.14.dist-info/METADATA,sha256=WbKtAXJUYKHvBrOJPEm_eXF9O9ekc0WdPmsQQCSGG5Q,1878
63
+ xax-0.1.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.14.dist-info/RECORD,,
File without changes