xax 0.2.19__tar.gz → 0.2.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.2.19/xax.egg-info → xax-0.2.21}/PKG-INFO +1 -17
  2. {xax-0.2.19 → xax-0.2.21}/pyproject.toml +0 -1
  3. {xax-0.2.19 → xax-0.2.21}/setup.py +5 -8
  4. {xax-0.2.19 → xax-0.2.21}/xax/__init__.py +1 -23
  5. xax-0.2.21/xax/cli/edit_config.py +77 -0
  6. {xax-0.2.19 → xax-0.2.21}/xax/nn/metrics.py +0 -3
  7. {xax-0.2.19 → xax-0.2.21}/xax/task/base.py +1 -1
  8. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/checkpointing.py +1 -1
  9. xax-0.2.21/xax/utils/types/__init__.py +0 -0
  10. {xax-0.2.19 → xax-0.2.21/xax.egg-info}/PKG-INFO +1 -17
  11. {xax-0.2.19 → xax-0.2.21}/xax.egg-info/SOURCES.txt +3 -2
  12. xax-0.2.21/xax.egg-info/entry_points.txt +2 -0
  13. {xax-0.2.19 → xax-0.2.21}/xax.egg-info/requires.txt +0 -18
  14. xax-0.2.19/xax/nn/equinox.py +0 -183
  15. xax-0.2.19/xax/nn/export.py +0 -154
  16. {xax-0.2.19 → xax-0.2.21}/LICENSE +0 -0
  17. {xax-0.2.19 → xax-0.2.21}/MANIFEST.in +0 -0
  18. {xax-0.2.19 → xax-0.2.21}/README.md +0 -0
  19. {xax-0.2.19 → xax-0.2.21}/setup.cfg +0 -0
  20. {xax-0.2.19/xax/core → xax-0.2.21/xax/cli}/__init__.py +0 -0
  21. {xax-0.2.19/xax/nn → xax-0.2.21/xax/core}/__init__.py +0 -0
  22. {xax-0.2.19 → xax-0.2.21}/xax/core/conf.py +0 -0
  23. {xax-0.2.19 → xax-0.2.21}/xax/core/state.py +0 -0
  24. {xax-0.2.19/xax/task → xax-0.2.21/xax/nn}/__init__.py +0 -0
  25. {xax-0.2.19 → xax-0.2.21}/xax/nn/embeddings.py +0 -0
  26. {xax-0.2.19 → xax-0.2.21}/xax/nn/functions.py +0 -0
  27. {xax-0.2.19 → xax-0.2.21}/xax/nn/geom.py +0 -0
  28. {xax-0.2.19 → xax-0.2.21}/xax/nn/losses.py +0 -0
  29. {xax-0.2.19 → xax-0.2.21}/xax/nn/parallel.py +0 -0
  30. {xax-0.2.19 → xax-0.2.21}/xax/nn/ssm.py +0 -0
  31. {xax-0.2.19 → xax-0.2.21}/xax/py.typed +0 -0
  32. {xax-0.2.19 → xax-0.2.21}/xax/requirements-dev.txt +0 -0
  33. {xax-0.2.19 → xax-0.2.21}/xax/requirements.txt +0 -0
  34. {xax-0.2.19/xax/task/launchers → xax-0.2.21/xax/task}/__init__.py +0 -0
  35. {xax-0.2.19/xax/task/loggers → xax-0.2.21/xax/task/launchers}/__init__.py +0 -0
  36. {xax-0.2.19 → xax-0.2.21}/xax/task/launchers/base.py +0 -0
  37. {xax-0.2.19 → xax-0.2.21}/xax/task/launchers/cli.py +0 -0
  38. {xax-0.2.19 → xax-0.2.21}/xax/task/launchers/single_process.py +0 -0
  39. {xax-0.2.19 → xax-0.2.21}/xax/task/logger.py +0 -0
  40. {xax-0.2.19/xax/utils → xax-0.2.21/xax/task/loggers}/__init__.py +0 -0
  41. {xax-0.2.19 → xax-0.2.21}/xax/task/loggers/callback.py +0 -0
  42. {xax-0.2.19 → xax-0.2.21}/xax/task/loggers/json.py +0 -0
  43. {xax-0.2.19 → xax-0.2.21}/xax/task/loggers/state.py +0 -0
  44. {xax-0.2.19 → xax-0.2.21}/xax/task/loggers/stdout.py +0 -0
  45. {xax-0.2.19 → xax-0.2.21}/xax/task/loggers/tensorboard.py +0 -0
  46. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/__init__.py +0 -0
  47. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/artifacts.py +0 -0
  48. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/compile.py +0 -0
  49. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/cpu_stats.py +0 -0
  50. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/data_loader.py +0 -0
  51. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/gpu_stats.py +0 -0
  52. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/logger.py +0 -0
  53. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/process.py +0 -0
  54. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/runnable.py +0 -0
  55. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/step_wrapper.py +0 -0
  56. {xax-0.2.19 → xax-0.2.21}/xax/task/mixins/train.py +0 -0
  57. {xax-0.2.19 → xax-0.2.21}/xax/task/script.py +0 -0
  58. {xax-0.2.19 → xax-0.2.21}/xax/task/task.py +0 -0
  59. {xax-0.2.19/xax/utils/data → xax-0.2.21/xax/utils}/__init__.py +0 -0
  60. {xax-0.2.19/xax/utils/types → xax-0.2.21/xax/utils/data}/__init__.py +0 -0
  61. {xax-0.2.19 → xax-0.2.21}/xax/utils/data/collate.py +0 -0
  62. {xax-0.2.19 → xax-0.2.21}/xax/utils/debugging.py +0 -0
  63. {xax-0.2.19 → xax-0.2.21}/xax/utils/experiments.py +0 -0
  64. {xax-0.2.19 → xax-0.2.21}/xax/utils/jax.py +0 -0
  65. {xax-0.2.19 → xax-0.2.21}/xax/utils/jaxpr.py +0 -0
  66. {xax-0.2.19 → xax-0.2.21}/xax/utils/logging.py +0 -0
  67. {xax-0.2.19 → xax-0.2.21}/xax/utils/numpy.py +0 -0
  68. {xax-0.2.19 → xax-0.2.21}/xax/utils/profile.py +0 -0
  69. {xax-0.2.19 → xax-0.2.21}/xax/utils/pytree.py +0 -0
  70. {xax-0.2.19 → xax-0.2.21}/xax/utils/tensorboard.py +0 -0
  71. {xax-0.2.19 → xax-0.2.21}/xax/utils/text.py +0 -0
  72. {xax-0.2.19 → xax-0.2.21}/xax/utils/types/frozen_dict.py +0 -0
  73. {xax-0.2.19 → xax-0.2.21}/xax/utils/types/hashable_array.py +0 -0
  74. {xax-0.2.19 → xax-0.2.21}/xax.egg-info/dependency_links.txt +0 -0
  75. {xax-0.2.19 → xax-0.2.21}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.19
3
+ Version: 0.2.21
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,22 +31,6 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
- Provides-Extra: exportable
35
- Requires-Dist: flax; extra == "exportable"
36
- Requires-Dist: orbax-export; extra == "exportable"
37
- Requires-Dist: tensorflow; extra == "exportable"
38
- Provides-Extra: all
39
- Requires-Dist: black; extra == "all"
40
- Requires-Dist: darglint; extra == "all"
41
- Requires-Dist: mypy; extra == "all"
42
- Requires-Dist: ruff; extra == "all"
43
- Requires-Dist: pytest; extra == "all"
44
- Requires-Dist: types-pillow; extra == "all"
45
- Requires-Dist: types-psutil; extra == "all"
46
- Requires-Dist: types-requests; extra == "all"
47
- Requires-Dist: flax; extra == "all"
48
- Requires-Dist: orbax-export; extra == "all"
49
- Requires-Dist: tensorflow; extra == "all"
50
34
  Dynamic: author
51
35
  Dynamic: description
52
36
  Dynamic: description-content-type
@@ -39,7 +39,6 @@ module = [
39
39
  "setuptools.*",
40
40
  "tensorboard.*",
41
41
  "transformers.*",
42
- "orbax.export.*",
43
42
  ]
44
43
 
45
44
  ignore_missing_imports = true
@@ -14,12 +14,6 @@ with open("xax/requirements.txt", "r", encoding="utf-8") as f:
14
14
  with open("xax/requirements-dev.txt", "r", encoding="utf-8") as f:
15
15
  requirements_dev: list[str] = f.read().splitlines()
16
16
 
17
- requirements_export: list[str] = [
18
- "flax",
19
- "orbax-export",
20
- "tensorflow",
21
- ]
22
-
23
17
  with open("xax/__init__.py", "r", encoding="utf-8") as fh:
24
18
  version_re = re.search(r"^__version__ = \"([^\"]*)\"", fh.read(), re.MULTILINE)
25
19
  assert version_re is not None, "Could not find version in xax/__init__.py"
@@ -39,8 +33,6 @@ setup(
39
33
  tests_require=requirements_dev,
40
34
  extras_require={
41
35
  "dev": requirements_dev,
42
- "exportable": requirements_export,
43
- "all": requirements_dev + requirements_export,
44
36
  },
45
37
  package_data={
46
38
  "xax": [
@@ -48,4 +40,9 @@ setup(
48
40
  "requirements*.txt",
49
41
  ],
50
42
  },
43
+ entry_points={
44
+ "console_scripts": [
45
+ "xax-edit-config=xax.cli.edit_config:main",
46
+ ],
47
+ },
51
48
  )
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.19"
15
+ __version__ = "0.2.21"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -34,12 +34,6 @@ __all__ = [
34
34
  "get_positional_embeddings",
35
35
  "get_rotary_embeddings",
36
36
  "rotary_embeddings",
37
- "MLPHyperParams",
38
- "export_eqx_mlp",
39
- "load_eqx",
40
- "load_eqx_mlp",
41
- "make_eqx_mlp",
42
- "save_eqx",
43
37
  "cubic_bezier_interpolation",
44
38
  "euler_to_quat",
45
39
  "get_projected_gravity_vector_from_quat",
@@ -215,12 +209,6 @@ NAME_MAP: dict[str, str] = {
215
209
  "get_positional_embeddings": "nn.embeddings",
216
210
  "get_rotary_embeddings": "nn.embeddings",
217
211
  "rotary_embeddings": "nn.embeddings",
218
- "MLPHyperParams": "nn.equinox",
219
- "export_eqx_mlp": "nn.equinox",
220
- "load_eqx": "nn.equinox",
221
- "load_eqx_mlp": "nn.equinox",
222
- "make_eqx_mlp": "nn.equinox",
223
- "save_eqx": "nn.equinox",
224
212
  "cubic_bezier_interpolation": "nn.geom",
225
213
  "euler_to_quat": "nn.geom",
226
214
  "get_projected_gravity_vector_from_quat": "nn.geom",
@@ -392,16 +380,6 @@ if IMPORT_ALL or TYPE_CHECKING:
392
380
  get_rotary_embeddings,
393
381
  rotary_embeddings,
394
382
  )
395
- from xax.nn.equinox import (
396
- DTYPE,
397
- ActivationFunction,
398
- MLPHyperParams,
399
- export_eqx_mlp,
400
- load_eqx,
401
- load_eqx_mlp,
402
- make_eqx_mlp,
403
- save_eqx,
404
- )
405
383
  from xax.nn.geom import (
406
384
  cubic_bezier_interpolation,
407
385
  euler_to_quat,
@@ -0,0 +1,77 @@
1
+ """Lets you edit a checkpoint config programmatically."""
2
+
3
+ import argparse
4
+ import difflib
5
+ import io
6
+ import os
7
+ import subprocess
8
+ import tarfile
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ from omegaconf import OmegaConf
13
+
14
+ from xax.task.mixins.checkpointing import load_ckpt
15
+ from xax.utils.text import colored, show_info
16
+
17
+
18
+ def main() -> None:
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("ckpt_path", type=Path)
21
+ args = parser.parse_args()
22
+
23
+ # Loads the config from the checkpoint.
24
+ config = load_ckpt(args.ckpt_path, part="config")
25
+ config_str = OmegaConf.to_yaml(config)
26
+
27
+ # Opens the user's preferred editor to edit the config.
28
+ with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
29
+ f.write(config_str.encode("utf-8"))
30
+ f.flush()
31
+ subprocess.run([os.environ.get("EDITOR", "vim"), f.name], check=True)
32
+
33
+ # Loads the edited config.
34
+ try:
35
+ edited_config = OmegaConf.load(f.name)
36
+ edited_config_str = OmegaConf.to_yaml(edited_config, sort_keys=True)
37
+ finally:
38
+ os.remove(f.name)
39
+
40
+ if edited_config_str == config_str:
41
+ show_info("No changes were made to the config.")
42
+ return
43
+
44
+ # Diffs the original and edited configs.
45
+ diff = difflib.ndiff(config_str.splitlines(), edited_config_str.splitlines())
46
+ for line in diff:
47
+ if line.startswith("+ "):
48
+ print(colored(line, "light-green"), flush=True)
49
+ elif line.startswith("- "):
50
+ print(colored(line, "light-red"), flush=True)
51
+ elif line.startswith("? "):
52
+ print(colored(line, "light-cyan"), flush=True)
53
+
54
+ # Saves the edited config to the checkpoint.
55
+ with tempfile.TemporaryDirectory() as tmp_dir:
56
+ with tarfile.open(args.ckpt_path, "r:gz") as src_tar:
57
+ for member in src_tar.getmembers():
58
+ if member.name != "config": # Skip the old config file
59
+ src_tar.extract(member, tmp_dir)
60
+
61
+ with tarfile.open(args.ckpt_path, "w:gz") as tar:
62
+ for root, _, files in os.walk(tmp_dir):
63
+ for file in files:
64
+ file_path = os.path.join(root, file)
65
+ arcname = os.path.relpath(file_path, tmp_dir)
66
+ tar.add(file_path, arcname=arcname)
67
+
68
+ # Add the new config file
69
+ info = tarfile.TarInfo(name="config")
70
+ config_bytes = edited_config_str.encode()
71
+ info.size = len(config_bytes)
72
+ tar.addfile(info, io.BytesIO(config_bytes))
73
+
74
+
75
+ if __name__ == "__main__":
76
+ # python -m xax.cli.edit_config
77
+ main()
@@ -7,8 +7,6 @@ import jax
7
7
  import jax.numpy as jnp
8
8
  from jaxtyping import Array
9
9
 
10
- from xax.utils.jax import jit as xax_jit
11
-
12
10
  NormType = Literal["l1", "l2"]
13
11
 
14
12
 
@@ -36,7 +34,6 @@ def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
36
34
  def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
37
35
 
38
36
 
39
- @xax_jit(static_argnames=["return_path"])
40
37
  def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
41
38
  """Dynamic Time Warping.
42
39
 
@@ -210,7 +210,7 @@ class BaseTask(Generic[Config]):
210
210
 
211
211
  @classmethod
212
212
  def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
213
- return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
213
+ return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli), sort_keys=True)
214
214
 
215
215
  @classmethod
216
216
  def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
@@ -292,7 +292,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
292
292
 
293
293
  if state is not None:
294
294
  add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
295
- add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
295
+ add_file_bytes("config", OmegaConf.to_yaml(self.config, sort_keys=True).encode())
296
296
 
297
297
  # Updates the symlink to the new checkpoint
298
298
  last_ckpt_path.unlink(missing_ok=True)
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.19
3
+ Version: 0.2.21
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,22 +31,6 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
- Provides-Extra: exportable
35
- Requires-Dist: flax; extra == "exportable"
36
- Requires-Dist: orbax-export; extra == "exportable"
37
- Requires-Dist: tensorflow; extra == "exportable"
38
- Provides-Extra: all
39
- Requires-Dist: black; extra == "all"
40
- Requires-Dist: darglint; extra == "all"
41
- Requires-Dist: mypy; extra == "all"
42
- Requires-Dist: ruff; extra == "all"
43
- Requires-Dist: pytest; extra == "all"
44
- Requires-Dist: types-pillow; extra == "all"
45
- Requires-Dist: types-psutil; extra == "all"
46
- Requires-Dist: types-requests; extra == "all"
47
- Requires-Dist: flax; extra == "all"
48
- Requires-Dist: orbax-export; extra == "all"
49
- Requires-Dist: tensorflow; extra == "all"
50
34
  Dynamic: author
51
35
  Dynamic: description
52
36
  Dynamic: description-content-type
@@ -11,15 +11,16 @@ xax/requirements.txt
11
11
  xax.egg-info/PKG-INFO
12
12
  xax.egg-info/SOURCES.txt
13
13
  xax.egg-info/dependency_links.txt
14
+ xax.egg-info/entry_points.txt
14
15
  xax.egg-info/requires.txt
15
16
  xax.egg-info/top_level.txt
17
+ xax/cli/__init__.py
18
+ xax/cli/edit_config.py
16
19
  xax/core/__init__.py
17
20
  xax/core/conf.py
18
21
  xax/core/state.py
19
22
  xax/nn/__init__.py
20
23
  xax/nn/embeddings.py
21
- xax/nn/equinox.py
22
- xax/nn/export.py
23
24
  xax/nn/functions.py
24
25
  xax/nn/geom.py
25
26
  xax/nn/losses.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ xax-edit-config = xax.cli.edit_config:main
@@ -14,19 +14,6 @@ tensorboard
14
14
  psutil
15
15
  requests
16
16
 
17
- [all]
18
- black
19
- darglint
20
- mypy
21
- ruff
22
- pytest
23
- types-pillow
24
- types-psutil
25
- types-requests
26
- flax
27
- orbax-export
28
- tensorflow
29
-
30
17
  [dev]
31
18
  black
32
19
  darglint
@@ -36,8 +23,3 @@ pytest
36
23
  types-pillow
37
24
  types-psutil
38
25
  types-requests
39
-
40
- [exportable]
41
- flax
42
- orbax-export
43
- tensorflow
@@ -1,183 +0,0 @@
1
- """Equinox utilities."""
2
-
3
- import json
4
- import logging
5
- from pathlib import Path
6
- from typing import Callable, Literal, TypedDict, cast
7
-
8
- import equinox as eqx
9
- import jax
10
- from jaxtyping import PRNGKeyArray
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- ActivationFunction = Literal[
15
- "relu",
16
- "tanh",
17
- "celu",
18
- "elu",
19
- "gelu",
20
- "glu",
21
- "hard_sigmoid",
22
- "hard_silu",
23
- "hard_swish",
24
- "hard_tanh",
25
- "leaky_relu",
26
- "log_sigmoid",
27
- "log_softmax",
28
- "logsumexp",
29
- "relu6",
30
- "selu",
31
- "sigmoid",
32
- "soft_sign",
33
- "softmax",
34
- "softplus",
35
- "sparse_plus",
36
- "sparse_sigmoid",
37
- "silu",
38
- "swish",
39
- "squareplus",
40
- "mish",
41
- "identity",
42
- ]
43
-
44
- DTYPE = Literal["float32", "float64"]
45
-
46
- DTYPE_MAP: dict[DTYPE, jax.numpy.dtype] = {
47
- "float32": jax.numpy.float32,
48
- "float64": jax.numpy.float64,
49
- }
50
-
51
-
52
- class MLPHyperParams(TypedDict):
53
- """Hyperparameters of an Equinox MLP."""
54
-
55
- in_size: int | Literal["scalar"]
56
- out_size: int | Literal["scalar"]
57
- width_size: int
58
- depth: int
59
- activation: ActivationFunction
60
- final_activation: ActivationFunction
61
- use_bias: bool
62
- use_final_bias: bool
63
- dtype: DTYPE
64
-
65
-
66
- def _infer_activation(activation: ActivationFunction) -> Callable:
67
- if activation == "identity":
68
- return lambda x: x
69
- try:
70
- return getattr(jax.nn, activation)
71
- except AttributeError as err:
72
- raise ValueError(f"Activation function `{activation}` not found in `jax.nn`") from err
73
-
74
-
75
- def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
76
- """Create an Equinox MLP from a set of hyperparameters.
77
-
78
- Args:
79
- hyperparams: The hyperparameters of the MLP.
80
- key: The PRNG key to use for the MLP.
81
- """
82
- activation = _infer_activation(hyperparams["activation"])
83
- final_activation = _infer_activation(hyperparams["final_activation"])
84
- dtype = DTYPE_MAP[hyperparams["dtype"]]
85
-
86
- return eqx.nn.MLP(
87
- in_size=hyperparams["in_size"],
88
- out_size=hyperparams["out_size"],
89
- width_size=hyperparams["width_size"],
90
- depth=hyperparams["depth"],
91
- activation=activation,
92
- final_activation=final_activation,
93
- use_bias=hyperparams["use_bias"],
94
- use_final_bias=hyperparams["use_final_bias"],
95
- dtype=dtype,
96
- key=key,
97
- )
98
-
99
-
100
- def export_eqx_mlp(
101
- model: eqx.nn.MLP,
102
- output_path: str | Path,
103
- dtype: jax.numpy.dtype | None = None,
104
- ) -> None:
105
- """Serialize an Equinox MLP to a .eqx file.
106
-
107
- Args:
108
- model: The JAX MLP to export.
109
- output_path: The path to save the exported model.
110
- dtype: The dtype of the model.
111
- """
112
- if dtype is None:
113
- dtype = eqx._misc.default_floating_dtype()
114
-
115
- activation = model.activation.__name__
116
- final_activation = model.final_activation.__name__
117
-
118
- if final_activation == "<lambda>":
119
- logger.warning("Final activation is a lambda function. Assuming identity.")
120
- final_activation = "identity"
121
-
122
- # cast strings to ActivationFunction for type checking
123
- activation = cast(ActivationFunction, activation)
124
- final_activation = cast(ActivationFunction, final_activation)
125
-
126
- if dtype not in DTYPE_MAP.values():
127
- raise ValueError(f"Invalid dtype: {dtype}. Must be one of {DTYPE_MAP.values()}")
128
-
129
- dtype = {v: k for k, v in DTYPE_MAP.items()}[dtype]
130
-
131
- hyperparams: MLPHyperParams = {
132
- "in_size": model.in_size,
133
- "out_size": model.out_size,
134
- "width_size": model.width_size,
135
- "depth": model.depth,
136
- "activation": activation,
137
- "final_activation": final_activation,
138
- "use_bias": model.use_bias,
139
- "use_final_bias": model.use_final_bias,
140
- "dtype": dtype,
141
- }
142
-
143
- with open(output_path, "wb") as f:
144
- hyperparam_str = json.dumps(hyperparams)
145
- f.write((hyperparam_str + "\n").encode(encoding="utf-8"))
146
- eqx.tree_serialise_leaves(f, model)
147
-
148
-
149
- def save_eqx(
150
- model: eqx.Module,
151
- output_path: str | Path,
152
- ) -> None:
153
- """Serialize an Equinox module to a .eqx file.
154
-
155
- Args:
156
- model: The Equinox module to export.
157
- output_path: The path to save the exported model.
158
- """
159
- with open(output_path, "wb") as f:
160
- eqx.tree_serialise_leaves(f, model)
161
-
162
-
163
- def load_eqx(
164
- model: eqx.Module,
165
- eqx_file: str | Path,
166
- ) -> eqx.Module:
167
- """Deserialize an Equinox module from a .eqx file.
168
-
169
- Args:
170
- model: The Equinox module to load into.
171
- eqx_file: The path to the .eqx file to load.
172
- """
173
- with open(eqx_file, "rb") as f:
174
- return eqx.tree_deserialise_leaves(f, model)
175
-
176
-
177
- def load_eqx_mlp(
178
- eqx_file: str | Path,
179
- ) -> eqx.nn.MLP:
180
- with open(eqx_file, "rb") as f:
181
- hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
182
- model = make_eqx_mlp(hyperparams=hyperparams, key=jax.random.PRNGKey(0))
183
- return eqx.tree_deserialise_leaves(f, model)
@@ -1,154 +0,0 @@
1
- """Export JAX functions to TensorFlow SavedModel format."""
2
-
3
- import logging
4
- from pathlib import Path
5
- from typing import Callable
6
-
7
- import jax
8
- from jaxtyping import Array, PyTree
9
-
10
- try:
11
- import flax
12
- import tensorflow as tf
13
- from jax.experimental import jax2tf
14
- from orbax.export import ExportManager, JaxModule, ServingConfig
15
- except ImportError as e:
16
- raise ImportError(
17
- "In order to export models, please install Xax with exportable dependencies, "
18
- "using 'xax[exportable]` to install the required dependencies."
19
- ) from e
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- def _run_infer(tf_module: tf.Module, input_shapes: list[tuple[int, ...]], batch_size: int | None) -> tf.Tensor:
25
- """Warm up the model by running it once."""
26
- if batch_size is not None:
27
- test_inputs = [
28
- jax.random.normal(jax.random.PRNGKey(42), (batch_size, *input_shape)) for input_shape in input_shapes
29
- ]
30
- else:
31
- test_inputs = [jax.random.normal(jax.random.PRNGKey(42), (1, *input_shape)) for input_shape in input_shapes]
32
- if not hasattr(tf_module, "infer"):
33
- raise ValueError("Model does not have an infer method")
34
- return tf_module.infer(*test_inputs)
35
-
36
-
37
- def export(
38
- model: Callable,
39
- input_shapes: list[tuple[int, ...]],
40
- output_dir: str | Path = "export",
41
- batch_size: int | None = None,
42
- ) -> None:
43
- """Export a JAX function to TensorFlow SavedModel.
44
-
45
- Note: Tensorflow GraphDef can't be larger than 2GB - https://github.com/tensorflow/tensorflow/issues/51870
46
- You can avoid this by saving model parameters as non-constants.
47
-
48
- Args:
49
- model: The JAX function to export.
50
- input_shapes: The shape of the input tensors, excluding batch dimension.
51
- output_dir: Directory to save the exported model.
52
- batch_size: Optional batch dimension. If None, a polymorphic batch dimension is used.
53
- """
54
- tf_module = tf.Module()
55
- # Create a polymorphic shape specification for each input
56
- poly_spec = "(b, ...)" if batch_size is not None else "(None, ...)"
57
- polymorphic_shapes = [poly_spec] * len(input_shapes)
58
- tf_module.infer = tf.function( # type: ignore [attr-defined]
59
- jax2tf.convert(
60
- model,
61
- polymorphic_shapes=polymorphic_shapes,
62
- # setting this to False will allow the model to run on platforms other than the one that exports the model
63
- # https://github.com/jax-ml/jax/blob/051687dc4c899df3d95c30b812ade401d8b31166/jax/experimental/jax2tf/README.md?plain=1#L1342
64
- # generally though I think native_serialization is recommended
65
- native_serialization=False,
66
- with_gradient=False,
67
- ),
68
- autograph=False,
69
- input_signature=[tf.TensorSpec([batch_size] + list(input_shape), tf.float32) for input_shape in input_shapes],
70
- )
71
-
72
- # warm up the model
73
- _run_infer(tf_module, input_shapes, batch_size)
74
-
75
- logger.info("Exporting SavedModel to %s", output_dir)
76
- tf.saved_model.save(
77
- tf_module,
78
- output_dir,
79
- )
80
-
81
-
82
- def export_with_params(
83
- model: Callable,
84
- params: PyTree,
85
- input_shapes: list[tuple[int, ...]],
86
- output_dir: str | Path = "export",
87
- batch_dim: int | None = None,
88
- ) -> None:
89
- """Export a JAX function that takes parameters to TensorFlow SavedModel.
90
-
91
- Args:
92
- model: The JAX function to export. Should take parameters as first argument.
93
- params: The parameters to use for the model.
94
- input_shapes: The shape of the input tensors, excluding batch dimension.
95
- output_dir: Directory to save the exported model.
96
- batch_dim: Optional batch dimension. If None, a polymorphic batch dimension is used.
97
- """
98
- param_vars = tf.nest.map_structure(tf.Variable, params)
99
-
100
- converted_model = jax2tf.convert(model)
101
-
102
- def model_fn(*inputs: PyTree) -> Array:
103
- return converted_model(param_vars, *inputs)
104
-
105
- tf_module = tf.Module()
106
- tf_module._variables = tf.nest.flatten(param_vars) # type: ignore [attr-defined]
107
- tf_module.infer = tf.function( # type: ignore [attr-defined]
108
- model_fn,
109
- jit_compile=True,
110
- autograph=False,
111
- input_signature=[tf.TensorSpec([batch_dim] + list(input_shape), tf.float32) for input_shape in input_shapes],
112
- )
113
-
114
- # warm up the model
115
- _run_infer(tf_module, input_shapes, batch_dim)
116
-
117
- logger.info("Exporting SavedModel to %s", output_dir)
118
- tf.saved_model.save(tf_module, output_dir)
119
-
120
-
121
- def export_flax(
122
- model: flax.linen.Module,
123
- params: PyTree,
124
- input_shape: tuple[int, ...],
125
- preprocessor: Callable | None = None,
126
- postprocessor: Callable | None = None,
127
- input_name: str = "inputs",
128
- output_name: str = "outputs",
129
- output_dir: str | Path = "export",
130
- ) -> None:
131
- jax_module = JaxModule(
132
- params, model.apply, trainable=False, input_polymorphic_shape="(b, ...)"
133
- ) # if you want to use a batch dimension
134
-
135
- # to avoid mapping sequences to ambiguous mappings
136
- if postprocessor is None:
137
-
138
- def postprocessor(x: PyTree) -> PyTree:
139
- return {output_name: x}
140
-
141
- export_manager = ExportManager(
142
- jax_module,
143
- [
144
- ServingConfig(
145
- "serving_default",
146
- input_signature=[tf.TensorSpec([None] + list(input_shape), tf.float32, name=input_name)],
147
- tf_preprocessor=preprocessor,
148
- tf_postprocessor=postprocessor,
149
- )
150
- ],
151
- )
152
-
153
- logger.info("Exporting model to %s", output_dir)
154
- export_manager.save(output_dir)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes