xax 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.2"
15
+ __version__ = "0.1.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/task/base.py CHANGED
@@ -153,6 +153,8 @@ class BaseTask(Generic[Config]):
153
153
  for base in cls.__orig_bases__:
154
154
  if hasattr(base, "__args__"):
155
155
  for arg in base.__args__:
156
+ if isinstance(arg, TypeVar) and arg.__bound__ is not None:
157
+ arg = arg.__bound__
156
158
  if issubclass(arg, BaseConfig):
157
159
  return arg
158
160
 
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
49
  load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
- save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
51
50
 
52
51
 
53
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
213
212
  add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
214
213
  add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
215
214
 
216
- if self.config.save_tf_model:
217
- try:
218
- from jax.experimental import jax2tf
219
- except ModuleNotFoundError:
220
- raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
221
-
222
- tf_model = jax2tf.convert(model)
223
- add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
224
-
225
215
  # Updates the symlink to the new checkpoint.
226
216
  last_ckpt_path.unlink(missing_ok=True)
227
217
  try:
@@ -5,6 +5,7 @@ behavior during initialization and training.
5
5
  """
6
6
 
7
7
  import logging
8
+ import sys
8
9
  from dataclasses import dataclass
9
10
  from pathlib import Path
10
11
  from typing import Generic, TypeVar
@@ -17,6 +18,16 @@ from xax.task.base import BaseConfig, BaseTask
17
18
  logger = logging.getLogger(__name__)
18
19
 
19
20
 
21
+ def get_cache_dir() -> str | None:
22
+ # By default, only cache on MacOS, since Jax caching on Linux is very
23
+ # prone to NaNs.
24
+ match sys.platform:
25
+ case "darwin":
26
+ return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
27
+ case _:
28
+ return None
29
+
30
+
20
31
  @jax.tree_util.register_dataclass
21
32
  @dataclass
22
33
  class CompileOptions:
@@ -42,7 +53,8 @@ class CompileOptions:
42
53
 
43
54
  # JAX cache options
44
55
  cache_dir: str | None = field(
45
- value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
56
+ # Only cache by default on MacOS systems.
57
+ value=get_cache_dir,
46
58
  help="Directory for JAX compilation cache. If None, caching is disabled",
47
59
  )
48
60
  cache_min_size_bytes: int = field(
@@ -54,7 +66,7 @@ class CompileOptions:
54
66
  help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
55
67
  )
56
68
  cache_enable_xla: str = field(
57
- value="all",
69
+ value="none",
58
70
  help="Which XLA caches to enable",
59
71
  )
60
72
 
xax/task/mixins/train.py CHANGED
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
50
50
  diff_configs,
51
51
  get_diff_string,
52
52
  get_git_state,
53
+ get_packages_with_versions,
53
54
  get_training_code,
54
55
  )
55
56
  from xax.utils.logging import LOG_STATUS
@@ -452,6 +453,7 @@ class TrainMixin(
452
453
  logger.log(LOG_STATUS, self.task_name)
453
454
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
454
455
  self.logger.log_file("git_state.txt", get_git_state(self))
456
+ self.logger.log_file("packages.txt", get_packages_with_versions())
455
457
  self.logger.log_file("training_code.txt", get_training_code(self))
456
458
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
457
459
 
xax/utils/experiments.py CHANGED
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
28
28
  from urllib.parse import urlparse
29
29
 
30
30
  import git
31
+ import pkg_resources
31
32
  import requests
32
33
  from jaxtyping import Array
33
34
  from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
468
469
  return traceback.format_exc()
469
470
 
470
471
 
472
+ def get_packages_with_versions() -> str:
473
+ """Gets the packages and their versions.
474
+
475
+ Returns:
476
+ A dictionary of packages and their versions.
477
+ """
478
+ packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
479
+ return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
+
481
+
471
482
  def get_training_code(obj: object) -> str:
472
483
  """Gets the text from the file containing the provided object.
473
484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=Ti6hrfoY5wnywzOvkvtCwq2SvLsjYfbm_6U_UzYakls,13361
1
+ xax/__init__.py,sha256=Ib-tn4sRqSIiFjoQceuUzdwRc4i-VW9P5bBLq4wDbYc,13361
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -14,7 +14,7 @@ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
14
  xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
15
15
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
16
16
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- xax/task/base.py,sha256=MlH5dTKAiMzFRI5fmXCvL1k8ELbalWMBICeVxmW6k2U,7479
17
+ xax/task/base.py,sha256=4fUjrG-llQpeESQuaQbww4M6WR6djjTK89fY20UV9zU,7610
18
18
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
19
19
  xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
20
20
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
@@ -30,8 +30,8 @@ xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,67
30
30
  xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
31
31
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
32
32
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
33
- xax/task/mixins/checkpointing.py,sha256=sRkVxJbQfqDf1-lp1KFrAGYWHhTlV8_DORxGQ_69P1A,8954
34
- xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
33
+ xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
34
+ xax/task/mixins/compile.py,sha256=9pVJEUvizu6-6tq0HaMtHGNSi9Yk_mxNyqBFimcfwL0,3683
35
35
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
36
36
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
37
37
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
@@ -39,10 +39,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
41
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=BEC7HSwBlGZDe7jCsedqEA8-K1Zx52-bTjsBONYIE5g,22225
42
+ xax/task/mixins/train.py,sha256=8AaBXaopnrxtSZXldyFCE3QX1k5r3IsZMr6O0ICnNnU,22332
43
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
- xax/utils/experiments.py,sha256=d-e-RCw9PlnuqV3FPW0U74zcvlOKV48lUrX8tvAfhew,28887
45
+ xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
46
46
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
47
47
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
48
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
@@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,1617
53
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
56
- xax-0.1.2.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
- xax-0.1.2.dist-info/METADATA,sha256=-BB6_Qiip_pPkf96Wl9FZsM_7PPKJr5l8v2owrXCvoI,1877
58
- xax-0.1.2.dist-info/WHEEL,sha256=L0N565qmK-3nM2eBoMNFszYJ_MTx03_tQ0CQu1bHLYo,91
59
- xax-0.1.2.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
- xax-0.1.2.dist-info/RECORD,,
56
+ xax-0.1.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
+ xax-0.1.4.dist-info/METADATA,sha256=VMyCF4gPJh5ww_VzF0HgfInf_bfD40ahwxpHkxEMGoI,1877
58
+ xax-0.1.4.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
59
+ xax-0.1.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
+ xax-0.1.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.0.1)
2
+ Generator: setuptools (78.0.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5