xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -4,14 +4,15 @@ This package is structured so that all the important stuff can be accessed
4
4
  without having to dig around through the internals. This is done by lazily
5
5
  importing the module by name.
6
6
 
7
- This file can be maintained by running the update script:
7
+ This file can be maintained by updating the imports at the bottom of the file
8
+ and running the update script:
8
9
 
9
10
  .. code-block:: bash
10
11
 
11
12
  python -m scripts.update_api --inplace
12
13
  """
13
14
 
14
- __version__ = "0.0.6"
15
+ __version__ = "0.1.0"
15
16
 
16
17
  # This list shouldn't be modified by hand; instead, run the update script.
17
18
  __all__ = [
@@ -34,6 +35,20 @@ __all__ = [
34
35
  "get_positional_embeddings",
35
36
  "get_rotary_embeddings",
36
37
  "rotary_embeddings",
38
+ "MLPHyperParams",
39
+ "export_eqx_mlp",
40
+ "load_eqx",
41
+ "load_eqx_mlp",
42
+ "make_eqx_mlp",
43
+ "save_eqx",
44
+ "export",
45
+ "export_flax",
46
+ "export_with_params",
47
+ "euler_to_quat",
48
+ "get_projected_gravity_vector_from_quat",
49
+ "quat_to_euler",
50
+ "cast_norm_type",
51
+ "get_norm",
37
52
  "is_master",
38
53
  "BaseLauncher",
39
54
  "CliLauncher",
@@ -50,13 +65,16 @@ __all__ = [
50
65
  "CPUStatsOptions",
51
66
  "DataloaderConfig",
52
67
  "GPUStatsOptions",
68
+ "StepContext",
53
69
  "Script",
54
70
  "ScriptConfig",
55
71
  "Config",
56
72
  "Task",
57
73
  "collate",
58
74
  "collate_non_null",
75
+ "get_named_leaves",
59
76
  "BaseFileDownloader",
77
+ "ContextTimer",
60
78
  "CumulativeTimer",
61
79
  "DataDownloader",
62
80
  "IntervalTicker",
@@ -78,11 +96,24 @@ __all__ = [
78
96
  "save_config",
79
97
  "stage_environment",
80
98
  "to_markdown_table",
99
+ "jit",
100
+ "save_jaxpr_dot",
81
101
  "ColoredFormatter",
82
102
  "configure_logging",
83
103
  "one_hot",
84
104
  "partial_flatten",
85
105
  "worker_chunk",
106
+ "profile",
107
+ "compute_nan_ratio",
108
+ "flatten_array",
109
+ "flatten_pytree",
110
+ "pytree_has_nans",
111
+ "reshuffle_pytree",
112
+ "reshuffle_pytree_along_dims",
113
+ "reshuffle_pytree_independently",
114
+ "slice_array",
115
+ "slice_pytree",
116
+ "update_pytree",
86
117
  "TextBlock",
87
118
  "camelcase_to_snakecase",
88
119
  "colored",
@@ -104,21 +135,36 @@ __all__ += [
104
135
  "Batch",
105
136
  "CollateMode",
106
137
  "EmbeddingKind",
138
+ "ActivationFunction",
139
+ "DTYPE",
107
140
  "LOG_ERROR_SUMMARY",
108
141
  "LOG_PING",
109
142
  "LOG_STATUS",
143
+ "NormType",
110
144
  "Output",
111
145
  "Phase",
112
146
  "RawConfigType",
113
147
  ]
114
148
 
115
149
  import os
150
+ import shutil
116
151
  from typing import TYPE_CHECKING
117
152
 
153
+ # Sets some useful XLA flags.
154
+ xla_flags: list[str] = []
155
+ if "XLA_FLAGS" in os.environ:
156
+ xla_flags.append(os.environ["XLA_FLAGS"])
157
+
158
+ # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
159
+ # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
160
+ if shutil.which("nvidia-smi") is not None:
161
+ xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
162
+ os.environ["XLA_FLAGS"] = " ".join(xla_flags)
163
+
118
164
  # If this flag is set, eagerly imports the entire package (not recommended).
119
165
  IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
120
166
 
121
- del os
167
+ del os, shutil, xla_flags
122
168
 
123
169
  # This dictionary is auto-generated and shouldn't be modified by hand; instead,
124
170
  # run the update script.
@@ -142,6 +188,20 @@ NAME_MAP: dict[str, str] = {
142
188
  "get_positional_embeddings": "nn.embeddings",
143
189
  "get_rotary_embeddings": "nn.embeddings",
144
190
  "rotary_embeddings": "nn.embeddings",
191
+ "MLPHyperParams": "nn.equinox",
192
+ "export_eqx_mlp": "nn.equinox",
193
+ "load_eqx": "nn.equinox",
194
+ "load_eqx_mlp": "nn.equinox",
195
+ "make_eqx_mlp": "nn.equinox",
196
+ "save_eqx": "nn.equinox",
197
+ "export": "nn.export",
198
+ "export_flax": "nn.export",
199
+ "export_with_params": "nn.export",
200
+ "euler_to_quat": "nn.geom",
201
+ "get_projected_gravity_vector_from_quat": "nn.geom",
202
+ "quat_to_euler": "nn.geom",
203
+ "cast_norm_type": "nn.norm",
204
+ "get_norm": "nn.norm",
145
205
  "is_master": "nn.parallel",
146
206
  "BaseLauncher": "task.launchers.base",
147
207
  "CliLauncher": "task.launchers.cli",
@@ -158,13 +218,16 @@ NAME_MAP: dict[str, str] = {
158
218
  "CPUStatsOptions": "task.mixins.cpu_stats",
159
219
  "DataloaderConfig": "task.mixins.data_loader",
160
220
  "GPUStatsOptions": "task.mixins.gpu_stats",
221
+ "StepContext": "task.mixins.step_wrapper",
161
222
  "Script": "task.script",
162
223
  "ScriptConfig": "task.script",
163
224
  "Config": "task.task",
164
225
  "Task": "task.task",
165
226
  "collate": "utils.data.collate",
166
227
  "collate_non_null": "utils.data.collate",
228
+ "get_named_leaves": "utils.debugging",
167
229
  "BaseFileDownloader": "utils.experiments",
230
+ "ContextTimer": "utils.experiments",
168
231
  "CumulativeTimer": "utils.experiments",
169
232
  "DataDownloader": "utils.experiments",
170
233
  "IntervalTicker": "utils.experiments",
@@ -186,11 +249,24 @@ NAME_MAP: dict[str, str] = {
186
249
  "save_config": "utils.experiments",
187
250
  "stage_environment": "utils.experiments",
188
251
  "to_markdown_table": "utils.experiments",
252
+ "jit": "utils.jax",
253
+ "save_jaxpr_dot": "utils.jaxpr",
189
254
  "ColoredFormatter": "utils.logging",
190
255
  "configure_logging": "utils.logging",
191
256
  "one_hot": "utils.numpy",
192
257
  "partial_flatten": "utils.numpy",
193
258
  "worker_chunk": "utils.numpy",
259
+ "profile": "utils.profile",
260
+ "compute_nan_ratio": "utils.pytree",
261
+ "flatten_array": "utils.pytree",
262
+ "flatten_pytree": "utils.pytree",
263
+ "pytree_has_nans": "utils.pytree",
264
+ "reshuffle_pytree": "utils.pytree",
265
+ "reshuffle_pytree_along_dims": "utils.pytree",
266
+ "reshuffle_pytree_independently": "utils.pytree",
267
+ "slice_array": "utils.pytree",
268
+ "slice_pytree": "utils.pytree",
269
+ "update_pytree": "utils.pytree",
194
270
  "TextBlock": "utils.text",
195
271
  "camelcase_to_snakecase": "utils.text",
196
272
  "colored": "utils.text",
@@ -217,9 +293,12 @@ NAME_MAP.update(
217
293
  "LOG_ERROR_SUMMARY": "utils.logging",
218
294
  "LOG_PING": "utils.logging",
219
295
  "LOG_STATUS": "utils.logging",
296
+ "NormType": "nn.norm",
220
297
  "Output": "task.mixins.output",
221
298
  "Phase": "core.state",
222
299
  "RawConfigType": "task.base",
300
+ "ActivationFunction": "nn.equinox",
301
+ "DTYPE": "nn.equinox",
223
302
  },
224
303
  )
225
304
 
@@ -257,6 +336,27 @@ if IMPORT_ALL or TYPE_CHECKING:
257
336
  get_rotary_embeddings,
258
337
  rotary_embeddings,
259
338
  )
339
+ from xax.nn.equinox import (
340
+ DTYPE,
341
+ ActivationFunction,
342
+ MLPHyperParams,
343
+ export_eqx_mlp,
344
+ load_eqx,
345
+ load_eqx_mlp,
346
+ make_eqx_mlp,
347
+ save_eqx,
348
+ )
349
+ from xax.nn.export import (
350
+ export,
351
+ export_flax,
352
+ export_with_params,
353
+ )
354
+ from xax.nn.geom import (
355
+ euler_to_quat,
356
+ get_projected_gravity_vector_from_quat,
357
+ quat_to_euler,
358
+ )
359
+ from xax.nn.norm import NormType, cast_norm_type, get_norm
260
360
  from xax.nn.parallel import is_master
261
361
  from xax.task.base import RawConfigType
262
362
  from xax.task.launchers.base import BaseLauncher
@@ -271,12 +371,15 @@ if IMPORT_ALL or TYPE_CHECKING:
271
371
  from xax.task.mixins.cpu_stats import CPUStatsOptions
272
372
  from xax.task.mixins.data_loader import DataloaderConfig
273
373
  from xax.task.mixins.gpu_stats import GPUStatsOptions
374
+ from xax.task.mixins.step_wrapper import StepContext
274
375
  from xax.task.mixins.train import Batch, Output
275
376
  from xax.task.script import Script, ScriptConfig
276
377
  from xax.task.task import Config, Task
277
378
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
379
+ from xax.utils.debugging import get_named_leaves
278
380
  from xax.utils.experiments import (
279
381
  BaseFileDownloader,
382
+ ContextTimer,
280
383
  CumulativeTimer,
281
384
  DataDownloader,
282
385
  IntervalTicker,
@@ -299,6 +402,8 @@ if IMPORT_ALL or TYPE_CHECKING:
299
402
  stage_environment,
300
403
  to_markdown_table,
301
404
  )
405
+ from xax.utils.jax import jit
406
+ from xax.utils.jaxpr import save_jaxpr_dot
302
407
  from xax.utils.logging import (
303
408
  LOG_ERROR_SUMMARY,
304
409
  LOG_PING,
@@ -307,6 +412,19 @@ if IMPORT_ALL or TYPE_CHECKING:
307
412
  configure_logging,
308
413
  )
309
414
  from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
415
+ from xax.utils.profile import profile
416
+ from xax.utils.pytree import (
417
+ compute_nan_ratio,
418
+ flatten_array,
419
+ flatten_pytree,
420
+ pytree_has_nans,
421
+ reshuffle_pytree,
422
+ reshuffle_pytree_along_dims,
423
+ reshuffle_pytree_independently,
424
+ slice_array,
425
+ slice_pytree,
426
+ update_pytree,
427
+ )
310
428
  from xax.utils.text import (
311
429
  TextBlock,
312
430
  camelcase_to_snakecase,
xax/nn/equinox.py ADDED
@@ -0,0 +1,180 @@
1
+ """Equinox utilities."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Callable, Literal, TypedDict, cast
7
+
8
+ import equinox as eqx
9
+ import jax
10
+ from jaxtyping import PRNGKeyArray
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ ActivationFunction = Literal[
15
+ "relu",
16
+ "tanh",
17
+ "celu",
18
+ "elu",
19
+ "gelu",
20
+ "glu",
21
+ "hard_sigmoid",
22
+ "hard_silu",
23
+ "hard_swish",
24
+ "hard_tanh",
25
+ "leaky_relu",
26
+ "log_sigmoid",
27
+ "log_softmax",
28
+ "logsumexp",
29
+ "relu6",
30
+ "selu",
31
+ "sigmoid",
32
+ "soft_sign",
33
+ "softmax",
34
+ "softplus",
35
+ "sparse_plus",
36
+ "sparse_sigmoid",
37
+ "silu",
38
+ "swish",
39
+ "squareplus",
40
+ "mish",
41
+ "identity",
42
+ ]
43
+
44
+ DTYPE = Literal["float32", "float64"]
45
+
46
+ DTYPE_MAP: dict[DTYPE, jax.numpy.dtype] = {
47
+ "float32": jax.numpy.float32,
48
+ "float64": jax.numpy.float64,
49
+ }
50
+
51
+
52
+ class MLPHyperParams(TypedDict):
53
+ """Hyperparameters of an Equinox MLP."""
54
+
55
+ in_size: int | Literal["scalar"]
56
+ out_size: int | Literal["scalar"]
57
+ width_size: int
58
+ depth: int
59
+ activation: ActivationFunction
60
+ final_activation: ActivationFunction
61
+ use_bias: bool
62
+ use_final_bias: bool
63
+ dtype: DTYPE
64
+
65
+
66
+ def _infer_activation(activation: ActivationFunction) -> Callable:
67
+ if activation == "identity":
68
+ return lambda x: x
69
+ try:
70
+ return getattr(jax.nn, activation)
71
+ except AttributeError:
72
+ raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
73
+
74
+
75
+ def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray = jax.random.PRNGKey(0)) -> eqx.nn.MLP:
76
+ """Create an Equinox MLP from a set of hyperparameters.
77
+
78
+ Args:
79
+ hyperparams: The hyperparameters of the MLP.
80
+ key: The PRNG key to use for the MLP.
81
+ """
82
+ activation = _infer_activation(hyperparams["activation"])
83
+ final_activation = _infer_activation(hyperparams["final_activation"])
84
+ dtype = DTYPE_MAP[hyperparams["dtype"]]
85
+
86
+ return eqx.nn.MLP(
87
+ in_size=hyperparams["in_size"],
88
+ out_size=hyperparams["out_size"],
89
+ width_size=hyperparams["width_size"],
90
+ depth=hyperparams["depth"],
91
+ activation=activation,
92
+ final_activation=final_activation,
93
+ use_bias=hyperparams["use_bias"],
94
+ use_final_bias=hyperparams["use_final_bias"],
95
+ dtype=dtype,
96
+ key=key,
97
+ )
98
+
99
+
100
+ def export_eqx_mlp(
101
+ model: eqx.nn.MLP,
102
+ output_path: str | Path,
103
+ dtype: jax.numpy.dtype = eqx._misc.default_floating_dtype(),
104
+ ) -> None:
105
+ """Serialize an Equinox MLP to a .eqx file.
106
+
107
+ Args:
108
+ model: The JAX MLP to export.
109
+ output_path: The path to save the exported model.
110
+ dtype: The dtype of the model.
111
+ """
112
+ activation = model.activation.__name__
113
+ final_activation = model.final_activation.__name__
114
+
115
+ if final_activation == "<lambda>":
116
+ logger.warning("Final activation is a lambda function. Assuming identity.")
117
+ final_activation = "identity"
118
+
119
+ # cast strings to ActivationFunction for type checking
120
+ activation = cast(ActivationFunction, activation)
121
+ final_activation = cast(ActivationFunction, final_activation)
122
+
123
+ if dtype not in DTYPE_MAP.values():
124
+ raise ValueError(f"Invalid dtype: {dtype}. Must be one of {DTYPE_MAP.values()}")
125
+
126
+ dtype = {v: k for k, v in DTYPE_MAP.items()}[dtype]
127
+
128
+ hyperparams: MLPHyperParams = {
129
+ "in_size": model.in_size,
130
+ "out_size": model.out_size,
131
+ "width_size": model.width_size,
132
+ "depth": model.depth,
133
+ "activation": activation,
134
+ "final_activation": final_activation,
135
+ "use_bias": model.use_bias,
136
+ "use_final_bias": model.use_final_bias,
137
+ "dtype": dtype,
138
+ }
139
+
140
+ with open(output_path, "wb") as f:
141
+ hyperparam_str = json.dumps(hyperparams)
142
+ f.write((hyperparam_str + "\n").encode(encoding="utf-8"))
143
+ eqx.tree_serialise_leaves(f, model)
144
+
145
+
146
+ def save_eqx(
147
+ model: eqx.Module,
148
+ output_path: str | Path,
149
+ ) -> None:
150
+ """Serialize an Equinox module to a .eqx file.
151
+
152
+ Args:
153
+ model: The Equinox module to export.
154
+ output_path: The path to save the exported model.
155
+ """
156
+ with open(output_path, "wb") as f:
157
+ eqx.tree_serialise_leaves(f, model)
158
+
159
+
160
+ def load_eqx(
161
+ model: eqx.Module,
162
+ eqx_file: str | Path,
163
+ ) -> eqx.Module:
164
+ """Deserialize an Equinox module from a .eqx file.
165
+
166
+ Args:
167
+ model: The Equinox module to load into.
168
+ eqx_file: The path to the .eqx file to load.
169
+ """
170
+ with open(eqx_file, "rb") as f:
171
+ return eqx.tree_deserialise_leaves(f, model)
172
+
173
+
174
+ def load_eqx_mlp(
175
+ eqx_file: str | Path,
176
+ ) -> eqx.nn.MLP:
177
+ with open(eqx_file, "rb") as f:
178
+ hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
179
+ model = make_eqx_mlp(hyperparams=hyperparams)
180
+ return eqx.tree_deserialise_leaves(f, model)
xax/nn/export.py ADDED
@@ -0,0 +1,147 @@
1
+ """Export JAX functions to TensorFlow SavedModel format."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Callable
6
+
7
+ import flax
8
+ import jax
9
+ import tensorflow as tf
10
+ from jax.experimental import jax2tf
11
+ from jaxtyping import Array, PyTree
12
+ from orbax.export import ExportManager, JaxModule, ServingConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _run_infer(tf_module: tf.Module, input_shapes: list[tuple[int, ...]], batch_size: int | None) -> tf.Tensor:
18
+ """Warm up the model by running it once."""
19
+ if batch_size is not None:
20
+ test_inputs = [
21
+ jax.random.normal(jax.random.PRNGKey(42), (batch_size, *input_shape)) for input_shape in input_shapes
22
+ ]
23
+ else:
24
+ test_inputs = [jax.random.normal(jax.random.PRNGKey(42), (1, *input_shape)) for input_shape in input_shapes]
25
+ if not hasattr(tf_module, "infer"):
26
+ raise ValueError("Model does not have an infer method")
27
+ return tf_module.infer(*test_inputs)
28
+
29
+
30
+ def export(
31
+ model: Callable,
32
+ input_shapes: list[tuple[int, ...]],
33
+ output_dir: str | Path = "export",
34
+ batch_size: int | None = None,
35
+ ) -> None:
36
+ """Export a JAX function to TensorFlow SavedModel.
37
+
38
+ Note: Tensorflow GraphDef can't be larger than 2GB - https://github.com/tensorflow/tensorflow/issues/51870
39
+ You can avoid this by saving model parameters as non-constants.
40
+
41
+ Args:
42
+ model: The JAX function to export.
43
+ input_shapes: The shape of the input tensors, excluding batch dimension.
44
+ output_dir: Directory to save the exported model.
45
+ batch_size: Optional batch dimension. If None, a polymorphic batch dimension is used.
46
+ """
47
+ tf_module = tf.Module()
48
+ # Create a polymorphic shape specification for each input
49
+ poly_spec = "(b, ...)" if batch_size is not None else "(None, ...)"
50
+ polymorphic_shapes = [poly_spec] * len(input_shapes)
51
+ tf_module.infer = tf.function( # type: ignore [attr-defined]
52
+ jax2tf.convert(
53
+ model,
54
+ polymorphic_shapes=polymorphic_shapes,
55
+ # setting this to False will allow the model to run on platforms other than the one that exports the model
56
+ # https://github.com/jax-ml/jax/blob/051687dc4c899df3d95c30b812ade401d8b31166/jax/experimental/jax2tf/README.md?plain=1#L1342
57
+ # generally though I think native_serialization is recommended
58
+ native_serialization=False,
59
+ with_gradient=False,
60
+ ),
61
+ autograph=False,
62
+ input_signature=[tf.TensorSpec([batch_size] + list(input_shape), tf.float32) for input_shape in input_shapes],
63
+ )
64
+
65
+ # warm up the model
66
+ _run_infer(tf_module, input_shapes, batch_size)
67
+
68
+ logger.info("Exporting SavedModel to %s", output_dir)
69
+ tf.saved_model.save(
70
+ tf_module,
71
+ output_dir,
72
+ )
73
+
74
+
75
+ def export_with_params(
76
+ model: Callable,
77
+ params: PyTree,
78
+ input_shapes: list[tuple[int, ...]],
79
+ output_dir: str | Path = "export",
80
+ batch_dim: int | None = None,
81
+ ) -> None:
82
+ """Export a JAX function that takes parameters to TensorFlow SavedModel.
83
+
84
+ Args:
85
+ model: The JAX function to export. Should take parameters as first argument.
86
+ params: The parameters to use for the model.
87
+ input_shapes: The shape of the input tensors, excluding batch dimension.
88
+ output_dir: Directory to save the exported model.
89
+ batch_dim: Optional batch dimension. If None, a polymorphic batch dimension is used.
90
+ """
91
+ param_vars = tf.nest.map_structure(tf.Variable, params)
92
+
93
+ converted_model = jax2tf.convert(model)
94
+
95
+ def model_fn(*inputs: PyTree) -> Array:
96
+ return converted_model(param_vars, *inputs)
97
+
98
+ tf_module = tf.Module()
99
+ tf_module._variables = tf.nest.flatten(param_vars) # type: ignore [attr-defined]
100
+ tf_module.infer = tf.function( # type: ignore [attr-defined]
101
+ model_fn,
102
+ jit_compile=True,
103
+ autograph=False,
104
+ input_signature=[tf.TensorSpec([batch_dim] + list(input_shape), tf.float32) for input_shape in input_shapes],
105
+ )
106
+
107
+ # warm up the model
108
+ _run_infer(tf_module, input_shapes, batch_dim)
109
+
110
+ logger.info("Exporting SavedModel to %s", output_dir)
111
+ tf.saved_model.save(tf_module, output_dir)
112
+
113
+
114
+ def export_flax(
115
+ model: flax.linen.Module,
116
+ params: PyTree,
117
+ input_shape: tuple[int, ...],
118
+ preprocessor: Callable | None = None,
119
+ postprocessor: Callable | None = None,
120
+ input_name: str = "inputs",
121
+ output_name: str = "outputs",
122
+ output_dir: str | Path = "export",
123
+ ) -> None:
124
+ jax_module = JaxModule(
125
+ params, model.apply, trainable=False, input_polymorphic_shape="(b, ...)"
126
+ ) # if you want to use a batch dimension
127
+
128
+ # to avoid mapping sequences to ambiguous mappings
129
+ if postprocessor is None:
130
+
131
+ def postprocessor(x: PyTree) -> PyTree:
132
+ return {output_name: x}
133
+
134
+ export_manager = ExportManager(
135
+ jax_module,
136
+ [
137
+ ServingConfig(
138
+ "serving_default",
139
+ input_signature=[tf.TensorSpec([None] + list(input_shape), tf.float32, name=input_name)],
140
+ tf_preprocessor=preprocessor,
141
+ tf_postprocessor=postprocessor,
142
+ )
143
+ ],
144
+ )
145
+
146
+ logger.info("Exporting model to %s", output_dir)
147
+ export_manager.save(output_dir)
xax/nn/geom.py ADDED
@@ -0,0 +1,101 @@
1
+ """Defines geometry functions."""
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+
6
+
7
+ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
8
+ """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
9
+
10
+ Args:
11
+ quat_4: The quaternion to convert, shape (*, 4).
12
+ eps: A small epsilon value to avoid division by zero.
13
+
14
+ Returns:
15
+ The roll, pitch, yaw angles with shape (*, 3).
16
+ """
17
+ quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
18
+ w, x, y, z = jnp.split(quat_4, 4, axis=-1)
19
+
20
+ # Roll (x-axis rotation)
21
+ sinr_cosp = 2.0 * (w * x + y * z)
22
+ cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
23
+ roll = jnp.arctan2(sinr_cosp, cosr_cosp)
24
+
25
+ # Pitch (y-axis rotation)
26
+ sinp = 2.0 * (w * y - z * x)
27
+
28
+ # Handle edge cases where |sinp| >= 1
29
+ pitch = jnp.where(
30
+ jnp.abs(sinp) >= 1.0,
31
+ jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
32
+ jnp.arcsin(sinp),
33
+ )
34
+
35
+ # Yaw (z-axis rotation)
36
+ siny_cosp = 2.0 * (w * z + x * y)
37
+ cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
38
+ yaw = jnp.arctan2(siny_cosp, cosy_cosp)
39
+
40
+ return jnp.concatenate([roll, pitch, yaw], axis=-1)
41
+
42
+
43
+ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
44
+ """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
45
+
46
+ Args:
47
+ euler_3: The roll, pitch, yaw angles, shape (*, 3).
48
+
49
+ Returns:
50
+ The quaternion with shape (*, 4).
51
+ """
52
+ # Extract roll, pitch, yaw from input
53
+ roll, pitch, yaw = jnp.split(euler_3, 3, axis=-1)
54
+
55
+ # Calculate trigonometric functions for each angle
56
+ cr = jnp.cos(roll * 0.5)
57
+ sr = jnp.sin(roll * 0.5)
58
+ cp = jnp.cos(pitch * 0.5)
59
+ sp = jnp.sin(pitch * 0.5)
60
+ cy = jnp.cos(yaw * 0.5)
61
+ sy = jnp.sin(yaw * 0.5)
62
+
63
+ # Calculate quaternion components using the conversion formula
64
+ w = cr * cp * cy + sr * sp * sy
65
+ x = sr * cp * cy - cr * sp * sy
66
+ y = cr * sp * cy + sr * cp * sy
67
+ z = cr * cp * sy - sr * sp * cy
68
+
69
+ # Combine into quaternion [w, x, y, z]
70
+ quat = jnp.concatenate([w, x, y, z], axis=-1)
71
+
72
+ # Normalize the quaternion
73
+ quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
74
+
75
+ return quat
76
+
77
+
78
+ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -> jax.Array:
79
+ """Calculates the gravity vector projected onto the local frame given a quaternion orientation.
80
+
81
+ Args:
82
+ quat: A quaternion (w,x,y,z) representing the orientation, shape (*, 4).
83
+ eps: A small epsilon value to avoid division by zero.
84
+
85
+ Returns:
86
+ A 3D vector representing the gravity in the local frame, shape (*, 3).
87
+ """
88
+ # Normalize quaternion
89
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
90
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
91
+
92
+ # Gravity vector in world frame is [0, 0, -1] (pointing down)
93
+ # Rotate gravity vector using quaternion rotation
94
+
95
+ # Calculate quaternion rotation: q * [0,0,-1] * q^-1
96
+ gx = 2 * (x * z - w * y)
97
+ gy = 2 * (y * z + w * x)
98
+ gz = w * w - x * x - y * y + z * z
99
+
100
+ # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
101
+ return jnp.concatenate([gx, gy, -gz], axis=-1)