xax 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +28 -10
- xax/nn/attention.py +531 -381
- xax/nn/geom.py +82 -0
- xax/task/mixins/runnable.py +1 -2
- {xax-0.3.2.dist-info → xax-0.3.4.dist-info}/METADATA +1 -1
- {xax-0.3.2.dist-info → xax-0.3.4.dist-info}/RECORD +10 -10
- {xax-0.3.2.dist-info → xax-0.3.4.dist-info}/WHEEL +0 -0
- {xax-0.3.2.dist-info → xax-0.3.4.dist-info}/entry_points.txt +0 -0
- {xax-0.3.2.dist-info → xax-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.2.dist-info → xax-0.3.4.dist-info}/top_level.txt +0 -0
xax/nn/geom.py
CHANGED
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
|
|
272
272
|
z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
|
273
273
|
|
274
274
|
return jnp.concatenate([w, x, y, z], axis=-1)
|
275
|
+
|
276
|
+
|
277
|
+
def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
|
278
|
+
"""Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
|
282
|
+
eps: A small epsilon value to avoid division by zero when normalising.
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
A quaternion with shape ``(*, 4)``.
|
286
|
+
"""
|
287
|
+
chex.assert_shape(rotation_matrix, (..., 3, 3))
|
288
|
+
|
289
|
+
m00 = rotation_matrix[..., 0, 0]
|
290
|
+
m01 = rotation_matrix[..., 0, 1]
|
291
|
+
m02 = rotation_matrix[..., 0, 2]
|
292
|
+
m10 = rotation_matrix[..., 1, 0]
|
293
|
+
m11 = rotation_matrix[..., 1, 1]
|
294
|
+
m12 = rotation_matrix[..., 1, 2]
|
295
|
+
m20 = rotation_matrix[..., 2, 0]
|
296
|
+
m21 = rotation_matrix[..., 2, 1]
|
297
|
+
m22 = rotation_matrix[..., 2, 2]
|
298
|
+
|
299
|
+
trace = m00 + m11 + m22
|
300
|
+
|
301
|
+
# Case 0: trace is positive
|
302
|
+
s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
|
303
|
+
w0 = 0.25 * s0
|
304
|
+
x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
|
305
|
+
y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
|
306
|
+
z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
|
307
|
+
|
308
|
+
# Case 1: m00 is the largest diagonal term
|
309
|
+
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
|
310
|
+
w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
|
311
|
+
x1 = 0.25 * s1
|
312
|
+
y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
|
313
|
+
z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
|
314
|
+
|
315
|
+
# Case 2: m11 is the largest diagonal term
|
316
|
+
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
|
317
|
+
w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
|
318
|
+
x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
|
319
|
+
y2 = 0.25 * s2
|
320
|
+
z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
|
321
|
+
|
322
|
+
# Case 3: m22 is the largest diagonal term
|
323
|
+
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
|
324
|
+
w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
|
325
|
+
x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
|
326
|
+
y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
|
327
|
+
z3 = 0.25 * s3
|
328
|
+
|
329
|
+
cond0 = trace > 0.0
|
330
|
+
cond1 = (m00 > m11) & (m00 > m22)
|
331
|
+
cond2 = m11 > m22
|
332
|
+
|
333
|
+
w = jnp.where(
|
334
|
+
cond0,
|
335
|
+
w0,
|
336
|
+
jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
|
337
|
+
)
|
338
|
+
x = jnp.where(
|
339
|
+
cond0,
|
340
|
+
x0,
|
341
|
+
jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
|
342
|
+
)
|
343
|
+
y = jnp.where(
|
344
|
+
cond0,
|
345
|
+
y0,
|
346
|
+
jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
|
347
|
+
)
|
348
|
+
z = jnp.where(
|
349
|
+
cond0,
|
350
|
+
z0,
|
351
|
+
jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
|
352
|
+
)
|
353
|
+
|
354
|
+
quat = jnp.stack([w, x, y, z], axis=-1)
|
355
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
356
|
+
return quat
|
xax/task/mixins/runnable.py
CHANGED
@@ -10,6 +10,7 @@ import jax
|
|
10
10
|
|
11
11
|
from xax.task.base import BaseConfig, BaseTask, RawConfigType
|
12
12
|
from xax.task.launchers.base import BaseLauncher
|
13
|
+
from xax.task.launchers.cli import CliLauncher
|
13
14
|
|
14
15
|
|
15
16
|
@jax.tree_util.register_dataclass
|
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
|
|
45
46
|
use_cli: bool | list[str] = True,
|
46
47
|
) -> None:
|
47
48
|
if launcher is None:
|
48
|
-
from xax.task.launchers.cli import CliLauncher
|
49
|
-
|
50
49
|
launcher = CliLauncher()
|
51
50
|
launcher.launch(cls, *cfgs, use_cli=use_cli)
|
52
51
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=LJFB4xQplzC08tkbkZMxaCd-7jIB7aJZzBMcs9AuqiM,16240
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -8,10 +8,10 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
xax/nn/attention.py,sha256=
|
11
|
+
xax/nn/attention.py,sha256=aIEtrM7vAQtaXTPKmsqGcYqt03CyiUQMccXj8Cjw3vc,29514
|
12
12
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
13
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
|
-
xax/nn/geom.py,sha256=
|
14
|
+
xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
|
15
15
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
16
16
|
xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
|
17
17
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
@@ -40,7 +40,7 @@ xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-go
|
|
40
40
|
xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
|
41
41
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
42
42
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
43
|
-
xax/task/mixins/runnable.py,sha256=
|
43
|
+
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
44
44
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
45
45
|
xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
|
46
46
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
62
|
+
xax-0.3.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.3.4.dist-info/METADATA,sha256=j_UQdK4iPYbhzMH0osmHm5XJnYnFY1A_Z5MwSJwXr-4,1246
|
64
|
+
xax-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
xax-0.3.4.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.3.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|