xax 0.2.18__tar.gz → 0.2.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.2.18/xax.egg-info → xax-0.2.20}/PKG-INFO +1 -1
  2. {xax-0.2.18 → xax-0.2.20}/setup.py +5 -0
  3. {xax-0.2.18 → xax-0.2.20}/xax/__init__.py +1 -1
  4. xax-0.2.20/xax/cli/edit_config.py +67 -0
  5. {xax-0.2.18 → xax-0.2.20}/xax/task/base.py +13 -5
  6. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/artifacts.py +8 -2
  7. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/checkpointing.py +1 -1
  8. xax-0.2.20/xax/utils/types/__init__.py +0 -0
  9. {xax-0.2.18 → xax-0.2.20/xax.egg-info}/PKG-INFO +1 -1
  10. {xax-0.2.18 → xax-0.2.20}/xax.egg-info/SOURCES.txt +3 -0
  11. xax-0.2.20/xax.egg-info/entry_points.txt +2 -0
  12. {xax-0.2.18 → xax-0.2.20}/LICENSE +0 -0
  13. {xax-0.2.18 → xax-0.2.20}/MANIFEST.in +0 -0
  14. {xax-0.2.18 → xax-0.2.20}/README.md +0 -0
  15. {xax-0.2.18 → xax-0.2.20}/pyproject.toml +0 -0
  16. {xax-0.2.18 → xax-0.2.20}/setup.cfg +0 -0
  17. {xax-0.2.18/xax/core → xax-0.2.20/xax/cli}/__init__.py +0 -0
  18. {xax-0.2.18/xax/nn → xax-0.2.20/xax/core}/__init__.py +0 -0
  19. {xax-0.2.18 → xax-0.2.20}/xax/core/conf.py +0 -0
  20. {xax-0.2.18 → xax-0.2.20}/xax/core/state.py +0 -0
  21. {xax-0.2.18/xax/task → xax-0.2.20/xax/nn}/__init__.py +0 -0
  22. {xax-0.2.18 → xax-0.2.20}/xax/nn/embeddings.py +0 -0
  23. {xax-0.2.18 → xax-0.2.20}/xax/nn/equinox.py +0 -0
  24. {xax-0.2.18 → xax-0.2.20}/xax/nn/export.py +0 -0
  25. {xax-0.2.18 → xax-0.2.20}/xax/nn/functions.py +0 -0
  26. {xax-0.2.18 → xax-0.2.20}/xax/nn/geom.py +0 -0
  27. {xax-0.2.18 → xax-0.2.20}/xax/nn/losses.py +0 -0
  28. {xax-0.2.18 → xax-0.2.20}/xax/nn/metrics.py +0 -0
  29. {xax-0.2.18 → xax-0.2.20}/xax/nn/parallel.py +0 -0
  30. {xax-0.2.18 → xax-0.2.20}/xax/nn/ssm.py +0 -0
  31. {xax-0.2.18 → xax-0.2.20}/xax/py.typed +0 -0
  32. {xax-0.2.18 → xax-0.2.20}/xax/requirements-dev.txt +0 -0
  33. {xax-0.2.18 → xax-0.2.20}/xax/requirements.txt +0 -0
  34. {xax-0.2.18/xax/task/launchers → xax-0.2.20/xax/task}/__init__.py +0 -0
  35. {xax-0.2.18/xax/task/loggers → xax-0.2.20/xax/task/launchers}/__init__.py +0 -0
  36. {xax-0.2.18 → xax-0.2.20}/xax/task/launchers/base.py +0 -0
  37. {xax-0.2.18 → xax-0.2.20}/xax/task/launchers/cli.py +0 -0
  38. {xax-0.2.18 → xax-0.2.20}/xax/task/launchers/single_process.py +0 -0
  39. {xax-0.2.18 → xax-0.2.20}/xax/task/logger.py +0 -0
  40. {xax-0.2.18/xax/utils → xax-0.2.20/xax/task/loggers}/__init__.py +0 -0
  41. {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/callback.py +0 -0
  42. {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/json.py +0 -0
  43. {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/state.py +0 -0
  44. {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/stdout.py +0 -0
  45. {xax-0.2.18 → xax-0.2.20}/xax/task/loggers/tensorboard.py +0 -0
  46. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/__init__.py +0 -0
  47. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/data_loader.py +0 -0
  50. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/gpu_stats.py +0 -0
  51. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/logger.py +0 -0
  52. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/process.py +0 -0
  53. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/runnable.py +0 -0
  54. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/step_wrapper.py +0 -0
  55. {xax-0.2.18 → xax-0.2.20}/xax/task/mixins/train.py +0 -0
  56. {xax-0.2.18 → xax-0.2.20}/xax/task/script.py +0 -0
  57. {xax-0.2.18 → xax-0.2.20}/xax/task/task.py +0 -0
  58. {xax-0.2.18/xax/utils/data → xax-0.2.20/xax/utils}/__init__.py +0 -0
  59. {xax-0.2.18/xax/utils/types → xax-0.2.20/xax/utils/data}/__init__.py +0 -0
  60. {xax-0.2.18 → xax-0.2.20}/xax/utils/data/collate.py +0 -0
  61. {xax-0.2.18 → xax-0.2.20}/xax/utils/debugging.py +0 -0
  62. {xax-0.2.18 → xax-0.2.20}/xax/utils/experiments.py +0 -0
  63. {xax-0.2.18 → xax-0.2.20}/xax/utils/jax.py +0 -0
  64. {xax-0.2.18 → xax-0.2.20}/xax/utils/jaxpr.py +0 -0
  65. {xax-0.2.18 → xax-0.2.20}/xax/utils/logging.py +0 -0
  66. {xax-0.2.18 → xax-0.2.20}/xax/utils/numpy.py +0 -0
  67. {xax-0.2.18 → xax-0.2.20}/xax/utils/profile.py +0 -0
  68. {xax-0.2.18 → xax-0.2.20}/xax/utils/pytree.py +0 -0
  69. {xax-0.2.18 → xax-0.2.20}/xax/utils/tensorboard.py +0 -0
  70. {xax-0.2.18 → xax-0.2.20}/xax/utils/text.py +0 -0
  71. {xax-0.2.18 → xax-0.2.20}/xax/utils/types/frozen_dict.py +0 -0
  72. {xax-0.2.18 → xax-0.2.20}/xax/utils/types/hashable_array.py +0 -0
  73. {xax-0.2.18 → xax-0.2.20}/xax.egg-info/dependency_links.txt +0 -0
  74. {xax-0.2.18 → xax-0.2.20}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.2.18 → xax-0.2.20}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.18
3
+ Version: 0.2.20
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -48,4 +48,9 @@ setup(
48
48
  "requirements*.txt",
49
49
  ],
50
50
  },
51
+ entry_points={
52
+ "console_scripts": [
53
+ "xax-edit-config=xax.cli.edit_config:main",
54
+ ],
55
+ },
51
56
  )
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.18"
15
+ __version__ = "0.2.20"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -0,0 +1,67 @@
1
+ """Lets you edit a checkpoint config programmatically."""
2
+
3
+ import argparse
4
+ import difflib
5
+ import io
6
+ import os
7
+ import subprocess
8
+ import tarfile
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ from omegaconf import OmegaConf
13
+
14
+ from xax.task.mixins.checkpointing import load_ckpt
15
+ from xax.utils.text import colored, show_info
16
+
17
+
18
+ def main() -> None:
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("ckpt_path", type=Path)
21
+ args = parser.parse_args()
22
+
23
+ # Loads the config from the checkpoint.
24
+ config = load_ckpt(args.ckpt_path, part="config")
25
+ config_str = OmegaConf.to_yaml(config)
26
+
27
+ # Opens the user's preferred editor to edit the config.
28
+ with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
29
+ f.write(config_str.encode("utf-8"))
30
+ f.flush()
31
+ subprocess.run([os.environ.get("EDITOR", "vim"), f.name], check=True)
32
+
33
+ # Loads the edited config.
34
+ try:
35
+ edited_config = OmegaConf.load(f.name)
36
+ edited_config_str = OmegaConf.to_yaml(edited_config, sort_keys=True)
37
+ finally:
38
+ os.remove(f.name)
39
+
40
+ if edited_config_str == config_str:
41
+ show_info("No changes were made to the config.")
42
+ return
43
+
44
+ # Diffs the original and edited configs.
45
+ diff = difflib.ndiff(config_str.splitlines(), edited_config_str.splitlines())
46
+ for line in diff:
47
+ if line.startswith("+ "):
48
+ print(colored(line, "light-green"), flush=True)
49
+ elif line.startswith("- "):
50
+ print(colored(line, "light-red"), flush=True)
51
+ elif line.startswith("? "):
52
+ print(colored(line, "light-cyan"), flush=True)
53
+
54
+ # Saves the edited config to the checkpoint.
55
+ with tarfile.open(args.ckpt_path, "w:gz") as tar:
56
+
57
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
58
+ info = tarfile.TarInfo(name=name)
59
+ info.size = len(data)
60
+ tar.addfile(info, io.BytesIO(data))
61
+
62
+ add_file_bytes("config", edited_config_str.encode())
63
+
64
+
65
+ if __name__ == "__main__":
66
+ # python -m xax.cli.edit_config
67
+ main()
@@ -92,7 +92,11 @@ class BaseTask(Generic[Config]):
92
92
 
93
93
  @functools.cached_property
94
94
  def task_path(self) -> Path:
95
- return Path(inspect.getfile(self.__class__))
95
+ try:
96
+ return Path(inspect.getfile(self.__class__))
97
+ except OSError:
98
+ logger.warning("Could not resolve task path for %s, returning current working directory")
99
+ return Path.cwd()
96
100
 
97
101
  @functools.cached_property
98
102
  def task_module(self) -> str:
@@ -172,14 +176,18 @@ class BaseTask(Generic[Config]):
172
176
  Returns:
173
177
  The merged configs.
174
178
  """
175
- task_path = Path(inspect.getfile(cls))
179
+ try:
180
+ task_path = Path(inspect.getfile(cls))
181
+ except OSError:
182
+ logger.warning("Could not resolve task path for %s, returning current working directory", cls.__name__)
183
+ task_path = Path.cwd()
176
184
  cfg = OmegaConf.structured(cls.get_config_class())
177
185
  cfg = OmegaConf.merge(cfg, *(get_config(other_cfg, task_path) for other_cfg in cfgs))
178
186
  if use_cli:
179
187
  args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
180
188
  if "-h" in args or "--help" in args:
181
- sys.stderr.write(OmegaConf.to_yaml(cfg))
182
- sys.stderr.flush()
189
+ sys.stdout.write(OmegaConf.to_yaml(cfg, sort_keys=True))
190
+ sys.stdout.flush()
183
191
  sys.exit(0)
184
192
 
185
193
  # Attempts to load any paths as configs.
@@ -202,7 +210,7 @@ class BaseTask(Generic[Config]):
202
210
 
203
211
  @classmethod
204
212
  def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
205
- return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
213
+ return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli), sort_keys=True)
206
214
 
207
215
  @classmethod
208
216
  def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
@@ -43,8 +43,14 @@ class ArtifactsMixin(BaseTask[Config]):
43
43
  def run_dir(self) -> Path:
44
44
  run_dir = get_run_dir()
45
45
  if run_dir is None:
46
- task_file = inspect.getfile(self.__class__)
47
- run_dir = Path(task_file).resolve().parent
46
+ try:
47
+ task_file = inspect.getfile(self.__class__)
48
+ run_dir = Path(task_file).resolve().parent
49
+ except OSError:
50
+ logger.warning(
51
+ "Could not resolve task path for %s, returning current working directory", self.__class__.__name__
52
+ )
53
+ run_dir = Path.cwd()
48
54
  return run_dir / self.task_name
49
55
 
50
56
  @property
@@ -292,7 +292,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
292
292
 
293
293
  if state is not None:
294
294
  add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
295
- add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
295
+ add_file_bytes("config", OmegaConf.to_yaml(self.config, sort_keys=True).encode())
296
296
 
297
297
  # Updates the symlink to the new checkpoint
298
298
  last_ckpt_path.unlink(missing_ok=True)
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.18
3
+ Version: 0.2.20
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -11,8 +11,11 @@ xax/requirements.txt
11
11
  xax.egg-info/PKG-INFO
12
12
  xax.egg-info/SOURCES.txt
13
13
  xax.egg-info/dependency_links.txt
14
+ xax.egg-info/entry_points.txt
14
15
  xax.egg-info/requires.txt
15
16
  xax.egg-info/top_level.txt
17
+ xax/cli/__init__.py
18
+ xax/cli/edit_config.py
16
19
  xax/core/__init__.py
17
20
  xax/core/conf.py
18
21
  xax/core/state.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ xax-edit-config = xax.cli.edit_config:main
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes