xax 0.2.18__tar.gz → 0.2.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.18/xax.egg-info → xax-0.2.20}/PKG-INFO +1 -1
- {xax-0.2.18 → xax-0.2.20}/setup.py +5 -0
- {xax-0.2.18 → xax-0.2.20}/xax/__init__.py +1 -1
- xax-0.2.20/xax/cli/edit_config.py +67 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/base.py +13 -5
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/artifacts.py +8 -2
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/checkpointing.py +1 -1
- xax-0.2.20/xax/utils/types/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.18 → xax-0.2.20}/xax.egg-info/SOURCES.txt +3 -0
- xax-0.2.20/xax.egg-info/entry_points.txt +2 -0
- {xax-0.2.18 → xax-0.2.20}/LICENSE +0 -0
- {xax-0.2.18 → xax-0.2.20}/MANIFEST.in +0 -0
- {xax-0.2.18 → xax-0.2.20}/README.md +0 -0
- {xax-0.2.18 → xax-0.2.20}/pyproject.toml +0 -0
- {xax-0.2.18 → xax-0.2.20}/setup.cfg +0 -0
- {xax-0.2.18/xax/core → xax-0.2.20/xax/cli}/__init__.py +0 -0
- {xax-0.2.18/xax/nn → xax-0.2.20/xax/core}/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/core/conf.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/core/state.py +0 -0
- {xax-0.2.18/xax/task → xax-0.2.20/xax/nn}/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/embeddings.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/equinox.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/export.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/functions.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/geom.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/losses.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/metrics.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/parallel.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/nn/ssm.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/py.typed +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/requirements-dev.txt +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/requirements.txt +0 -0
- {xax-0.2.18/xax/task/launchers → xax-0.2.20/xax/task}/__init__.py +0 -0
- {xax-0.2.18/xax/task/loggers → xax-0.2.20/xax/task/launchers}/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/launchers/base.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/logger.py +0 -0
- {xax-0.2.18/xax/utils → xax-0.2.20/xax/task/loggers}/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/json.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/state.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/process.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/train.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/script.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/task/task.py +0 -0
- {xax-0.2.18/xax/utils/data → xax-0.2.20/xax/utils}/__init__.py +0 -0
- {xax-0.2.18/xax/utils/types → xax-0.2.20/xax/utils/data}/__init__.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/data/collate.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/debugging.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/experiments.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/jax.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/logging.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/numpy.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/profile.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/pytree.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/text.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.18 → xax-0.2.20}/xax.egg-info/top_level.txt +0 -0
@@ -0,0 +1,67 @@
|
|
1
|
+
"""Lets you edit a checkpoint config programmatically."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import difflib
|
5
|
+
import io
|
6
|
+
import os
|
7
|
+
import subprocess
|
8
|
+
import tarfile
|
9
|
+
import tempfile
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from omegaconf import OmegaConf
|
13
|
+
|
14
|
+
from xax.task.mixins.checkpointing import load_ckpt
|
15
|
+
from xax.utils.text import colored, show_info
|
16
|
+
|
17
|
+
|
18
|
+
def main() -> None:
|
19
|
+
parser = argparse.ArgumentParser()
|
20
|
+
parser.add_argument("ckpt_path", type=Path)
|
21
|
+
args = parser.parse_args()
|
22
|
+
|
23
|
+
# Loads the config from the checkpoint.
|
24
|
+
config = load_ckpt(args.ckpt_path, part="config")
|
25
|
+
config_str = OmegaConf.to_yaml(config)
|
26
|
+
|
27
|
+
# Opens the user's preferred editor to edit the config.
|
28
|
+
with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
|
29
|
+
f.write(config_str.encode("utf-8"))
|
30
|
+
f.flush()
|
31
|
+
subprocess.run([os.environ.get("EDITOR", "vim"), f.name], check=True)
|
32
|
+
|
33
|
+
# Loads the edited config.
|
34
|
+
try:
|
35
|
+
edited_config = OmegaConf.load(f.name)
|
36
|
+
edited_config_str = OmegaConf.to_yaml(edited_config, sort_keys=True)
|
37
|
+
finally:
|
38
|
+
os.remove(f.name)
|
39
|
+
|
40
|
+
if edited_config_str == config_str:
|
41
|
+
show_info("No changes were made to the config.")
|
42
|
+
return
|
43
|
+
|
44
|
+
# Diffs the original and edited configs.
|
45
|
+
diff = difflib.ndiff(config_str.splitlines(), edited_config_str.splitlines())
|
46
|
+
for line in diff:
|
47
|
+
if line.startswith("+ "):
|
48
|
+
print(colored(line, "light-green"), flush=True)
|
49
|
+
elif line.startswith("- "):
|
50
|
+
print(colored(line, "light-red"), flush=True)
|
51
|
+
elif line.startswith("? "):
|
52
|
+
print(colored(line, "light-cyan"), flush=True)
|
53
|
+
|
54
|
+
# Saves the edited config to the checkpoint.
|
55
|
+
with tarfile.open(args.ckpt_path, "w:gz") as tar:
|
56
|
+
|
57
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
58
|
+
info = tarfile.TarInfo(name=name)
|
59
|
+
info.size = len(data)
|
60
|
+
tar.addfile(info, io.BytesIO(data))
|
61
|
+
|
62
|
+
add_file_bytes("config", edited_config_str.encode())
|
63
|
+
|
64
|
+
|
65
|
+
if __name__ == "__main__":
|
66
|
+
# python -m xax.cli.edit_config
|
67
|
+
main()
|
@@ -92,7 +92,11 @@ class BaseTask(Generic[Config]):
|
|
92
92
|
|
93
93
|
@functools.cached_property
|
94
94
|
def task_path(self) -> Path:
|
95
|
-
|
95
|
+
try:
|
96
|
+
return Path(inspect.getfile(self.__class__))
|
97
|
+
except OSError:
|
98
|
+
logger.warning("Could not resolve task path for %s, returning current working directory")
|
99
|
+
return Path.cwd()
|
96
100
|
|
97
101
|
@functools.cached_property
|
98
102
|
def task_module(self) -> str:
|
@@ -172,14 +176,18 @@ class BaseTask(Generic[Config]):
|
|
172
176
|
Returns:
|
173
177
|
The merged configs.
|
174
178
|
"""
|
175
|
-
|
179
|
+
try:
|
180
|
+
task_path = Path(inspect.getfile(cls))
|
181
|
+
except OSError:
|
182
|
+
logger.warning("Could not resolve task path for %s, returning current working directory", cls.__name__)
|
183
|
+
task_path = Path.cwd()
|
176
184
|
cfg = OmegaConf.structured(cls.get_config_class())
|
177
185
|
cfg = OmegaConf.merge(cfg, *(get_config(other_cfg, task_path) for other_cfg in cfgs))
|
178
186
|
if use_cli:
|
179
187
|
args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
|
180
188
|
if "-h" in args or "--help" in args:
|
181
|
-
sys.
|
182
|
-
sys.
|
189
|
+
sys.stdout.write(OmegaConf.to_yaml(cfg, sort_keys=True))
|
190
|
+
sys.stdout.flush()
|
183
191
|
sys.exit(0)
|
184
192
|
|
185
193
|
# Attempts to load any paths as configs.
|
@@ -202,7 +210,7 @@ class BaseTask(Generic[Config]):
|
|
202
210
|
|
203
211
|
@classmethod
|
204
212
|
def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
|
205
|
-
return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
|
213
|
+
return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli), sort_keys=True)
|
206
214
|
|
207
215
|
@classmethod
|
208
216
|
def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
|
@@ -43,8 +43,14 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
43
43
|
def run_dir(self) -> Path:
|
44
44
|
run_dir = get_run_dir()
|
45
45
|
if run_dir is None:
|
46
|
-
|
47
|
-
|
46
|
+
try:
|
47
|
+
task_file = inspect.getfile(self.__class__)
|
48
|
+
run_dir = Path(task_file).resolve().parent
|
49
|
+
except OSError:
|
50
|
+
logger.warning(
|
51
|
+
"Could not resolve task path for %s, returning current working directory", self.__class__.__name__
|
52
|
+
)
|
53
|
+
run_dir = Path.cwd()
|
48
54
|
return run_dir / self.task_name
|
49
55
|
|
50
56
|
@property
|
@@ -292,7 +292,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
292
292
|
|
293
293
|
if state is not None:
|
294
294
|
add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
|
295
|
-
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
295
|
+
add_file_bytes("config", OmegaConf.to_yaml(self.config, sort_keys=True).encode())
|
296
296
|
|
297
297
|
# Updates the symlink to the new checkpoint
|
298
298
|
last_ckpt_path.unlink(missing_ok=True)
|
File without changes
|
@@ -11,8 +11,11 @@ xax/requirements.txt
|
|
11
11
|
xax.egg-info/PKG-INFO
|
12
12
|
xax.egg-info/SOURCES.txt
|
13
13
|
xax.egg-info/dependency_links.txt
|
14
|
+
xax.egg-info/entry_points.txt
|
14
15
|
xax.egg-info/requires.txt
|
15
16
|
xax.egg-info/top_level.txt
|
17
|
+
xax/cli/__init__.py
|
18
|
+
xax/cli/edit_config.py
|
16
19
|
xax/core/__init__.py
|
17
20
|
xax/core/conf.py
|
18
21
|
xax/core/state.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|