xax 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.2/xax.egg-info → xax-0.1.4}/PKG-INFO +1 -1
- {xax-0.1.2 → xax-0.1.4}/xax/__init__.py +1 -1
- {xax-0.1.2 → xax-0.1.4}/xax/task/base.py +2 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/checkpointing.py +0 -10
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/compile.py +14 -2
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/train.py +2 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/experiments.py +11 -0
- {xax-0.1.2 → xax-0.1.4/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.1.2 → xax-0.1.4}/LICENSE +0 -0
- {xax-0.1.2 → xax-0.1.4}/MANIFEST.in +0 -0
- {xax-0.1.2 → xax-0.1.4}/README.md +0 -0
- {xax-0.1.2 → xax-0.1.4}/pyproject.toml +0 -0
- {xax-0.1.2 → xax-0.1.4}/setup.cfg +0 -0
- {xax-0.1.2 → xax-0.1.4}/setup.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/core/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/core/conf.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/core/state.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/embeddings.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/equinox.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/export.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/functions.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/geom.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/norm.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/nn/parallel.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/py.typed +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/requirements-dev.txt +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/requirements.txt +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/base.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/logger.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/json.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/state.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/process.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/script.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/task/task.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/data/collate.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/debugging.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/jax.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/logging.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/numpy.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/profile.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/pytree.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/tensorboard.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax/utils/text.py +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.2 → xax-0.1.4}/xax.egg-info/top_level.txt +0 -0
@@ -153,6 +153,8 @@ class BaseTask(Generic[Config]):
|
|
153
153
|
for base in cls.__orig_bases__:
|
154
154
|
if hasattr(base, "__args__"):
|
155
155
|
for arg in base.__args__:
|
156
|
+
if isinstance(arg, TypeVar) and arg.__bound__ is not None:
|
157
|
+
arg = arg.__bound__
|
156
158
|
if issubclass(arg, BaseConfig):
|
157
159
|
return arg
|
158
160
|
|
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
47
47
|
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
48
48
|
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
49
49
|
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
|
-
save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
|
51
50
|
|
52
51
|
|
53
52
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
213
212
|
add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
|
214
213
|
add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
|
215
214
|
|
216
|
-
if self.config.save_tf_model:
|
217
|
-
try:
|
218
|
-
from jax.experimental import jax2tf
|
219
|
-
except ModuleNotFoundError:
|
220
|
-
raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
|
221
|
-
|
222
|
-
tf_model = jax2tf.convert(model)
|
223
|
-
add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
|
224
|
-
|
225
215
|
# Updates the symlink to the new checkpoint.
|
226
216
|
last_ckpt_path.unlink(missing_ok=True)
|
227
217
|
try:
|
@@ -5,6 +5,7 @@ behavior during initialization and training.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import logging
|
8
|
+
import sys
|
8
9
|
from dataclasses import dataclass
|
9
10
|
from pathlib import Path
|
10
11
|
from typing import Generic, TypeVar
|
@@ -17,6 +18,16 @@ from xax.task.base import BaseConfig, BaseTask
|
|
17
18
|
logger = logging.getLogger(__name__)
|
18
19
|
|
19
20
|
|
21
|
+
def get_cache_dir() -> str | None:
|
22
|
+
# By default, only cache on MacOS, since Jax caching on Linux is very
|
23
|
+
# prone to NaNs.
|
24
|
+
match sys.platform:
|
25
|
+
case "darwin":
|
26
|
+
return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
|
27
|
+
case _:
|
28
|
+
return None
|
29
|
+
|
30
|
+
|
20
31
|
@jax.tree_util.register_dataclass
|
21
32
|
@dataclass
|
22
33
|
class CompileOptions:
|
@@ -42,7 +53,8 @@ class CompileOptions:
|
|
42
53
|
|
43
54
|
# JAX cache options
|
44
55
|
cache_dir: str | None = field(
|
45
|
-
|
56
|
+
# Only cache by default on MacOS systems.
|
57
|
+
value=get_cache_dir,
|
46
58
|
help="Directory for JAX compilation cache. If None, caching is disabled",
|
47
59
|
)
|
48
60
|
cache_min_size_bytes: int = field(
|
@@ -54,7 +66,7 @@ class CompileOptions:
|
|
54
66
|
help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
|
55
67
|
)
|
56
68
|
cache_enable_xla: str = field(
|
57
|
-
value="
|
69
|
+
value="none",
|
58
70
|
help="Which XLA caches to enable",
|
59
71
|
)
|
60
72
|
|
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
|
|
50
50
|
diff_configs,
|
51
51
|
get_diff_string,
|
52
52
|
get_git_state,
|
53
|
+
get_packages_with_versions,
|
53
54
|
get_training_code,
|
54
55
|
)
|
55
56
|
from xax.utils.logging import LOG_STATUS
|
@@ -452,6 +453,7 @@ class TrainMixin(
|
|
452
453
|
logger.log(LOG_STATUS, self.task_name)
|
453
454
|
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
454
455
|
self.logger.log_file("git_state.txt", get_git_state(self))
|
456
|
+
self.logger.log_file("packages.txt", get_packages_with_versions())
|
455
457
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
456
458
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
457
459
|
|
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
|
|
28
28
|
from urllib.parse import urlparse
|
29
29
|
|
30
30
|
import git
|
31
|
+
import pkg_resources
|
31
32
|
import requests
|
32
33
|
from jaxtyping import Array
|
33
34
|
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
|
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
|
|
468
469
|
return traceback.format_exc()
|
469
470
|
|
470
471
|
|
472
|
+
def get_packages_with_versions() -> str:
|
473
|
+
"""Gets the packages and their versions.
|
474
|
+
|
475
|
+
Returns:
|
476
|
+
A dictionary of packages and their versions.
|
477
|
+
"""
|
478
|
+
packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
|
479
|
+
return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
|
480
|
+
|
481
|
+
|
471
482
|
def get_training_code(obj: object) -> str:
|
472
483
|
"""Gets the text from the file containing the provided object.
|
473
484
|
|
{xax-0.1.2 → xax-0.1.4}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.1.2 → xax-0.1.4}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|