xax 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.0/xax.egg-info → xax-0.3.2}/PKG-INFO +1 -1
- {xax-0.3.0 → xax-0.3.2}/xax/__init__.py +4 -1
- {xax-0.3.0 → xax-0.3.2}/xax/core/state.py +1 -1
- {xax-0.3.0 → xax-0.3.2}/xax/nn/geom.py +21 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/base.py +3 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/launchers/single_process.py +2 -1
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/artifacts.py +10 -1
- {xax-0.3.0 → xax-0.3.2}/xax/utils/logging.py +3 -1
- {xax-0.3.0 → xax-0.3.2/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.0 → xax-0.3.2}/LICENSE +0 -0
- {xax-0.3.0 → xax-0.3.2}/MANIFEST.in +0 -0
- {xax-0.3.0 → xax-0.3.2}/README.md +0 -0
- {xax-0.3.0 → xax-0.3.2}/pyproject.toml +0 -0
- {xax-0.3.0 → xax-0.3.2}/setup.cfg +0 -0
- {xax-0.3.0 → xax-0.3.2}/setup.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/cli/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/cli/edit_config.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/core/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/core/conf.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/attention.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/embeddings.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/functions.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/losses.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/metrics.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/parallel.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/nn/ssm.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/py.typed +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/requirements-dev.txt +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/requirements.txt +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/launchers/base.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/logger.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/loggers/json.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/loggers/state.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/process.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/mixins/train.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/script.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/task/task.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/data/collate.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/debugging.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/experiments.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/jax.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/numpy.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/profile.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/pytree.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/text.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.0 → xax-0.3.2}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.2"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -44,6 +44,7 @@ __all__ = [
|
|
44
44
|
"normalize",
|
45
45
|
"quat_to_euler",
|
46
46
|
"quat_to_rotmat",
|
47
|
+
"quat_mul",
|
47
48
|
"rotate_vector_by_quat",
|
48
49
|
"rotation6d_to_rotation_matrix",
|
49
50
|
"rotation_matrix_to_rotation6d",
|
@@ -225,6 +226,7 @@ NAME_MAP: dict[str, str] = {
|
|
225
226
|
"normalize": "nn.geom",
|
226
227
|
"quat_to_euler": "nn.geom",
|
227
228
|
"quat_to_rotmat": "nn.geom",
|
229
|
+
"quat_mul": "nn.geom",
|
228
230
|
"rotate_vector_by_quat": "nn.geom",
|
229
231
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
230
232
|
"rotation_matrix_to_rotation6d": "nn.geom",
|
@@ -398,6 +400,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
398
400
|
euler_to_quat,
|
399
401
|
get_projected_gravity_vector_from_quat,
|
400
402
|
normalize,
|
403
|
+
quat_mul,
|
401
404
|
quat_to_euler,
|
402
405
|
quat_to_rotmat,
|
403
406
|
rotate_vector_by_quat,
|
@@ -107,7 +107,7 @@ class State:
|
|
107
107
|
@classmethod
|
108
108
|
def from_dict(cls, **d: Unpack[StateDict]) -> "State":
|
109
109
|
if "phase" in d:
|
110
|
-
d["_phase"] = _phase_to_int(
|
110
|
+
d["_phase"] = _phase_to_int(d.pop("phase"))
|
111
111
|
|
112
112
|
int32_arr = jnp.array(
|
113
113
|
[
|
@@ -251,3 +251,24 @@ def rotation_matrix_to_rotation6d(rotation_matrix: jnp.ndarray) -> jnp.ndarray:
|
|
251
251
|
# Simply concatenate a1 and a2 from SO(3)
|
252
252
|
r6d = jnp.concatenate([rotation_matrix[..., 0], rotation_matrix[..., 1]], axis=-1)
|
253
253
|
return r6d.reshape(shape[:-2] + (6,))
|
254
|
+
|
255
|
+
|
256
|
+
def quat_mul(q2: Array, q1: Array) -> Array:
|
257
|
+
"""Multiply two quaternions (supports batching).
|
258
|
+
|
259
|
+
Args:
|
260
|
+
q2: Second quaternion (w, x, y, z), shape (..., 4)
|
261
|
+
q1: First quaternion (w, x, y, z), shape (..., 4)
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
Product quaternion, shape (..., 4)
|
265
|
+
"""
|
266
|
+
w1, x1, y1, z1 = jnp.split(q1, 4, axis=-1)
|
267
|
+
w2, x2, y2, z2 = jnp.split(q2, 4, axis=-1)
|
268
|
+
|
269
|
+
w = w2 * w1 - x2 * x1 - y2 * y1 - z2 * z1
|
270
|
+
x = w2 * x1 + x2 * w1 + y2 * z1 - z2 * y1
|
271
|
+
y = w2 * y1 - x2 * z1 + y2 * w1 + z2 * x1
|
272
|
+
z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
|
273
|
+
|
274
|
+
return jnp.concatenate([w, x, y, z], axis=-1)
|
@@ -82,6 +82,9 @@ class BaseTask(Generic[Config]):
|
|
82
82
|
def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
|
83
83
|
return state
|
84
84
|
|
85
|
+
def add_logger_handlers(self, logger: logging.Logger) -> None:
|
86
|
+
pass
|
87
|
+
|
85
88
|
@functools.cached_property
|
86
89
|
def task_class_name(self) -> str:
|
87
90
|
return self.__class__.__name__
|
@@ -15,8 +15,9 @@ def run_single_process_training(
|
|
15
15
|
*cfgs: RawConfigType,
|
16
16
|
use_cli: bool | list[str] = True,
|
17
17
|
) -> None:
|
18
|
-
configure_logging()
|
18
|
+
logger = configure_logging()
|
19
19
|
task_obj = task.get_task(*cfgs, use_cli=use_cli)
|
20
|
+
task_obj.add_logger_handlers(logger)
|
20
21
|
task_obj.run()
|
21
22
|
|
22
23
|
|
@@ -14,7 +14,7 @@ from xax.core.state import State
|
|
14
14
|
from xax.nn.parallel import is_master
|
15
15
|
from xax.task.base import BaseConfig, BaseTask
|
16
16
|
from xax.utils.experiments import stage_environment
|
17
|
-
from xax.utils.logging import LOG_STATUS
|
17
|
+
from xax.utils.logging import LOG_STATUS, RankFilter
|
18
18
|
from xax.utils.text import show_info
|
19
19
|
|
20
20
|
logger = logging.getLogger(__name__)
|
@@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|
24
24
|
@dataclass
|
25
25
|
class ArtifactsConfig(BaseConfig):
|
26
26
|
exp_dir: str | None = field(None, help="The fixed experiment directory")
|
27
|
+
log_to_file: bool = field(True, help="If set, add a file handler to the logger to write all logs to the exp dir")
|
27
28
|
|
28
29
|
|
29
30
|
Config = TypeVar("Config", bound=ArtifactsConfig)
|
@@ -39,6 +40,14 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
39
40
|
self._exp_dir = None
|
40
41
|
self._stage_dir = None
|
41
42
|
|
43
|
+
def add_logger_handlers(self, logger: logging.Logger) -> None:
|
44
|
+
super().add_logger_handlers(logger)
|
45
|
+
if is_master() and self.config.log_to_file:
|
46
|
+
file_handler = logging.FileHandler(self.exp_dir / "logs.txt")
|
47
|
+
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
48
|
+
file_handler.addFilter(RankFilter(rank=0))
|
49
|
+
logger.addHandler(file_handler)
|
50
|
+
|
42
51
|
@functools.cached_property
|
43
52
|
def run_dir(self) -> Path:
|
44
53
|
run_dir = get_run_dir()
|
@@ -146,7 +146,7 @@ def configure_logging(
|
|
146
146
|
rank: int | None = None,
|
147
147
|
world_size: int | None = None,
|
148
148
|
debug: bool | None = None,
|
149
|
-
) ->
|
149
|
+
) -> logging.Logger:
|
150
150
|
"""Instantiates logging.
|
151
151
|
|
152
152
|
This captures logs and reroutes them to the Toasts module, which is
|
@@ -186,6 +186,8 @@ def configure_logging(
|
|
186
186
|
logging.getLogger("PIL").setLevel(logging.WARNING)
|
187
187
|
logging.getLogger("torch").setLevel(logging.WARNING)
|
188
188
|
|
189
|
+
return root_logger
|
190
|
+
|
189
191
|
|
190
192
|
def get_unused_port(default: int | None = None) -> int:
|
191
193
|
"""Returns an unused port number on the local machine.
|
{xax-0.3.0 → xax-0.3.2}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.3.0 → xax-0.3.2}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|