xax 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/base.py +2 -0
- xax/task/mixins/compile.py +14 -2
- {xax-0.1.3.dist-info → xax-0.1.4.dist-info}/METADATA +1 -1
- {xax-0.1.3.dist-info → xax-0.1.4.dist-info}/RECORD +8 -8
- {xax-0.1.3.dist-info → xax-0.1.4.dist-info}/WHEEL +0 -0
- {xax-0.1.3.dist-info → xax-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.3.dist-info → xax-0.1.4.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/task/base.py
CHANGED
@@ -153,6 +153,8 @@ class BaseTask(Generic[Config]):
|
|
153
153
|
for base in cls.__orig_bases__:
|
154
154
|
if hasattr(base, "__args__"):
|
155
155
|
for arg in base.__args__:
|
156
|
+
if isinstance(arg, TypeVar) and arg.__bound__ is not None:
|
157
|
+
arg = arg.__bound__
|
156
158
|
if issubclass(arg, BaseConfig):
|
157
159
|
return arg
|
158
160
|
|
xax/task/mixins/compile.py
CHANGED
@@ -5,6 +5,7 @@ behavior during initialization and training.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import logging
|
8
|
+
import sys
|
8
9
|
from dataclasses import dataclass
|
9
10
|
from pathlib import Path
|
10
11
|
from typing import Generic, TypeVar
|
@@ -17,6 +18,16 @@ from xax.task.base import BaseConfig, BaseTask
|
|
17
18
|
logger = logging.getLogger(__name__)
|
18
19
|
|
19
20
|
|
21
|
+
def get_cache_dir() -> str | None:
|
22
|
+
# By default, only cache on MacOS, since Jax caching on Linux is very
|
23
|
+
# prone to NaNs.
|
24
|
+
match sys.platform:
|
25
|
+
case "darwin":
|
26
|
+
return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
|
27
|
+
case _:
|
28
|
+
return None
|
29
|
+
|
30
|
+
|
20
31
|
@jax.tree_util.register_dataclass
|
21
32
|
@dataclass
|
22
33
|
class CompileOptions:
|
@@ -42,7 +53,8 @@ class CompileOptions:
|
|
42
53
|
|
43
54
|
# JAX cache options
|
44
55
|
cache_dir: str | None = field(
|
45
|
-
|
56
|
+
# Only cache by default on MacOS systems.
|
57
|
+
value=get_cache_dir,
|
46
58
|
help="Directory for JAX compilation cache. If None, caching is disabled",
|
47
59
|
)
|
48
60
|
cache_min_size_bytes: int = field(
|
@@ -54,7 +66,7 @@ class CompileOptions:
|
|
54
66
|
help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
|
55
67
|
)
|
56
68
|
cache_enable_xla: str = field(
|
57
|
-
value="
|
69
|
+
value="none",
|
58
70
|
help="Which XLA caches to enable",
|
59
71
|
)
|
60
72
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=Ib-tn4sRqSIiFjoQceuUzdwRc4i-VW9P5bBLq4wDbYc,13361
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -14,7 +14,7 @@ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
|
14
14
|
xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
|
15
15
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
16
16
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
xax/task/base.py,sha256=
|
17
|
+
xax/task/base.py,sha256=4fUjrG-llQpeESQuaQbww4M6WR6djjTK89fY20UV9zU,7610
|
18
18
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
19
19
|
xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
|
20
20
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
@@ -31,7 +31,7 @@ xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCn
|
|
31
31
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
32
32
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
33
33
|
xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
|
34
|
-
xax/task/mixins/compile.py,sha256=
|
34
|
+
xax/task/mixins/compile.py,sha256=9pVJEUvizu6-6tq0HaMtHGNSi9Yk_mxNyqBFimcfwL0,3683
|
35
35
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
36
36
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
37
37
|
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
@@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,1617
|
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
55
55
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
56
|
-
xax-0.1.
|
57
|
-
xax-0.1.
|
58
|
-
xax-0.1.
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
56
|
+
xax-0.1.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
57
|
+
xax-0.1.4.dist-info/METADATA,sha256=VMyCF4gPJh5ww_VzF0HgfInf_bfD40ahwxpHkxEMGoI,1877
|
58
|
+
xax-0.1.4.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
|
59
|
+
xax-0.1.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
60
|
+
xax-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|