xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
|
|
11
11
|
python -m scripts.update_api --inplace
|
12
12
|
"""
|
13
13
|
|
14
|
-
__version__ = "0.0.
|
14
|
+
__version__ = "0.0.6"
|
15
15
|
|
16
16
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
17
|
__all__ = [
|
@@ -23,15 +23,26 @@ __all__ = [
|
|
23
23
|
"load_user_config",
|
24
24
|
"State",
|
25
25
|
"cast_phase",
|
26
|
+
"FourierEmbeddings",
|
27
|
+
"IdentityPositionalEmbeddings",
|
28
|
+
"LearnedPositionalEmbeddings",
|
29
|
+
"RotaryEmbeddings",
|
30
|
+
"SinusoidalEmbeddings",
|
31
|
+
"apply_rotary_embeddings",
|
32
|
+
"cast_embedding_kind",
|
33
|
+
"fourier_embeddings",
|
34
|
+
"get_positional_embeddings",
|
35
|
+
"get_rotary_embeddings",
|
36
|
+
"rotary_embeddings",
|
37
|
+
"is_master",
|
26
38
|
"BaseLauncher",
|
27
39
|
"CliLauncher",
|
28
40
|
"SingleProcessLauncher",
|
29
|
-
"LogAudio",
|
30
41
|
"LogImage",
|
31
42
|
"LogLine",
|
32
|
-
"LogVideo",
|
33
43
|
"Logger",
|
34
44
|
"LoggerImpl",
|
45
|
+
"CallbackLogger",
|
35
46
|
"JsonLogger",
|
36
47
|
"StateLogger",
|
37
48
|
"StdoutLogger",
|
@@ -46,34 +57,59 @@ __all__ = [
|
|
46
57
|
"collate",
|
47
58
|
"collate_non_null",
|
48
59
|
"BaseFileDownloader",
|
60
|
+
"CumulativeTimer",
|
49
61
|
"DataDownloader",
|
62
|
+
"IntervalTicker",
|
63
|
+
"IterationTimer",
|
64
|
+
"MinGradScaleError",
|
50
65
|
"ModelDownloader",
|
66
|
+
"NaNError",
|
67
|
+
"StateTimer",
|
68
|
+
"TrainingFinishedError",
|
51
69
|
"check_md5",
|
52
70
|
"check_sha256",
|
71
|
+
"cpu_count",
|
72
|
+
"date_str",
|
73
|
+
"diff_configs",
|
53
74
|
"get_git_state",
|
75
|
+
"get_random_port",
|
54
76
|
"get_state_dict_prefix",
|
55
77
|
"get_training_code",
|
56
78
|
"save_config",
|
79
|
+
"stage_environment",
|
80
|
+
"to_markdown_table",
|
57
81
|
"ColoredFormatter",
|
58
82
|
"configure_logging",
|
59
83
|
"one_hot",
|
60
84
|
"partial_flatten",
|
61
85
|
"worker_chunk",
|
62
86
|
"TextBlock",
|
87
|
+
"camelcase_to_snakecase",
|
63
88
|
"colored",
|
64
89
|
"format_datetime",
|
65
90
|
"format_timedelta",
|
91
|
+
"highlight_exception_message",
|
92
|
+
"is_interactive_session",
|
66
93
|
"outlined",
|
67
94
|
"render_text_blocks",
|
68
95
|
"show_error",
|
96
|
+
"show_info",
|
69
97
|
"show_warning",
|
98
|
+
"snakecase_to_camelcase",
|
70
99
|
"uncolored",
|
71
100
|
"wrapped",
|
72
101
|
]
|
73
102
|
|
74
103
|
__all__ += [
|
104
|
+
"Batch",
|
75
105
|
"CollateMode",
|
106
|
+
"EmbeddingKind",
|
107
|
+
"LOG_ERROR_SUMMARY",
|
108
|
+
"LOG_PING",
|
109
|
+
"LOG_STATUS",
|
110
|
+
"Output",
|
76
111
|
"Phase",
|
112
|
+
"RawConfigType",
|
77
113
|
]
|
78
114
|
|
79
115
|
import os
|
@@ -95,21 +131,32 @@ NAME_MAP: dict[str, str] = {
|
|
95
131
|
"load_user_config": "core.conf",
|
96
132
|
"State": "core.state",
|
97
133
|
"cast_phase": "core.state",
|
134
|
+
"FourierEmbeddings": "nn.embeddings",
|
135
|
+
"IdentityPositionalEmbeddings": "nn.embeddings",
|
136
|
+
"LearnedPositionalEmbeddings": "nn.embeddings",
|
137
|
+
"RotaryEmbeddings": "nn.embeddings",
|
138
|
+
"SinusoidalEmbeddings": "nn.embeddings",
|
139
|
+
"apply_rotary_embeddings": "nn.embeddings",
|
140
|
+
"cast_embedding_kind": "nn.embeddings",
|
141
|
+
"fourier_embeddings": "nn.embeddings",
|
142
|
+
"get_positional_embeddings": "nn.embeddings",
|
143
|
+
"get_rotary_embeddings": "nn.embeddings",
|
144
|
+
"rotary_embeddings": "nn.embeddings",
|
145
|
+
"is_master": "nn.parallel",
|
98
146
|
"BaseLauncher": "task.launchers.base",
|
99
147
|
"CliLauncher": "task.launchers.cli",
|
100
148
|
"SingleProcessLauncher": "task.launchers.single_process",
|
101
|
-
"LogAudio": "task.logger",
|
102
149
|
"LogImage": "task.logger",
|
103
150
|
"LogLine": "task.logger",
|
104
|
-
"LogVideo": "task.logger",
|
105
151
|
"Logger": "task.logger",
|
106
152
|
"LoggerImpl": "task.logger",
|
153
|
+
"CallbackLogger": "task.loggers.callback",
|
107
154
|
"JsonLogger": "task.loggers.json",
|
108
155
|
"StateLogger": "task.loggers.state",
|
109
156
|
"StdoutLogger": "task.loggers.stdout",
|
110
157
|
"TensorboardLogger": "task.loggers.tensorboard",
|
111
158
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
112
|
-
"
|
159
|
+
"DataloaderConfig": "task.mixins.data_loader",
|
113
160
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
114
161
|
"Script": "task.script",
|
115
162
|
"ScriptConfig": "task.script",
|
@@ -118,27 +165,45 @@ NAME_MAP: dict[str, str] = {
|
|
118
165
|
"collate": "utils.data.collate",
|
119
166
|
"collate_non_null": "utils.data.collate",
|
120
167
|
"BaseFileDownloader": "utils.experiments",
|
168
|
+
"CumulativeTimer": "utils.experiments",
|
121
169
|
"DataDownloader": "utils.experiments",
|
170
|
+
"IntervalTicker": "utils.experiments",
|
171
|
+
"IterationTimer": "utils.experiments",
|
172
|
+
"MinGradScaleError": "utils.experiments",
|
122
173
|
"ModelDownloader": "utils.experiments",
|
174
|
+
"NaNError": "utils.experiments",
|
175
|
+
"StateTimer": "utils.experiments",
|
176
|
+
"TrainingFinishedError": "utils.experiments",
|
123
177
|
"check_md5": "utils.experiments",
|
124
178
|
"check_sha256": "utils.experiments",
|
179
|
+
"cpu_count": "utils.experiments",
|
180
|
+
"date_str": "utils.experiments",
|
181
|
+
"diff_configs": "utils.experiments",
|
125
182
|
"get_git_state": "utils.experiments",
|
183
|
+
"get_random_port": "utils.experiments",
|
126
184
|
"get_state_dict_prefix": "utils.experiments",
|
127
185
|
"get_training_code": "utils.experiments",
|
128
186
|
"save_config": "utils.experiments",
|
187
|
+
"stage_environment": "utils.experiments",
|
188
|
+
"to_markdown_table": "utils.experiments",
|
129
189
|
"ColoredFormatter": "utils.logging",
|
130
190
|
"configure_logging": "utils.logging",
|
131
191
|
"one_hot": "utils.numpy",
|
132
192
|
"partial_flatten": "utils.numpy",
|
133
193
|
"worker_chunk": "utils.numpy",
|
134
194
|
"TextBlock": "utils.text",
|
195
|
+
"camelcase_to_snakecase": "utils.text",
|
135
196
|
"colored": "utils.text",
|
136
197
|
"format_datetime": "utils.text",
|
137
198
|
"format_timedelta": "utils.text",
|
199
|
+
"highlight_exception_message": "utils.text",
|
200
|
+
"is_interactive_session": "utils.text",
|
138
201
|
"outlined": "utils.text",
|
139
202
|
"render_text_blocks": "utils.text",
|
140
203
|
"show_error": "utils.text",
|
204
|
+
"show_info": "utils.text",
|
141
205
|
"show_warning": "utils.text",
|
206
|
+
"snakecase_to_camelcase": "utils.text",
|
142
207
|
"uncolored": "utils.text",
|
143
208
|
"wrapped": "utils.text",
|
144
209
|
}
|
@@ -146,8 +211,15 @@ NAME_MAP: dict[str, str] = {
|
|
146
211
|
# Need to manually set some values which can't be auto-generated.
|
147
212
|
NAME_MAP.update(
|
148
213
|
{
|
214
|
+
"Batch": "task.mixins.train",
|
149
215
|
"CollateMode": "utils.data.collate",
|
216
|
+
"EmbeddingKind": "nn.embeddings",
|
217
|
+
"LOG_ERROR_SUMMARY": "utils.logging",
|
218
|
+
"LOG_PING": "utils.logging",
|
219
|
+
"LOG_STATUS": "utils.logging",
|
220
|
+
"Output": "task.mixins.output",
|
150
221
|
"Phase": "core.state",
|
222
|
+
"RawConfigType": "task.base",
|
151
223
|
},
|
152
224
|
)
|
153
225
|
|
@@ -171,10 +243,27 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
171
243
|
load_user_config,
|
172
244
|
)
|
173
245
|
from xax.core.state import Phase, State, cast_phase
|
246
|
+
from xax.nn.embeddings import (
|
247
|
+
EmbeddingKind,
|
248
|
+
FourierEmbeddings,
|
249
|
+
IdentityPositionalEmbeddings,
|
250
|
+
LearnedPositionalEmbeddings,
|
251
|
+
RotaryEmbeddings,
|
252
|
+
SinusoidalEmbeddings,
|
253
|
+
apply_rotary_embeddings,
|
254
|
+
cast_embedding_kind,
|
255
|
+
fourier_embeddings,
|
256
|
+
get_positional_embeddings,
|
257
|
+
get_rotary_embeddings,
|
258
|
+
rotary_embeddings,
|
259
|
+
)
|
260
|
+
from xax.nn.parallel import is_master
|
261
|
+
from xax.task.base import RawConfigType
|
174
262
|
from xax.task.launchers.base import BaseLauncher
|
175
263
|
from xax.task.launchers.cli import CliLauncher
|
176
264
|
from xax.task.launchers.single_process import SingleProcessLauncher
|
177
|
-
from xax.task.logger import
|
265
|
+
from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
|
266
|
+
from xax.task.loggers.callback import CallbackLogger
|
178
267
|
from xax.task.loggers.json import JsonLogger
|
179
268
|
from xax.task.loggers.state import StateLogger
|
180
269
|
from xax.task.loggers.stdout import StdoutLogger
|
@@ -182,31 +271,56 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
182
271
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
183
272
|
from xax.task.mixins.data_loader import DataloaderConfig
|
184
273
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
274
|
+
from xax.task.mixins.train import Batch, Output
|
185
275
|
from xax.task.script import Script, ScriptConfig
|
186
276
|
from xax.task.task import Config, Task
|
187
277
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
188
278
|
from xax.utils.experiments import (
|
189
279
|
BaseFileDownloader,
|
280
|
+
CumulativeTimer,
|
190
281
|
DataDownloader,
|
282
|
+
IntervalTicker,
|
283
|
+
IterationTimer,
|
284
|
+
MinGradScaleError,
|
191
285
|
ModelDownloader,
|
286
|
+
NaNError,
|
287
|
+
StateTimer,
|
288
|
+
TrainingFinishedError,
|
192
289
|
check_md5,
|
193
290
|
check_sha256,
|
291
|
+
cpu_count,
|
292
|
+
date_str,
|
293
|
+
diff_configs,
|
194
294
|
get_git_state,
|
295
|
+
get_random_port,
|
195
296
|
get_state_dict_prefix,
|
196
297
|
get_training_code,
|
197
298
|
save_config,
|
299
|
+
stage_environment,
|
300
|
+
to_markdown_table,
|
301
|
+
)
|
302
|
+
from xax.utils.logging import (
|
303
|
+
LOG_ERROR_SUMMARY,
|
304
|
+
LOG_PING,
|
305
|
+
LOG_STATUS,
|
306
|
+
ColoredFormatter,
|
307
|
+
configure_logging,
|
198
308
|
)
|
199
|
-
from xax.utils.logging import ColoredFormatter, configure_logging
|
200
309
|
from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
|
201
310
|
from xax.utils.text import (
|
202
311
|
TextBlock,
|
312
|
+
camelcase_to_snakecase,
|
203
313
|
colored,
|
204
314
|
format_datetime,
|
205
315
|
format_timedelta,
|
316
|
+
highlight_exception_message,
|
317
|
+
is_interactive_session,
|
206
318
|
outlined,
|
207
319
|
render_text_blocks,
|
208
320
|
show_error,
|
321
|
+
show_info,
|
209
322
|
show_warning,
|
323
|
+
snakecase_to_camelcase,
|
210
324
|
uncolored,
|
211
325
|
wrapped,
|
212
326
|
)
|
xax/core/conf.py
CHANGED
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_base
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Any, cast
|
8
8
|
|
9
|
-
import jax.numpy as jnp
|
10
9
|
from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
|
11
10
|
|
12
11
|
from xax.utils.text import show_error
|
@@ -61,67 +60,44 @@ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
|
|
61
60
|
return False
|
62
61
|
|
63
62
|
|
64
|
-
@dataclass
|
63
|
+
@dataclass(kw_only=True)
|
65
64
|
class Logging:
|
66
65
|
hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
|
67
|
-
log_level: str = field("INFO", help="The logging level to use")
|
66
|
+
log_level: str = field(II("oc.env:XAX_LOG_LEVEL,INFO"), help="The logging level to use")
|
68
67
|
|
69
68
|
|
70
|
-
@dataclass
|
71
|
-
class Device:
|
72
|
-
cpu: bool = field(True, help="Whether to use the CPU")
|
73
|
-
gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
|
74
|
-
metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
|
75
|
-
use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
|
76
|
-
use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
|
77
|
-
use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
|
78
|
-
use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
|
79
|
-
|
80
|
-
|
81
|
-
def parse_dtype(cfg: Device) -> jnp.dtype | None:
|
82
|
-
if cfg.use_fp64:
|
83
|
-
return jnp.float64
|
84
|
-
if cfg.use_fp32:
|
85
|
-
return jnp.float32
|
86
|
-
if cfg.use_bf16:
|
87
|
-
return jnp.bfloat16
|
88
|
-
if cfg.use_fp16:
|
89
|
-
return jnp.float16
|
90
|
-
return None
|
91
|
-
|
92
|
-
|
93
|
-
@dataclass
|
69
|
+
@dataclass(kw_only=True)
|
94
70
|
class Triton:
|
95
71
|
use_triton_if_available: bool = field(True, help="Use Triton if available")
|
96
72
|
|
97
73
|
|
98
|
-
@dataclass
|
74
|
+
@dataclass(kw_only=True)
|
99
75
|
class Experiment:
|
100
76
|
default_random_seed: int = field(1337, help="The default random seed to use")
|
77
|
+
max_workers: int = field(32, help="Maximum number of workers to use")
|
101
78
|
|
102
79
|
|
103
|
-
@dataclass
|
80
|
+
@dataclass(kw_only=True)
|
104
81
|
class Directories:
|
105
82
|
run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
|
106
83
|
data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
|
107
84
|
pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
|
108
85
|
|
109
86
|
|
110
|
-
@dataclass
|
87
|
+
@dataclass(kw_only=True)
|
111
88
|
class SlurmPartition:
|
112
89
|
partition: str = field(MISSING, help="The partition name")
|
113
90
|
num_nodes: int = field(1, help="The number of nodes to use")
|
114
91
|
|
115
92
|
|
116
|
-
@dataclass
|
93
|
+
@dataclass(kw_only=True)
|
117
94
|
class Slurm:
|
118
95
|
launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
|
119
96
|
|
120
97
|
|
121
|
-
@dataclass
|
98
|
+
@dataclass(kw_only=True)
|
122
99
|
class UserConfig:
|
123
100
|
logging: Logging = field(Logging)
|
124
|
-
device: Device = field(Device)
|
125
101
|
triton: Triton = field(Triton)
|
126
102
|
experiment: Experiment = field(Experiment)
|
127
103
|
directories: Directories = field(Directories)
|
xax/core/state.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Literal, TypedDict, cast, get_args
|
5
|
+
from typing import Literal, NotRequired, TypedDict, cast, get_args
|
6
6
|
|
7
7
|
from omegaconf import MISSING
|
8
8
|
|
@@ -18,16 +18,16 @@ def cast_phase(raw_phase: str) -> Phase:
|
|
18
18
|
|
19
19
|
|
20
20
|
class StateDict(TypedDict, total=False):
|
21
|
-
num_steps: int
|
22
|
-
num_samples: int
|
23
|
-
num_valid_steps: int
|
24
|
-
num_valid_samples: int
|
25
|
-
start_time_s: float
|
26
|
-
elapsed_time_s: float
|
27
|
-
raw_phase: str
|
21
|
+
num_steps: NotRequired[int]
|
22
|
+
num_samples: NotRequired[int]
|
23
|
+
num_valid_steps: NotRequired[int]
|
24
|
+
num_valid_samples: NotRequired[int]
|
25
|
+
start_time_s: NotRequired[float]
|
26
|
+
elapsed_time_s: NotRequired[float]
|
27
|
+
raw_phase: NotRequired[str]
|
28
28
|
|
29
29
|
|
30
|
-
@dataclass
|
30
|
+
@dataclass
|
31
31
|
class State:
|
32
32
|
num_steps: int = field(MISSING, help="Number of steps so far")
|
33
33
|
num_samples: int = field(MISSING, help="Number of sample so far")
|
@@ -41,6 +41,10 @@ class State:
|
|
41
41
|
def phase(self) -> Phase:
|
42
42
|
return cast_phase(self.raw_phase)
|
43
43
|
|
44
|
+
@phase.setter
|
45
|
+
def phase(self, phase: Phase) -> None:
|
46
|
+
self.raw_phase = phase
|
47
|
+
|
44
48
|
@classmethod
|
45
49
|
def init_state(cls) -> "State":
|
46
50
|
return cls(
|
@@ -65,17 +69,3 @@ class State:
|
|
65
69
|
return self.num_valid_steps
|
66
70
|
case _:
|
67
71
|
raise ValueError(f"Invalid phase: {phase}")
|
68
|
-
|
69
|
-
def replace(self, values: StateDict) -> "State":
|
70
|
-
return State(
|
71
|
-
num_steps=values.get("num_steps", self.num_steps),
|
72
|
-
num_samples=values.get("num_samples", self.num_samples),
|
73
|
-
num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
|
74
|
-
num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
|
75
|
-
start_time_s=values.get("start_time_s", self.start_time_s),
|
76
|
-
elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
|
77
|
-
raw_phase=values.get("raw_phase", self.raw_phase),
|
78
|
-
)
|
79
|
-
|
80
|
-
def with_phase(self, phase: Phase) -> "State":
|
81
|
-
return self.replace({"raw_phase": phase})
|