xax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
11
11
  python -m scripts.update_api --inplace
12
12
  """
13
13
 
14
- __version__ = "0.0.5"
14
+ __version__ = "0.0.6"
15
15
 
16
16
  # This list shouldn't be modified by hand; instead, run the update script.
17
17
  __all__ = [
@@ -34,6 +34,7 @@ __all__ = [
34
34
  "get_positional_embeddings",
35
35
  "get_rotary_embeddings",
36
36
  "rotary_embeddings",
37
+ "is_master",
37
38
  "BaseLauncher",
38
39
  "CliLauncher",
39
40
  "SingleProcessLauncher",
@@ -56,27 +57,45 @@ __all__ = [
56
57
  "collate",
57
58
  "collate_non_null",
58
59
  "BaseFileDownloader",
60
+ "CumulativeTimer",
59
61
  "DataDownloader",
62
+ "IntervalTicker",
63
+ "IterationTimer",
64
+ "MinGradScaleError",
60
65
  "ModelDownloader",
66
+ "NaNError",
67
+ "StateTimer",
68
+ "TrainingFinishedError",
61
69
  "check_md5",
62
70
  "check_sha256",
71
+ "cpu_count",
72
+ "date_str",
73
+ "diff_configs",
63
74
  "get_git_state",
75
+ "get_random_port",
64
76
  "get_state_dict_prefix",
65
77
  "get_training_code",
66
78
  "save_config",
79
+ "stage_environment",
80
+ "to_markdown_table",
67
81
  "ColoredFormatter",
68
82
  "configure_logging",
69
83
  "one_hot",
70
84
  "partial_flatten",
71
85
  "worker_chunk",
72
86
  "TextBlock",
87
+ "camelcase_to_snakecase",
73
88
  "colored",
74
89
  "format_datetime",
75
90
  "format_timedelta",
91
+ "highlight_exception_message",
92
+ "is_interactive_session",
76
93
  "outlined",
77
94
  "render_text_blocks",
78
95
  "show_error",
96
+ "show_info",
79
97
  "show_warning",
98
+ "snakecase_to_camelcase",
80
99
  "uncolored",
81
100
  "wrapped",
82
101
  ]
@@ -85,8 +104,12 @@ __all__ += [
85
104
  "Batch",
86
105
  "CollateMode",
87
106
  "EmbeddingKind",
107
+ "LOG_ERROR_SUMMARY",
108
+ "LOG_PING",
109
+ "LOG_STATUS",
88
110
  "Output",
89
111
  "Phase",
112
+ "RawConfigType",
90
113
  ]
91
114
 
92
115
  import os
@@ -119,6 +142,7 @@ NAME_MAP: dict[str, str] = {
119
142
  "get_positional_embeddings": "nn.embeddings",
120
143
  "get_rotary_embeddings": "nn.embeddings",
121
144
  "rotary_embeddings": "nn.embeddings",
145
+ "is_master": "nn.parallel",
122
146
  "BaseLauncher": "task.launchers.base",
123
147
  "CliLauncher": "task.launchers.cli",
124
148
  "SingleProcessLauncher": "task.launchers.single_process",
@@ -141,27 +165,45 @@ NAME_MAP: dict[str, str] = {
141
165
  "collate": "utils.data.collate",
142
166
  "collate_non_null": "utils.data.collate",
143
167
  "BaseFileDownloader": "utils.experiments",
168
+ "CumulativeTimer": "utils.experiments",
144
169
  "DataDownloader": "utils.experiments",
170
+ "IntervalTicker": "utils.experiments",
171
+ "IterationTimer": "utils.experiments",
172
+ "MinGradScaleError": "utils.experiments",
145
173
  "ModelDownloader": "utils.experiments",
174
+ "NaNError": "utils.experiments",
175
+ "StateTimer": "utils.experiments",
176
+ "TrainingFinishedError": "utils.experiments",
146
177
  "check_md5": "utils.experiments",
147
178
  "check_sha256": "utils.experiments",
179
+ "cpu_count": "utils.experiments",
180
+ "date_str": "utils.experiments",
181
+ "diff_configs": "utils.experiments",
148
182
  "get_git_state": "utils.experiments",
183
+ "get_random_port": "utils.experiments",
149
184
  "get_state_dict_prefix": "utils.experiments",
150
185
  "get_training_code": "utils.experiments",
151
186
  "save_config": "utils.experiments",
187
+ "stage_environment": "utils.experiments",
188
+ "to_markdown_table": "utils.experiments",
152
189
  "ColoredFormatter": "utils.logging",
153
190
  "configure_logging": "utils.logging",
154
191
  "one_hot": "utils.numpy",
155
192
  "partial_flatten": "utils.numpy",
156
193
  "worker_chunk": "utils.numpy",
157
194
  "TextBlock": "utils.text",
195
+ "camelcase_to_snakecase": "utils.text",
158
196
  "colored": "utils.text",
159
197
  "format_datetime": "utils.text",
160
198
  "format_timedelta": "utils.text",
199
+ "highlight_exception_message": "utils.text",
200
+ "is_interactive_session": "utils.text",
161
201
  "outlined": "utils.text",
162
202
  "render_text_blocks": "utils.text",
163
203
  "show_error": "utils.text",
204
+ "show_info": "utils.text",
164
205
  "show_warning": "utils.text",
206
+ "snakecase_to_camelcase": "utils.text",
165
207
  "uncolored": "utils.text",
166
208
  "wrapped": "utils.text",
167
209
  }
@@ -172,8 +214,12 @@ NAME_MAP.update(
172
214
  "Batch": "task.mixins.train",
173
215
  "CollateMode": "utils.data.collate",
174
216
  "EmbeddingKind": "nn.embeddings",
217
+ "LOG_ERROR_SUMMARY": "utils.logging",
218
+ "LOG_PING": "utils.logging",
219
+ "LOG_STATUS": "utils.logging",
175
220
  "Output": "task.mixins.output",
176
221
  "Phase": "core.state",
222
+ "RawConfigType": "task.base",
177
223
  },
178
224
  )
179
225
 
@@ -211,6 +257,8 @@ if IMPORT_ALL or TYPE_CHECKING:
211
257
  get_rotary_embeddings,
212
258
  rotary_embeddings,
213
259
  )
260
+ from xax.nn.parallel import is_master
261
+ from xax.task.base import RawConfigType
214
262
  from xax.task.launchers.base import BaseLauncher
215
263
  from xax.task.launchers.cli import CliLauncher
216
264
  from xax.task.launchers.single_process import SingleProcessLauncher
@@ -229,26 +277,50 @@ if IMPORT_ALL or TYPE_CHECKING:
229
277
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
230
278
  from xax.utils.experiments import (
231
279
  BaseFileDownloader,
280
+ CumulativeTimer,
232
281
  DataDownloader,
282
+ IntervalTicker,
283
+ IterationTimer,
284
+ MinGradScaleError,
233
285
  ModelDownloader,
286
+ NaNError,
287
+ StateTimer,
288
+ TrainingFinishedError,
234
289
  check_md5,
235
290
  check_sha256,
291
+ cpu_count,
292
+ date_str,
293
+ diff_configs,
236
294
  get_git_state,
295
+ get_random_port,
237
296
  get_state_dict_prefix,
238
297
  get_training_code,
239
298
  save_config,
299
+ stage_environment,
300
+ to_markdown_table,
301
+ )
302
+ from xax.utils.logging import (
303
+ LOG_ERROR_SUMMARY,
304
+ LOG_PING,
305
+ LOG_STATUS,
306
+ ColoredFormatter,
307
+ configure_logging,
240
308
  )
241
- from xax.utils.logging import ColoredFormatter, configure_logging
242
309
  from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
243
310
  from xax.utils.text import (
244
311
  TextBlock,
312
+ camelcase_to_snakecase,
245
313
  colored,
246
314
  format_datetime,
247
315
  format_timedelta,
316
+ highlight_exception_message,
317
+ is_interactive_session,
248
318
  outlined,
249
319
  render_text_blocks,
250
320
  show_error,
321
+ show_info,
251
322
  show_warning,
323
+ snakecase_to_camelcase,
252
324
  uncolored,
253
325
  wrapped,
254
326
  )
xax/core/conf.py CHANGED
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_base
6
6
  from pathlib import Path
7
7
  from typing import Any, cast
8
8
 
9
- import jax.numpy as jnp
10
9
  from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
11
10
 
12
11
  from xax.utils.text import show_error
@@ -61,68 +60,44 @@ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
61
60
  return False
62
61
 
63
62
 
64
- @dataclass
63
+ @dataclass(kw_only=True)
65
64
  class Logging:
66
65
  hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
67
- log_level: str = field("INFO", help="The logging level to use")
66
+ log_level: str = field(II("oc.env:XAX_LOG_LEVEL,INFO"), help="The logging level to use")
68
67
 
69
68
 
70
- @dataclass
71
- class Device:
72
- cpu: bool = field(True, help="Whether to use the CPU")
73
- gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
74
- metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
75
- use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
76
- use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
77
- use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
78
- use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
79
-
80
-
81
- def parse_dtype(cfg: Device) -> jnp.dtype | None:
82
- if cfg.use_fp64:
83
- return jnp.float64
84
- if cfg.use_fp32:
85
- return jnp.float32
86
- if cfg.use_bf16:
87
- return jnp.bfloat16
88
- if cfg.use_fp16:
89
- return jnp.float16
90
- return None
91
-
92
-
93
- @dataclass
69
+ @dataclass(kw_only=True)
94
70
  class Triton:
95
71
  use_triton_if_available: bool = field(True, help="Use Triton if available")
96
72
 
97
73
 
98
- @dataclass
74
+ @dataclass(kw_only=True)
99
75
  class Experiment:
100
76
  default_random_seed: int = field(1337, help="The default random seed to use")
101
77
  max_workers: int = field(32, help="Maximum number of workers to use")
102
78
 
103
79
 
104
- @dataclass
80
+ @dataclass(kw_only=True)
105
81
  class Directories:
106
82
  run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
107
83
  data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
108
84
  pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
109
85
 
110
86
 
111
- @dataclass
87
+ @dataclass(kw_only=True)
112
88
  class SlurmPartition:
113
89
  partition: str = field(MISSING, help="The partition name")
114
90
  num_nodes: int = field(1, help="The number of nodes to use")
115
91
 
116
92
 
117
- @dataclass
93
+ @dataclass(kw_only=True)
118
94
  class Slurm:
119
95
  launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
120
96
 
121
97
 
122
- @dataclass
98
+ @dataclass(kw_only=True)
123
99
  class UserConfig:
124
100
  logging: Logging = field(Logging)
125
- device: Device = field(Device)
126
101
  triton: Triton = field(Triton)
127
102
  experiment: Experiment = field(Experiment)
128
103
  directories: Directories = field(Directories)
xax/core/state.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Literal, TypedDict, cast, get_args
5
+ from typing import Literal, NotRequired, TypedDict, cast, get_args
6
6
 
7
7
  from omegaconf import MISSING
8
8
 
@@ -18,16 +18,16 @@ def cast_phase(raw_phase: str) -> Phase:
18
18
 
19
19
 
20
20
  class StateDict(TypedDict, total=False):
21
- num_steps: int
22
- num_samples: int
23
- num_valid_steps: int
24
- num_valid_samples: int
25
- start_time_s: float
26
- elapsed_time_s: float
27
- raw_phase: str
21
+ num_steps: NotRequired[int]
22
+ num_samples: NotRequired[int]
23
+ num_valid_steps: NotRequired[int]
24
+ num_valid_samples: NotRequired[int]
25
+ start_time_s: NotRequired[float]
26
+ elapsed_time_s: NotRequired[float]
27
+ raw_phase: NotRequired[str]
28
28
 
29
29
 
30
- @dataclass(frozen=True)
30
+ @dataclass
31
31
  class State:
32
32
  num_steps: int = field(MISSING, help="Number of steps so far")
33
33
  num_samples: int = field(MISSING, help="Number of sample so far")
@@ -41,6 +41,10 @@ class State:
41
41
  def phase(self) -> Phase:
42
42
  return cast_phase(self.raw_phase)
43
43
 
44
+ @phase.setter
45
+ def phase(self, phase: Phase) -> None:
46
+ self.raw_phase = phase
47
+
44
48
  @classmethod
45
49
  def init_state(cls) -> "State":
46
50
  return cls(
@@ -65,17 +69,3 @@ class State:
65
69
  return self.num_valid_steps
66
70
  case _:
67
71
  raise ValueError(f"Invalid phase: {phase}")
68
-
69
- def replace(self, values: StateDict) -> "State":
70
- return State(
71
- num_steps=values.get("num_steps", self.num_steps),
72
- num_samples=values.get("num_samples", self.num_samples),
73
- num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
74
- num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
75
- start_time_s=values.get("start_time_s", self.start_time_s),
76
- elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
77
- raw_phase=values.get("raw_phase", self.raw_phase),
78
- )
79
-
80
- def with_phase(self, phase: Phase) -> "State":
81
- return self.replace({"raw_phase": phase})
xax/requirements.txt CHANGED
@@ -6,6 +6,8 @@ jaxtyping
6
6
  equinox
7
7
  optax
8
8
  dpshdl
9
+ chex
10
+ importlib-resources
9
11
 
10
12
  # Data processing and serialization
11
13
  cloudpickle
xax/task/base.py CHANGED
@@ -15,6 +15,7 @@ from pathlib import Path
15
15
  from types import TracebackType
16
16
  from typing import Generic, Self, TypeVar, cast
17
17
 
18
+ import jax
18
19
  from omegaconf import Container, DictConfig, OmegaConf
19
20
 
20
21
  from xax.core.state import State
@@ -23,6 +24,7 @@ from xax.utils.text import camelcase_to_snakecase
23
24
  logger = logging.getLogger(__name__)
24
25
 
25
26
 
27
+ @jax.tree_util.register_dataclass
26
28
  @dataclass
27
29
  class BaseConfig:
28
30
  pass