xax 0.1.5__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.5/xax.egg-info → xax-0.1.7}/PKG-INFO +1 -1
- {xax-0.1.5 → xax-0.1.7}/xax/__init__.py +10 -13
- {xax-0.1.5 → xax-0.1.7}/xax/nn/equinox.py +2 -2
- {xax-0.1.5 → xax-0.1.7}/xax/nn/export.py +5 -5
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/compile.py +1 -1
- xax-0.1.7/xax/utils/types/__init__.py +0 -0
- xax-0.1.7/xax/utils/types/frozen_dict.py +148 -0
- xax-0.1.7/xax/utils/types/hashable_array.py +31 -0
- {xax-0.1.5 → xax-0.1.7/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.1.5 → xax-0.1.7}/xax.egg-info/SOURCES.txt +4 -1
- {xax-0.1.5 → xax-0.1.7}/LICENSE +0 -0
- {xax-0.1.5 → xax-0.1.7}/MANIFEST.in +0 -0
- {xax-0.1.5 → xax-0.1.7}/README.md +0 -0
- {xax-0.1.5 → xax-0.1.7}/pyproject.toml +0 -0
- {xax-0.1.5 → xax-0.1.7}/setup.cfg +0 -0
- {xax-0.1.5 → xax-0.1.7}/setup.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/core/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/core/conf.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/core/state.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/nn/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/nn/embeddings.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/nn/functions.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/nn/geom.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/nn/norm.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/nn/parallel.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/py.typed +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/requirements-dev.txt +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/requirements.txt +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/base.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/base.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/logger.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/json.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/state.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/process.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/train.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/script.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/task/task.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/data/collate.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/debugging.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/experiments.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/jax.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/logging.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/numpy.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/profile.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/pytree.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/tensorboard.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax/utils/text.py +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.5 → xax-0.1.7}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.7"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -41,9 +41,6 @@ __all__ = [
|
|
41
41
|
"load_eqx_mlp",
|
42
42
|
"make_eqx_mlp",
|
43
43
|
"save_eqx",
|
44
|
-
"export",
|
45
|
-
"export_flax",
|
46
|
-
"export_with_params",
|
47
44
|
"euler_to_quat",
|
48
45
|
"get_projected_gravity_vector_from_quat",
|
49
46
|
"quat_to_euler",
|
@@ -130,6 +127,9 @@ __all__ = [
|
|
130
127
|
"snakecase_to_camelcase",
|
131
128
|
"uncolored",
|
132
129
|
"wrapped",
|
130
|
+
"FrozenDict",
|
131
|
+
"HashableArray",
|
132
|
+
"hashable_array",
|
133
133
|
]
|
134
134
|
|
135
135
|
__all__ += [
|
@@ -159,7 +159,7 @@ if "XLA_FLAGS" in os.environ:
|
|
159
159
|
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
160
160
|
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
161
161
|
if shutil.which("nvidia-smi") is not None:
|
162
|
-
xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
|
162
|
+
xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
|
163
163
|
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
164
164
|
|
165
165
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
@@ -195,9 +195,6 @@ NAME_MAP: dict[str, str] = {
|
|
195
195
|
"load_eqx_mlp": "nn.equinox",
|
196
196
|
"make_eqx_mlp": "nn.equinox",
|
197
197
|
"save_eqx": "nn.equinox",
|
198
|
-
"export": "nn.export",
|
199
|
-
"export_flax": "nn.export",
|
200
|
-
"export_with_params": "nn.export",
|
201
198
|
"euler_to_quat": "nn.geom",
|
202
199
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
203
200
|
"quat_to_euler": "nn.geom",
|
@@ -284,6 +281,9 @@ NAME_MAP: dict[str, str] = {
|
|
284
281
|
"snakecase_to_camelcase": "utils.text",
|
285
282
|
"uncolored": "utils.text",
|
286
283
|
"wrapped": "utils.text",
|
284
|
+
"FrozenDict": "utils.types.frozen_dict",
|
285
|
+
"HashableArray": "utils.types.hashable_array",
|
286
|
+
"hashable_array": "utils.types.hashable_array",
|
287
287
|
}
|
288
288
|
|
289
289
|
# Need to manually set some values which can't be auto-generated.
|
@@ -348,11 +348,6 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
348
348
|
make_eqx_mlp,
|
349
349
|
save_eqx,
|
350
350
|
)
|
351
|
-
from xax.nn.export import (
|
352
|
-
export,
|
353
|
-
export_flax,
|
354
|
-
export_with_params,
|
355
|
-
)
|
356
351
|
from xax.nn.geom import (
|
357
352
|
euler_to_quat,
|
358
353
|
get_projected_gravity_vector_from_quat,
|
@@ -444,5 +439,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
444
439
|
uncolored,
|
445
440
|
wrapped,
|
446
441
|
)
|
442
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
443
|
+
from xax.utils.types.hashable_array import HashableArray, hashable_array
|
447
444
|
|
448
445
|
del TYPE_CHECKING, IMPORT_ALL
|
@@ -72,7 +72,7 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
|
|
72
72
|
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
|
73
73
|
|
74
74
|
|
75
|
-
def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray
|
75
|
+
def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
|
76
76
|
"""Create an Equinox MLP from a set of hyperparameters.
|
77
77
|
|
78
78
|
Args:
|
@@ -176,5 +176,5 @@ def load_eqx_mlp(
|
|
176
176
|
) -> eqx.nn.MLP:
|
177
177
|
with open(eqx_file, "rb") as f:
|
178
178
|
hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
|
179
|
-
model = make_eqx_mlp(hyperparams=hyperparams)
|
179
|
+
model = make_eqx_mlp(hyperparams=hyperparams, key=jax.random.PRNGKey(0))
|
180
180
|
return eqx.tree_deserialise_leaves(f, model)
|
@@ -4,18 +4,18 @@ import logging
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Callable
|
6
6
|
|
7
|
-
import flax
|
8
7
|
import jax
|
9
|
-
import tensorflow as tf
|
10
|
-
from jax.experimental import jax2tf
|
11
8
|
from jaxtyping import Array, PyTree
|
12
9
|
|
13
10
|
try:
|
11
|
+
import flax
|
12
|
+
import tensorflow as tf
|
13
|
+
from jax.experimental import jax2tf
|
14
14
|
from orbax.export import ExportManager, JaxModule, ServingConfig
|
15
15
|
except ImportError as e:
|
16
16
|
raise ImportError(
|
17
|
-
"
|
18
|
-
"'xax[export]` to install the required dependencies."
|
17
|
+
"In order to export models, please install Xax with export dependencies, "
|
18
|
+
"using 'xax[export]` to install the required dependencies."
|
19
19
|
) from e
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
@@ -22,7 +22,7 @@ def get_cache_dir() -> str | None:
|
|
22
22
|
# By default, only cache on MacOS, since Jax caching on Linux is very
|
23
23
|
# prone to NaNs.
|
24
24
|
match sys.platform:
|
25
|
-
case "darwin":
|
25
|
+
case "darwin" | "linux":
|
26
26
|
return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
|
27
27
|
case _:
|
28
28
|
return None
|
File without changes
|
@@ -0,0 +1,148 @@
|
|
1
|
+
"""Defines a frozen dictionary type.
|
2
|
+
|
3
|
+
This is mostly taken from Flax - we move it here to avoid having to use Flax as
|
4
|
+
a dependency in downstream projects.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import collections
|
8
|
+
from types import MappingProxyType
|
9
|
+
from typing import Any, Iterator, Mapping, Self, TypeVar
|
10
|
+
|
11
|
+
import jax
|
12
|
+
|
13
|
+
K = TypeVar("K")
|
14
|
+
V = TypeVar("V")
|
15
|
+
|
16
|
+
|
17
|
+
def _prepare_freeze(xs: Any) -> Any: # noqa: ANN401
|
18
|
+
"""Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
|
19
|
+
if isinstance(xs, FrozenDict):
|
20
|
+
return xs._dict # pylint: disable=protected-access
|
21
|
+
if not isinstance(xs, dict):
|
22
|
+
return xs
|
23
|
+
return {key: _prepare_freeze(val) for key, val in xs.items()}
|
24
|
+
|
25
|
+
|
26
|
+
def _indent(x: str, num_spaces: int) -> str:
|
27
|
+
indent_str = " " * num_spaces
|
28
|
+
lines = x.split("\n")
|
29
|
+
assert not lines[-1]
|
30
|
+
return "\n".join(indent_str + line for line in lines[:-1]) + "\n"
|
31
|
+
|
32
|
+
|
33
|
+
class FrozenKeysView(collections.abc.KeysView[K]):
|
34
|
+
def __repr__(self) -> str:
|
35
|
+
return f"frozen_dict_keys({list(self)})"
|
36
|
+
|
37
|
+
|
38
|
+
class FrozenValuesView(collections.abc.ValuesView[V]):
|
39
|
+
def __repr__(self) -> str:
|
40
|
+
return f"frozen_dict_values({list(self)})"
|
41
|
+
|
42
|
+
|
43
|
+
@jax.tree_util.register_pytree_with_keys_class
|
44
|
+
class FrozenDict(Mapping[K, V]):
|
45
|
+
"""An immutable variant of the Python dict."""
|
46
|
+
|
47
|
+
__slots__ = ("_dict", "_hash")
|
48
|
+
|
49
|
+
def __init__(self, *args: Any, __unsafe_skip_copy__: bool = False, **kwargs: Any) -> None: # noqa: ANN401
|
50
|
+
# make sure the dict is as
|
51
|
+
xs = dict(*args, **kwargs)
|
52
|
+
if __unsafe_skip_copy__:
|
53
|
+
self._dict = xs
|
54
|
+
else:
|
55
|
+
self._dict = _prepare_freeze(xs)
|
56
|
+
|
57
|
+
self._hash: int | None = None
|
58
|
+
|
59
|
+
def __getitem__(self, key: K) -> V:
|
60
|
+
v = self._dict[key]
|
61
|
+
if isinstance(v, dict):
|
62
|
+
return FrozenDict(v) # type: ignore[return-value]
|
63
|
+
return v
|
64
|
+
|
65
|
+
def __setitem__(self, key: K, value: V) -> None:
|
66
|
+
raise ValueError("FrozenDict is immutable.")
|
67
|
+
|
68
|
+
def __contains__(self, key: object) -> bool:
|
69
|
+
return key in self._dict
|
70
|
+
|
71
|
+
def __iter__(self) -> Iterator[K]:
|
72
|
+
return iter(self._dict)
|
73
|
+
|
74
|
+
def __len__(self) -> int:
|
75
|
+
return len(self._dict)
|
76
|
+
|
77
|
+
def __repr__(self) -> str:
|
78
|
+
return self.pretty_repr()
|
79
|
+
|
80
|
+
def __reduce__(self) -> tuple[type["FrozenDict[K, V]"], tuple[dict[K, V]]]:
|
81
|
+
return FrozenDict, (self.unfreeze(),)
|
82
|
+
|
83
|
+
def pretty_repr(self, num_spaces: int = 4) -> str:
|
84
|
+
"""Returns an indented representation of the nested dictionary."""
|
85
|
+
|
86
|
+
def pretty_dict(x: Any) -> str: # noqa: ANN401
|
87
|
+
if not isinstance(x, dict):
|
88
|
+
return repr(x)
|
89
|
+
rep = ""
|
90
|
+
for key, val in x.items():
|
91
|
+
rep += f"{key}: {pretty_dict(val)},\n"
|
92
|
+
if rep:
|
93
|
+
return "{\n" + _indent(rep, num_spaces) + "}"
|
94
|
+
else:
|
95
|
+
return "{}"
|
96
|
+
|
97
|
+
return f"FrozenDict({pretty_dict(self._dict)})"
|
98
|
+
|
99
|
+
def __hash__(self) -> int:
|
100
|
+
if self._hash is None:
|
101
|
+
h = 0
|
102
|
+
for key, value in self.items():
|
103
|
+
h ^= hash((key, value))
|
104
|
+
self._hash = h
|
105
|
+
return self._hash
|
106
|
+
|
107
|
+
def copy(self, add_or_replace: Mapping[K, V] = MappingProxyType({})) -> Self:
|
108
|
+
return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
|
109
|
+
|
110
|
+
def keys(self) -> FrozenKeysView[K]:
|
111
|
+
return FrozenKeysView(self)
|
112
|
+
|
113
|
+
def values(self) -> FrozenValuesView[V]:
|
114
|
+
return FrozenValuesView(self)
|
115
|
+
|
116
|
+
def items(self) -> Iterator[tuple[K, V]]: # type: ignore[override]
|
117
|
+
for key in self._dict:
|
118
|
+
yield (key, self[key])
|
119
|
+
|
120
|
+
def pop(self, key: K) -> tuple["FrozenDict[K, V]", V]:
|
121
|
+
value = self[key]
|
122
|
+
new_dict = dict(self._dict)
|
123
|
+
new_dict.pop(key)
|
124
|
+
new_self = type(self)(new_dict)
|
125
|
+
return new_self, value
|
126
|
+
|
127
|
+
def unfreeze(self) -> dict[K, V]:
|
128
|
+
return unfreeze(self)
|
129
|
+
|
130
|
+
def tree_flatten_with_keys(self) -> tuple[tuple[tuple[jax.tree_util.DictKey, Any], ...], tuple[K, ...]]:
|
131
|
+
sorted_keys = sorted(self._dict)
|
132
|
+
return tuple([(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]), tuple(sorted_keys)
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
|
136
|
+
return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
|
137
|
+
|
138
|
+
|
139
|
+
def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
|
140
|
+
if isinstance(x, FrozenDict):
|
141
|
+
return jax.tree_util.tree_map(lambda y: y, x._dict)
|
142
|
+
elif isinstance(x, dict):
|
143
|
+
ys = {}
|
144
|
+
for key, value in x.items():
|
145
|
+
ys[key] = unfreeze(value)
|
146
|
+
return ys
|
147
|
+
else:
|
148
|
+
return x
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""Defines a hashable array wrapper.
|
2
|
+
|
3
|
+
Since Jax relies extensively on hashing, and we sometimes want to treat Jax
|
4
|
+
arrays as constants, this wrapper lets us ensure that Jax and Numpy arrays can
|
5
|
+
be hashed for Jitting.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
|
12
|
+
class HashableArray:
|
13
|
+
def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
|
14
|
+
if not isinstance(array, (np.ndarray, jnp.ndarray)):
|
15
|
+
raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
|
16
|
+
self.array = array
|
17
|
+
self._hash: int | None = None
|
18
|
+
|
19
|
+
def __hash__(self) -> int:
|
20
|
+
if self._hash is None:
|
21
|
+
self._hash = hash(self.array.tobytes())
|
22
|
+
return self._hash
|
23
|
+
|
24
|
+
def __eq__(self, other: object) -> bool:
|
25
|
+
if not isinstance(other, HashableArray):
|
26
|
+
return False
|
27
|
+
return bool(jnp.array_equal(self.array, other.array))
|
28
|
+
|
29
|
+
|
30
|
+
def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
|
31
|
+
return HashableArray(array)
|
{xax-0.1.5 → xax-0.1.7}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.1.5 → xax-0.1.7}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|