xax 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.13"
15
+ __version__ = "0.2.14"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -6,7 +6,7 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import dataclass
8
8
  from pathlib import Path
9
- from typing import Generic, Literal, TypeVar, cast, overload
9
+ from typing import Generic, Literal, Sequence, TypeVar, cast, overload
10
10
 
11
11
  import equinox as eqx
12
12
  import jax
@@ -57,10 +57,10 @@ def load_ckpt(
57
57
  path: Path,
58
58
  *,
59
59
  part: Literal["all"],
60
- model_template: PyTree,
61
- optimizer_template: PyTree,
62
- opt_state_template: PyTree,
63
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
60
+ model_templates: Sequence[PyTree],
61
+ optimizer_templates: Sequence[optax.GradientTransformation],
62
+ opt_state_templates: Sequence[optax.OptState],
63
+ ) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, DictConfig]: ...
64
64
 
65
65
 
66
66
  @overload
@@ -68,20 +68,35 @@ def load_ckpt(
68
68
  path: Path,
69
69
  *,
70
70
  part: Literal["model_state_config"],
71
- model_template: PyTree,
72
- ) -> tuple[PyTree, State, DictConfig]: ...
71
+ model_templates: Sequence[PyTree],
72
+ ) -> tuple[list[PyTree], State, DictConfig]: ...
73
73
 
74
74
 
75
75
  @overload
76
- def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
76
+ def load_ckpt(
77
+ path: Path,
78
+ *,
79
+ part: Literal["model"],
80
+ model_templates: Sequence[PyTree],
81
+ ) -> list[PyTree]: ...
77
82
 
78
83
 
79
84
  @overload
80
- def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
85
+ def load_ckpt(
86
+ path: Path,
87
+ *,
88
+ part: Literal["opt"],
89
+ optimizer_templates: Sequence[optax.GradientTransformation],
90
+ ) -> list[optax.GradientTransformation]: ...
81
91
 
82
92
 
83
93
  @overload
84
- def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
94
+ def load_ckpt(
95
+ path: Path,
96
+ *,
97
+ part: Literal["opt_state"],
98
+ opt_state_templates: Sequence[optax.OptState],
99
+ ) -> list[optax.OptState]: ...
85
100
 
86
101
 
87
102
  @overload
@@ -96,40 +111,49 @@ def load_ckpt(
96
111
  path: str | Path,
97
112
  *,
98
113
  part: CheckpointPart = "model",
99
- model_template: PyTree | None = None,
100
- optimizer_template: PyTree | None = None,
101
- opt_state_template: PyTree | None = None,
114
+ model_templates: Sequence[PyTree] | None = None,
115
+ optimizer_templates: Sequence[optax.GradientTransformation] | None = None,
116
+ opt_state_templates: Sequence[optax.OptState] | None = None,
102
117
  ) -> (
103
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
104
- | tuple[PyTree, State, DictConfig]
105
- | PyTree
106
- | optax.GradientTransformation
107
- | optax.OptState
118
+ tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, DictConfig]
119
+ | tuple[list[PyTree], State, DictConfig]
120
+ | list[PyTree]
121
+ | list[optax.GradientTransformation]
122
+ | list[optax.OptState]
108
123
  | State
109
124
  | DictConfig
110
125
  ):
111
126
  with tarfile.open(path, "r:gz") as tar:
112
127
 
113
- def get_model() -> PyTree:
114
- if model_template is None:
128
+ def get_model() -> list[PyTree]:
129
+ if model_templates is None:
115
130
  raise ValueError("model_template must be provided to load model weights")
116
- if (model := tar.extractfile("model")) is None:
117
- raise ValueError(f"Checkpoint does not contain a model file: {path}")
118
- return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
119
-
120
- def get_opt() -> optax.GradientTransformation:
121
- if optimizer_template is None:
131
+ models: list[PyTree] = []
132
+ for i, model_template in enumerate(model_templates):
133
+ if (model := tar.extractfile(f"model_{i}")) is None:
134
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
135
+ models.append(eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template))
136
+ return models
137
+
138
+ def get_opt() -> list[optax.GradientTransformation]:
139
+ if optimizer_templates is None:
122
140
  raise ValueError("optimizer_template must be provided to load optimizer")
123
- if (opt := tar.extractfile("optimizer")) is None:
124
- raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
125
- return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
126
-
127
- def get_opt_state() -> optax.OptState:
128
- if opt_state_template is None:
141
+ opts: list[optax.GradientTransformation] = []
142
+ for i, optimizer_template in enumerate(optimizer_templates):
143
+ if (opt := tar.extractfile(f"optimizer_{i}")) is None:
144
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
145
+ opts.append(eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template))
146
+ return opts
147
+
148
+ def get_opt_state() -> list[optax.OptState]:
149
+ if opt_state_templates is None:
129
150
  raise ValueError("opt_state_template must be provided to load optimizer state")
130
- if (opt_state := tar.extractfile("opt_state")) is None:
131
- raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
132
- return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
151
+ opt_states: list[optax.OptState] = []
152
+ for i, opt_state_template in enumerate(opt_state_templates):
153
+ if (opt_state := tar.extractfile(f"opt_state_{i}")) is None:
154
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
155
+ opt_states.append(eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template))
156
+ return opt_states
133
157
 
134
158
  def get_state() -> State:
135
159
  if (state := tar.extractfile("state")) is None:
@@ -192,20 +216,20 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
192
216
 
193
217
  def save_checkpoint(
194
218
  self,
195
- model: PyTree | None = None,
196
- optimizer: optax.GradientTransformation | None = None,
197
- opt_state: optax.OptState | None = None,
219
+ models: Sequence[PyTree] | None = None,
220
+ optimizers: Sequence[optax.GradientTransformation] | None = None,
221
+ opt_states: Sequence[optax.OptState] | None = None,
198
222
  aux_data: PyTree | None = None,
199
223
  state: State | None = None,
200
224
  ) -> Path:
201
225
  """Save a checkpoint.
202
226
 
203
227
  Args:
204
- model: The model to save
205
- state: The current training state
206
- optimizer: The optimizer to save
228
+ models: The models to save
229
+ optimizers: The optimizers to save
230
+ opt_states: The optimizer states to save
207
231
  aux_data: Additional data to save
208
- opt_state: The optimizer state to save
232
+ state: The current training state
209
233
 
210
234
  Returns:
211
235
  Path to the saved checkpoint
@@ -235,22 +259,25 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
235
259
  tar.addfile(tarinfo, buf)
236
260
 
237
261
  # Save model using Equinox
238
- if model is not None:
239
- with io.BytesIO() as buf:
240
- eqx.tree_serialise_leaves(buf, model)
241
- add_file("model", buf)
262
+ if models is not None:
263
+ for i, model in enumerate(models):
264
+ with io.BytesIO() as buf:
265
+ eqx.tree_serialise_leaves(buf, model)
266
+ add_file(f"model_{i}", buf)
242
267
 
243
268
  # Save optimizer using Equinox
244
- if optimizer is not None:
245
- with io.BytesIO() as buf:
246
- eqx.tree_serialise_leaves(buf, optimizer)
247
- add_file("optimizer", buf)
269
+ if optimizers is not None:
270
+ for i, optimizer in enumerate(optimizers):
271
+ with io.BytesIO() as buf:
272
+ eqx.tree_serialise_leaves(buf, optimizer)
273
+ add_file(f"optimizer_{i}", buf)
248
274
 
249
275
  # Save optimizer state using Equinox
250
- if opt_state is not None:
251
- with io.BytesIO() as buf:
252
- eqx.tree_serialise_leaves(buf, opt_state)
253
- add_file("opt_state", buf)
276
+ if opt_states is not None:
277
+ for i, opt_state in enumerate(opt_states):
278
+ with io.BytesIO() as buf:
279
+ eqx.tree_serialise_leaves(buf, opt_state)
280
+ add_file(f"opt_state_{i}", buf)
254
281
 
255
282
  # Save aux data using Equinox.
256
283
  if aux_data is not None:
xax/task/mixins/train.py CHANGED
@@ -310,23 +310,46 @@ class TrainMixin(
310
310
  self.write_logs(state)
311
311
 
312
312
  @abstractmethod
313
- def get_model(self, key: PRNGKeyArray) -> PyTree:
313
+ def get_model(self, key: PRNGKeyArray) -> PyTree | Sequence[PyTree]:
314
314
  """Returns the Equinox model to train.
315
315
 
316
316
  Returns:
317
317
  The model to train.
318
318
  """
319
319
 
320
+ def _get_models(self, key: PRNGKeyArray) -> list[PyTree]:
321
+ models = self.get_model(key)
322
+ if isinstance(models, Sequence):
323
+ models = list(models)
324
+ elif isinstance(models, eqx.Module):
325
+ models = [models]
326
+ else:
327
+ logger.warning("Model is not a sequence or an eqx.Module, wrapping it in a list anyway")
328
+ models = [models]
329
+ return models
330
+
320
331
  @abstractmethod
321
- def get_optimizer(self) -> optax.GradientTransformation:
332
+ def get_optimizer(self) -> optax.GradientTransformation | Sequence[optax.GradientTransformation]:
322
333
  """Gets the optimizer for the model.
323
334
 
324
335
  Returns:
325
336
  The optimizer to use to train the model.
326
337
  """
327
338
 
328
- def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
329
- return optimizer.init(eqx.filter(model, eqx.is_array))
339
+ def _get_optimizers(self) -> list[optax.GradientTransformation]:
340
+ optimizers = self.get_optimizer()
341
+ if isinstance(optimizers, optax.GradientTransformation):
342
+ optimizers = [optimizers]
343
+ elif isinstance(optimizers, Sequence):
344
+ optimizers = list(optimizers)
345
+ return optimizers
346
+
347
+ def get_initial_opt_state(
348
+ self,
349
+ models: list[PyTree],
350
+ optimizers: list[optax.GradientTransformation],
351
+ ) -> list[optax.OptState]:
352
+ return [opt.init(eqx.filter(model, eqx.is_array)) for model, opt in zip(models, optimizers, strict=True)]
330
353
 
331
354
  @overload
332
355
  def load_initial_state(
@@ -346,7 +369,10 @@ class TrainMixin(
346
369
  self,
347
370
  key: PRNGKeyArray,
348
371
  load_optimizer: bool = False,
349
- ) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
372
+ ) -> (
373
+ tuple[list[PyTree], State]
374
+ | tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]
375
+ ):
350
376
  init_ckpt_path = self.get_init_ckpt_path()
351
377
 
352
378
  if init_ckpt_path is not None:
@@ -364,16 +390,17 @@ class TrainMixin(
364
390
  return model, optimizer, opt_state, state
365
391
 
366
392
  logger.info("Starting a new training run")
367
- model = self.get_model(key)
393
+ models = self._get_models(key)
368
394
  state = State.init_state()
369
395
 
370
396
  if not load_optimizer:
371
- return model, state
397
+ return models, state
372
398
 
373
- optimizer = self.get_optimizer()
374
- opt_state = self.get_initial_opt_state(model, optimizer)
399
+ # Gets the optimizer(s) for the model.
400
+ optimizers = self._get_optimizers()
401
+ opt_states = self.get_initial_opt_state(models, optimizers)
375
402
 
376
- return model, optimizer, opt_state, state
403
+ return models, optimizers, opt_states, state
377
404
 
378
405
  @overload
379
406
  def load_ckpt(
@@ -381,7 +408,7 @@ class TrainMixin(
381
408
  path: Path,
382
409
  *,
383
410
  part: Literal["all"],
384
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
411
+ ) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, Config]: ...
385
412
 
386
413
  @overload
387
414
  def load_ckpt(
@@ -389,7 +416,7 @@ class TrainMixin(
389
416
  path: Path,
390
417
  *,
391
418
  part: Literal["model_state_config"],
392
- ) -> tuple[PyTree, State, Config]: ...
419
+ ) -> tuple[list[PyTree], State, Config]: ...
393
420
 
394
421
  @overload
395
422
  def load_ckpt(
@@ -397,7 +424,7 @@ class TrainMixin(
397
424
  path: Path,
398
425
  *,
399
426
  part: Literal["model"],
400
- ) -> PyTree: ...
427
+ ) -> list[PyTree]: ...
401
428
 
402
429
  @overload
403
430
  def load_ckpt(
@@ -405,7 +432,7 @@ class TrainMixin(
405
432
  path: Path,
406
433
  *,
407
434
  part: Literal["opt"],
408
- ) -> optax.GradientTransformation: ...
435
+ ) -> list[optax.GradientTransformation]: ...
409
436
 
410
437
  @overload
411
438
  def load_ckpt(
@@ -415,7 +442,7 @@ class TrainMixin(
415
442
  part: Literal["opt_state"],
416
443
  model: PyTree | None = None,
417
444
  optimizer: optax.GradientTransformation | None = None,
418
- ) -> optax.OptState: ...
445
+ ) -> list[optax.OptState]: ...
419
446
 
420
447
  @overload
421
448
  def load_ckpt(
@@ -423,7 +450,7 @@ class TrainMixin(
423
450
  path: Path,
424
451
  *,
425
452
  part: Literal["state"],
426
- ) -> State: ...
453
+ ) -> list[State]: ...
427
454
 
428
455
  @overload
429
456
  def load_ckpt(
@@ -431,7 +458,7 @@ class TrainMixin(
431
458
  path: Path,
432
459
  *,
433
460
  part: Literal["config"],
434
- ) -> Config: ...
461
+ ) -> list[Config]: ...
435
462
 
436
463
  def load_ckpt(
437
464
  self,
@@ -441,11 +468,11 @@ class TrainMixin(
441
468
  model: PyTree | None = None,
442
469
  optimizer: optax.GradientTransformation | None = None,
443
470
  ) -> (
444
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
445
- | tuple[PyTree, State, Config]
446
- | PyTree
447
- | optax.GradientTransformation
448
- | optax.OptState
471
+ tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, Config]
472
+ | tuple[list[PyTree], State, Config]
473
+ | list[PyTree]
474
+ | list[optax.GradientTransformation]
475
+ | list[optax.OptState]
449
476
  | State
450
477
  | Config
451
478
  ):
@@ -456,28 +483,28 @@ class TrainMixin(
456
483
 
457
484
  match part:
458
485
  case "model_state_config":
459
- model_spec = eqx.filter_eval_shape(self.get_model, key)
460
- model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
486
+ model_specs = eqx.filter_eval_shape(self._get_models, key)
487
+ model, state, config = load_ckpt(path, part="model_state_config", model_templates=model_specs)
461
488
  config = self.get_config(config, use_cli=False)
462
489
  return model, state, config
463
490
 
464
491
  case "model":
465
- model_spec = eqx.filter_eval_shape(self.get_model, key)
466
- return load_ckpt(path, part="model", model_template=model_spec)
492
+ model_specs = eqx.filter_eval_shape(self._get_models, key)
493
+ return load_ckpt(path, part="model", model_templates=model_specs)
467
494
 
468
495
  case "opt":
469
- optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
470
- return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
496
+ optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
497
+ return load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
471
498
 
472
499
  case "opt_state":
473
500
  if model is None:
474
- model_spec = eqx.filter_eval_shape(self.get_model, key)
475
- model = load_ckpt(path, part="model", model_template=model_spec)
501
+ model_specs = eqx.filter_eval_shape(self._get_models, key)
502
+ model = load_ckpt(path, part="model", model_templates=model_specs)
476
503
  if optimizer is None:
477
- optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
478
- optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
479
- opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
480
- return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
504
+ optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
505
+ optimizer = load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
506
+ opt_state_specs = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
507
+ return load_ckpt(path, part="opt_state", opt_state_templates=opt_state_specs)
481
508
 
482
509
  case "state":
483
510
  return load_ckpt(path, part="state")
@@ -486,12 +513,12 @@ class TrainMixin(
486
513
  return self.get_config(load_ckpt(path, part="config"), use_cli=False)
487
514
 
488
515
  case "all":
489
- model_spec = eqx.filter_eval_shape(self.get_model, key)
490
- model = load_ckpt(path, part="model", model_template=model_spec)
491
- optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
492
- optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
493
- opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
494
- opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
516
+ model_specs = eqx.filter_eval_shape(self._get_models, key)
517
+ model = load_ckpt(path, part="model", model_templates=model_specs)
518
+ optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
519
+ optimizer = load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
520
+ opt_state_specs = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
521
+ opt_state = load_ckpt(path, part="opt_state", opt_state_templates=opt_state_specs)
495
522
  state = load_ckpt(path, part="state")
496
523
  config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
497
524
  return model, optimizer, opt_state, state, config
@@ -718,14 +745,22 @@ class TrainMixin(
718
745
 
719
746
  def train_loop(
720
747
  self,
721
- model: PyTree,
722
- optimizer: optax.GradientTransformation,
723
- opt_state: optax.OptState,
748
+ models: Sequence[PyTree],
749
+ optimizers: Sequence[optax.GradientTransformation],
750
+ opt_states: Sequence[optax.OptState],
724
751
  train_pf: Iterator[Batch],
725
752
  valid_pf: Iterator[Batch],
726
753
  state: State,
727
754
  ) -> None:
728
- model_arr, model_static = eqx.partition(model, self.model_partition_fn)
755
+ if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
756
+ raise ValueError(
757
+ "Vanilla training expects a single model, optimizer and optimizer state. "
758
+ f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
759
+ )
760
+
761
+ model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
762
+ optimizer = optimizers[0]
763
+ opt_state = opt_states[0]
729
764
 
730
765
  while not self.is_training_over(state):
731
766
  valid_step = self.valid_step_timer(state)
@@ -773,11 +808,11 @@ class TrainMixin(
773
808
 
774
809
  if self.should_checkpoint(state):
775
810
  model = eqx.combine(model_arr, model_static)
776
- self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
811
+ self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
777
812
 
778
813
  # After finishing training, save the final checkpoint.
779
814
  model = eqx.combine(model_arr, model_static)
780
- self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
815
+ self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
781
816
 
782
817
  @contextlib.contextmanager
783
818
  def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
@@ -841,14 +876,14 @@ class TrainMixin(
841
876
  Thread(target=self.log_state, daemon=True).start()
842
877
 
843
878
  key, model_key = jax.random.split(key)
844
- model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
845
- logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
846
- logger.info("Optimizer size: %s", f"{get_pytree_param_count(optimizer):,}")
879
+ models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
880
+ logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
881
+ logger.info("Optimizer size: %s", f"{get_pytree_param_count(optimizers):,}")
847
882
 
848
883
  state = self.on_training_start(state)
849
884
 
850
885
  def on_exit() -> None:
851
- self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
886
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
852
887
 
853
888
  # Handle user-defined interrupts during the training loop.
854
889
  self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
@@ -857,9 +892,9 @@ class TrainMixin(
857
892
  with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
858
893
  try:
859
894
  self.train_loop(
860
- model=model,
861
- optimizer=optimizer,
862
- opt_state=opt_state,
895
+ models=models,
896
+ optimizers=optimizers,
897
+ opt_states=opt_states,
863
898
  train_pf=train_pf,
864
899
  valid_pf=valid_pf,
865
900
  state=state,
@@ -869,7 +904,7 @@ class TrainMixin(
869
904
  if is_master():
870
905
  num_steps, num_samples = int(state.num_steps), int(state.num_samples)
871
906
  show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
872
- self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
907
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
873
908
 
874
909
  except (KeyboardInterrupt, bdb.BdbQuit):
875
910
  if is_master():
@@ -879,7 +914,7 @@ class TrainMixin(
879
914
  exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
880
915
  sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
881
916
  sys.stdout.flush()
882
- self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
917
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
883
918
 
884
919
  finally:
885
920
  state = self.on_training_end(state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.13
3
+ Version: 0.2.14
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=33wIwGeXDFReg2ZnFqUHfSybj5cKyMqnI8ncj8-9yVg,15510
1
+ xax/__init__.py,sha256=9RtfWhU2Qb-YlZurazJ_GKpSpPaGs0nwu9-nmzbC0Pk,15510
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -32,7 +32,7 @@ xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,66
32
32
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
35
- xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
35
+ xax/task/mixins/checkpointing.py,sha256=ypdXvC6oJlsUGm4PiTJWXrtTi9w0K9IpoO0-8gM1hZ4,11295
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
38
38
  xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=_QoxSDMW6nmpH82Un2LDsVIBg9KIx8npRwSjY4TEGYA,31830
44
+ xax/task/mixins/train.py,sha256=Fgx2SWGC0e1QtRv1iTXXNg45dzbCzur3UHvjRZfOoiM,33465
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.13.dist-info/METADATA,sha256=-foHRw3ph7yxBmMmjO_oqZqwvdEROYTH4Drc9P58ujI,1880
63
- xax-0.2.13.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
64
- xax-0.2.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.13.dist-info/RECORD,,
61
+ xax-0.2.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.14.dist-info/METADATA,sha256=5Agik9rI2VgDzlElHUVYpQm8JVOytLrcIxtNMCa_UUE,1880
63
+ xax-0.2.14.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
64
+ xax-0.2.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.14.dist-info/RECORD,,
File without changes