xax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +74 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/METADATA +12 -2
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
|
|
11
11
|
python -m scripts.update_api --inplace
|
12
12
|
"""
|
13
13
|
|
14
|
-
__version__ = "0.0.
|
14
|
+
__version__ = "0.0.6"
|
15
15
|
|
16
16
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
17
|
__all__ = [
|
@@ -34,6 +34,7 @@ __all__ = [
|
|
34
34
|
"get_positional_embeddings",
|
35
35
|
"get_rotary_embeddings",
|
36
36
|
"rotary_embeddings",
|
37
|
+
"is_master",
|
37
38
|
"BaseLauncher",
|
38
39
|
"CliLauncher",
|
39
40
|
"SingleProcessLauncher",
|
@@ -56,27 +57,45 @@ __all__ = [
|
|
56
57
|
"collate",
|
57
58
|
"collate_non_null",
|
58
59
|
"BaseFileDownloader",
|
60
|
+
"CumulativeTimer",
|
59
61
|
"DataDownloader",
|
62
|
+
"IntervalTicker",
|
63
|
+
"IterationTimer",
|
64
|
+
"MinGradScaleError",
|
60
65
|
"ModelDownloader",
|
66
|
+
"NaNError",
|
67
|
+
"StateTimer",
|
68
|
+
"TrainingFinishedError",
|
61
69
|
"check_md5",
|
62
70
|
"check_sha256",
|
71
|
+
"cpu_count",
|
72
|
+
"date_str",
|
73
|
+
"diff_configs",
|
63
74
|
"get_git_state",
|
75
|
+
"get_random_port",
|
64
76
|
"get_state_dict_prefix",
|
65
77
|
"get_training_code",
|
66
78
|
"save_config",
|
79
|
+
"stage_environment",
|
80
|
+
"to_markdown_table",
|
67
81
|
"ColoredFormatter",
|
68
82
|
"configure_logging",
|
69
83
|
"one_hot",
|
70
84
|
"partial_flatten",
|
71
85
|
"worker_chunk",
|
72
86
|
"TextBlock",
|
87
|
+
"camelcase_to_snakecase",
|
73
88
|
"colored",
|
74
89
|
"format_datetime",
|
75
90
|
"format_timedelta",
|
91
|
+
"highlight_exception_message",
|
92
|
+
"is_interactive_session",
|
76
93
|
"outlined",
|
77
94
|
"render_text_blocks",
|
78
95
|
"show_error",
|
96
|
+
"show_info",
|
79
97
|
"show_warning",
|
98
|
+
"snakecase_to_camelcase",
|
80
99
|
"uncolored",
|
81
100
|
"wrapped",
|
82
101
|
]
|
@@ -85,8 +104,12 @@ __all__ += [
|
|
85
104
|
"Batch",
|
86
105
|
"CollateMode",
|
87
106
|
"EmbeddingKind",
|
107
|
+
"LOG_ERROR_SUMMARY",
|
108
|
+
"LOG_PING",
|
109
|
+
"LOG_STATUS",
|
88
110
|
"Output",
|
89
111
|
"Phase",
|
112
|
+
"RawConfigType",
|
90
113
|
]
|
91
114
|
|
92
115
|
import os
|
@@ -119,6 +142,7 @@ NAME_MAP: dict[str, str] = {
|
|
119
142
|
"get_positional_embeddings": "nn.embeddings",
|
120
143
|
"get_rotary_embeddings": "nn.embeddings",
|
121
144
|
"rotary_embeddings": "nn.embeddings",
|
145
|
+
"is_master": "nn.parallel",
|
122
146
|
"BaseLauncher": "task.launchers.base",
|
123
147
|
"CliLauncher": "task.launchers.cli",
|
124
148
|
"SingleProcessLauncher": "task.launchers.single_process",
|
@@ -141,27 +165,45 @@ NAME_MAP: dict[str, str] = {
|
|
141
165
|
"collate": "utils.data.collate",
|
142
166
|
"collate_non_null": "utils.data.collate",
|
143
167
|
"BaseFileDownloader": "utils.experiments",
|
168
|
+
"CumulativeTimer": "utils.experiments",
|
144
169
|
"DataDownloader": "utils.experiments",
|
170
|
+
"IntervalTicker": "utils.experiments",
|
171
|
+
"IterationTimer": "utils.experiments",
|
172
|
+
"MinGradScaleError": "utils.experiments",
|
145
173
|
"ModelDownloader": "utils.experiments",
|
174
|
+
"NaNError": "utils.experiments",
|
175
|
+
"StateTimer": "utils.experiments",
|
176
|
+
"TrainingFinishedError": "utils.experiments",
|
146
177
|
"check_md5": "utils.experiments",
|
147
178
|
"check_sha256": "utils.experiments",
|
179
|
+
"cpu_count": "utils.experiments",
|
180
|
+
"date_str": "utils.experiments",
|
181
|
+
"diff_configs": "utils.experiments",
|
148
182
|
"get_git_state": "utils.experiments",
|
183
|
+
"get_random_port": "utils.experiments",
|
149
184
|
"get_state_dict_prefix": "utils.experiments",
|
150
185
|
"get_training_code": "utils.experiments",
|
151
186
|
"save_config": "utils.experiments",
|
187
|
+
"stage_environment": "utils.experiments",
|
188
|
+
"to_markdown_table": "utils.experiments",
|
152
189
|
"ColoredFormatter": "utils.logging",
|
153
190
|
"configure_logging": "utils.logging",
|
154
191
|
"one_hot": "utils.numpy",
|
155
192
|
"partial_flatten": "utils.numpy",
|
156
193
|
"worker_chunk": "utils.numpy",
|
157
194
|
"TextBlock": "utils.text",
|
195
|
+
"camelcase_to_snakecase": "utils.text",
|
158
196
|
"colored": "utils.text",
|
159
197
|
"format_datetime": "utils.text",
|
160
198
|
"format_timedelta": "utils.text",
|
199
|
+
"highlight_exception_message": "utils.text",
|
200
|
+
"is_interactive_session": "utils.text",
|
161
201
|
"outlined": "utils.text",
|
162
202
|
"render_text_blocks": "utils.text",
|
163
203
|
"show_error": "utils.text",
|
204
|
+
"show_info": "utils.text",
|
164
205
|
"show_warning": "utils.text",
|
206
|
+
"snakecase_to_camelcase": "utils.text",
|
165
207
|
"uncolored": "utils.text",
|
166
208
|
"wrapped": "utils.text",
|
167
209
|
}
|
@@ -172,8 +214,12 @@ NAME_MAP.update(
|
|
172
214
|
"Batch": "task.mixins.train",
|
173
215
|
"CollateMode": "utils.data.collate",
|
174
216
|
"EmbeddingKind": "nn.embeddings",
|
217
|
+
"LOG_ERROR_SUMMARY": "utils.logging",
|
218
|
+
"LOG_PING": "utils.logging",
|
219
|
+
"LOG_STATUS": "utils.logging",
|
175
220
|
"Output": "task.mixins.output",
|
176
221
|
"Phase": "core.state",
|
222
|
+
"RawConfigType": "task.base",
|
177
223
|
},
|
178
224
|
)
|
179
225
|
|
@@ -211,6 +257,8 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
211
257
|
get_rotary_embeddings,
|
212
258
|
rotary_embeddings,
|
213
259
|
)
|
260
|
+
from xax.nn.parallel import is_master
|
261
|
+
from xax.task.base import RawConfigType
|
214
262
|
from xax.task.launchers.base import BaseLauncher
|
215
263
|
from xax.task.launchers.cli import CliLauncher
|
216
264
|
from xax.task.launchers.single_process import SingleProcessLauncher
|
@@ -229,26 +277,50 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
229
277
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
230
278
|
from xax.utils.experiments import (
|
231
279
|
BaseFileDownloader,
|
280
|
+
CumulativeTimer,
|
232
281
|
DataDownloader,
|
282
|
+
IntervalTicker,
|
283
|
+
IterationTimer,
|
284
|
+
MinGradScaleError,
|
233
285
|
ModelDownloader,
|
286
|
+
NaNError,
|
287
|
+
StateTimer,
|
288
|
+
TrainingFinishedError,
|
234
289
|
check_md5,
|
235
290
|
check_sha256,
|
291
|
+
cpu_count,
|
292
|
+
date_str,
|
293
|
+
diff_configs,
|
236
294
|
get_git_state,
|
295
|
+
get_random_port,
|
237
296
|
get_state_dict_prefix,
|
238
297
|
get_training_code,
|
239
298
|
save_config,
|
299
|
+
stage_environment,
|
300
|
+
to_markdown_table,
|
301
|
+
)
|
302
|
+
from xax.utils.logging import (
|
303
|
+
LOG_ERROR_SUMMARY,
|
304
|
+
LOG_PING,
|
305
|
+
LOG_STATUS,
|
306
|
+
ColoredFormatter,
|
307
|
+
configure_logging,
|
240
308
|
)
|
241
|
-
from xax.utils.logging import ColoredFormatter, configure_logging
|
242
309
|
from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
|
243
310
|
from xax.utils.text import (
|
244
311
|
TextBlock,
|
312
|
+
camelcase_to_snakecase,
|
245
313
|
colored,
|
246
314
|
format_datetime,
|
247
315
|
format_timedelta,
|
316
|
+
highlight_exception_message,
|
317
|
+
is_interactive_session,
|
248
318
|
outlined,
|
249
319
|
render_text_blocks,
|
250
320
|
show_error,
|
321
|
+
show_info,
|
251
322
|
show_warning,
|
323
|
+
snakecase_to_camelcase,
|
252
324
|
uncolored,
|
253
325
|
wrapped,
|
254
326
|
)
|
xax/core/conf.py
CHANGED
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_base
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Any, cast
|
8
8
|
|
9
|
-
import jax.numpy as jnp
|
10
9
|
from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
|
11
10
|
|
12
11
|
from xax.utils.text import show_error
|
@@ -61,68 +60,44 @@ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
|
|
61
60
|
return False
|
62
61
|
|
63
62
|
|
64
|
-
@dataclass
|
63
|
+
@dataclass(kw_only=True)
|
65
64
|
class Logging:
|
66
65
|
hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
|
67
|
-
log_level: str = field("INFO", help="The logging level to use")
|
66
|
+
log_level: str = field(II("oc.env:XAX_LOG_LEVEL,INFO"), help="The logging level to use")
|
68
67
|
|
69
68
|
|
70
|
-
@dataclass
|
71
|
-
class Device:
|
72
|
-
cpu: bool = field(True, help="Whether to use the CPU")
|
73
|
-
gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
|
74
|
-
metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
|
75
|
-
use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
|
76
|
-
use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
|
77
|
-
use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
|
78
|
-
use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
|
79
|
-
|
80
|
-
|
81
|
-
def parse_dtype(cfg: Device) -> jnp.dtype | None:
|
82
|
-
if cfg.use_fp64:
|
83
|
-
return jnp.float64
|
84
|
-
if cfg.use_fp32:
|
85
|
-
return jnp.float32
|
86
|
-
if cfg.use_bf16:
|
87
|
-
return jnp.bfloat16
|
88
|
-
if cfg.use_fp16:
|
89
|
-
return jnp.float16
|
90
|
-
return None
|
91
|
-
|
92
|
-
|
93
|
-
@dataclass
|
69
|
+
@dataclass(kw_only=True)
|
94
70
|
class Triton:
|
95
71
|
use_triton_if_available: bool = field(True, help="Use Triton if available")
|
96
72
|
|
97
73
|
|
98
|
-
@dataclass
|
74
|
+
@dataclass(kw_only=True)
|
99
75
|
class Experiment:
|
100
76
|
default_random_seed: int = field(1337, help="The default random seed to use")
|
101
77
|
max_workers: int = field(32, help="Maximum number of workers to use")
|
102
78
|
|
103
79
|
|
104
|
-
@dataclass
|
80
|
+
@dataclass(kw_only=True)
|
105
81
|
class Directories:
|
106
82
|
run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
|
107
83
|
data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
|
108
84
|
pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
|
109
85
|
|
110
86
|
|
111
|
-
@dataclass
|
87
|
+
@dataclass(kw_only=True)
|
112
88
|
class SlurmPartition:
|
113
89
|
partition: str = field(MISSING, help="The partition name")
|
114
90
|
num_nodes: int = field(1, help="The number of nodes to use")
|
115
91
|
|
116
92
|
|
117
|
-
@dataclass
|
93
|
+
@dataclass(kw_only=True)
|
118
94
|
class Slurm:
|
119
95
|
launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
|
120
96
|
|
121
97
|
|
122
|
-
@dataclass
|
98
|
+
@dataclass(kw_only=True)
|
123
99
|
class UserConfig:
|
124
100
|
logging: Logging = field(Logging)
|
125
|
-
device: Device = field(Device)
|
126
101
|
triton: Triton = field(Triton)
|
127
102
|
experiment: Experiment = field(Experiment)
|
128
103
|
directories: Directories = field(Directories)
|
xax/core/state.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Literal, TypedDict, cast, get_args
|
5
|
+
from typing import Literal, NotRequired, TypedDict, cast, get_args
|
6
6
|
|
7
7
|
from omegaconf import MISSING
|
8
8
|
|
@@ -18,16 +18,16 @@ def cast_phase(raw_phase: str) -> Phase:
|
|
18
18
|
|
19
19
|
|
20
20
|
class StateDict(TypedDict, total=False):
|
21
|
-
num_steps: int
|
22
|
-
num_samples: int
|
23
|
-
num_valid_steps: int
|
24
|
-
num_valid_samples: int
|
25
|
-
start_time_s: float
|
26
|
-
elapsed_time_s: float
|
27
|
-
raw_phase: str
|
21
|
+
num_steps: NotRequired[int]
|
22
|
+
num_samples: NotRequired[int]
|
23
|
+
num_valid_steps: NotRequired[int]
|
24
|
+
num_valid_samples: NotRequired[int]
|
25
|
+
start_time_s: NotRequired[float]
|
26
|
+
elapsed_time_s: NotRequired[float]
|
27
|
+
raw_phase: NotRequired[str]
|
28
28
|
|
29
29
|
|
30
|
-
@dataclass
|
30
|
+
@dataclass
|
31
31
|
class State:
|
32
32
|
num_steps: int = field(MISSING, help="Number of steps so far")
|
33
33
|
num_samples: int = field(MISSING, help="Number of sample so far")
|
@@ -41,6 +41,10 @@ class State:
|
|
41
41
|
def phase(self) -> Phase:
|
42
42
|
return cast_phase(self.raw_phase)
|
43
43
|
|
44
|
+
@phase.setter
|
45
|
+
def phase(self, phase: Phase) -> None:
|
46
|
+
self.raw_phase = phase
|
47
|
+
|
44
48
|
@classmethod
|
45
49
|
def init_state(cls) -> "State":
|
46
50
|
return cls(
|
@@ -65,17 +69,3 @@ class State:
|
|
65
69
|
return self.num_valid_steps
|
66
70
|
case _:
|
67
71
|
raise ValueError(f"Invalid phase: {phase}")
|
68
|
-
|
69
|
-
def replace(self, values: StateDict) -> "State":
|
70
|
-
return State(
|
71
|
-
num_steps=values.get("num_steps", self.num_steps),
|
72
|
-
num_samples=values.get("num_samples", self.num_samples),
|
73
|
-
num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
|
74
|
-
num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
|
75
|
-
start_time_s=values.get("start_time_s", self.start_time_s),
|
76
|
-
elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
|
77
|
-
raw_phase=values.get("raw_phase", self.raw_phase),
|
78
|
-
)
|
79
|
-
|
80
|
-
def with_phase(self, phase: Phase) -> "State":
|
81
|
-
return self.replace({"raw_phase": phase})
|
xax/requirements.txt
CHANGED
xax/task/base.py
CHANGED
@@ -15,6 +15,7 @@ from pathlib import Path
|
|
15
15
|
from types import TracebackType
|
16
16
|
from typing import Generic, Self, TypeVar, cast
|
17
17
|
|
18
|
+
import jax
|
18
19
|
from omegaconf import Container, DictConfig, OmegaConf
|
19
20
|
|
20
21
|
from xax.core.state import State
|
@@ -23,6 +24,7 @@ from xax.utils.text import camelcase_to_snakecase
|
|
23
24
|
logger = logging.getLogger(__name__)
|
24
25
|
|
25
26
|
|
27
|
+
@jax.tree_util.register_dataclass
|
26
28
|
@dataclass
|
27
29
|
class BaseConfig:
|
28
30
|
pass
|