xax 0.2.20__tar.gz → 0.2.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.20/xax.egg-info → xax-0.2.22}/PKG-INFO +1 -17
- {xax-0.2.20 → xax-0.2.22}/pyproject.toml +0 -1
- {xax-0.2.20 → xax-0.2.22}/setup.py +0 -8
- {xax-0.2.20 → xax-0.2.22}/xax/__init__.py +6 -24
- {xax-0.2.20 → xax-0.2.22}/xax/cli/edit_config.py +16 -6
- {xax-0.2.20 → xax-0.2.22}/xax/nn/metrics.py +0 -3
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/train.py +7 -3
- {xax-0.2.20 → xax-0.2.22}/xax/utils/jax.py +109 -7
- {xax-0.2.20 → xax-0.2.22/xax.egg-info}/PKG-INFO +1 -17
- {xax-0.2.20 → xax-0.2.22}/xax.egg-info/SOURCES.txt +0 -2
- {xax-0.2.20 → xax-0.2.22}/xax.egg-info/requires.txt +0 -18
- xax-0.2.20/xax/nn/equinox.py +0 -183
- xax-0.2.20/xax/nn/export.py +0 -154
- {xax-0.2.20 → xax-0.2.22}/LICENSE +0 -0
- {xax-0.2.20 → xax-0.2.22}/MANIFEST.in +0 -0
- {xax-0.2.20 → xax-0.2.22}/README.md +0 -0
- {xax-0.2.20 → xax-0.2.22}/setup.cfg +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/cli/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/core/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/core/conf.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/core/state.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/embeddings.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/functions.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/geom.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/losses.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/parallel.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/nn/ssm.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/py.typed +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/requirements-dev.txt +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/requirements.txt +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/base.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/launchers/base.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/logger.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/loggers/json.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/loggers/state.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/process.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/script.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/task/task.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/data/collate.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/debugging.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/experiments.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/logging.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/numpy.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/profile.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/pytree.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/text.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.2.20 → xax-0.2.22}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.22
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -31,22 +31,6 @@ Requires-Dist: pytest; extra == "dev"
|
|
31
31
|
Requires-Dist: types-pillow; extra == "dev"
|
32
32
|
Requires-Dist: types-psutil; extra == "dev"
|
33
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
-
Provides-Extra: exportable
|
35
|
-
Requires-Dist: flax; extra == "exportable"
|
36
|
-
Requires-Dist: orbax-export; extra == "exportable"
|
37
|
-
Requires-Dist: tensorflow; extra == "exportable"
|
38
|
-
Provides-Extra: all
|
39
|
-
Requires-Dist: black; extra == "all"
|
40
|
-
Requires-Dist: darglint; extra == "all"
|
41
|
-
Requires-Dist: mypy; extra == "all"
|
42
|
-
Requires-Dist: ruff; extra == "all"
|
43
|
-
Requires-Dist: pytest; extra == "all"
|
44
|
-
Requires-Dist: types-pillow; extra == "all"
|
45
|
-
Requires-Dist: types-psutil; extra == "all"
|
46
|
-
Requires-Dist: types-requests; extra == "all"
|
47
|
-
Requires-Dist: flax; extra == "all"
|
48
|
-
Requires-Dist: orbax-export; extra == "all"
|
49
|
-
Requires-Dist: tensorflow; extra == "all"
|
50
34
|
Dynamic: author
|
51
35
|
Dynamic: description
|
52
36
|
Dynamic: description-content-type
|
@@ -14,12 +14,6 @@ with open("xax/requirements.txt", "r", encoding="utf-8") as f:
|
|
14
14
|
with open("xax/requirements-dev.txt", "r", encoding="utf-8") as f:
|
15
15
|
requirements_dev: list[str] = f.read().splitlines()
|
16
16
|
|
17
|
-
requirements_export: list[str] = [
|
18
|
-
"flax",
|
19
|
-
"orbax-export",
|
20
|
-
"tensorflow",
|
21
|
-
]
|
22
|
-
|
23
17
|
with open("xax/__init__.py", "r", encoding="utf-8") as fh:
|
24
18
|
version_re = re.search(r"^__version__ = \"([^\"]*)\"", fh.read(), re.MULTILINE)
|
25
19
|
assert version_re is not None, "Could not find version in xax/__init__.py"
|
@@ -39,8 +33,6 @@ setup(
|
|
39
33
|
tests_require=requirements_dev,
|
40
34
|
extras_require={
|
41
35
|
"dev": requirements_dev,
|
42
|
-
"exportable": requirements_export,
|
43
|
-
"all": requirements_dev + requirements_export,
|
44
36
|
},
|
45
37
|
package_data={
|
46
38
|
"xax": [
|
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.22"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -34,12 +34,6 @@ __all__ = [
|
|
34
34
|
"get_positional_embeddings",
|
35
35
|
"get_rotary_embeddings",
|
36
36
|
"rotary_embeddings",
|
37
|
-
"MLPHyperParams",
|
38
|
-
"export_eqx_mlp",
|
39
|
-
"load_eqx",
|
40
|
-
"load_eqx_mlp",
|
41
|
-
"make_eqx_mlp",
|
42
|
-
"save_eqx",
|
43
37
|
"cubic_bezier_interpolation",
|
44
38
|
"euler_to_quat",
|
45
39
|
"get_projected_gravity_vector_from_quat",
|
@@ -118,8 +112,10 @@ __all__ = [
|
|
118
112
|
"save_config",
|
119
113
|
"stage_environment",
|
120
114
|
"to_markdown_table",
|
115
|
+
"grad",
|
121
116
|
"jit",
|
122
117
|
"scan",
|
118
|
+
"vmap",
|
123
119
|
"save_jaxpr_dot",
|
124
120
|
"ColoredFormatter",
|
125
121
|
"configure_logging",
|
@@ -215,12 +211,6 @@ NAME_MAP: dict[str, str] = {
|
|
215
211
|
"get_positional_embeddings": "nn.embeddings",
|
216
212
|
"get_rotary_embeddings": "nn.embeddings",
|
217
213
|
"rotary_embeddings": "nn.embeddings",
|
218
|
-
"MLPHyperParams": "nn.equinox",
|
219
|
-
"export_eqx_mlp": "nn.equinox",
|
220
|
-
"load_eqx": "nn.equinox",
|
221
|
-
"load_eqx_mlp": "nn.equinox",
|
222
|
-
"make_eqx_mlp": "nn.equinox",
|
223
|
-
"save_eqx": "nn.equinox",
|
224
214
|
"cubic_bezier_interpolation": "nn.geom",
|
225
215
|
"euler_to_quat": "nn.geom",
|
226
216
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
@@ -299,8 +289,10 @@ NAME_MAP: dict[str, str] = {
|
|
299
289
|
"save_config": "utils.experiments",
|
300
290
|
"stage_environment": "utils.experiments",
|
301
291
|
"to_markdown_table": "utils.experiments",
|
292
|
+
"grad": "utils.jax",
|
302
293
|
"jit": "utils.jax",
|
303
294
|
"scan": "utils.jax",
|
295
|
+
"vmap": "utils.jax",
|
304
296
|
"save_jaxpr_dot": "utils.jaxpr",
|
305
297
|
"ColoredFormatter": "utils.logging",
|
306
298
|
"configure_logging": "utils.logging",
|
@@ -392,16 +384,6 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
392
384
|
get_rotary_embeddings,
|
393
385
|
rotary_embeddings,
|
394
386
|
)
|
395
|
-
from xax.nn.equinox import (
|
396
|
-
DTYPE,
|
397
|
-
ActivationFunction,
|
398
|
-
MLPHyperParams,
|
399
|
-
export_eqx_mlp,
|
400
|
-
load_eqx,
|
401
|
-
load_eqx_mlp,
|
402
|
-
make_eqx_mlp,
|
403
|
-
save_eqx,
|
404
|
-
)
|
405
387
|
from xax.nn.geom import (
|
406
388
|
cubic_bezier_interpolation,
|
407
389
|
euler_to_quat,
|
@@ -482,7 +464,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
482
464
|
stage_environment,
|
483
465
|
to_markdown_table,
|
484
466
|
)
|
485
|
-
from xax.utils.jax import jit, scan
|
467
|
+
from xax.utils.jax import grad, jit, scan, vmap
|
486
468
|
from xax.utils.jaxpr import save_jaxpr_dot
|
487
469
|
from xax.utils.logging import (
|
488
470
|
LOG_ERROR_SUMMARY,
|
@@ -52,14 +52,24 @@ def main() -> None:
|
|
52
52
|
print(colored(line, "light-cyan"), flush=True)
|
53
53
|
|
54
54
|
# Saves the edited config to the checkpoint.
|
55
|
-
with
|
55
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
56
|
+
with tarfile.open(args.ckpt_path, "r:gz") as src_tar:
|
57
|
+
for member in src_tar.getmembers():
|
58
|
+
if member.name != "config": # Skip the old config file
|
59
|
+
src_tar.extract(member, tmp_dir)
|
56
60
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
+
with tarfile.open(args.ckpt_path, "w:gz") as tar:
|
62
|
+
for root, _, files in os.walk(tmp_dir):
|
63
|
+
for file in files:
|
64
|
+
file_path = os.path.join(root, file)
|
65
|
+
arcname = os.path.relpath(file_path, tmp_dir)
|
66
|
+
tar.add(file_path, arcname=arcname)
|
61
67
|
|
62
|
-
|
68
|
+
# Add the new config file
|
69
|
+
info = tarfile.TarInfo(name="config")
|
70
|
+
config_bytes = edited_config_str.encode()
|
71
|
+
info.size = len(config_bytes)
|
72
|
+
tar.addfile(info, io.BytesIO(config_bytes))
|
63
73
|
|
64
74
|
|
65
75
|
if __name__ == "__main__":
|
@@ -7,8 +7,6 @@ import jax
|
|
7
7
|
import jax.numpy as jnp
|
8
8
|
from jaxtyping import Array
|
9
9
|
|
10
|
-
from xax.utils.jax import jit as xax_jit
|
11
|
-
|
12
10
|
NormType = Literal["l1", "l2"]
|
13
11
|
|
14
12
|
|
@@ -36,7 +34,6 @@ def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
|
|
36
34
|
def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
|
37
35
|
|
38
36
|
|
39
|
-
@xax_jit(static_argnames=["return_path"])
|
40
37
|
def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
|
41
38
|
"""Dynamic Time Warping.
|
42
39
|
|
@@ -625,9 +625,13 @@ class TrainMixin(
|
|
625
625
|
grad_metrics = {"grad_norm": grad_norm}
|
626
626
|
|
627
627
|
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
628
|
-
# Clip
|
629
|
-
|
630
|
-
|
628
|
+
# Clip gradients based on global norm, similar to optax.clip_by_global_norm
|
629
|
+
trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
|
630
|
+
|
631
|
+
def clip_fn(t: Array) -> Array:
|
632
|
+
return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
|
633
|
+
|
634
|
+
grads = jax.tree.map(clip_fn, grads)
|
631
635
|
|
632
636
|
# Apply the gradient updates.
|
633
637
|
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
@@ -6,13 +6,14 @@ import logging
|
|
6
6
|
import os
|
7
7
|
import time
|
8
8
|
from functools import wraps
|
9
|
-
from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
9
|
+
from typing import Any, Callable, Hashable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
10
10
|
|
11
11
|
import jax
|
12
12
|
import jax.numpy as jnp
|
13
13
|
import numpy as np
|
14
14
|
from jax._src import sharding_impls
|
15
15
|
from jax._src.lib import xla_client as xc
|
16
|
+
from jaxtyping import PyTree
|
16
17
|
|
17
18
|
logger = logging.getLogger(__name__)
|
18
19
|
|
@@ -20,6 +21,7 @@ DEFAULT_COMPILE_TIMEOUT = 1.0
|
|
20
21
|
|
21
22
|
Number = int | float | np.ndarray | jnp.ndarray
|
22
23
|
|
24
|
+
T = TypeVar("T", bound=PyTree)
|
23
25
|
|
24
26
|
P = ParamSpec("P") # For function parameters
|
25
27
|
R = TypeVar("R") # For function return type
|
@@ -29,6 +31,9 @@ Carry = TypeVar("Carry")
|
|
29
31
|
X = TypeVar("X")
|
30
32
|
Y = TypeVar("Y")
|
31
33
|
|
34
|
+
F = TypeVar("F", bound=Callable)
|
35
|
+
AxisName = Hashable
|
36
|
+
|
32
37
|
|
33
38
|
@functools.lru_cache(maxsize=None)
|
34
39
|
def disable_jit_level() -> int:
|
@@ -166,6 +171,22 @@ def jit(
|
|
166
171
|
return decorator
|
167
172
|
|
168
173
|
|
174
|
+
def _split_module(tree: T, axis: int = 0) -> list[T]:
|
175
|
+
"""Splits a module in the same way that jax.lax.scan and jax.vmap do.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
tree: The tree to split.
|
179
|
+
axis: The axis to split on.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
A list of the split trees.
|
183
|
+
"""
|
184
|
+
first_leaf = jax.tree.leaves(tree)[0]
|
185
|
+
num_slices = first_leaf.shape[axis]
|
186
|
+
result = [jax.tree.map(lambda x, idx=i: jnp.take(x, idx, axis=axis), tree) for i in range(num_slices)]
|
187
|
+
return result
|
188
|
+
|
189
|
+
|
169
190
|
def scan(
|
170
191
|
f: Callable[[Carry, X], tuple[Carry, Y]],
|
171
192
|
init: Carry,
|
@@ -195,15 +216,96 @@ def scan(
|
|
195
216
|
if not should_disable_jit(jit_level):
|
196
217
|
return jax.lax.scan(f, init, xs, length, reverse, unroll)
|
197
218
|
|
219
|
+
carry = init
|
220
|
+
ys = []
|
221
|
+
|
198
222
|
if xs is None:
|
199
223
|
if length is None:
|
200
224
|
raise ValueError("length must be provided if xs is None")
|
201
|
-
|
225
|
+
for _ in range(length) if not reverse else range(length - 1, -1, -1):
|
226
|
+
carry, y = f(carry, None) # type: ignore[arg-type]
|
227
|
+
ys.append(y)
|
202
228
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
229
|
+
else:
|
230
|
+
xlist = _split_module(xs, axis=0)
|
231
|
+
if reverse:
|
232
|
+
xlist = xlist[::-1]
|
233
|
+
for x in xlist:
|
234
|
+
carry, y = f(carry, x)
|
235
|
+
ys.append(y)
|
236
|
+
|
237
|
+
if reverse:
|
238
|
+
ys = ys[::-1]
|
239
|
+
|
240
|
+
if not ys:
|
241
|
+
return carry, jnp.array([]) # type: ignore[return-value]
|
208
242
|
|
209
243
|
return carry, jax.tree.map(lambda *ys: jnp.stack(ys), *ys)
|
244
|
+
|
245
|
+
|
246
|
+
def vmap(
|
247
|
+
fun: Callable[P, R],
|
248
|
+
in_axes: int | Sequence[int | None] = 0,
|
249
|
+
jit_level: int | None = None,
|
250
|
+
) -> Callable[P, R]:
|
251
|
+
"""A wrapper around jax.lax.vmap that allows for more flexible tracing.
|
252
|
+
|
253
|
+
If the provided JIT level is below the environment JIT level, we manually
|
254
|
+
unroll the scan function as a for loop.
|
255
|
+
"""
|
256
|
+
if not should_disable_jit(jit_level):
|
257
|
+
return jax.vmap(fun, in_axes=in_axes)
|
258
|
+
|
259
|
+
@functools.wraps(fun)
|
260
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
261
|
+
if kwargs:
|
262
|
+
raise ValueError("vmap does not support keyword arguments")
|
263
|
+
|
264
|
+
ia = in_axes
|
265
|
+
if isinstance(ia, int):
|
266
|
+
ia = [ia] * len(args)
|
267
|
+
elif len(ia) != len(args):
|
268
|
+
raise ValueError("in_axes must be the same length as args")
|
269
|
+
|
270
|
+
if not all(isinstance(a, int) or a is None for a in ia):
|
271
|
+
raise ValueError("in_axes must be a list of integers or None")
|
272
|
+
|
273
|
+
ns = next((len(_split_module(a, axis=i)) for i, a in zip(ia, args, strict=True) if i is not None), None)
|
274
|
+
if ns is None:
|
275
|
+
return fun(*args, **kwargs)
|
276
|
+
split_args = [[a] * ns if i is None else _split_module(a, axis=i) for i, a in zip(ia, args, strict=True)]
|
277
|
+
split_outputs = [fun(*sargs, **kwargs) for sargs in zip(*split_args, strict=True)]
|
278
|
+
|
279
|
+
if not split_outputs:
|
280
|
+
return jnp.array([]) # type: ignore[return-value]
|
281
|
+
|
282
|
+
return jax.tree.map(lambda *ys: jnp.stack(ys), *split_outputs)
|
283
|
+
|
284
|
+
return wrapped
|
285
|
+
|
286
|
+
|
287
|
+
def grad(
|
288
|
+
fun: Callable[P, R],
|
289
|
+
argnums: int | Sequence[int] = 0,
|
290
|
+
has_aux: bool = False,
|
291
|
+
holomorphic: bool = False,
|
292
|
+
allow_int: bool = False,
|
293
|
+
reduce_axes: Sequence[AxisName] = (),
|
294
|
+
jit_level: int | None = None,
|
295
|
+
) -> Callable:
|
296
|
+
"""A wrapper around jax.grad that allows for more flexible tracing.
|
297
|
+
|
298
|
+
We don't do anything special here, we just manually evaluate the function
|
299
|
+
if the JIT level is below the environment JIT level.
|
300
|
+
"""
|
301
|
+
if not should_disable_jit(jit_level):
|
302
|
+
return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)
|
303
|
+
|
304
|
+
@functools.wraps(fun)
|
305
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Callable:
|
306
|
+
# Evaluate the function once, then just return the gradient.
|
307
|
+
fun(*args, **kwargs)
|
308
|
+
|
309
|
+
return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)(*args, **kwargs)
|
310
|
+
|
311
|
+
return wrapped
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.22
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -31,22 +31,6 @@ Requires-Dist: pytest; extra == "dev"
|
|
31
31
|
Requires-Dist: types-pillow; extra == "dev"
|
32
32
|
Requires-Dist: types-psutil; extra == "dev"
|
33
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
-
Provides-Extra: exportable
|
35
|
-
Requires-Dist: flax; extra == "exportable"
|
36
|
-
Requires-Dist: orbax-export; extra == "exportable"
|
37
|
-
Requires-Dist: tensorflow; extra == "exportable"
|
38
|
-
Provides-Extra: all
|
39
|
-
Requires-Dist: black; extra == "all"
|
40
|
-
Requires-Dist: darglint; extra == "all"
|
41
|
-
Requires-Dist: mypy; extra == "all"
|
42
|
-
Requires-Dist: ruff; extra == "all"
|
43
|
-
Requires-Dist: pytest; extra == "all"
|
44
|
-
Requires-Dist: types-pillow; extra == "all"
|
45
|
-
Requires-Dist: types-psutil; extra == "all"
|
46
|
-
Requires-Dist: types-requests; extra == "all"
|
47
|
-
Requires-Dist: flax; extra == "all"
|
48
|
-
Requires-Dist: orbax-export; extra == "all"
|
49
|
-
Requires-Dist: tensorflow; extra == "all"
|
50
34
|
Dynamic: author
|
51
35
|
Dynamic: description
|
52
36
|
Dynamic: description-content-type
|
@@ -14,19 +14,6 @@ tensorboard
|
|
14
14
|
psutil
|
15
15
|
requests
|
16
16
|
|
17
|
-
[all]
|
18
|
-
black
|
19
|
-
darglint
|
20
|
-
mypy
|
21
|
-
ruff
|
22
|
-
pytest
|
23
|
-
types-pillow
|
24
|
-
types-psutil
|
25
|
-
types-requests
|
26
|
-
flax
|
27
|
-
orbax-export
|
28
|
-
tensorflow
|
29
|
-
|
30
17
|
[dev]
|
31
18
|
black
|
32
19
|
darglint
|
@@ -36,8 +23,3 @@ pytest
|
|
36
23
|
types-pillow
|
37
24
|
types-psutil
|
38
25
|
types-requests
|
39
|
-
|
40
|
-
[exportable]
|
41
|
-
flax
|
42
|
-
orbax-export
|
43
|
-
tensorflow
|
xax-0.2.20/xax/nn/equinox.py
DELETED
@@ -1,183 +0,0 @@
|
|
1
|
-
"""Equinox utilities."""
|
2
|
-
|
3
|
-
import json
|
4
|
-
import logging
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import Callable, Literal, TypedDict, cast
|
7
|
-
|
8
|
-
import equinox as eqx
|
9
|
-
import jax
|
10
|
-
from jaxtyping import PRNGKeyArray
|
11
|
-
|
12
|
-
logger = logging.getLogger(__name__)
|
13
|
-
|
14
|
-
ActivationFunction = Literal[
|
15
|
-
"relu",
|
16
|
-
"tanh",
|
17
|
-
"celu",
|
18
|
-
"elu",
|
19
|
-
"gelu",
|
20
|
-
"glu",
|
21
|
-
"hard_sigmoid",
|
22
|
-
"hard_silu",
|
23
|
-
"hard_swish",
|
24
|
-
"hard_tanh",
|
25
|
-
"leaky_relu",
|
26
|
-
"log_sigmoid",
|
27
|
-
"log_softmax",
|
28
|
-
"logsumexp",
|
29
|
-
"relu6",
|
30
|
-
"selu",
|
31
|
-
"sigmoid",
|
32
|
-
"soft_sign",
|
33
|
-
"softmax",
|
34
|
-
"softplus",
|
35
|
-
"sparse_plus",
|
36
|
-
"sparse_sigmoid",
|
37
|
-
"silu",
|
38
|
-
"swish",
|
39
|
-
"squareplus",
|
40
|
-
"mish",
|
41
|
-
"identity",
|
42
|
-
]
|
43
|
-
|
44
|
-
DTYPE = Literal["float32", "float64"]
|
45
|
-
|
46
|
-
DTYPE_MAP: dict[DTYPE, jax.numpy.dtype] = {
|
47
|
-
"float32": jax.numpy.float32,
|
48
|
-
"float64": jax.numpy.float64,
|
49
|
-
}
|
50
|
-
|
51
|
-
|
52
|
-
class MLPHyperParams(TypedDict):
|
53
|
-
"""Hyperparameters of an Equinox MLP."""
|
54
|
-
|
55
|
-
in_size: int | Literal["scalar"]
|
56
|
-
out_size: int | Literal["scalar"]
|
57
|
-
width_size: int
|
58
|
-
depth: int
|
59
|
-
activation: ActivationFunction
|
60
|
-
final_activation: ActivationFunction
|
61
|
-
use_bias: bool
|
62
|
-
use_final_bias: bool
|
63
|
-
dtype: DTYPE
|
64
|
-
|
65
|
-
|
66
|
-
def _infer_activation(activation: ActivationFunction) -> Callable:
|
67
|
-
if activation == "identity":
|
68
|
-
return lambda x: x
|
69
|
-
try:
|
70
|
-
return getattr(jax.nn, activation)
|
71
|
-
except AttributeError as err:
|
72
|
-
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`") from err
|
73
|
-
|
74
|
-
|
75
|
-
def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
|
76
|
-
"""Create an Equinox MLP from a set of hyperparameters.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
hyperparams: The hyperparameters of the MLP.
|
80
|
-
key: The PRNG key to use for the MLP.
|
81
|
-
"""
|
82
|
-
activation = _infer_activation(hyperparams["activation"])
|
83
|
-
final_activation = _infer_activation(hyperparams["final_activation"])
|
84
|
-
dtype = DTYPE_MAP[hyperparams["dtype"]]
|
85
|
-
|
86
|
-
return eqx.nn.MLP(
|
87
|
-
in_size=hyperparams["in_size"],
|
88
|
-
out_size=hyperparams["out_size"],
|
89
|
-
width_size=hyperparams["width_size"],
|
90
|
-
depth=hyperparams["depth"],
|
91
|
-
activation=activation,
|
92
|
-
final_activation=final_activation,
|
93
|
-
use_bias=hyperparams["use_bias"],
|
94
|
-
use_final_bias=hyperparams["use_final_bias"],
|
95
|
-
dtype=dtype,
|
96
|
-
key=key,
|
97
|
-
)
|
98
|
-
|
99
|
-
|
100
|
-
def export_eqx_mlp(
|
101
|
-
model: eqx.nn.MLP,
|
102
|
-
output_path: str | Path,
|
103
|
-
dtype: jax.numpy.dtype | None = None,
|
104
|
-
) -> None:
|
105
|
-
"""Serialize an Equinox MLP to a .eqx file.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
model: The JAX MLP to export.
|
109
|
-
output_path: The path to save the exported model.
|
110
|
-
dtype: The dtype of the model.
|
111
|
-
"""
|
112
|
-
if dtype is None:
|
113
|
-
dtype = eqx._misc.default_floating_dtype()
|
114
|
-
|
115
|
-
activation = model.activation.__name__
|
116
|
-
final_activation = model.final_activation.__name__
|
117
|
-
|
118
|
-
if final_activation == "<lambda>":
|
119
|
-
logger.warning("Final activation is a lambda function. Assuming identity.")
|
120
|
-
final_activation = "identity"
|
121
|
-
|
122
|
-
# cast strings to ActivationFunction for type checking
|
123
|
-
activation = cast(ActivationFunction, activation)
|
124
|
-
final_activation = cast(ActivationFunction, final_activation)
|
125
|
-
|
126
|
-
if dtype not in DTYPE_MAP.values():
|
127
|
-
raise ValueError(f"Invalid dtype: {dtype}. Must be one of {DTYPE_MAP.values()}")
|
128
|
-
|
129
|
-
dtype = {v: k for k, v in DTYPE_MAP.items()}[dtype]
|
130
|
-
|
131
|
-
hyperparams: MLPHyperParams = {
|
132
|
-
"in_size": model.in_size,
|
133
|
-
"out_size": model.out_size,
|
134
|
-
"width_size": model.width_size,
|
135
|
-
"depth": model.depth,
|
136
|
-
"activation": activation,
|
137
|
-
"final_activation": final_activation,
|
138
|
-
"use_bias": model.use_bias,
|
139
|
-
"use_final_bias": model.use_final_bias,
|
140
|
-
"dtype": dtype,
|
141
|
-
}
|
142
|
-
|
143
|
-
with open(output_path, "wb") as f:
|
144
|
-
hyperparam_str = json.dumps(hyperparams)
|
145
|
-
f.write((hyperparam_str + "\n").encode(encoding="utf-8"))
|
146
|
-
eqx.tree_serialise_leaves(f, model)
|
147
|
-
|
148
|
-
|
149
|
-
def save_eqx(
|
150
|
-
model: eqx.Module,
|
151
|
-
output_path: str | Path,
|
152
|
-
) -> None:
|
153
|
-
"""Serialize an Equinox module to a .eqx file.
|
154
|
-
|
155
|
-
Args:
|
156
|
-
model: The Equinox module to export.
|
157
|
-
output_path: The path to save the exported model.
|
158
|
-
"""
|
159
|
-
with open(output_path, "wb") as f:
|
160
|
-
eqx.tree_serialise_leaves(f, model)
|
161
|
-
|
162
|
-
|
163
|
-
def load_eqx(
|
164
|
-
model: eqx.Module,
|
165
|
-
eqx_file: str | Path,
|
166
|
-
) -> eqx.Module:
|
167
|
-
"""Deserialize an Equinox module from a .eqx file.
|
168
|
-
|
169
|
-
Args:
|
170
|
-
model: The Equinox module to load into.
|
171
|
-
eqx_file: The path to the .eqx file to load.
|
172
|
-
"""
|
173
|
-
with open(eqx_file, "rb") as f:
|
174
|
-
return eqx.tree_deserialise_leaves(f, model)
|
175
|
-
|
176
|
-
|
177
|
-
def load_eqx_mlp(
|
178
|
-
eqx_file: str | Path,
|
179
|
-
) -> eqx.nn.MLP:
|
180
|
-
with open(eqx_file, "rb") as f:
|
181
|
-
hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
|
182
|
-
model = make_eqx_mlp(hyperparams=hyperparams, key=jax.random.PRNGKey(0))
|
183
|
-
return eqx.tree_deserialise_leaves(f, model)
|
xax-0.2.20/xax/nn/export.py
DELETED
@@ -1,154 +0,0 @@
|
|
1
|
-
"""Export JAX functions to TensorFlow SavedModel format."""
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
from typing import Callable
|
6
|
-
|
7
|
-
import jax
|
8
|
-
from jaxtyping import Array, PyTree
|
9
|
-
|
10
|
-
try:
|
11
|
-
import flax
|
12
|
-
import tensorflow as tf
|
13
|
-
from jax.experimental import jax2tf
|
14
|
-
from orbax.export import ExportManager, JaxModule, ServingConfig
|
15
|
-
except ImportError as e:
|
16
|
-
raise ImportError(
|
17
|
-
"In order to export models, please install Xax with exportable dependencies, "
|
18
|
-
"using 'xax[exportable]` to install the required dependencies."
|
19
|
-
) from e
|
20
|
-
|
21
|
-
logger = logging.getLogger(__name__)
|
22
|
-
|
23
|
-
|
24
|
-
def _run_infer(tf_module: tf.Module, input_shapes: list[tuple[int, ...]], batch_size: int | None) -> tf.Tensor:
|
25
|
-
"""Warm up the model by running it once."""
|
26
|
-
if batch_size is not None:
|
27
|
-
test_inputs = [
|
28
|
-
jax.random.normal(jax.random.PRNGKey(42), (batch_size, *input_shape)) for input_shape in input_shapes
|
29
|
-
]
|
30
|
-
else:
|
31
|
-
test_inputs = [jax.random.normal(jax.random.PRNGKey(42), (1, *input_shape)) for input_shape in input_shapes]
|
32
|
-
if not hasattr(tf_module, "infer"):
|
33
|
-
raise ValueError("Model does not have an infer method")
|
34
|
-
return tf_module.infer(*test_inputs)
|
35
|
-
|
36
|
-
|
37
|
-
def export(
|
38
|
-
model: Callable,
|
39
|
-
input_shapes: list[tuple[int, ...]],
|
40
|
-
output_dir: str | Path = "export",
|
41
|
-
batch_size: int | None = None,
|
42
|
-
) -> None:
|
43
|
-
"""Export a JAX function to TensorFlow SavedModel.
|
44
|
-
|
45
|
-
Note: Tensorflow GraphDef can't be larger than 2GB - https://github.com/tensorflow/tensorflow/issues/51870
|
46
|
-
You can avoid this by saving model parameters as non-constants.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
model: The JAX function to export.
|
50
|
-
input_shapes: The shape of the input tensors, excluding batch dimension.
|
51
|
-
output_dir: Directory to save the exported model.
|
52
|
-
batch_size: Optional batch dimension. If None, a polymorphic batch dimension is used.
|
53
|
-
"""
|
54
|
-
tf_module = tf.Module()
|
55
|
-
# Create a polymorphic shape specification for each input
|
56
|
-
poly_spec = "(b, ...)" if batch_size is not None else "(None, ...)"
|
57
|
-
polymorphic_shapes = [poly_spec] * len(input_shapes)
|
58
|
-
tf_module.infer = tf.function( # type: ignore [attr-defined]
|
59
|
-
jax2tf.convert(
|
60
|
-
model,
|
61
|
-
polymorphic_shapes=polymorphic_shapes,
|
62
|
-
# setting this to False will allow the model to run on platforms other than the one that exports the model
|
63
|
-
# https://github.com/jax-ml/jax/blob/051687dc4c899df3d95c30b812ade401d8b31166/jax/experimental/jax2tf/README.md?plain=1#L1342
|
64
|
-
# generally though I think native_serialization is recommended
|
65
|
-
native_serialization=False,
|
66
|
-
with_gradient=False,
|
67
|
-
),
|
68
|
-
autograph=False,
|
69
|
-
input_signature=[tf.TensorSpec([batch_size] + list(input_shape), tf.float32) for input_shape in input_shapes],
|
70
|
-
)
|
71
|
-
|
72
|
-
# warm up the model
|
73
|
-
_run_infer(tf_module, input_shapes, batch_size)
|
74
|
-
|
75
|
-
logger.info("Exporting SavedModel to %s", output_dir)
|
76
|
-
tf.saved_model.save(
|
77
|
-
tf_module,
|
78
|
-
output_dir,
|
79
|
-
)
|
80
|
-
|
81
|
-
|
82
|
-
def export_with_params(
|
83
|
-
model: Callable,
|
84
|
-
params: PyTree,
|
85
|
-
input_shapes: list[tuple[int, ...]],
|
86
|
-
output_dir: str | Path = "export",
|
87
|
-
batch_dim: int | None = None,
|
88
|
-
) -> None:
|
89
|
-
"""Export a JAX function that takes parameters to TensorFlow SavedModel.
|
90
|
-
|
91
|
-
Args:
|
92
|
-
model: The JAX function to export. Should take parameters as first argument.
|
93
|
-
params: The parameters to use for the model.
|
94
|
-
input_shapes: The shape of the input tensors, excluding batch dimension.
|
95
|
-
output_dir: Directory to save the exported model.
|
96
|
-
batch_dim: Optional batch dimension. If None, a polymorphic batch dimension is used.
|
97
|
-
"""
|
98
|
-
param_vars = tf.nest.map_structure(tf.Variable, params)
|
99
|
-
|
100
|
-
converted_model = jax2tf.convert(model)
|
101
|
-
|
102
|
-
def model_fn(*inputs: PyTree) -> Array:
|
103
|
-
return converted_model(param_vars, *inputs)
|
104
|
-
|
105
|
-
tf_module = tf.Module()
|
106
|
-
tf_module._variables = tf.nest.flatten(param_vars) # type: ignore [attr-defined]
|
107
|
-
tf_module.infer = tf.function( # type: ignore [attr-defined]
|
108
|
-
model_fn,
|
109
|
-
jit_compile=True,
|
110
|
-
autograph=False,
|
111
|
-
input_signature=[tf.TensorSpec([batch_dim] + list(input_shape), tf.float32) for input_shape in input_shapes],
|
112
|
-
)
|
113
|
-
|
114
|
-
# warm up the model
|
115
|
-
_run_infer(tf_module, input_shapes, batch_dim)
|
116
|
-
|
117
|
-
logger.info("Exporting SavedModel to %s", output_dir)
|
118
|
-
tf.saved_model.save(tf_module, output_dir)
|
119
|
-
|
120
|
-
|
121
|
-
def export_flax(
|
122
|
-
model: flax.linen.Module,
|
123
|
-
params: PyTree,
|
124
|
-
input_shape: tuple[int, ...],
|
125
|
-
preprocessor: Callable | None = None,
|
126
|
-
postprocessor: Callable | None = None,
|
127
|
-
input_name: str = "inputs",
|
128
|
-
output_name: str = "outputs",
|
129
|
-
output_dir: str | Path = "export",
|
130
|
-
) -> None:
|
131
|
-
jax_module = JaxModule(
|
132
|
-
params, model.apply, trainable=False, input_polymorphic_shape="(b, ...)"
|
133
|
-
) # if you want to use a batch dimension
|
134
|
-
|
135
|
-
# to avoid mapping sequences to ambiguous mappings
|
136
|
-
if postprocessor is None:
|
137
|
-
|
138
|
-
def postprocessor(x: PyTree) -> PyTree:
|
139
|
-
return {output_name: x}
|
140
|
-
|
141
|
-
export_manager = ExportManager(
|
142
|
-
jax_module,
|
143
|
-
[
|
144
|
-
ServingConfig(
|
145
|
-
"serving_default",
|
146
|
-
input_signature=[tf.TensorSpec([None] + list(input_shape), tf.float32, name=input_name)],
|
147
|
-
tf_preprocessor=preprocessor,
|
148
|
-
tf_postprocessor=postprocessor,
|
149
|
-
)
|
150
|
-
],
|
151
|
-
)
|
152
|
-
|
153
|
-
logger.info("Exporting model to %s", output_dir)
|
154
|
-
export_manager.save(output_dir)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|