xax 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/cli/__init__.py +0 -0
- xax/cli/edit_config.py +67 -0
- xax/task/base.py +1 -1
- xax/task/mixins/checkpointing.py +1 -1
- {xax-0.2.19.dist-info → xax-0.2.20.dist-info}/METADATA +1 -1
- {xax-0.2.19.dist-info → xax-0.2.20.dist-info}/RECORD +11 -8
- {xax-0.2.19.dist-info → xax-0.2.20.dist-info}/WHEEL +1 -1
- xax-0.2.20.dist-info/entry_points.txt +2 -0
- {xax-0.2.19.dist-info → xax-0.2.20.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.19.dist-info → xax-0.2.20.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/cli/__init__.py
ADDED
File without changes
|
xax/cli/edit_config.py
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
"""Lets you edit a checkpoint config programmatically."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import difflib
|
5
|
+
import io
|
6
|
+
import os
|
7
|
+
import subprocess
|
8
|
+
import tarfile
|
9
|
+
import tempfile
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from omegaconf import OmegaConf
|
13
|
+
|
14
|
+
from xax.task.mixins.checkpointing import load_ckpt
|
15
|
+
from xax.utils.text import colored, show_info
|
16
|
+
|
17
|
+
|
18
|
+
def main() -> None:
|
19
|
+
parser = argparse.ArgumentParser()
|
20
|
+
parser.add_argument("ckpt_path", type=Path)
|
21
|
+
args = parser.parse_args()
|
22
|
+
|
23
|
+
# Loads the config from the checkpoint.
|
24
|
+
config = load_ckpt(args.ckpt_path, part="config")
|
25
|
+
config_str = OmegaConf.to_yaml(config)
|
26
|
+
|
27
|
+
# Opens the user's preferred editor to edit the config.
|
28
|
+
with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
|
29
|
+
f.write(config_str.encode("utf-8"))
|
30
|
+
f.flush()
|
31
|
+
subprocess.run([os.environ.get("EDITOR", "vim"), f.name], check=True)
|
32
|
+
|
33
|
+
# Loads the edited config.
|
34
|
+
try:
|
35
|
+
edited_config = OmegaConf.load(f.name)
|
36
|
+
edited_config_str = OmegaConf.to_yaml(edited_config, sort_keys=True)
|
37
|
+
finally:
|
38
|
+
os.remove(f.name)
|
39
|
+
|
40
|
+
if edited_config_str == config_str:
|
41
|
+
show_info("No changes were made to the config.")
|
42
|
+
return
|
43
|
+
|
44
|
+
# Diffs the original and edited configs.
|
45
|
+
diff = difflib.ndiff(config_str.splitlines(), edited_config_str.splitlines())
|
46
|
+
for line in diff:
|
47
|
+
if line.startswith("+ "):
|
48
|
+
print(colored(line, "light-green"), flush=True)
|
49
|
+
elif line.startswith("- "):
|
50
|
+
print(colored(line, "light-red"), flush=True)
|
51
|
+
elif line.startswith("? "):
|
52
|
+
print(colored(line, "light-cyan"), flush=True)
|
53
|
+
|
54
|
+
# Saves the edited config to the checkpoint.
|
55
|
+
with tarfile.open(args.ckpt_path, "w:gz") as tar:
|
56
|
+
|
57
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
58
|
+
info = tarfile.TarInfo(name=name)
|
59
|
+
info.size = len(data)
|
60
|
+
tar.addfile(info, io.BytesIO(data))
|
61
|
+
|
62
|
+
add_file_bytes("config", edited_config_str.encode())
|
63
|
+
|
64
|
+
|
65
|
+
if __name__ == "__main__":
|
66
|
+
# python -m xax.cli.edit_config
|
67
|
+
main()
|
xax/task/base.py
CHANGED
@@ -210,7 +210,7 @@ class BaseTask(Generic[Config]):
|
|
210
210
|
|
211
211
|
@classmethod
|
212
212
|
def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
|
213
|
-
return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
|
213
|
+
return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli), sort_keys=True)
|
214
214
|
|
215
215
|
@classmethod
|
216
216
|
def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -292,7 +292,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
292
292
|
|
293
293
|
if state is not None:
|
294
294
|
add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
|
295
|
-
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
295
|
+
add_file_bytes("config", OmegaConf.to_yaml(self.config, sort_keys=True).encode())
|
296
296
|
|
297
297
|
# Updates the symlink to the new checkpoint
|
298
298
|
last_ckpt_path.unlink(missing_ok=True)
|
@@ -1,7 +1,9 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=yWQcHMlP2cKIpfJCJLXv796F-AAHQhS-1sRxu7871mw,15733
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
|
+
xax/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
xax/cli/edit_config.py,sha256=99x_k6aNimbcebi2vSJhln-cv4364h6GQdRccuv_qcs,2069
|
5
7
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
7
9
|
xax/core/state.py,sha256=KsNMnM_RgsZ2Ntc2pp4Fi6zG4rZb_89-kqmyGxDvyRg,4974
|
@@ -16,7 +18,7 @@ xax/nn/metrics.py,sha256=OAkeScwhi-wTBIJ59KHUhYbZTq4V4V-LG-mKlxMJ7bY,3238
|
|
16
18
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
19
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
20
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
21
|
+
xax/task/base.py,sha256=TYANmjNcce4_V5ZSYLnE91PXRn7Nn0nT7hN8plW_Au0,8117
|
20
22
|
xax/task/logger.py,sha256=W_BpluYvQai1lh1dDCAj-2_mWUC1buhwJncHygDffjc,41125
|
21
23
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
24
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
@@ -32,7 +34,7 @@ xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,66
|
|
32
34
|
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
33
35
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
36
|
xax/task/mixins/artifacts.py,sha256=IBPAQMCGd7PQiZHfSjLakPW5j7cNuL6AsW6QkVSc02E,3277
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
37
|
+
xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
|
36
38
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
39
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
40
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
@@ -58,8 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
60
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
61
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
62
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
63
|
+
xax-0.2.20.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
64
|
+
xax-0.2.20.dist-info/METADATA,sha256=YCcDox7HsHIVUeDeZYCvoaGYhsU7TzspX-v-Xw0-H4g,1880
|
65
|
+
xax-0.2.20.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
66
|
+
xax-0.2.20.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
67
|
+
xax-0.2.20.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
68
|
+
xax-0.2.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|