xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +102 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/nn/geom.py +75 -0
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/jax.py +126 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +50 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/METADATA +12 -2
- xax-0.0.7.dist-info/RECORD +55 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
|
|
11
11
|
python -m scripts.update_api --inplace
|
12
12
|
"""
|
13
13
|
|
14
|
-
__version__ = "0.0.
|
14
|
+
__version__ = "0.0.7"
|
15
15
|
|
16
16
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
17
|
__all__ = [
|
@@ -34,6 +34,9 @@ __all__ = [
|
|
34
34
|
"get_positional_embeddings",
|
35
35
|
"get_rotary_embeddings",
|
36
36
|
"rotary_embeddings",
|
37
|
+
"euler_to_quat",
|
38
|
+
"quat_to_euler",
|
39
|
+
"is_master",
|
37
40
|
"BaseLauncher",
|
38
41
|
"CliLauncher",
|
39
42
|
"SingleProcessLauncher",
|
@@ -56,27 +59,52 @@ __all__ = [
|
|
56
59
|
"collate",
|
57
60
|
"collate_non_null",
|
58
61
|
"BaseFileDownloader",
|
62
|
+
"CumulativeTimer",
|
59
63
|
"DataDownloader",
|
64
|
+
"IntervalTicker",
|
65
|
+
"IterationTimer",
|
66
|
+
"MinGradScaleError",
|
60
67
|
"ModelDownloader",
|
68
|
+
"NaNError",
|
69
|
+
"StateTimer",
|
70
|
+
"TrainingFinishedError",
|
61
71
|
"check_md5",
|
62
72
|
"check_sha256",
|
73
|
+
"cpu_count",
|
74
|
+
"date_str",
|
75
|
+
"diff_configs",
|
63
76
|
"get_git_state",
|
77
|
+
"get_random_port",
|
64
78
|
"get_state_dict_prefix",
|
65
79
|
"get_training_code",
|
66
80
|
"save_config",
|
81
|
+
"stage_environment",
|
82
|
+
"to_markdown_table",
|
83
|
+
"jit",
|
67
84
|
"ColoredFormatter",
|
68
85
|
"configure_logging",
|
69
86
|
"one_hot",
|
70
87
|
"partial_flatten",
|
71
88
|
"worker_chunk",
|
89
|
+
"profile",
|
90
|
+
"compute_nan_ratio",
|
91
|
+
"flatten_array",
|
92
|
+
"flatten_pytree",
|
93
|
+
"slice_array",
|
94
|
+
"slice_pytree",
|
72
95
|
"TextBlock",
|
96
|
+
"camelcase_to_snakecase",
|
73
97
|
"colored",
|
74
98
|
"format_datetime",
|
75
99
|
"format_timedelta",
|
100
|
+
"highlight_exception_message",
|
101
|
+
"is_interactive_session",
|
76
102
|
"outlined",
|
77
103
|
"render_text_blocks",
|
78
104
|
"show_error",
|
105
|
+
"show_info",
|
79
106
|
"show_warning",
|
107
|
+
"snakecase_to_camelcase",
|
80
108
|
"uncolored",
|
81
109
|
"wrapped",
|
82
110
|
]
|
@@ -85,8 +113,12 @@ __all__ += [
|
|
85
113
|
"Batch",
|
86
114
|
"CollateMode",
|
87
115
|
"EmbeddingKind",
|
116
|
+
"LOG_ERROR_SUMMARY",
|
117
|
+
"LOG_PING",
|
118
|
+
"LOG_STATUS",
|
88
119
|
"Output",
|
89
120
|
"Phase",
|
121
|
+
"RawConfigType",
|
90
122
|
]
|
91
123
|
|
92
124
|
import os
|
@@ -119,6 +151,9 @@ NAME_MAP: dict[str, str] = {
|
|
119
151
|
"get_positional_embeddings": "nn.embeddings",
|
120
152
|
"get_rotary_embeddings": "nn.embeddings",
|
121
153
|
"rotary_embeddings": "nn.embeddings",
|
154
|
+
"euler_to_quat": "nn.geom",
|
155
|
+
"quat_to_euler": "nn.geom",
|
156
|
+
"is_master": "nn.parallel",
|
122
157
|
"BaseLauncher": "task.launchers.base",
|
123
158
|
"CliLauncher": "task.launchers.cli",
|
124
159
|
"SingleProcessLauncher": "task.launchers.single_process",
|
@@ -141,27 +176,52 @@ NAME_MAP: dict[str, str] = {
|
|
141
176
|
"collate": "utils.data.collate",
|
142
177
|
"collate_non_null": "utils.data.collate",
|
143
178
|
"BaseFileDownloader": "utils.experiments",
|
179
|
+
"CumulativeTimer": "utils.experiments",
|
144
180
|
"DataDownloader": "utils.experiments",
|
181
|
+
"IntervalTicker": "utils.experiments",
|
182
|
+
"IterationTimer": "utils.experiments",
|
183
|
+
"MinGradScaleError": "utils.experiments",
|
145
184
|
"ModelDownloader": "utils.experiments",
|
185
|
+
"NaNError": "utils.experiments",
|
186
|
+
"StateTimer": "utils.experiments",
|
187
|
+
"TrainingFinishedError": "utils.experiments",
|
146
188
|
"check_md5": "utils.experiments",
|
147
189
|
"check_sha256": "utils.experiments",
|
190
|
+
"cpu_count": "utils.experiments",
|
191
|
+
"date_str": "utils.experiments",
|
192
|
+
"diff_configs": "utils.experiments",
|
148
193
|
"get_git_state": "utils.experiments",
|
194
|
+
"get_random_port": "utils.experiments",
|
149
195
|
"get_state_dict_prefix": "utils.experiments",
|
150
196
|
"get_training_code": "utils.experiments",
|
151
197
|
"save_config": "utils.experiments",
|
198
|
+
"stage_environment": "utils.experiments",
|
199
|
+
"to_markdown_table": "utils.experiments",
|
200
|
+
"jit": "utils.jax",
|
152
201
|
"ColoredFormatter": "utils.logging",
|
153
202
|
"configure_logging": "utils.logging",
|
154
203
|
"one_hot": "utils.numpy",
|
155
204
|
"partial_flatten": "utils.numpy",
|
156
205
|
"worker_chunk": "utils.numpy",
|
206
|
+
"profile": "utils.profile",
|
207
|
+
"compute_nan_ratio": "utils.pytree",
|
208
|
+
"flatten_array": "utils.pytree",
|
209
|
+
"flatten_pytree": "utils.pytree",
|
210
|
+
"slice_array": "utils.pytree",
|
211
|
+
"slice_pytree": "utils.pytree",
|
157
212
|
"TextBlock": "utils.text",
|
213
|
+
"camelcase_to_snakecase": "utils.text",
|
158
214
|
"colored": "utils.text",
|
159
215
|
"format_datetime": "utils.text",
|
160
216
|
"format_timedelta": "utils.text",
|
217
|
+
"highlight_exception_message": "utils.text",
|
218
|
+
"is_interactive_session": "utils.text",
|
161
219
|
"outlined": "utils.text",
|
162
220
|
"render_text_blocks": "utils.text",
|
163
221
|
"show_error": "utils.text",
|
222
|
+
"show_info": "utils.text",
|
164
223
|
"show_warning": "utils.text",
|
224
|
+
"snakecase_to_camelcase": "utils.text",
|
165
225
|
"uncolored": "utils.text",
|
166
226
|
"wrapped": "utils.text",
|
167
227
|
}
|
@@ -172,8 +232,12 @@ NAME_MAP.update(
|
|
172
232
|
"Batch": "task.mixins.train",
|
173
233
|
"CollateMode": "utils.data.collate",
|
174
234
|
"EmbeddingKind": "nn.embeddings",
|
235
|
+
"LOG_ERROR_SUMMARY": "utils.logging",
|
236
|
+
"LOG_PING": "utils.logging",
|
237
|
+
"LOG_STATUS": "utils.logging",
|
175
238
|
"Output": "task.mixins.output",
|
176
239
|
"Phase": "core.state",
|
240
|
+
"RawConfigType": "task.base",
|
177
241
|
},
|
178
242
|
)
|
179
243
|
|
@@ -211,6 +275,9 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
211
275
|
get_rotary_embeddings,
|
212
276
|
rotary_embeddings,
|
213
277
|
)
|
278
|
+
from xax.nn.geom import euler_to_quat, quat_to_euler
|
279
|
+
from xax.nn.parallel import is_master
|
280
|
+
from xax.task.base import RawConfigType
|
214
281
|
from xax.task.launchers.base import BaseLauncher
|
215
282
|
from xax.task.launchers.cli import CliLauncher
|
216
283
|
from xax.task.launchers.single_process import SingleProcessLauncher
|
@@ -229,26 +296,59 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
229
296
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
230
297
|
from xax.utils.experiments import (
|
231
298
|
BaseFileDownloader,
|
299
|
+
CumulativeTimer,
|
232
300
|
DataDownloader,
|
301
|
+
IntervalTicker,
|
302
|
+
IterationTimer,
|
303
|
+
MinGradScaleError,
|
233
304
|
ModelDownloader,
|
305
|
+
NaNError,
|
306
|
+
StateTimer,
|
307
|
+
TrainingFinishedError,
|
234
308
|
check_md5,
|
235
309
|
check_sha256,
|
310
|
+
cpu_count,
|
311
|
+
date_str,
|
312
|
+
diff_configs,
|
236
313
|
get_git_state,
|
314
|
+
get_random_port,
|
237
315
|
get_state_dict_prefix,
|
238
316
|
get_training_code,
|
239
317
|
save_config,
|
318
|
+
stage_environment,
|
319
|
+
to_markdown_table,
|
320
|
+
)
|
321
|
+
from xax.utils.jax import jit
|
322
|
+
from xax.utils.logging import (
|
323
|
+
LOG_ERROR_SUMMARY,
|
324
|
+
LOG_PING,
|
325
|
+
LOG_STATUS,
|
326
|
+
ColoredFormatter,
|
327
|
+
configure_logging,
|
240
328
|
)
|
241
|
-
from xax.utils.logging import ColoredFormatter, configure_logging
|
242
329
|
from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
|
330
|
+
from xax.utils.profile import profile
|
331
|
+
from xax.utils.pytree import (
|
332
|
+
compute_nan_ratio,
|
333
|
+
flatten_array,
|
334
|
+
flatten_pytree,
|
335
|
+
slice_array,
|
336
|
+
slice_pytree,
|
337
|
+
)
|
243
338
|
from xax.utils.text import (
|
244
339
|
TextBlock,
|
340
|
+
camelcase_to_snakecase,
|
245
341
|
colored,
|
246
342
|
format_datetime,
|
247
343
|
format_timedelta,
|
344
|
+
highlight_exception_message,
|
345
|
+
is_interactive_session,
|
248
346
|
outlined,
|
249
347
|
render_text_blocks,
|
250
348
|
show_error,
|
349
|
+
show_info,
|
251
350
|
show_warning,
|
351
|
+
snakecase_to_camelcase,
|
252
352
|
uncolored,
|
253
353
|
wrapped,
|
254
354
|
)
|
xax/core/conf.py
CHANGED
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_base
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Any, cast
|
8
8
|
|
9
|
-
import jax.numpy as jnp
|
10
9
|
from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
|
11
10
|
|
12
11
|
from xax.utils.text import show_error
|
@@ -61,68 +60,44 @@ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
|
|
61
60
|
return False
|
62
61
|
|
63
62
|
|
64
|
-
@dataclass
|
63
|
+
@dataclass(kw_only=True)
|
65
64
|
class Logging:
|
66
65
|
hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
|
67
|
-
log_level: str = field("INFO", help="The logging level to use")
|
66
|
+
log_level: str = field(II("oc.env:XAX_LOG_LEVEL,INFO"), help="The logging level to use")
|
68
67
|
|
69
68
|
|
70
|
-
@dataclass
|
71
|
-
class Device:
|
72
|
-
cpu: bool = field(True, help="Whether to use the CPU")
|
73
|
-
gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
|
74
|
-
metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
|
75
|
-
use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
|
76
|
-
use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
|
77
|
-
use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
|
78
|
-
use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
|
79
|
-
|
80
|
-
|
81
|
-
def parse_dtype(cfg: Device) -> jnp.dtype | None:
|
82
|
-
if cfg.use_fp64:
|
83
|
-
return jnp.float64
|
84
|
-
if cfg.use_fp32:
|
85
|
-
return jnp.float32
|
86
|
-
if cfg.use_bf16:
|
87
|
-
return jnp.bfloat16
|
88
|
-
if cfg.use_fp16:
|
89
|
-
return jnp.float16
|
90
|
-
return None
|
91
|
-
|
92
|
-
|
93
|
-
@dataclass
|
69
|
+
@dataclass(kw_only=True)
|
94
70
|
class Triton:
|
95
71
|
use_triton_if_available: bool = field(True, help="Use Triton if available")
|
96
72
|
|
97
73
|
|
98
|
-
@dataclass
|
74
|
+
@dataclass(kw_only=True)
|
99
75
|
class Experiment:
|
100
76
|
default_random_seed: int = field(1337, help="The default random seed to use")
|
101
77
|
max_workers: int = field(32, help="Maximum number of workers to use")
|
102
78
|
|
103
79
|
|
104
|
-
@dataclass
|
80
|
+
@dataclass(kw_only=True)
|
105
81
|
class Directories:
|
106
82
|
run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
|
107
83
|
data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
|
108
84
|
pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
|
109
85
|
|
110
86
|
|
111
|
-
@dataclass
|
87
|
+
@dataclass(kw_only=True)
|
112
88
|
class SlurmPartition:
|
113
89
|
partition: str = field(MISSING, help="The partition name")
|
114
90
|
num_nodes: int = field(1, help="The number of nodes to use")
|
115
91
|
|
116
92
|
|
117
|
-
@dataclass
|
93
|
+
@dataclass(kw_only=True)
|
118
94
|
class Slurm:
|
119
95
|
launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
|
120
96
|
|
121
97
|
|
122
|
-
@dataclass
|
98
|
+
@dataclass(kw_only=True)
|
123
99
|
class UserConfig:
|
124
100
|
logging: Logging = field(Logging)
|
125
|
-
device: Device = field(Device)
|
126
101
|
triton: Triton = field(Triton)
|
127
102
|
experiment: Experiment = field(Experiment)
|
128
103
|
directories: Directories = field(Directories)
|
xax/core/state.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Literal, TypedDict, cast, get_args
|
5
|
+
from typing import Literal, NotRequired, TypedDict, cast, get_args
|
6
6
|
|
7
7
|
from omegaconf import MISSING
|
8
8
|
|
@@ -18,16 +18,16 @@ def cast_phase(raw_phase: str) -> Phase:
|
|
18
18
|
|
19
19
|
|
20
20
|
class StateDict(TypedDict, total=False):
|
21
|
-
num_steps: int
|
22
|
-
num_samples: int
|
23
|
-
num_valid_steps: int
|
24
|
-
num_valid_samples: int
|
25
|
-
start_time_s: float
|
26
|
-
elapsed_time_s: float
|
27
|
-
raw_phase: str
|
21
|
+
num_steps: NotRequired[int]
|
22
|
+
num_samples: NotRequired[int]
|
23
|
+
num_valid_steps: NotRequired[int]
|
24
|
+
num_valid_samples: NotRequired[int]
|
25
|
+
start_time_s: NotRequired[float]
|
26
|
+
elapsed_time_s: NotRequired[float]
|
27
|
+
raw_phase: NotRequired[str]
|
28
28
|
|
29
29
|
|
30
|
-
@dataclass
|
30
|
+
@dataclass
|
31
31
|
class State:
|
32
32
|
num_steps: int = field(MISSING, help="Number of steps so far")
|
33
33
|
num_samples: int = field(MISSING, help="Number of sample so far")
|
@@ -41,6 +41,10 @@ class State:
|
|
41
41
|
def phase(self) -> Phase:
|
42
42
|
return cast_phase(self.raw_phase)
|
43
43
|
|
44
|
+
@phase.setter
|
45
|
+
def phase(self, phase: Phase) -> None:
|
46
|
+
self.raw_phase = phase
|
47
|
+
|
44
48
|
@classmethod
|
45
49
|
def init_state(cls) -> "State":
|
46
50
|
return cls(
|
@@ -65,17 +69,3 @@ class State:
|
|
65
69
|
return self.num_valid_steps
|
66
70
|
case _:
|
67
71
|
raise ValueError(f"Invalid phase: {phase}")
|
68
|
-
|
69
|
-
def replace(self, values: StateDict) -> "State":
|
70
|
-
return State(
|
71
|
-
num_steps=values.get("num_steps", self.num_steps),
|
72
|
-
num_samples=values.get("num_samples", self.num_samples),
|
73
|
-
num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
|
74
|
-
num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
|
75
|
-
start_time_s=values.get("start_time_s", self.start_time_s),
|
76
|
-
elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
|
77
|
-
raw_phase=values.get("raw_phase", self.raw_phase),
|
78
|
-
)
|
79
|
-
|
80
|
-
def with_phase(self, phase: Phase) -> "State":
|
81
|
-
return self.replace({"raw_phase": phase})
|
xax/nn/geom.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
"""Defines geometry functions."""
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from jax import numpy as jnp
|
5
|
+
|
6
|
+
|
7
|
+
def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
|
8
|
+
"""Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
quat_4: The quaternion to convert, shape (*, 4).
|
12
|
+
eps: A small epsilon value to avoid division by zero.
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
The roll, pitch, yaw angles with shape (*, 3).
|
16
|
+
"""
|
17
|
+
quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
|
18
|
+
w, x, y, z = jnp.split(quat_4, 4, axis=-1)
|
19
|
+
|
20
|
+
# Roll (x-axis rotation)
|
21
|
+
sinr_cosp = 2.0 * (w * x + y * z)
|
22
|
+
cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
|
23
|
+
roll = jnp.arctan2(sinr_cosp, cosr_cosp)
|
24
|
+
|
25
|
+
# Pitch (y-axis rotation)
|
26
|
+
sinp = 2.0 * (w * y - z * x)
|
27
|
+
|
28
|
+
# Handle edge cases where |sinp| >= 1
|
29
|
+
pitch = jnp.where(
|
30
|
+
jnp.abs(sinp) >= 1.0,
|
31
|
+
jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
|
32
|
+
jnp.arcsin(sinp),
|
33
|
+
)
|
34
|
+
|
35
|
+
# Yaw (z-axis rotation)
|
36
|
+
siny_cosp = 2.0 * (w * z + x * y)
|
37
|
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
38
|
+
yaw = jnp.arctan2(siny_cosp, cosy_cosp)
|
39
|
+
|
40
|
+
return jnp.concatenate([roll, pitch, yaw], axis=-1)
|
41
|
+
|
42
|
+
|
43
|
+
def euler_to_quat(euler_3: jax.Array) -> jax.Array:
|
44
|
+
"""Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
|
45
|
+
|
46
|
+
Args:
|
47
|
+
euler_3: The roll, pitch, yaw angles, shape (*, 3).
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The quaternion with shape (*, 4).
|
51
|
+
"""
|
52
|
+
# Extract roll, pitch, yaw from input
|
53
|
+
roll, pitch, yaw = jnp.split(euler_3, 3, axis=-1)
|
54
|
+
|
55
|
+
# Calculate trigonometric functions for each angle
|
56
|
+
cr = jnp.cos(roll * 0.5)
|
57
|
+
sr = jnp.sin(roll * 0.5)
|
58
|
+
cp = jnp.cos(pitch * 0.5)
|
59
|
+
sp = jnp.sin(pitch * 0.5)
|
60
|
+
cy = jnp.cos(yaw * 0.5)
|
61
|
+
sy = jnp.sin(yaw * 0.5)
|
62
|
+
|
63
|
+
# Calculate quaternion components using the conversion formula
|
64
|
+
w = cr * cp * cy + sr * sp * sy
|
65
|
+
x = sr * cp * cy - cr * sp * sy
|
66
|
+
y = cr * sp * cy + sr * cp * sy
|
67
|
+
z = cr * cp * sy - sr * sp * cy
|
68
|
+
|
69
|
+
# Combine into quaternion [w, x, y, z]
|
70
|
+
quat = jnp.concatenate([w, x, y, z], axis=-1)
|
71
|
+
|
72
|
+
# Normalize the quaternion
|
73
|
+
quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
|
74
|
+
|
75
|
+
return quat
|
xax/requirements.txt
CHANGED
xax/task/base.py
CHANGED
@@ -15,6 +15,7 @@ from pathlib import Path
|
|
15
15
|
from types import TracebackType
|
16
16
|
from typing import Generic, Self, TypeVar, cast
|
17
17
|
|
18
|
+
import jax
|
18
19
|
from omegaconf import Container, DictConfig, OmegaConf
|
19
20
|
|
20
21
|
from xax.core.state import State
|
@@ -23,6 +24,7 @@ from xax.utils.text import camelcase_to_snakecase
|
|
23
24
|
logger = logging.getLogger(__name__)
|
24
25
|
|
25
26
|
|
27
|
+
@jax.tree_util.register_dataclass
|
26
28
|
@dataclass
|
27
29
|
class BaseConfig:
|
28
30
|
pass
|