xax 0.3.2__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.2/xax.egg-info → xax-0.3.3}/PKG-INFO +1 -1
- {xax-0.3.2 → xax-0.3.3}/xax/__init__.py +6 -3
- {xax-0.3.2 → xax-0.3.3}/xax/nn/geom.py +82 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/runnable.py +1 -2
- {xax-0.3.2 → xax-0.3.3/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.2 → xax-0.3.3}/LICENSE +0 -0
- {xax-0.3.2 → xax-0.3.3}/MANIFEST.in +0 -0
- {xax-0.3.2 → xax-0.3.3}/README.md +0 -0
- {xax-0.3.2 → xax-0.3.3}/pyproject.toml +0 -0
- {xax-0.3.2 → xax-0.3.3}/setup.cfg +0 -0
- {xax-0.3.2 → xax-0.3.3}/setup.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/cli/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/cli/edit_config.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/core/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/core/conf.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/core/state.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/attention.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/embeddings.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/functions.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/losses.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/metrics.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/parallel.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/nn/ssm.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/py.typed +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/requirements-dev.txt +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/requirements.txt +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/base.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/base.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/logger.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/json.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/state.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/process.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/mixins/train.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/script.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/task/task.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/data/collate.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/debugging.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/experiments.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/jax.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/logging.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/numpy.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/profile.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/pytree.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/text.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.2 → xax-0.3.3}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.3"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -42,11 +42,12 @@ __all__ = [
|
|
42
42
|
"euler_to_quat",
|
43
43
|
"get_projected_gravity_vector_from_quat",
|
44
44
|
"normalize",
|
45
|
+
"quat_mul",
|
45
46
|
"quat_to_euler",
|
46
47
|
"quat_to_rotmat",
|
47
|
-
"quat_mul",
|
48
48
|
"rotate_vector_by_quat",
|
49
49
|
"rotation6d_to_rotation_matrix",
|
50
|
+
"rotation_matrix_to_quat",
|
50
51
|
"rotation_matrix_to_rotation6d",
|
51
52
|
"cross_entropy",
|
52
53
|
"cast_norm_type",
|
@@ -224,11 +225,12 @@ NAME_MAP: dict[str, str] = {
|
|
224
225
|
"euler_to_quat": "nn.geom",
|
225
226
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
226
227
|
"normalize": "nn.geom",
|
228
|
+
"quat_mul": "nn.geom",
|
227
229
|
"quat_to_euler": "nn.geom",
|
228
230
|
"quat_to_rotmat": "nn.geom",
|
229
|
-
"quat_mul": "nn.geom",
|
230
231
|
"rotate_vector_by_quat": "nn.geom",
|
231
232
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
233
|
+
"rotation_matrix_to_quat": "nn.geom",
|
232
234
|
"rotation_matrix_to_rotation6d": "nn.geom",
|
233
235
|
"cross_entropy": "nn.losses",
|
234
236
|
"cast_norm_type": "nn.metrics",
|
@@ -405,6 +407,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
405
407
|
quat_to_rotmat,
|
406
408
|
rotate_vector_by_quat,
|
407
409
|
rotation6d_to_rotation_matrix,
|
410
|
+
rotation_matrix_to_quat,
|
408
411
|
rotation_matrix_to_rotation6d,
|
409
412
|
)
|
410
413
|
from xax.nn.losses import cross_entropy
|
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
|
|
272
272
|
z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
|
273
273
|
|
274
274
|
return jnp.concatenate([w, x, y, z], axis=-1)
|
275
|
+
|
276
|
+
|
277
|
+
def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
|
278
|
+
"""Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
|
282
|
+
eps: A small epsilon value to avoid division by zero when normalising.
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
A quaternion with shape ``(*, 4)``.
|
286
|
+
"""
|
287
|
+
chex.assert_shape(rotation_matrix, (..., 3, 3))
|
288
|
+
|
289
|
+
m00 = rotation_matrix[..., 0, 0]
|
290
|
+
m01 = rotation_matrix[..., 0, 1]
|
291
|
+
m02 = rotation_matrix[..., 0, 2]
|
292
|
+
m10 = rotation_matrix[..., 1, 0]
|
293
|
+
m11 = rotation_matrix[..., 1, 1]
|
294
|
+
m12 = rotation_matrix[..., 1, 2]
|
295
|
+
m20 = rotation_matrix[..., 2, 0]
|
296
|
+
m21 = rotation_matrix[..., 2, 1]
|
297
|
+
m22 = rotation_matrix[..., 2, 2]
|
298
|
+
|
299
|
+
trace = m00 + m11 + m22
|
300
|
+
|
301
|
+
# Case 0: trace is positive
|
302
|
+
s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
|
303
|
+
w0 = 0.25 * s0
|
304
|
+
x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
|
305
|
+
y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
|
306
|
+
z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
|
307
|
+
|
308
|
+
# Case 1: m00 is the largest diagonal term
|
309
|
+
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
|
310
|
+
w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
|
311
|
+
x1 = 0.25 * s1
|
312
|
+
y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
|
313
|
+
z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
|
314
|
+
|
315
|
+
# Case 2: m11 is the largest diagonal term
|
316
|
+
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
|
317
|
+
w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
|
318
|
+
x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
|
319
|
+
y2 = 0.25 * s2
|
320
|
+
z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
|
321
|
+
|
322
|
+
# Case 3: m22 is the largest diagonal term
|
323
|
+
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
|
324
|
+
w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
|
325
|
+
x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
|
326
|
+
y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
|
327
|
+
z3 = 0.25 * s3
|
328
|
+
|
329
|
+
cond0 = trace > 0.0
|
330
|
+
cond1 = (m00 > m11) & (m00 > m22)
|
331
|
+
cond2 = m11 > m22
|
332
|
+
|
333
|
+
w = jnp.where(
|
334
|
+
cond0,
|
335
|
+
w0,
|
336
|
+
jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
|
337
|
+
)
|
338
|
+
x = jnp.where(
|
339
|
+
cond0,
|
340
|
+
x0,
|
341
|
+
jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
|
342
|
+
)
|
343
|
+
y = jnp.where(
|
344
|
+
cond0,
|
345
|
+
y0,
|
346
|
+
jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
|
347
|
+
)
|
348
|
+
z = jnp.where(
|
349
|
+
cond0,
|
350
|
+
z0,
|
351
|
+
jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
|
352
|
+
)
|
353
|
+
|
354
|
+
quat = jnp.stack([w, x, y, z], axis=-1)
|
355
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
356
|
+
return quat
|
@@ -10,6 +10,7 @@ import jax
|
|
10
10
|
|
11
11
|
from xax.task.base import BaseConfig, BaseTask, RawConfigType
|
12
12
|
from xax.task.launchers.base import BaseLauncher
|
13
|
+
from xax.task.launchers.cli import CliLauncher
|
13
14
|
|
14
15
|
|
15
16
|
@jax.tree_util.register_dataclass
|
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
|
|
45
46
|
use_cli: bool | list[str] = True,
|
46
47
|
) -> None:
|
47
48
|
if launcher is None:
|
48
|
-
from xax.task.launchers.cli import CliLauncher
|
49
|
-
|
50
49
|
launcher = CliLauncher()
|
51
50
|
launcher.launch(cls, *cfgs, use_cli=use_cli)
|
52
51
|
|
{xax-0.3.2 → xax-0.3.3}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.3.2 → xax-0.3.3}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|