xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,541 @@
1
+ """Defines a mixin for running the training loop."""
2
+
3
+ import bdb
4
+ import contextlib
5
+ import functools
6
+ import itertools
7
+ import logging
8
+ import signal
9
+ import sys
10
+ import textwrap
11
+ import time
12
+ import traceback
13
+ from abc import ABC, abstractmethod
14
+ from dataclasses import dataclass, is_dataclass
15
+ from threading import Thread
16
+ from typing import Any, Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
17
+
18
+ import equinox as eqx
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ import optax
23
+ from jaxtyping import Array, PyTree
24
+ from omegaconf import DictConfig
25
+
26
+ from xax.core.conf import field
27
+ from xax.core.state import Phase, State
28
+ from xax.nn.parallel import is_master
29
+ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
30
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
31
+ from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
32
+ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
33
+ from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
34
+ from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
35
+ from xax.utils.experiments import (
36
+ StateTimer,
37
+ TrainingFinishedError,
38
+ diff_configs,
39
+ get_diff_string,
40
+ get_git_state,
41
+ get_training_code,
42
+ )
43
+ from xax.utils.logging import LOG_STATUS
44
+ from xax.utils.text import highlight_exception_message, show_info
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ # Batch = TypeVar("Batch")
49
+ # Output = TypeVar("Output")
50
+
51
+ Batch = Any
52
+ Output = Any
53
+
54
+ StepKind = Literal["step", "sample", "second"]
55
+
56
+ PRINT_FINISH_TIME_EVERY_N_SECONDS = 60 * 2
57
+
58
+
59
+ def cast_step_kind(s: str) -> StepKind:
60
+ assert s in get_args(StepKind), f"`step_kind` must be one of {get_args(StepKind)}, not {s}"
61
+ return cast(StepKind, s)
62
+
63
+
64
+ @functools.lru_cache(maxsize=None)
65
+ def batch_chunks_schedule(schedule: list[int] | None) -> list[int] | None:
66
+ if schedule is None:
67
+ return None
68
+ if any(s < 1 for s in schedule):
69
+ raise ValueError("Batch chunk schedule must be positive")
70
+ return list(itertools.accumulate([0] + schedule))
71
+
72
+
73
+ @functools.lru_cache(maxsize=None)
74
+ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
75
+ if schedule is None:
76
+ return None
77
+ if any(s < 1 for s in schedule):
78
+ raise ValueError("Batch chunk schedule must be positive")
79
+ return list(itertools.accumulate([0] + schedule))
80
+
81
+
82
+ class ValidStepTimer:
83
+ def __init__(
84
+ self,
85
+ valid_every_n_steps: int | None = None,
86
+ valid_first_n_steps: int = 0,
87
+ valid_every_n_seconds: float | None = None,
88
+ valid_first_n_seconds: float | None = None,
89
+ ) -> None:
90
+ super().__init__()
91
+
92
+ self.valid_every_n_steps = valid_every_n_steps
93
+ self.valid_first_n_steps = valid_first_n_steps
94
+ self.valid_every_n_seconds = valid_every_n_seconds
95
+ self.valid_first_n_seconds = valid_first_n_seconds
96
+ self.first_valid_step_flag = True
97
+
98
+ self.last_valid_time: float | None = None
99
+ self.last_valid_step: int | None = None
100
+
101
+ def is_valid_step(self, state: State) -> bool:
102
+ if state.num_steps < self.valid_first_n_steps:
103
+ return True
104
+
105
+ if self.last_valid_time is None or self.last_valid_step is None:
106
+ self.last_valid_time = state.elapsed_time_s
107
+ self.last_valid_step = state.num_steps
108
+ return True
109
+
110
+ # Step-based validation.
111
+ valid_every_n_steps = self.valid_every_n_steps
112
+ if valid_every_n_steps is not None and state.num_steps > valid_every_n_steps + self.last_valid_step:
113
+ self.last_valid_step = state.num_steps
114
+ return True
115
+
116
+ # Time-based validation.
117
+ valid_every_n_seconds = self.valid_every_n_seconds
118
+ if valid_every_n_seconds is not None and state.elapsed_time_s - self.last_valid_time >= valid_every_n_seconds:
119
+ self.last_valid_time = state.elapsed_time_s
120
+ return True
121
+
122
+ # Time-based validation for first validation step.
123
+ if self.first_valid_step_flag:
124
+ valid_first_n_seconds = self.valid_first_n_seconds
125
+ if valid_first_n_seconds is not None and state.elapsed_time_s >= valid_first_n_seconds:
126
+ self.last_valid_time = state.elapsed_time_s
127
+ self.first_valid_step_flag = False
128
+ return True
129
+
130
+ return False
131
+
132
+
133
+ @dataclass
134
+ class TrainConfig(
135
+ CheckpointingConfig,
136
+ DataloadersConfig,
137
+ LoggerConfig,
138
+ StepContextConfig,
139
+ ArtifactsConfig,
140
+ RunnableConfig,
141
+ ):
142
+ valid_every_n_steps: int | None = field(None, help="Number of training steps to run per validation step")
143
+ valid_first_n_steps: int = field(0, help="Treat the first N steps as validation steps")
144
+ valid_every_n_seconds: float | None = field(60.0 * 10.0, help="Run validation every N seconds")
145
+ valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
146
+ batch_dim: int = field(0, help="The batch dimension, for splitting batches into chunks")
147
+ max_steps: int | None = field(None, help="Maximum number of steps to run")
148
+ step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
149
+ random_seed: int = field(1337, help="Random seed for the task")
150
+
151
+
152
+ Config = TypeVar("Config", bound=TrainConfig)
153
+
154
+
155
+ class TrainMixin(
156
+ CheckpointingMixin[Config],
157
+ DataloadersMixin[Config],
158
+ LoggerMixin[Config],
159
+ StepContextMixin[Config],
160
+ ArtifactsMixin[Config],
161
+ RunnableMixin[Config],
162
+ Generic[Config],
163
+ ABC,
164
+ ):
165
+ valid_step_timer: ValidStepTimer
166
+ state_timers: dict[Phase, StateTimer]
167
+
168
+ _training_over_flag: bool
169
+ _last_printed_remaining_time: float
170
+ _step_kind: StepKind
171
+
172
+ def __init__(self, config: Config) -> None:
173
+ super().__init__(config)
174
+
175
+ # Timer for validation steps.
176
+ self.valid_step_timer = ValidStepTimer(
177
+ valid_every_n_steps=config.valid_every_n_steps,
178
+ valid_first_n_steps=config.valid_first_n_steps,
179
+ valid_every_n_seconds=config.valid_every_n_seconds,
180
+ valid_first_n_seconds=config.valid_first_n_seconds,
181
+ )
182
+
183
+ # Timers for iterations.
184
+ self.state_timers = {phase: StateTimer() for phase in get_args(Phase)}
185
+
186
+ # This flag can be toggled to end training from anywhere in the task.
187
+ self._training_over_flag = False
188
+
189
+ self._last_printed_remaining_time = 0.0
190
+
191
+ # The kind of step that was specified in the config.
192
+ self._step_kind = cast_step_kind(self.config.step_kind)
193
+
194
+ def prng_key(self) -> jnp.ndarray:
195
+ return jax.random.PRNGKey(self.config.random_seed)
196
+
197
+ def on_step_end(self, state: State) -> State:
198
+ state = super().on_step_end(state)
199
+ return state.replace(
200
+ {
201
+ "elapsed_time_s": time.time() - state.start_time_s,
202
+ },
203
+ )
204
+
205
+ def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
206
+ """Override this function to do logging during the training phase.
207
+
208
+ This function is called after the model forward pass and before the
209
+ backward pass. It is called in the training phase.
210
+
211
+ Args:
212
+ model: The current model.
213
+ batch: The batch from the dataloader.
214
+ output: The model output.
215
+ state: The current training state.
216
+ """
217
+
218
+ def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
219
+ """Override this function to do logging during the validation phase.
220
+
221
+ This function is called after the model forward pass. It is called in
222
+ the validation phase.
223
+
224
+ Args:
225
+ model: The current model.
226
+ batch: The batch from the dataloader.
227
+ output: The model output.
228
+ state: The current training state.
229
+ """
230
+
231
+ def log_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
232
+ phase = state.phase
233
+
234
+ # Log the state timers.
235
+ timer = self.state_timers[phase]
236
+ timer.step(state)
237
+ for ns, d in timer.log_dict().items():
238
+ for k, v in d.items():
239
+ self.logger.log_scalar(k, v, namespace=ns)
240
+
241
+ # Delegate to the appropriate logging function based on the phase.
242
+ match phase:
243
+ case "train":
244
+ self.log_train_step(model, batch, output, state)
245
+ case "valid":
246
+ self.log_valid_step(model, batch, output, state)
247
+ case _:
248
+ raise KeyError(f"Unknown phase: {phase}")
249
+
250
+ @abstractmethod
251
+ def get_model(self) -> PyTree:
252
+ """Returns the Equinox model to train.
253
+
254
+ Returns:
255
+ The model to train.
256
+ """
257
+
258
+ @abstractmethod
259
+ def get_optimizer(self) -> optax.GradientTransformation:
260
+ """Gets the optimizer for the model.
261
+
262
+ Returns:
263
+ The optimizer to use to train the model.
264
+ """
265
+
266
+ def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
267
+ return optimizer.init(eqx.filter(model, eqx.is_array))
268
+
269
+ def load_initial_state(self) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
270
+ init_ckpt_path = self.get_init_ckpt_path()
271
+
272
+ if init_ckpt_path is not None:
273
+ logger.info("Loading checkpoint from %s", init_ckpt_path)
274
+ with self.step_context("load_checkpoint"):
275
+ model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
276
+ config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
277
+ if config_diff:
278
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
279
+ return model, optimizer, opt_state, state
280
+
281
+ with self.step_context("get_model"):
282
+ model = self.get_model()
283
+
284
+ with self.step_context("get_optimizer"):
285
+ optimizer = self.get_optimizer()
286
+
287
+ with self.step_context("get_initial_opt_state"):
288
+ opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
289
+
290
+ return model, optimizer, opt_state, State.init_state()
291
+
292
+ @abstractmethod
293
+ def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
294
+ """Gets the output from the model.
295
+
296
+ By default, we assume the model is a function that takes the batch as
297
+ input and returns the loss. This function can be patched to do more
298
+ complex operations instead.
299
+
300
+ Args:
301
+ model: The current model.
302
+ batch: The current minibatch of samples.
303
+ state: The current training state.
304
+ """
305
+
306
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
307
+ """Gets the loss for the current batch.
308
+
309
+ By default, we assume the model is a function that takes the batch as
310
+ input and returns the loss. This function can be patched to do more
311
+ complex operations instead.
312
+
313
+ Args:
314
+ model: The current model.
315
+ batch: The current minibatch of samples.
316
+ output: The output from the model.
317
+ state: The current training state.
318
+
319
+ Returns:
320
+ The computed loss, as a tensor.
321
+ """
322
+ if not isinstance(output, Array):
323
+ raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
324
+ return output
325
+
326
+ def get_output_and_loss(self, model: PyTree, batch: Batch, state: State) -> tuple[Array, Output]:
327
+ output = self.get_output(model, batch, state)
328
+ loss = self.compute_loss(model, batch, output, state)
329
+ return loss, output
330
+
331
+ @eqx.filter_jit
332
+ def update(
333
+ self,
334
+ model: PyTree,
335
+ optimizer: optax.GradientTransformation,
336
+ opt_state: optax.OptState,
337
+ batch: Batch,
338
+ state: State,
339
+ ) -> tuple[Array, PyTree, optax.OptState, Output]:
340
+ (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
341
+ updates, opt_state = optimizer.update(grads, opt_state)
342
+ model = eqx.apply_updates(model, updates)
343
+ return loss, model, opt_state, output
344
+
345
+ def get_size_of_batch(self, batch: Batch) -> int | None:
346
+ """Gets the batch size for the current batch.
347
+
348
+ Args:
349
+ batch: The current minibatch of samples.
350
+
351
+ Returns:
352
+ The parsed batch size, or None if the batch size could not be
353
+ determined.
354
+ """
355
+ if isinstance(batch, (np.ndarray, Array)):
356
+ return batch.shape[0]
357
+ if is_dataclass(batch):
358
+ for v in batch.__dict__.values():
359
+ if bsz := self.get_size_of_batch(v):
360
+ return bsz
361
+ if isinstance(batch, Mapping):
362
+ for v in batch.values():
363
+ if bsz := self.get_size_of_batch(v):
364
+ return bsz
365
+ if isinstance(batch, Sequence):
366
+ for i in batch:
367
+ if bsz := self.get_size_of_batch(i):
368
+ return bsz
369
+ return None
370
+
371
+ def set_training_over(self) -> None:
372
+ self._training_over_flag = True
373
+
374
+ def maybe_log_termination_time(self, remaining_percent: float, state: State) -> None:
375
+ if self._last_printed_remaining_time + PRINT_FINISH_TIME_EVERY_N_SECONDS > state.elapsed_time_s:
376
+ return
377
+ self._last_printed_remaining_time = state.elapsed_time_s
378
+ remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
379
+ termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
380
+ logger.info("Estimated finish time: %s", termination_time)
381
+
382
+ def is_training_over(self, state: State) -> bool:
383
+ if self._training_over_flag:
384
+ return True
385
+ remaining_percent = self.get_remaining_percent(state)
386
+ if remaining_percent is None:
387
+ return False
388
+ self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
389
+ self.maybe_log_termination_time(remaining_percent, state)
390
+ return remaining_percent <= 0.0
391
+
392
+ def get_step(self, state: State) -> int:
393
+ match self._step_kind:
394
+ case "step":
395
+ return state.num_steps
396
+ case "sample":
397
+ return state.num_samples
398
+ case "second":
399
+ return int(state.elapsed_time_s)
400
+ case _:
401
+ raise ValueError(f"Invalid step kind {self._step_kind}")
402
+
403
+ def get_remaining_percent(self, state: State) -> float | None:
404
+ if self.config.max_steps is None:
405
+ return None
406
+ return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
407
+
408
+ def log_state(self) -> None:
409
+ logger.log(LOG_STATUS, self.task_path)
410
+ logger.log(LOG_STATUS, self.task_name)
411
+ self.logger.log_git_state(get_git_state(self))
412
+ self.logger.log_training_code(get_training_code(self))
413
+ self.logger.log_config(cast(DictConfig, self.config))
414
+
415
+ def train_step(
416
+ self,
417
+ model: PyTree,
418
+ optimizer: optax.GradientTransformation,
419
+ opt_state: optax.OptState,
420
+ batch: Batch,
421
+ state: State,
422
+ ) -> tuple[PyTree, optax.OptState, State]:
423
+ state = state.with_phase("train")
424
+ loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
425
+ self.logger.log_scalar("loss", loss, namespace="loss")
426
+ self.log_step(model, batch, output, state)
427
+ self.write_logs(state)
428
+ return (
429
+ model,
430
+ opt_state,
431
+ state.replace(
432
+ {
433
+ "num_steps": state.num_steps + 1,
434
+ "num_samples": state.num_samples + (self.get_size_of_batch(batch) or 0),
435
+ },
436
+ ),
437
+ )
438
+
439
+ def val_step(self, model: PyTree, batch: Batch, state: State) -> tuple[PyTree, State]:
440
+ state = state.with_phase("valid")
441
+ loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
442
+ self.logger.log_scalar("loss", loss, namespace="loss")
443
+ self.log_step(model, batch, output, state)
444
+ self.write_logs(state)
445
+ return model, state.replace(
446
+ {
447
+ "num_valid_steps": state.num_valid_steps + 1,
448
+ "num_valid_samples": state.num_valid_samples + (self.get_size_of_batch(batch) or 0),
449
+ },
450
+ )
451
+
452
+ def run(self) -> None:
453
+ self.run_training_loop()
454
+
455
+ def run_training_loop(self) -> None:
456
+ """Runs the training loop.
457
+
458
+ Args:
459
+ model: The current model
460
+ task: The current task
461
+ optimizer: The current optimizer
462
+ lr_scheduler: The current learning rate scheduler
463
+
464
+ Raises:
465
+ ValueError: If the task is not a supervised learning task
466
+ """
467
+ with contextlib.ExitStack() as ctx:
468
+ self.set_loggers()
469
+
470
+ if is_master():
471
+ Thread(target=self.log_state, daemon=True).start()
472
+
473
+ # Gets the datasets.
474
+ with self.step_context("get_dataset"):
475
+ train_ds = self.get_dataset("train")
476
+ valid_ds = self.get_dataset("valid")
477
+
478
+ # Gets the dataloaders.
479
+ with self.step_context("get_dataloader"):
480
+ train_dl = self.get_dataloader(train_ds, "train")
481
+ valid_dl = self.get_dataloader(valid_ds, "valid")
482
+
483
+ # Gets the prefetchers.
484
+ with self.step_context("get_prefetcher"):
485
+ train_pf = self.get_prefetcher(train_dl)
486
+ valid_pf = self.get_prefetcher(valid_dl)
487
+
488
+ ctx.enter_context(self)
489
+ ctx.enter_context(train_pf)
490
+ ctx.enter_context(valid_pf)
491
+
492
+ model, optimizer, opt_state, state = self.load_initial_state()
493
+
494
+ state = self.on_training_start(state)
495
+
496
+ def on_exit() -> None:
497
+ self.save_checkpoint(model, optimizer, opt_state, state)
498
+
499
+ # Handle user-defined interrupts during the training loop.
500
+ self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
501
+
502
+ try:
503
+ while True:
504
+ while True:
505
+ if self.is_training_over(state):
506
+ raise TrainingFinishedError
507
+
508
+ if self.valid_step_timer.is_valid_step(state):
509
+ model, state = self.val_step(model, next(valid_pf), state)
510
+
511
+ with self.step_context("on_step_start"):
512
+ state = self.on_step_start(state)
513
+
514
+ model, opt_state, state = self.train_step(model, optimizer, opt_state, next(train_pf), state)
515
+
516
+ with self.step_context("on_step_end"):
517
+ state = self.on_step_end(state)
518
+
519
+ if self.should_checkpoint(state):
520
+ self.save_checkpoint(model, optimizer, opt_state, state)
521
+
522
+ except TrainingFinishedError:
523
+ if is_master():
524
+ show_info(
525
+ f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
526
+ important=True,
527
+ )
528
+ self.save_checkpoint(model, optimizer, opt_state, state)
529
+
530
+ except (KeyboardInterrupt, bdb.BdbQuit):
531
+ if is_master():
532
+ show_info("Interrupted training", important=True)
533
+
534
+ except BaseException:
535
+ exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
536
+ sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
537
+ sys.stdout.flush()
538
+ self.save_checkpoint(model, optimizer, opt_state, state)
539
+
540
+ finally:
541
+ state = self.on_training_end(state)
xax/task/script.py ADDED
@@ -0,0 +1,53 @@
1
+ """Composes various mixins into a single script interface."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Generic, TypeVar
5
+
6
+ from xax.task.base import BaseConfig, BaseTask
7
+ from xax.task.mixins import (
8
+ ArtifactsConfig,
9
+ ArtifactsMixin,
10
+ CPUStatsConfig,
11
+ CPUStatsMixin,
12
+ GPUStatsConfig,
13
+ GPUStatsMixin,
14
+ LoggerConfig,
15
+ LoggerMixin,
16
+ ProcessConfig,
17
+ ProcessMixin,
18
+ RunnableConfig,
19
+ RunnableMixin,
20
+ StepContextConfig,
21
+ StepContextMixin,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class ScriptConfig(
27
+ CPUStatsConfig,
28
+ GPUStatsConfig,
29
+ ProcessConfig,
30
+ LoggerConfig,
31
+ StepContextConfig,
32
+ ArtifactsConfig,
33
+ RunnableConfig,
34
+ BaseConfig,
35
+ ):
36
+ pass
37
+
38
+
39
+ ConfigT = TypeVar("ConfigT", bound=ScriptConfig)
40
+
41
+
42
+ class Script(
43
+ CPUStatsMixin[ConfigT],
44
+ GPUStatsMixin[ConfigT],
45
+ ProcessMixin[ConfigT],
46
+ LoggerMixin[ConfigT],
47
+ StepContextMixin[ConfigT],
48
+ ArtifactsMixin[ConfigT],
49
+ RunnableMixin[ConfigT],
50
+ BaseTask[ConfigT],
51
+ Generic[ConfigT],
52
+ ):
53
+ pass
xax/task/task.py ADDED
@@ -0,0 +1,65 @@
1
+ """Composes the base task with all the mixins into a single task interface."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Generic, TypeVar
5
+
6
+ from xax.task.base import BaseConfig, BaseTask
7
+ from xax.task.mixins import (
8
+ ArtifactsConfig,
9
+ ArtifactsMixin,
10
+ CheckpointingConfig,
11
+ CheckpointingMixin,
12
+ CPUStatsConfig,
13
+ CPUStatsMixin,
14
+ DataloadersConfig,
15
+ DataloadersMixin,
16
+ GPUStatsConfig,
17
+ GPUStatsMixin,
18
+ LoggerConfig,
19
+ LoggerMixin,
20
+ ProcessConfig,
21
+ ProcessMixin,
22
+ RunnableConfig,
23
+ RunnableMixin,
24
+ StepContextConfig,
25
+ StepContextMixin,
26
+ TrainConfig,
27
+ TrainMixin,
28
+ )
29
+
30
+
31
+ @dataclass
32
+ class Config(
33
+ TrainConfig,
34
+ CheckpointingConfig,
35
+ DataloadersConfig,
36
+ CPUStatsConfig,
37
+ GPUStatsConfig,
38
+ ProcessConfig,
39
+ LoggerConfig,
40
+ StepContextConfig,
41
+ ArtifactsConfig,
42
+ RunnableConfig,
43
+ BaseConfig,
44
+ ):
45
+ pass
46
+
47
+
48
+ ConfigT = TypeVar("ConfigT", bound=Config)
49
+
50
+
51
+ class Task(
52
+ TrainMixin[ConfigT],
53
+ CheckpointingMixin[ConfigT],
54
+ DataloadersMixin[ConfigT],
55
+ CPUStatsMixin[ConfigT],
56
+ GPUStatsMixin[ConfigT],
57
+ ProcessMixin[ConfigT],
58
+ LoggerMixin[ConfigT],
59
+ StepContextMixin[ConfigT],
60
+ ArtifactsMixin[ConfigT],
61
+ RunnableMixin[ConfigT],
62
+ BaseTask[ConfigT],
63
+ Generic[ConfigT],
64
+ ):
65
+ pass
xax/utils/__init__.py ADDED
File without changes
File without changes