xax 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.3/xax.egg-info → xax-0.1.4}/PKG-INFO +1 -1
  2. {xax-0.1.3 → xax-0.1.4}/xax/__init__.py +1 -1
  3. {xax-0.1.3 → xax-0.1.4}/xax/task/base.py +2 -0
  4. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/compile.py +14 -2
  5. {xax-0.1.3 → xax-0.1.4/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.1.3 → xax-0.1.4}/LICENSE +0 -0
  7. {xax-0.1.3 → xax-0.1.4}/MANIFEST.in +0 -0
  8. {xax-0.1.3 → xax-0.1.4}/README.md +0 -0
  9. {xax-0.1.3 → xax-0.1.4}/pyproject.toml +0 -0
  10. {xax-0.1.3 → xax-0.1.4}/setup.cfg +0 -0
  11. {xax-0.1.3 → xax-0.1.4}/setup.py +0 -0
  12. {xax-0.1.3 → xax-0.1.4}/xax/core/__init__.py +0 -0
  13. {xax-0.1.3 → xax-0.1.4}/xax/core/conf.py +0 -0
  14. {xax-0.1.3 → xax-0.1.4}/xax/core/state.py +0 -0
  15. {xax-0.1.3 → xax-0.1.4}/xax/nn/__init__.py +0 -0
  16. {xax-0.1.3 → xax-0.1.4}/xax/nn/embeddings.py +0 -0
  17. {xax-0.1.3 → xax-0.1.4}/xax/nn/equinox.py +0 -0
  18. {xax-0.1.3 → xax-0.1.4}/xax/nn/export.py +0 -0
  19. {xax-0.1.3 → xax-0.1.4}/xax/nn/functions.py +0 -0
  20. {xax-0.1.3 → xax-0.1.4}/xax/nn/geom.py +0 -0
  21. {xax-0.1.3 → xax-0.1.4}/xax/nn/norm.py +0 -0
  22. {xax-0.1.3 → xax-0.1.4}/xax/nn/parallel.py +0 -0
  23. {xax-0.1.3 → xax-0.1.4}/xax/py.typed +0 -0
  24. {xax-0.1.3 → xax-0.1.4}/xax/requirements-dev.txt +0 -0
  25. {xax-0.1.3 → xax-0.1.4}/xax/requirements.txt +0 -0
  26. {xax-0.1.3 → xax-0.1.4}/xax/task/__init__.py +0 -0
  27. {xax-0.1.3 → xax-0.1.4}/xax/task/launchers/__init__.py +0 -0
  28. {xax-0.1.3 → xax-0.1.4}/xax/task/launchers/base.py +0 -0
  29. {xax-0.1.3 → xax-0.1.4}/xax/task/launchers/cli.py +0 -0
  30. {xax-0.1.3 → xax-0.1.4}/xax/task/launchers/single_process.py +0 -0
  31. {xax-0.1.3 → xax-0.1.4}/xax/task/logger.py +0 -0
  32. {xax-0.1.3 → xax-0.1.4}/xax/task/loggers/__init__.py +0 -0
  33. {xax-0.1.3 → xax-0.1.4}/xax/task/loggers/callback.py +0 -0
  34. {xax-0.1.3 → xax-0.1.4}/xax/task/loggers/json.py +0 -0
  35. {xax-0.1.3 → xax-0.1.4}/xax/task/loggers/state.py +0 -0
  36. {xax-0.1.3 → xax-0.1.4}/xax/task/loggers/stdout.py +0 -0
  37. {xax-0.1.3 → xax-0.1.4}/xax/task/loggers/tensorboard.py +0 -0
  38. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/__init__.py +0 -0
  39. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/artifacts.py +0 -0
  40. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/checkpointing.py +0 -0
  41. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/cpu_stats.py +0 -0
  42. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/data_loader.py +0 -0
  43. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/gpu_stats.py +0 -0
  44. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/logger.py +0 -0
  45. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/process.py +0 -0
  46. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/runnable.py +0 -0
  47. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/step_wrapper.py +0 -0
  48. {xax-0.1.3 → xax-0.1.4}/xax/task/mixins/train.py +0 -0
  49. {xax-0.1.3 → xax-0.1.4}/xax/task/script.py +0 -0
  50. {xax-0.1.3 → xax-0.1.4}/xax/task/task.py +0 -0
  51. {xax-0.1.3 → xax-0.1.4}/xax/utils/__init__.py +0 -0
  52. {xax-0.1.3 → xax-0.1.4}/xax/utils/data/__init__.py +0 -0
  53. {xax-0.1.3 → xax-0.1.4}/xax/utils/data/collate.py +0 -0
  54. {xax-0.1.3 → xax-0.1.4}/xax/utils/debugging.py +0 -0
  55. {xax-0.1.3 → xax-0.1.4}/xax/utils/experiments.py +0 -0
  56. {xax-0.1.3 → xax-0.1.4}/xax/utils/jax.py +0 -0
  57. {xax-0.1.3 → xax-0.1.4}/xax/utils/jaxpr.py +0 -0
  58. {xax-0.1.3 → xax-0.1.4}/xax/utils/logging.py +0 -0
  59. {xax-0.1.3 → xax-0.1.4}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.3 → xax-0.1.4}/xax/utils/profile.py +0 -0
  61. {xax-0.1.3 → xax-0.1.4}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.3 → xax-0.1.4}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.3 → xax-0.1.4}/xax/utils/text.py +0 -0
  64. {xax-0.1.3 → xax-0.1.4}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.3 → xax-0.1.4}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.3 → xax-0.1.4}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.3 → xax-0.1.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.3"
15
+ __version__ = "0.1.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -153,6 +153,8 @@ class BaseTask(Generic[Config]):
153
153
  for base in cls.__orig_bases__:
154
154
  if hasattr(base, "__args__"):
155
155
  for arg in base.__args__:
156
+ if isinstance(arg, TypeVar) and arg.__bound__ is not None:
157
+ arg = arg.__bound__
156
158
  if issubclass(arg, BaseConfig):
157
159
  return arg
158
160
 
@@ -5,6 +5,7 @@ behavior during initialization and training.
5
5
  """
6
6
 
7
7
  import logging
8
+ import sys
8
9
  from dataclasses import dataclass
9
10
  from pathlib import Path
10
11
  from typing import Generic, TypeVar
@@ -17,6 +18,16 @@ from xax.task.base import BaseConfig, BaseTask
17
18
  logger = logging.getLogger(__name__)
18
19
 
19
20
 
21
+ def get_cache_dir() -> str | None:
22
+ # By default, only cache on MacOS, since Jax caching on Linux is very
23
+ # prone to NaNs.
24
+ match sys.platform:
25
+ case "darwin":
26
+ return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
27
+ case _:
28
+ return None
29
+
30
+
20
31
  @jax.tree_util.register_dataclass
21
32
  @dataclass
22
33
  class CompileOptions:
@@ -42,7 +53,8 @@ class CompileOptions:
42
53
 
43
54
  # JAX cache options
44
55
  cache_dir: str | None = field(
45
- value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
56
+ # Only cache by default on MacOS systems.
57
+ value=get_cache_dir,
46
58
  help="Directory for JAX compilation cache. If None, caching is disabled",
47
59
  )
48
60
  cache_min_size_bytes: int = field(
@@ -54,7 +66,7 @@ class CompileOptions:
54
66
  help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
55
67
  )
56
68
  cache_enable_xla: str = field(
57
- value="all",
69
+ value="none",
58
70
  help="Which XLA caches to enable",
59
71
  )
60
72
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes