xax 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/nn/geom.py +1 -14
- xax/task/mixins/checkpointing.py +9 -2
- {xax-0.2.16.dist-info → xax-0.2.18.dist-info}/METADATA +1 -1
- {xax-0.2.16.dist-info → xax-0.2.18.dist-info}/RECORD +8 -8
- {xax-0.2.16.dist-info → xax-0.2.18.dist-info}/WHEEL +1 -1
- {xax-0.2.16.dist-info → xax-0.2.18.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.16.dist-info → xax-0.2.18.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/nn/geom.py
CHANGED
@@ -86,20 +86,7 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
|
|
86
86
|
Returns:
|
87
87
|
A 3D vector representing the gravity in the local frame, shape (*, 3).
|
88
88
|
"""
|
89
|
-
|
90
|
-
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
91
|
-
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
92
|
-
|
93
|
-
# Gravity vector in world frame is [0, 0, -1] (pointing down)
|
94
|
-
# Rotate gravity vector using quaternion rotation
|
95
|
-
|
96
|
-
# Calculate quaternion rotation: q * [0,0,-1] * q^-1
|
97
|
-
gx = 2 * (x * z - w * y)
|
98
|
-
gy = 2 * (y * z + w * x)
|
99
|
-
gz = w * w - x * x - y * y + z * z
|
100
|
-
|
101
|
-
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
102
|
-
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
89
|
+
return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
|
103
90
|
|
104
91
|
|
105
92
|
def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -6,7 +6,7 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import Generic, Literal, Sequence, TypeVar, cast, overload
|
9
|
+
from typing import Generic, Literal, Self, Sequence, TypeVar, cast, overload
|
10
10
|
|
11
11
|
import equinox as eqx
|
12
12
|
import jax
|
@@ -46,7 +46,6 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
46
46
|
save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
|
47
47
|
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
48
48
|
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
49
|
-
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
49
|
|
51
50
|
|
52
51
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
@@ -306,3 +305,11 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
306
305
|
self.on_after_checkpoint_save(ckpt_path, state)
|
307
306
|
|
308
307
|
return ckpt_path
|
308
|
+
|
309
|
+
@classmethod
|
310
|
+
def load_config(cls, ckpt_path: str | Path) -> Config:
|
311
|
+
return cls.get_config(load_ckpt(Path(ckpt_path), part="config"), use_cli=False)
|
312
|
+
|
313
|
+
@classmethod
|
314
|
+
def load_task(cls, ckpt_path: str | Path) -> Self:
|
315
|
+
return cls(cls.load_config(ckpt_path))
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=ru-amUXLIpFI1cE7NfkPpff6NnuLbSh6LHbJkmvH2zM,15733
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -10,7 +10,7 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
|
10
10
|
xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
|
11
11
|
xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
|
12
12
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
13
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=A7WPefMvgwUNReZC7_HX1GmvHPASyghbaXaKsuhwDrE,7382
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/metrics.py,sha256=OAkeScwhi-wTBIJ59KHUhYbZTq4V4V-LG-mKlxMJ7bY,3238
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
@@ -32,7 +32,7 @@ xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,66
|
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=w-xsx8tH8PPYRf9eoqvz-4kAdunHMvazcoZAIj6niyI,11468
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
38
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.18.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.18.dist-info/METADATA,sha256=phU79eE74vMM79HouJP9iQgTDONJ59no3ktUl5GHF78,1880
|
63
|
+
xax-0.2.18.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
|
64
|
+
xax-0.2.18.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|