xax 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.19"
15
+ __version__ = "0.2.20"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/cli/__init__.py ADDED
File without changes
xax/cli/edit_config.py ADDED
@@ -0,0 +1,67 @@
1
+ """Lets you edit a checkpoint config programmatically."""
2
+
3
+ import argparse
4
+ import difflib
5
+ import io
6
+ import os
7
+ import subprocess
8
+ import tarfile
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ from omegaconf import OmegaConf
13
+
14
+ from xax.task.mixins.checkpointing import load_ckpt
15
+ from xax.utils.text import colored, show_info
16
+
17
+
18
+ def main() -> None:
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("ckpt_path", type=Path)
21
+ args = parser.parse_args()
22
+
23
+ # Loads the config from the checkpoint.
24
+ config = load_ckpt(args.ckpt_path, part="config")
25
+ config_str = OmegaConf.to_yaml(config)
26
+
27
+ # Opens the user's preferred editor to edit the config.
28
+ with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
29
+ f.write(config_str.encode("utf-8"))
30
+ f.flush()
31
+ subprocess.run([os.environ.get("EDITOR", "vim"), f.name], check=True)
32
+
33
+ # Loads the edited config.
34
+ try:
35
+ edited_config = OmegaConf.load(f.name)
36
+ edited_config_str = OmegaConf.to_yaml(edited_config, sort_keys=True)
37
+ finally:
38
+ os.remove(f.name)
39
+
40
+ if edited_config_str == config_str:
41
+ show_info("No changes were made to the config.")
42
+ return
43
+
44
+ # Diffs the original and edited configs.
45
+ diff = difflib.ndiff(config_str.splitlines(), edited_config_str.splitlines())
46
+ for line in diff:
47
+ if line.startswith("+ "):
48
+ print(colored(line, "light-green"), flush=True)
49
+ elif line.startswith("- "):
50
+ print(colored(line, "light-red"), flush=True)
51
+ elif line.startswith("? "):
52
+ print(colored(line, "light-cyan"), flush=True)
53
+
54
+ # Saves the edited config to the checkpoint.
55
+ with tarfile.open(args.ckpt_path, "w:gz") as tar:
56
+
57
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
58
+ info = tarfile.TarInfo(name=name)
59
+ info.size = len(data)
60
+ tar.addfile(info, io.BytesIO(data))
61
+
62
+ add_file_bytes("config", edited_config_str.encode())
63
+
64
+
65
+ if __name__ == "__main__":
66
+ # python -m xax.cli.edit_config
67
+ main()
xax/task/base.py CHANGED
@@ -210,7 +210,7 @@ class BaseTask(Generic[Config]):
210
210
 
211
211
  @classmethod
212
212
  def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
213
- return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
213
+ return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli), sort_keys=True)
214
214
 
215
215
  @classmethod
216
216
  def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
@@ -292,7 +292,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
292
292
 
293
293
  if state is not None:
294
294
  add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
295
- add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
295
+ add_file_bytes("config", OmegaConf.to_yaml(self.config, sort_keys=True).encode())
296
296
 
297
297
  # Updates the symlink to the new checkpoint
298
298
  last_ckpt_path.unlink(missing_ok=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.19
3
+ Version: 0.2.20
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,7 +1,9 @@
1
- xax/__init__.py,sha256=Pv4UWs4GGvojcnBDz4hk9whnU4ZEklWXmTMaTwlSZPM,15733
1
+ xax/__init__.py,sha256=yWQcHMlP2cKIpfJCJLXv796F-AAHQhS-1sRxu7871mw,15733
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
5
+ xax/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ xax/cli/edit_config.py,sha256=99x_k6aNimbcebi2vSJhln-cv4364h6GQdRccuv_qcs,2069
5
7
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
8
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
7
9
  xax/core/state.py,sha256=KsNMnM_RgsZ2Ntc2pp4Fi6zG4rZb_89-kqmyGxDvyRg,4974
@@ -16,7 +18,7 @@ xax/nn/metrics.py,sha256=OAkeScwhi-wTBIJ59KHUhYbZTq4V4V-LG-mKlxMJ7bY,3238
16
18
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
19
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
20
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- xax/task/base.py,sha256=qF-ed7bFrBIRHxM7hVPLTGRYYwiLp98H0b6NWF80LTQ,8101
21
+ xax/task/base.py,sha256=TYANmjNcce4_V5ZSYLnE91PXRn7Nn0nT7hN8plW_Au0,8117
20
22
  xax/task/logger.py,sha256=W_BpluYvQai1lh1dDCAj-2_mWUC1buhwJncHygDffjc,41125
21
23
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
24
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
@@ -32,7 +34,7 @@ xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,66
32
34
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
33
35
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
36
  xax/task/mixins/artifacts.py,sha256=IBPAQMCGd7PQiZHfSjLakPW5j7cNuL6AsW6QkVSc02E,3277
35
- xax/task/mixins/checkpointing.py,sha256=w-xsx8tH8PPYRf9eoqvz-4kAdunHMvazcoZAIj6niyI,11468
37
+ xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
36
38
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
39
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
38
40
  xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
@@ -58,8 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
60
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
61
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
60
62
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.19.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.19.dist-info/METADATA,sha256=fmZSs-aNP8Rwo9bPOnb9Fiu9UUlZ_5mxXnVl4fZkDMw,1880
63
- xax-0.2.19.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
64
- xax-0.2.19.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.19.dist-info/RECORD,,
63
+ xax-0.2.20.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
+ xax-0.2.20.dist-info/METADATA,sha256=YCcDox7HsHIVUeDeZYCvoaGYhsU7TzspX-v-Xw0-H4g,1880
65
+ xax-0.2.20.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
66
+ xax-0.2.20.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
+ xax-0.2.20.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
+ xax-0.2.20.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.1.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ xax-edit-config = xax.cli.edit_config:main