xax 0.0.7__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {xax-0.0.7/xax.egg-info → xax-0.1.0}/PKG-INFO +23 -4
  2. {xax-0.0.7 → xax-0.1.0}/pyproject.toml +1 -0
  3. {xax-0.0.7 → xax-0.1.0}/setup.py +14 -2
  4. {xax-0.0.7 → xax-0.1.0}/xax/__init__.py +94 -4
  5. xax-0.1.0/xax/nn/equinox.py +180 -0
  6. xax-0.1.0/xax/nn/export.py +147 -0
  7. {xax-0.0.7 → xax-0.1.0}/xax/nn/geom.py +26 -0
  8. xax-0.1.0/xax/nn/norm.py +23 -0
  9. {xax-0.0.7 → xax-0.1.0}/xax/requirements.txt +1 -0
  10. {xax-0.0.7 → xax-0.1.0}/xax/task/base.py +6 -0
  11. {xax-0.0.7 → xax-0.1.0}/xax/task/logger.py +97 -2
  12. {xax-0.0.7 → xax-0.1.0}/xax/task/loggers/stdout.py +2 -2
  13. {xax-0.0.7 → xax-0.1.0}/xax/task/loggers/tensorboard.py +25 -14
  14. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/artifacts.py +1 -21
  15. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/checkpointing.py +19 -5
  16. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/logger.py +28 -4
  17. xax-0.1.0/xax/task/mixins/step_wrapper.py +59 -0
  18. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/train.py +50 -34
  19. {xax-0.0.7 → xax-0.1.0}/xax/task/script.py +0 -4
  20. xax-0.1.0/xax/utils/debugging.py +49 -0
  21. {xax-0.0.7 → xax-0.1.0}/xax/utils/experiments.py +23 -4
  22. xax-0.1.0/xax/utils/jaxpr.py +77 -0
  23. xax-0.1.0/xax/utils/pytree.py +238 -0
  24. {xax-0.0.7 → xax-0.1.0}/xax/utils/tensorboard.py +177 -1
  25. {xax-0.0.7 → xax-0.1.0/xax.egg-info}/PKG-INFO +23 -4
  26. {xax-0.0.7 → xax-0.1.0}/xax.egg-info/SOURCES.txt +5 -0
  27. {xax-0.0.7 → xax-0.1.0}/xax.egg-info/requires.txt +21 -0
  28. xax-0.0.7/xax/task/mixins/step_wrapper.py +0 -68
  29. xax-0.0.7/xax/utils/pytree.py +0 -50
  30. {xax-0.0.7 → xax-0.1.0}/LICENSE +0 -0
  31. {xax-0.0.7 → xax-0.1.0}/MANIFEST.in +0 -0
  32. {xax-0.0.7 → xax-0.1.0}/README.md +0 -0
  33. {xax-0.0.7 → xax-0.1.0}/setup.cfg +0 -0
  34. {xax-0.0.7 → xax-0.1.0}/xax/core/__init__.py +0 -0
  35. {xax-0.0.7 → xax-0.1.0}/xax/core/conf.py +0 -0
  36. {xax-0.0.7 → xax-0.1.0}/xax/core/state.py +0 -0
  37. {xax-0.0.7 → xax-0.1.0}/xax/nn/__init__.py +0 -0
  38. {xax-0.0.7 → xax-0.1.0}/xax/nn/embeddings.py +0 -0
  39. {xax-0.0.7 → xax-0.1.0}/xax/nn/functions.py +0 -0
  40. {xax-0.0.7 → xax-0.1.0}/xax/nn/parallel.py +0 -0
  41. {xax-0.0.7 → xax-0.1.0}/xax/py.typed +0 -0
  42. {xax-0.0.7 → xax-0.1.0}/xax/requirements-dev.txt +0 -0
  43. {xax-0.0.7 → xax-0.1.0}/xax/task/__init__.py +0 -0
  44. {xax-0.0.7 → xax-0.1.0}/xax/task/launchers/__init__.py +0 -0
  45. {xax-0.0.7 → xax-0.1.0}/xax/task/launchers/base.py +0 -0
  46. {xax-0.0.7 → xax-0.1.0}/xax/task/launchers/cli.py +0 -0
  47. {xax-0.0.7 → xax-0.1.0}/xax/task/launchers/single_process.py +0 -0
  48. {xax-0.0.7 → xax-0.1.0}/xax/task/loggers/__init__.py +0 -0
  49. {xax-0.0.7 → xax-0.1.0}/xax/task/loggers/callback.py +0 -0
  50. {xax-0.0.7 → xax-0.1.0}/xax/task/loggers/json.py +0 -0
  51. {xax-0.0.7 → xax-0.1.0}/xax/task/loggers/state.py +0 -0
  52. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/__init__.py +0 -0
  53. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/compile.py +0 -0
  54. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/cpu_stats.py +0 -0
  55. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/data_loader.py +0 -0
  56. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/gpu_stats.py +0 -0
  57. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/process.py +0 -0
  58. {xax-0.0.7 → xax-0.1.0}/xax/task/mixins/runnable.py +0 -0
  59. {xax-0.0.7 → xax-0.1.0}/xax/task/task.py +0 -0
  60. {xax-0.0.7 → xax-0.1.0}/xax/utils/__init__.py +0 -0
  61. {xax-0.0.7 → xax-0.1.0}/xax/utils/data/__init__.py +0 -0
  62. {xax-0.0.7 → xax-0.1.0}/xax/utils/data/collate.py +0 -0
  63. {xax-0.0.7 → xax-0.1.0}/xax/utils/jax.py +0 -0
  64. {xax-0.0.7 → xax-0.1.0}/xax/utils/logging.py +0 -0
  65. {xax-0.0.7 → xax-0.1.0}/xax/utils/numpy.py +0 -0
  66. {xax-0.0.7 → xax-0.1.0}/xax/utils/profile.py +0 -0
  67. {xax-0.0.7 → xax-0.1.0}/xax/utils/text.py +0 -0
  68. {xax-0.0.7 → xax-0.1.0}/xax.egg-info/dependency_links.txt +0 -0
  69. {xax-0.0.7 → xax-0.1.0}/xax.egg-info/top_level.txt +0 -0
@@ -1,12 +1,13 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.0.7
4
- Summary: The xax project
5
- Home-page: https://github.com/dpshai/xax
3
+ Version: 0.1.0
4
+ Summary: A library for fast Jax experimentation
5
+ Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: attrs
10
11
  Requires-Dist: jax
11
12
  Requires-Dist: jaxtyping
12
13
  Requires-Dist: equinox
@@ -30,10 +31,28 @@ Requires-Dist: pytest; extra == "dev"
30
31
  Requires-Dist: types-pillow; extra == "dev"
31
32
  Requires-Dist: types-psutil; extra == "dev"
32
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: export
35
+ Requires-Dist: orbax-export; extra == "export"
36
+ Requires-Dist: tensorflow; extra == "export"
37
+ Provides-Extra: flax
38
+ Requires-Dist: flax; extra == "flax"
39
+ Provides-Extra: all
40
+ Requires-Dist: black; extra == "all"
41
+ Requires-Dist: darglint; extra == "all"
42
+ Requires-Dist: mypy; extra == "all"
43
+ Requires-Dist: ruff; extra == "all"
44
+ Requires-Dist: pytest; extra == "all"
45
+ Requires-Dist: types-pillow; extra == "all"
46
+ Requires-Dist: types-psutil; extra == "all"
47
+ Requires-Dist: types-requests; extra == "all"
48
+ Requires-Dist: orbax-export; extra == "all"
49
+ Requires-Dist: tensorflow; extra == "all"
50
+ Requires-Dist: flax; extra == "all"
33
51
  Dynamic: author
34
52
  Dynamic: description
35
53
  Dynamic: description-content-type
36
54
  Dynamic: home-page
55
+ Dynamic: license-file
37
56
  Dynamic: provides-extra
38
57
  Dynamic: requires-dist
39
58
  Dynamic: requires-python
@@ -40,6 +40,7 @@ module = [
40
40
  "setuptools.*",
41
41
  "tensorboard.*",
42
42
  "transformers.*",
43
+ "orbax.export.*",
43
44
  ]
44
45
 
45
46
  ignore_missing_imports = true
@@ -14,6 +14,15 @@ with open("xax/requirements.txt", "r", encoding="utf-8") as f:
14
14
  with open("xax/requirements-dev.txt", "r", encoding="utf-8") as f:
15
15
  requirements_dev: list[str] = f.read().splitlines()
16
16
 
17
+ requirements_export: list[str] = [
18
+ "orbax-export",
19
+ "tensorflow",
20
+ ]
21
+
22
+ requirements_flax: list[str] = [
23
+ "flax",
24
+ ]
25
+
17
26
  with open("xax/__init__.py", "r", encoding="utf-8") as fh:
18
27
  version_re = re.search(r"^__version__ = \"([^\"]*)\"", fh.read(), re.MULTILINE)
19
28
  assert version_re is not None, "Could not find version in xax/__init__.py"
@@ -23,9 +32,9 @@ version: str = version_re.group(1)
23
32
  setup(
24
33
  name="xax",
25
34
  version=version,
26
- description="The xax project",
35
+ description="A library for fast Jax experimentation",
27
36
  author="Benjamin Bolte",
28
- url="https://github.com/dpshai/xax",
37
+ url="https://github.com/kscalelabs/xax",
29
38
  long_description=long_description,
30
39
  long_description_content_type="text/markdown",
31
40
  python_requires=">=3.11",
@@ -33,6 +42,9 @@ setup(
33
42
  tests_require=requirements_dev,
34
43
  extras_require={
35
44
  "dev": requirements_dev,
45
+ "export": requirements_export,
46
+ "flax": requirements_flax,
47
+ "all": requirements_dev + requirements_export + requirements_flax,
36
48
  },
37
49
  package_data={
38
50
  "xax": [
@@ -4,14 +4,15 @@ This package is structured so that all the important stuff can be accessed
4
4
  without having to dig around through the internals. This is done by lazily
5
5
  importing the module by name.
6
6
 
7
- This file can be maintained by running the update script:
7
+ This file can be maintained by updating the imports at the bottom of the file
8
+ and running the update script:
8
9
 
9
10
  .. code-block:: bash
10
11
 
11
12
  python -m scripts.update_api --inplace
12
13
  """
13
14
 
14
- __version__ = "0.0.7"
15
+ __version__ = "0.1.0"
15
16
 
16
17
  # This list shouldn't be modified by hand; instead, run the update script.
17
18
  __all__ = [
@@ -34,8 +35,20 @@ __all__ = [
34
35
  "get_positional_embeddings",
35
36
  "get_rotary_embeddings",
36
37
  "rotary_embeddings",
38
+ "MLPHyperParams",
39
+ "export_eqx_mlp",
40
+ "load_eqx",
41
+ "load_eqx_mlp",
42
+ "make_eqx_mlp",
43
+ "save_eqx",
44
+ "export",
45
+ "export_flax",
46
+ "export_with_params",
37
47
  "euler_to_quat",
48
+ "get_projected_gravity_vector_from_quat",
38
49
  "quat_to_euler",
50
+ "cast_norm_type",
51
+ "get_norm",
39
52
  "is_master",
40
53
  "BaseLauncher",
41
54
  "CliLauncher",
@@ -52,13 +65,16 @@ __all__ = [
52
65
  "CPUStatsOptions",
53
66
  "DataloaderConfig",
54
67
  "GPUStatsOptions",
68
+ "StepContext",
55
69
  "Script",
56
70
  "ScriptConfig",
57
71
  "Config",
58
72
  "Task",
59
73
  "collate",
60
74
  "collate_non_null",
75
+ "get_named_leaves",
61
76
  "BaseFileDownloader",
77
+ "ContextTimer",
62
78
  "CumulativeTimer",
63
79
  "DataDownloader",
64
80
  "IntervalTicker",
@@ -81,6 +97,7 @@ __all__ = [
81
97
  "stage_environment",
82
98
  "to_markdown_table",
83
99
  "jit",
100
+ "save_jaxpr_dot",
84
101
  "ColoredFormatter",
85
102
  "configure_logging",
86
103
  "one_hot",
@@ -90,8 +107,13 @@ __all__ = [
90
107
  "compute_nan_ratio",
91
108
  "flatten_array",
92
109
  "flatten_pytree",
110
+ "pytree_has_nans",
111
+ "reshuffle_pytree",
112
+ "reshuffle_pytree_along_dims",
113
+ "reshuffle_pytree_independently",
93
114
  "slice_array",
94
115
  "slice_pytree",
116
+ "update_pytree",
95
117
  "TextBlock",
96
118
  "camelcase_to_snakecase",
97
119
  "colored",
@@ -113,21 +135,36 @@ __all__ += [
113
135
  "Batch",
114
136
  "CollateMode",
115
137
  "EmbeddingKind",
138
+ "ActivationFunction",
139
+ "DTYPE",
116
140
  "LOG_ERROR_SUMMARY",
117
141
  "LOG_PING",
118
142
  "LOG_STATUS",
143
+ "NormType",
119
144
  "Output",
120
145
  "Phase",
121
146
  "RawConfigType",
122
147
  ]
123
148
 
124
149
  import os
150
+ import shutil
125
151
  from typing import TYPE_CHECKING
126
152
 
153
+ # Sets some useful XLA flags.
154
+ xla_flags: list[str] = []
155
+ if "XLA_FLAGS" in os.environ:
156
+ xla_flags.append(os.environ["XLA_FLAGS"])
157
+
158
+ # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
159
+ # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
160
+ if shutil.which("nvidia-smi") is not None:
161
+ xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
162
+ os.environ["XLA_FLAGS"] = " ".join(xla_flags)
163
+
127
164
  # If this flag is set, eagerly imports the entire package (not recommended).
128
165
  IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
129
166
 
130
- del os
167
+ del os, shutil, xla_flags
131
168
 
132
169
  # This dictionary is auto-generated and shouldn't be modified by hand; instead,
133
170
  # run the update script.
@@ -151,8 +188,20 @@ NAME_MAP: dict[str, str] = {
151
188
  "get_positional_embeddings": "nn.embeddings",
152
189
  "get_rotary_embeddings": "nn.embeddings",
153
190
  "rotary_embeddings": "nn.embeddings",
191
+ "MLPHyperParams": "nn.equinox",
192
+ "export_eqx_mlp": "nn.equinox",
193
+ "load_eqx": "nn.equinox",
194
+ "load_eqx_mlp": "nn.equinox",
195
+ "make_eqx_mlp": "nn.equinox",
196
+ "save_eqx": "nn.equinox",
197
+ "export": "nn.export",
198
+ "export_flax": "nn.export",
199
+ "export_with_params": "nn.export",
154
200
  "euler_to_quat": "nn.geom",
201
+ "get_projected_gravity_vector_from_quat": "nn.geom",
155
202
  "quat_to_euler": "nn.geom",
203
+ "cast_norm_type": "nn.norm",
204
+ "get_norm": "nn.norm",
156
205
  "is_master": "nn.parallel",
157
206
  "BaseLauncher": "task.launchers.base",
158
207
  "CliLauncher": "task.launchers.cli",
@@ -169,13 +218,16 @@ NAME_MAP: dict[str, str] = {
169
218
  "CPUStatsOptions": "task.mixins.cpu_stats",
170
219
  "DataloaderConfig": "task.mixins.data_loader",
171
220
  "GPUStatsOptions": "task.mixins.gpu_stats",
221
+ "StepContext": "task.mixins.step_wrapper",
172
222
  "Script": "task.script",
173
223
  "ScriptConfig": "task.script",
174
224
  "Config": "task.task",
175
225
  "Task": "task.task",
176
226
  "collate": "utils.data.collate",
177
227
  "collate_non_null": "utils.data.collate",
228
+ "get_named_leaves": "utils.debugging",
178
229
  "BaseFileDownloader": "utils.experiments",
230
+ "ContextTimer": "utils.experiments",
179
231
  "CumulativeTimer": "utils.experiments",
180
232
  "DataDownloader": "utils.experiments",
181
233
  "IntervalTicker": "utils.experiments",
@@ -198,6 +250,7 @@ NAME_MAP: dict[str, str] = {
198
250
  "stage_environment": "utils.experiments",
199
251
  "to_markdown_table": "utils.experiments",
200
252
  "jit": "utils.jax",
253
+ "save_jaxpr_dot": "utils.jaxpr",
201
254
  "ColoredFormatter": "utils.logging",
202
255
  "configure_logging": "utils.logging",
203
256
  "one_hot": "utils.numpy",
@@ -207,8 +260,13 @@ NAME_MAP: dict[str, str] = {
207
260
  "compute_nan_ratio": "utils.pytree",
208
261
  "flatten_array": "utils.pytree",
209
262
  "flatten_pytree": "utils.pytree",
263
+ "pytree_has_nans": "utils.pytree",
264
+ "reshuffle_pytree": "utils.pytree",
265
+ "reshuffle_pytree_along_dims": "utils.pytree",
266
+ "reshuffle_pytree_independently": "utils.pytree",
210
267
  "slice_array": "utils.pytree",
211
268
  "slice_pytree": "utils.pytree",
269
+ "update_pytree": "utils.pytree",
212
270
  "TextBlock": "utils.text",
213
271
  "camelcase_to_snakecase": "utils.text",
214
272
  "colored": "utils.text",
@@ -235,9 +293,12 @@ NAME_MAP.update(
235
293
  "LOG_ERROR_SUMMARY": "utils.logging",
236
294
  "LOG_PING": "utils.logging",
237
295
  "LOG_STATUS": "utils.logging",
296
+ "NormType": "nn.norm",
238
297
  "Output": "task.mixins.output",
239
298
  "Phase": "core.state",
240
299
  "RawConfigType": "task.base",
300
+ "ActivationFunction": "nn.equinox",
301
+ "DTYPE": "nn.equinox",
241
302
  },
242
303
  )
243
304
 
@@ -275,7 +336,27 @@ if IMPORT_ALL or TYPE_CHECKING:
275
336
  get_rotary_embeddings,
276
337
  rotary_embeddings,
277
338
  )
278
- from xax.nn.geom import euler_to_quat, quat_to_euler
339
+ from xax.nn.equinox import (
340
+ DTYPE,
341
+ ActivationFunction,
342
+ MLPHyperParams,
343
+ export_eqx_mlp,
344
+ load_eqx,
345
+ load_eqx_mlp,
346
+ make_eqx_mlp,
347
+ save_eqx,
348
+ )
349
+ from xax.nn.export import (
350
+ export,
351
+ export_flax,
352
+ export_with_params,
353
+ )
354
+ from xax.nn.geom import (
355
+ euler_to_quat,
356
+ get_projected_gravity_vector_from_quat,
357
+ quat_to_euler,
358
+ )
359
+ from xax.nn.norm import NormType, cast_norm_type, get_norm
279
360
  from xax.nn.parallel import is_master
280
361
  from xax.task.base import RawConfigType
281
362
  from xax.task.launchers.base import BaseLauncher
@@ -290,12 +371,15 @@ if IMPORT_ALL or TYPE_CHECKING:
290
371
  from xax.task.mixins.cpu_stats import CPUStatsOptions
291
372
  from xax.task.mixins.data_loader import DataloaderConfig
292
373
  from xax.task.mixins.gpu_stats import GPUStatsOptions
374
+ from xax.task.mixins.step_wrapper import StepContext
293
375
  from xax.task.mixins.train import Batch, Output
294
376
  from xax.task.script import Script, ScriptConfig
295
377
  from xax.task.task import Config, Task
296
378
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
379
+ from xax.utils.debugging import get_named_leaves
297
380
  from xax.utils.experiments import (
298
381
  BaseFileDownloader,
382
+ ContextTimer,
299
383
  CumulativeTimer,
300
384
  DataDownloader,
301
385
  IntervalTicker,
@@ -319,6 +403,7 @@ if IMPORT_ALL or TYPE_CHECKING:
319
403
  to_markdown_table,
320
404
  )
321
405
  from xax.utils.jax import jit
406
+ from xax.utils.jaxpr import save_jaxpr_dot
322
407
  from xax.utils.logging import (
323
408
  LOG_ERROR_SUMMARY,
324
409
  LOG_PING,
@@ -332,8 +417,13 @@ if IMPORT_ALL or TYPE_CHECKING:
332
417
  compute_nan_ratio,
333
418
  flatten_array,
334
419
  flatten_pytree,
420
+ pytree_has_nans,
421
+ reshuffle_pytree,
422
+ reshuffle_pytree_along_dims,
423
+ reshuffle_pytree_independently,
335
424
  slice_array,
336
425
  slice_pytree,
426
+ update_pytree,
337
427
  )
338
428
  from xax.utils.text import (
339
429
  TextBlock,
@@ -0,0 +1,180 @@
1
+ """Equinox utilities."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Callable, Literal, TypedDict, cast
7
+
8
+ import equinox as eqx
9
+ import jax
10
+ from jaxtyping import PRNGKeyArray
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ ActivationFunction = Literal[
15
+ "relu",
16
+ "tanh",
17
+ "celu",
18
+ "elu",
19
+ "gelu",
20
+ "glu",
21
+ "hard_sigmoid",
22
+ "hard_silu",
23
+ "hard_swish",
24
+ "hard_tanh",
25
+ "leaky_relu",
26
+ "log_sigmoid",
27
+ "log_softmax",
28
+ "logsumexp",
29
+ "relu6",
30
+ "selu",
31
+ "sigmoid",
32
+ "soft_sign",
33
+ "softmax",
34
+ "softplus",
35
+ "sparse_plus",
36
+ "sparse_sigmoid",
37
+ "silu",
38
+ "swish",
39
+ "squareplus",
40
+ "mish",
41
+ "identity",
42
+ ]
43
+
44
+ DTYPE = Literal["float32", "float64"]
45
+
46
+ DTYPE_MAP: dict[DTYPE, jax.numpy.dtype] = {
47
+ "float32": jax.numpy.float32,
48
+ "float64": jax.numpy.float64,
49
+ }
50
+
51
+
52
+ class MLPHyperParams(TypedDict):
53
+ """Hyperparameters of an Equinox MLP."""
54
+
55
+ in_size: int | Literal["scalar"]
56
+ out_size: int | Literal["scalar"]
57
+ width_size: int
58
+ depth: int
59
+ activation: ActivationFunction
60
+ final_activation: ActivationFunction
61
+ use_bias: bool
62
+ use_final_bias: bool
63
+ dtype: DTYPE
64
+
65
+
66
+ def _infer_activation(activation: ActivationFunction) -> Callable:
67
+ if activation == "identity":
68
+ return lambda x: x
69
+ try:
70
+ return getattr(jax.nn, activation)
71
+ except AttributeError:
72
+ raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
73
+
74
+
75
+ def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray = jax.random.PRNGKey(0)) -> eqx.nn.MLP:
76
+ """Create an Equinox MLP from a set of hyperparameters.
77
+
78
+ Args:
79
+ hyperparams: The hyperparameters of the MLP.
80
+ key: The PRNG key to use for the MLP.
81
+ """
82
+ activation = _infer_activation(hyperparams["activation"])
83
+ final_activation = _infer_activation(hyperparams["final_activation"])
84
+ dtype = DTYPE_MAP[hyperparams["dtype"]]
85
+
86
+ return eqx.nn.MLP(
87
+ in_size=hyperparams["in_size"],
88
+ out_size=hyperparams["out_size"],
89
+ width_size=hyperparams["width_size"],
90
+ depth=hyperparams["depth"],
91
+ activation=activation,
92
+ final_activation=final_activation,
93
+ use_bias=hyperparams["use_bias"],
94
+ use_final_bias=hyperparams["use_final_bias"],
95
+ dtype=dtype,
96
+ key=key,
97
+ )
98
+
99
+
100
+ def export_eqx_mlp(
101
+ model: eqx.nn.MLP,
102
+ output_path: str | Path,
103
+ dtype: jax.numpy.dtype = eqx._misc.default_floating_dtype(),
104
+ ) -> None:
105
+ """Serialize an Equinox MLP to a .eqx file.
106
+
107
+ Args:
108
+ model: The JAX MLP to export.
109
+ output_path: The path to save the exported model.
110
+ dtype: The dtype of the model.
111
+ """
112
+ activation = model.activation.__name__
113
+ final_activation = model.final_activation.__name__
114
+
115
+ if final_activation == "<lambda>":
116
+ logger.warning("Final activation is a lambda function. Assuming identity.")
117
+ final_activation = "identity"
118
+
119
+ # cast strings to ActivationFunction for type checking
120
+ activation = cast(ActivationFunction, activation)
121
+ final_activation = cast(ActivationFunction, final_activation)
122
+
123
+ if dtype not in DTYPE_MAP.values():
124
+ raise ValueError(f"Invalid dtype: {dtype}. Must be one of {DTYPE_MAP.values()}")
125
+
126
+ dtype = {v: k for k, v in DTYPE_MAP.items()}[dtype]
127
+
128
+ hyperparams: MLPHyperParams = {
129
+ "in_size": model.in_size,
130
+ "out_size": model.out_size,
131
+ "width_size": model.width_size,
132
+ "depth": model.depth,
133
+ "activation": activation,
134
+ "final_activation": final_activation,
135
+ "use_bias": model.use_bias,
136
+ "use_final_bias": model.use_final_bias,
137
+ "dtype": dtype,
138
+ }
139
+
140
+ with open(output_path, "wb") as f:
141
+ hyperparam_str = json.dumps(hyperparams)
142
+ f.write((hyperparam_str + "\n").encode(encoding="utf-8"))
143
+ eqx.tree_serialise_leaves(f, model)
144
+
145
+
146
+ def save_eqx(
147
+ model: eqx.Module,
148
+ output_path: str | Path,
149
+ ) -> None:
150
+ """Serialize an Equinox module to a .eqx file.
151
+
152
+ Args:
153
+ model: The Equinox module to export.
154
+ output_path: The path to save the exported model.
155
+ """
156
+ with open(output_path, "wb") as f:
157
+ eqx.tree_serialise_leaves(f, model)
158
+
159
+
160
+ def load_eqx(
161
+ model: eqx.Module,
162
+ eqx_file: str | Path,
163
+ ) -> eqx.Module:
164
+ """Deserialize an Equinox module from a .eqx file.
165
+
166
+ Args:
167
+ model: The Equinox module to load into.
168
+ eqx_file: The path to the .eqx file to load.
169
+ """
170
+ with open(eqx_file, "rb") as f:
171
+ return eqx.tree_deserialise_leaves(f, model)
172
+
173
+
174
+ def load_eqx_mlp(
175
+ eqx_file: str | Path,
176
+ ) -> eqx.nn.MLP:
177
+ with open(eqx_file, "rb") as f:
178
+ hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
179
+ model = make_eqx_mlp(hyperparams=hyperparams)
180
+ return eqx.tree_deserialise_leaves(f, model)
@@ -0,0 +1,147 @@
1
+ """Export JAX functions to TensorFlow SavedModel format."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Callable
6
+
7
+ import flax
8
+ import jax
9
+ import tensorflow as tf
10
+ from jax.experimental import jax2tf
11
+ from jaxtyping import Array, PyTree
12
+ from orbax.export import ExportManager, JaxModule, ServingConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _run_infer(tf_module: tf.Module, input_shapes: list[tuple[int, ...]], batch_size: int | None) -> tf.Tensor:
18
+ """Warm up the model by running it once."""
19
+ if batch_size is not None:
20
+ test_inputs = [
21
+ jax.random.normal(jax.random.PRNGKey(42), (batch_size, *input_shape)) for input_shape in input_shapes
22
+ ]
23
+ else:
24
+ test_inputs = [jax.random.normal(jax.random.PRNGKey(42), (1, *input_shape)) for input_shape in input_shapes]
25
+ if not hasattr(tf_module, "infer"):
26
+ raise ValueError("Model does not have an infer method")
27
+ return tf_module.infer(*test_inputs)
28
+
29
+
30
+ def export(
31
+ model: Callable,
32
+ input_shapes: list[tuple[int, ...]],
33
+ output_dir: str | Path = "export",
34
+ batch_size: int | None = None,
35
+ ) -> None:
36
+ """Export a JAX function to TensorFlow SavedModel.
37
+
38
+ Note: Tensorflow GraphDef can't be larger than 2GB - https://github.com/tensorflow/tensorflow/issues/51870
39
+ You can avoid this by saving model parameters as non-constants.
40
+
41
+ Args:
42
+ model: The JAX function to export.
43
+ input_shapes: The shape of the input tensors, excluding batch dimension.
44
+ output_dir: Directory to save the exported model.
45
+ batch_size: Optional batch dimension. If None, a polymorphic batch dimension is used.
46
+ """
47
+ tf_module = tf.Module()
48
+ # Create a polymorphic shape specification for each input
49
+ poly_spec = "(b, ...)" if batch_size is not None else "(None, ...)"
50
+ polymorphic_shapes = [poly_spec] * len(input_shapes)
51
+ tf_module.infer = tf.function( # type: ignore [attr-defined]
52
+ jax2tf.convert(
53
+ model,
54
+ polymorphic_shapes=polymorphic_shapes,
55
+ # setting this to False will allow the model to run on platforms other than the one that exports the model
56
+ # https://github.com/jax-ml/jax/blob/051687dc4c899df3d95c30b812ade401d8b31166/jax/experimental/jax2tf/README.md?plain=1#L1342
57
+ # generally though I think native_serialization is recommended
58
+ native_serialization=False,
59
+ with_gradient=False,
60
+ ),
61
+ autograph=False,
62
+ input_signature=[tf.TensorSpec([batch_size] + list(input_shape), tf.float32) for input_shape in input_shapes],
63
+ )
64
+
65
+ # warm up the model
66
+ _run_infer(tf_module, input_shapes, batch_size)
67
+
68
+ logger.info("Exporting SavedModel to %s", output_dir)
69
+ tf.saved_model.save(
70
+ tf_module,
71
+ output_dir,
72
+ )
73
+
74
+
75
+ def export_with_params(
76
+ model: Callable,
77
+ params: PyTree,
78
+ input_shapes: list[tuple[int, ...]],
79
+ output_dir: str | Path = "export",
80
+ batch_dim: int | None = None,
81
+ ) -> None:
82
+ """Export a JAX function that takes parameters to TensorFlow SavedModel.
83
+
84
+ Args:
85
+ model: The JAX function to export. Should take parameters as first argument.
86
+ params: The parameters to use for the model.
87
+ input_shapes: The shape of the input tensors, excluding batch dimension.
88
+ output_dir: Directory to save the exported model.
89
+ batch_dim: Optional batch dimension. If None, a polymorphic batch dimension is used.
90
+ """
91
+ param_vars = tf.nest.map_structure(tf.Variable, params)
92
+
93
+ converted_model = jax2tf.convert(model)
94
+
95
+ def model_fn(*inputs: PyTree) -> Array:
96
+ return converted_model(param_vars, *inputs)
97
+
98
+ tf_module = tf.Module()
99
+ tf_module._variables = tf.nest.flatten(param_vars) # type: ignore [attr-defined]
100
+ tf_module.infer = tf.function( # type: ignore [attr-defined]
101
+ model_fn,
102
+ jit_compile=True,
103
+ autograph=False,
104
+ input_signature=[tf.TensorSpec([batch_dim] + list(input_shape), tf.float32) for input_shape in input_shapes],
105
+ )
106
+
107
+ # warm up the model
108
+ _run_infer(tf_module, input_shapes, batch_dim)
109
+
110
+ logger.info("Exporting SavedModel to %s", output_dir)
111
+ tf.saved_model.save(tf_module, output_dir)
112
+
113
+
114
+ def export_flax(
115
+ model: flax.linen.Module,
116
+ params: PyTree,
117
+ input_shape: tuple[int, ...],
118
+ preprocessor: Callable | None = None,
119
+ postprocessor: Callable | None = None,
120
+ input_name: str = "inputs",
121
+ output_name: str = "outputs",
122
+ output_dir: str | Path = "export",
123
+ ) -> None:
124
+ jax_module = JaxModule(
125
+ params, model.apply, trainable=False, input_polymorphic_shape="(b, ...)"
126
+ ) # if you want to use a batch dimension
127
+
128
+ # to avoid mapping sequences to ambiguous mappings
129
+ if postprocessor is None:
130
+
131
+ def postprocessor(x: PyTree) -> PyTree:
132
+ return {output_name: x}
133
+
134
+ export_manager = ExportManager(
135
+ jax_module,
136
+ [
137
+ ServingConfig(
138
+ "serving_default",
139
+ input_signature=[tf.TensorSpec([None] + list(input_shape), tf.float32, name=input_name)],
140
+ tf_preprocessor=preprocessor,
141
+ tf_postprocessor=postprocessor,
142
+ )
143
+ ],
144
+ )
145
+
146
+ logger.info("Exporting model to %s", output_dir)
147
+ export_manager.save(output_dir)