xax 0.2.17__tar.gz → 0.2.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.17/xax.egg-info → xax-0.2.19}/PKG-INFO +1 -1
  2. {xax-0.2.17 → xax-0.2.19}/xax/__init__.py +1 -1
  3. {xax-0.2.17 → xax-0.2.19}/xax/task/base.py +12 -4
  4. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/artifacts.py +8 -2
  5. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/checkpointing.py +9 -2
  6. {xax-0.2.17 → xax-0.2.19/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.2.17 → xax-0.2.19}/LICENSE +0 -0
  8. {xax-0.2.17 → xax-0.2.19}/MANIFEST.in +0 -0
  9. {xax-0.2.17 → xax-0.2.19}/README.md +0 -0
  10. {xax-0.2.17 → xax-0.2.19}/pyproject.toml +0 -0
  11. {xax-0.2.17 → xax-0.2.19}/setup.cfg +0 -0
  12. {xax-0.2.17 → xax-0.2.19}/setup.py +0 -0
  13. {xax-0.2.17 → xax-0.2.19}/xax/core/__init__.py +0 -0
  14. {xax-0.2.17 → xax-0.2.19}/xax/core/conf.py +0 -0
  15. {xax-0.2.17 → xax-0.2.19}/xax/core/state.py +0 -0
  16. {xax-0.2.17 → xax-0.2.19}/xax/nn/__init__.py +0 -0
  17. {xax-0.2.17 → xax-0.2.19}/xax/nn/embeddings.py +0 -0
  18. {xax-0.2.17 → xax-0.2.19}/xax/nn/equinox.py +0 -0
  19. {xax-0.2.17 → xax-0.2.19}/xax/nn/export.py +0 -0
  20. {xax-0.2.17 → xax-0.2.19}/xax/nn/functions.py +0 -0
  21. {xax-0.2.17 → xax-0.2.19}/xax/nn/geom.py +0 -0
  22. {xax-0.2.17 → xax-0.2.19}/xax/nn/losses.py +0 -0
  23. {xax-0.2.17 → xax-0.2.19}/xax/nn/metrics.py +0 -0
  24. {xax-0.2.17 → xax-0.2.19}/xax/nn/parallel.py +0 -0
  25. {xax-0.2.17 → xax-0.2.19}/xax/nn/ssm.py +0 -0
  26. {xax-0.2.17 → xax-0.2.19}/xax/py.typed +0 -0
  27. {xax-0.2.17 → xax-0.2.19}/xax/requirements-dev.txt +0 -0
  28. {xax-0.2.17 → xax-0.2.19}/xax/requirements.txt +0 -0
  29. {xax-0.2.17 → xax-0.2.19}/xax/task/__init__.py +0 -0
  30. {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/base.py +0 -0
  32. {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.2.17 → xax-0.2.19}/xax/task/logger.py +0 -0
  35. {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/json.py +0 -0
  38. {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/state.py +0 -0
  39. {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/compile.py +0 -0
  43. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/process.py +0 -0
  48. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/train.py +0 -0
  51. {xax-0.2.17 → xax-0.2.19}/xax/task/script.py +0 -0
  52. {xax-0.2.17 → xax-0.2.19}/xax/task/task.py +0 -0
  53. {xax-0.2.17 → xax-0.2.19}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.17 → xax-0.2.19}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.17 → xax-0.2.19}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.17 → xax-0.2.19}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.17 → xax-0.2.19}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.17 → xax-0.2.19}/xax/utils/jax.py +0 -0
  59. {xax-0.2.17 → xax-0.2.19}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.17 → xax-0.2.19}/xax/utils/logging.py +0 -0
  61. {xax-0.2.17 → xax-0.2.19}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.17 → xax-0.2.19}/xax/utils/profile.py +0 -0
  63. {xax-0.2.17 → xax-0.2.19}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.17 → xax-0.2.19}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.17 → xax-0.2.19}/xax/utils/text.py +0 -0
  66. {xax-0.2.17 → xax-0.2.19}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.17 → xax-0.2.19}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.17 → xax-0.2.19}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.17 → xax-0.2.19}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.17 → xax-0.2.19}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.17 → xax-0.2.19}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.17 → xax-0.2.19}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.17
3
+ Version: 0.2.19
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.17"
15
+ __version__ = "0.2.19"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -92,7 +92,11 @@ class BaseTask(Generic[Config]):
92
92
 
93
93
  @functools.cached_property
94
94
  def task_path(self) -> Path:
95
- return Path(inspect.getfile(self.__class__))
95
+ try:
96
+ return Path(inspect.getfile(self.__class__))
97
+ except OSError:
98
+ logger.warning("Could not resolve task path for %s, returning current working directory")
99
+ return Path.cwd()
96
100
 
97
101
  @functools.cached_property
98
102
  def task_module(self) -> str:
@@ -172,14 +176,18 @@ class BaseTask(Generic[Config]):
172
176
  Returns:
173
177
  The merged configs.
174
178
  """
175
- task_path = Path(inspect.getfile(cls))
179
+ try:
180
+ task_path = Path(inspect.getfile(cls))
181
+ except OSError:
182
+ logger.warning("Could not resolve task path for %s, returning current working directory", cls.__name__)
183
+ task_path = Path.cwd()
176
184
  cfg = OmegaConf.structured(cls.get_config_class())
177
185
  cfg = OmegaConf.merge(cfg, *(get_config(other_cfg, task_path) for other_cfg in cfgs))
178
186
  if use_cli:
179
187
  args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
180
188
  if "-h" in args or "--help" in args:
181
- sys.stderr.write(OmegaConf.to_yaml(cfg))
182
- sys.stderr.flush()
189
+ sys.stdout.write(OmegaConf.to_yaml(cfg, sort_keys=True))
190
+ sys.stdout.flush()
183
191
  sys.exit(0)
184
192
 
185
193
  # Attempts to load any paths as configs.
@@ -43,8 +43,14 @@ class ArtifactsMixin(BaseTask[Config]):
43
43
  def run_dir(self) -> Path:
44
44
  run_dir = get_run_dir()
45
45
  if run_dir is None:
46
- task_file = inspect.getfile(self.__class__)
47
- run_dir = Path(task_file).resolve().parent
46
+ try:
47
+ task_file = inspect.getfile(self.__class__)
48
+ run_dir = Path(task_file).resolve().parent
49
+ except OSError:
50
+ logger.warning(
51
+ "Could not resolve task path for %s, returning current working directory", self.__class__.__name__
52
+ )
53
+ run_dir = Path.cwd()
48
54
  return run_dir / self.task_name
49
55
 
50
56
  @property
@@ -6,7 +6,7 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import dataclass
8
8
  from pathlib import Path
9
- from typing import Generic, Literal, Sequence, TypeVar, cast, overload
9
+ from typing import Generic, Literal, Self, Sequence, TypeVar, cast, overload
10
10
 
11
11
  import equinox as eqx
12
12
  import jax
@@ -46,7 +46,6 @@ class CheckpointingConfig(ArtifactsConfig):
46
46
  save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
- load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
49
 
51
50
 
52
51
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -306,3 +305,11 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
306
305
  self.on_after_checkpoint_save(ckpt_path, state)
307
306
 
308
307
  return ckpt_path
308
+
309
+ @classmethod
310
+ def load_config(cls, ckpt_path: str | Path) -> Config:
311
+ return cls.get_config(load_ckpt(Path(ckpt_path), part="config"), use_cli=False)
312
+
313
+ @classmethod
314
+ def load_task(cls, ckpt_path: str | Path) -> Self:
315
+ return cls(cls.load_config(ckpt_path))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.17
3
+ Version: 0.2.19
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes