xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/task/mixins/train.py
ADDED
@@ -0,0 +1,541 @@
|
|
1
|
+
"""Defines a mixin for running the training loop."""
|
2
|
+
|
3
|
+
import bdb
|
4
|
+
import contextlib
|
5
|
+
import functools
|
6
|
+
import itertools
|
7
|
+
import logging
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import textwrap
|
11
|
+
import time
|
12
|
+
import traceback
|
13
|
+
from abc import ABC, abstractmethod
|
14
|
+
from dataclasses import dataclass, is_dataclass
|
15
|
+
from threading import Thread
|
16
|
+
from typing import Any, Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
|
17
|
+
|
18
|
+
import equinox as eqx
|
19
|
+
import jax
|
20
|
+
import jax.numpy as jnp
|
21
|
+
import numpy as np
|
22
|
+
import optax
|
23
|
+
from jaxtyping import Array, PyTree
|
24
|
+
from omegaconf import DictConfig
|
25
|
+
|
26
|
+
from xax.core.conf import field
|
27
|
+
from xax.core.state import Phase, State
|
28
|
+
from xax.nn.parallel import is_master
|
29
|
+
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
30
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
31
|
+
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
32
|
+
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
33
|
+
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
34
|
+
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
35
|
+
from xax.utils.experiments import (
|
36
|
+
StateTimer,
|
37
|
+
TrainingFinishedError,
|
38
|
+
diff_configs,
|
39
|
+
get_diff_string,
|
40
|
+
get_git_state,
|
41
|
+
get_training_code,
|
42
|
+
)
|
43
|
+
from xax.utils.logging import LOG_STATUS
|
44
|
+
from xax.utils.text import highlight_exception_message, show_info
|
45
|
+
|
46
|
+
logger = logging.getLogger(__name__)
|
47
|
+
|
48
|
+
# Batch = TypeVar("Batch")
|
49
|
+
# Output = TypeVar("Output")
|
50
|
+
|
51
|
+
Batch = Any
|
52
|
+
Output = Any
|
53
|
+
|
54
|
+
StepKind = Literal["step", "sample", "second"]
|
55
|
+
|
56
|
+
PRINT_FINISH_TIME_EVERY_N_SECONDS = 60 * 2
|
57
|
+
|
58
|
+
|
59
|
+
def cast_step_kind(s: str) -> StepKind:
|
60
|
+
assert s in get_args(StepKind), f"`step_kind` must be one of {get_args(StepKind)}, not {s}"
|
61
|
+
return cast(StepKind, s)
|
62
|
+
|
63
|
+
|
64
|
+
@functools.lru_cache(maxsize=None)
|
65
|
+
def batch_chunks_schedule(schedule: list[int] | None) -> list[int] | None:
|
66
|
+
if schedule is None:
|
67
|
+
return None
|
68
|
+
if any(s < 1 for s in schedule):
|
69
|
+
raise ValueError("Batch chunk schedule must be positive")
|
70
|
+
return list(itertools.accumulate([0] + schedule))
|
71
|
+
|
72
|
+
|
73
|
+
@functools.lru_cache(maxsize=None)
|
74
|
+
def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
|
75
|
+
if schedule is None:
|
76
|
+
return None
|
77
|
+
if any(s < 1 for s in schedule):
|
78
|
+
raise ValueError("Batch chunk schedule must be positive")
|
79
|
+
return list(itertools.accumulate([0] + schedule))
|
80
|
+
|
81
|
+
|
82
|
+
class ValidStepTimer:
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
valid_every_n_steps: int | None = None,
|
86
|
+
valid_first_n_steps: int = 0,
|
87
|
+
valid_every_n_seconds: float | None = None,
|
88
|
+
valid_first_n_seconds: float | None = None,
|
89
|
+
) -> None:
|
90
|
+
super().__init__()
|
91
|
+
|
92
|
+
self.valid_every_n_steps = valid_every_n_steps
|
93
|
+
self.valid_first_n_steps = valid_first_n_steps
|
94
|
+
self.valid_every_n_seconds = valid_every_n_seconds
|
95
|
+
self.valid_first_n_seconds = valid_first_n_seconds
|
96
|
+
self.first_valid_step_flag = True
|
97
|
+
|
98
|
+
self.last_valid_time: float | None = None
|
99
|
+
self.last_valid_step: int | None = None
|
100
|
+
|
101
|
+
def is_valid_step(self, state: State) -> bool:
|
102
|
+
if state.num_steps < self.valid_first_n_steps:
|
103
|
+
return True
|
104
|
+
|
105
|
+
if self.last_valid_time is None or self.last_valid_step is None:
|
106
|
+
self.last_valid_time = state.elapsed_time_s
|
107
|
+
self.last_valid_step = state.num_steps
|
108
|
+
return True
|
109
|
+
|
110
|
+
# Step-based validation.
|
111
|
+
valid_every_n_steps = self.valid_every_n_steps
|
112
|
+
if valid_every_n_steps is not None and state.num_steps > valid_every_n_steps + self.last_valid_step:
|
113
|
+
self.last_valid_step = state.num_steps
|
114
|
+
return True
|
115
|
+
|
116
|
+
# Time-based validation.
|
117
|
+
valid_every_n_seconds = self.valid_every_n_seconds
|
118
|
+
if valid_every_n_seconds is not None and state.elapsed_time_s - self.last_valid_time >= valid_every_n_seconds:
|
119
|
+
self.last_valid_time = state.elapsed_time_s
|
120
|
+
return True
|
121
|
+
|
122
|
+
# Time-based validation for first validation step.
|
123
|
+
if self.first_valid_step_flag:
|
124
|
+
valid_first_n_seconds = self.valid_first_n_seconds
|
125
|
+
if valid_first_n_seconds is not None and state.elapsed_time_s >= valid_first_n_seconds:
|
126
|
+
self.last_valid_time = state.elapsed_time_s
|
127
|
+
self.first_valid_step_flag = False
|
128
|
+
return True
|
129
|
+
|
130
|
+
return False
|
131
|
+
|
132
|
+
|
133
|
+
@dataclass
|
134
|
+
class TrainConfig(
|
135
|
+
CheckpointingConfig,
|
136
|
+
DataloadersConfig,
|
137
|
+
LoggerConfig,
|
138
|
+
StepContextConfig,
|
139
|
+
ArtifactsConfig,
|
140
|
+
RunnableConfig,
|
141
|
+
):
|
142
|
+
valid_every_n_steps: int | None = field(None, help="Number of training steps to run per validation step")
|
143
|
+
valid_first_n_steps: int = field(0, help="Treat the first N steps as validation steps")
|
144
|
+
valid_every_n_seconds: float | None = field(60.0 * 10.0, help="Run validation every N seconds")
|
145
|
+
valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
|
146
|
+
batch_dim: int = field(0, help="The batch dimension, for splitting batches into chunks")
|
147
|
+
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
148
|
+
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
149
|
+
random_seed: int = field(1337, help="Random seed for the task")
|
150
|
+
|
151
|
+
|
152
|
+
Config = TypeVar("Config", bound=TrainConfig)
|
153
|
+
|
154
|
+
|
155
|
+
class TrainMixin(
|
156
|
+
CheckpointingMixin[Config],
|
157
|
+
DataloadersMixin[Config],
|
158
|
+
LoggerMixin[Config],
|
159
|
+
StepContextMixin[Config],
|
160
|
+
ArtifactsMixin[Config],
|
161
|
+
RunnableMixin[Config],
|
162
|
+
Generic[Config],
|
163
|
+
ABC,
|
164
|
+
):
|
165
|
+
valid_step_timer: ValidStepTimer
|
166
|
+
state_timers: dict[Phase, StateTimer]
|
167
|
+
|
168
|
+
_training_over_flag: bool
|
169
|
+
_last_printed_remaining_time: float
|
170
|
+
_step_kind: StepKind
|
171
|
+
|
172
|
+
def __init__(self, config: Config) -> None:
|
173
|
+
super().__init__(config)
|
174
|
+
|
175
|
+
# Timer for validation steps.
|
176
|
+
self.valid_step_timer = ValidStepTimer(
|
177
|
+
valid_every_n_steps=config.valid_every_n_steps,
|
178
|
+
valid_first_n_steps=config.valid_first_n_steps,
|
179
|
+
valid_every_n_seconds=config.valid_every_n_seconds,
|
180
|
+
valid_first_n_seconds=config.valid_first_n_seconds,
|
181
|
+
)
|
182
|
+
|
183
|
+
# Timers for iterations.
|
184
|
+
self.state_timers = {phase: StateTimer() for phase in get_args(Phase)}
|
185
|
+
|
186
|
+
# This flag can be toggled to end training from anywhere in the task.
|
187
|
+
self._training_over_flag = False
|
188
|
+
|
189
|
+
self._last_printed_remaining_time = 0.0
|
190
|
+
|
191
|
+
# The kind of step that was specified in the config.
|
192
|
+
self._step_kind = cast_step_kind(self.config.step_kind)
|
193
|
+
|
194
|
+
def prng_key(self) -> jnp.ndarray:
|
195
|
+
return jax.random.PRNGKey(self.config.random_seed)
|
196
|
+
|
197
|
+
def on_step_end(self, state: State) -> State:
|
198
|
+
state = super().on_step_end(state)
|
199
|
+
return state.replace(
|
200
|
+
{
|
201
|
+
"elapsed_time_s": time.time() - state.start_time_s,
|
202
|
+
},
|
203
|
+
)
|
204
|
+
|
205
|
+
def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
206
|
+
"""Override this function to do logging during the training phase.
|
207
|
+
|
208
|
+
This function is called after the model forward pass and before the
|
209
|
+
backward pass. It is called in the training phase.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
model: The current model.
|
213
|
+
batch: The batch from the dataloader.
|
214
|
+
output: The model output.
|
215
|
+
state: The current training state.
|
216
|
+
"""
|
217
|
+
|
218
|
+
def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
219
|
+
"""Override this function to do logging during the validation phase.
|
220
|
+
|
221
|
+
This function is called after the model forward pass. It is called in
|
222
|
+
the validation phase.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
model: The current model.
|
226
|
+
batch: The batch from the dataloader.
|
227
|
+
output: The model output.
|
228
|
+
state: The current training state.
|
229
|
+
"""
|
230
|
+
|
231
|
+
def log_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
232
|
+
phase = state.phase
|
233
|
+
|
234
|
+
# Log the state timers.
|
235
|
+
timer = self.state_timers[phase]
|
236
|
+
timer.step(state)
|
237
|
+
for ns, d in timer.log_dict().items():
|
238
|
+
for k, v in d.items():
|
239
|
+
self.logger.log_scalar(k, v, namespace=ns)
|
240
|
+
|
241
|
+
# Delegate to the appropriate logging function based on the phase.
|
242
|
+
match phase:
|
243
|
+
case "train":
|
244
|
+
self.log_train_step(model, batch, output, state)
|
245
|
+
case "valid":
|
246
|
+
self.log_valid_step(model, batch, output, state)
|
247
|
+
case _:
|
248
|
+
raise KeyError(f"Unknown phase: {phase}")
|
249
|
+
|
250
|
+
@abstractmethod
|
251
|
+
def get_model(self) -> PyTree:
|
252
|
+
"""Returns the Equinox model to train.
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
The model to train.
|
256
|
+
"""
|
257
|
+
|
258
|
+
@abstractmethod
|
259
|
+
def get_optimizer(self) -> optax.GradientTransformation:
|
260
|
+
"""Gets the optimizer for the model.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
The optimizer to use to train the model.
|
264
|
+
"""
|
265
|
+
|
266
|
+
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
267
|
+
return optimizer.init(eqx.filter(model, eqx.is_array))
|
268
|
+
|
269
|
+
def load_initial_state(self) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
270
|
+
init_ckpt_path = self.get_init_ckpt_path()
|
271
|
+
|
272
|
+
if init_ckpt_path is not None:
|
273
|
+
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
274
|
+
with self.step_context("load_checkpoint"):
|
275
|
+
model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
|
276
|
+
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
277
|
+
if config_diff:
|
278
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
279
|
+
return model, optimizer, opt_state, state
|
280
|
+
|
281
|
+
with self.step_context("get_model"):
|
282
|
+
model = self.get_model()
|
283
|
+
|
284
|
+
with self.step_context("get_optimizer"):
|
285
|
+
optimizer = self.get_optimizer()
|
286
|
+
|
287
|
+
with self.step_context("get_initial_opt_state"):
|
288
|
+
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
|
289
|
+
|
290
|
+
return model, optimizer, opt_state, State.init_state()
|
291
|
+
|
292
|
+
@abstractmethod
|
293
|
+
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
294
|
+
"""Gets the output from the model.
|
295
|
+
|
296
|
+
By default, we assume the model is a function that takes the batch as
|
297
|
+
input and returns the loss. This function can be patched to do more
|
298
|
+
complex operations instead.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
model: The current model.
|
302
|
+
batch: The current minibatch of samples.
|
303
|
+
state: The current training state.
|
304
|
+
"""
|
305
|
+
|
306
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
307
|
+
"""Gets the loss for the current batch.
|
308
|
+
|
309
|
+
By default, we assume the model is a function that takes the batch as
|
310
|
+
input and returns the loss. This function can be patched to do more
|
311
|
+
complex operations instead.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
model: The current model.
|
315
|
+
batch: The current minibatch of samples.
|
316
|
+
output: The output from the model.
|
317
|
+
state: The current training state.
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
The computed loss, as a tensor.
|
321
|
+
"""
|
322
|
+
if not isinstance(output, Array):
|
323
|
+
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
324
|
+
return output
|
325
|
+
|
326
|
+
def get_output_and_loss(self, model: PyTree, batch: Batch, state: State) -> tuple[Array, Output]:
|
327
|
+
output = self.get_output(model, batch, state)
|
328
|
+
loss = self.compute_loss(model, batch, output, state)
|
329
|
+
return loss, output
|
330
|
+
|
331
|
+
@eqx.filter_jit
|
332
|
+
def update(
|
333
|
+
self,
|
334
|
+
model: PyTree,
|
335
|
+
optimizer: optax.GradientTransformation,
|
336
|
+
opt_state: optax.OptState,
|
337
|
+
batch: Batch,
|
338
|
+
state: State,
|
339
|
+
) -> tuple[Array, PyTree, optax.OptState, Output]:
|
340
|
+
(loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
|
341
|
+
updates, opt_state = optimizer.update(grads, opt_state)
|
342
|
+
model = eqx.apply_updates(model, updates)
|
343
|
+
return loss, model, opt_state, output
|
344
|
+
|
345
|
+
def get_size_of_batch(self, batch: Batch) -> int | None:
|
346
|
+
"""Gets the batch size for the current batch.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
batch: The current minibatch of samples.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
The parsed batch size, or None if the batch size could not be
|
353
|
+
determined.
|
354
|
+
"""
|
355
|
+
if isinstance(batch, (np.ndarray, Array)):
|
356
|
+
return batch.shape[0]
|
357
|
+
if is_dataclass(batch):
|
358
|
+
for v in batch.__dict__.values():
|
359
|
+
if bsz := self.get_size_of_batch(v):
|
360
|
+
return bsz
|
361
|
+
if isinstance(batch, Mapping):
|
362
|
+
for v in batch.values():
|
363
|
+
if bsz := self.get_size_of_batch(v):
|
364
|
+
return bsz
|
365
|
+
if isinstance(batch, Sequence):
|
366
|
+
for i in batch:
|
367
|
+
if bsz := self.get_size_of_batch(i):
|
368
|
+
return bsz
|
369
|
+
return None
|
370
|
+
|
371
|
+
def set_training_over(self) -> None:
|
372
|
+
self._training_over_flag = True
|
373
|
+
|
374
|
+
def maybe_log_termination_time(self, remaining_percent: float, state: State) -> None:
|
375
|
+
if self._last_printed_remaining_time + PRINT_FINISH_TIME_EVERY_N_SECONDS > state.elapsed_time_s:
|
376
|
+
return
|
377
|
+
self._last_printed_remaining_time = state.elapsed_time_s
|
378
|
+
remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
|
379
|
+
termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
|
380
|
+
logger.info("Estimated finish time: %s", termination_time)
|
381
|
+
|
382
|
+
def is_training_over(self, state: State) -> bool:
|
383
|
+
if self._training_over_flag:
|
384
|
+
return True
|
385
|
+
remaining_percent = self.get_remaining_percent(state)
|
386
|
+
if remaining_percent is None:
|
387
|
+
return False
|
388
|
+
self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
|
389
|
+
self.maybe_log_termination_time(remaining_percent, state)
|
390
|
+
return remaining_percent <= 0.0
|
391
|
+
|
392
|
+
def get_step(self, state: State) -> int:
|
393
|
+
match self._step_kind:
|
394
|
+
case "step":
|
395
|
+
return state.num_steps
|
396
|
+
case "sample":
|
397
|
+
return state.num_samples
|
398
|
+
case "second":
|
399
|
+
return int(state.elapsed_time_s)
|
400
|
+
case _:
|
401
|
+
raise ValueError(f"Invalid step kind {self._step_kind}")
|
402
|
+
|
403
|
+
def get_remaining_percent(self, state: State) -> float | None:
|
404
|
+
if self.config.max_steps is None:
|
405
|
+
return None
|
406
|
+
return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
|
407
|
+
|
408
|
+
def log_state(self) -> None:
|
409
|
+
logger.log(LOG_STATUS, self.task_path)
|
410
|
+
logger.log(LOG_STATUS, self.task_name)
|
411
|
+
self.logger.log_git_state(get_git_state(self))
|
412
|
+
self.logger.log_training_code(get_training_code(self))
|
413
|
+
self.logger.log_config(cast(DictConfig, self.config))
|
414
|
+
|
415
|
+
def train_step(
|
416
|
+
self,
|
417
|
+
model: PyTree,
|
418
|
+
optimizer: optax.GradientTransformation,
|
419
|
+
opt_state: optax.OptState,
|
420
|
+
batch: Batch,
|
421
|
+
state: State,
|
422
|
+
) -> tuple[PyTree, optax.OptState, State]:
|
423
|
+
state = state.with_phase("train")
|
424
|
+
loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
|
425
|
+
self.logger.log_scalar("loss", loss, namespace="loss")
|
426
|
+
self.log_step(model, batch, output, state)
|
427
|
+
self.write_logs(state)
|
428
|
+
return (
|
429
|
+
model,
|
430
|
+
opt_state,
|
431
|
+
state.replace(
|
432
|
+
{
|
433
|
+
"num_steps": state.num_steps + 1,
|
434
|
+
"num_samples": state.num_samples + (self.get_size_of_batch(batch) or 0),
|
435
|
+
},
|
436
|
+
),
|
437
|
+
)
|
438
|
+
|
439
|
+
def val_step(self, model: PyTree, batch: Batch, state: State) -> tuple[PyTree, State]:
|
440
|
+
state = state.with_phase("valid")
|
441
|
+
loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
|
442
|
+
self.logger.log_scalar("loss", loss, namespace="loss")
|
443
|
+
self.log_step(model, batch, output, state)
|
444
|
+
self.write_logs(state)
|
445
|
+
return model, state.replace(
|
446
|
+
{
|
447
|
+
"num_valid_steps": state.num_valid_steps + 1,
|
448
|
+
"num_valid_samples": state.num_valid_samples + (self.get_size_of_batch(batch) or 0),
|
449
|
+
},
|
450
|
+
)
|
451
|
+
|
452
|
+
def run(self) -> None:
|
453
|
+
self.run_training_loop()
|
454
|
+
|
455
|
+
def run_training_loop(self) -> None:
|
456
|
+
"""Runs the training loop.
|
457
|
+
|
458
|
+
Args:
|
459
|
+
model: The current model
|
460
|
+
task: The current task
|
461
|
+
optimizer: The current optimizer
|
462
|
+
lr_scheduler: The current learning rate scheduler
|
463
|
+
|
464
|
+
Raises:
|
465
|
+
ValueError: If the task is not a supervised learning task
|
466
|
+
"""
|
467
|
+
with contextlib.ExitStack() as ctx:
|
468
|
+
self.set_loggers()
|
469
|
+
|
470
|
+
if is_master():
|
471
|
+
Thread(target=self.log_state, daemon=True).start()
|
472
|
+
|
473
|
+
# Gets the datasets.
|
474
|
+
with self.step_context("get_dataset"):
|
475
|
+
train_ds = self.get_dataset("train")
|
476
|
+
valid_ds = self.get_dataset("valid")
|
477
|
+
|
478
|
+
# Gets the dataloaders.
|
479
|
+
with self.step_context("get_dataloader"):
|
480
|
+
train_dl = self.get_dataloader(train_ds, "train")
|
481
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
482
|
+
|
483
|
+
# Gets the prefetchers.
|
484
|
+
with self.step_context("get_prefetcher"):
|
485
|
+
train_pf = self.get_prefetcher(train_dl)
|
486
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
487
|
+
|
488
|
+
ctx.enter_context(self)
|
489
|
+
ctx.enter_context(train_pf)
|
490
|
+
ctx.enter_context(valid_pf)
|
491
|
+
|
492
|
+
model, optimizer, opt_state, state = self.load_initial_state()
|
493
|
+
|
494
|
+
state = self.on_training_start(state)
|
495
|
+
|
496
|
+
def on_exit() -> None:
|
497
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
498
|
+
|
499
|
+
# Handle user-defined interrupts during the training loop.
|
500
|
+
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
501
|
+
|
502
|
+
try:
|
503
|
+
while True:
|
504
|
+
while True:
|
505
|
+
if self.is_training_over(state):
|
506
|
+
raise TrainingFinishedError
|
507
|
+
|
508
|
+
if self.valid_step_timer.is_valid_step(state):
|
509
|
+
model, state = self.val_step(model, next(valid_pf), state)
|
510
|
+
|
511
|
+
with self.step_context("on_step_start"):
|
512
|
+
state = self.on_step_start(state)
|
513
|
+
|
514
|
+
model, opt_state, state = self.train_step(model, optimizer, opt_state, next(train_pf), state)
|
515
|
+
|
516
|
+
with self.step_context("on_step_end"):
|
517
|
+
state = self.on_step_end(state)
|
518
|
+
|
519
|
+
if self.should_checkpoint(state):
|
520
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
521
|
+
|
522
|
+
except TrainingFinishedError:
|
523
|
+
if is_master():
|
524
|
+
show_info(
|
525
|
+
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
526
|
+
important=True,
|
527
|
+
)
|
528
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
529
|
+
|
530
|
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
531
|
+
if is_master():
|
532
|
+
show_info("Interrupted training", important=True)
|
533
|
+
|
534
|
+
except BaseException:
|
535
|
+
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
536
|
+
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
537
|
+
sys.stdout.flush()
|
538
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
539
|
+
|
540
|
+
finally:
|
541
|
+
state = self.on_training_end(state)
|
xax/task/script.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
"""Composes various mixins into a single script interface."""
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
from xax.task.base import BaseConfig, BaseTask
|
7
|
+
from xax.task.mixins import (
|
8
|
+
ArtifactsConfig,
|
9
|
+
ArtifactsMixin,
|
10
|
+
CPUStatsConfig,
|
11
|
+
CPUStatsMixin,
|
12
|
+
GPUStatsConfig,
|
13
|
+
GPUStatsMixin,
|
14
|
+
LoggerConfig,
|
15
|
+
LoggerMixin,
|
16
|
+
ProcessConfig,
|
17
|
+
ProcessMixin,
|
18
|
+
RunnableConfig,
|
19
|
+
RunnableMixin,
|
20
|
+
StepContextConfig,
|
21
|
+
StepContextMixin,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class ScriptConfig(
|
27
|
+
CPUStatsConfig,
|
28
|
+
GPUStatsConfig,
|
29
|
+
ProcessConfig,
|
30
|
+
LoggerConfig,
|
31
|
+
StepContextConfig,
|
32
|
+
ArtifactsConfig,
|
33
|
+
RunnableConfig,
|
34
|
+
BaseConfig,
|
35
|
+
):
|
36
|
+
pass
|
37
|
+
|
38
|
+
|
39
|
+
ConfigT = TypeVar("ConfigT", bound=ScriptConfig)
|
40
|
+
|
41
|
+
|
42
|
+
class Script(
|
43
|
+
CPUStatsMixin[ConfigT],
|
44
|
+
GPUStatsMixin[ConfigT],
|
45
|
+
ProcessMixin[ConfigT],
|
46
|
+
LoggerMixin[ConfigT],
|
47
|
+
StepContextMixin[ConfigT],
|
48
|
+
ArtifactsMixin[ConfigT],
|
49
|
+
RunnableMixin[ConfigT],
|
50
|
+
BaseTask[ConfigT],
|
51
|
+
Generic[ConfigT],
|
52
|
+
):
|
53
|
+
pass
|
xax/task/task.py
ADDED
@@ -0,0 +1,65 @@
|
|
1
|
+
"""Composes the base task with all the mixins into a single task interface."""
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
from xax.task.base import BaseConfig, BaseTask
|
7
|
+
from xax.task.mixins import (
|
8
|
+
ArtifactsConfig,
|
9
|
+
ArtifactsMixin,
|
10
|
+
CheckpointingConfig,
|
11
|
+
CheckpointingMixin,
|
12
|
+
CPUStatsConfig,
|
13
|
+
CPUStatsMixin,
|
14
|
+
DataloadersConfig,
|
15
|
+
DataloadersMixin,
|
16
|
+
GPUStatsConfig,
|
17
|
+
GPUStatsMixin,
|
18
|
+
LoggerConfig,
|
19
|
+
LoggerMixin,
|
20
|
+
ProcessConfig,
|
21
|
+
ProcessMixin,
|
22
|
+
RunnableConfig,
|
23
|
+
RunnableMixin,
|
24
|
+
StepContextConfig,
|
25
|
+
StepContextMixin,
|
26
|
+
TrainConfig,
|
27
|
+
TrainMixin,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class Config(
|
33
|
+
TrainConfig,
|
34
|
+
CheckpointingConfig,
|
35
|
+
DataloadersConfig,
|
36
|
+
CPUStatsConfig,
|
37
|
+
GPUStatsConfig,
|
38
|
+
ProcessConfig,
|
39
|
+
LoggerConfig,
|
40
|
+
StepContextConfig,
|
41
|
+
ArtifactsConfig,
|
42
|
+
RunnableConfig,
|
43
|
+
BaseConfig,
|
44
|
+
):
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
ConfigT = TypeVar("ConfigT", bound=Config)
|
49
|
+
|
50
|
+
|
51
|
+
class Task(
|
52
|
+
TrainMixin[ConfigT],
|
53
|
+
CheckpointingMixin[ConfigT],
|
54
|
+
DataloadersMixin[ConfigT],
|
55
|
+
CPUStatsMixin[ConfigT],
|
56
|
+
GPUStatsMixin[ConfigT],
|
57
|
+
ProcessMixin[ConfigT],
|
58
|
+
LoggerMixin[ConfigT],
|
59
|
+
StepContextMixin[ConfigT],
|
60
|
+
ArtifactsMixin[ConfigT],
|
61
|
+
RunnableMixin[ConfigT],
|
62
|
+
BaseTask[ConfigT],
|
63
|
+
Generic[ConfigT],
|
64
|
+
):
|
65
|
+
pass
|
xax/utils/__init__.py
ADDED
File without changes
|
File without changes
|