xax 0.3.14__tar.gz → 0.3.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {xax-0.3.14/xax.egg-info → xax-0.3.15}/PKG-INFO +1 -1
  2. {xax-0.3.14 → xax-0.3.15}/xax/__init__.py +9 -3
  3. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/__init__.py +2 -1
  4. xax-0.3.15/xax/task/mixins/supervised.py +368 -0
  5. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/train.py +36 -345
  6. {xax-0.3.14 → xax-0.3.15}/xax/task/task.py +26 -2
  7. {xax-0.3.14 → xax-0.3.15/xax.egg-info}/PKG-INFO +1 -1
  8. {xax-0.3.14 → xax-0.3.15}/xax.egg-info/SOURCES.txt +1 -0
  9. {xax-0.3.14 → xax-0.3.15}/LICENSE +0 -0
  10. {xax-0.3.14 → xax-0.3.15}/MANIFEST.in +0 -0
  11. {xax-0.3.14 → xax-0.3.15}/README.md +0 -0
  12. {xax-0.3.14 → xax-0.3.15}/pyproject.toml +0 -0
  13. {xax-0.3.14 → xax-0.3.15}/setup.cfg +0 -0
  14. {xax-0.3.14 → xax-0.3.15}/setup.py +0 -0
  15. {xax-0.3.14 → xax-0.3.15}/xax/cli/__init__.py +0 -0
  16. {xax-0.3.14 → xax-0.3.15}/xax/cli/edit_config.py +0 -0
  17. {xax-0.3.14 → xax-0.3.15}/xax/core/__init__.py +0 -0
  18. {xax-0.3.14 → xax-0.3.15}/xax/core/conf.py +0 -0
  19. {xax-0.3.14 → xax-0.3.15}/xax/core/state.py +0 -0
  20. {xax-0.3.14 → xax-0.3.15}/xax/nn/__init__.py +0 -0
  21. {xax-0.3.14 → xax-0.3.15}/xax/nn/attention.py +0 -0
  22. {xax-0.3.14 → xax-0.3.15}/xax/nn/distributions.py +0 -0
  23. {xax-0.3.14 → xax-0.3.15}/xax/nn/embeddings.py +0 -0
  24. {xax-0.3.14 → xax-0.3.15}/xax/nn/functions.py +0 -0
  25. {xax-0.3.14 → xax-0.3.15}/xax/nn/geom.py +0 -0
  26. {xax-0.3.14 → xax-0.3.15}/xax/nn/losses.py +0 -0
  27. {xax-0.3.14 → xax-0.3.15}/xax/nn/metrics.py +0 -0
  28. {xax-0.3.14 → xax-0.3.15}/xax/nn/parallel.py +0 -0
  29. {xax-0.3.14 → xax-0.3.15}/xax/nn/ssm.py +0 -0
  30. {xax-0.3.14 → xax-0.3.15}/xax/py.typed +0 -0
  31. {xax-0.3.14 → xax-0.3.15}/xax/requirements-dev.txt +0 -0
  32. {xax-0.3.14 → xax-0.3.15}/xax/requirements.txt +0 -0
  33. {xax-0.3.14 → xax-0.3.15}/xax/task/__init__.py +0 -0
  34. {xax-0.3.14 → xax-0.3.15}/xax/task/base.py +0 -0
  35. {xax-0.3.14 → xax-0.3.15}/xax/task/launchers/__init__.py +0 -0
  36. {xax-0.3.14 → xax-0.3.15}/xax/task/launchers/base.py +0 -0
  37. {xax-0.3.14 → xax-0.3.15}/xax/task/launchers/cli.py +0 -0
  38. {xax-0.3.14 → xax-0.3.15}/xax/task/launchers/single_process.py +0 -0
  39. {xax-0.3.14 → xax-0.3.15}/xax/task/logger.py +0 -0
  40. {xax-0.3.14 → xax-0.3.15}/xax/task/loggers/__init__.py +0 -0
  41. {xax-0.3.14 → xax-0.3.15}/xax/task/loggers/callback.py +0 -0
  42. {xax-0.3.14 → xax-0.3.15}/xax/task/loggers/json.py +0 -0
  43. {xax-0.3.14 → xax-0.3.15}/xax/task/loggers/state.py +0 -0
  44. {xax-0.3.14 → xax-0.3.15}/xax/task/loggers/stdout.py +0 -0
  45. {xax-0.3.14 → xax-0.3.15}/xax/task/loggers/tensorboard.py +0 -0
  46. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/artifacts.py +0 -0
  47. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/checkpointing.py +0 -0
  48. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/compile.py +0 -0
  49. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/cpu_stats.py +0 -0
  50. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/data_loader.py +0 -0
  51. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/gpu_stats.py +0 -0
  52. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/logger.py +0 -0
  53. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/process.py +0 -0
  54. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/runnable.py +0 -0
  55. {xax-0.3.14 → xax-0.3.15}/xax/task/mixins/step_wrapper.py +0 -0
  56. {xax-0.3.14 → xax-0.3.15}/xax/task/script.py +0 -0
  57. {xax-0.3.14 → xax-0.3.15}/xax/utils/__init__.py +0 -0
  58. {xax-0.3.14 → xax-0.3.15}/xax/utils/data/__init__.py +0 -0
  59. {xax-0.3.14 → xax-0.3.15}/xax/utils/data/collate.py +0 -0
  60. {xax-0.3.14 → xax-0.3.15}/xax/utils/debugging.py +0 -0
  61. {xax-0.3.14 → xax-0.3.15}/xax/utils/experiments.py +0 -0
  62. {xax-0.3.14 → xax-0.3.15}/xax/utils/jax.py +0 -0
  63. {xax-0.3.14 → xax-0.3.15}/xax/utils/jaxpr.py +0 -0
  64. {xax-0.3.14 → xax-0.3.15}/xax/utils/logging.py +0 -0
  65. {xax-0.3.14 → xax-0.3.15}/xax/utils/numpy.py +0 -0
  66. {xax-0.3.14 → xax-0.3.15}/xax/utils/profile.py +0 -0
  67. {xax-0.3.14 → xax-0.3.15}/xax/utils/pytree.py +0 -0
  68. {xax-0.3.14 → xax-0.3.15}/xax/utils/tensorboard.py +0 -0
  69. {xax-0.3.14 → xax-0.3.15}/xax/utils/text.py +0 -0
  70. {xax-0.3.14 → xax-0.3.15}/xax/utils/types/__init__.py +0 -0
  71. {xax-0.3.14 → xax-0.3.15}/xax/utils/types/frozen_dict.py +0 -0
  72. {xax-0.3.14 → xax-0.3.15}/xax/utils/types/hashable_array.py +0 -0
  73. {xax-0.3.14 → xax-0.3.15}/xax.egg-info/dependency_links.txt +0 -0
  74. {xax-0.3.14 → xax-0.3.15}/xax.egg-info/entry_points.txt +0 -0
  75. {xax-0.3.14 → xax-0.3.15}/xax.egg-info/requires.txt +0 -0
  76. {xax-0.3.14 → xax-0.3.15}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.14
3
+ Version: 0.3.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.14"
15
+ __version__ = "0.3.15"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -94,10 +94,13 @@ __all__ = [
94
94
  "DataloaderConfig",
95
95
  "GPUStatsOptions",
96
96
  "StepContext",
97
+ "InitParams",
97
98
  "ValidStepTimer",
98
99
  "Script",
99
100
  "ScriptConfig",
100
101
  "Config",
102
+ "SupervisedConfig",
103
+ "SupervisedTask",
101
104
  "Task",
102
105
  "collate",
103
106
  "collate_non_null",
@@ -291,10 +294,13 @@ NAME_MAP: dict[str, str] = {
291
294
  "DataloaderConfig": "task.mixins.data_loader",
292
295
  "GPUStatsOptions": "task.mixins.gpu_stats",
293
296
  "StepContext": "task.mixins.step_wrapper",
297
+ "InitParams": "task.mixins.train",
294
298
  "ValidStepTimer": "task.mixins.train",
295
299
  "Script": "task.script",
296
300
  "ScriptConfig": "task.script",
297
301
  "Config": "task.task",
302
+ "SupervisedConfig": "task.task",
303
+ "SupervisedTask": "task.task",
298
304
  "Task": "task.task",
299
305
  "collate": "utils.data.collate",
300
306
  "collate_non_null": "utils.data.collate",
@@ -488,9 +494,9 @@ if IMPORT_ALL or TYPE_CHECKING:
488
494
  from xax.task.mixins.data_loader import DataloaderConfig
489
495
  from xax.task.mixins.gpu_stats import GPUStatsOptions
490
496
  from xax.task.mixins.step_wrapper import StepContext
491
- from xax.task.mixins.train import Batch, Output, ValidStepTimer
497
+ from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
492
498
  from xax.task.script import Script, ScriptConfig
493
- from xax.task.task import Config, Task
499
+ from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
494
500
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
495
501
  from xax.utils.debugging import (
496
502
  breakpoint_if_nonfinite,
@@ -10,4 +10,5 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
10
10
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
11
11
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
12
12
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
13
- from xax.task.mixins.train import TrainConfig, TrainMixin
13
+ from xax.task.mixins.supervised import SupervisedConfig, SupervisedMixin
14
+ from xax.task.mixins.train import InitParams, TrainConfig, TrainMixin
@@ -0,0 +1,368 @@
1
+ """Defines a mixin for running the training loop."""
2
+
3
+ import bdb
4
+ import contextlib
5
+ import itertools
6
+ import logging
7
+ import signal
8
+ import sys
9
+ import textwrap
10
+ import traceback
11
+ from abc import ABC
12
+ from dataclasses import dataclass
13
+ from threading import Thread
14
+ from typing import (
15
+ Generator,
16
+ Generic,
17
+ Iterator,
18
+ Sequence,
19
+ TypeVar,
20
+ )
21
+
22
+ import equinox as eqx
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import optax
26
+ from jaxtyping import Array, PRNGKeyArray, PyTree
27
+
28
+ from xax.core.conf import field
29
+ from xax.core.state import State
30
+ from xax.nn.parallel import is_master
31
+ from xax.task.mixins.train import Batch, InitParams, Output, TrainConfig, TrainMixin
32
+ from xax.utils.experiments import (
33
+ ContextTimer,
34
+ TrainingFinishedError,
35
+ )
36
+ from xax.utils.jax import jit as xax_jit, scan as xax_scan
37
+ from xax.utils.logging import LOG_PING
38
+ from xax.utils.pytree import get_pytree_param_count
39
+ from xax.utils.text import highlight_exception_message, show_info
40
+ from xax.utils.types.frozen_dict import FrozenDict
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @jax.tree_util.register_dataclass
46
+ @dataclass
47
+ class SupervisedConfig(TrainConfig):
48
+ updates_per_step: int = field(1, help="Number of updates to perform per step")
49
+
50
+
51
+ Config = TypeVar("Config", bound=SupervisedConfig)
52
+
53
+
54
+ class SupervisedMixin(
55
+ TrainMixin[Config, InitParams],
56
+ Generic[Config],
57
+ ABC,
58
+ ):
59
+ def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
60
+ """Gets the output from the model.
61
+
62
+ By default, we assume the model is a function that takes the batch as
63
+ input and returns the loss. This function can be patched to do more
64
+ complex operations instead.
65
+
66
+ Args:
67
+ model: The current model.
68
+ batch: The current minibatch of samples.
69
+ state: The current training state.
70
+ """
71
+ raise NotImplementedError("`get_output` must be implemented by the subclass")
72
+
73
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
74
+ """Gets the loss for the current batch.
75
+
76
+ By default, we assume the model is a function that takes the batch as
77
+ input and returns the loss. This function can be patched to do more
78
+ complex operations instead.
79
+
80
+ Args:
81
+ model: The current model.
82
+ batch: The current minibatch of samples.
83
+ output: The output from the model.
84
+ state: The current training state.
85
+
86
+ Returns:
87
+ The computed loss, as a tensor.
88
+ """
89
+ if not isinstance(output, Array):
90
+ raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
91
+ return output
92
+
93
+ def compute_metrics(
94
+ self,
95
+ model: PyTree,
96
+ batch: Batch,
97
+ output: Output,
98
+ loss: Array,
99
+ state: State,
100
+ ) -> dict[str, Array]:
101
+ """Computes the metrics for the current batch.
102
+
103
+ Args:
104
+ model: The current model.
105
+ batch: The current minibatch of samples.
106
+ output: The output from the model.
107
+ loss: The loss for the current batch.
108
+ state: The current training state.
109
+
110
+ Returns:
111
+ A dictionary of metrics.
112
+ """
113
+ return {
114
+ "loss": loss,
115
+ }
116
+
117
+ @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
118
+ def get_output_and_loss(
119
+ self,
120
+ model_arr: PyTree,
121
+ model_static: PyTree,
122
+ batch: Batch,
123
+ state: State,
124
+ ) -> tuple[Array, tuple[Output, dict[str, Array]]]:
125
+ model = eqx.combine(model_arr, model_static)
126
+ output = self.get_output(model, batch, state)
127
+ loss = self.compute_loss(model, batch, output, state)
128
+ metrics = self.compute_metrics(model, batch, output, loss, state)
129
+ return loss, (output, metrics)
130
+
131
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
132
+ def update(
133
+ self,
134
+ model_arr: PyTree,
135
+ model_static: PyTree,
136
+ optimizer: optax.GradientTransformation,
137
+ opt_state: optax.OptState,
138
+ batch: Batch,
139
+ state: State,
140
+ ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
141
+ grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
142
+ grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
143
+ grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
144
+ updates, opt_state = optimizer.update(grads, opt_state, model_arr)
145
+ model_arr = eqx.apply_updates(model_arr, updates)
146
+ return model_arr, opt_state, output, metrics
147
+
148
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
149
+ def train_step(
150
+ self,
151
+ model_arr: PyTree,
152
+ model_static: PyTree,
153
+ optimizer: optax.GradientTransformation,
154
+ opt_state: optax.OptState,
155
+ batches: Batch,
156
+ state: State,
157
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
158
+ def update_fn(
159
+ carry: tuple[PyTree, optax.OptState],
160
+ batch: Batch,
161
+ ) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
162
+ model_arr, opt_state = carry
163
+ model_arr, opt_state, output, metrics = self.update(
164
+ model_arr,
165
+ model_static,
166
+ optimizer,
167
+ opt_state,
168
+ batch,
169
+ state,
170
+ )
171
+ return (model_arr, opt_state), (output, FrozenDict(metrics))
172
+
173
+ (model_arr, opt_state), (output, metrics) = xax_scan(
174
+ update_fn,
175
+ (model_arr, opt_state),
176
+ batches,
177
+ jit_level=3,
178
+ )
179
+
180
+ # Only get the final output and metrics.
181
+ output = jax.tree.map(lambda x: x[-1], output)
182
+ metrics = jax.tree.map(lambda x: x[-1], metrics)
183
+
184
+ return model_arr, opt_state, output, metrics
185
+
186
+ @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
187
+ def val_step(
188
+ self,
189
+ model_arr: PyTree,
190
+ model_static: PyTree,
191
+ batch: Batch,
192
+ state: State,
193
+ ) -> tuple[Output, FrozenDict[str, Array]]:
194
+ _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
195
+ return output, FrozenDict(metrics)
196
+
197
+ def train_loop(
198
+ self,
199
+ models: Sequence[PyTree],
200
+ optimizers: Sequence[optax.GradientTransformation],
201
+ opt_states: Sequence[optax.OptState],
202
+ train_pf: Iterator[Batch],
203
+ valid_pf: Iterator[Batch],
204
+ state: State,
205
+ ) -> None:
206
+ if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
207
+ raise ValueError(
208
+ "Vanilla training expects a single model, optimizer and optimizer state. "
209
+ f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
210
+ )
211
+
212
+ model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
213
+ optimizer = optimizers[0]
214
+ opt_state = opt_states[0]
215
+
216
+ while not self.is_training_over(state):
217
+ valid_step = self.valid_step_timer(state)
218
+
219
+ if valid_step:
220
+ with ContextTimer() as timer:
221
+ state = state.replace(phase="valid")
222
+ valid_batch = next(valid_pf)
223
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
224
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
225
+
226
+ state = state.replace(
227
+ num_steps=state.num_steps + 1,
228
+ num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
229
+ elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
230
+ )
231
+
232
+ with ContextTimer() as timer:
233
+ state = self.on_step_start(state)
234
+ state = state.replace(phase="train")
235
+ train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
236
+ model_arr, opt_state, output, metrics = self.train_step(
237
+ model_arr=model_arr,
238
+ model_static=model_static,
239
+ optimizer=optimizer,
240
+ opt_state=opt_state,
241
+ batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
242
+ state=state,
243
+ )
244
+ self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
245
+ state = self.on_step_end(state)
246
+
247
+ state = state.replace(
248
+ num_steps=state.num_steps + 1,
249
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
250
+ elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
251
+ )
252
+
253
+ if state.num_steps <= 3:
254
+ logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
255
+
256
+ if self.should_checkpoint(state):
257
+ model = eqx.combine(model_arr, model_static)
258
+ self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
259
+
260
+ # After finishing training, save the final checkpoint.
261
+ model = eqx.combine(model_arr, model_static)
262
+ self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
263
+
264
+ @contextlib.contextmanager
265
+ def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
266
+ try:
267
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
268
+ yield train_iterator
269
+ return
270
+ except NotImplementedError:
271
+ pass
272
+
273
+ train_ds = self.get_dataset("train")
274
+ train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
275
+ train_pf = self.get_prefetcher(train_dl)
276
+
277
+ try:
278
+ with train_pf as train_pf_ctx:
279
+ yield train_pf_ctx
280
+ finally:
281
+ logger.info("Closing train prefetcher")
282
+
283
+ @contextlib.contextmanager
284
+ def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
285
+ try:
286
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
287
+ yield valid_iterator
288
+ return
289
+ except NotImplementedError:
290
+ pass
291
+
292
+ valid_ds = self.get_dataset("valid")
293
+ valid_dl = self.get_dataloader(valid_ds, "valid")
294
+ valid_pf = self.get_prefetcher(valid_dl)
295
+
296
+ try:
297
+ with valid_pf as valid_pf_ctx:
298
+ yield valid_pf_ctx
299
+ finally:
300
+ logger.info("Closing valid prefetcher")
301
+
302
+ def run(self) -> None:
303
+ self.run_training()
304
+
305
+ def run_training(self) -> None:
306
+ """Runs the training loop.
307
+
308
+ Args:
309
+ model: The current model
310
+ task: The current task
311
+ optimizer: The current optimizer
312
+ lr_scheduler: The current learning rate scheduler
313
+
314
+ Raises:
315
+ ValueError: If the task is not a supervised learning task
316
+ """
317
+ with self:
318
+ key = self.prng_key()
319
+
320
+ self.set_loggers()
321
+
322
+ if is_master():
323
+ Thread(target=self.log_state, daemon=True).start()
324
+
325
+ key, model_key = jax.random.split(key)
326
+ init_params = InitParams(key=model_key)
327
+ models, optimizers, opt_states, state = self.load_initial_state(init_params, load_optimizer=True)
328
+ logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
329
+ logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
330
+
331
+ state = self.on_training_start(state)
332
+
333
+ def on_exit() -> None:
334
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
335
+
336
+ # Handle user-defined interrupts during the training loop.
337
+ self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
338
+
339
+ key, tkey, vkey = jax.random.split(key, 3)
340
+ with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
341
+ try:
342
+ self.train_loop(
343
+ models=models,
344
+ optimizers=optimizers,
345
+ opt_states=opt_states,
346
+ train_pf=train_pf,
347
+ valid_pf=valid_pf,
348
+ state=state,
349
+ )
350
+
351
+ except TrainingFinishedError:
352
+ if is_master():
353
+ num_steps, num_samples = int(state.num_steps), int(state.num_samples)
354
+ show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
355
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
356
+
357
+ except (KeyboardInterrupt, bdb.BdbQuit):
358
+ if is_master():
359
+ show_info("Interrupted training", important=True)
360
+
361
+ except BaseException:
362
+ exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
363
+ sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
364
+ sys.stdout.flush()
365
+ self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
366
+
367
+ finally:
368
+ state = self.on_training_end(state)