xax 0.2.17__tar.gz → 0.2.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.17/xax.egg-info → xax-0.2.19}/PKG-INFO +1 -1
- {xax-0.2.17 → xax-0.2.19}/xax/__init__.py +1 -1
- {xax-0.2.17 → xax-0.2.19}/xax/task/base.py +12 -4
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/artifacts.py +8 -2
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/checkpointing.py +9 -2
- {xax-0.2.17 → xax-0.2.19/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.17 → xax-0.2.19}/LICENSE +0 -0
- {xax-0.2.17 → xax-0.2.19}/MANIFEST.in +0 -0
- {xax-0.2.17 → xax-0.2.19}/README.md +0 -0
- {xax-0.2.17 → xax-0.2.19}/pyproject.toml +0 -0
- {xax-0.2.17 → xax-0.2.19}/setup.cfg +0 -0
- {xax-0.2.17 → xax-0.2.19}/setup.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/core/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/core/conf.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/core/state.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/embeddings.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/equinox.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/export.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/functions.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/geom.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/losses.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/metrics.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/parallel.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/nn/ssm.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/py.typed +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/requirements-dev.txt +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/requirements.txt +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/base.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/logger.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/json.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/state.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/process.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/mixins/train.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/script.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/task/task.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/data/collate.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/debugging.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/experiments.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/jax.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/logging.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/numpy.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/profile.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/pytree.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/text.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.17 → xax-0.2.19}/xax.egg-info/top_level.txt +0 -0
@@ -92,7 +92,11 @@ class BaseTask(Generic[Config]):
|
|
92
92
|
|
93
93
|
@functools.cached_property
|
94
94
|
def task_path(self) -> Path:
|
95
|
-
|
95
|
+
try:
|
96
|
+
return Path(inspect.getfile(self.__class__))
|
97
|
+
except OSError:
|
98
|
+
logger.warning("Could not resolve task path for %s, returning current working directory")
|
99
|
+
return Path.cwd()
|
96
100
|
|
97
101
|
@functools.cached_property
|
98
102
|
def task_module(self) -> str:
|
@@ -172,14 +176,18 @@ class BaseTask(Generic[Config]):
|
|
172
176
|
Returns:
|
173
177
|
The merged configs.
|
174
178
|
"""
|
175
|
-
|
179
|
+
try:
|
180
|
+
task_path = Path(inspect.getfile(cls))
|
181
|
+
except OSError:
|
182
|
+
logger.warning("Could not resolve task path for %s, returning current working directory", cls.__name__)
|
183
|
+
task_path = Path.cwd()
|
176
184
|
cfg = OmegaConf.structured(cls.get_config_class())
|
177
185
|
cfg = OmegaConf.merge(cfg, *(get_config(other_cfg, task_path) for other_cfg in cfgs))
|
178
186
|
if use_cli:
|
179
187
|
args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
|
180
188
|
if "-h" in args or "--help" in args:
|
181
|
-
sys.
|
182
|
-
sys.
|
189
|
+
sys.stdout.write(OmegaConf.to_yaml(cfg, sort_keys=True))
|
190
|
+
sys.stdout.flush()
|
183
191
|
sys.exit(0)
|
184
192
|
|
185
193
|
# Attempts to load any paths as configs.
|
@@ -43,8 +43,14 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
43
43
|
def run_dir(self) -> Path:
|
44
44
|
run_dir = get_run_dir()
|
45
45
|
if run_dir is None:
|
46
|
-
|
47
|
-
|
46
|
+
try:
|
47
|
+
task_file = inspect.getfile(self.__class__)
|
48
|
+
run_dir = Path(task_file).resolve().parent
|
49
|
+
except OSError:
|
50
|
+
logger.warning(
|
51
|
+
"Could not resolve task path for %s, returning current working directory", self.__class__.__name__
|
52
|
+
)
|
53
|
+
run_dir = Path.cwd()
|
48
54
|
return run_dir / self.task_name
|
49
55
|
|
50
56
|
@property
|
@@ -6,7 +6,7 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import Generic, Literal, Sequence, TypeVar, cast, overload
|
9
|
+
from typing import Generic, Literal, Self, Sequence, TypeVar, cast, overload
|
10
10
|
|
11
11
|
import equinox as eqx
|
12
12
|
import jax
|
@@ -46,7 +46,6 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
46
46
|
save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
|
47
47
|
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
48
48
|
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
49
|
-
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
49
|
|
51
50
|
|
52
51
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
@@ -306,3 +305,11 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
306
305
|
self.on_after_checkpoint_save(ckpt_path, state)
|
307
306
|
|
308
307
|
return ckpt_path
|
308
|
+
|
309
|
+
@classmethod
|
310
|
+
def load_config(cls, ckpt_path: str | Path) -> Config:
|
311
|
+
return cls.get_config(load_ckpt(Path(ckpt_path), part="config"), use_cli=False)
|
312
|
+
|
313
|
+
@classmethod
|
314
|
+
def load_task(cls, ckpt_path: str | Path) -> Self:
|
315
|
+
return cls(cls.load_config(ckpt_path))
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|