xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +121 -3
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +101 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jax.py +126 -0
- xax/utils/jaxpr.py +77 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +238 -0
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/RECORD +28 -20
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -4,14 +4,15 @@ This package is structured so that all the important stuff can be accessed
|
|
4
4
|
without having to dig around through the internals. This is done by lazily
|
5
5
|
importing the module by name.
|
6
6
|
|
7
|
-
This file can be maintained by
|
7
|
+
This file can be maintained by updating the imports at the bottom of the file
|
8
|
+
and running the update script:
|
8
9
|
|
9
10
|
.. code-block:: bash
|
10
11
|
|
11
12
|
python -m scripts.update_api --inplace
|
12
13
|
"""
|
13
14
|
|
14
|
-
__version__ = "0.0
|
15
|
+
__version__ = "0.1.0"
|
15
16
|
|
16
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
18
|
__all__ = [
|
@@ -34,6 +35,20 @@ __all__ = [
|
|
34
35
|
"get_positional_embeddings",
|
35
36
|
"get_rotary_embeddings",
|
36
37
|
"rotary_embeddings",
|
38
|
+
"MLPHyperParams",
|
39
|
+
"export_eqx_mlp",
|
40
|
+
"load_eqx",
|
41
|
+
"load_eqx_mlp",
|
42
|
+
"make_eqx_mlp",
|
43
|
+
"save_eqx",
|
44
|
+
"export",
|
45
|
+
"export_flax",
|
46
|
+
"export_with_params",
|
47
|
+
"euler_to_quat",
|
48
|
+
"get_projected_gravity_vector_from_quat",
|
49
|
+
"quat_to_euler",
|
50
|
+
"cast_norm_type",
|
51
|
+
"get_norm",
|
37
52
|
"is_master",
|
38
53
|
"BaseLauncher",
|
39
54
|
"CliLauncher",
|
@@ -50,13 +65,16 @@ __all__ = [
|
|
50
65
|
"CPUStatsOptions",
|
51
66
|
"DataloaderConfig",
|
52
67
|
"GPUStatsOptions",
|
68
|
+
"StepContext",
|
53
69
|
"Script",
|
54
70
|
"ScriptConfig",
|
55
71
|
"Config",
|
56
72
|
"Task",
|
57
73
|
"collate",
|
58
74
|
"collate_non_null",
|
75
|
+
"get_named_leaves",
|
59
76
|
"BaseFileDownloader",
|
77
|
+
"ContextTimer",
|
60
78
|
"CumulativeTimer",
|
61
79
|
"DataDownloader",
|
62
80
|
"IntervalTicker",
|
@@ -78,11 +96,24 @@ __all__ = [
|
|
78
96
|
"save_config",
|
79
97
|
"stage_environment",
|
80
98
|
"to_markdown_table",
|
99
|
+
"jit",
|
100
|
+
"save_jaxpr_dot",
|
81
101
|
"ColoredFormatter",
|
82
102
|
"configure_logging",
|
83
103
|
"one_hot",
|
84
104
|
"partial_flatten",
|
85
105
|
"worker_chunk",
|
106
|
+
"profile",
|
107
|
+
"compute_nan_ratio",
|
108
|
+
"flatten_array",
|
109
|
+
"flatten_pytree",
|
110
|
+
"pytree_has_nans",
|
111
|
+
"reshuffle_pytree",
|
112
|
+
"reshuffle_pytree_along_dims",
|
113
|
+
"reshuffle_pytree_independently",
|
114
|
+
"slice_array",
|
115
|
+
"slice_pytree",
|
116
|
+
"update_pytree",
|
86
117
|
"TextBlock",
|
87
118
|
"camelcase_to_snakecase",
|
88
119
|
"colored",
|
@@ -104,21 +135,36 @@ __all__ += [
|
|
104
135
|
"Batch",
|
105
136
|
"CollateMode",
|
106
137
|
"EmbeddingKind",
|
138
|
+
"ActivationFunction",
|
139
|
+
"DTYPE",
|
107
140
|
"LOG_ERROR_SUMMARY",
|
108
141
|
"LOG_PING",
|
109
142
|
"LOG_STATUS",
|
143
|
+
"NormType",
|
110
144
|
"Output",
|
111
145
|
"Phase",
|
112
146
|
"RawConfigType",
|
113
147
|
]
|
114
148
|
|
115
149
|
import os
|
150
|
+
import shutil
|
116
151
|
from typing import TYPE_CHECKING
|
117
152
|
|
153
|
+
# Sets some useful XLA flags.
|
154
|
+
xla_flags: list[str] = []
|
155
|
+
if "XLA_FLAGS" in os.environ:
|
156
|
+
xla_flags.append(os.environ["XLA_FLAGS"])
|
157
|
+
|
158
|
+
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
159
|
+
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
160
|
+
if shutil.which("nvidia-smi") is not None:
|
161
|
+
xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
|
162
|
+
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
163
|
+
|
118
164
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
119
165
|
IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
|
120
166
|
|
121
|
-
del os
|
167
|
+
del os, shutil, xla_flags
|
122
168
|
|
123
169
|
# This dictionary is auto-generated and shouldn't be modified by hand; instead,
|
124
170
|
# run the update script.
|
@@ -142,6 +188,20 @@ NAME_MAP: dict[str, str] = {
|
|
142
188
|
"get_positional_embeddings": "nn.embeddings",
|
143
189
|
"get_rotary_embeddings": "nn.embeddings",
|
144
190
|
"rotary_embeddings": "nn.embeddings",
|
191
|
+
"MLPHyperParams": "nn.equinox",
|
192
|
+
"export_eqx_mlp": "nn.equinox",
|
193
|
+
"load_eqx": "nn.equinox",
|
194
|
+
"load_eqx_mlp": "nn.equinox",
|
195
|
+
"make_eqx_mlp": "nn.equinox",
|
196
|
+
"save_eqx": "nn.equinox",
|
197
|
+
"export": "nn.export",
|
198
|
+
"export_flax": "nn.export",
|
199
|
+
"export_with_params": "nn.export",
|
200
|
+
"euler_to_quat": "nn.geom",
|
201
|
+
"get_projected_gravity_vector_from_quat": "nn.geom",
|
202
|
+
"quat_to_euler": "nn.geom",
|
203
|
+
"cast_norm_type": "nn.norm",
|
204
|
+
"get_norm": "nn.norm",
|
145
205
|
"is_master": "nn.parallel",
|
146
206
|
"BaseLauncher": "task.launchers.base",
|
147
207
|
"CliLauncher": "task.launchers.cli",
|
@@ -158,13 +218,16 @@ NAME_MAP: dict[str, str] = {
|
|
158
218
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
159
219
|
"DataloaderConfig": "task.mixins.data_loader",
|
160
220
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
221
|
+
"StepContext": "task.mixins.step_wrapper",
|
161
222
|
"Script": "task.script",
|
162
223
|
"ScriptConfig": "task.script",
|
163
224
|
"Config": "task.task",
|
164
225
|
"Task": "task.task",
|
165
226
|
"collate": "utils.data.collate",
|
166
227
|
"collate_non_null": "utils.data.collate",
|
228
|
+
"get_named_leaves": "utils.debugging",
|
167
229
|
"BaseFileDownloader": "utils.experiments",
|
230
|
+
"ContextTimer": "utils.experiments",
|
168
231
|
"CumulativeTimer": "utils.experiments",
|
169
232
|
"DataDownloader": "utils.experiments",
|
170
233
|
"IntervalTicker": "utils.experiments",
|
@@ -186,11 +249,24 @@ NAME_MAP: dict[str, str] = {
|
|
186
249
|
"save_config": "utils.experiments",
|
187
250
|
"stage_environment": "utils.experiments",
|
188
251
|
"to_markdown_table": "utils.experiments",
|
252
|
+
"jit": "utils.jax",
|
253
|
+
"save_jaxpr_dot": "utils.jaxpr",
|
189
254
|
"ColoredFormatter": "utils.logging",
|
190
255
|
"configure_logging": "utils.logging",
|
191
256
|
"one_hot": "utils.numpy",
|
192
257
|
"partial_flatten": "utils.numpy",
|
193
258
|
"worker_chunk": "utils.numpy",
|
259
|
+
"profile": "utils.profile",
|
260
|
+
"compute_nan_ratio": "utils.pytree",
|
261
|
+
"flatten_array": "utils.pytree",
|
262
|
+
"flatten_pytree": "utils.pytree",
|
263
|
+
"pytree_has_nans": "utils.pytree",
|
264
|
+
"reshuffle_pytree": "utils.pytree",
|
265
|
+
"reshuffle_pytree_along_dims": "utils.pytree",
|
266
|
+
"reshuffle_pytree_independently": "utils.pytree",
|
267
|
+
"slice_array": "utils.pytree",
|
268
|
+
"slice_pytree": "utils.pytree",
|
269
|
+
"update_pytree": "utils.pytree",
|
194
270
|
"TextBlock": "utils.text",
|
195
271
|
"camelcase_to_snakecase": "utils.text",
|
196
272
|
"colored": "utils.text",
|
@@ -217,9 +293,12 @@ NAME_MAP.update(
|
|
217
293
|
"LOG_ERROR_SUMMARY": "utils.logging",
|
218
294
|
"LOG_PING": "utils.logging",
|
219
295
|
"LOG_STATUS": "utils.logging",
|
296
|
+
"NormType": "nn.norm",
|
220
297
|
"Output": "task.mixins.output",
|
221
298
|
"Phase": "core.state",
|
222
299
|
"RawConfigType": "task.base",
|
300
|
+
"ActivationFunction": "nn.equinox",
|
301
|
+
"DTYPE": "nn.equinox",
|
223
302
|
},
|
224
303
|
)
|
225
304
|
|
@@ -257,6 +336,27 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
257
336
|
get_rotary_embeddings,
|
258
337
|
rotary_embeddings,
|
259
338
|
)
|
339
|
+
from xax.nn.equinox import (
|
340
|
+
DTYPE,
|
341
|
+
ActivationFunction,
|
342
|
+
MLPHyperParams,
|
343
|
+
export_eqx_mlp,
|
344
|
+
load_eqx,
|
345
|
+
load_eqx_mlp,
|
346
|
+
make_eqx_mlp,
|
347
|
+
save_eqx,
|
348
|
+
)
|
349
|
+
from xax.nn.export import (
|
350
|
+
export,
|
351
|
+
export_flax,
|
352
|
+
export_with_params,
|
353
|
+
)
|
354
|
+
from xax.nn.geom import (
|
355
|
+
euler_to_quat,
|
356
|
+
get_projected_gravity_vector_from_quat,
|
357
|
+
quat_to_euler,
|
358
|
+
)
|
359
|
+
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
260
360
|
from xax.nn.parallel import is_master
|
261
361
|
from xax.task.base import RawConfigType
|
262
362
|
from xax.task.launchers.base import BaseLauncher
|
@@ -271,12 +371,15 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
271
371
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
272
372
|
from xax.task.mixins.data_loader import DataloaderConfig
|
273
373
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
374
|
+
from xax.task.mixins.step_wrapper import StepContext
|
274
375
|
from xax.task.mixins.train import Batch, Output
|
275
376
|
from xax.task.script import Script, ScriptConfig
|
276
377
|
from xax.task.task import Config, Task
|
277
378
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
379
|
+
from xax.utils.debugging import get_named_leaves
|
278
380
|
from xax.utils.experiments import (
|
279
381
|
BaseFileDownloader,
|
382
|
+
ContextTimer,
|
280
383
|
CumulativeTimer,
|
281
384
|
DataDownloader,
|
282
385
|
IntervalTicker,
|
@@ -299,6 +402,8 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
299
402
|
stage_environment,
|
300
403
|
to_markdown_table,
|
301
404
|
)
|
405
|
+
from xax.utils.jax import jit
|
406
|
+
from xax.utils.jaxpr import save_jaxpr_dot
|
302
407
|
from xax.utils.logging import (
|
303
408
|
LOG_ERROR_SUMMARY,
|
304
409
|
LOG_PING,
|
@@ -307,6 +412,19 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
307
412
|
configure_logging,
|
308
413
|
)
|
309
414
|
from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
|
415
|
+
from xax.utils.profile import profile
|
416
|
+
from xax.utils.pytree import (
|
417
|
+
compute_nan_ratio,
|
418
|
+
flatten_array,
|
419
|
+
flatten_pytree,
|
420
|
+
pytree_has_nans,
|
421
|
+
reshuffle_pytree,
|
422
|
+
reshuffle_pytree_along_dims,
|
423
|
+
reshuffle_pytree_independently,
|
424
|
+
slice_array,
|
425
|
+
slice_pytree,
|
426
|
+
update_pytree,
|
427
|
+
)
|
310
428
|
from xax.utils.text import (
|
311
429
|
TextBlock,
|
312
430
|
camelcase_to_snakecase,
|
xax/nn/equinox.py
ADDED
@@ -0,0 +1,180 @@
|
|
1
|
+
"""Equinox utilities."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Callable, Literal, TypedDict, cast
|
7
|
+
|
8
|
+
import equinox as eqx
|
9
|
+
import jax
|
10
|
+
from jaxtyping import PRNGKeyArray
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
ActivationFunction = Literal[
|
15
|
+
"relu",
|
16
|
+
"tanh",
|
17
|
+
"celu",
|
18
|
+
"elu",
|
19
|
+
"gelu",
|
20
|
+
"glu",
|
21
|
+
"hard_sigmoid",
|
22
|
+
"hard_silu",
|
23
|
+
"hard_swish",
|
24
|
+
"hard_tanh",
|
25
|
+
"leaky_relu",
|
26
|
+
"log_sigmoid",
|
27
|
+
"log_softmax",
|
28
|
+
"logsumexp",
|
29
|
+
"relu6",
|
30
|
+
"selu",
|
31
|
+
"sigmoid",
|
32
|
+
"soft_sign",
|
33
|
+
"softmax",
|
34
|
+
"softplus",
|
35
|
+
"sparse_plus",
|
36
|
+
"sparse_sigmoid",
|
37
|
+
"silu",
|
38
|
+
"swish",
|
39
|
+
"squareplus",
|
40
|
+
"mish",
|
41
|
+
"identity",
|
42
|
+
]
|
43
|
+
|
44
|
+
DTYPE = Literal["float32", "float64"]
|
45
|
+
|
46
|
+
DTYPE_MAP: dict[DTYPE, jax.numpy.dtype] = {
|
47
|
+
"float32": jax.numpy.float32,
|
48
|
+
"float64": jax.numpy.float64,
|
49
|
+
}
|
50
|
+
|
51
|
+
|
52
|
+
class MLPHyperParams(TypedDict):
|
53
|
+
"""Hyperparameters of an Equinox MLP."""
|
54
|
+
|
55
|
+
in_size: int | Literal["scalar"]
|
56
|
+
out_size: int | Literal["scalar"]
|
57
|
+
width_size: int
|
58
|
+
depth: int
|
59
|
+
activation: ActivationFunction
|
60
|
+
final_activation: ActivationFunction
|
61
|
+
use_bias: bool
|
62
|
+
use_final_bias: bool
|
63
|
+
dtype: DTYPE
|
64
|
+
|
65
|
+
|
66
|
+
def _infer_activation(activation: ActivationFunction) -> Callable:
|
67
|
+
if activation == "identity":
|
68
|
+
return lambda x: x
|
69
|
+
try:
|
70
|
+
return getattr(jax.nn, activation)
|
71
|
+
except AttributeError:
|
72
|
+
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
|
73
|
+
|
74
|
+
|
75
|
+
def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray = jax.random.PRNGKey(0)) -> eqx.nn.MLP:
|
76
|
+
"""Create an Equinox MLP from a set of hyperparameters.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
hyperparams: The hyperparameters of the MLP.
|
80
|
+
key: The PRNG key to use for the MLP.
|
81
|
+
"""
|
82
|
+
activation = _infer_activation(hyperparams["activation"])
|
83
|
+
final_activation = _infer_activation(hyperparams["final_activation"])
|
84
|
+
dtype = DTYPE_MAP[hyperparams["dtype"]]
|
85
|
+
|
86
|
+
return eqx.nn.MLP(
|
87
|
+
in_size=hyperparams["in_size"],
|
88
|
+
out_size=hyperparams["out_size"],
|
89
|
+
width_size=hyperparams["width_size"],
|
90
|
+
depth=hyperparams["depth"],
|
91
|
+
activation=activation,
|
92
|
+
final_activation=final_activation,
|
93
|
+
use_bias=hyperparams["use_bias"],
|
94
|
+
use_final_bias=hyperparams["use_final_bias"],
|
95
|
+
dtype=dtype,
|
96
|
+
key=key,
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def export_eqx_mlp(
|
101
|
+
model: eqx.nn.MLP,
|
102
|
+
output_path: str | Path,
|
103
|
+
dtype: jax.numpy.dtype = eqx._misc.default_floating_dtype(),
|
104
|
+
) -> None:
|
105
|
+
"""Serialize an Equinox MLP to a .eqx file.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
model: The JAX MLP to export.
|
109
|
+
output_path: The path to save the exported model.
|
110
|
+
dtype: The dtype of the model.
|
111
|
+
"""
|
112
|
+
activation = model.activation.__name__
|
113
|
+
final_activation = model.final_activation.__name__
|
114
|
+
|
115
|
+
if final_activation == "<lambda>":
|
116
|
+
logger.warning("Final activation is a lambda function. Assuming identity.")
|
117
|
+
final_activation = "identity"
|
118
|
+
|
119
|
+
# cast strings to ActivationFunction for type checking
|
120
|
+
activation = cast(ActivationFunction, activation)
|
121
|
+
final_activation = cast(ActivationFunction, final_activation)
|
122
|
+
|
123
|
+
if dtype not in DTYPE_MAP.values():
|
124
|
+
raise ValueError(f"Invalid dtype: {dtype}. Must be one of {DTYPE_MAP.values()}")
|
125
|
+
|
126
|
+
dtype = {v: k for k, v in DTYPE_MAP.items()}[dtype]
|
127
|
+
|
128
|
+
hyperparams: MLPHyperParams = {
|
129
|
+
"in_size": model.in_size,
|
130
|
+
"out_size": model.out_size,
|
131
|
+
"width_size": model.width_size,
|
132
|
+
"depth": model.depth,
|
133
|
+
"activation": activation,
|
134
|
+
"final_activation": final_activation,
|
135
|
+
"use_bias": model.use_bias,
|
136
|
+
"use_final_bias": model.use_final_bias,
|
137
|
+
"dtype": dtype,
|
138
|
+
}
|
139
|
+
|
140
|
+
with open(output_path, "wb") as f:
|
141
|
+
hyperparam_str = json.dumps(hyperparams)
|
142
|
+
f.write((hyperparam_str + "\n").encode(encoding="utf-8"))
|
143
|
+
eqx.tree_serialise_leaves(f, model)
|
144
|
+
|
145
|
+
|
146
|
+
def save_eqx(
|
147
|
+
model: eqx.Module,
|
148
|
+
output_path: str | Path,
|
149
|
+
) -> None:
|
150
|
+
"""Serialize an Equinox module to a .eqx file.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
model: The Equinox module to export.
|
154
|
+
output_path: The path to save the exported model.
|
155
|
+
"""
|
156
|
+
with open(output_path, "wb") as f:
|
157
|
+
eqx.tree_serialise_leaves(f, model)
|
158
|
+
|
159
|
+
|
160
|
+
def load_eqx(
|
161
|
+
model: eqx.Module,
|
162
|
+
eqx_file: str | Path,
|
163
|
+
) -> eqx.Module:
|
164
|
+
"""Deserialize an Equinox module from a .eqx file.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
model: The Equinox module to load into.
|
168
|
+
eqx_file: The path to the .eqx file to load.
|
169
|
+
"""
|
170
|
+
with open(eqx_file, "rb") as f:
|
171
|
+
return eqx.tree_deserialise_leaves(f, model)
|
172
|
+
|
173
|
+
|
174
|
+
def load_eqx_mlp(
|
175
|
+
eqx_file: str | Path,
|
176
|
+
) -> eqx.nn.MLP:
|
177
|
+
with open(eqx_file, "rb") as f:
|
178
|
+
hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
|
179
|
+
model = make_eqx_mlp(hyperparams=hyperparams)
|
180
|
+
return eqx.tree_deserialise_leaves(f, model)
|
xax/nn/export.py
ADDED
@@ -0,0 +1,147 @@
|
|
1
|
+
"""Export JAX functions to TensorFlow SavedModel format."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Callable
|
6
|
+
|
7
|
+
import flax
|
8
|
+
import jax
|
9
|
+
import tensorflow as tf
|
10
|
+
from jax.experimental import jax2tf
|
11
|
+
from jaxtyping import Array, PyTree
|
12
|
+
from orbax.export import ExportManager, JaxModule, ServingConfig
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
def _run_infer(tf_module: tf.Module, input_shapes: list[tuple[int, ...]], batch_size: int | None) -> tf.Tensor:
|
18
|
+
"""Warm up the model by running it once."""
|
19
|
+
if batch_size is not None:
|
20
|
+
test_inputs = [
|
21
|
+
jax.random.normal(jax.random.PRNGKey(42), (batch_size, *input_shape)) for input_shape in input_shapes
|
22
|
+
]
|
23
|
+
else:
|
24
|
+
test_inputs = [jax.random.normal(jax.random.PRNGKey(42), (1, *input_shape)) for input_shape in input_shapes]
|
25
|
+
if not hasattr(tf_module, "infer"):
|
26
|
+
raise ValueError("Model does not have an infer method")
|
27
|
+
return tf_module.infer(*test_inputs)
|
28
|
+
|
29
|
+
|
30
|
+
def export(
|
31
|
+
model: Callable,
|
32
|
+
input_shapes: list[tuple[int, ...]],
|
33
|
+
output_dir: str | Path = "export",
|
34
|
+
batch_size: int | None = None,
|
35
|
+
) -> None:
|
36
|
+
"""Export a JAX function to TensorFlow SavedModel.
|
37
|
+
|
38
|
+
Note: Tensorflow GraphDef can't be larger than 2GB - https://github.com/tensorflow/tensorflow/issues/51870
|
39
|
+
You can avoid this by saving model parameters as non-constants.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
model: The JAX function to export.
|
43
|
+
input_shapes: The shape of the input tensors, excluding batch dimension.
|
44
|
+
output_dir: Directory to save the exported model.
|
45
|
+
batch_size: Optional batch dimension. If None, a polymorphic batch dimension is used.
|
46
|
+
"""
|
47
|
+
tf_module = tf.Module()
|
48
|
+
# Create a polymorphic shape specification for each input
|
49
|
+
poly_spec = "(b, ...)" if batch_size is not None else "(None, ...)"
|
50
|
+
polymorphic_shapes = [poly_spec] * len(input_shapes)
|
51
|
+
tf_module.infer = tf.function( # type: ignore [attr-defined]
|
52
|
+
jax2tf.convert(
|
53
|
+
model,
|
54
|
+
polymorphic_shapes=polymorphic_shapes,
|
55
|
+
# setting this to False will allow the model to run on platforms other than the one that exports the model
|
56
|
+
# https://github.com/jax-ml/jax/blob/051687dc4c899df3d95c30b812ade401d8b31166/jax/experimental/jax2tf/README.md?plain=1#L1342
|
57
|
+
# generally though I think native_serialization is recommended
|
58
|
+
native_serialization=False,
|
59
|
+
with_gradient=False,
|
60
|
+
),
|
61
|
+
autograph=False,
|
62
|
+
input_signature=[tf.TensorSpec([batch_size] + list(input_shape), tf.float32) for input_shape in input_shapes],
|
63
|
+
)
|
64
|
+
|
65
|
+
# warm up the model
|
66
|
+
_run_infer(tf_module, input_shapes, batch_size)
|
67
|
+
|
68
|
+
logger.info("Exporting SavedModel to %s", output_dir)
|
69
|
+
tf.saved_model.save(
|
70
|
+
tf_module,
|
71
|
+
output_dir,
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
def export_with_params(
|
76
|
+
model: Callable,
|
77
|
+
params: PyTree,
|
78
|
+
input_shapes: list[tuple[int, ...]],
|
79
|
+
output_dir: str | Path = "export",
|
80
|
+
batch_dim: int | None = None,
|
81
|
+
) -> None:
|
82
|
+
"""Export a JAX function that takes parameters to TensorFlow SavedModel.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
model: The JAX function to export. Should take parameters as first argument.
|
86
|
+
params: The parameters to use for the model.
|
87
|
+
input_shapes: The shape of the input tensors, excluding batch dimension.
|
88
|
+
output_dir: Directory to save the exported model.
|
89
|
+
batch_dim: Optional batch dimension. If None, a polymorphic batch dimension is used.
|
90
|
+
"""
|
91
|
+
param_vars = tf.nest.map_structure(tf.Variable, params)
|
92
|
+
|
93
|
+
converted_model = jax2tf.convert(model)
|
94
|
+
|
95
|
+
def model_fn(*inputs: PyTree) -> Array:
|
96
|
+
return converted_model(param_vars, *inputs)
|
97
|
+
|
98
|
+
tf_module = tf.Module()
|
99
|
+
tf_module._variables = tf.nest.flatten(param_vars) # type: ignore [attr-defined]
|
100
|
+
tf_module.infer = tf.function( # type: ignore [attr-defined]
|
101
|
+
model_fn,
|
102
|
+
jit_compile=True,
|
103
|
+
autograph=False,
|
104
|
+
input_signature=[tf.TensorSpec([batch_dim] + list(input_shape), tf.float32) for input_shape in input_shapes],
|
105
|
+
)
|
106
|
+
|
107
|
+
# warm up the model
|
108
|
+
_run_infer(tf_module, input_shapes, batch_dim)
|
109
|
+
|
110
|
+
logger.info("Exporting SavedModel to %s", output_dir)
|
111
|
+
tf.saved_model.save(tf_module, output_dir)
|
112
|
+
|
113
|
+
|
114
|
+
def export_flax(
|
115
|
+
model: flax.linen.Module,
|
116
|
+
params: PyTree,
|
117
|
+
input_shape: tuple[int, ...],
|
118
|
+
preprocessor: Callable | None = None,
|
119
|
+
postprocessor: Callable | None = None,
|
120
|
+
input_name: str = "inputs",
|
121
|
+
output_name: str = "outputs",
|
122
|
+
output_dir: str | Path = "export",
|
123
|
+
) -> None:
|
124
|
+
jax_module = JaxModule(
|
125
|
+
params, model.apply, trainable=False, input_polymorphic_shape="(b, ...)"
|
126
|
+
) # if you want to use a batch dimension
|
127
|
+
|
128
|
+
# to avoid mapping sequences to ambiguous mappings
|
129
|
+
if postprocessor is None:
|
130
|
+
|
131
|
+
def postprocessor(x: PyTree) -> PyTree:
|
132
|
+
return {output_name: x}
|
133
|
+
|
134
|
+
export_manager = ExportManager(
|
135
|
+
jax_module,
|
136
|
+
[
|
137
|
+
ServingConfig(
|
138
|
+
"serving_default",
|
139
|
+
input_signature=[tf.TensorSpec([None] + list(input_shape), tf.float32, name=input_name)],
|
140
|
+
tf_preprocessor=preprocessor,
|
141
|
+
tf_postprocessor=postprocessor,
|
142
|
+
)
|
143
|
+
],
|
144
|
+
)
|
145
|
+
|
146
|
+
logger.info("Exporting model to %s", output_dir)
|
147
|
+
export_manager.save(output_dir)
|
xax/nn/geom.py
ADDED
@@ -0,0 +1,101 @@
|
|
1
|
+
"""Defines geometry functions."""
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from jax import numpy as jnp
|
5
|
+
|
6
|
+
|
7
|
+
def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
|
8
|
+
"""Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
quat_4: The quaternion to convert, shape (*, 4).
|
12
|
+
eps: A small epsilon value to avoid division by zero.
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
The roll, pitch, yaw angles with shape (*, 3).
|
16
|
+
"""
|
17
|
+
quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
|
18
|
+
w, x, y, z = jnp.split(quat_4, 4, axis=-1)
|
19
|
+
|
20
|
+
# Roll (x-axis rotation)
|
21
|
+
sinr_cosp = 2.0 * (w * x + y * z)
|
22
|
+
cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
|
23
|
+
roll = jnp.arctan2(sinr_cosp, cosr_cosp)
|
24
|
+
|
25
|
+
# Pitch (y-axis rotation)
|
26
|
+
sinp = 2.0 * (w * y - z * x)
|
27
|
+
|
28
|
+
# Handle edge cases where |sinp| >= 1
|
29
|
+
pitch = jnp.where(
|
30
|
+
jnp.abs(sinp) >= 1.0,
|
31
|
+
jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
|
32
|
+
jnp.arcsin(sinp),
|
33
|
+
)
|
34
|
+
|
35
|
+
# Yaw (z-axis rotation)
|
36
|
+
siny_cosp = 2.0 * (w * z + x * y)
|
37
|
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
38
|
+
yaw = jnp.arctan2(siny_cosp, cosy_cosp)
|
39
|
+
|
40
|
+
return jnp.concatenate([roll, pitch, yaw], axis=-1)
|
41
|
+
|
42
|
+
|
43
|
+
def euler_to_quat(euler_3: jax.Array) -> jax.Array:
|
44
|
+
"""Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
|
45
|
+
|
46
|
+
Args:
|
47
|
+
euler_3: The roll, pitch, yaw angles, shape (*, 3).
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The quaternion with shape (*, 4).
|
51
|
+
"""
|
52
|
+
# Extract roll, pitch, yaw from input
|
53
|
+
roll, pitch, yaw = jnp.split(euler_3, 3, axis=-1)
|
54
|
+
|
55
|
+
# Calculate trigonometric functions for each angle
|
56
|
+
cr = jnp.cos(roll * 0.5)
|
57
|
+
sr = jnp.sin(roll * 0.5)
|
58
|
+
cp = jnp.cos(pitch * 0.5)
|
59
|
+
sp = jnp.sin(pitch * 0.5)
|
60
|
+
cy = jnp.cos(yaw * 0.5)
|
61
|
+
sy = jnp.sin(yaw * 0.5)
|
62
|
+
|
63
|
+
# Calculate quaternion components using the conversion formula
|
64
|
+
w = cr * cp * cy + sr * sp * sy
|
65
|
+
x = sr * cp * cy - cr * sp * sy
|
66
|
+
y = cr * sp * cy + sr * cp * sy
|
67
|
+
z = cr * cp * sy - sr * sp * cy
|
68
|
+
|
69
|
+
# Combine into quaternion [w, x, y, z]
|
70
|
+
quat = jnp.concatenate([w, x, y, z], axis=-1)
|
71
|
+
|
72
|
+
# Normalize the quaternion
|
73
|
+
quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
|
74
|
+
|
75
|
+
return quat
|
76
|
+
|
77
|
+
|
78
|
+
def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -> jax.Array:
|
79
|
+
"""Calculates the gravity vector projected onto the local frame given a quaternion orientation.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
quat: A quaternion (w,x,y,z) representing the orientation, shape (*, 4).
|
83
|
+
eps: A small epsilon value to avoid division by zero.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
A 3D vector representing the gravity in the local frame, shape (*, 3).
|
87
|
+
"""
|
88
|
+
# Normalize quaternion
|
89
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
90
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
91
|
+
|
92
|
+
# Gravity vector in world frame is [0, 0, -1] (pointing down)
|
93
|
+
# Rotate gravity vector using quaternion rotation
|
94
|
+
|
95
|
+
# Calculate quaternion rotation: q * [0,0,-1] * q^-1
|
96
|
+
gx = 2 * (x * z - w * y)
|
97
|
+
gy = 2 * (y * z + w * x)
|
98
|
+
gz = w * w - x * x - y * y + z * z
|
99
|
+
|
100
|
+
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
101
|
+
return jnp.concatenate([gx, gy, -gz], axis=-1)
|