xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/__init__.py
CHANGED
@@ -1 +1,256 @@
|
|
1
|
-
|
1
|
+
"""Defines the top-level xax API.
|
2
|
+
|
3
|
+
This package is structured so that all the important stuff can be accessed
|
4
|
+
without having to dig around through the internals. This is done by lazily
|
5
|
+
importing the module by name.
|
6
|
+
|
7
|
+
This file can be maintained by running the update script:
|
8
|
+
|
9
|
+
.. code-block:: bash
|
10
|
+
|
11
|
+
python -m scripts.update_api --inplace
|
12
|
+
"""
|
13
|
+
|
14
|
+
__version__ = "0.0.5"
|
15
|
+
|
16
|
+
# This list shouldn't be modified by hand; instead, run the update script.
|
17
|
+
__all__ = [
|
18
|
+
"UserConfig",
|
19
|
+
"field",
|
20
|
+
"get_data_dir",
|
21
|
+
"get_pretrained_models_dir",
|
22
|
+
"get_run_dir",
|
23
|
+
"load_user_config",
|
24
|
+
"State",
|
25
|
+
"cast_phase",
|
26
|
+
"FourierEmbeddings",
|
27
|
+
"IdentityPositionalEmbeddings",
|
28
|
+
"LearnedPositionalEmbeddings",
|
29
|
+
"RotaryEmbeddings",
|
30
|
+
"SinusoidalEmbeddings",
|
31
|
+
"apply_rotary_embeddings",
|
32
|
+
"cast_embedding_kind",
|
33
|
+
"fourier_embeddings",
|
34
|
+
"get_positional_embeddings",
|
35
|
+
"get_rotary_embeddings",
|
36
|
+
"rotary_embeddings",
|
37
|
+
"BaseLauncher",
|
38
|
+
"CliLauncher",
|
39
|
+
"SingleProcessLauncher",
|
40
|
+
"LogImage",
|
41
|
+
"LogLine",
|
42
|
+
"Logger",
|
43
|
+
"LoggerImpl",
|
44
|
+
"CallbackLogger",
|
45
|
+
"JsonLogger",
|
46
|
+
"StateLogger",
|
47
|
+
"StdoutLogger",
|
48
|
+
"TensorboardLogger",
|
49
|
+
"CPUStatsOptions",
|
50
|
+
"DataloaderConfig",
|
51
|
+
"GPUStatsOptions",
|
52
|
+
"Script",
|
53
|
+
"ScriptConfig",
|
54
|
+
"Config",
|
55
|
+
"Task",
|
56
|
+
"collate",
|
57
|
+
"collate_non_null",
|
58
|
+
"BaseFileDownloader",
|
59
|
+
"DataDownloader",
|
60
|
+
"ModelDownloader",
|
61
|
+
"check_md5",
|
62
|
+
"check_sha256",
|
63
|
+
"get_git_state",
|
64
|
+
"get_state_dict_prefix",
|
65
|
+
"get_training_code",
|
66
|
+
"save_config",
|
67
|
+
"ColoredFormatter",
|
68
|
+
"configure_logging",
|
69
|
+
"one_hot",
|
70
|
+
"partial_flatten",
|
71
|
+
"worker_chunk",
|
72
|
+
"TextBlock",
|
73
|
+
"colored",
|
74
|
+
"format_datetime",
|
75
|
+
"format_timedelta",
|
76
|
+
"outlined",
|
77
|
+
"render_text_blocks",
|
78
|
+
"show_error",
|
79
|
+
"show_warning",
|
80
|
+
"uncolored",
|
81
|
+
"wrapped",
|
82
|
+
]
|
83
|
+
|
84
|
+
__all__ += [
|
85
|
+
"Batch",
|
86
|
+
"CollateMode",
|
87
|
+
"EmbeddingKind",
|
88
|
+
"Output",
|
89
|
+
"Phase",
|
90
|
+
]
|
91
|
+
|
92
|
+
import os
|
93
|
+
from typing import TYPE_CHECKING
|
94
|
+
|
95
|
+
# If this flag is set, eagerly imports the entire package (not recommended).
|
96
|
+
IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
|
97
|
+
|
98
|
+
del os
|
99
|
+
|
100
|
+
# This dictionary is auto-generated and shouldn't be modified by hand; instead,
|
101
|
+
# run the update script.
|
102
|
+
NAME_MAP: dict[str, str] = {
|
103
|
+
"UserConfig": "core.conf",
|
104
|
+
"field": "core.conf",
|
105
|
+
"get_data_dir": "core.conf",
|
106
|
+
"get_pretrained_models_dir": "core.conf",
|
107
|
+
"get_run_dir": "core.conf",
|
108
|
+
"load_user_config": "core.conf",
|
109
|
+
"State": "core.state",
|
110
|
+
"cast_phase": "core.state",
|
111
|
+
"FourierEmbeddings": "nn.embeddings",
|
112
|
+
"IdentityPositionalEmbeddings": "nn.embeddings",
|
113
|
+
"LearnedPositionalEmbeddings": "nn.embeddings",
|
114
|
+
"RotaryEmbeddings": "nn.embeddings",
|
115
|
+
"SinusoidalEmbeddings": "nn.embeddings",
|
116
|
+
"apply_rotary_embeddings": "nn.embeddings",
|
117
|
+
"cast_embedding_kind": "nn.embeddings",
|
118
|
+
"fourier_embeddings": "nn.embeddings",
|
119
|
+
"get_positional_embeddings": "nn.embeddings",
|
120
|
+
"get_rotary_embeddings": "nn.embeddings",
|
121
|
+
"rotary_embeddings": "nn.embeddings",
|
122
|
+
"BaseLauncher": "task.launchers.base",
|
123
|
+
"CliLauncher": "task.launchers.cli",
|
124
|
+
"SingleProcessLauncher": "task.launchers.single_process",
|
125
|
+
"LogImage": "task.logger",
|
126
|
+
"LogLine": "task.logger",
|
127
|
+
"Logger": "task.logger",
|
128
|
+
"LoggerImpl": "task.logger",
|
129
|
+
"CallbackLogger": "task.loggers.callback",
|
130
|
+
"JsonLogger": "task.loggers.json",
|
131
|
+
"StateLogger": "task.loggers.state",
|
132
|
+
"StdoutLogger": "task.loggers.stdout",
|
133
|
+
"TensorboardLogger": "task.loggers.tensorboard",
|
134
|
+
"CPUStatsOptions": "task.mixins.cpu_stats",
|
135
|
+
"DataloaderConfig": "task.mixins.data_loader",
|
136
|
+
"GPUStatsOptions": "task.mixins.gpu_stats",
|
137
|
+
"Script": "task.script",
|
138
|
+
"ScriptConfig": "task.script",
|
139
|
+
"Config": "task.task",
|
140
|
+
"Task": "task.task",
|
141
|
+
"collate": "utils.data.collate",
|
142
|
+
"collate_non_null": "utils.data.collate",
|
143
|
+
"BaseFileDownloader": "utils.experiments",
|
144
|
+
"DataDownloader": "utils.experiments",
|
145
|
+
"ModelDownloader": "utils.experiments",
|
146
|
+
"check_md5": "utils.experiments",
|
147
|
+
"check_sha256": "utils.experiments",
|
148
|
+
"get_git_state": "utils.experiments",
|
149
|
+
"get_state_dict_prefix": "utils.experiments",
|
150
|
+
"get_training_code": "utils.experiments",
|
151
|
+
"save_config": "utils.experiments",
|
152
|
+
"ColoredFormatter": "utils.logging",
|
153
|
+
"configure_logging": "utils.logging",
|
154
|
+
"one_hot": "utils.numpy",
|
155
|
+
"partial_flatten": "utils.numpy",
|
156
|
+
"worker_chunk": "utils.numpy",
|
157
|
+
"TextBlock": "utils.text",
|
158
|
+
"colored": "utils.text",
|
159
|
+
"format_datetime": "utils.text",
|
160
|
+
"format_timedelta": "utils.text",
|
161
|
+
"outlined": "utils.text",
|
162
|
+
"render_text_blocks": "utils.text",
|
163
|
+
"show_error": "utils.text",
|
164
|
+
"show_warning": "utils.text",
|
165
|
+
"uncolored": "utils.text",
|
166
|
+
"wrapped": "utils.text",
|
167
|
+
}
|
168
|
+
|
169
|
+
# Need to manually set some values which can't be auto-generated.
|
170
|
+
NAME_MAP.update(
|
171
|
+
{
|
172
|
+
"Batch": "task.mixins.train",
|
173
|
+
"CollateMode": "utils.data.collate",
|
174
|
+
"EmbeddingKind": "nn.embeddings",
|
175
|
+
"Output": "task.mixins.output",
|
176
|
+
"Phase": "core.state",
|
177
|
+
},
|
178
|
+
)
|
179
|
+
|
180
|
+
|
181
|
+
def __getattr__(name: str) -> object:
|
182
|
+
if name not in NAME_MAP:
|
183
|
+
raise AttributeError(f"{__name__} has no attribute {name}")
|
184
|
+
|
185
|
+
module_name = f"xax.{NAME_MAP[name]}"
|
186
|
+
module = __import__(module_name, fromlist=[name])
|
187
|
+
return getattr(module, name)
|
188
|
+
|
189
|
+
|
190
|
+
if IMPORT_ALL or TYPE_CHECKING:
|
191
|
+
from xax.core.conf import (
|
192
|
+
UserConfig,
|
193
|
+
field,
|
194
|
+
get_data_dir,
|
195
|
+
get_pretrained_models_dir,
|
196
|
+
get_run_dir,
|
197
|
+
load_user_config,
|
198
|
+
)
|
199
|
+
from xax.core.state import Phase, State, cast_phase
|
200
|
+
from xax.nn.embeddings import (
|
201
|
+
EmbeddingKind,
|
202
|
+
FourierEmbeddings,
|
203
|
+
IdentityPositionalEmbeddings,
|
204
|
+
LearnedPositionalEmbeddings,
|
205
|
+
RotaryEmbeddings,
|
206
|
+
SinusoidalEmbeddings,
|
207
|
+
apply_rotary_embeddings,
|
208
|
+
cast_embedding_kind,
|
209
|
+
fourier_embeddings,
|
210
|
+
get_positional_embeddings,
|
211
|
+
get_rotary_embeddings,
|
212
|
+
rotary_embeddings,
|
213
|
+
)
|
214
|
+
from xax.task.launchers.base import BaseLauncher
|
215
|
+
from xax.task.launchers.cli import CliLauncher
|
216
|
+
from xax.task.launchers.single_process import SingleProcessLauncher
|
217
|
+
from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
|
218
|
+
from xax.task.loggers.callback import CallbackLogger
|
219
|
+
from xax.task.loggers.json import JsonLogger
|
220
|
+
from xax.task.loggers.state import StateLogger
|
221
|
+
from xax.task.loggers.stdout import StdoutLogger
|
222
|
+
from xax.task.loggers.tensorboard import TensorboardLogger
|
223
|
+
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
224
|
+
from xax.task.mixins.data_loader import DataloaderConfig
|
225
|
+
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
226
|
+
from xax.task.mixins.train import Batch, Output
|
227
|
+
from xax.task.script import Script, ScriptConfig
|
228
|
+
from xax.task.task import Config, Task
|
229
|
+
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
230
|
+
from xax.utils.experiments import (
|
231
|
+
BaseFileDownloader,
|
232
|
+
DataDownloader,
|
233
|
+
ModelDownloader,
|
234
|
+
check_md5,
|
235
|
+
check_sha256,
|
236
|
+
get_git_state,
|
237
|
+
get_state_dict_prefix,
|
238
|
+
get_training_code,
|
239
|
+
save_config,
|
240
|
+
)
|
241
|
+
from xax.utils.logging import ColoredFormatter, configure_logging
|
242
|
+
from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
|
243
|
+
from xax.utils.text import (
|
244
|
+
TextBlock,
|
245
|
+
colored,
|
246
|
+
format_datetime,
|
247
|
+
format_timedelta,
|
248
|
+
outlined,
|
249
|
+
render_text_blocks,
|
250
|
+
show_error,
|
251
|
+
show_warning,
|
252
|
+
uncolored,
|
253
|
+
wrapped,
|
254
|
+
)
|
255
|
+
|
256
|
+
del TYPE_CHECKING, IMPORT_ALL
|
xax/core/conf.py
ADDED
@@ -0,0 +1,193 @@
|
|
1
|
+
"""Defines base configuration functions and utilities."""
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import os
|
5
|
+
from dataclasses import dataclass, field as field_base
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, cast
|
8
|
+
|
9
|
+
import jax.numpy as jnp
|
10
|
+
from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
|
11
|
+
|
12
|
+
from xax.utils.text import show_error
|
13
|
+
|
14
|
+
FieldType = Any
|
15
|
+
|
16
|
+
|
17
|
+
def field(value: FieldType, **kwargs: str) -> FieldType:
|
18
|
+
"""Short-hand function for getting a config field.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
value: The current field's default value.
|
22
|
+
kwargs: Additional metadata fields to supply.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
The dataclass field.
|
26
|
+
"""
|
27
|
+
metadata: dict[str, Any] = {}
|
28
|
+
metadata.update(kwargs)
|
29
|
+
|
30
|
+
if hasattr(value, "__call__"):
|
31
|
+
return field_base(default_factory=value, metadata=metadata)
|
32
|
+
if value.__class__.__hash__ is None:
|
33
|
+
return field_base(default_factory=lambda: value, metadata=metadata)
|
34
|
+
return field_base(default=value, metadata=metadata)
|
35
|
+
|
36
|
+
|
37
|
+
def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
|
38
|
+
"""Utility function for checking if a config key is missing.
|
39
|
+
|
40
|
+
This is for cases when you are using a raw dataclass rather than an
|
41
|
+
OmegaConf container but want to treat them the same way.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
cfg: The config to check
|
45
|
+
key: The key to check
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
Whether or not the key is missing a value in the config
|
49
|
+
"""
|
50
|
+
if isinstance(cfg, OmegaConfContainer):
|
51
|
+
if OmegaConf.is_missing(cfg, key):
|
52
|
+
return True
|
53
|
+
if OmegaConf.is_interpolation(cfg, key):
|
54
|
+
try:
|
55
|
+
getattr(cfg, key)
|
56
|
+
return False
|
57
|
+
except Exception:
|
58
|
+
return True
|
59
|
+
if getattr(cfg, key) is MISSING:
|
60
|
+
return True
|
61
|
+
return False
|
62
|
+
|
63
|
+
|
64
|
+
@dataclass
|
65
|
+
class Logging:
|
66
|
+
hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
|
67
|
+
log_level: str = field("INFO", help="The logging level to use")
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class Device:
|
72
|
+
cpu: bool = field(True, help="Whether to use the CPU")
|
73
|
+
gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
|
74
|
+
metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
|
75
|
+
use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
|
76
|
+
use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
|
77
|
+
use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
|
78
|
+
use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
|
79
|
+
|
80
|
+
|
81
|
+
def parse_dtype(cfg: Device) -> jnp.dtype | None:
|
82
|
+
if cfg.use_fp64:
|
83
|
+
return jnp.float64
|
84
|
+
if cfg.use_fp32:
|
85
|
+
return jnp.float32
|
86
|
+
if cfg.use_bf16:
|
87
|
+
return jnp.bfloat16
|
88
|
+
if cfg.use_fp16:
|
89
|
+
return jnp.float16
|
90
|
+
return None
|
91
|
+
|
92
|
+
|
93
|
+
@dataclass
|
94
|
+
class Triton:
|
95
|
+
use_triton_if_available: bool = field(True, help="Use Triton if available")
|
96
|
+
|
97
|
+
|
98
|
+
@dataclass
|
99
|
+
class Experiment:
|
100
|
+
default_random_seed: int = field(1337, help="The default random seed to use")
|
101
|
+
max_workers: int = field(32, help="Maximum number of workers to use")
|
102
|
+
|
103
|
+
|
104
|
+
@dataclass
|
105
|
+
class Directories:
|
106
|
+
run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
|
107
|
+
data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
|
108
|
+
pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
|
109
|
+
|
110
|
+
|
111
|
+
@dataclass
|
112
|
+
class SlurmPartition:
|
113
|
+
partition: str = field(MISSING, help="The partition name")
|
114
|
+
num_nodes: int = field(1, help="The number of nodes to use")
|
115
|
+
|
116
|
+
|
117
|
+
@dataclass
|
118
|
+
class Slurm:
|
119
|
+
launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
|
120
|
+
|
121
|
+
|
122
|
+
@dataclass
|
123
|
+
class UserConfig:
|
124
|
+
logging: Logging = field(Logging)
|
125
|
+
device: Device = field(Device)
|
126
|
+
triton: Triton = field(Triton)
|
127
|
+
experiment: Experiment = field(Experiment)
|
128
|
+
directories: Directories = field(Directories)
|
129
|
+
slurm: Slurm = field(Slurm)
|
130
|
+
|
131
|
+
|
132
|
+
def user_config_path() -> Path:
|
133
|
+
xaxrc_path_raw = os.environ.get("XAXRC_PATH", "~/.xax.yml")
|
134
|
+
xaxrc_path = Path(xaxrc_path_raw).expanduser()
|
135
|
+
return xaxrc_path
|
136
|
+
|
137
|
+
|
138
|
+
@functools.lru_cache(maxsize=None)
|
139
|
+
def _load_user_config_cached() -> UserConfig:
|
140
|
+
xaxrc_path = user_config_path()
|
141
|
+
base_cfg = OmegaConf.structured(UserConfig)
|
142
|
+
|
143
|
+
# Writes the config file.
|
144
|
+
if xaxrc_path.exists():
|
145
|
+
cfg = OmegaConf.merge(base_cfg, OmegaConf.load(xaxrc_path))
|
146
|
+
else:
|
147
|
+
show_error(f"No config file was found in {xaxrc_path}; writing one...", important=True)
|
148
|
+
OmegaConf.save(base_cfg, xaxrc_path)
|
149
|
+
cfg = base_cfg
|
150
|
+
|
151
|
+
# Looks in the current directory for a config file.
|
152
|
+
local_cfg_path = Path("xax.yml")
|
153
|
+
if local_cfg_path.exists():
|
154
|
+
cfg = OmegaConf.merge(cfg, OmegaConf.load(local_cfg_path))
|
155
|
+
|
156
|
+
return cast(UserConfig, cfg)
|
157
|
+
|
158
|
+
|
159
|
+
def load_user_config() -> UserConfig:
|
160
|
+
"""Loads the ``~/.xax.yml`` configuration file.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
The loaded configuration.
|
164
|
+
"""
|
165
|
+
return _load_user_config_cached()
|
166
|
+
|
167
|
+
|
168
|
+
def get_run_dir() -> Path | None:
|
169
|
+
config = load_user_config().directories
|
170
|
+
if is_missing(config, "run"):
|
171
|
+
return None
|
172
|
+
(run_dir := Path(config.run)).mkdir(parents=True, exist_ok=True)
|
173
|
+
return run_dir
|
174
|
+
|
175
|
+
|
176
|
+
def get_data_dir() -> Path:
|
177
|
+
config = load_user_config().directories
|
178
|
+
if is_missing(config, "data"):
|
179
|
+
raise RuntimeError(
|
180
|
+
"The data directory has not been set! You should set it in your config file "
|
181
|
+
f"in {user_config_path()} or set the DATA_DIR environment variable."
|
182
|
+
)
|
183
|
+
return Path(config.data)
|
184
|
+
|
185
|
+
|
186
|
+
def get_pretrained_models_dir() -> Path:
|
187
|
+
config = load_user_config().directories
|
188
|
+
if is_missing(config, "pretrained_models"):
|
189
|
+
raise RuntimeError(
|
190
|
+
"The data directory has not been set! You should set it in your config file "
|
191
|
+
f"in {user_config_path()} or set the MODEL_DIR environment variable."
|
192
|
+
)
|
193
|
+
return Path(config.pretrained_models)
|
xax/core/state.py
ADDED
@@ -0,0 +1,81 @@
|
|
1
|
+
"""Defines a dataclass for keeping track of the current training state."""
|
2
|
+
|
3
|
+
import time
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import Literal, TypedDict, cast, get_args
|
6
|
+
|
7
|
+
from omegaconf import MISSING
|
8
|
+
|
9
|
+
from xax.core.conf import field
|
10
|
+
|
11
|
+
Phase = Literal["train", "valid"]
|
12
|
+
|
13
|
+
|
14
|
+
def cast_phase(raw_phase: str) -> Phase:
|
15
|
+
args = get_args(Phase)
|
16
|
+
assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
|
17
|
+
return cast(Phase, raw_phase)
|
18
|
+
|
19
|
+
|
20
|
+
class StateDict(TypedDict, total=False):
|
21
|
+
num_steps: int
|
22
|
+
num_samples: int
|
23
|
+
num_valid_steps: int
|
24
|
+
num_valid_samples: int
|
25
|
+
start_time_s: float
|
26
|
+
elapsed_time_s: float
|
27
|
+
raw_phase: str
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass(frozen=True)
|
31
|
+
class State:
|
32
|
+
num_steps: int = field(MISSING, help="Number of steps so far")
|
33
|
+
num_samples: int = field(MISSING, help="Number of sample so far")
|
34
|
+
num_valid_steps: int = field(MISSING, help="Number of validation steps so far")
|
35
|
+
num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
|
36
|
+
start_time_s: float = field(MISSING, help="Start time of training")
|
37
|
+
elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
|
38
|
+
raw_phase: str = field(MISSING, help="Current training phase")
|
39
|
+
|
40
|
+
@property
|
41
|
+
def phase(self) -> Phase:
|
42
|
+
return cast_phase(self.raw_phase)
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def init_state(cls) -> "State":
|
46
|
+
return cls(
|
47
|
+
num_steps=0,
|
48
|
+
num_samples=0,
|
49
|
+
num_valid_steps=0,
|
50
|
+
num_valid_samples=0,
|
51
|
+
start_time_s=time.time(),
|
52
|
+
elapsed_time_s=0.0,
|
53
|
+
raw_phase="train",
|
54
|
+
)
|
55
|
+
|
56
|
+
@property
|
57
|
+
def training(self) -> bool:
|
58
|
+
return self.phase == "train"
|
59
|
+
|
60
|
+
def num_phase_steps(self, phase: Phase) -> int:
|
61
|
+
match phase:
|
62
|
+
case "train":
|
63
|
+
return self.num_steps
|
64
|
+
case "valid":
|
65
|
+
return self.num_valid_steps
|
66
|
+
case _:
|
67
|
+
raise ValueError(f"Invalid phase: {phase}")
|
68
|
+
|
69
|
+
def replace(self, values: StateDict) -> "State":
|
70
|
+
return State(
|
71
|
+
num_steps=values.get("num_steps", self.num_steps),
|
72
|
+
num_samples=values.get("num_samples", self.num_samples),
|
73
|
+
num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
|
74
|
+
num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
|
75
|
+
start_time_s=values.get("start_time_s", self.start_time_s),
|
76
|
+
elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
|
77
|
+
raw_phase=values.get("raw_phase", self.raw_phase),
|
78
|
+
)
|
79
|
+
|
80
|
+
def with_phase(self, phase: Phase) -> "State":
|
81
|
+
return self.replace({"raw_phase": phase})
|
xax/nn/__init__.py
ADDED
File without changes
|