xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/mixins/logger.py CHANGED
@@ -4,14 +4,12 @@ import os
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
6
  from types import TracebackType
7
- from typing import Callable, Generic, Self, Sequence, TypeVar
8
-
9
- from jaxtyping import Array
7
+ from typing import Generic, Self, TypeVar
10
8
 
11
9
  from xax.core.conf import Device as BaseDeviceConfig, field
12
10
  from xax.core.state import State
13
11
  from xax.task.base import BaseConfig, BaseTask
14
- from xax.task.logger import ChannelSelectMode, Logger, LoggerImpl, Number
12
+ from xax.task.logger import Logger, LoggerImpl
15
13
  from xax.task.loggers.json import JsonLogger
16
14
  from xax.task.loggers.state import StateLogger
17
15
  from xax.task.loggers.stdout import StdoutLogger
@@ -59,252 +57,6 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
59
57
  def write_logs(self, state: State) -> None:
60
58
  self.logger.write(state)
61
59
 
62
- def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
63
- self.logger.log_scalar(key, value, namespace=namespace)
64
-
65
- def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
66
- self.logger.log_string(key, value, namespace=namespace)
67
-
68
- def log_image(
69
- self,
70
- key: str,
71
- value: Callable[[], Array] | Array,
72
- *,
73
- namespace: str | None = None,
74
- keep_resolution: bool = False,
75
- ) -> None:
76
- self.logger.log_image(
77
- key,
78
- value,
79
- namespace=namespace,
80
- keep_resolution=keep_resolution,
81
- )
82
-
83
- def log_labeled_image(
84
- self,
85
- key: str,
86
- value: Callable[[], tuple[Array, str]] | tuple[Array, str],
87
- *,
88
- namespace: str | None = None,
89
- max_line_length: int | None = None,
90
- keep_resolution: bool = False,
91
- centered: bool = True,
92
- ) -> None:
93
- self.logger.log_labeled_image(
94
- key,
95
- value,
96
- namespace=namespace,
97
- max_line_length=max_line_length,
98
- keep_resolution=keep_resolution,
99
- centered=centered,
100
- )
101
-
102
- def log_images(
103
- self,
104
- key: str,
105
- value: Callable[[], Array] | Array,
106
- *,
107
- namespace: str | None = None,
108
- keep_resolution: bool = False,
109
- max_images: int | None = None,
110
- sep: int = 0,
111
- ) -> None:
112
- self.logger.log_images(
113
- key,
114
- value,
115
- namespace=namespace,
116
- keep_resolution=keep_resolution,
117
- max_images=max_images,
118
- sep=sep,
119
- )
120
-
121
- def log_labeled_images(
122
- self,
123
- key: str,
124
- value: Callable[[], tuple[Array, Sequence[str]]] | tuple[Array, Sequence[str]],
125
- *,
126
- namespace: str | None = None,
127
- max_line_length: int | None = None,
128
- keep_resolution: bool = False,
129
- max_images: int | None = None,
130
- sep: int = 0,
131
- centered: bool = True,
132
- ) -> None:
133
- self.logger.log_labeled_images(
134
- key,
135
- value,
136
- namespace=namespace,
137
- max_line_length=max_line_length,
138
- keep_resolution=keep_resolution,
139
- max_images=max_images,
140
- sep=sep,
141
- centered=centered,
142
- )
143
-
144
- def log_audio(
145
- self,
146
- key: str,
147
- value: Callable[[], Array] | Array,
148
- *,
149
- namespace: str | None = None,
150
- sample_rate: int = 44100,
151
- log_spec: bool = True,
152
- n_fft_ms: float = 32.0,
153
- hop_length_ms: float | None = None,
154
- channel_select_mode: ChannelSelectMode = "first",
155
- keep_resolution: bool = False,
156
- ) -> None:
157
- self.logger.log_audio(
158
- key,
159
- value,
160
- namespace=namespace,
161
- sample_rate=sample_rate,
162
- log_spec=log_spec,
163
- n_fft_ms=n_fft_ms,
164
- hop_length_ms=hop_length_ms,
165
- channel_select_mode=channel_select_mode,
166
- keep_resolution=keep_resolution,
167
- )
168
-
169
- def log_audios(
170
- self,
171
- key: str,
172
- value: Callable[[], Array] | Array,
173
- *,
174
- namespace: str | None = None,
175
- sep_ms: float = 0.0,
176
- max_audios: int | None = None,
177
- sample_rate: int = 44100,
178
- log_spec: bool = True,
179
- n_fft_ms: float = 32.0,
180
- hop_length_ms: float | None = None,
181
- channel_select_mode: ChannelSelectMode = "first",
182
- spec_sep: int = 0,
183
- keep_resolution: bool = False,
184
- ) -> None:
185
- self.logger.log_audios(
186
- key,
187
- value,
188
- namespace=namespace,
189
- sep_ms=sep_ms,
190
- max_audios=max_audios,
191
- sample_rate=sample_rate,
192
- log_spec=log_spec,
193
- n_fft_ms=n_fft_ms,
194
- hop_length_ms=hop_length_ms,
195
- channel_select_mode=channel_select_mode,
196
- spec_sep=spec_sep,
197
- keep_resolution=keep_resolution,
198
- )
199
-
200
- def log_spectrogram(
201
- self,
202
- key: str,
203
- value: Callable[[], Array] | Array,
204
- *,
205
- namespace: str | None = None,
206
- sample_rate: int = 44100,
207
- n_fft_ms: float = 32.0,
208
- hop_length_ms: float | None = None,
209
- channel_select_mode: ChannelSelectMode = "first",
210
- keep_resolution: bool = False,
211
- ) -> None:
212
- self.logger.log_spectrogram(
213
- key,
214
- value,
215
- namespace=namespace,
216
- sample_rate=sample_rate,
217
- n_fft_ms=n_fft_ms,
218
- hop_length_ms=hop_length_ms,
219
- channel_select_mode=channel_select_mode,
220
- keep_resolution=keep_resolution,
221
- )
222
-
223
- def log_spectrograms(
224
- self,
225
- key: str,
226
- value: Callable[[], Array] | Array,
227
- *,
228
- namespace: str | None = None,
229
- max_audios: int | None = None,
230
- sample_rate: int = 44100,
231
- n_fft_ms: float = 32.0,
232
- hop_length_ms: float | None = None,
233
- channel_select_mode: ChannelSelectMode = "first",
234
- spec_sep: int = 0,
235
- keep_resolution: bool = False,
236
- ) -> None:
237
- self.logger.log_spectrograms(
238
- key,
239
- value,
240
- namespace=namespace,
241
- max_audios=max_audios,
242
- sample_rate=sample_rate,
243
- n_fft_ms=n_fft_ms,
244
- hop_length_ms=hop_length_ms,
245
- channel_select_mode=channel_select_mode,
246
- spec_sep=spec_sep,
247
- keep_resolution=keep_resolution,
248
- )
249
-
250
- def log_video(
251
- self,
252
- key: str,
253
- value: Callable[[], Array] | Array,
254
- *,
255
- namespace: str | None = None,
256
- fps: int | None = None,
257
- length: float | None = None,
258
- ) -> None:
259
- self.logger.log_video(
260
- key,
261
- value,
262
- namespace=namespace,
263
- fps=fps,
264
- length=length,
265
- )
266
-
267
- def log_videos(
268
- self,
269
- key: str,
270
- value: Callable[[], Array | list[Array]] | Array | list[Array],
271
- *,
272
- namespace: str | None = None,
273
- max_videos: int | None = None,
274
- sep: int = 0,
275
- fps: int | None = None,
276
- length: int | None = None,
277
- ) -> None:
278
- self.logger.log_videos(
279
- key,
280
- value,
281
- namespace=namespace,
282
- max_videos=max_videos,
283
- sep=sep,
284
- fps=fps,
285
- length=length,
286
- )
287
-
288
- def log_histogram(self, key: str, value: Callable[[], Array] | Array, *, namespace: str | None = None) -> None:
289
- self.logger.log_histogram(key, value, namespace=namespace)
290
-
291
- def log_point_cloud(
292
- self,
293
- key: str,
294
- value: Callable[[], Array] | Array,
295
- *,
296
- namespace: str | None = None,
297
- max_points: int = 1000,
298
- colors: Callable[[], Array] | Array | None = None,
299
- ) -> None:
300
- self.logger.log_point_cloud(
301
- key,
302
- value,
303
- namespace=namespace,
304
- max_points=max_points,
305
- colors=colors,
306
- )
307
-
308
60
  def __enter__(self) -> Self:
309
61
  self.logger.__enter__()
310
62
  return self
@@ -38,6 +38,10 @@ class ProcessMixin(BaseTask[Config], Generic[Config]):
38
38
  def multiprocessing_context(self) -> BaseContext:
39
39
  return self._mp_ctx
40
40
 
41
+ @property
42
+ def multiprocessing_manager(self) -> SyncManager:
43
+ return self._mp_manager
44
+
41
45
  def on_training_end(self, state: State) -> State:
42
46
  state = super().on_training_end(state)
43
47
 
xax/task/mixins/train.py CHANGED
@@ -1,9 +1,11 @@
1
1
  """Defines a mixin for running the training loop."""
2
2
 
3
+ import bdb
3
4
  import contextlib
4
5
  import functools
5
6
  import itertools
6
7
  import logging
8
+ import signal
7
9
  import sys
8
10
  import textwrap
9
11
  import time
@@ -11,20 +13,21 @@ import traceback
11
13
  from abc import ABC, abstractmethod
12
14
  from dataclasses import dataclass, is_dataclass
13
15
  from threading import Thread
14
- from typing import Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
16
+ from typing import Any, Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
15
17
 
16
18
  import equinox as eqx
17
19
  import jax
18
20
  import jax.numpy as jnp
19
21
  import numpy as np
20
22
  import optax
21
- from jaxtyping import Array
23
+ from jaxtyping import Array, PyTree
22
24
  from omegaconf import DictConfig
23
25
 
24
26
  from xax.core.conf import field
25
27
  from xax.core.state import Phase, State
26
28
  from xax.nn.parallel import is_master
27
29
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
30
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
28
31
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
29
32
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
30
33
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -32,6 +35,8 @@ from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
32
35
  from xax.utils.experiments import (
33
36
  StateTimer,
34
37
  TrainingFinishedError,
38
+ diff_configs,
39
+ get_diff_string,
35
40
  get_git_state,
36
41
  get_training_code,
37
42
  )
@@ -40,9 +45,11 @@ from xax.utils.text import highlight_exception_message, show_info
40
45
 
41
46
  logger = logging.getLogger(__name__)
42
47
 
43
- Model = TypeVar("Model", bound=eqx.Module)
44
- Batch = TypeVar("Batch")
45
- Output = TypeVar("Output")
48
+ # Batch = TypeVar("Batch")
49
+ # Output = TypeVar("Output")
50
+
51
+ Batch = Any
52
+ Output = Any
46
53
 
47
54
  StepKind = Literal["step", "sample", "second"]
48
55
 
@@ -125,6 +132,7 @@ class ValidStepTimer:
125
132
 
126
133
  @dataclass
127
134
  class TrainConfig(
135
+ CheckpointingConfig,
128
136
  DataloadersConfig,
129
137
  LoggerConfig,
130
138
  StepContextConfig,
@@ -145,12 +153,13 @@ Config = TypeVar("Config", bound=TrainConfig)
145
153
 
146
154
 
147
155
  class TrainMixin(
156
+ CheckpointingMixin[Config],
148
157
  DataloadersMixin[Config],
149
158
  LoggerMixin[Config],
150
159
  StepContextMixin[Config],
151
160
  ArtifactsMixin[Config],
152
161
  RunnableMixin[Config],
153
- Generic[Config, Model, Batch, Output],
162
+ Generic[Config],
154
163
  ABC,
155
164
  ):
156
165
  valid_step_timer: ValidStepTimer
@@ -159,7 +168,6 @@ class TrainMixin(
159
168
  _training_over_flag: bool
160
169
  _last_printed_remaining_time: float
161
170
  _step_kind: StepKind
162
- _prng_key: jnp.ndarray
163
171
 
164
172
  def __init__(self, config: Config) -> None:
165
173
  super().__init__(config)
@@ -183,12 +191,8 @@ class TrainMixin(
183
191
  # The kind of step that was specified in the config.
184
192
  self._step_kind = cast_step_kind(self.config.step_kind)
185
193
 
186
- # Defines a PRNG key for the task.
187
- self._prng_key = jax.random.PRNGKey(self.config.random_seed)
188
-
189
- @property
190
194
  def prng_key(self) -> jnp.ndarray:
191
- return self._prng_key
195
+ return jax.random.PRNGKey(self.config.random_seed)
192
196
 
193
197
  def on_step_end(self, state: State) -> State:
194
198
  state = super().on_step_end(state)
@@ -198,7 +202,7 @@ class TrainMixin(
198
202
  },
199
203
  )
200
204
 
201
- def log_train_step(self, model: Model, batch: Batch, output: Output, state: State) -> None:
205
+ def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
202
206
  """Override this function to do logging during the training phase.
203
207
 
204
208
  This function is called after the model forward pass and before the
@@ -211,7 +215,7 @@ class TrainMixin(
211
215
  state: The current training state.
212
216
  """
213
217
 
214
- def log_valid_step(self, model: Model, batch: Batch, output: Output, state: State) -> None:
218
+ def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
215
219
  """Override this function to do logging during the validation phase.
216
220
 
217
221
  This function is called after the model forward pass. It is called in
@@ -224,7 +228,7 @@ class TrainMixin(
224
228
  state: The current training state.
225
229
  """
226
230
 
227
- def log_step(self, model: Model, batch: Batch, output: Output, state: State) -> None:
231
+ def log_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
228
232
  phase = state.phase
229
233
 
230
234
  # Log the state timers.
@@ -232,7 +236,7 @@ class TrainMixin(
232
236
  timer.step(state)
233
237
  for ns, d in timer.log_dict().items():
234
238
  for k, v in d.items():
235
- self.log_scalar(k, v, namespace=ns)
239
+ self.logger.log_scalar(k, v, namespace=ns)
236
240
 
237
241
  # Delegate to the appropriate logging function based on the phase.
238
242
  match phase:
@@ -244,7 +248,7 @@ class TrainMixin(
244
248
  raise KeyError(f"Unknown phase: {phase}")
245
249
 
246
250
  @abstractmethod
247
- def get_model(self) -> Model:
251
+ def get_model(self) -> PyTree:
248
252
  """Returns the Equinox model to train.
249
253
 
250
254
  Returns:
@@ -259,11 +263,34 @@ class TrainMixin(
259
263
  The optimizer to use to train the model.
260
264
  """
261
265
 
262
- def get_initial_opt_state(self, model: Model, optimizer: optax.GradientTransformation) -> optax.OptState:
266
+ def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
263
267
  return optimizer.init(eqx.filter(model, eqx.is_array))
264
268
 
269
+ def load_initial_state(self) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
270
+ init_ckpt_path = self.get_init_ckpt_path()
271
+
272
+ if init_ckpt_path is not None:
273
+ logger.info("Loading checkpoint from %s", init_ckpt_path)
274
+ with self.step_context("load_checkpoint"):
275
+ model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
276
+ config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
277
+ if config_diff:
278
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
279
+ return model, optimizer, opt_state, state
280
+
281
+ with self.step_context("get_model"):
282
+ model = self.get_model()
283
+
284
+ with self.step_context("get_optimizer"):
285
+ optimizer = self.get_optimizer()
286
+
287
+ with self.step_context("get_initial_opt_state"):
288
+ opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
289
+
290
+ return model, optimizer, opt_state, State.init_state()
291
+
265
292
  @abstractmethod
266
- def get_output(self, model: Model, batch: Batch, state: State) -> Output:
293
+ def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
267
294
  """Gets the output from the model.
268
295
 
269
296
  By default, we assume the model is a function that takes the batch as
@@ -276,7 +303,7 @@ class TrainMixin(
276
303
  state: The current training state.
277
304
  """
278
305
 
279
- def compute_loss(self, model: Model, batch: Batch, output: Output, state: State) -> Array:
306
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
280
307
  """Gets the loss for the current batch.
281
308
 
282
309
  By default, we assume the model is a function that takes the batch as
@@ -296,7 +323,7 @@ class TrainMixin(
296
323
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
297
324
  return output
298
325
 
299
- def get_output_and_loss(self, model: Model, batch: Batch, state: State) -> tuple[Array, Output]:
326
+ def get_output_and_loss(self, model: PyTree, batch: Batch, state: State) -> tuple[Array, Output]:
300
327
  output = self.get_output(model, batch, state)
301
328
  loss = self.compute_loss(model, batch, output, state)
302
329
  return loss, output
@@ -304,12 +331,12 @@ class TrainMixin(
304
331
  @eqx.filter_jit
305
332
  def update(
306
333
  self,
307
- model: Model,
334
+ model: PyTree,
308
335
  optimizer: optax.GradientTransformation,
309
336
  opt_state: optax.OptState,
310
337
  batch: Batch,
311
338
  state: State,
312
- ) -> tuple[Array, Model, optax.OptState, Output]:
339
+ ) -> tuple[Array, PyTree, optax.OptState, Output]:
313
340
  (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
314
341
  updates, opt_state = optimizer.update(grads, opt_state)
315
342
  model = eqx.apply_updates(model, updates)
@@ -358,7 +385,7 @@ class TrainMixin(
358
385
  remaining_percent = self.get_remaining_percent(state)
359
386
  if remaining_percent is None:
360
387
  return False
361
- self.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
388
+ self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
362
389
  self.maybe_log_termination_time(remaining_percent, state)
363
390
  return remaining_percent <= 0.0
364
391
 
@@ -387,15 +414,15 @@ class TrainMixin(
387
414
 
388
415
  def train_step(
389
416
  self,
390
- model: Model,
417
+ model: PyTree,
391
418
  optimizer: optax.GradientTransformation,
392
419
  opt_state: optax.OptState,
393
420
  batch: Batch,
394
421
  state: State,
395
- ) -> tuple[Model, optax.OptState, State]:
422
+ ) -> tuple[PyTree, optax.OptState, State]:
396
423
  state = state.with_phase("train")
397
424
  loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
398
- self.log_scalar("loss", loss, namespace="loss")
425
+ self.logger.log_scalar("loss", loss, namespace="loss")
399
426
  self.log_step(model, batch, output, state)
400
427
  self.write_logs(state)
401
428
  return (
@@ -409,10 +436,10 @@ class TrainMixin(
409
436
  ),
410
437
  )
411
438
 
412
- def val_step(self, model: Model, batch: Batch, state: State) -> tuple[Model, State]:
439
+ def val_step(self, model: PyTree, batch: Batch, state: State) -> tuple[PyTree, State]:
413
440
  state = state.with_phase("valid")
414
441
  loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
415
- self.log_scalar("loss", loss, namespace="loss")
442
+ self.logger.log_scalar("loss", loss, namespace="loss")
416
443
  self.log_step(model, batch, output, state)
417
444
  self.write_logs(state)
418
445
  return model, state.replace(
@@ -462,20 +489,15 @@ class TrainMixin(
462
489
  ctx.enter_context(train_pf)
463
490
  ctx.enter_context(valid_pf)
464
491
 
465
- # Gets the model.
466
- with self.step_context("get_model"):
467
- model = self.get_model()
492
+ model, optimizer, opt_state, state = self.load_initial_state()
468
493
 
469
- # Gets the optimizer.
470
- with self.step_context("get_optimizer"):
471
- optimizer = self.get_optimizer()
494
+ state = self.on_training_start(state)
472
495
 
473
- # Gets the initial optimizer state.
474
- with self.step_context("get_initial_opt_state"):
475
- opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
496
+ def on_exit() -> None:
497
+ self.save_checkpoint(model, optimizer, opt_state, state)
476
498
 
477
- state = State.init_state()
478
- state = self.on_training_start(state)
499
+ # Handle user-defined interrupts during the training loop.
500
+ self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
479
501
 
480
502
  try:
481
503
  while True:
@@ -494,17 +516,26 @@ class TrainMixin(
494
516
  with self.step_context("on_step_end"):
495
517
  state = self.on_step_end(state)
496
518
 
519
+ if self.should_checkpoint(state):
520
+ self.save_checkpoint(model, optimizer, opt_state, state)
521
+
497
522
  except TrainingFinishedError:
498
523
  if is_master():
499
524
  show_info(
500
525
  f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
501
526
  important=True,
502
527
  )
528
+ self.save_checkpoint(model, optimizer, opt_state, state)
529
+
530
+ except (KeyboardInterrupt, bdb.BdbQuit):
531
+ if is_master():
532
+ show_info("Interrupted training", important=True)
503
533
 
504
534
  except BaseException:
505
535
  exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
506
536
  sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
507
537
  sys.stdout.flush()
538
+ self.save_checkpoint(model, optimizer, opt_state, state)
508
539
 
509
540
  finally:
510
541
  state = self.on_training_end(state)
xax/task/task.py CHANGED
@@ -7,7 +7,8 @@ from xax.task.base import BaseConfig, BaseTask
7
7
  from xax.task.mixins import (
8
8
  ArtifactsConfig,
9
9
  ArtifactsMixin,
10
- Batch,
10
+ CheckpointingConfig,
11
+ CheckpointingMixin,
11
12
  CPUStatsConfig,
12
13
  CPUStatsMixin,
13
14
  DataloadersConfig,
@@ -16,8 +17,6 @@ from xax.task.mixins import (
16
17
  GPUStatsMixin,
17
18
  LoggerConfig,
18
19
  LoggerMixin,
19
- Model,
20
- Output,
21
20
  ProcessConfig,
22
21
  ProcessMixin,
23
22
  RunnableConfig,
@@ -32,6 +31,7 @@ from xax.task.mixins import (
32
31
  @dataclass
33
32
  class Config(
34
33
  TrainConfig,
34
+ CheckpointingConfig,
35
35
  DataloadersConfig,
36
36
  CPUStatsConfig,
37
37
  GPUStatsConfig,
@@ -49,7 +49,8 @@ ConfigT = TypeVar("ConfigT", bound=Config)
49
49
 
50
50
 
51
51
  class Task(
52
- TrainMixin[ConfigT, Model, Batch, Output],
52
+ TrainMixin[ConfigT],
53
+ CheckpointingMixin[ConfigT],
53
54
  DataloadersMixin[ConfigT],
54
55
  CPUStatsMixin[ConfigT],
55
56
  GPUStatsMixin[ConfigT],
@@ -59,6 +60,6 @@ class Task(
59
60
  ArtifactsMixin[ConfigT],
60
61
  RunnableMixin[ConfigT],
61
62
  BaseTask[ConfigT],
62
- Generic[ConfigT, Model, Batch, Output],
63
+ Generic[ConfigT],
63
64
  ):
64
65
  pass
xax/utils/data/collate.py CHANGED
@@ -167,9 +167,9 @@ def collate(
167
167
  # Collate dictionaries if they have the same keys.
168
168
  if isinstance(item, dict) and all(set(i.keys()) == set(item.keys()) for i in items):
169
169
  output_dict = {}
170
- item_keys = set(item.keys())
171
- for key in item_keys:
172
- output_dict[key] = collate([i[key] for i in items], mode=mode, pad=pad)
170
+ item_keys_set = set(item.keys())
171
+ for key_in_set in item_keys_set:
172
+ output_dict[key_in_set] = collate([i[key_in_set] for i in items], mode=mode, pad=pad)
173
173
  return output_dict
174
174
 
175
175
  # Collate lists and tuples if they have the same lengths.
@@ -186,9 +186,9 @@ def collate(
186
186
  # Handles dataclasses.
187
187
  if is_dataclass(item):
188
188
  output_dict = {}
189
- item_keys = item.__dict__.keys()
190
- for key in item_keys:
191
- output_dict[key] = collate([getattr(i, key) for i in items], mode=mode, pad=pad)
189
+ item_keys_dict = item.__dict__.keys()
190
+ for key_in_dict in item_keys_dict:
191
+ output_dict[key_in_dict] = collate([getattr(i, key_in_dict) for i in items], mode=mode, pad=pad)
192
192
  return item.__class__(**output_dict)
193
193
 
194
194
  # By default, don't do anything.