xax 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.13"
15
+ __version__ = "0.3.15"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -94,10 +94,13 @@ __all__ = [
94
94
  "DataloaderConfig",
95
95
  "GPUStatsOptions",
96
96
  "StepContext",
97
+ "InitParams",
97
98
  "ValidStepTimer",
98
99
  "Script",
99
100
  "ScriptConfig",
100
101
  "Config",
102
+ "SupervisedConfig",
103
+ "SupervisedTask",
101
104
  "Task",
102
105
  "collate",
103
106
  "collate_non_null",
@@ -291,10 +294,13 @@ NAME_MAP: dict[str, str] = {
291
294
  "DataloaderConfig": "task.mixins.data_loader",
292
295
  "GPUStatsOptions": "task.mixins.gpu_stats",
293
296
  "StepContext": "task.mixins.step_wrapper",
297
+ "InitParams": "task.mixins.train",
294
298
  "ValidStepTimer": "task.mixins.train",
295
299
  "Script": "task.script",
296
300
  "ScriptConfig": "task.script",
297
301
  "Config": "task.task",
302
+ "SupervisedConfig": "task.task",
303
+ "SupervisedTask": "task.task",
298
304
  "Task": "task.task",
299
305
  "collate": "utils.data.collate",
300
306
  "collate_non_null": "utils.data.collate",
@@ -488,9 +494,9 @@ if IMPORT_ALL or TYPE_CHECKING:
488
494
  from xax.task.mixins.data_loader import DataloaderConfig
489
495
  from xax.task.mixins.gpu_stats import GPUStatsOptions
490
496
  from xax.task.mixins.step_wrapper import StepContext
491
- from xax.task.mixins.train import Batch, Output, ValidStepTimer
497
+ from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
492
498
  from xax.task.script import Script, ScriptConfig
493
- from xax.task.task import Config, Task
499
+ from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
494
500
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
495
501
  from xax.utils.debugging import (
496
502
  breakpoint_if_nonfinite,
@@ -10,4 +10,5 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
10
10
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
11
11
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
12
12
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
13
- from xax.task.mixins.train import TrainConfig, TrainMixin
13
+ from xax.task.mixins.supervised import SupervisedConfig, SupervisedMixin
14
+ from xax.task.mixins.train import InitParams, TrainConfig, TrainMixin
@@ -82,7 +82,7 @@ class ArtifactsMixin(BaseTask[Config]):
82
82
  return self._exp_dir
83
83
 
84
84
  def get_exp_dir(run_id: int) -> Path:
85
- return self.run_dir / f"run_{run_id}"
85
+ return self.run_dir / f"run_{run_id:03d}"
86
86
 
87
87
  run_id = 0
88
88
  while (exp_dir := get_exp_dir(run_id)).is_dir():
@@ -0,0 +1,368 @@
1
+ """Defines a mixin for running the training loop."""
2
+
3
+ import bdb
4
+ import contextlib
5
+ import itertools
6
+ import logging
7
+ import signal
8
+ import sys
9
+ import textwrap
10
+ import traceback
11
+ from abc import ABC
12
+ from dataclasses import dataclass
13
+ from threading import Thread
14
+ from typing import (
15
+ Generator,
16
+ Generic,
17
+ Iterator,
18
+ Sequence,
19
+ TypeVar,
20
+ )
21
+
22
+ import equinox as eqx
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import optax
26
+ from jaxtyping import Array, PRNGKeyArray, PyTree
27
+
28
+ from xax.core.conf import field
29
+ from xax.core.state import State
30
+ from xax.nn.parallel import is_master
31
+ from xax.task.mixins.train import Batch, InitParams, Output, TrainConfig, TrainMixin
32
+ from xax.utils.experiments import (
33
+ ContextTimer,
34
+ TrainingFinishedError,
35
+ )
36
+ from xax.utils.jax import jit as xax_jit, scan as xax_scan
37
+ from xax.utils.logging import LOG_PING
38
+ from xax.utils.pytree import get_pytree_param_count
39
+ from xax.utils.text import highlight_exception_message, show_info
40
+ from xax.utils.types.frozen_dict import FrozenDict
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @jax.tree_util.register_dataclass
46
+ @dataclass
47
+ class SupervisedConfig(TrainConfig):
48
+ updates_per_step: int = field(1, help="Number of updates to perform per step")
49
+
50
+
51
+ Config = TypeVar("Config", bound=SupervisedConfig)
52
+
53
+
54
+ class SupervisedMixin(
55
+ TrainMixin[Config, InitParams],
56
+ Generic[Config],
57
+ ABC,
58
+ ):
59
+ def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
60
+ """Gets the output from the model.
61
+
62
+ By default, we assume the model is a function that takes the batch as
63
+ input and returns the loss. This function can be patched to do more
64
+ complex operations instead.
65
+
66
+ Args:
67
+ model: The current model.
68
+ batch: The current minibatch of samples.
69
+ state: The current training state.
70
+ """
71
+ raise NotImplementedError("`get_output` must be implemented by the subclass")
72
+
73
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
74
+ """Gets the loss for the current batch.
75
+
76
+ By default, we assume the model is a function that takes the batch as
77
+ input and returns the loss. This function can be patched to do more
78
+ complex operations instead.
79
+
80
+ Args:
81
+ model: The current model.
82
+ batch: The current minibatch of samples.
83
+ output: The output from the model.
84
+ state: The current training state.
85
+
86
+ Returns:
87
+ The computed loss, as a tensor.
88
+ """
89
+ if not isinstance(output, Array):
90
+ raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
91
+ return output
92
+
93
+ def compute_metrics(
94
+ self,
95
+ model: PyTree,
96
+ batch: Batch,
97
+ output: Output,
98
+ loss: Array,
99
+ state: State,
100
+ ) -> dict[str, Array]:
101
+ """Computes the metrics for the current batch.
102
+
103
+ Args:
104
+ model: The current model.
105
+ batch: The current minibatch of samples.
106
+ output: The output from the model.
107
+ loss: The loss for the current batch.
108
+ state: The current training state.
109
+
110
+ Returns:
111
+ A dictionary of metrics.
112
+ """
113
+ return {
114
+ "loss": loss,
115
+ }
116
+
117
+ @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
118
+ def get_output_and_loss(
119
+ self,
120
+ model_arr: PyTree,
121
+ model_static: PyTree,
122
+ batch: Batch,
123
+ state: State,
124
+ ) -> tuple[Array, tuple[Output, dict[str, Array]]]:
125
+ model = eqx.combine(model_arr, model_static)
126
+ output = self.get_output(model, batch, state)
127
+ loss = self.compute_loss(model, batch, output, state)
128
+ metrics = self.compute_metrics(model, batch, output, loss, state)
129
+ return loss, (output, metrics)
130
+
131
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
132
+ def update(
133
+ self,
134
+ model_arr: PyTree,
135
+ model_static: PyTree,
136
+ optimizer: optax.GradientTransformation,
137
+ opt_state: optax.OptState,
138
+ batch: Batch,
139
+ state: State,
140
+ ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
141
+ grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
142
+ grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
143
+ grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
144
+ updates, opt_state = optimizer.update(grads, opt_state, model_arr)
145
+ model_arr = eqx.apply_updates(model_arr, updates)
146
+ return model_arr, opt_state, output, metrics
147
+
148
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
149
+ def train_step(
150
+ self,
151
+ model_arr: PyTree,
152
+ model_static: PyTree,
153
+ optimizer: optax.GradientTransformation,
154
+ opt_state: optax.OptState,
155
+ batches: Batch,
156
+ state: State,
157
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
158
+ def update_fn(
159
+ carry: tuple[PyTree, optax.OptState],
160
+ batch: Batch,
161
+ ) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
162
+ model_arr, opt_state = carry
163
+ model_arr, opt_state, output, metrics = self.update(
164
+ model_arr,
165
+ model_static,
166
+ optimizer,
167
+ opt_state,
168
+ batch,
169
+ state,
170
+ )
171
+ return (model_arr, opt_state), (output, FrozenDict(metrics))
172
+
173
+ (model_arr, opt_state), (output, metrics) = xax_scan(
174
+ update_fn,
175
+ (model_arr, opt_state),
176
+ batches,
177
+ jit_level=3,
178
+ )
179
+
180
+ # Only get the final output and metrics.
181
+ output = jax.tree.map(lambda x: x[-1], output)
182
+ metrics = jax.tree.map(lambda x: x[-1], metrics)
183
+
184
+ return model_arr, opt_state, output, metrics
185
+
186
+ @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
187
+ def val_step(
188
+ self,
189
+ model_arr: PyTree,
190
+ model_static: PyTree,
191
+ batch: Batch,
192
+ state: State,
193
+ ) -> tuple[Output, FrozenDict[str, Array]]:
194
+ _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
195
+ return output, FrozenDict(metrics)
196
+
197
+ def train_loop(
198
+ self,
199
+ models: Sequence[PyTree],
200
+ optimizers: Sequence[optax.GradientTransformation],
201
+ opt_states: Sequence[optax.OptState],
202
+ train_pf: Iterator[Batch],
203
+ valid_pf: Iterator[Batch],
204
+ state: State,
205
+ ) -> None:
206
+ if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
207
+ raise ValueError(
208
+ "Vanilla training expects a single model, optimizer and optimizer state. "
209
+ f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
210
+ )
211
+
212
+ model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
213
+ optimizer = optimizers[0]
214
+ opt_state = opt_states[0]
215
+
216
+ while not self.is_training_over(state):
217
+ valid_step = self.valid_step_timer(state)
218
+
219
+ if valid_step:
220
+ with ContextTimer() as timer:
221
+ state = state.replace(phase="valid")
222
+ valid_batch = next(valid_pf)
223
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
224
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
225
+
226
+ state = state.replace(
227
+ num_steps=state.num_steps + 1,
228
+ num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
229
+ elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
230
+ )
231
+
232
+ with ContextTimer() as timer:
233
+ state = self.on_step_start(state)
234
+ state = state.replace(phase="train")
235
+ train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
236
+ model_arr, opt_state, output, metrics = self.train_step(
237
+ model_arr=model_arr,
238
+ model_static=model_static,
239
+ optimizer=optimizer,
240
+ opt_state=opt_state,
241
+ batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
242
+ state=state,
243
+ )
244
+ self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
245
+ state = self.on_step_end(state)
246
+
247
+ state = state.replace(
248
+ num_steps=state.num_steps + 1,
249
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
250
+ elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
251
+ )
252
+
253
+ if state.num_steps <= 3:
254
+ logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
255
+
256
+ if self.should_checkpoint(state):
257
+ model = eqx.combine(model_arr, model_static)
258
+ self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
259
+
260
+ # After finishing training, save the final checkpoint.
261
+ model = eqx.combine(model_arr, model_static)
262
+ self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
263
+
264
+ @contextlib.contextmanager
265
+ def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
266
+ try:
267
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
268
+ yield train_iterator
269
+ return
270
+ except NotImplementedError:
271
+ pass
272
+
273
+ train_ds = self.get_dataset("train")
274
+ train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
275
+ train_pf = self.get_prefetcher(train_dl)
276
+
277
+ try:
278
+ with train_pf as train_pf_ctx:
279
+ yield train_pf_ctx
280
+ finally:
281
+ logger.info("Closing train prefetcher")
282
+
283
+ @contextlib.contextmanager
284
+ def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
285
+ try:
286
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
287
+ yield valid_iterator
288
+ return
289
+ except NotImplementedError:
290
+ pass
291
+
292
+ valid_ds = self.get_dataset("valid")
293
+ valid_dl = self.get_dataloader(valid_ds, "valid")
294
+ valid_pf = self.get_prefetcher(valid_dl)
295
+
296
+ try:
297
+ with valid_pf as valid_pf_ctx:
298
+ yield valid_pf_ctx
299
+ finally:
300
+ logger.info("Closing valid prefetcher")
301
+
302
+ def run(self) -> None:
303
+ self.run_training()
304
+
305
+ def run_training(self) -> None:
306
+ """Runs the training loop.
307
+
308
+ Args:
309
+ model: The current model
310
+ task: The current task
311
+ optimizer: The current optimizer
312
+ lr_scheduler: The current learning rate scheduler
313
+
314
+ Raises:
315
+ ValueError: If the task is not a supervised learning task
316
+ """
317
+ with self:
318
+ key = self.prng_key()
319
+
320
+ self.set_loggers()
321
+
322
+ if is_master():
323
+ Thread(target=self.log_state, daemon=True).start()
324
+
325
+ key, model_key = jax.random.split(key)
326
+ init_params = InitParams(key=model_key)
327
+ models, optimizers, opt_states, state = self.load_initial_state(init_params, load_optimizer=True)
328
+ logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
329
+ logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
330
+
331
+ state = self.on_training_start(state)
332
+
333
+ def on_exit() -> None:
334
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
335
+
336
+ # Handle user-defined interrupts during the training loop.
337
+ self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
338
+
339
+ key, tkey, vkey = jax.random.split(key, 3)
340
+ with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
341
+ try:
342
+ self.train_loop(
343
+ models=models,
344
+ optimizers=optimizers,
345
+ opt_states=opt_states,
346
+ train_pf=train_pf,
347
+ valid_pf=valid_pf,
348
+ state=state,
349
+ )
350
+
351
+ except TrainingFinishedError:
352
+ if is_master():
353
+ num_steps, num_samples = int(state.num_steps), int(state.num_samples)
354
+ show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
355
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
356
+
357
+ except (KeyboardInterrupt, bdb.BdbQuit):
358
+ if is_master():
359
+ show_info("Interrupted training", important=True)
360
+
361
+ except BaseException:
362
+ exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
363
+ sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
364
+ sys.stdout.flush()
365
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
366
+
367
+ finally:
368
+ state = self.on_training_end(state)
xax/task/mixins/train.py CHANGED
@@ -1,24 +1,15 @@
1
1
  """Defines a mixin for running the training loop."""
2
2
 
3
- import bdb
4
- import contextlib
5
3
  import functools
6
4
  import itertools
7
5
  import logging
8
- import signal
9
- import sys
10
- import textwrap
11
6
  import time
12
- import traceback
13
7
  from abc import ABC, abstractmethod
14
8
  from dataclasses import asdict, dataclass, is_dataclass
15
9
  from pathlib import Path
16
- from threading import Thread
17
10
  from typing import (
18
11
  Any,
19
- Generator,
20
12
  Generic,
21
- Iterator,
22
13
  Literal,
23
14
  Mapping,
24
15
  Sequence,
@@ -30,7 +21,6 @@ from typing import (
30
21
 
31
22
  import equinox as eqx
32
23
  import jax
33
- import jax.numpy as jnp
34
24
  import numpy as np
35
25
  import optax
36
26
  from jaxtyping import Array, PRNGKeyArray, PyTree
@@ -38,7 +28,6 @@ from jaxtyping import Array, PRNGKeyArray, PyTree
38
28
  from xax.core.conf import field
39
29
  from xax.core.state import Phase, State
40
30
  from xax.nn.functions import set_random_seed
41
- from xax.nn.parallel import is_master
42
31
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
32
  from xax.task.mixins.checkpointing import (
44
33
  CheckpointingConfig,
@@ -51,19 +40,14 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
51
40
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
52
41
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
53
42
  from xax.utils.experiments import (
54
- ContextTimer,
55
43
  StateTimer,
56
- TrainingFinishedError,
57
44
  diff_configs,
58
45
  get_diff_string,
59
46
  get_info_json,
60
47
  get_state_file_string,
61
48
  get_training_code,
62
49
  )
63
- from xax.utils.jax import jit as xax_jit, scan as xax_scan
64
50
  from xax.utils.logging import LOG_PING, LOG_STATUS
65
- from xax.utils.pytree import get_pytree_param_count
66
- from xax.utils.text import highlight_exception_message, show_info
67
51
  from xax.utils.types.frozen_dict import FrozenDict
68
52
 
69
53
  logger = logging.getLogger(__name__)
@@ -159,6 +143,16 @@ class ValidStepTimer:
159
143
  return False
160
144
 
161
145
 
146
+ @jax.tree_util.register_dataclass
147
+ @dataclass(frozen=True)
148
+ class InitParams:
149
+ key: PRNGKeyArray
150
+
151
+
152
+ # Subclasses should be able to override the init params.
153
+ InitParamsT = TypeVar("InitParamsT", bound=InitParams)
154
+
155
+
162
156
  @jax.tree_util.register_dataclass
163
157
  @dataclass
164
158
  class TrainConfig(
@@ -175,7 +169,6 @@ class TrainConfig(
175
169
  valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
176
170
  max_steps: int | None = field(None, help="Maximum number of steps to run")
177
171
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
178
- updates_per_step: int = field(1, help="Number of updates to perform per step")
179
172
  random_seed: int = field(1337, help="Random seed for the task")
180
173
 
181
174
 
@@ -189,7 +182,7 @@ class TrainMixin(
189
182
  StepContextMixin[Config],
190
183
  ArtifactsMixin[Config],
191
184
  RunnableMixin[Config],
192
- Generic[Config],
185
+ Generic[Config, InitParamsT],
193
186
  ABC,
194
187
  ):
195
188
  valid_step_timer: ValidStepTimer
@@ -309,15 +302,18 @@ class TrainMixin(
309
302
  self.write_logs(state)
310
303
 
311
304
  @abstractmethod
312
- def get_model(self, key: PRNGKeyArray) -> PyTree | Sequence[PyTree]:
305
+ def get_model(self, params: InitParamsT) -> PyTree | Sequence[PyTree]:
313
306
  """Returns the Equinox model to train.
314
307
 
308
+ Args:
309
+ params: The parameters for initializing the model.
310
+
315
311
  Returns:
316
312
  The model to train.
317
313
  """
318
314
 
319
- def _get_models(self, key: PRNGKeyArray) -> list[PyTree]:
320
- models = self.get_model(key)
315
+ def _get_models(self, params: InitParamsT) -> list[PyTree]:
316
+ models = self.get_model(params)
321
317
  if isinstance(models, Sequence):
322
318
  models = list(models)
323
319
  elif isinstance(models, eqx.Module):
@@ -353,20 +349,20 @@ class TrainMixin(
353
349
  @overload
354
350
  def load_initial_state(
355
351
  self,
356
- key: PRNGKeyArray,
352
+ params: InitParamsT,
357
353
  load_optimizer: Literal[False] = False,
358
354
  ) -> tuple[PyTree, State]: ...
359
355
 
360
356
  @overload
361
357
  def load_initial_state(
362
358
  self,
363
- key: PRNGKeyArray,
359
+ params: InitParamsT,
364
360
  load_optimizer: Literal[True],
365
361
  ) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
366
362
 
367
363
  def load_initial_state(
368
364
  self,
369
- key: PRNGKeyArray,
365
+ params: InitParamsT,
370
366
  load_optimizer: bool = False,
371
367
  ) -> (
372
368
  tuple[list[PyTree], State]
@@ -376,7 +372,7 @@ class TrainMixin(
376
372
 
377
373
  if init_ckpt_path is not None:
378
374
  logger.info("Loading checkpoint from %s", init_ckpt_path)
379
- model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
375
+ model, state, config = self.load_ckpt(init_ckpt_path, params, part="model_state_config")
380
376
  config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
381
377
  if config_diff:
382
378
  logger.warning("Loaded config differs from current config:\n%s", config_diff)
@@ -384,12 +380,12 @@ class TrainMixin(
384
380
  if not load_optimizer:
385
381
  return model, state
386
382
 
387
- optimizer = self.load_ckpt(init_ckpt_path, part="opt")
388
- opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
383
+ optimizer = self.load_ckpt(init_ckpt_path, params, part="opt")
384
+ opt_state = self.load_ckpt(init_ckpt_path, params, part="opt_state", model=model, optimizer=optimizer)
389
385
  return model, optimizer, opt_state, state
390
386
 
391
387
  logger.info("Starting a new training run")
392
- models = self._get_models(key)
388
+ models = self._get_models(params)
393
389
  state = State.init_state()
394
390
 
395
391
  if not load_optimizer:
@@ -405,6 +401,7 @@ class TrainMixin(
405
401
  def load_ckpt(
406
402
  self,
407
403
  path: Path,
404
+ init_params: InitParamsT,
408
405
  *,
409
406
  part: Literal["all"],
410
407
  ) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, Config]: ...
@@ -413,6 +410,7 @@ class TrainMixin(
413
410
  def load_ckpt(
414
411
  self,
415
412
  path: Path,
413
+ init_params: InitParamsT,
416
414
  *,
417
415
  part: Literal["model_state_config"],
418
416
  ) -> tuple[list[PyTree], State, Config]: ...
@@ -421,6 +419,7 @@ class TrainMixin(
421
419
  def load_ckpt(
422
420
  self,
423
421
  path: Path,
422
+ init_params: InitParamsT,
424
423
  *,
425
424
  part: Literal["model"],
426
425
  ) -> list[PyTree]: ...
@@ -429,6 +428,7 @@ class TrainMixin(
429
428
  def load_ckpt(
430
429
  self,
431
430
  path: Path,
431
+ init_params: InitParamsT,
432
432
  *,
433
433
  part: Literal["opt"],
434
434
  ) -> list[optax.GradientTransformation]: ...
@@ -437,6 +437,7 @@ class TrainMixin(
437
437
  def load_ckpt(
438
438
  self,
439
439
  path: Path,
440
+ init_params: InitParamsT,
440
441
  *,
441
442
  part: Literal["opt_state"],
442
443
  model: PyTree | None = None,
@@ -447,6 +448,7 @@ class TrainMixin(
447
448
  def load_ckpt(
448
449
  self,
449
450
  path: Path,
451
+ init_params: InitParamsT,
450
452
  *,
451
453
  part: Literal["state"],
452
454
  ) -> list[State]: ...
@@ -455,6 +457,7 @@ class TrainMixin(
455
457
  def load_ckpt(
456
458
  self,
457
459
  path: Path,
460
+ init_params: InitParamsT,
458
461
  *,
459
462
  part: Literal["config"],
460
463
  ) -> list[Config]: ...
@@ -462,6 +465,7 @@ class TrainMixin(
462
465
  def load_ckpt(
463
466
  self,
464
467
  path: str | Path,
468
+ init_params: InitParamsT,
465
469
  *,
466
470
  part: CheckpointPart = "all",
467
471
  model: PyTree | None = None,
@@ -477,18 +481,15 @@ class TrainMixin(
477
481
  ):
478
482
  path = Path(path)
479
483
 
480
- # This key isn't used for anything, it's just a required argument.
481
- key = jax.random.PRNGKey(0)
482
-
483
484
  match part:
484
485
  case "model_state_config":
485
- model_specs = eqx.filter_eval_shape(self._get_models, key)
486
+ model_specs = eqx.filter_eval_shape(self._get_models, init_params)
486
487
  model, state, config = load_ckpt(path, part="model_state_config", model_templates=model_specs)
487
488
  config = self.get_config(config, use_cli=False)
488
489
  return model, state, config
489
490
 
490
491
  case "model":
491
- model_specs = eqx.filter_eval_shape(self._get_models, key)
492
+ model_specs = eqx.filter_eval_shape(self._get_models, init_params)
492
493
  return load_ckpt(path, part="model", model_templates=model_specs)
493
494
 
494
495
  case "opt":
@@ -497,7 +498,7 @@ class TrainMixin(
497
498
 
498
499
  case "opt_state":
499
500
  if model is None:
500
- model_specs = eqx.filter_eval_shape(self._get_models, key)
501
+ model_specs = eqx.filter_eval_shape(self._get_models, init_params)
501
502
  model = load_ckpt(path, part="model", model_templates=model_specs)
502
503
  if optimizer is None:
503
504
  optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
@@ -512,7 +513,7 @@ class TrainMixin(
512
513
  return self.get_config(load_ckpt(path, part="config"), use_cli=False)
513
514
 
514
515
  case "all":
515
- model_specs = eqx.filter_eval_shape(self._get_models, key)
516
+ model_specs = eqx.filter_eval_shape(self._get_models, init_params)
516
517
  model = load_ckpt(path, part="model", model_templates=model_specs)
517
518
  optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
518
519
  optimizer = load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
@@ -525,95 +526,6 @@ class TrainMixin(
525
526
  case _:
526
527
  raise ValueError(f"Unknown checkpoint part: {part}")
527
528
 
528
- def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
529
- """Gets the output from the model.
530
-
531
- By default, we assume the model is a function that takes the batch as
532
- input and returns the loss. This function can be patched to do more
533
- complex operations instead.
534
-
535
- Args:
536
- model: The current model.
537
- batch: The current minibatch of samples.
538
- state: The current training state.
539
- """
540
- raise NotImplementedError("`get_output` must be implemented by the subclass")
541
-
542
- def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
543
- """Gets the loss for the current batch.
544
-
545
- By default, we assume the model is a function that takes the batch as
546
- input and returns the loss. This function can be patched to do more
547
- complex operations instead.
548
-
549
- Args:
550
- model: The current model.
551
- batch: The current minibatch of samples.
552
- output: The output from the model.
553
- state: The current training state.
554
-
555
- Returns:
556
- The computed loss, as a tensor.
557
- """
558
- if not isinstance(output, Array):
559
- raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
560
- return output
561
-
562
- def compute_metrics(
563
- self,
564
- model: PyTree,
565
- batch: Batch,
566
- output: Output,
567
- loss: Array,
568
- state: State,
569
- ) -> dict[str, Array]:
570
- """Computes the metrics for the current batch.
571
-
572
- Args:
573
- model: The current model.
574
- batch: The current minibatch of samples.
575
- output: The output from the model.
576
- loss: The loss for the current batch.
577
- state: The current training state.
578
-
579
- Returns:
580
- A dictionary of metrics.
581
- """
582
- return {
583
- "loss": loss,
584
- }
585
-
586
- @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
587
- def get_output_and_loss(
588
- self,
589
- model_arr: PyTree,
590
- model_static: PyTree,
591
- batch: Batch,
592
- state: State,
593
- ) -> tuple[Array, tuple[Output, dict[str, Array]]]:
594
- model = eqx.combine(model_arr, model_static)
595
- output = self.get_output(model, batch, state)
596
- loss = self.compute_loss(model, batch, output, state)
597
- metrics = self.compute_metrics(model, batch, output, loss, state)
598
- return loss, (output, metrics)
599
-
600
- @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
601
- def update(
602
- self,
603
- model_arr: PyTree,
604
- model_static: PyTree,
605
- optimizer: optax.GradientTransformation,
606
- opt_state: optax.OptState,
607
- batch: Batch,
608
- state: State,
609
- ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
610
- grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
611
- grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
612
- grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
613
- updates, opt_state = optimizer.update(grads, opt_state, model_arr)
614
- model_arr = eqx.apply_updates(model_arr, updates)
615
- return model_arr, opt_state, output, metrics
616
-
617
529
  def get_size_of_batch(self, batch: Batch) -> int | None:
618
530
  """Gets the batch size for the current batch.
619
531
 
@@ -687,224 +599,3 @@ class TrainMixin(
687
599
 
688
600
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
689
601
  return eqx.is_inexact_array(item)
690
-
691
- @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
692
- def train_step(
693
- self,
694
- model_arr: PyTree,
695
- model_static: PyTree,
696
- optimizer: optax.GradientTransformation,
697
- opt_state: optax.OptState,
698
- batches: Batch,
699
- state: State,
700
- ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
701
- def update_fn(
702
- carry: tuple[PyTree, optax.OptState],
703
- batch: Batch,
704
- ) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
705
- model_arr, opt_state = carry
706
- model_arr, opt_state, output, metrics = self.update(
707
- model_arr,
708
- model_static,
709
- optimizer,
710
- opt_state,
711
- batch,
712
- state,
713
- )
714
- return (model_arr, opt_state), (output, FrozenDict(metrics))
715
-
716
- (model_arr, opt_state), (output, metrics) = xax_scan(
717
- update_fn,
718
- (model_arr, opt_state),
719
- batches,
720
- jit_level=3,
721
- )
722
-
723
- # Only get the final output and metrics.
724
- output = jax.tree.map(lambda x: x[-1], output)
725
- metrics = jax.tree.map(lambda x: x[-1], metrics)
726
-
727
- return model_arr, opt_state, output, metrics
728
-
729
- @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
730
- def val_step(
731
- self,
732
- model_arr: PyTree,
733
- model_static: PyTree,
734
- batch: Batch,
735
- state: State,
736
- ) -> tuple[Output, FrozenDict[str, Array]]:
737
- _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
738
- return output, FrozenDict(metrics)
739
-
740
- def train_loop(
741
- self,
742
- models: Sequence[PyTree],
743
- optimizers: Sequence[optax.GradientTransformation],
744
- opt_states: Sequence[optax.OptState],
745
- train_pf: Iterator[Batch],
746
- valid_pf: Iterator[Batch],
747
- state: State,
748
- ) -> None:
749
- if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
750
- raise ValueError(
751
- "Vanilla training expects a single model, optimizer and optimizer state. "
752
- f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
753
- )
754
-
755
- model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
756
- optimizer = optimizers[0]
757
- opt_state = opt_states[0]
758
-
759
- while not self.is_training_over(state):
760
- valid_step = self.valid_step_timer(state)
761
-
762
- if valid_step:
763
- with ContextTimer() as timer:
764
- state = state.replace(phase="valid")
765
- valid_batch = next(valid_pf)
766
- output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
767
- self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
768
-
769
- state = state.replace(
770
- num_steps=state.num_steps + 1,
771
- num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
772
- elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
773
- )
774
-
775
- with ContextTimer() as timer:
776
- state = self.on_step_start(state)
777
- state = state.replace(phase="train")
778
- train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
779
- model_arr, opt_state, output, metrics = self.train_step(
780
- model_arr=model_arr,
781
- model_static=model_static,
782
- optimizer=optimizer,
783
- opt_state=opt_state,
784
- batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
785
- state=state,
786
- )
787
- self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
788
- state = self.on_step_end(state)
789
-
790
- state = state.replace(
791
- num_steps=state.num_steps + 1,
792
- num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
793
- elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
794
- )
795
-
796
- if state.num_steps <= 3:
797
- logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
798
-
799
- if self.should_checkpoint(state):
800
- model = eqx.combine(model_arr, model_static)
801
- self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
802
-
803
- # After finishing training, save the final checkpoint.
804
- model = eqx.combine(model_arr, model_static)
805
- self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
806
-
807
- @contextlib.contextmanager
808
- def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
809
- try:
810
- train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
811
- yield train_iterator
812
- return
813
- except NotImplementedError:
814
- pass
815
-
816
- train_ds = self.get_dataset("train")
817
- train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
818
- train_pf = self.get_prefetcher(train_dl)
819
-
820
- try:
821
- with train_pf as train_pf_ctx:
822
- yield train_pf_ctx
823
- finally:
824
- logger.info("Closing train prefetcher")
825
-
826
- @contextlib.contextmanager
827
- def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
828
- try:
829
- valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
830
- yield valid_iterator
831
- return
832
- except NotImplementedError:
833
- pass
834
-
835
- valid_ds = self.get_dataset("valid")
836
- valid_dl = self.get_dataloader(valid_ds, "valid")
837
- valid_pf = self.get_prefetcher(valid_dl)
838
-
839
- try:
840
- with valid_pf as valid_pf_ctx:
841
- yield valid_pf_ctx
842
- finally:
843
- logger.info("Closing valid prefetcher")
844
-
845
- def run(self) -> None:
846
- self.run_training()
847
-
848
- def run_training(self) -> None:
849
- """Runs the training loop.
850
-
851
- Args:
852
- model: The current model
853
- task: The current task
854
- optimizer: The current optimizer
855
- lr_scheduler: The current learning rate scheduler
856
-
857
- Raises:
858
- ValueError: If the task is not a supervised learning task
859
- """
860
- with self:
861
- key = self.prng_key()
862
-
863
- self.set_loggers()
864
-
865
- if is_master():
866
- Thread(target=self.log_state, daemon=True).start()
867
-
868
- key, model_key = jax.random.split(key)
869
- models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
870
- logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
871
- logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
872
-
873
- state = self.on_training_start(state)
874
-
875
- def on_exit() -> None:
876
- self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
877
-
878
- # Handle user-defined interrupts during the training loop.
879
- self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
880
-
881
- key, tkey, vkey = jax.random.split(key, 3)
882
- with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
883
- try:
884
- self.train_loop(
885
- models=models,
886
- optimizers=optimizers,
887
- opt_states=opt_states,
888
- train_pf=train_pf,
889
- valid_pf=valid_pf,
890
- state=state,
891
- )
892
-
893
- except TrainingFinishedError:
894
- if is_master():
895
- num_steps, num_samples = int(state.num_steps), int(state.num_samples)
896
- show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
897
- self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
898
-
899
- except (KeyboardInterrupt, bdb.BdbQuit):
900
- if is_master():
901
- show_info("Interrupted training", important=True)
902
-
903
- except BaseException:
904
- exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
905
- sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
906
- sys.stdout.flush()
907
- self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
908
-
909
- finally:
910
- state = self.on_training_end(state)
xax/task/task.py CHANGED
@@ -19,6 +19,7 @@ from xax.task.mixins import (
19
19
  DataloadersMixin,
20
20
  GPUStatsConfig,
21
21
  GPUStatsMixin,
22
+ InitParams,
22
23
  LoggerConfig,
23
24
  LoggerMixin,
24
25
  ProcessConfig,
@@ -27,6 +28,8 @@ from xax.task.mixins import (
27
28
  RunnableMixin,
28
29
  StepContextConfig,
29
30
  StepContextMixin,
31
+ SupervisedConfig as BaseSupervisedConfig,
32
+ SupervisedMixin as BaseSupervisedMixin,
30
33
  TrainConfig,
31
34
  TrainMixin,
32
35
  )
@@ -52,10 +55,11 @@ class Config(
52
55
 
53
56
 
54
57
  ConfigT = TypeVar("ConfigT", bound=Config)
58
+ InitParamsT = TypeVar("InitParamsT", bound=InitParams)
55
59
 
56
60
 
57
61
  class Task(
58
- TrainMixin[ConfigT],
62
+ TrainMixin[ConfigT, InitParamsT],
59
63
  CheckpointingMixin[ConfigT],
60
64
  CompileMixin[ConfigT],
61
65
  DataloadersMixin[ConfigT],
@@ -67,6 +71,26 @@ class Task(
67
71
  ArtifactsMixin[ConfigT],
68
72
  RunnableMixin[ConfigT],
69
73
  BaseTask[ConfigT],
70
- Generic[ConfigT],
74
+ Generic[ConfigT, InitParamsT],
75
+ ):
76
+ pass
77
+
78
+
79
+ @jax.tree_util.register_dataclass
80
+ @dataclass
81
+ class SupervisedConfig(
82
+ BaseSupervisedConfig,
83
+ Config,
84
+ ):
85
+ pass
86
+
87
+
88
+ SupervisedConfigT = TypeVar("SupervisedConfigT", bound=SupervisedConfig)
89
+
90
+
91
+ class SupervisedTask(
92
+ BaseSupervisedMixin[SupervisedConfigT],
93
+ Task[SupervisedConfigT, InitParams],
94
+ Generic[SupervisedConfigT],
71
95
  ):
72
96
  pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.13
3
+ Version: 0.3.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=gTdL72cZZzdpYkHj1Ks981o3nE_BNlvIv1ISYlQarmM,16944
1
+ xax/__init__.py,sha256=4N9rKQk5lIy4eCDrgpRJEChLD3aXbzWxiGrCcdpuFMQ,17165
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -21,7 +21,7 @@ xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
21
  xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
22
22
  xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
23
23
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
24
- xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
24
+ xax/task/task.py,sha256=Iy02wRUti5lDX1rfDHIgX87dGYeayJxJ9nzJzp_lMq0,1960
25
25
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
27
27
  xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
@@ -32,8 +32,8 @@ xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
32
32
  xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
33
33
  xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
34
34
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
35
- xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
36
- xax/task/mixins/artifacts.py,sha256=R-y3p7__zJHlHDqwDVAZysg2ZmebCJbqAx_xGT2Xpd0,3857
35
+ xax/task/mixins/__init__.py,sha256=wYc4zfutdMyEmzCVV421gSf25ZXW9htNTSY_TW6vL_8,894
36
+ xax/task/mixins/artifacts.py,sha256=UN26TW22ARduO6Bjs0yRu4-V6-Md9MPbXLKDnS28m44,3861
37
37
  xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
38
38
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
39
39
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
@@ -43,7 +43,8 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
43
43
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
44
44
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
45
45
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
46
- xax/task/mixins/train.py,sha256=qb0zpsyeCk_U8Sk8THxtXkUVwj5r0lOlMLNRTctvcWU,32812
46
+ xax/task/mixins/supervised.py,sha256=IxAh-ywvjDNoqXtzHwv2WpVsXFOX45SZjyF3qpbN-2k,13757
47
+ xax/task/mixins/train.py,sha256=0loO44W6vVjP5usWvN0D1TgYTJ7N3PDevR7brmw3ymQ,20493
47
48
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
49
  xax/utils/debugging.py,sha256=85JYIdnzLnvXsuli-4YHei_3tE3DnX3rmDSARKW2u1M,2192
49
50
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
@@ -60,9 +61,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
60
61
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
62
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
62
63
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
63
- xax-0.3.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
- xax-0.3.13.dist-info/METADATA,sha256=Gl4h20HE74S6yx7NlKB64JF1ngQMx7e8gM5uu1SEH-M,1247
65
- xax-0.3.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
- xax-0.3.13.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
- xax-0.3.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
- xax-0.3.13.dist-info/RECORD,,
64
+ xax-0.3.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
65
+ xax-0.3.15.dist-info/METADATA,sha256=6tDxvfIGYiR4FOCopyZWxz24XNU-cGGHwoiIjFJl4Pc,1247
66
+ xax-0.3.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
+ xax-0.3.15.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
68
+ xax-0.3.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
69
+ xax-0.3.15.dist-info/RECORD,,
File without changes