xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
11
11
  python -m scripts.update_api --inplace
12
12
  """
13
13
 
14
- __version__ = "0.0.5"
14
+ __version__ = "0.0.7"
15
15
 
16
16
  # This list shouldn't be modified by hand; instead, run the update script.
17
17
  __all__ = [
@@ -34,6 +34,9 @@ __all__ = [
34
34
  "get_positional_embeddings",
35
35
  "get_rotary_embeddings",
36
36
  "rotary_embeddings",
37
+ "euler_to_quat",
38
+ "quat_to_euler",
39
+ "is_master",
37
40
  "BaseLauncher",
38
41
  "CliLauncher",
39
42
  "SingleProcessLauncher",
@@ -56,27 +59,52 @@ __all__ = [
56
59
  "collate",
57
60
  "collate_non_null",
58
61
  "BaseFileDownloader",
62
+ "CumulativeTimer",
59
63
  "DataDownloader",
64
+ "IntervalTicker",
65
+ "IterationTimer",
66
+ "MinGradScaleError",
60
67
  "ModelDownloader",
68
+ "NaNError",
69
+ "StateTimer",
70
+ "TrainingFinishedError",
61
71
  "check_md5",
62
72
  "check_sha256",
73
+ "cpu_count",
74
+ "date_str",
75
+ "diff_configs",
63
76
  "get_git_state",
77
+ "get_random_port",
64
78
  "get_state_dict_prefix",
65
79
  "get_training_code",
66
80
  "save_config",
81
+ "stage_environment",
82
+ "to_markdown_table",
83
+ "jit",
67
84
  "ColoredFormatter",
68
85
  "configure_logging",
69
86
  "one_hot",
70
87
  "partial_flatten",
71
88
  "worker_chunk",
89
+ "profile",
90
+ "compute_nan_ratio",
91
+ "flatten_array",
92
+ "flatten_pytree",
93
+ "slice_array",
94
+ "slice_pytree",
72
95
  "TextBlock",
96
+ "camelcase_to_snakecase",
73
97
  "colored",
74
98
  "format_datetime",
75
99
  "format_timedelta",
100
+ "highlight_exception_message",
101
+ "is_interactive_session",
76
102
  "outlined",
77
103
  "render_text_blocks",
78
104
  "show_error",
105
+ "show_info",
79
106
  "show_warning",
107
+ "snakecase_to_camelcase",
80
108
  "uncolored",
81
109
  "wrapped",
82
110
  ]
@@ -85,8 +113,12 @@ __all__ += [
85
113
  "Batch",
86
114
  "CollateMode",
87
115
  "EmbeddingKind",
116
+ "LOG_ERROR_SUMMARY",
117
+ "LOG_PING",
118
+ "LOG_STATUS",
88
119
  "Output",
89
120
  "Phase",
121
+ "RawConfigType",
90
122
  ]
91
123
 
92
124
  import os
@@ -119,6 +151,9 @@ NAME_MAP: dict[str, str] = {
119
151
  "get_positional_embeddings": "nn.embeddings",
120
152
  "get_rotary_embeddings": "nn.embeddings",
121
153
  "rotary_embeddings": "nn.embeddings",
154
+ "euler_to_quat": "nn.geom",
155
+ "quat_to_euler": "nn.geom",
156
+ "is_master": "nn.parallel",
122
157
  "BaseLauncher": "task.launchers.base",
123
158
  "CliLauncher": "task.launchers.cli",
124
159
  "SingleProcessLauncher": "task.launchers.single_process",
@@ -141,27 +176,52 @@ NAME_MAP: dict[str, str] = {
141
176
  "collate": "utils.data.collate",
142
177
  "collate_non_null": "utils.data.collate",
143
178
  "BaseFileDownloader": "utils.experiments",
179
+ "CumulativeTimer": "utils.experiments",
144
180
  "DataDownloader": "utils.experiments",
181
+ "IntervalTicker": "utils.experiments",
182
+ "IterationTimer": "utils.experiments",
183
+ "MinGradScaleError": "utils.experiments",
145
184
  "ModelDownloader": "utils.experiments",
185
+ "NaNError": "utils.experiments",
186
+ "StateTimer": "utils.experiments",
187
+ "TrainingFinishedError": "utils.experiments",
146
188
  "check_md5": "utils.experiments",
147
189
  "check_sha256": "utils.experiments",
190
+ "cpu_count": "utils.experiments",
191
+ "date_str": "utils.experiments",
192
+ "diff_configs": "utils.experiments",
148
193
  "get_git_state": "utils.experiments",
194
+ "get_random_port": "utils.experiments",
149
195
  "get_state_dict_prefix": "utils.experiments",
150
196
  "get_training_code": "utils.experiments",
151
197
  "save_config": "utils.experiments",
198
+ "stage_environment": "utils.experiments",
199
+ "to_markdown_table": "utils.experiments",
200
+ "jit": "utils.jax",
152
201
  "ColoredFormatter": "utils.logging",
153
202
  "configure_logging": "utils.logging",
154
203
  "one_hot": "utils.numpy",
155
204
  "partial_flatten": "utils.numpy",
156
205
  "worker_chunk": "utils.numpy",
206
+ "profile": "utils.profile",
207
+ "compute_nan_ratio": "utils.pytree",
208
+ "flatten_array": "utils.pytree",
209
+ "flatten_pytree": "utils.pytree",
210
+ "slice_array": "utils.pytree",
211
+ "slice_pytree": "utils.pytree",
157
212
  "TextBlock": "utils.text",
213
+ "camelcase_to_snakecase": "utils.text",
158
214
  "colored": "utils.text",
159
215
  "format_datetime": "utils.text",
160
216
  "format_timedelta": "utils.text",
217
+ "highlight_exception_message": "utils.text",
218
+ "is_interactive_session": "utils.text",
161
219
  "outlined": "utils.text",
162
220
  "render_text_blocks": "utils.text",
163
221
  "show_error": "utils.text",
222
+ "show_info": "utils.text",
164
223
  "show_warning": "utils.text",
224
+ "snakecase_to_camelcase": "utils.text",
165
225
  "uncolored": "utils.text",
166
226
  "wrapped": "utils.text",
167
227
  }
@@ -172,8 +232,12 @@ NAME_MAP.update(
172
232
  "Batch": "task.mixins.train",
173
233
  "CollateMode": "utils.data.collate",
174
234
  "EmbeddingKind": "nn.embeddings",
235
+ "LOG_ERROR_SUMMARY": "utils.logging",
236
+ "LOG_PING": "utils.logging",
237
+ "LOG_STATUS": "utils.logging",
175
238
  "Output": "task.mixins.output",
176
239
  "Phase": "core.state",
240
+ "RawConfigType": "task.base",
177
241
  },
178
242
  )
179
243
 
@@ -211,6 +275,9 @@ if IMPORT_ALL or TYPE_CHECKING:
211
275
  get_rotary_embeddings,
212
276
  rotary_embeddings,
213
277
  )
278
+ from xax.nn.geom import euler_to_quat, quat_to_euler
279
+ from xax.nn.parallel import is_master
280
+ from xax.task.base import RawConfigType
214
281
  from xax.task.launchers.base import BaseLauncher
215
282
  from xax.task.launchers.cli import CliLauncher
216
283
  from xax.task.launchers.single_process import SingleProcessLauncher
@@ -229,26 +296,59 @@ if IMPORT_ALL or TYPE_CHECKING:
229
296
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
230
297
  from xax.utils.experiments import (
231
298
  BaseFileDownloader,
299
+ CumulativeTimer,
232
300
  DataDownloader,
301
+ IntervalTicker,
302
+ IterationTimer,
303
+ MinGradScaleError,
233
304
  ModelDownloader,
305
+ NaNError,
306
+ StateTimer,
307
+ TrainingFinishedError,
234
308
  check_md5,
235
309
  check_sha256,
310
+ cpu_count,
311
+ date_str,
312
+ diff_configs,
236
313
  get_git_state,
314
+ get_random_port,
237
315
  get_state_dict_prefix,
238
316
  get_training_code,
239
317
  save_config,
318
+ stage_environment,
319
+ to_markdown_table,
320
+ )
321
+ from xax.utils.jax import jit
322
+ from xax.utils.logging import (
323
+ LOG_ERROR_SUMMARY,
324
+ LOG_PING,
325
+ LOG_STATUS,
326
+ ColoredFormatter,
327
+ configure_logging,
240
328
  )
241
- from xax.utils.logging import ColoredFormatter, configure_logging
242
329
  from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
330
+ from xax.utils.profile import profile
331
+ from xax.utils.pytree import (
332
+ compute_nan_ratio,
333
+ flatten_array,
334
+ flatten_pytree,
335
+ slice_array,
336
+ slice_pytree,
337
+ )
243
338
  from xax.utils.text import (
244
339
  TextBlock,
340
+ camelcase_to_snakecase,
245
341
  colored,
246
342
  format_datetime,
247
343
  format_timedelta,
344
+ highlight_exception_message,
345
+ is_interactive_session,
248
346
  outlined,
249
347
  render_text_blocks,
250
348
  show_error,
349
+ show_info,
251
350
  show_warning,
351
+ snakecase_to_camelcase,
252
352
  uncolored,
253
353
  wrapped,
254
354
  )
xax/core/conf.py CHANGED
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_base
6
6
  from pathlib import Path
7
7
  from typing import Any, cast
8
8
 
9
- import jax.numpy as jnp
10
9
  from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
11
10
 
12
11
  from xax.utils.text import show_error
@@ -61,68 +60,44 @@ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
61
60
  return False
62
61
 
63
62
 
64
- @dataclass
63
+ @dataclass(kw_only=True)
65
64
  class Logging:
66
65
  hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
67
- log_level: str = field("INFO", help="The logging level to use")
66
+ log_level: str = field(II("oc.env:XAX_LOG_LEVEL,INFO"), help="The logging level to use")
68
67
 
69
68
 
70
- @dataclass
71
- class Device:
72
- cpu: bool = field(True, help="Whether to use the CPU")
73
- gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
74
- metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
75
- use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
76
- use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
77
- use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
78
- use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
79
-
80
-
81
- def parse_dtype(cfg: Device) -> jnp.dtype | None:
82
- if cfg.use_fp64:
83
- return jnp.float64
84
- if cfg.use_fp32:
85
- return jnp.float32
86
- if cfg.use_bf16:
87
- return jnp.bfloat16
88
- if cfg.use_fp16:
89
- return jnp.float16
90
- return None
91
-
92
-
93
- @dataclass
69
+ @dataclass(kw_only=True)
94
70
  class Triton:
95
71
  use_triton_if_available: bool = field(True, help="Use Triton if available")
96
72
 
97
73
 
98
- @dataclass
74
+ @dataclass(kw_only=True)
99
75
  class Experiment:
100
76
  default_random_seed: int = field(1337, help="The default random seed to use")
101
77
  max_workers: int = field(32, help="Maximum number of workers to use")
102
78
 
103
79
 
104
- @dataclass
80
+ @dataclass(kw_only=True)
105
81
  class Directories:
106
82
  run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
107
83
  data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
108
84
  pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
109
85
 
110
86
 
111
- @dataclass
87
+ @dataclass(kw_only=True)
112
88
  class SlurmPartition:
113
89
  partition: str = field(MISSING, help="The partition name")
114
90
  num_nodes: int = field(1, help="The number of nodes to use")
115
91
 
116
92
 
117
- @dataclass
93
+ @dataclass(kw_only=True)
118
94
  class Slurm:
119
95
  launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
120
96
 
121
97
 
122
- @dataclass
98
+ @dataclass(kw_only=True)
123
99
  class UserConfig:
124
100
  logging: Logging = field(Logging)
125
- device: Device = field(Device)
126
101
  triton: Triton = field(Triton)
127
102
  experiment: Experiment = field(Experiment)
128
103
  directories: Directories = field(Directories)
xax/core/state.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Literal, TypedDict, cast, get_args
5
+ from typing import Literal, NotRequired, TypedDict, cast, get_args
6
6
 
7
7
  from omegaconf import MISSING
8
8
 
@@ -18,16 +18,16 @@ def cast_phase(raw_phase: str) -> Phase:
18
18
 
19
19
 
20
20
  class StateDict(TypedDict, total=False):
21
- num_steps: int
22
- num_samples: int
23
- num_valid_steps: int
24
- num_valid_samples: int
25
- start_time_s: float
26
- elapsed_time_s: float
27
- raw_phase: str
21
+ num_steps: NotRequired[int]
22
+ num_samples: NotRequired[int]
23
+ num_valid_steps: NotRequired[int]
24
+ num_valid_samples: NotRequired[int]
25
+ start_time_s: NotRequired[float]
26
+ elapsed_time_s: NotRequired[float]
27
+ raw_phase: NotRequired[str]
28
28
 
29
29
 
30
- @dataclass(frozen=True)
30
+ @dataclass
31
31
  class State:
32
32
  num_steps: int = field(MISSING, help="Number of steps so far")
33
33
  num_samples: int = field(MISSING, help="Number of sample so far")
@@ -41,6 +41,10 @@ class State:
41
41
  def phase(self) -> Phase:
42
42
  return cast_phase(self.raw_phase)
43
43
 
44
+ @phase.setter
45
+ def phase(self, phase: Phase) -> None:
46
+ self.raw_phase = phase
47
+
44
48
  @classmethod
45
49
  def init_state(cls) -> "State":
46
50
  return cls(
@@ -65,17 +69,3 @@ class State:
65
69
  return self.num_valid_steps
66
70
  case _:
67
71
  raise ValueError(f"Invalid phase: {phase}")
68
-
69
- def replace(self, values: StateDict) -> "State":
70
- return State(
71
- num_steps=values.get("num_steps", self.num_steps),
72
- num_samples=values.get("num_samples", self.num_samples),
73
- num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
74
- num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
75
- start_time_s=values.get("start_time_s", self.start_time_s),
76
- elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
77
- raw_phase=values.get("raw_phase", self.raw_phase),
78
- )
79
-
80
- def with_phase(self, phase: Phase) -> "State":
81
- return self.replace({"raw_phase": phase})
xax/nn/geom.py ADDED
@@ -0,0 +1,75 @@
1
+ """Defines geometry functions."""
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+
6
+
7
+ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
8
+ """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
9
+
10
+ Args:
11
+ quat_4: The quaternion to convert, shape (*, 4).
12
+ eps: A small epsilon value to avoid division by zero.
13
+
14
+ Returns:
15
+ The roll, pitch, yaw angles with shape (*, 3).
16
+ """
17
+ quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
18
+ w, x, y, z = jnp.split(quat_4, 4, axis=-1)
19
+
20
+ # Roll (x-axis rotation)
21
+ sinr_cosp = 2.0 * (w * x + y * z)
22
+ cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
23
+ roll = jnp.arctan2(sinr_cosp, cosr_cosp)
24
+
25
+ # Pitch (y-axis rotation)
26
+ sinp = 2.0 * (w * y - z * x)
27
+
28
+ # Handle edge cases where |sinp| >= 1
29
+ pitch = jnp.where(
30
+ jnp.abs(sinp) >= 1.0,
31
+ jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
32
+ jnp.arcsin(sinp),
33
+ )
34
+
35
+ # Yaw (z-axis rotation)
36
+ siny_cosp = 2.0 * (w * z + x * y)
37
+ cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
38
+ yaw = jnp.arctan2(siny_cosp, cosy_cosp)
39
+
40
+ return jnp.concatenate([roll, pitch, yaw], axis=-1)
41
+
42
+
43
+ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
44
+ """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
45
+
46
+ Args:
47
+ euler_3: The roll, pitch, yaw angles, shape (*, 3).
48
+
49
+ Returns:
50
+ The quaternion with shape (*, 4).
51
+ """
52
+ # Extract roll, pitch, yaw from input
53
+ roll, pitch, yaw = jnp.split(euler_3, 3, axis=-1)
54
+
55
+ # Calculate trigonometric functions for each angle
56
+ cr = jnp.cos(roll * 0.5)
57
+ sr = jnp.sin(roll * 0.5)
58
+ cp = jnp.cos(pitch * 0.5)
59
+ sp = jnp.sin(pitch * 0.5)
60
+ cy = jnp.cos(yaw * 0.5)
61
+ sy = jnp.sin(yaw * 0.5)
62
+
63
+ # Calculate quaternion components using the conversion formula
64
+ w = cr * cp * cy + sr * sp * sy
65
+ x = sr * cp * cy - cr * sp * sy
66
+ y = cr * sp * cy + sr * cp * sy
67
+ z = cr * cp * sy - sr * sp * cy
68
+
69
+ # Combine into quaternion [w, x, y, z]
70
+ quat = jnp.concatenate([w, x, y, z], axis=-1)
71
+
72
+ # Normalize the quaternion
73
+ quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
74
+
75
+ return quat
xax/requirements.txt CHANGED
@@ -6,6 +6,8 @@ jaxtyping
6
6
  equinox
7
7
  optax
8
8
  dpshdl
9
+ chex
10
+ importlib-resources
9
11
 
10
12
  # Data processing and serialization
11
13
  cloudpickle
xax/task/base.py CHANGED
@@ -15,6 +15,7 @@ from pathlib import Path
15
15
  from types import TracebackType
16
16
  from typing import Generic, Self, TypeVar, cast
17
17
 
18
+ import jax
18
19
  from omegaconf import Container, DictConfig, OmegaConf
19
20
 
20
21
  from xax.core.state import State
@@ -23,6 +24,7 @@ from xax.utils.text import camelcase_to_snakecase
23
24
  logger = logging.getLogger(__name__)
24
25
 
25
26
 
27
+ @jax.tree_util.register_dataclass
26
28
  @dataclass
27
29
  class BaseConfig:
28
30
  pass