xax 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.2/xax.egg-info → xax-0.1.4}/PKG-INFO +1 -1
  2. {xax-0.1.2 → xax-0.1.4}/xax/__init__.py +1 -1
  3. {xax-0.1.2 → xax-0.1.4}/xax/task/base.py +2 -0
  4. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/checkpointing.py +0 -10
  5. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/compile.py +14 -2
  6. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/train.py +2 -0
  7. {xax-0.1.2 → xax-0.1.4}/xax/utils/experiments.py +11 -0
  8. {xax-0.1.2 → xax-0.1.4/xax.egg-info}/PKG-INFO +1 -1
  9. {xax-0.1.2 → xax-0.1.4}/LICENSE +0 -0
  10. {xax-0.1.2 → xax-0.1.4}/MANIFEST.in +0 -0
  11. {xax-0.1.2 → xax-0.1.4}/README.md +0 -0
  12. {xax-0.1.2 → xax-0.1.4}/pyproject.toml +0 -0
  13. {xax-0.1.2 → xax-0.1.4}/setup.cfg +0 -0
  14. {xax-0.1.2 → xax-0.1.4}/setup.py +0 -0
  15. {xax-0.1.2 → xax-0.1.4}/xax/core/__init__.py +0 -0
  16. {xax-0.1.2 → xax-0.1.4}/xax/core/conf.py +0 -0
  17. {xax-0.1.2 → xax-0.1.4}/xax/core/state.py +0 -0
  18. {xax-0.1.2 → xax-0.1.4}/xax/nn/__init__.py +0 -0
  19. {xax-0.1.2 → xax-0.1.4}/xax/nn/embeddings.py +0 -0
  20. {xax-0.1.2 → xax-0.1.4}/xax/nn/equinox.py +0 -0
  21. {xax-0.1.2 → xax-0.1.4}/xax/nn/export.py +0 -0
  22. {xax-0.1.2 → xax-0.1.4}/xax/nn/functions.py +0 -0
  23. {xax-0.1.2 → xax-0.1.4}/xax/nn/geom.py +0 -0
  24. {xax-0.1.2 → xax-0.1.4}/xax/nn/norm.py +0 -0
  25. {xax-0.1.2 → xax-0.1.4}/xax/nn/parallel.py +0 -0
  26. {xax-0.1.2 → xax-0.1.4}/xax/py.typed +0 -0
  27. {xax-0.1.2 → xax-0.1.4}/xax/requirements-dev.txt +0 -0
  28. {xax-0.1.2 → xax-0.1.4}/xax/requirements.txt +0 -0
  29. {xax-0.1.2 → xax-0.1.4}/xax/task/__init__.py +0 -0
  30. {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/base.py +0 -0
  32. {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.1.2 → xax-0.1.4}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.1.2 → xax-0.1.4}/xax/task/logger.py +0 -0
  35. {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/json.py +0 -0
  38. {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/state.py +0 -0
  39. {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.1.2 → xax-0.1.4}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/process.py +0 -0
  48. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.1.2 → xax-0.1.4}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.1.2 → xax-0.1.4}/xax/task/script.py +0 -0
  51. {xax-0.1.2 → xax-0.1.4}/xax/task/task.py +0 -0
  52. {xax-0.1.2 → xax-0.1.4}/xax/utils/__init__.py +0 -0
  53. {xax-0.1.2 → xax-0.1.4}/xax/utils/data/__init__.py +0 -0
  54. {xax-0.1.2 → xax-0.1.4}/xax/utils/data/collate.py +0 -0
  55. {xax-0.1.2 → xax-0.1.4}/xax/utils/debugging.py +0 -0
  56. {xax-0.1.2 → xax-0.1.4}/xax/utils/jax.py +0 -0
  57. {xax-0.1.2 → xax-0.1.4}/xax/utils/jaxpr.py +0 -0
  58. {xax-0.1.2 → xax-0.1.4}/xax/utils/logging.py +0 -0
  59. {xax-0.1.2 → xax-0.1.4}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.2 → xax-0.1.4}/xax/utils/profile.py +0 -0
  61. {xax-0.1.2 → xax-0.1.4}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.2 → xax-0.1.4}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.2 → xax-0.1.4}/xax/utils/text.py +0 -0
  64. {xax-0.1.2 → xax-0.1.4}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.2 → xax-0.1.4}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.2 → xax-0.1.4}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.2 → xax-0.1.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.2"
15
+ __version__ = "0.1.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -153,6 +153,8 @@ class BaseTask(Generic[Config]):
153
153
  for base in cls.__orig_bases__:
154
154
  if hasattr(base, "__args__"):
155
155
  for arg in base.__args__:
156
+ if isinstance(arg, TypeVar) and arg.__bound__ is not None:
157
+ arg = arg.__bound__
156
158
  if issubclass(arg, BaseConfig):
157
159
  return arg
158
160
 
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
49
  load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
- save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
51
50
 
52
51
 
53
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
213
212
  add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
214
213
  add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
215
214
 
216
- if self.config.save_tf_model:
217
- try:
218
- from jax.experimental import jax2tf
219
- except ModuleNotFoundError:
220
- raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
221
-
222
- tf_model = jax2tf.convert(model)
223
- add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
224
-
225
215
  # Updates the symlink to the new checkpoint.
226
216
  last_ckpt_path.unlink(missing_ok=True)
227
217
  try:
@@ -5,6 +5,7 @@ behavior during initialization and training.
5
5
  """
6
6
 
7
7
  import logging
8
+ import sys
8
9
  from dataclasses import dataclass
9
10
  from pathlib import Path
10
11
  from typing import Generic, TypeVar
@@ -17,6 +18,16 @@ from xax.task.base import BaseConfig, BaseTask
17
18
  logger = logging.getLogger(__name__)
18
19
 
19
20
 
21
+ def get_cache_dir() -> str | None:
22
+ # By default, only cache on MacOS, since Jax caching on Linux is very
23
+ # prone to NaNs.
24
+ match sys.platform:
25
+ case "darwin":
26
+ return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
27
+ case _:
28
+ return None
29
+
30
+
20
31
  @jax.tree_util.register_dataclass
21
32
  @dataclass
22
33
  class CompileOptions:
@@ -42,7 +53,8 @@ class CompileOptions:
42
53
 
43
54
  # JAX cache options
44
55
  cache_dir: str | None = field(
45
- value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
56
+ # Only cache by default on MacOS systems.
57
+ value=get_cache_dir,
46
58
  help="Directory for JAX compilation cache. If None, caching is disabled",
47
59
  )
48
60
  cache_min_size_bytes: int = field(
@@ -54,7 +66,7 @@ class CompileOptions:
54
66
  help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
55
67
  )
56
68
  cache_enable_xla: str = field(
57
- value="all",
69
+ value="none",
58
70
  help="Which XLA caches to enable",
59
71
  )
60
72
 
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
50
50
  diff_configs,
51
51
  get_diff_string,
52
52
  get_git_state,
53
+ get_packages_with_versions,
53
54
  get_training_code,
54
55
  )
55
56
  from xax.utils.logging import LOG_STATUS
@@ -452,6 +453,7 @@ class TrainMixin(
452
453
  logger.log(LOG_STATUS, self.task_name)
453
454
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
454
455
  self.logger.log_file("git_state.txt", get_git_state(self))
456
+ self.logger.log_file("packages.txt", get_packages_with_versions())
455
457
  self.logger.log_file("training_code.txt", get_training_code(self))
456
458
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
457
459
 
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
28
28
  from urllib.parse import urlparse
29
29
 
30
30
  import git
31
+ import pkg_resources
31
32
  import requests
32
33
  from jaxtyping import Array
33
34
  from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
468
469
  return traceback.format_exc()
469
470
 
470
471
 
472
+ def get_packages_with_versions() -> str:
473
+ """Gets the packages and their versions.
474
+
475
+ Returns:
476
+ A dictionary of packages and their versions.
477
+ """
478
+ packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
479
+ return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
+
481
+
471
482
  def get_training_code(obj: object) -> str:
472
483
  """Gets the text from the file containing the provided object.
473
484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes