xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/task/mixins/train.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
"""Defines a mixin for running the training loop."""
|
2
2
|
|
3
|
+
import bdb
|
3
4
|
import contextlib
|
4
5
|
import functools
|
5
6
|
import itertools
|
6
7
|
import logging
|
8
|
+
import signal
|
7
9
|
import sys
|
8
10
|
import textwrap
|
9
11
|
import time
|
@@ -11,20 +13,31 @@ import traceback
|
|
11
13
|
from abc import ABC, abstractmethod
|
12
14
|
from dataclasses import dataclass, is_dataclass
|
13
15
|
from threading import Thread
|
14
|
-
from typing import
|
16
|
+
from typing import (
|
17
|
+
Any,
|
18
|
+
Generator,
|
19
|
+
Generic,
|
20
|
+
Iterator,
|
21
|
+
Literal,
|
22
|
+
Mapping,
|
23
|
+
Sequence,
|
24
|
+
TypeVar,
|
25
|
+
cast,
|
26
|
+
get_args,
|
27
|
+
)
|
15
28
|
|
16
29
|
import equinox as eqx
|
17
30
|
import jax
|
18
|
-
import jax.numpy as jnp
|
19
31
|
import numpy as np
|
20
32
|
import optax
|
21
|
-
from jaxtyping import Array
|
33
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree
|
22
34
|
from omegaconf import DictConfig
|
23
35
|
|
24
36
|
from xax.core.conf import field
|
25
37
|
from xax.core.state import Phase, State
|
26
38
|
from xax.nn.parallel import is_master
|
27
39
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
40
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
28
41
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
29
42
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
30
43
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -32,6 +45,8 @@ from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
|
32
45
|
from xax.utils.experiments import (
|
33
46
|
StateTimer,
|
34
47
|
TrainingFinishedError,
|
48
|
+
diff_configs,
|
49
|
+
get_diff_string,
|
35
50
|
get_git_state,
|
36
51
|
get_training_code,
|
37
52
|
)
|
@@ -40,9 +55,11 @@ from xax.utils.text import highlight_exception_message, show_info
|
|
40
55
|
|
41
56
|
logger = logging.getLogger(__name__)
|
42
57
|
|
43
|
-
|
44
|
-
|
45
|
-
|
58
|
+
# Batch = TypeVar("Batch")
|
59
|
+
# Output = TypeVar("Output")
|
60
|
+
|
61
|
+
Batch = Any
|
62
|
+
Output = Any
|
46
63
|
|
47
64
|
StepKind = Literal["step", "sample", "second"]
|
48
65
|
|
@@ -123,8 +140,10 @@ class ValidStepTimer:
|
|
123
140
|
return False
|
124
141
|
|
125
142
|
|
143
|
+
@jax.tree_util.register_dataclass
|
126
144
|
@dataclass
|
127
145
|
class TrainConfig(
|
146
|
+
CheckpointingConfig,
|
128
147
|
DataloadersConfig,
|
129
148
|
LoggerConfig,
|
130
149
|
StepContextConfig,
|
@@ -145,12 +164,13 @@ Config = TypeVar("Config", bound=TrainConfig)
|
|
145
164
|
|
146
165
|
|
147
166
|
class TrainMixin(
|
167
|
+
CheckpointingMixin[Config],
|
148
168
|
DataloadersMixin[Config],
|
149
169
|
LoggerMixin[Config],
|
150
170
|
StepContextMixin[Config],
|
151
171
|
ArtifactsMixin[Config],
|
152
172
|
RunnableMixin[Config],
|
153
|
-
Generic[Config
|
173
|
+
Generic[Config],
|
154
174
|
ABC,
|
155
175
|
):
|
156
176
|
valid_step_timer: ValidStepTimer
|
@@ -159,7 +179,6 @@ class TrainMixin(
|
|
159
179
|
_training_over_flag: bool
|
160
180
|
_last_printed_remaining_time: float
|
161
181
|
_step_kind: StepKind
|
162
|
-
_prng_key: jnp.ndarray
|
163
182
|
|
164
183
|
def __init__(self, config: Config) -> None:
|
165
184
|
super().__init__(config)
|
@@ -183,22 +202,15 @@ class TrainMixin(
|
|
183
202
|
# The kind of step that was specified in the config.
|
184
203
|
self._step_kind = cast_step_kind(self.config.step_kind)
|
185
204
|
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
@property
|
190
|
-
def prng_key(self) -> jnp.ndarray:
|
191
|
-
return self._prng_key
|
205
|
+
def prng_key(self) -> PRNGKeyArray:
|
206
|
+
return jax.random.PRNGKey(self.config.random_seed)
|
192
207
|
|
193
208
|
def on_step_end(self, state: State) -> State:
|
194
209
|
state = super().on_step_end(state)
|
195
|
-
|
196
|
-
|
197
|
-
"elapsed_time_s": time.time() - state.start_time_s,
|
198
|
-
},
|
199
|
-
)
|
210
|
+
state.elapsed_time_s = time.time() - state.start_time_s
|
211
|
+
return state
|
200
212
|
|
201
|
-
def log_train_step(self, model:
|
213
|
+
def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
202
214
|
"""Override this function to do logging during the training phase.
|
203
215
|
|
204
216
|
This function is called after the model forward pass and before the
|
@@ -211,7 +223,7 @@ class TrainMixin(
|
|
211
223
|
state: The current training state.
|
212
224
|
"""
|
213
225
|
|
214
|
-
def log_valid_step(self, model:
|
226
|
+
def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
215
227
|
"""Override this function to do logging during the validation phase.
|
216
228
|
|
217
229
|
This function is called after the model forward pass. It is called in
|
@@ -224,15 +236,18 @@ class TrainMixin(
|
|
224
236
|
state: The current training state.
|
225
237
|
"""
|
226
238
|
|
227
|
-
def
|
228
|
-
|
229
|
-
|
230
|
-
# Log the state timers.
|
231
|
-
timer = self.state_timers[phase]
|
239
|
+
def log_state_timers(self, state: State) -> None:
|
240
|
+
timer = self.state_timers[state.phase]
|
232
241
|
timer.step(state)
|
233
242
|
for ns, d in timer.log_dict().items():
|
234
243
|
for k, v in d.items():
|
235
|
-
self.log_scalar(k, v, namespace=ns)
|
244
|
+
self.logger.log_scalar(k, v, namespace=ns)
|
245
|
+
|
246
|
+
def log_step(self, model: PyTree, batch: Batch, output: Output, loss: Array, state: State) -> None:
|
247
|
+
phase = state.phase
|
248
|
+
|
249
|
+
self.logger.log_scalar("loss", loss, namespace="loss")
|
250
|
+
self.log_state_timers(state)
|
236
251
|
|
237
252
|
# Delegate to the appropriate logging function based on the phase.
|
238
253
|
match phase:
|
@@ -243,8 +258,10 @@ class TrainMixin(
|
|
243
258
|
case _:
|
244
259
|
raise KeyError(f"Unknown phase: {phase}")
|
245
260
|
|
261
|
+
self.write_logs(state)
|
262
|
+
|
246
263
|
@abstractmethod
|
247
|
-
def get_model(self) ->
|
264
|
+
def get_model(self, key: PRNGKeyArray) -> PyTree:
|
248
265
|
"""Returns the Equinox model to train.
|
249
266
|
|
250
267
|
Returns:
|
@@ -259,11 +276,37 @@ class TrainMixin(
|
|
259
276
|
The optimizer to use to train the model.
|
260
277
|
"""
|
261
278
|
|
262
|
-
def get_initial_opt_state(self, model:
|
279
|
+
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
263
280
|
return optimizer.init(eqx.filter(model, eqx.is_array))
|
264
281
|
|
265
|
-
|
266
|
-
|
282
|
+
def load_initial_state(
|
283
|
+
self,
|
284
|
+
key: PRNGKeyArray,
|
285
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
286
|
+
init_ckpt_path = self.get_init_ckpt_path()
|
287
|
+
|
288
|
+
if init_ckpt_path is not None:
|
289
|
+
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
290
|
+
with self.step_context("load_checkpoint"):
|
291
|
+
model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
|
292
|
+
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
293
|
+
if config_diff:
|
294
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
295
|
+
return model, optimizer, opt_state, state
|
296
|
+
|
297
|
+
with self.step_context("get_model"):
|
298
|
+
model = self.get_model(key)
|
299
|
+
|
300
|
+
with self.step_context("get_optimizer"):
|
301
|
+
optimizer = self.get_optimizer()
|
302
|
+
|
303
|
+
with self.step_context("get_initial_opt_state"):
|
304
|
+
opt_state = self.get_initial_opt_state(model, optimizer)
|
305
|
+
|
306
|
+
return model, optimizer, opt_state, State.init_state()
|
307
|
+
|
308
|
+
@eqx.filter_jit
|
309
|
+
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
267
310
|
"""Gets the output from the model.
|
268
311
|
|
269
312
|
By default, we assume the model is a function that takes the batch as
|
@@ -273,10 +316,11 @@ class TrainMixin(
|
|
273
316
|
Args:
|
274
317
|
model: The current model.
|
275
318
|
batch: The current minibatch of samples.
|
276
|
-
state: The current training state.
|
277
319
|
"""
|
320
|
+
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
278
321
|
|
279
|
-
|
322
|
+
@eqx.filter_jit
|
323
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
|
280
324
|
"""Gets the loss for the current batch.
|
281
325
|
|
282
326
|
By default, we assume the model is a function that takes the batch as
|
@@ -287,7 +331,6 @@ class TrainMixin(
|
|
287
331
|
model: The current model.
|
288
332
|
batch: The current minibatch of samples.
|
289
333
|
output: The output from the model.
|
290
|
-
state: The current training state.
|
291
334
|
|
292
335
|
Returns:
|
293
336
|
The computed loss, as a tensor.
|
@@ -296,22 +339,22 @@ class TrainMixin(
|
|
296
339
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
297
340
|
return output
|
298
341
|
|
299
|
-
|
300
|
-
|
301
|
-
|
342
|
+
@eqx.filter_jit
|
343
|
+
def get_output_and_loss(self, model: PyTree, batch: Batch) -> tuple[Array, Output]:
|
344
|
+
output = self.get_output(model, batch)
|
345
|
+
loss = self.compute_loss(model, batch, output)
|
302
346
|
return loss, output
|
303
347
|
|
304
348
|
@eqx.filter_jit
|
305
349
|
def update(
|
306
350
|
self,
|
307
|
-
model:
|
351
|
+
model: PyTree,
|
308
352
|
optimizer: optax.GradientTransformation,
|
309
353
|
opt_state: optax.OptState,
|
310
354
|
batch: Batch,
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
updates, opt_state = optimizer.update(grads, opt_state)
|
355
|
+
) -> tuple[Array, PyTree, optax.OptState, Output]:
|
356
|
+
(loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch)
|
357
|
+
updates, opt_state = optimizer.update(grads, opt_state, model)
|
315
358
|
model = eqx.apply_updates(model, updates)
|
316
359
|
return loss, model, opt_state, output
|
317
360
|
|
@@ -350,7 +393,13 @@ class TrainMixin(
|
|
350
393
|
self._last_printed_remaining_time = state.elapsed_time_s
|
351
394
|
remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
|
352
395
|
termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
|
353
|
-
logger.info("Estimated finish time: %s", termination_time)
|
396
|
+
# logger.info("Estimated finish time: %s", termination_time)
|
397
|
+
jax.debug.print("Estimated finish time: {}", termination_time)
|
398
|
+
|
399
|
+
def get_remaining_percent(self, state: State) -> float | None:
|
400
|
+
if self.config.max_steps is None:
|
401
|
+
return None
|
402
|
+
return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
|
354
403
|
|
355
404
|
def is_training_over(self, state: State) -> bool:
|
356
405
|
if self._training_over_flag:
|
@@ -358,7 +407,6 @@ class TrainMixin(
|
|
358
407
|
remaining_percent = self.get_remaining_percent(state)
|
359
408
|
if remaining_percent is None:
|
360
409
|
return False
|
361
|
-
self.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
|
362
410
|
self.maybe_log_termination_time(remaining_percent, state)
|
363
411
|
return remaining_percent <= 0.0
|
364
412
|
|
@@ -373,59 +421,124 @@ class TrainMixin(
|
|
373
421
|
case _:
|
374
422
|
raise ValueError(f"Invalid step kind {self._step_kind}")
|
375
423
|
|
376
|
-
def get_remaining_percent(self, state: State) -> float | None:
|
377
|
-
if self.config.max_steps is None:
|
378
|
-
return None
|
379
|
-
return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
|
380
|
-
|
381
424
|
def log_state(self) -> None:
|
382
425
|
logger.log(LOG_STATUS, self.task_path)
|
383
426
|
logger.log(LOG_STATUS, self.task_name)
|
384
|
-
self.logger.
|
385
|
-
self.logger.
|
386
|
-
self.logger.
|
427
|
+
self.logger.log_file("git_state.txt", get_git_state(self))
|
428
|
+
self.logger.log_file("training_code.txt", get_training_code(self))
|
429
|
+
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
387
430
|
|
431
|
+
@eqx.filter_jit
|
388
432
|
def train_step(
|
389
433
|
self,
|
390
|
-
model:
|
434
|
+
model: PyTree,
|
391
435
|
optimizer: optax.GradientTransformation,
|
392
436
|
opt_state: optax.OptState,
|
393
437
|
batch: Batch,
|
438
|
+
) -> tuple[PyTree, optax.OptState, Array, Output]:
|
439
|
+
loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch)
|
440
|
+
return model, opt_state, loss, output
|
441
|
+
|
442
|
+
@eqx.filter_jit
|
443
|
+
def val_step(self, model: PyTree, batch: Batch) -> tuple[PyTree, Array, Output]:
|
444
|
+
loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch)
|
445
|
+
return model, loss, output
|
446
|
+
|
447
|
+
def train_loop(
|
448
|
+
self,
|
449
|
+
model: PyTree,
|
450
|
+
optimizer: optax.GradientTransformation,
|
451
|
+
opt_state: optax.OptState,
|
452
|
+
train_pf: Iterator[Batch],
|
453
|
+
valid_pf: Iterator[Batch],
|
394
454
|
state: State,
|
395
|
-
) ->
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
455
|
+
) -> None:
|
456
|
+
while not self.is_training_over(state):
|
457
|
+
if self.valid_step_timer.is_valid_step(state):
|
458
|
+
valid_batch = next(valid_pf)
|
459
|
+
model, loss, output = self.val_step(model, valid_batch)
|
460
|
+
|
461
|
+
# Perform logging.
|
462
|
+
with self.step_context("write_logs"):
|
463
|
+
state.phase = "valid"
|
464
|
+
self.log_step(model, valid_batch, output, loss, state)
|
465
|
+
state.num_valid_samples += 1
|
466
|
+
|
467
|
+
with self.step_context("on_step_start"):
|
468
|
+
state = self.on_step_start(state)
|
469
|
+
|
470
|
+
with self.step_context("update_state"):
|
471
|
+
train_batch = next(train_pf)
|
472
|
+
model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
|
473
|
+
|
474
|
+
# Perform logging.
|
475
|
+
with self.step_context("write_logs"):
|
476
|
+
state.phase = "train"
|
477
|
+
self.log_step(model, train_batch, output, loss, state)
|
478
|
+
state.num_steps += 1
|
479
|
+
state.num_samples += self.get_size_of_batch(train_batch) or 0
|
480
|
+
|
481
|
+
with self.step_context("on_step_end"):
|
482
|
+
state = self.on_step_end(state)
|
483
|
+
|
484
|
+
if self.should_checkpoint(state):
|
485
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
486
|
+
|
487
|
+
# After finishing training, save the final checkpoint.
|
488
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
489
|
+
|
490
|
+
@contextlib.contextmanager
|
491
|
+
def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
492
|
+
try:
|
493
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train")
|
494
|
+
yield train_iterator
|
495
|
+
return
|
496
|
+
except NotImplementedError:
|
497
|
+
pass
|
411
498
|
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
self.
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
499
|
+
with self.step_context("get_dataset"):
|
500
|
+
train_ds = self.get_dataset("train")
|
501
|
+
|
502
|
+
with self.step_context("get_dataloader"):
|
503
|
+
train_dl = self.get_dataloader(train_ds, "train")
|
504
|
+
|
505
|
+
with self.step_context("get_prefetcher"):
|
506
|
+
train_pf = self.get_prefetcher(train_dl)
|
507
|
+
|
508
|
+
try:
|
509
|
+
with train_pf as train_pf_ctx:
|
510
|
+
yield train_pf_ctx
|
511
|
+
finally:
|
512
|
+
logger.info("Closing train prefetcher")
|
513
|
+
|
514
|
+
@contextlib.contextmanager
|
515
|
+
def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
516
|
+
try:
|
517
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
|
518
|
+
yield valid_iterator
|
519
|
+
return
|
520
|
+
except NotImplementedError:
|
521
|
+
pass
|
522
|
+
|
523
|
+
with self.step_context("get_dataset"):
|
524
|
+
valid_ds = self.get_dataset("valid")
|
525
|
+
|
526
|
+
with self.step_context("get_dataloader"):
|
527
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
528
|
+
|
529
|
+
with self.step_context("get_prefetcher"):
|
530
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
531
|
+
|
532
|
+
try:
|
533
|
+
with valid_pf as valid_pf_ctx:
|
534
|
+
yield valid_pf_ctx
|
535
|
+
finally:
|
536
|
+
logger.info("Closing valid prefetcher")
|
424
537
|
|
425
538
|
def run(self) -> None:
|
426
|
-
self.
|
539
|
+
self.run_training()
|
427
540
|
|
428
|
-
def
|
541
|
+
def run_training(self) -> None:
|
429
542
|
"""Runs the training loop.
|
430
543
|
|
431
544
|
Args:
|
@@ -437,74 +550,52 @@ class TrainMixin(
|
|
437
550
|
Raises:
|
438
551
|
ValueError: If the task is not a supervised learning task
|
439
552
|
"""
|
440
|
-
with
|
553
|
+
with self:
|
554
|
+
key = self.prng_key()
|
555
|
+
|
441
556
|
self.set_loggers()
|
442
557
|
|
443
558
|
if is_master():
|
444
559
|
Thread(target=self.log_state, daemon=True).start()
|
445
560
|
|
446
|
-
|
447
|
-
|
448
|
-
train_ds = self.get_dataset("train")
|
449
|
-
valid_ds = self.get_dataset("valid")
|
450
|
-
|
451
|
-
# Gets the dataloaders.
|
452
|
-
with self.step_context("get_dataloader"):
|
453
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
454
|
-
valid_dl = self.get_dataloader(valid_ds, "valid")
|
455
|
-
|
456
|
-
# Gets the prefetchers.
|
457
|
-
with self.step_context("get_prefetcher"):
|
458
|
-
train_pf = self.get_prefetcher(train_dl)
|
459
|
-
valid_pf = self.get_prefetcher(valid_dl)
|
460
|
-
|
461
|
-
ctx.enter_context(self)
|
462
|
-
ctx.enter_context(train_pf)
|
463
|
-
ctx.enter_context(valid_pf)
|
464
|
-
|
465
|
-
# Gets the model.
|
466
|
-
with self.step_context("get_model"):
|
467
|
-
model = self.get_model()
|
468
|
-
|
469
|
-
# Gets the optimizer.
|
470
|
-
with self.step_context("get_optimizer"):
|
471
|
-
optimizer = self.get_optimizer()
|
472
|
-
|
473
|
-
# Gets the initial optimizer state.
|
474
|
-
with self.step_context("get_initial_opt_state"):
|
475
|
-
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
|
476
|
-
|
477
|
-
state = State.init_state()
|
561
|
+
key, model_key = jax.random.split(key)
|
562
|
+
model, optimizer, opt_state, state = self.load_initial_state(model_key)
|
478
563
|
state = self.on_training_start(state)
|
479
564
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
state = self.on_step_end(state)
|
496
|
-
|
497
|
-
except TrainingFinishedError:
|
498
|
-
if is_master():
|
499
|
-
show_info(
|
500
|
-
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
501
|
-
important=True,
|
565
|
+
def on_exit() -> None:
|
566
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
567
|
+
|
568
|
+
# Handle user-defined interrupts during the training loop.
|
569
|
+
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
570
|
+
|
571
|
+
with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
|
572
|
+
try:
|
573
|
+
self.train_loop(
|
574
|
+
model=model,
|
575
|
+
optimizer=optimizer,
|
576
|
+
opt_state=opt_state,
|
577
|
+
train_pf=train_pf,
|
578
|
+
valid_pf=valid_pf,
|
579
|
+
state=state,
|
502
580
|
)
|
503
581
|
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
582
|
+
except TrainingFinishedError:
|
583
|
+
if is_master():
|
584
|
+
show_info(
|
585
|
+
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
586
|
+
important=True,
|
587
|
+
)
|
588
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
589
|
+
|
590
|
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
591
|
+
if is_master():
|
592
|
+
show_info("Interrupted training", important=True)
|
593
|
+
|
594
|
+
except BaseException:
|
595
|
+
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
596
|
+
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
597
|
+
sys.stdout.flush()
|
598
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
599
|
+
|
600
|
+
finally:
|
601
|
+
state = self.on_training_end(state)
|
xax/task/script.py
CHANGED
xax/task/task.py
CHANGED
@@ -3,11 +3,16 @@
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Generic, TypeVar
|
5
5
|
|
6
|
+
import jax
|
7
|
+
|
6
8
|
from xax.task.base import BaseConfig, BaseTask
|
7
9
|
from xax.task.mixins import (
|
8
10
|
ArtifactsConfig,
|
9
11
|
ArtifactsMixin,
|
10
|
-
|
12
|
+
CheckpointingConfig,
|
13
|
+
CheckpointingMixin,
|
14
|
+
CompileConfig,
|
15
|
+
CompileMixin,
|
11
16
|
CPUStatsConfig,
|
12
17
|
CPUStatsMixin,
|
13
18
|
DataloadersConfig,
|
@@ -16,8 +21,6 @@ from xax.task.mixins import (
|
|
16
21
|
GPUStatsMixin,
|
17
22
|
LoggerConfig,
|
18
23
|
LoggerMixin,
|
19
|
-
Model,
|
20
|
-
Output,
|
21
24
|
ProcessConfig,
|
22
25
|
ProcessMixin,
|
23
26
|
RunnableConfig,
|
@@ -29,9 +32,12 @@ from xax.task.mixins import (
|
|
29
32
|
)
|
30
33
|
|
31
34
|
|
35
|
+
@jax.tree_util.register_dataclass
|
32
36
|
@dataclass
|
33
37
|
class Config(
|
34
38
|
TrainConfig,
|
39
|
+
CheckpointingConfig,
|
40
|
+
CompileConfig,
|
35
41
|
DataloadersConfig,
|
36
42
|
CPUStatsConfig,
|
37
43
|
GPUStatsConfig,
|
@@ -49,7 +55,9 @@ ConfigT = TypeVar("ConfigT", bound=Config)
|
|
49
55
|
|
50
56
|
|
51
57
|
class Task(
|
52
|
-
TrainMixin[ConfigT
|
58
|
+
TrainMixin[ConfigT],
|
59
|
+
CheckpointingMixin[ConfigT],
|
60
|
+
CompileMixin[ConfigT],
|
53
61
|
DataloadersMixin[ConfigT],
|
54
62
|
CPUStatsMixin[ConfigT],
|
55
63
|
GPUStatsMixin[ConfigT],
|
@@ -59,6 +67,6 @@ class Task(
|
|
59
67
|
ArtifactsMixin[ConfigT],
|
60
68
|
RunnableMixin[ConfigT],
|
61
69
|
BaseTask[ConfigT],
|
62
|
-
Generic[ConfigT
|
70
|
+
Generic[ConfigT],
|
63
71
|
):
|
64
72
|
pass
|
xax/utils/data/collate.py
CHANGED
@@ -167,9 +167,9 @@ def collate(
|
|
167
167
|
# Collate dictionaries if they have the same keys.
|
168
168
|
if isinstance(item, dict) and all(set(i.keys()) == set(item.keys()) for i in items):
|
169
169
|
output_dict = {}
|
170
|
-
|
171
|
-
for
|
172
|
-
output_dict[
|
170
|
+
item_keys_set = set(item.keys())
|
171
|
+
for key_in_set in item_keys_set:
|
172
|
+
output_dict[key_in_set] = collate([i[key_in_set] for i in items], mode=mode, pad=pad)
|
173
173
|
return output_dict
|
174
174
|
|
175
175
|
# Collate lists and tuples if they have the same lengths.
|
@@ -186,9 +186,9 @@ def collate(
|
|
186
186
|
# Handles dataclasses.
|
187
187
|
if is_dataclass(item):
|
188
188
|
output_dict = {}
|
189
|
-
|
190
|
-
for
|
191
|
-
output_dict[
|
189
|
+
item_keys_dict = item.__dict__.keys()
|
190
|
+
for key_in_dict in item_keys_dict:
|
191
|
+
output_dict[key_in_dict] = collate([getattr(i, key_in_dict) for i in items], mode=mode, pad=pad)
|
192
192
|
return item.__class__(**output_dict)
|
193
193
|
|
194
194
|
# By default, don't do anything.
|