xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/mixins/train.py CHANGED
@@ -1,9 +1,11 @@
1
1
  """Defines a mixin for running the training loop."""
2
2
 
3
+ import bdb
3
4
  import contextlib
4
5
  import functools
5
6
  import itertools
6
7
  import logging
8
+ import signal
7
9
  import sys
8
10
  import textwrap
9
11
  import time
@@ -11,20 +13,31 @@ import traceback
11
13
  from abc import ABC, abstractmethod
12
14
  from dataclasses import dataclass, is_dataclass
13
15
  from threading import Thread
14
- from typing import Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
16
+ from typing import (
17
+ Any,
18
+ Generator,
19
+ Generic,
20
+ Iterator,
21
+ Literal,
22
+ Mapping,
23
+ Sequence,
24
+ TypeVar,
25
+ cast,
26
+ get_args,
27
+ )
15
28
 
16
29
  import equinox as eqx
17
30
  import jax
18
- import jax.numpy as jnp
19
31
  import numpy as np
20
32
  import optax
21
- from jaxtyping import Array
33
+ from jaxtyping import Array, PRNGKeyArray, PyTree
22
34
  from omegaconf import DictConfig
23
35
 
24
36
  from xax.core.conf import field
25
37
  from xax.core.state import Phase, State
26
38
  from xax.nn.parallel import is_master
27
39
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
40
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
28
41
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
29
42
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
30
43
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -32,6 +45,8 @@ from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
32
45
  from xax.utils.experiments import (
33
46
  StateTimer,
34
47
  TrainingFinishedError,
48
+ diff_configs,
49
+ get_diff_string,
35
50
  get_git_state,
36
51
  get_training_code,
37
52
  )
@@ -40,9 +55,11 @@ from xax.utils.text import highlight_exception_message, show_info
40
55
 
41
56
  logger = logging.getLogger(__name__)
42
57
 
43
- Model = TypeVar("Model", bound=eqx.Module)
44
- Batch = TypeVar("Batch")
45
- Output = TypeVar("Output")
58
+ # Batch = TypeVar("Batch")
59
+ # Output = TypeVar("Output")
60
+
61
+ Batch = Any
62
+ Output = Any
46
63
 
47
64
  StepKind = Literal["step", "sample", "second"]
48
65
 
@@ -123,8 +140,10 @@ class ValidStepTimer:
123
140
  return False
124
141
 
125
142
 
143
+ @jax.tree_util.register_dataclass
126
144
  @dataclass
127
145
  class TrainConfig(
146
+ CheckpointingConfig,
128
147
  DataloadersConfig,
129
148
  LoggerConfig,
130
149
  StepContextConfig,
@@ -145,12 +164,13 @@ Config = TypeVar("Config", bound=TrainConfig)
145
164
 
146
165
 
147
166
  class TrainMixin(
167
+ CheckpointingMixin[Config],
148
168
  DataloadersMixin[Config],
149
169
  LoggerMixin[Config],
150
170
  StepContextMixin[Config],
151
171
  ArtifactsMixin[Config],
152
172
  RunnableMixin[Config],
153
- Generic[Config, Model, Batch, Output],
173
+ Generic[Config],
154
174
  ABC,
155
175
  ):
156
176
  valid_step_timer: ValidStepTimer
@@ -159,7 +179,6 @@ class TrainMixin(
159
179
  _training_over_flag: bool
160
180
  _last_printed_remaining_time: float
161
181
  _step_kind: StepKind
162
- _prng_key: jnp.ndarray
163
182
 
164
183
  def __init__(self, config: Config) -> None:
165
184
  super().__init__(config)
@@ -183,22 +202,15 @@ class TrainMixin(
183
202
  # The kind of step that was specified in the config.
184
203
  self._step_kind = cast_step_kind(self.config.step_kind)
185
204
 
186
- # Defines a PRNG key for the task.
187
- self._prng_key = jax.random.PRNGKey(self.config.random_seed)
188
-
189
- @property
190
- def prng_key(self) -> jnp.ndarray:
191
- return self._prng_key
205
+ def prng_key(self) -> PRNGKeyArray:
206
+ return jax.random.PRNGKey(self.config.random_seed)
192
207
 
193
208
  def on_step_end(self, state: State) -> State:
194
209
  state = super().on_step_end(state)
195
- return state.replace(
196
- {
197
- "elapsed_time_s": time.time() - state.start_time_s,
198
- },
199
- )
210
+ state.elapsed_time_s = time.time() - state.start_time_s
211
+ return state
200
212
 
201
- def log_train_step(self, model: Model, batch: Batch, output: Output, state: State) -> None:
213
+ def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
202
214
  """Override this function to do logging during the training phase.
203
215
 
204
216
  This function is called after the model forward pass and before the
@@ -211,7 +223,7 @@ class TrainMixin(
211
223
  state: The current training state.
212
224
  """
213
225
 
214
- def log_valid_step(self, model: Model, batch: Batch, output: Output, state: State) -> None:
226
+ def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
215
227
  """Override this function to do logging during the validation phase.
216
228
 
217
229
  This function is called after the model forward pass. It is called in
@@ -224,15 +236,18 @@ class TrainMixin(
224
236
  state: The current training state.
225
237
  """
226
238
 
227
- def log_step(self, model: Model, batch: Batch, output: Output, state: State) -> None:
228
- phase = state.phase
229
-
230
- # Log the state timers.
231
- timer = self.state_timers[phase]
239
+ def log_state_timers(self, state: State) -> None:
240
+ timer = self.state_timers[state.phase]
232
241
  timer.step(state)
233
242
  for ns, d in timer.log_dict().items():
234
243
  for k, v in d.items():
235
- self.log_scalar(k, v, namespace=ns)
244
+ self.logger.log_scalar(k, v, namespace=ns)
245
+
246
+ def log_step(self, model: PyTree, batch: Batch, output: Output, loss: Array, state: State) -> None:
247
+ phase = state.phase
248
+
249
+ self.logger.log_scalar("loss", loss, namespace="loss")
250
+ self.log_state_timers(state)
236
251
 
237
252
  # Delegate to the appropriate logging function based on the phase.
238
253
  match phase:
@@ -243,8 +258,10 @@ class TrainMixin(
243
258
  case _:
244
259
  raise KeyError(f"Unknown phase: {phase}")
245
260
 
261
+ self.write_logs(state)
262
+
246
263
  @abstractmethod
247
- def get_model(self) -> Model:
264
+ def get_model(self, key: PRNGKeyArray) -> PyTree:
248
265
  """Returns the Equinox model to train.
249
266
 
250
267
  Returns:
@@ -259,11 +276,37 @@ class TrainMixin(
259
276
  The optimizer to use to train the model.
260
277
  """
261
278
 
262
- def get_initial_opt_state(self, model: Model, optimizer: optax.GradientTransformation) -> optax.OptState:
279
+ def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
263
280
  return optimizer.init(eqx.filter(model, eqx.is_array))
264
281
 
265
- @abstractmethod
266
- def get_output(self, model: Model, batch: Batch, state: State) -> Output:
282
+ def load_initial_state(
283
+ self,
284
+ key: PRNGKeyArray,
285
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
286
+ init_ckpt_path = self.get_init_ckpt_path()
287
+
288
+ if init_ckpt_path is not None:
289
+ logger.info("Loading checkpoint from %s", init_ckpt_path)
290
+ with self.step_context("load_checkpoint"):
291
+ model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
292
+ config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
293
+ if config_diff:
294
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
295
+ return model, optimizer, opt_state, state
296
+
297
+ with self.step_context("get_model"):
298
+ model = self.get_model(key)
299
+
300
+ with self.step_context("get_optimizer"):
301
+ optimizer = self.get_optimizer()
302
+
303
+ with self.step_context("get_initial_opt_state"):
304
+ opt_state = self.get_initial_opt_state(model, optimizer)
305
+
306
+ return model, optimizer, opt_state, State.init_state()
307
+
308
+ @eqx.filter_jit
309
+ def get_output(self, model: PyTree, batch: Batch) -> Output:
267
310
  """Gets the output from the model.
268
311
 
269
312
  By default, we assume the model is a function that takes the batch as
@@ -273,10 +316,11 @@ class TrainMixin(
273
316
  Args:
274
317
  model: The current model.
275
318
  batch: The current minibatch of samples.
276
- state: The current training state.
277
319
  """
320
+ raise NotImplementedError("`get_output` must be implemented by the subclass")
278
321
 
279
- def compute_loss(self, model: Model, batch: Batch, output: Output, state: State) -> Array:
322
+ @eqx.filter_jit
323
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
280
324
  """Gets the loss for the current batch.
281
325
 
282
326
  By default, we assume the model is a function that takes the batch as
@@ -287,7 +331,6 @@ class TrainMixin(
287
331
  model: The current model.
288
332
  batch: The current minibatch of samples.
289
333
  output: The output from the model.
290
- state: The current training state.
291
334
 
292
335
  Returns:
293
336
  The computed loss, as a tensor.
@@ -296,22 +339,22 @@ class TrainMixin(
296
339
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
297
340
  return output
298
341
 
299
- def get_output_and_loss(self, model: Model, batch: Batch, state: State) -> tuple[Array, Output]:
300
- output = self.get_output(model, batch, state)
301
- loss = self.compute_loss(model, batch, output, state)
342
+ @eqx.filter_jit
343
+ def get_output_and_loss(self, model: PyTree, batch: Batch) -> tuple[Array, Output]:
344
+ output = self.get_output(model, batch)
345
+ loss = self.compute_loss(model, batch, output)
302
346
  return loss, output
303
347
 
304
348
  @eqx.filter_jit
305
349
  def update(
306
350
  self,
307
- model: Model,
351
+ model: PyTree,
308
352
  optimizer: optax.GradientTransformation,
309
353
  opt_state: optax.OptState,
310
354
  batch: Batch,
311
- state: State,
312
- ) -> tuple[Array, Model, optax.OptState, Output]:
313
- (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
314
- updates, opt_state = optimizer.update(grads, opt_state)
355
+ ) -> tuple[Array, PyTree, optax.OptState, Output]:
356
+ (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch)
357
+ updates, opt_state = optimizer.update(grads, opt_state, model)
315
358
  model = eqx.apply_updates(model, updates)
316
359
  return loss, model, opt_state, output
317
360
 
@@ -350,7 +393,13 @@ class TrainMixin(
350
393
  self._last_printed_remaining_time = state.elapsed_time_s
351
394
  remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
352
395
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
353
- logger.info("Estimated finish time: %s", termination_time)
396
+ # logger.info("Estimated finish time: %s", termination_time)
397
+ jax.debug.print("Estimated finish time: {}", termination_time)
398
+
399
+ def get_remaining_percent(self, state: State) -> float | None:
400
+ if self.config.max_steps is None:
401
+ return None
402
+ return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
354
403
 
355
404
  def is_training_over(self, state: State) -> bool:
356
405
  if self._training_over_flag:
@@ -358,7 +407,6 @@ class TrainMixin(
358
407
  remaining_percent = self.get_remaining_percent(state)
359
408
  if remaining_percent is None:
360
409
  return False
361
- self.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
362
410
  self.maybe_log_termination_time(remaining_percent, state)
363
411
  return remaining_percent <= 0.0
364
412
 
@@ -373,59 +421,124 @@ class TrainMixin(
373
421
  case _:
374
422
  raise ValueError(f"Invalid step kind {self._step_kind}")
375
423
 
376
- def get_remaining_percent(self, state: State) -> float | None:
377
- if self.config.max_steps is None:
378
- return None
379
- return (self.config.max_steps - self.get_step(state)) / self.config.max_steps
380
-
381
424
  def log_state(self) -> None:
382
425
  logger.log(LOG_STATUS, self.task_path)
383
426
  logger.log(LOG_STATUS, self.task_name)
384
- self.logger.log_git_state(get_git_state(self))
385
- self.logger.log_training_code(get_training_code(self))
386
- self.logger.log_config(cast(DictConfig, self.config))
427
+ self.logger.log_file("git_state.txt", get_git_state(self))
428
+ self.logger.log_file("training_code.txt", get_training_code(self))
429
+ self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
387
430
 
431
+ @eqx.filter_jit
388
432
  def train_step(
389
433
  self,
390
- model: Model,
434
+ model: PyTree,
391
435
  optimizer: optax.GradientTransformation,
392
436
  opt_state: optax.OptState,
393
437
  batch: Batch,
438
+ ) -> tuple[PyTree, optax.OptState, Array, Output]:
439
+ loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch)
440
+ return model, opt_state, loss, output
441
+
442
+ @eqx.filter_jit
443
+ def val_step(self, model: PyTree, batch: Batch) -> tuple[PyTree, Array, Output]:
444
+ loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch)
445
+ return model, loss, output
446
+
447
+ def train_loop(
448
+ self,
449
+ model: PyTree,
450
+ optimizer: optax.GradientTransformation,
451
+ opt_state: optax.OptState,
452
+ train_pf: Iterator[Batch],
453
+ valid_pf: Iterator[Batch],
394
454
  state: State,
395
- ) -> tuple[Model, optax.OptState, State]:
396
- state = state.with_phase("train")
397
- loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
398
- self.log_scalar("loss", loss, namespace="loss")
399
- self.log_step(model, batch, output, state)
400
- self.write_logs(state)
401
- return (
402
- model,
403
- opt_state,
404
- state.replace(
405
- {
406
- "num_steps": state.num_steps + 1,
407
- "num_samples": state.num_samples + (self.get_size_of_batch(batch) or 0),
408
- },
409
- ),
410
- )
455
+ ) -> None:
456
+ while not self.is_training_over(state):
457
+ if self.valid_step_timer.is_valid_step(state):
458
+ valid_batch = next(valid_pf)
459
+ model, loss, output = self.val_step(model, valid_batch)
460
+
461
+ # Perform logging.
462
+ with self.step_context("write_logs"):
463
+ state.phase = "valid"
464
+ self.log_step(model, valid_batch, output, loss, state)
465
+ state.num_valid_samples += 1
466
+
467
+ with self.step_context("on_step_start"):
468
+ state = self.on_step_start(state)
469
+
470
+ with self.step_context("update_state"):
471
+ train_batch = next(train_pf)
472
+ model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
473
+
474
+ # Perform logging.
475
+ with self.step_context("write_logs"):
476
+ state.phase = "train"
477
+ self.log_step(model, train_batch, output, loss, state)
478
+ state.num_steps += 1
479
+ state.num_samples += self.get_size_of_batch(train_batch) or 0
480
+
481
+ with self.step_context("on_step_end"):
482
+ state = self.on_step_end(state)
483
+
484
+ if self.should_checkpoint(state):
485
+ self.save_checkpoint(model, optimizer, opt_state, state)
486
+
487
+ # After finishing training, save the final checkpoint.
488
+ self.save_checkpoint(model, optimizer, opt_state, state)
489
+
490
+ @contextlib.contextmanager
491
+ def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
492
+ try:
493
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train")
494
+ yield train_iterator
495
+ return
496
+ except NotImplementedError:
497
+ pass
411
498
 
412
- def val_step(self, model: Model, batch: Batch, state: State) -> tuple[Model, State]:
413
- state = state.with_phase("valid")
414
- loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
415
- self.log_scalar("loss", loss, namespace="loss")
416
- self.log_step(model, batch, output, state)
417
- self.write_logs(state)
418
- return model, state.replace(
419
- {
420
- "num_valid_steps": state.num_valid_steps + 1,
421
- "num_valid_samples": state.num_valid_samples + (self.get_size_of_batch(batch) or 0),
422
- },
423
- )
499
+ with self.step_context("get_dataset"):
500
+ train_ds = self.get_dataset("train")
501
+
502
+ with self.step_context("get_dataloader"):
503
+ train_dl = self.get_dataloader(train_ds, "train")
504
+
505
+ with self.step_context("get_prefetcher"):
506
+ train_pf = self.get_prefetcher(train_dl)
507
+
508
+ try:
509
+ with train_pf as train_pf_ctx:
510
+ yield train_pf_ctx
511
+ finally:
512
+ logger.info("Closing train prefetcher")
513
+
514
+ @contextlib.contextmanager
515
+ def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
516
+ try:
517
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
518
+ yield valid_iterator
519
+ return
520
+ except NotImplementedError:
521
+ pass
522
+
523
+ with self.step_context("get_dataset"):
524
+ valid_ds = self.get_dataset("valid")
525
+
526
+ with self.step_context("get_dataloader"):
527
+ valid_dl = self.get_dataloader(valid_ds, "valid")
528
+
529
+ with self.step_context("get_prefetcher"):
530
+ valid_pf = self.get_prefetcher(valid_dl)
531
+
532
+ try:
533
+ with valid_pf as valid_pf_ctx:
534
+ yield valid_pf_ctx
535
+ finally:
536
+ logger.info("Closing valid prefetcher")
424
537
 
425
538
  def run(self) -> None:
426
- self.run_training_loop()
539
+ self.run_training()
427
540
 
428
- def run_training_loop(self) -> None:
541
+ def run_training(self) -> None:
429
542
  """Runs the training loop.
430
543
 
431
544
  Args:
@@ -437,74 +550,52 @@ class TrainMixin(
437
550
  Raises:
438
551
  ValueError: If the task is not a supervised learning task
439
552
  """
440
- with contextlib.ExitStack() as ctx:
553
+ with self:
554
+ key = self.prng_key()
555
+
441
556
  self.set_loggers()
442
557
 
443
558
  if is_master():
444
559
  Thread(target=self.log_state, daemon=True).start()
445
560
 
446
- # Gets the datasets.
447
- with self.step_context("get_dataset"):
448
- train_ds = self.get_dataset("train")
449
- valid_ds = self.get_dataset("valid")
450
-
451
- # Gets the dataloaders.
452
- with self.step_context("get_dataloader"):
453
- train_dl = self.get_dataloader(train_ds, "train")
454
- valid_dl = self.get_dataloader(valid_ds, "valid")
455
-
456
- # Gets the prefetchers.
457
- with self.step_context("get_prefetcher"):
458
- train_pf = self.get_prefetcher(train_dl)
459
- valid_pf = self.get_prefetcher(valid_dl)
460
-
461
- ctx.enter_context(self)
462
- ctx.enter_context(train_pf)
463
- ctx.enter_context(valid_pf)
464
-
465
- # Gets the model.
466
- with self.step_context("get_model"):
467
- model = self.get_model()
468
-
469
- # Gets the optimizer.
470
- with self.step_context("get_optimizer"):
471
- optimizer = self.get_optimizer()
472
-
473
- # Gets the initial optimizer state.
474
- with self.step_context("get_initial_opt_state"):
475
- opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
476
-
477
- state = State.init_state()
561
+ key, model_key = jax.random.split(key)
562
+ model, optimizer, opt_state, state = self.load_initial_state(model_key)
478
563
  state = self.on_training_start(state)
479
564
 
480
- try:
481
- while True:
482
- while True:
483
- if self.is_training_over(state):
484
- raise TrainingFinishedError
485
-
486
- if self.valid_step_timer.is_valid_step(state):
487
- model, state = self.val_step(model, next(valid_pf), state)
488
-
489
- with self.step_context("on_step_start"):
490
- state = self.on_step_start(state)
491
-
492
- model, opt_state, state = self.train_step(model, optimizer, opt_state, next(train_pf), state)
493
-
494
- with self.step_context("on_step_end"):
495
- state = self.on_step_end(state)
496
-
497
- except TrainingFinishedError:
498
- if is_master():
499
- show_info(
500
- f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
501
- important=True,
565
+ def on_exit() -> None:
566
+ self.save_checkpoint(model, optimizer, opt_state, state)
567
+
568
+ # Handle user-defined interrupts during the training loop.
569
+ self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
570
+
571
+ with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
572
+ try:
573
+ self.train_loop(
574
+ model=model,
575
+ optimizer=optimizer,
576
+ opt_state=opt_state,
577
+ train_pf=train_pf,
578
+ valid_pf=valid_pf,
579
+ state=state,
502
580
  )
503
581
 
504
- except BaseException:
505
- exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
506
- sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
507
- sys.stdout.flush()
508
-
509
- finally:
510
- state = self.on_training_end(state)
582
+ except TrainingFinishedError:
583
+ if is_master():
584
+ show_info(
585
+ f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
586
+ important=True,
587
+ )
588
+ self.save_checkpoint(model, optimizer, opt_state, state)
589
+
590
+ except (KeyboardInterrupt, bdb.BdbQuit):
591
+ if is_master():
592
+ show_info("Interrupted training", important=True)
593
+
594
+ except BaseException:
595
+ exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
596
+ sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
597
+ sys.stdout.flush()
598
+ self.save_checkpoint(model, optimizer, opt_state, state)
599
+
600
+ finally:
601
+ state = self.on_training_end(state)
xax/task/script.py CHANGED
@@ -22,7 +22,7 @@ from xax.task.mixins import (
22
22
  )
23
23
 
24
24
 
25
- @dataclass
25
+ @dataclass(kw_only=True)
26
26
  class ScriptConfig(
27
27
  CPUStatsConfig,
28
28
  GPUStatsConfig,
xax/task/task.py CHANGED
@@ -3,11 +3,16 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Generic, TypeVar
5
5
 
6
+ import jax
7
+
6
8
  from xax.task.base import BaseConfig, BaseTask
7
9
  from xax.task.mixins import (
8
10
  ArtifactsConfig,
9
11
  ArtifactsMixin,
10
- Batch,
12
+ CheckpointingConfig,
13
+ CheckpointingMixin,
14
+ CompileConfig,
15
+ CompileMixin,
11
16
  CPUStatsConfig,
12
17
  CPUStatsMixin,
13
18
  DataloadersConfig,
@@ -16,8 +21,6 @@ from xax.task.mixins import (
16
21
  GPUStatsMixin,
17
22
  LoggerConfig,
18
23
  LoggerMixin,
19
- Model,
20
- Output,
21
24
  ProcessConfig,
22
25
  ProcessMixin,
23
26
  RunnableConfig,
@@ -29,9 +32,12 @@ from xax.task.mixins import (
29
32
  )
30
33
 
31
34
 
35
+ @jax.tree_util.register_dataclass
32
36
  @dataclass
33
37
  class Config(
34
38
  TrainConfig,
39
+ CheckpointingConfig,
40
+ CompileConfig,
35
41
  DataloadersConfig,
36
42
  CPUStatsConfig,
37
43
  GPUStatsConfig,
@@ -49,7 +55,9 @@ ConfigT = TypeVar("ConfigT", bound=Config)
49
55
 
50
56
 
51
57
  class Task(
52
- TrainMixin[ConfigT, Model, Batch, Output],
58
+ TrainMixin[ConfigT],
59
+ CheckpointingMixin[ConfigT],
60
+ CompileMixin[ConfigT],
53
61
  DataloadersMixin[ConfigT],
54
62
  CPUStatsMixin[ConfigT],
55
63
  GPUStatsMixin[ConfigT],
@@ -59,6 +67,6 @@ class Task(
59
67
  ArtifactsMixin[ConfigT],
60
68
  RunnableMixin[ConfigT],
61
69
  BaseTask[ConfigT],
62
- Generic[ConfigT, Model, Batch, Output],
70
+ Generic[ConfigT],
63
71
  ):
64
72
  pass
xax/utils/data/collate.py CHANGED
@@ -167,9 +167,9 @@ def collate(
167
167
  # Collate dictionaries if they have the same keys.
168
168
  if isinstance(item, dict) and all(set(i.keys()) == set(item.keys()) for i in items):
169
169
  output_dict = {}
170
- item_keys = set(item.keys())
171
- for key in item_keys:
172
- output_dict[key] = collate([i[key] for i in items], mode=mode, pad=pad)
170
+ item_keys_set = set(item.keys())
171
+ for key_in_set in item_keys_set:
172
+ output_dict[key_in_set] = collate([i[key_in_set] for i in items], mode=mode, pad=pad)
173
173
  return output_dict
174
174
 
175
175
  # Collate lists and tuples if they have the same lengths.
@@ -186,9 +186,9 @@ def collate(
186
186
  # Handles dataclasses.
187
187
  if is_dataclass(item):
188
188
  output_dict = {}
189
- item_keys = item.__dict__.keys()
190
- for key in item_keys:
191
- output_dict[key] = collate([getattr(i, key) for i in items], mode=mode, pad=pad)
189
+ item_keys_dict = item.__dict__.keys()
190
+ for key_in_dict in item_keys_dict:
191
+ output_dict[key_in_dict] = collate([getattr(i, key_in_dict) for i in items], mode=mode, pad=pad)
192
192
  return item.__class__(**output_dict)
193
193
 
194
194
  # By default, don't do anything.