xax 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/RECORD +26 -21
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -4,14 +4,15 @@ This package is structured so that all the important stuff can be accessed
|
|
4
4
|
without having to dig around through the internals. This is done by lazily
|
5
5
|
importing the module by name.
|
6
6
|
|
7
|
-
This file can be maintained by
|
7
|
+
This file can be maintained by updating the imports at the bottom of the file
|
8
|
+
and running the update script:
|
8
9
|
|
9
10
|
.. code-block:: bash
|
10
11
|
|
11
12
|
python -m scripts.update_api --inplace
|
12
13
|
"""
|
13
14
|
|
14
|
-
__version__ = "0.0
|
15
|
+
__version__ = "0.1.0"
|
15
16
|
|
16
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
18
|
__all__ = [
|
@@ -34,8 +35,20 @@ __all__ = [
|
|
34
35
|
"get_positional_embeddings",
|
35
36
|
"get_rotary_embeddings",
|
36
37
|
"rotary_embeddings",
|
38
|
+
"MLPHyperParams",
|
39
|
+
"export_eqx_mlp",
|
40
|
+
"load_eqx",
|
41
|
+
"load_eqx_mlp",
|
42
|
+
"make_eqx_mlp",
|
43
|
+
"save_eqx",
|
44
|
+
"export",
|
45
|
+
"export_flax",
|
46
|
+
"export_with_params",
|
37
47
|
"euler_to_quat",
|
48
|
+
"get_projected_gravity_vector_from_quat",
|
38
49
|
"quat_to_euler",
|
50
|
+
"cast_norm_type",
|
51
|
+
"get_norm",
|
39
52
|
"is_master",
|
40
53
|
"BaseLauncher",
|
41
54
|
"CliLauncher",
|
@@ -52,13 +65,16 @@ __all__ = [
|
|
52
65
|
"CPUStatsOptions",
|
53
66
|
"DataloaderConfig",
|
54
67
|
"GPUStatsOptions",
|
68
|
+
"StepContext",
|
55
69
|
"Script",
|
56
70
|
"ScriptConfig",
|
57
71
|
"Config",
|
58
72
|
"Task",
|
59
73
|
"collate",
|
60
74
|
"collate_non_null",
|
75
|
+
"get_named_leaves",
|
61
76
|
"BaseFileDownloader",
|
77
|
+
"ContextTimer",
|
62
78
|
"CumulativeTimer",
|
63
79
|
"DataDownloader",
|
64
80
|
"IntervalTicker",
|
@@ -81,6 +97,7 @@ __all__ = [
|
|
81
97
|
"stage_environment",
|
82
98
|
"to_markdown_table",
|
83
99
|
"jit",
|
100
|
+
"save_jaxpr_dot",
|
84
101
|
"ColoredFormatter",
|
85
102
|
"configure_logging",
|
86
103
|
"one_hot",
|
@@ -90,8 +107,13 @@ __all__ = [
|
|
90
107
|
"compute_nan_ratio",
|
91
108
|
"flatten_array",
|
92
109
|
"flatten_pytree",
|
110
|
+
"pytree_has_nans",
|
111
|
+
"reshuffle_pytree",
|
112
|
+
"reshuffle_pytree_along_dims",
|
113
|
+
"reshuffle_pytree_independently",
|
93
114
|
"slice_array",
|
94
115
|
"slice_pytree",
|
116
|
+
"update_pytree",
|
95
117
|
"TextBlock",
|
96
118
|
"camelcase_to_snakecase",
|
97
119
|
"colored",
|
@@ -113,21 +135,36 @@ __all__ += [
|
|
113
135
|
"Batch",
|
114
136
|
"CollateMode",
|
115
137
|
"EmbeddingKind",
|
138
|
+
"ActivationFunction",
|
139
|
+
"DTYPE",
|
116
140
|
"LOG_ERROR_SUMMARY",
|
117
141
|
"LOG_PING",
|
118
142
|
"LOG_STATUS",
|
143
|
+
"NormType",
|
119
144
|
"Output",
|
120
145
|
"Phase",
|
121
146
|
"RawConfigType",
|
122
147
|
]
|
123
148
|
|
124
149
|
import os
|
150
|
+
import shutil
|
125
151
|
from typing import TYPE_CHECKING
|
126
152
|
|
153
|
+
# Sets some useful XLA flags.
|
154
|
+
xla_flags: list[str] = []
|
155
|
+
if "XLA_FLAGS" in os.environ:
|
156
|
+
xla_flags.append(os.environ["XLA_FLAGS"])
|
157
|
+
|
158
|
+
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
159
|
+
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
160
|
+
if shutil.which("nvidia-smi") is not None:
|
161
|
+
xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
|
162
|
+
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
163
|
+
|
127
164
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
128
165
|
IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
|
129
166
|
|
130
|
-
del os
|
167
|
+
del os, shutil, xla_flags
|
131
168
|
|
132
169
|
# This dictionary is auto-generated and shouldn't be modified by hand; instead,
|
133
170
|
# run the update script.
|
@@ -151,8 +188,20 @@ NAME_MAP: dict[str, str] = {
|
|
151
188
|
"get_positional_embeddings": "nn.embeddings",
|
152
189
|
"get_rotary_embeddings": "nn.embeddings",
|
153
190
|
"rotary_embeddings": "nn.embeddings",
|
191
|
+
"MLPHyperParams": "nn.equinox",
|
192
|
+
"export_eqx_mlp": "nn.equinox",
|
193
|
+
"load_eqx": "nn.equinox",
|
194
|
+
"load_eqx_mlp": "nn.equinox",
|
195
|
+
"make_eqx_mlp": "nn.equinox",
|
196
|
+
"save_eqx": "nn.equinox",
|
197
|
+
"export": "nn.export",
|
198
|
+
"export_flax": "nn.export",
|
199
|
+
"export_with_params": "nn.export",
|
154
200
|
"euler_to_quat": "nn.geom",
|
201
|
+
"get_projected_gravity_vector_from_quat": "nn.geom",
|
155
202
|
"quat_to_euler": "nn.geom",
|
203
|
+
"cast_norm_type": "nn.norm",
|
204
|
+
"get_norm": "nn.norm",
|
156
205
|
"is_master": "nn.parallel",
|
157
206
|
"BaseLauncher": "task.launchers.base",
|
158
207
|
"CliLauncher": "task.launchers.cli",
|
@@ -169,13 +218,16 @@ NAME_MAP: dict[str, str] = {
|
|
169
218
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
170
219
|
"DataloaderConfig": "task.mixins.data_loader",
|
171
220
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
221
|
+
"StepContext": "task.mixins.step_wrapper",
|
172
222
|
"Script": "task.script",
|
173
223
|
"ScriptConfig": "task.script",
|
174
224
|
"Config": "task.task",
|
175
225
|
"Task": "task.task",
|
176
226
|
"collate": "utils.data.collate",
|
177
227
|
"collate_non_null": "utils.data.collate",
|
228
|
+
"get_named_leaves": "utils.debugging",
|
178
229
|
"BaseFileDownloader": "utils.experiments",
|
230
|
+
"ContextTimer": "utils.experiments",
|
179
231
|
"CumulativeTimer": "utils.experiments",
|
180
232
|
"DataDownloader": "utils.experiments",
|
181
233
|
"IntervalTicker": "utils.experiments",
|
@@ -198,6 +250,7 @@ NAME_MAP: dict[str, str] = {
|
|
198
250
|
"stage_environment": "utils.experiments",
|
199
251
|
"to_markdown_table": "utils.experiments",
|
200
252
|
"jit": "utils.jax",
|
253
|
+
"save_jaxpr_dot": "utils.jaxpr",
|
201
254
|
"ColoredFormatter": "utils.logging",
|
202
255
|
"configure_logging": "utils.logging",
|
203
256
|
"one_hot": "utils.numpy",
|
@@ -207,8 +260,13 @@ NAME_MAP: dict[str, str] = {
|
|
207
260
|
"compute_nan_ratio": "utils.pytree",
|
208
261
|
"flatten_array": "utils.pytree",
|
209
262
|
"flatten_pytree": "utils.pytree",
|
263
|
+
"pytree_has_nans": "utils.pytree",
|
264
|
+
"reshuffle_pytree": "utils.pytree",
|
265
|
+
"reshuffle_pytree_along_dims": "utils.pytree",
|
266
|
+
"reshuffle_pytree_independently": "utils.pytree",
|
210
267
|
"slice_array": "utils.pytree",
|
211
268
|
"slice_pytree": "utils.pytree",
|
269
|
+
"update_pytree": "utils.pytree",
|
212
270
|
"TextBlock": "utils.text",
|
213
271
|
"camelcase_to_snakecase": "utils.text",
|
214
272
|
"colored": "utils.text",
|
@@ -235,9 +293,12 @@ NAME_MAP.update(
|
|
235
293
|
"LOG_ERROR_SUMMARY": "utils.logging",
|
236
294
|
"LOG_PING": "utils.logging",
|
237
295
|
"LOG_STATUS": "utils.logging",
|
296
|
+
"NormType": "nn.norm",
|
238
297
|
"Output": "task.mixins.output",
|
239
298
|
"Phase": "core.state",
|
240
299
|
"RawConfigType": "task.base",
|
300
|
+
"ActivationFunction": "nn.equinox",
|
301
|
+
"DTYPE": "nn.equinox",
|
241
302
|
},
|
242
303
|
)
|
243
304
|
|
@@ -275,7 +336,27 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
275
336
|
get_rotary_embeddings,
|
276
337
|
rotary_embeddings,
|
277
338
|
)
|
278
|
-
from xax.nn.
|
339
|
+
from xax.nn.equinox import (
|
340
|
+
DTYPE,
|
341
|
+
ActivationFunction,
|
342
|
+
MLPHyperParams,
|
343
|
+
export_eqx_mlp,
|
344
|
+
load_eqx,
|
345
|
+
load_eqx_mlp,
|
346
|
+
make_eqx_mlp,
|
347
|
+
save_eqx,
|
348
|
+
)
|
349
|
+
from xax.nn.export import (
|
350
|
+
export,
|
351
|
+
export_flax,
|
352
|
+
export_with_params,
|
353
|
+
)
|
354
|
+
from xax.nn.geom import (
|
355
|
+
euler_to_quat,
|
356
|
+
get_projected_gravity_vector_from_quat,
|
357
|
+
quat_to_euler,
|
358
|
+
)
|
359
|
+
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
279
360
|
from xax.nn.parallel import is_master
|
280
361
|
from xax.task.base import RawConfigType
|
281
362
|
from xax.task.launchers.base import BaseLauncher
|
@@ -290,12 +371,15 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
290
371
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
291
372
|
from xax.task.mixins.data_loader import DataloaderConfig
|
292
373
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
374
|
+
from xax.task.mixins.step_wrapper import StepContext
|
293
375
|
from xax.task.mixins.train import Batch, Output
|
294
376
|
from xax.task.script import Script, ScriptConfig
|
295
377
|
from xax.task.task import Config, Task
|
296
378
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
379
|
+
from xax.utils.debugging import get_named_leaves
|
297
380
|
from xax.utils.experiments import (
|
298
381
|
BaseFileDownloader,
|
382
|
+
ContextTimer,
|
299
383
|
CumulativeTimer,
|
300
384
|
DataDownloader,
|
301
385
|
IntervalTicker,
|
@@ -319,6 +403,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
319
403
|
to_markdown_table,
|
320
404
|
)
|
321
405
|
from xax.utils.jax import jit
|
406
|
+
from xax.utils.jaxpr import save_jaxpr_dot
|
322
407
|
from xax.utils.logging import (
|
323
408
|
LOG_ERROR_SUMMARY,
|
324
409
|
LOG_PING,
|
@@ -332,8 +417,13 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
332
417
|
compute_nan_ratio,
|
333
418
|
flatten_array,
|
334
419
|
flatten_pytree,
|
420
|
+
pytree_has_nans,
|
421
|
+
reshuffle_pytree,
|
422
|
+
reshuffle_pytree_along_dims,
|
423
|
+
reshuffle_pytree_independently,
|
335
424
|
slice_array,
|
336
425
|
slice_pytree,
|
426
|
+
update_pytree,
|
337
427
|
)
|
338
428
|
from xax.utils.text import (
|
339
429
|
TextBlock,
|
xax/nn/equinox.py
ADDED
@@ -0,0 +1,180 @@
|
|
1
|
+
"""Equinox utilities."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Callable, Literal, TypedDict, cast
|
7
|
+
|
8
|
+
import equinox as eqx
|
9
|
+
import jax
|
10
|
+
from jaxtyping import PRNGKeyArray
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
ActivationFunction = Literal[
|
15
|
+
"relu",
|
16
|
+
"tanh",
|
17
|
+
"celu",
|
18
|
+
"elu",
|
19
|
+
"gelu",
|
20
|
+
"glu",
|
21
|
+
"hard_sigmoid",
|
22
|
+
"hard_silu",
|
23
|
+
"hard_swish",
|
24
|
+
"hard_tanh",
|
25
|
+
"leaky_relu",
|
26
|
+
"log_sigmoid",
|
27
|
+
"log_softmax",
|
28
|
+
"logsumexp",
|
29
|
+
"relu6",
|
30
|
+
"selu",
|
31
|
+
"sigmoid",
|
32
|
+
"soft_sign",
|
33
|
+
"softmax",
|
34
|
+
"softplus",
|
35
|
+
"sparse_plus",
|
36
|
+
"sparse_sigmoid",
|
37
|
+
"silu",
|
38
|
+
"swish",
|
39
|
+
"squareplus",
|
40
|
+
"mish",
|
41
|
+
"identity",
|
42
|
+
]
|
43
|
+
|
44
|
+
DTYPE = Literal["float32", "float64"]
|
45
|
+
|
46
|
+
DTYPE_MAP: dict[DTYPE, jax.numpy.dtype] = {
|
47
|
+
"float32": jax.numpy.float32,
|
48
|
+
"float64": jax.numpy.float64,
|
49
|
+
}
|
50
|
+
|
51
|
+
|
52
|
+
class MLPHyperParams(TypedDict):
|
53
|
+
"""Hyperparameters of an Equinox MLP."""
|
54
|
+
|
55
|
+
in_size: int | Literal["scalar"]
|
56
|
+
out_size: int | Literal["scalar"]
|
57
|
+
width_size: int
|
58
|
+
depth: int
|
59
|
+
activation: ActivationFunction
|
60
|
+
final_activation: ActivationFunction
|
61
|
+
use_bias: bool
|
62
|
+
use_final_bias: bool
|
63
|
+
dtype: DTYPE
|
64
|
+
|
65
|
+
|
66
|
+
def _infer_activation(activation: ActivationFunction) -> Callable:
|
67
|
+
if activation == "identity":
|
68
|
+
return lambda x: x
|
69
|
+
try:
|
70
|
+
return getattr(jax.nn, activation)
|
71
|
+
except AttributeError:
|
72
|
+
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
|
73
|
+
|
74
|
+
|
75
|
+
def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray = jax.random.PRNGKey(0)) -> eqx.nn.MLP:
|
76
|
+
"""Create an Equinox MLP from a set of hyperparameters.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
hyperparams: The hyperparameters of the MLP.
|
80
|
+
key: The PRNG key to use for the MLP.
|
81
|
+
"""
|
82
|
+
activation = _infer_activation(hyperparams["activation"])
|
83
|
+
final_activation = _infer_activation(hyperparams["final_activation"])
|
84
|
+
dtype = DTYPE_MAP[hyperparams["dtype"]]
|
85
|
+
|
86
|
+
return eqx.nn.MLP(
|
87
|
+
in_size=hyperparams["in_size"],
|
88
|
+
out_size=hyperparams["out_size"],
|
89
|
+
width_size=hyperparams["width_size"],
|
90
|
+
depth=hyperparams["depth"],
|
91
|
+
activation=activation,
|
92
|
+
final_activation=final_activation,
|
93
|
+
use_bias=hyperparams["use_bias"],
|
94
|
+
use_final_bias=hyperparams["use_final_bias"],
|
95
|
+
dtype=dtype,
|
96
|
+
key=key,
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def export_eqx_mlp(
|
101
|
+
model: eqx.nn.MLP,
|
102
|
+
output_path: str | Path,
|
103
|
+
dtype: jax.numpy.dtype = eqx._misc.default_floating_dtype(),
|
104
|
+
) -> None:
|
105
|
+
"""Serialize an Equinox MLP to a .eqx file.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
model: The JAX MLP to export.
|
109
|
+
output_path: The path to save the exported model.
|
110
|
+
dtype: The dtype of the model.
|
111
|
+
"""
|
112
|
+
activation = model.activation.__name__
|
113
|
+
final_activation = model.final_activation.__name__
|
114
|
+
|
115
|
+
if final_activation == "<lambda>":
|
116
|
+
logger.warning("Final activation is a lambda function. Assuming identity.")
|
117
|
+
final_activation = "identity"
|
118
|
+
|
119
|
+
# cast strings to ActivationFunction for type checking
|
120
|
+
activation = cast(ActivationFunction, activation)
|
121
|
+
final_activation = cast(ActivationFunction, final_activation)
|
122
|
+
|
123
|
+
if dtype not in DTYPE_MAP.values():
|
124
|
+
raise ValueError(f"Invalid dtype: {dtype}. Must be one of {DTYPE_MAP.values()}")
|
125
|
+
|
126
|
+
dtype = {v: k for k, v in DTYPE_MAP.items()}[dtype]
|
127
|
+
|
128
|
+
hyperparams: MLPHyperParams = {
|
129
|
+
"in_size": model.in_size,
|
130
|
+
"out_size": model.out_size,
|
131
|
+
"width_size": model.width_size,
|
132
|
+
"depth": model.depth,
|
133
|
+
"activation": activation,
|
134
|
+
"final_activation": final_activation,
|
135
|
+
"use_bias": model.use_bias,
|
136
|
+
"use_final_bias": model.use_final_bias,
|
137
|
+
"dtype": dtype,
|
138
|
+
}
|
139
|
+
|
140
|
+
with open(output_path, "wb") as f:
|
141
|
+
hyperparam_str = json.dumps(hyperparams)
|
142
|
+
f.write((hyperparam_str + "\n").encode(encoding="utf-8"))
|
143
|
+
eqx.tree_serialise_leaves(f, model)
|
144
|
+
|
145
|
+
|
146
|
+
def save_eqx(
|
147
|
+
model: eqx.Module,
|
148
|
+
output_path: str | Path,
|
149
|
+
) -> None:
|
150
|
+
"""Serialize an Equinox module to a .eqx file.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
model: The Equinox module to export.
|
154
|
+
output_path: The path to save the exported model.
|
155
|
+
"""
|
156
|
+
with open(output_path, "wb") as f:
|
157
|
+
eqx.tree_serialise_leaves(f, model)
|
158
|
+
|
159
|
+
|
160
|
+
def load_eqx(
|
161
|
+
model: eqx.Module,
|
162
|
+
eqx_file: str | Path,
|
163
|
+
) -> eqx.Module:
|
164
|
+
"""Deserialize an Equinox module from a .eqx file.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
model: The Equinox module to load into.
|
168
|
+
eqx_file: The path to the .eqx file to load.
|
169
|
+
"""
|
170
|
+
with open(eqx_file, "rb") as f:
|
171
|
+
return eqx.tree_deserialise_leaves(f, model)
|
172
|
+
|
173
|
+
|
174
|
+
def load_eqx_mlp(
|
175
|
+
eqx_file: str | Path,
|
176
|
+
) -> eqx.nn.MLP:
|
177
|
+
with open(eqx_file, "rb") as f:
|
178
|
+
hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
|
179
|
+
model = make_eqx_mlp(hyperparams=hyperparams)
|
180
|
+
return eqx.tree_deserialise_leaves(f, model)
|
xax/nn/export.py
ADDED
@@ -0,0 +1,147 @@
|
|
1
|
+
"""Export JAX functions to TensorFlow SavedModel format."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Callable
|
6
|
+
|
7
|
+
import flax
|
8
|
+
import jax
|
9
|
+
import tensorflow as tf
|
10
|
+
from jax.experimental import jax2tf
|
11
|
+
from jaxtyping import Array, PyTree
|
12
|
+
from orbax.export import ExportManager, JaxModule, ServingConfig
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
def _run_infer(tf_module: tf.Module, input_shapes: list[tuple[int, ...]], batch_size: int | None) -> tf.Tensor:
|
18
|
+
"""Warm up the model by running it once."""
|
19
|
+
if batch_size is not None:
|
20
|
+
test_inputs = [
|
21
|
+
jax.random.normal(jax.random.PRNGKey(42), (batch_size, *input_shape)) for input_shape in input_shapes
|
22
|
+
]
|
23
|
+
else:
|
24
|
+
test_inputs = [jax.random.normal(jax.random.PRNGKey(42), (1, *input_shape)) for input_shape in input_shapes]
|
25
|
+
if not hasattr(tf_module, "infer"):
|
26
|
+
raise ValueError("Model does not have an infer method")
|
27
|
+
return tf_module.infer(*test_inputs)
|
28
|
+
|
29
|
+
|
30
|
+
def export(
|
31
|
+
model: Callable,
|
32
|
+
input_shapes: list[tuple[int, ...]],
|
33
|
+
output_dir: str | Path = "export",
|
34
|
+
batch_size: int | None = None,
|
35
|
+
) -> None:
|
36
|
+
"""Export a JAX function to TensorFlow SavedModel.
|
37
|
+
|
38
|
+
Note: Tensorflow GraphDef can't be larger than 2GB - https://github.com/tensorflow/tensorflow/issues/51870
|
39
|
+
You can avoid this by saving model parameters as non-constants.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
model: The JAX function to export.
|
43
|
+
input_shapes: The shape of the input tensors, excluding batch dimension.
|
44
|
+
output_dir: Directory to save the exported model.
|
45
|
+
batch_size: Optional batch dimension. If None, a polymorphic batch dimension is used.
|
46
|
+
"""
|
47
|
+
tf_module = tf.Module()
|
48
|
+
# Create a polymorphic shape specification for each input
|
49
|
+
poly_spec = "(b, ...)" if batch_size is not None else "(None, ...)"
|
50
|
+
polymorphic_shapes = [poly_spec] * len(input_shapes)
|
51
|
+
tf_module.infer = tf.function( # type: ignore [attr-defined]
|
52
|
+
jax2tf.convert(
|
53
|
+
model,
|
54
|
+
polymorphic_shapes=polymorphic_shapes,
|
55
|
+
# setting this to False will allow the model to run on platforms other than the one that exports the model
|
56
|
+
# https://github.com/jax-ml/jax/blob/051687dc4c899df3d95c30b812ade401d8b31166/jax/experimental/jax2tf/README.md?plain=1#L1342
|
57
|
+
# generally though I think native_serialization is recommended
|
58
|
+
native_serialization=False,
|
59
|
+
with_gradient=False,
|
60
|
+
),
|
61
|
+
autograph=False,
|
62
|
+
input_signature=[tf.TensorSpec([batch_size] + list(input_shape), tf.float32) for input_shape in input_shapes],
|
63
|
+
)
|
64
|
+
|
65
|
+
# warm up the model
|
66
|
+
_run_infer(tf_module, input_shapes, batch_size)
|
67
|
+
|
68
|
+
logger.info("Exporting SavedModel to %s", output_dir)
|
69
|
+
tf.saved_model.save(
|
70
|
+
tf_module,
|
71
|
+
output_dir,
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
def export_with_params(
|
76
|
+
model: Callable,
|
77
|
+
params: PyTree,
|
78
|
+
input_shapes: list[tuple[int, ...]],
|
79
|
+
output_dir: str | Path = "export",
|
80
|
+
batch_dim: int | None = None,
|
81
|
+
) -> None:
|
82
|
+
"""Export a JAX function that takes parameters to TensorFlow SavedModel.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
model: The JAX function to export. Should take parameters as first argument.
|
86
|
+
params: The parameters to use for the model.
|
87
|
+
input_shapes: The shape of the input tensors, excluding batch dimension.
|
88
|
+
output_dir: Directory to save the exported model.
|
89
|
+
batch_dim: Optional batch dimension. If None, a polymorphic batch dimension is used.
|
90
|
+
"""
|
91
|
+
param_vars = tf.nest.map_structure(tf.Variable, params)
|
92
|
+
|
93
|
+
converted_model = jax2tf.convert(model)
|
94
|
+
|
95
|
+
def model_fn(*inputs: PyTree) -> Array:
|
96
|
+
return converted_model(param_vars, *inputs)
|
97
|
+
|
98
|
+
tf_module = tf.Module()
|
99
|
+
tf_module._variables = tf.nest.flatten(param_vars) # type: ignore [attr-defined]
|
100
|
+
tf_module.infer = tf.function( # type: ignore [attr-defined]
|
101
|
+
model_fn,
|
102
|
+
jit_compile=True,
|
103
|
+
autograph=False,
|
104
|
+
input_signature=[tf.TensorSpec([batch_dim] + list(input_shape), tf.float32) for input_shape in input_shapes],
|
105
|
+
)
|
106
|
+
|
107
|
+
# warm up the model
|
108
|
+
_run_infer(tf_module, input_shapes, batch_dim)
|
109
|
+
|
110
|
+
logger.info("Exporting SavedModel to %s", output_dir)
|
111
|
+
tf.saved_model.save(tf_module, output_dir)
|
112
|
+
|
113
|
+
|
114
|
+
def export_flax(
|
115
|
+
model: flax.linen.Module,
|
116
|
+
params: PyTree,
|
117
|
+
input_shape: tuple[int, ...],
|
118
|
+
preprocessor: Callable | None = None,
|
119
|
+
postprocessor: Callable | None = None,
|
120
|
+
input_name: str = "inputs",
|
121
|
+
output_name: str = "outputs",
|
122
|
+
output_dir: str | Path = "export",
|
123
|
+
) -> None:
|
124
|
+
jax_module = JaxModule(
|
125
|
+
params, model.apply, trainable=False, input_polymorphic_shape="(b, ...)"
|
126
|
+
) # if you want to use a batch dimension
|
127
|
+
|
128
|
+
# to avoid mapping sequences to ambiguous mappings
|
129
|
+
if postprocessor is None:
|
130
|
+
|
131
|
+
def postprocessor(x: PyTree) -> PyTree:
|
132
|
+
return {output_name: x}
|
133
|
+
|
134
|
+
export_manager = ExportManager(
|
135
|
+
jax_module,
|
136
|
+
[
|
137
|
+
ServingConfig(
|
138
|
+
"serving_default",
|
139
|
+
input_signature=[tf.TensorSpec([None] + list(input_shape), tf.float32, name=input_name)],
|
140
|
+
tf_preprocessor=preprocessor,
|
141
|
+
tf_postprocessor=postprocessor,
|
142
|
+
)
|
143
|
+
],
|
144
|
+
)
|
145
|
+
|
146
|
+
logger.info("Exporting model to %s", output_dir)
|
147
|
+
export_manager.save(output_dir)
|
xax/nn/geom.py
CHANGED
@@ -73,3 +73,29 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
|
|
73
73
|
quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
|
74
74
|
|
75
75
|
return quat
|
76
|
+
|
77
|
+
|
78
|
+
def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -> jax.Array:
|
79
|
+
"""Calculates the gravity vector projected onto the local frame given a quaternion orientation.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
quat: A quaternion (w,x,y,z) representing the orientation, shape (*, 4).
|
83
|
+
eps: A small epsilon value to avoid division by zero.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
A 3D vector representing the gravity in the local frame, shape (*, 3).
|
87
|
+
"""
|
88
|
+
# Normalize quaternion
|
89
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
90
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
91
|
+
|
92
|
+
# Gravity vector in world frame is [0, 0, -1] (pointing down)
|
93
|
+
# Rotate gravity vector using quaternion rotation
|
94
|
+
|
95
|
+
# Calculate quaternion rotation: q * [0,0,-1] * q^-1
|
96
|
+
gx = 2 * (x * z - w * y)
|
97
|
+
gy = 2 * (y * z + w * x)
|
98
|
+
gz = w * w - x * x - y * y + z * z
|
99
|
+
|
100
|
+
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
101
|
+
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
xax/nn/norm.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Normalization utilities."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, get_args
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
NormType = Literal["l1", "l2"]
|
8
|
+
|
9
|
+
|
10
|
+
def cast_norm_type(norm: str) -> NormType:
|
11
|
+
if norm not in get_args(NormType):
|
12
|
+
raise ValueError(f"Invalid norm: {norm}")
|
13
|
+
return cast(NormType, norm)
|
14
|
+
|
15
|
+
|
16
|
+
def get_norm(x: jnp.ndarray, norm: NormType) -> jnp.ndarray:
|
17
|
+
match norm:
|
18
|
+
case "l1":
|
19
|
+
return jnp.abs(x)
|
20
|
+
case "l2":
|
21
|
+
return jnp.square(x)
|
22
|
+
case _:
|
23
|
+
raise ValueError(f"Invalid norm: {norm}")
|
xax/requirements.txt
CHANGED
xax/task/base.py
CHANGED
@@ -81,6 +81,12 @@ class BaseTask(Generic[Config]):
|
|
81
81
|
def on_training_end(self, state: State) -> State:
|
82
82
|
return state
|
83
83
|
|
84
|
+
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
85
|
+
return state
|
86
|
+
|
87
|
+
def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
|
88
|
+
pass
|
89
|
+
|
84
90
|
@functools.cached_property
|
85
91
|
def task_class_name(self) -> str:
|
86
92
|
return self.__class__.__name__
|