xax 0.3.13__tar.gz → 0.3.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.13/xax.egg-info → xax-0.3.15}/PKG-INFO +1 -1
- {xax-0.3.13 → xax-0.3.15}/xax/__init__.py +9 -3
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/__init__.py +2 -1
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/artifacts.py +1 -1
- xax-0.3.15/xax/task/mixins/supervised.py +368 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/train.py +36 -345
- {xax-0.3.13 → xax-0.3.15}/xax/task/task.py +26 -2
- {xax-0.3.13 → xax-0.3.15/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.13 → xax-0.3.15}/xax.egg-info/SOURCES.txt +1 -0
- {xax-0.3.13 → xax-0.3.15}/LICENSE +0 -0
- {xax-0.3.13 → xax-0.3.15}/MANIFEST.in +0 -0
- {xax-0.3.13 → xax-0.3.15}/README.md +0 -0
- {xax-0.3.13 → xax-0.3.15}/pyproject.toml +0 -0
- {xax-0.3.13 → xax-0.3.15}/setup.cfg +0 -0
- {xax-0.3.13 → xax-0.3.15}/setup.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/cli/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/cli/edit_config.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/core/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/core/conf.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/core/state.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/attention.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/distributions.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/embeddings.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/functions.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/geom.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/losses.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/metrics.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/parallel.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/nn/ssm.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/py.typed +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/requirements-dev.txt +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/requirements.txt +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/base.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/launchers/base.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/logger.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/loggers/json.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/loggers/state.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/process.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/task/script.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/data/collate.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/debugging.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/experiments.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/jax.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/logging.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/numpy.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/profile.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/pytree.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/text.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.13 → xax-0.3.15}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.15"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -94,10 +94,13 @@ __all__ = [
|
|
94
94
|
"DataloaderConfig",
|
95
95
|
"GPUStatsOptions",
|
96
96
|
"StepContext",
|
97
|
+
"InitParams",
|
97
98
|
"ValidStepTimer",
|
98
99
|
"Script",
|
99
100
|
"ScriptConfig",
|
100
101
|
"Config",
|
102
|
+
"SupervisedConfig",
|
103
|
+
"SupervisedTask",
|
101
104
|
"Task",
|
102
105
|
"collate",
|
103
106
|
"collate_non_null",
|
@@ -291,10 +294,13 @@ NAME_MAP: dict[str, str] = {
|
|
291
294
|
"DataloaderConfig": "task.mixins.data_loader",
|
292
295
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
293
296
|
"StepContext": "task.mixins.step_wrapper",
|
297
|
+
"InitParams": "task.mixins.train",
|
294
298
|
"ValidStepTimer": "task.mixins.train",
|
295
299
|
"Script": "task.script",
|
296
300
|
"ScriptConfig": "task.script",
|
297
301
|
"Config": "task.task",
|
302
|
+
"SupervisedConfig": "task.task",
|
303
|
+
"SupervisedTask": "task.task",
|
298
304
|
"Task": "task.task",
|
299
305
|
"collate": "utils.data.collate",
|
300
306
|
"collate_non_null": "utils.data.collate",
|
@@ -488,9 +494,9 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
488
494
|
from xax.task.mixins.data_loader import DataloaderConfig
|
489
495
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
490
496
|
from xax.task.mixins.step_wrapper import StepContext
|
491
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
497
|
+
from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
|
492
498
|
from xax.task.script import Script, ScriptConfig
|
493
|
-
from xax.task.task import Config, Task
|
499
|
+
from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
|
494
500
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
495
501
|
from xax.utils.debugging import (
|
496
502
|
breakpoint_if_nonfinite,
|
@@ -10,4 +10,5 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
10
10
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
11
11
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
12
12
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
13
|
-
from xax.task.mixins.
|
13
|
+
from xax.task.mixins.supervised import SupervisedConfig, SupervisedMixin
|
14
|
+
from xax.task.mixins.train import InitParams, TrainConfig, TrainMixin
|
@@ -82,7 +82,7 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
82
82
|
return self._exp_dir
|
83
83
|
|
84
84
|
def get_exp_dir(run_id: int) -> Path:
|
85
|
-
return self.run_dir / f"run_{run_id}"
|
85
|
+
return self.run_dir / f"run_{run_id:03d}"
|
86
86
|
|
87
87
|
run_id = 0
|
88
88
|
while (exp_dir := get_exp_dir(run_id)).is_dir():
|
@@ -0,0 +1,368 @@
|
|
1
|
+
"""Defines a mixin for running the training loop."""
|
2
|
+
|
3
|
+
import bdb
|
4
|
+
import contextlib
|
5
|
+
import itertools
|
6
|
+
import logging
|
7
|
+
import signal
|
8
|
+
import sys
|
9
|
+
import textwrap
|
10
|
+
import traceback
|
11
|
+
from abc import ABC
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from threading import Thread
|
14
|
+
from typing import (
|
15
|
+
Generator,
|
16
|
+
Generic,
|
17
|
+
Iterator,
|
18
|
+
Sequence,
|
19
|
+
TypeVar,
|
20
|
+
)
|
21
|
+
|
22
|
+
import equinox as eqx
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
import optax
|
26
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree
|
27
|
+
|
28
|
+
from xax.core.conf import field
|
29
|
+
from xax.core.state import State
|
30
|
+
from xax.nn.parallel import is_master
|
31
|
+
from xax.task.mixins.train import Batch, InitParams, Output, TrainConfig, TrainMixin
|
32
|
+
from xax.utils.experiments import (
|
33
|
+
ContextTimer,
|
34
|
+
TrainingFinishedError,
|
35
|
+
)
|
36
|
+
from xax.utils.jax import jit as xax_jit, scan as xax_scan
|
37
|
+
from xax.utils.logging import LOG_PING
|
38
|
+
from xax.utils.pytree import get_pytree_param_count
|
39
|
+
from xax.utils.text import highlight_exception_message, show_info
|
40
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
41
|
+
|
42
|
+
logger = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
@jax.tree_util.register_dataclass
|
46
|
+
@dataclass
|
47
|
+
class SupervisedConfig(TrainConfig):
|
48
|
+
updates_per_step: int = field(1, help="Number of updates to perform per step")
|
49
|
+
|
50
|
+
|
51
|
+
Config = TypeVar("Config", bound=SupervisedConfig)
|
52
|
+
|
53
|
+
|
54
|
+
class SupervisedMixin(
|
55
|
+
TrainMixin[Config, InitParams],
|
56
|
+
Generic[Config],
|
57
|
+
ABC,
|
58
|
+
):
|
59
|
+
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
60
|
+
"""Gets the output from the model.
|
61
|
+
|
62
|
+
By default, we assume the model is a function that takes the batch as
|
63
|
+
input and returns the loss. This function can be patched to do more
|
64
|
+
complex operations instead.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
model: The current model.
|
68
|
+
batch: The current minibatch of samples.
|
69
|
+
state: The current training state.
|
70
|
+
"""
|
71
|
+
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
72
|
+
|
73
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
74
|
+
"""Gets the loss for the current batch.
|
75
|
+
|
76
|
+
By default, we assume the model is a function that takes the batch as
|
77
|
+
input and returns the loss. This function can be patched to do more
|
78
|
+
complex operations instead.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
model: The current model.
|
82
|
+
batch: The current minibatch of samples.
|
83
|
+
output: The output from the model.
|
84
|
+
state: The current training state.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
The computed loss, as a tensor.
|
88
|
+
"""
|
89
|
+
if not isinstance(output, Array):
|
90
|
+
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
91
|
+
return output
|
92
|
+
|
93
|
+
def compute_metrics(
|
94
|
+
self,
|
95
|
+
model: PyTree,
|
96
|
+
batch: Batch,
|
97
|
+
output: Output,
|
98
|
+
loss: Array,
|
99
|
+
state: State,
|
100
|
+
) -> dict[str, Array]:
|
101
|
+
"""Computes the metrics for the current batch.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
model: The current model.
|
105
|
+
batch: The current minibatch of samples.
|
106
|
+
output: The output from the model.
|
107
|
+
loss: The loss for the current batch.
|
108
|
+
state: The current training state.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
A dictionary of metrics.
|
112
|
+
"""
|
113
|
+
return {
|
114
|
+
"loss": loss,
|
115
|
+
}
|
116
|
+
|
117
|
+
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
118
|
+
def get_output_and_loss(
|
119
|
+
self,
|
120
|
+
model_arr: PyTree,
|
121
|
+
model_static: PyTree,
|
122
|
+
batch: Batch,
|
123
|
+
state: State,
|
124
|
+
) -> tuple[Array, tuple[Output, dict[str, Array]]]:
|
125
|
+
model = eqx.combine(model_arr, model_static)
|
126
|
+
output = self.get_output(model, batch, state)
|
127
|
+
loss = self.compute_loss(model, batch, output, state)
|
128
|
+
metrics = self.compute_metrics(model, batch, output, loss, state)
|
129
|
+
return loss, (output, metrics)
|
130
|
+
|
131
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
132
|
+
def update(
|
133
|
+
self,
|
134
|
+
model_arr: PyTree,
|
135
|
+
model_static: PyTree,
|
136
|
+
optimizer: optax.GradientTransformation,
|
137
|
+
opt_state: optax.OptState,
|
138
|
+
batch: Batch,
|
139
|
+
state: State,
|
140
|
+
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
141
|
+
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
142
|
+
grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
|
143
|
+
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
144
|
+
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
145
|
+
model_arr = eqx.apply_updates(model_arr, updates)
|
146
|
+
return model_arr, opt_state, output, metrics
|
147
|
+
|
148
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
149
|
+
def train_step(
|
150
|
+
self,
|
151
|
+
model_arr: PyTree,
|
152
|
+
model_static: PyTree,
|
153
|
+
optimizer: optax.GradientTransformation,
|
154
|
+
opt_state: optax.OptState,
|
155
|
+
batches: Batch,
|
156
|
+
state: State,
|
157
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
158
|
+
def update_fn(
|
159
|
+
carry: tuple[PyTree, optax.OptState],
|
160
|
+
batch: Batch,
|
161
|
+
) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
|
162
|
+
model_arr, opt_state = carry
|
163
|
+
model_arr, opt_state, output, metrics = self.update(
|
164
|
+
model_arr,
|
165
|
+
model_static,
|
166
|
+
optimizer,
|
167
|
+
opt_state,
|
168
|
+
batch,
|
169
|
+
state,
|
170
|
+
)
|
171
|
+
return (model_arr, opt_state), (output, FrozenDict(metrics))
|
172
|
+
|
173
|
+
(model_arr, opt_state), (output, metrics) = xax_scan(
|
174
|
+
update_fn,
|
175
|
+
(model_arr, opt_state),
|
176
|
+
batches,
|
177
|
+
jit_level=3,
|
178
|
+
)
|
179
|
+
|
180
|
+
# Only get the final output and metrics.
|
181
|
+
output = jax.tree.map(lambda x: x[-1], output)
|
182
|
+
metrics = jax.tree.map(lambda x: x[-1], metrics)
|
183
|
+
|
184
|
+
return model_arr, opt_state, output, metrics
|
185
|
+
|
186
|
+
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
187
|
+
def val_step(
|
188
|
+
self,
|
189
|
+
model_arr: PyTree,
|
190
|
+
model_static: PyTree,
|
191
|
+
batch: Batch,
|
192
|
+
state: State,
|
193
|
+
) -> tuple[Output, FrozenDict[str, Array]]:
|
194
|
+
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
195
|
+
return output, FrozenDict(metrics)
|
196
|
+
|
197
|
+
def train_loop(
|
198
|
+
self,
|
199
|
+
models: Sequence[PyTree],
|
200
|
+
optimizers: Sequence[optax.GradientTransformation],
|
201
|
+
opt_states: Sequence[optax.OptState],
|
202
|
+
train_pf: Iterator[Batch],
|
203
|
+
valid_pf: Iterator[Batch],
|
204
|
+
state: State,
|
205
|
+
) -> None:
|
206
|
+
if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
|
207
|
+
raise ValueError(
|
208
|
+
"Vanilla training expects a single model, optimizer and optimizer state. "
|
209
|
+
f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
|
210
|
+
)
|
211
|
+
|
212
|
+
model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
|
213
|
+
optimizer = optimizers[0]
|
214
|
+
opt_state = opt_states[0]
|
215
|
+
|
216
|
+
while not self.is_training_over(state):
|
217
|
+
valid_step = self.valid_step_timer(state)
|
218
|
+
|
219
|
+
if valid_step:
|
220
|
+
with ContextTimer() as timer:
|
221
|
+
state = state.replace(phase="valid")
|
222
|
+
valid_batch = next(valid_pf)
|
223
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
224
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
225
|
+
|
226
|
+
state = state.replace(
|
227
|
+
num_steps=state.num_steps + 1,
|
228
|
+
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
229
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
230
|
+
)
|
231
|
+
|
232
|
+
with ContextTimer() as timer:
|
233
|
+
state = self.on_step_start(state)
|
234
|
+
state = state.replace(phase="train")
|
235
|
+
train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
|
236
|
+
model_arr, opt_state, output, metrics = self.train_step(
|
237
|
+
model_arr=model_arr,
|
238
|
+
model_static=model_static,
|
239
|
+
optimizer=optimizer,
|
240
|
+
opt_state=opt_state,
|
241
|
+
batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
|
242
|
+
state=state,
|
243
|
+
)
|
244
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
|
245
|
+
state = self.on_step_end(state)
|
246
|
+
|
247
|
+
state = state.replace(
|
248
|
+
num_steps=state.num_steps + 1,
|
249
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
|
250
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
251
|
+
)
|
252
|
+
|
253
|
+
if state.num_steps <= 3:
|
254
|
+
logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
|
255
|
+
|
256
|
+
if self.should_checkpoint(state):
|
257
|
+
model = eqx.combine(model_arr, model_static)
|
258
|
+
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
259
|
+
|
260
|
+
# After finishing training, save the final checkpoint.
|
261
|
+
model = eqx.combine(model_arr, model_static)
|
262
|
+
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
263
|
+
|
264
|
+
@contextlib.contextmanager
|
265
|
+
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
266
|
+
try:
|
267
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
|
268
|
+
yield train_iterator
|
269
|
+
return
|
270
|
+
except NotImplementedError:
|
271
|
+
pass
|
272
|
+
|
273
|
+
train_ds = self.get_dataset("train")
|
274
|
+
train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
|
275
|
+
train_pf = self.get_prefetcher(train_dl)
|
276
|
+
|
277
|
+
try:
|
278
|
+
with train_pf as train_pf_ctx:
|
279
|
+
yield train_pf_ctx
|
280
|
+
finally:
|
281
|
+
logger.info("Closing train prefetcher")
|
282
|
+
|
283
|
+
@contextlib.contextmanager
|
284
|
+
def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
285
|
+
try:
|
286
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
|
287
|
+
yield valid_iterator
|
288
|
+
return
|
289
|
+
except NotImplementedError:
|
290
|
+
pass
|
291
|
+
|
292
|
+
valid_ds = self.get_dataset("valid")
|
293
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
294
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
295
|
+
|
296
|
+
try:
|
297
|
+
with valid_pf as valid_pf_ctx:
|
298
|
+
yield valid_pf_ctx
|
299
|
+
finally:
|
300
|
+
logger.info("Closing valid prefetcher")
|
301
|
+
|
302
|
+
def run(self) -> None:
|
303
|
+
self.run_training()
|
304
|
+
|
305
|
+
def run_training(self) -> None:
|
306
|
+
"""Runs the training loop.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
model: The current model
|
310
|
+
task: The current task
|
311
|
+
optimizer: The current optimizer
|
312
|
+
lr_scheduler: The current learning rate scheduler
|
313
|
+
|
314
|
+
Raises:
|
315
|
+
ValueError: If the task is not a supervised learning task
|
316
|
+
"""
|
317
|
+
with self:
|
318
|
+
key = self.prng_key()
|
319
|
+
|
320
|
+
self.set_loggers()
|
321
|
+
|
322
|
+
if is_master():
|
323
|
+
Thread(target=self.log_state, daemon=True).start()
|
324
|
+
|
325
|
+
key, model_key = jax.random.split(key)
|
326
|
+
init_params = InitParams(key=model_key)
|
327
|
+
models, optimizers, opt_states, state = self.load_initial_state(init_params, load_optimizer=True)
|
328
|
+
logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
|
329
|
+
logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
|
330
|
+
|
331
|
+
state = self.on_training_start(state)
|
332
|
+
|
333
|
+
def on_exit() -> None:
|
334
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
335
|
+
|
336
|
+
# Handle user-defined interrupts during the training loop.
|
337
|
+
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
338
|
+
|
339
|
+
key, tkey, vkey = jax.random.split(key, 3)
|
340
|
+
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
341
|
+
try:
|
342
|
+
self.train_loop(
|
343
|
+
models=models,
|
344
|
+
optimizers=optimizers,
|
345
|
+
opt_states=opt_states,
|
346
|
+
train_pf=train_pf,
|
347
|
+
valid_pf=valid_pf,
|
348
|
+
state=state,
|
349
|
+
)
|
350
|
+
|
351
|
+
except TrainingFinishedError:
|
352
|
+
if is_master():
|
353
|
+
num_steps, num_samples = int(state.num_steps), int(state.num_samples)
|
354
|
+
show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
|
355
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
356
|
+
|
357
|
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
358
|
+
if is_master():
|
359
|
+
show_info("Interrupted training", important=True)
|
360
|
+
|
361
|
+
except BaseException:
|
362
|
+
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
363
|
+
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
364
|
+
sys.stdout.flush()
|
365
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
366
|
+
|
367
|
+
finally:
|
368
|
+
state = self.on_training_end(state)
|