xax 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +9 -3
- xax/task/mixins/__init__.py +2 -1
- xax/task/mixins/supervised.py +368 -0
- xax/task/mixins/train.py +36 -345
- xax/task/task.py +26 -2
- {xax-0.3.14.dist-info → xax-0.4.0.dist-info}/METADATA +1 -1
- {xax-0.3.14.dist-info → xax-0.4.0.dist-info}/RECORD +11 -10
- {xax-0.3.14.dist-info → xax-0.4.0.dist-info}/WHEEL +0 -0
- {xax-0.3.14.dist-info → xax-0.4.0.dist-info}/entry_points.txt +0 -0
- {xax-0.3.14.dist-info → xax-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.14.dist-info → xax-0.4.0.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.
|
15
|
+
__version__ = "0.4.0"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -94,10 +94,13 @@ __all__ = [
|
|
94
94
|
"DataloaderConfig",
|
95
95
|
"GPUStatsOptions",
|
96
96
|
"StepContext",
|
97
|
+
"InitParams",
|
97
98
|
"ValidStepTimer",
|
98
99
|
"Script",
|
99
100
|
"ScriptConfig",
|
100
101
|
"Config",
|
102
|
+
"SupervisedConfig",
|
103
|
+
"SupervisedTask",
|
101
104
|
"Task",
|
102
105
|
"collate",
|
103
106
|
"collate_non_null",
|
@@ -291,10 +294,13 @@ NAME_MAP: dict[str, str] = {
|
|
291
294
|
"DataloaderConfig": "task.mixins.data_loader",
|
292
295
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
293
296
|
"StepContext": "task.mixins.step_wrapper",
|
297
|
+
"InitParams": "task.mixins.train",
|
294
298
|
"ValidStepTimer": "task.mixins.train",
|
295
299
|
"Script": "task.script",
|
296
300
|
"ScriptConfig": "task.script",
|
297
301
|
"Config": "task.task",
|
302
|
+
"SupervisedConfig": "task.task",
|
303
|
+
"SupervisedTask": "task.task",
|
298
304
|
"Task": "task.task",
|
299
305
|
"collate": "utils.data.collate",
|
300
306
|
"collate_non_null": "utils.data.collate",
|
@@ -488,9 +494,9 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
488
494
|
from xax.task.mixins.data_loader import DataloaderConfig
|
489
495
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
490
496
|
from xax.task.mixins.step_wrapper import StepContext
|
491
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
497
|
+
from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
|
492
498
|
from xax.task.script import Script, ScriptConfig
|
493
|
-
from xax.task.task import Config, Task
|
499
|
+
from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
|
494
500
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
495
501
|
from xax.utils.debugging import (
|
496
502
|
breakpoint_if_nonfinite,
|
xax/task/mixins/__init__.py
CHANGED
@@ -10,4 +10,5 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
10
10
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
11
11
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
12
12
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
13
|
-
from xax.task.mixins.
|
13
|
+
from xax.task.mixins.supervised import SupervisedConfig, SupervisedMixin
|
14
|
+
from xax.task.mixins.train import InitParams, TrainConfig, TrainMixin
|
@@ -0,0 +1,368 @@
|
|
1
|
+
"""Defines a mixin for running the training loop."""
|
2
|
+
|
3
|
+
import bdb
|
4
|
+
import contextlib
|
5
|
+
import itertools
|
6
|
+
import logging
|
7
|
+
import signal
|
8
|
+
import sys
|
9
|
+
import textwrap
|
10
|
+
import traceback
|
11
|
+
from abc import ABC
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from threading import Thread
|
14
|
+
from typing import (
|
15
|
+
Generator,
|
16
|
+
Generic,
|
17
|
+
Iterator,
|
18
|
+
Sequence,
|
19
|
+
TypeVar,
|
20
|
+
)
|
21
|
+
|
22
|
+
import equinox as eqx
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
import optax
|
26
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree
|
27
|
+
|
28
|
+
from xax.core.conf import field
|
29
|
+
from xax.core.state import State
|
30
|
+
from xax.nn.parallel import is_master
|
31
|
+
from xax.task.mixins.train import Batch, InitParams, Output, TrainConfig, TrainMixin
|
32
|
+
from xax.utils.experiments import (
|
33
|
+
ContextTimer,
|
34
|
+
TrainingFinishedError,
|
35
|
+
)
|
36
|
+
from xax.utils.jax import jit as xax_jit, scan as xax_scan
|
37
|
+
from xax.utils.logging import LOG_PING
|
38
|
+
from xax.utils.pytree import get_pytree_param_count
|
39
|
+
from xax.utils.text import highlight_exception_message, show_info
|
40
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
41
|
+
|
42
|
+
logger = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
@jax.tree_util.register_dataclass
|
46
|
+
@dataclass
|
47
|
+
class SupervisedConfig(TrainConfig):
|
48
|
+
updates_per_step: int = field(1, help="Number of updates to perform per step")
|
49
|
+
|
50
|
+
|
51
|
+
Config = TypeVar("Config", bound=SupervisedConfig)
|
52
|
+
|
53
|
+
|
54
|
+
class SupervisedMixin(
|
55
|
+
TrainMixin[Config, InitParams],
|
56
|
+
Generic[Config],
|
57
|
+
ABC,
|
58
|
+
):
|
59
|
+
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
60
|
+
"""Gets the output from the model.
|
61
|
+
|
62
|
+
By default, we assume the model is a function that takes the batch as
|
63
|
+
input and returns the loss. This function can be patched to do more
|
64
|
+
complex operations instead.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
model: The current model.
|
68
|
+
batch: The current minibatch of samples.
|
69
|
+
state: The current training state.
|
70
|
+
"""
|
71
|
+
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
72
|
+
|
73
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
74
|
+
"""Gets the loss for the current batch.
|
75
|
+
|
76
|
+
By default, we assume the model is a function that takes the batch as
|
77
|
+
input and returns the loss. This function can be patched to do more
|
78
|
+
complex operations instead.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
model: The current model.
|
82
|
+
batch: The current minibatch of samples.
|
83
|
+
output: The output from the model.
|
84
|
+
state: The current training state.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
The computed loss, as a tensor.
|
88
|
+
"""
|
89
|
+
if not isinstance(output, Array):
|
90
|
+
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
91
|
+
return output
|
92
|
+
|
93
|
+
def compute_metrics(
|
94
|
+
self,
|
95
|
+
model: PyTree,
|
96
|
+
batch: Batch,
|
97
|
+
output: Output,
|
98
|
+
loss: Array,
|
99
|
+
state: State,
|
100
|
+
) -> dict[str, Array]:
|
101
|
+
"""Computes the metrics for the current batch.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
model: The current model.
|
105
|
+
batch: The current minibatch of samples.
|
106
|
+
output: The output from the model.
|
107
|
+
loss: The loss for the current batch.
|
108
|
+
state: The current training state.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
A dictionary of metrics.
|
112
|
+
"""
|
113
|
+
return {
|
114
|
+
"loss": loss,
|
115
|
+
}
|
116
|
+
|
117
|
+
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
118
|
+
def get_output_and_loss(
|
119
|
+
self,
|
120
|
+
model_arr: PyTree,
|
121
|
+
model_static: PyTree,
|
122
|
+
batch: Batch,
|
123
|
+
state: State,
|
124
|
+
) -> tuple[Array, tuple[Output, dict[str, Array]]]:
|
125
|
+
model = eqx.combine(model_arr, model_static)
|
126
|
+
output = self.get_output(model, batch, state)
|
127
|
+
loss = self.compute_loss(model, batch, output, state)
|
128
|
+
metrics = self.compute_metrics(model, batch, output, loss, state)
|
129
|
+
return loss, (output, metrics)
|
130
|
+
|
131
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
132
|
+
def update(
|
133
|
+
self,
|
134
|
+
model_arr: PyTree,
|
135
|
+
model_static: PyTree,
|
136
|
+
optimizer: optax.GradientTransformation,
|
137
|
+
opt_state: optax.OptState,
|
138
|
+
batch: Batch,
|
139
|
+
state: State,
|
140
|
+
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
141
|
+
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
142
|
+
grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
|
143
|
+
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
144
|
+
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
145
|
+
model_arr = eqx.apply_updates(model_arr, updates)
|
146
|
+
return model_arr, opt_state, output, metrics
|
147
|
+
|
148
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
149
|
+
def train_step(
|
150
|
+
self,
|
151
|
+
model_arr: PyTree,
|
152
|
+
model_static: PyTree,
|
153
|
+
optimizer: optax.GradientTransformation,
|
154
|
+
opt_state: optax.OptState,
|
155
|
+
batches: Batch,
|
156
|
+
state: State,
|
157
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
158
|
+
def update_fn(
|
159
|
+
carry: tuple[PyTree, optax.OptState],
|
160
|
+
batch: Batch,
|
161
|
+
) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
|
162
|
+
model_arr, opt_state = carry
|
163
|
+
model_arr, opt_state, output, metrics = self.update(
|
164
|
+
model_arr,
|
165
|
+
model_static,
|
166
|
+
optimizer,
|
167
|
+
opt_state,
|
168
|
+
batch,
|
169
|
+
state,
|
170
|
+
)
|
171
|
+
return (model_arr, opt_state), (output, FrozenDict(metrics))
|
172
|
+
|
173
|
+
(model_arr, opt_state), (output, metrics) = xax_scan(
|
174
|
+
update_fn,
|
175
|
+
(model_arr, opt_state),
|
176
|
+
batches,
|
177
|
+
jit_level=3,
|
178
|
+
)
|
179
|
+
|
180
|
+
# Only get the final output and metrics.
|
181
|
+
output = jax.tree.map(lambda x: x[-1], output)
|
182
|
+
metrics = jax.tree.map(lambda x: x[-1], metrics)
|
183
|
+
|
184
|
+
return model_arr, opt_state, output, metrics
|
185
|
+
|
186
|
+
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
187
|
+
def val_step(
|
188
|
+
self,
|
189
|
+
model_arr: PyTree,
|
190
|
+
model_static: PyTree,
|
191
|
+
batch: Batch,
|
192
|
+
state: State,
|
193
|
+
) -> tuple[Output, FrozenDict[str, Array]]:
|
194
|
+
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
195
|
+
return output, FrozenDict(metrics)
|
196
|
+
|
197
|
+
def train_loop(
|
198
|
+
self,
|
199
|
+
models: Sequence[PyTree],
|
200
|
+
optimizers: Sequence[optax.GradientTransformation],
|
201
|
+
opt_states: Sequence[optax.OptState],
|
202
|
+
train_pf: Iterator[Batch],
|
203
|
+
valid_pf: Iterator[Batch],
|
204
|
+
state: State,
|
205
|
+
) -> None:
|
206
|
+
if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
|
207
|
+
raise ValueError(
|
208
|
+
"Vanilla training expects a single model, optimizer and optimizer state. "
|
209
|
+
f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
|
210
|
+
)
|
211
|
+
|
212
|
+
model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
|
213
|
+
optimizer = optimizers[0]
|
214
|
+
opt_state = opt_states[0]
|
215
|
+
|
216
|
+
while not self.is_training_over(state):
|
217
|
+
valid_step = self.valid_step_timer(state)
|
218
|
+
|
219
|
+
if valid_step:
|
220
|
+
with ContextTimer() as timer:
|
221
|
+
state = state.replace(phase="valid")
|
222
|
+
valid_batch = next(valid_pf)
|
223
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
224
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
225
|
+
|
226
|
+
state = state.replace(
|
227
|
+
num_steps=state.num_steps + 1,
|
228
|
+
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
229
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
230
|
+
)
|
231
|
+
|
232
|
+
with ContextTimer() as timer:
|
233
|
+
state = self.on_step_start(state)
|
234
|
+
state = state.replace(phase="train")
|
235
|
+
train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
|
236
|
+
model_arr, opt_state, output, metrics = self.train_step(
|
237
|
+
model_arr=model_arr,
|
238
|
+
model_static=model_static,
|
239
|
+
optimizer=optimizer,
|
240
|
+
opt_state=opt_state,
|
241
|
+
batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
|
242
|
+
state=state,
|
243
|
+
)
|
244
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
|
245
|
+
state = self.on_step_end(state)
|
246
|
+
|
247
|
+
state = state.replace(
|
248
|
+
num_steps=state.num_steps + 1,
|
249
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
|
250
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
251
|
+
)
|
252
|
+
|
253
|
+
if state.num_steps <= 3:
|
254
|
+
logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
|
255
|
+
|
256
|
+
if self.should_checkpoint(state):
|
257
|
+
model = eqx.combine(model_arr, model_static)
|
258
|
+
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
259
|
+
|
260
|
+
# After finishing training, save the final checkpoint.
|
261
|
+
model = eqx.combine(model_arr, model_static)
|
262
|
+
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
263
|
+
|
264
|
+
@contextlib.contextmanager
|
265
|
+
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
266
|
+
try:
|
267
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
|
268
|
+
yield train_iterator
|
269
|
+
return
|
270
|
+
except NotImplementedError:
|
271
|
+
pass
|
272
|
+
|
273
|
+
train_ds = self.get_dataset("train")
|
274
|
+
train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
|
275
|
+
train_pf = self.get_prefetcher(train_dl)
|
276
|
+
|
277
|
+
try:
|
278
|
+
with train_pf as train_pf_ctx:
|
279
|
+
yield train_pf_ctx
|
280
|
+
finally:
|
281
|
+
logger.info("Closing train prefetcher")
|
282
|
+
|
283
|
+
@contextlib.contextmanager
|
284
|
+
def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
285
|
+
try:
|
286
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
|
287
|
+
yield valid_iterator
|
288
|
+
return
|
289
|
+
except NotImplementedError:
|
290
|
+
pass
|
291
|
+
|
292
|
+
valid_ds = self.get_dataset("valid")
|
293
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
294
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
295
|
+
|
296
|
+
try:
|
297
|
+
with valid_pf as valid_pf_ctx:
|
298
|
+
yield valid_pf_ctx
|
299
|
+
finally:
|
300
|
+
logger.info("Closing valid prefetcher")
|
301
|
+
|
302
|
+
def run(self) -> None:
|
303
|
+
self.run_training()
|
304
|
+
|
305
|
+
def run_training(self) -> None:
|
306
|
+
"""Runs the training loop.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
model: The current model
|
310
|
+
task: The current task
|
311
|
+
optimizer: The current optimizer
|
312
|
+
lr_scheduler: The current learning rate scheduler
|
313
|
+
|
314
|
+
Raises:
|
315
|
+
ValueError: If the task is not a supervised learning task
|
316
|
+
"""
|
317
|
+
with self:
|
318
|
+
key = self.prng_key()
|
319
|
+
|
320
|
+
self.set_loggers()
|
321
|
+
|
322
|
+
if is_master():
|
323
|
+
Thread(target=self.log_state, daemon=True).start()
|
324
|
+
|
325
|
+
key, model_key = jax.random.split(key)
|
326
|
+
init_params = InitParams(key=model_key)
|
327
|
+
models, optimizers, opt_states, state = self.load_initial_state(init_params, load_optimizer=True)
|
328
|
+
logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
|
329
|
+
logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
|
330
|
+
|
331
|
+
state = self.on_training_start(state)
|
332
|
+
|
333
|
+
def on_exit() -> None:
|
334
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
335
|
+
|
336
|
+
# Handle user-defined interrupts during the training loop.
|
337
|
+
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
338
|
+
|
339
|
+
key, tkey, vkey = jax.random.split(key, 3)
|
340
|
+
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
341
|
+
try:
|
342
|
+
self.train_loop(
|
343
|
+
models=models,
|
344
|
+
optimizers=optimizers,
|
345
|
+
opt_states=opt_states,
|
346
|
+
train_pf=train_pf,
|
347
|
+
valid_pf=valid_pf,
|
348
|
+
state=state,
|
349
|
+
)
|
350
|
+
|
351
|
+
except TrainingFinishedError:
|
352
|
+
if is_master():
|
353
|
+
num_steps, num_samples = int(state.num_steps), int(state.num_samples)
|
354
|
+
show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
|
355
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
356
|
+
|
357
|
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
358
|
+
if is_master():
|
359
|
+
show_info("Interrupted training", important=True)
|
360
|
+
|
361
|
+
except BaseException:
|
362
|
+
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
363
|
+
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
364
|
+
sys.stdout.flush()
|
365
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
366
|
+
|
367
|
+
finally:
|
368
|
+
state = self.on_training_end(state)
|
xax/task/mixins/train.py
CHANGED
@@ -1,24 +1,15 @@
|
|
1
1
|
"""Defines a mixin for running the training loop."""
|
2
2
|
|
3
|
-
import bdb
|
4
|
-
import contextlib
|
5
3
|
import functools
|
6
4
|
import itertools
|
7
5
|
import logging
|
8
|
-
import signal
|
9
|
-
import sys
|
10
|
-
import textwrap
|
11
6
|
import time
|
12
|
-
import traceback
|
13
7
|
from abc import ABC, abstractmethod
|
14
8
|
from dataclasses import asdict, dataclass, is_dataclass
|
15
9
|
from pathlib import Path
|
16
|
-
from threading import Thread
|
17
10
|
from typing import (
|
18
11
|
Any,
|
19
|
-
Generator,
|
20
12
|
Generic,
|
21
|
-
Iterator,
|
22
13
|
Literal,
|
23
14
|
Mapping,
|
24
15
|
Sequence,
|
@@ -30,7 +21,6 @@ from typing import (
|
|
30
21
|
|
31
22
|
import equinox as eqx
|
32
23
|
import jax
|
33
|
-
import jax.numpy as jnp
|
34
24
|
import numpy as np
|
35
25
|
import optax
|
36
26
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
@@ -38,7 +28,6 @@ from jaxtyping import Array, PRNGKeyArray, PyTree
|
|
38
28
|
from xax.core.conf import field
|
39
29
|
from xax.core.state import Phase, State
|
40
30
|
from xax.nn.functions import set_random_seed
|
41
|
-
from xax.nn.parallel import is_master
|
42
31
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
32
|
from xax.task.mixins.checkpointing import (
|
44
33
|
CheckpointingConfig,
|
@@ -51,19 +40,14 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
51
40
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
52
41
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
53
42
|
from xax.utils.experiments import (
|
54
|
-
ContextTimer,
|
55
43
|
StateTimer,
|
56
|
-
TrainingFinishedError,
|
57
44
|
diff_configs,
|
58
45
|
get_diff_string,
|
59
46
|
get_info_json,
|
60
47
|
get_state_file_string,
|
61
48
|
get_training_code,
|
62
49
|
)
|
63
|
-
from xax.utils.jax import jit as xax_jit, scan as xax_scan
|
64
50
|
from xax.utils.logging import LOG_PING, LOG_STATUS
|
65
|
-
from xax.utils.pytree import get_pytree_param_count
|
66
|
-
from xax.utils.text import highlight_exception_message, show_info
|
67
51
|
from xax.utils.types.frozen_dict import FrozenDict
|
68
52
|
|
69
53
|
logger = logging.getLogger(__name__)
|
@@ -159,6 +143,16 @@ class ValidStepTimer:
|
|
159
143
|
return False
|
160
144
|
|
161
145
|
|
146
|
+
@jax.tree_util.register_dataclass
|
147
|
+
@dataclass(frozen=True)
|
148
|
+
class InitParams:
|
149
|
+
key: PRNGKeyArray
|
150
|
+
|
151
|
+
|
152
|
+
# Subclasses should be able to override the init params.
|
153
|
+
InitParamsT = TypeVar("InitParamsT", bound=InitParams)
|
154
|
+
|
155
|
+
|
162
156
|
@jax.tree_util.register_dataclass
|
163
157
|
@dataclass
|
164
158
|
class TrainConfig(
|
@@ -175,7 +169,6 @@ class TrainConfig(
|
|
175
169
|
valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
|
176
170
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
177
171
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
178
|
-
updates_per_step: int = field(1, help="Number of updates to perform per step")
|
179
172
|
random_seed: int = field(1337, help="Random seed for the task")
|
180
173
|
|
181
174
|
|
@@ -189,7 +182,7 @@ class TrainMixin(
|
|
189
182
|
StepContextMixin[Config],
|
190
183
|
ArtifactsMixin[Config],
|
191
184
|
RunnableMixin[Config],
|
192
|
-
Generic[Config],
|
185
|
+
Generic[Config, InitParamsT],
|
193
186
|
ABC,
|
194
187
|
):
|
195
188
|
valid_step_timer: ValidStepTimer
|
@@ -309,15 +302,18 @@ class TrainMixin(
|
|
309
302
|
self.write_logs(state)
|
310
303
|
|
311
304
|
@abstractmethod
|
312
|
-
def get_model(self,
|
305
|
+
def get_model(self, params: InitParamsT) -> PyTree | Sequence[PyTree]:
|
313
306
|
"""Returns the Equinox model to train.
|
314
307
|
|
308
|
+
Args:
|
309
|
+
params: The parameters for initializing the model.
|
310
|
+
|
315
311
|
Returns:
|
316
312
|
The model to train.
|
317
313
|
"""
|
318
314
|
|
319
|
-
def _get_models(self,
|
320
|
-
models = self.get_model(
|
315
|
+
def _get_models(self, params: InitParamsT) -> list[PyTree]:
|
316
|
+
models = self.get_model(params)
|
321
317
|
if isinstance(models, Sequence):
|
322
318
|
models = list(models)
|
323
319
|
elif isinstance(models, eqx.Module):
|
@@ -353,20 +349,20 @@ class TrainMixin(
|
|
353
349
|
@overload
|
354
350
|
def load_initial_state(
|
355
351
|
self,
|
356
|
-
|
352
|
+
params: InitParamsT,
|
357
353
|
load_optimizer: Literal[False] = False,
|
358
354
|
) -> tuple[PyTree, State]: ...
|
359
355
|
|
360
356
|
@overload
|
361
357
|
def load_initial_state(
|
362
358
|
self,
|
363
|
-
|
359
|
+
params: InitParamsT,
|
364
360
|
load_optimizer: Literal[True],
|
365
361
|
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
|
366
362
|
|
367
363
|
def load_initial_state(
|
368
364
|
self,
|
369
|
-
|
365
|
+
params: InitParamsT,
|
370
366
|
load_optimizer: bool = False,
|
371
367
|
) -> (
|
372
368
|
tuple[list[PyTree], State]
|
@@ -376,7 +372,7 @@ class TrainMixin(
|
|
376
372
|
|
377
373
|
if init_ckpt_path is not None:
|
378
374
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
379
|
-
model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
|
375
|
+
model, state, config = self.load_ckpt(init_ckpt_path, params, part="model_state_config")
|
380
376
|
config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
|
381
377
|
if config_diff:
|
382
378
|
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
@@ -384,12 +380,12 @@ class TrainMixin(
|
|
384
380
|
if not load_optimizer:
|
385
381
|
return model, state
|
386
382
|
|
387
|
-
optimizer = self.load_ckpt(init_ckpt_path, part="opt")
|
388
|
-
opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
|
383
|
+
optimizer = self.load_ckpt(init_ckpt_path, params, part="opt")
|
384
|
+
opt_state = self.load_ckpt(init_ckpt_path, params, part="opt_state", model=model, optimizer=optimizer)
|
389
385
|
return model, optimizer, opt_state, state
|
390
386
|
|
391
387
|
logger.info("Starting a new training run")
|
392
|
-
models = self._get_models(
|
388
|
+
models = self._get_models(params)
|
393
389
|
state = State.init_state()
|
394
390
|
|
395
391
|
if not load_optimizer:
|
@@ -405,6 +401,7 @@ class TrainMixin(
|
|
405
401
|
def load_ckpt(
|
406
402
|
self,
|
407
403
|
path: Path,
|
404
|
+
init_params: InitParamsT,
|
408
405
|
*,
|
409
406
|
part: Literal["all"],
|
410
407
|
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, Config]: ...
|
@@ -413,6 +410,7 @@ class TrainMixin(
|
|
413
410
|
def load_ckpt(
|
414
411
|
self,
|
415
412
|
path: Path,
|
413
|
+
init_params: InitParamsT,
|
416
414
|
*,
|
417
415
|
part: Literal["model_state_config"],
|
418
416
|
) -> tuple[list[PyTree], State, Config]: ...
|
@@ -421,6 +419,7 @@ class TrainMixin(
|
|
421
419
|
def load_ckpt(
|
422
420
|
self,
|
423
421
|
path: Path,
|
422
|
+
init_params: InitParamsT,
|
424
423
|
*,
|
425
424
|
part: Literal["model"],
|
426
425
|
) -> list[PyTree]: ...
|
@@ -429,6 +428,7 @@ class TrainMixin(
|
|
429
428
|
def load_ckpt(
|
430
429
|
self,
|
431
430
|
path: Path,
|
431
|
+
init_params: InitParamsT,
|
432
432
|
*,
|
433
433
|
part: Literal["opt"],
|
434
434
|
) -> list[optax.GradientTransformation]: ...
|
@@ -437,6 +437,7 @@ class TrainMixin(
|
|
437
437
|
def load_ckpt(
|
438
438
|
self,
|
439
439
|
path: Path,
|
440
|
+
init_params: InitParamsT,
|
440
441
|
*,
|
441
442
|
part: Literal["opt_state"],
|
442
443
|
model: PyTree | None = None,
|
@@ -447,6 +448,7 @@ class TrainMixin(
|
|
447
448
|
def load_ckpt(
|
448
449
|
self,
|
449
450
|
path: Path,
|
451
|
+
init_params: InitParamsT,
|
450
452
|
*,
|
451
453
|
part: Literal["state"],
|
452
454
|
) -> list[State]: ...
|
@@ -455,6 +457,7 @@ class TrainMixin(
|
|
455
457
|
def load_ckpt(
|
456
458
|
self,
|
457
459
|
path: Path,
|
460
|
+
init_params: InitParamsT,
|
458
461
|
*,
|
459
462
|
part: Literal["config"],
|
460
463
|
) -> list[Config]: ...
|
@@ -462,6 +465,7 @@ class TrainMixin(
|
|
462
465
|
def load_ckpt(
|
463
466
|
self,
|
464
467
|
path: str | Path,
|
468
|
+
init_params: InitParamsT,
|
465
469
|
*,
|
466
470
|
part: CheckpointPart = "all",
|
467
471
|
model: PyTree | None = None,
|
@@ -477,18 +481,15 @@ class TrainMixin(
|
|
477
481
|
):
|
478
482
|
path = Path(path)
|
479
483
|
|
480
|
-
# This key isn't used for anything, it's just a required argument.
|
481
|
-
key = jax.random.PRNGKey(0)
|
482
|
-
|
483
484
|
match part:
|
484
485
|
case "model_state_config":
|
485
|
-
model_specs = eqx.filter_eval_shape(self._get_models,
|
486
|
+
model_specs = eqx.filter_eval_shape(self._get_models, init_params)
|
486
487
|
model, state, config = load_ckpt(path, part="model_state_config", model_templates=model_specs)
|
487
488
|
config = self.get_config(config, use_cli=False)
|
488
489
|
return model, state, config
|
489
490
|
|
490
491
|
case "model":
|
491
|
-
model_specs = eqx.filter_eval_shape(self._get_models,
|
492
|
+
model_specs = eqx.filter_eval_shape(self._get_models, init_params)
|
492
493
|
return load_ckpt(path, part="model", model_templates=model_specs)
|
493
494
|
|
494
495
|
case "opt":
|
@@ -497,7 +498,7 @@ class TrainMixin(
|
|
497
498
|
|
498
499
|
case "opt_state":
|
499
500
|
if model is None:
|
500
|
-
model_specs = eqx.filter_eval_shape(self._get_models,
|
501
|
+
model_specs = eqx.filter_eval_shape(self._get_models, init_params)
|
501
502
|
model = load_ckpt(path, part="model", model_templates=model_specs)
|
502
503
|
if optimizer is None:
|
503
504
|
optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
|
@@ -512,7 +513,7 @@ class TrainMixin(
|
|
512
513
|
return self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
513
514
|
|
514
515
|
case "all":
|
515
|
-
model_specs = eqx.filter_eval_shape(self._get_models,
|
516
|
+
model_specs = eqx.filter_eval_shape(self._get_models, init_params)
|
516
517
|
model = load_ckpt(path, part="model", model_templates=model_specs)
|
517
518
|
optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
|
518
519
|
optimizer = load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
|
@@ -525,95 +526,6 @@ class TrainMixin(
|
|
525
526
|
case _:
|
526
527
|
raise ValueError(f"Unknown checkpoint part: {part}")
|
527
528
|
|
528
|
-
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
529
|
-
"""Gets the output from the model.
|
530
|
-
|
531
|
-
By default, we assume the model is a function that takes the batch as
|
532
|
-
input and returns the loss. This function can be patched to do more
|
533
|
-
complex operations instead.
|
534
|
-
|
535
|
-
Args:
|
536
|
-
model: The current model.
|
537
|
-
batch: The current minibatch of samples.
|
538
|
-
state: The current training state.
|
539
|
-
"""
|
540
|
-
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
541
|
-
|
542
|
-
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
543
|
-
"""Gets the loss for the current batch.
|
544
|
-
|
545
|
-
By default, we assume the model is a function that takes the batch as
|
546
|
-
input and returns the loss. This function can be patched to do more
|
547
|
-
complex operations instead.
|
548
|
-
|
549
|
-
Args:
|
550
|
-
model: The current model.
|
551
|
-
batch: The current minibatch of samples.
|
552
|
-
output: The output from the model.
|
553
|
-
state: The current training state.
|
554
|
-
|
555
|
-
Returns:
|
556
|
-
The computed loss, as a tensor.
|
557
|
-
"""
|
558
|
-
if not isinstance(output, Array):
|
559
|
-
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
560
|
-
return output
|
561
|
-
|
562
|
-
def compute_metrics(
|
563
|
-
self,
|
564
|
-
model: PyTree,
|
565
|
-
batch: Batch,
|
566
|
-
output: Output,
|
567
|
-
loss: Array,
|
568
|
-
state: State,
|
569
|
-
) -> dict[str, Array]:
|
570
|
-
"""Computes the metrics for the current batch.
|
571
|
-
|
572
|
-
Args:
|
573
|
-
model: The current model.
|
574
|
-
batch: The current minibatch of samples.
|
575
|
-
output: The output from the model.
|
576
|
-
loss: The loss for the current batch.
|
577
|
-
state: The current training state.
|
578
|
-
|
579
|
-
Returns:
|
580
|
-
A dictionary of metrics.
|
581
|
-
"""
|
582
|
-
return {
|
583
|
-
"loss": loss,
|
584
|
-
}
|
585
|
-
|
586
|
-
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
587
|
-
def get_output_and_loss(
|
588
|
-
self,
|
589
|
-
model_arr: PyTree,
|
590
|
-
model_static: PyTree,
|
591
|
-
batch: Batch,
|
592
|
-
state: State,
|
593
|
-
) -> tuple[Array, tuple[Output, dict[str, Array]]]:
|
594
|
-
model = eqx.combine(model_arr, model_static)
|
595
|
-
output = self.get_output(model, batch, state)
|
596
|
-
loss = self.compute_loss(model, batch, output, state)
|
597
|
-
metrics = self.compute_metrics(model, batch, output, loss, state)
|
598
|
-
return loss, (output, metrics)
|
599
|
-
|
600
|
-
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
601
|
-
def update(
|
602
|
-
self,
|
603
|
-
model_arr: PyTree,
|
604
|
-
model_static: PyTree,
|
605
|
-
optimizer: optax.GradientTransformation,
|
606
|
-
opt_state: optax.OptState,
|
607
|
-
batch: Batch,
|
608
|
-
state: State,
|
609
|
-
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
610
|
-
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
611
|
-
grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
|
612
|
-
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
613
|
-
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
614
|
-
model_arr = eqx.apply_updates(model_arr, updates)
|
615
|
-
return model_arr, opt_state, output, metrics
|
616
|
-
|
617
529
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
618
530
|
"""Gets the batch size for the current batch.
|
619
531
|
|
@@ -687,224 +599,3 @@ class TrainMixin(
|
|
687
599
|
|
688
600
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
689
601
|
return eqx.is_inexact_array(item)
|
690
|
-
|
691
|
-
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
692
|
-
def train_step(
|
693
|
-
self,
|
694
|
-
model_arr: PyTree,
|
695
|
-
model_static: PyTree,
|
696
|
-
optimizer: optax.GradientTransformation,
|
697
|
-
opt_state: optax.OptState,
|
698
|
-
batches: Batch,
|
699
|
-
state: State,
|
700
|
-
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
701
|
-
def update_fn(
|
702
|
-
carry: tuple[PyTree, optax.OptState],
|
703
|
-
batch: Batch,
|
704
|
-
) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
|
705
|
-
model_arr, opt_state = carry
|
706
|
-
model_arr, opt_state, output, metrics = self.update(
|
707
|
-
model_arr,
|
708
|
-
model_static,
|
709
|
-
optimizer,
|
710
|
-
opt_state,
|
711
|
-
batch,
|
712
|
-
state,
|
713
|
-
)
|
714
|
-
return (model_arr, opt_state), (output, FrozenDict(metrics))
|
715
|
-
|
716
|
-
(model_arr, opt_state), (output, metrics) = xax_scan(
|
717
|
-
update_fn,
|
718
|
-
(model_arr, opt_state),
|
719
|
-
batches,
|
720
|
-
jit_level=3,
|
721
|
-
)
|
722
|
-
|
723
|
-
# Only get the final output and metrics.
|
724
|
-
output = jax.tree.map(lambda x: x[-1], output)
|
725
|
-
metrics = jax.tree.map(lambda x: x[-1], metrics)
|
726
|
-
|
727
|
-
return model_arr, opt_state, output, metrics
|
728
|
-
|
729
|
-
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
730
|
-
def val_step(
|
731
|
-
self,
|
732
|
-
model_arr: PyTree,
|
733
|
-
model_static: PyTree,
|
734
|
-
batch: Batch,
|
735
|
-
state: State,
|
736
|
-
) -> tuple[Output, FrozenDict[str, Array]]:
|
737
|
-
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
738
|
-
return output, FrozenDict(metrics)
|
739
|
-
|
740
|
-
def train_loop(
|
741
|
-
self,
|
742
|
-
models: Sequence[PyTree],
|
743
|
-
optimizers: Sequence[optax.GradientTransformation],
|
744
|
-
opt_states: Sequence[optax.OptState],
|
745
|
-
train_pf: Iterator[Batch],
|
746
|
-
valid_pf: Iterator[Batch],
|
747
|
-
state: State,
|
748
|
-
) -> None:
|
749
|
-
if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
|
750
|
-
raise ValueError(
|
751
|
-
"Vanilla training expects a single model, optimizer and optimizer state. "
|
752
|
-
f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
|
753
|
-
)
|
754
|
-
|
755
|
-
model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
|
756
|
-
optimizer = optimizers[0]
|
757
|
-
opt_state = opt_states[0]
|
758
|
-
|
759
|
-
while not self.is_training_over(state):
|
760
|
-
valid_step = self.valid_step_timer(state)
|
761
|
-
|
762
|
-
if valid_step:
|
763
|
-
with ContextTimer() as timer:
|
764
|
-
state = state.replace(phase="valid")
|
765
|
-
valid_batch = next(valid_pf)
|
766
|
-
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
767
|
-
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
768
|
-
|
769
|
-
state = state.replace(
|
770
|
-
num_steps=state.num_steps + 1,
|
771
|
-
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
772
|
-
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
773
|
-
)
|
774
|
-
|
775
|
-
with ContextTimer() as timer:
|
776
|
-
state = self.on_step_start(state)
|
777
|
-
state = state.replace(phase="train")
|
778
|
-
train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
|
779
|
-
model_arr, opt_state, output, metrics = self.train_step(
|
780
|
-
model_arr=model_arr,
|
781
|
-
model_static=model_static,
|
782
|
-
optimizer=optimizer,
|
783
|
-
opt_state=opt_state,
|
784
|
-
batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
|
785
|
-
state=state,
|
786
|
-
)
|
787
|
-
self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
|
788
|
-
state = self.on_step_end(state)
|
789
|
-
|
790
|
-
state = state.replace(
|
791
|
-
num_steps=state.num_steps + 1,
|
792
|
-
num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
|
793
|
-
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
794
|
-
)
|
795
|
-
|
796
|
-
if state.num_steps <= 3:
|
797
|
-
logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
|
798
|
-
|
799
|
-
if self.should_checkpoint(state):
|
800
|
-
model = eqx.combine(model_arr, model_static)
|
801
|
-
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
802
|
-
|
803
|
-
# After finishing training, save the final checkpoint.
|
804
|
-
model = eqx.combine(model_arr, model_static)
|
805
|
-
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
806
|
-
|
807
|
-
@contextlib.contextmanager
|
808
|
-
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
809
|
-
try:
|
810
|
-
train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
|
811
|
-
yield train_iterator
|
812
|
-
return
|
813
|
-
except NotImplementedError:
|
814
|
-
pass
|
815
|
-
|
816
|
-
train_ds = self.get_dataset("train")
|
817
|
-
train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
|
818
|
-
train_pf = self.get_prefetcher(train_dl)
|
819
|
-
|
820
|
-
try:
|
821
|
-
with train_pf as train_pf_ctx:
|
822
|
-
yield train_pf_ctx
|
823
|
-
finally:
|
824
|
-
logger.info("Closing train prefetcher")
|
825
|
-
|
826
|
-
@contextlib.contextmanager
|
827
|
-
def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
828
|
-
try:
|
829
|
-
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
|
830
|
-
yield valid_iterator
|
831
|
-
return
|
832
|
-
except NotImplementedError:
|
833
|
-
pass
|
834
|
-
|
835
|
-
valid_ds = self.get_dataset("valid")
|
836
|
-
valid_dl = self.get_dataloader(valid_ds, "valid")
|
837
|
-
valid_pf = self.get_prefetcher(valid_dl)
|
838
|
-
|
839
|
-
try:
|
840
|
-
with valid_pf as valid_pf_ctx:
|
841
|
-
yield valid_pf_ctx
|
842
|
-
finally:
|
843
|
-
logger.info("Closing valid prefetcher")
|
844
|
-
|
845
|
-
def run(self) -> None:
|
846
|
-
self.run_training()
|
847
|
-
|
848
|
-
def run_training(self) -> None:
|
849
|
-
"""Runs the training loop.
|
850
|
-
|
851
|
-
Args:
|
852
|
-
model: The current model
|
853
|
-
task: The current task
|
854
|
-
optimizer: The current optimizer
|
855
|
-
lr_scheduler: The current learning rate scheduler
|
856
|
-
|
857
|
-
Raises:
|
858
|
-
ValueError: If the task is not a supervised learning task
|
859
|
-
"""
|
860
|
-
with self:
|
861
|
-
key = self.prng_key()
|
862
|
-
|
863
|
-
self.set_loggers()
|
864
|
-
|
865
|
-
if is_master():
|
866
|
-
Thread(target=self.log_state, daemon=True).start()
|
867
|
-
|
868
|
-
key, model_key = jax.random.split(key)
|
869
|
-
models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
|
870
|
-
logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
|
871
|
-
logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
|
872
|
-
|
873
|
-
state = self.on_training_start(state)
|
874
|
-
|
875
|
-
def on_exit() -> None:
|
876
|
-
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
877
|
-
|
878
|
-
# Handle user-defined interrupts during the training loop.
|
879
|
-
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
880
|
-
|
881
|
-
key, tkey, vkey = jax.random.split(key, 3)
|
882
|
-
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
883
|
-
try:
|
884
|
-
self.train_loop(
|
885
|
-
models=models,
|
886
|
-
optimizers=optimizers,
|
887
|
-
opt_states=opt_states,
|
888
|
-
train_pf=train_pf,
|
889
|
-
valid_pf=valid_pf,
|
890
|
-
state=state,
|
891
|
-
)
|
892
|
-
|
893
|
-
except TrainingFinishedError:
|
894
|
-
if is_master():
|
895
|
-
num_steps, num_samples = int(state.num_steps), int(state.num_samples)
|
896
|
-
show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
|
897
|
-
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
898
|
-
|
899
|
-
except (KeyboardInterrupt, bdb.BdbQuit):
|
900
|
-
if is_master():
|
901
|
-
show_info("Interrupted training", important=True)
|
902
|
-
|
903
|
-
except BaseException:
|
904
|
-
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
905
|
-
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
906
|
-
sys.stdout.flush()
|
907
|
-
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
908
|
-
|
909
|
-
finally:
|
910
|
-
state = self.on_training_end(state)
|
xax/task/task.py
CHANGED
@@ -19,6 +19,7 @@ from xax.task.mixins import (
|
|
19
19
|
DataloadersMixin,
|
20
20
|
GPUStatsConfig,
|
21
21
|
GPUStatsMixin,
|
22
|
+
InitParams,
|
22
23
|
LoggerConfig,
|
23
24
|
LoggerMixin,
|
24
25
|
ProcessConfig,
|
@@ -27,6 +28,8 @@ from xax.task.mixins import (
|
|
27
28
|
RunnableMixin,
|
28
29
|
StepContextConfig,
|
29
30
|
StepContextMixin,
|
31
|
+
SupervisedConfig as BaseSupervisedConfig,
|
32
|
+
SupervisedMixin as BaseSupervisedMixin,
|
30
33
|
TrainConfig,
|
31
34
|
TrainMixin,
|
32
35
|
)
|
@@ -52,10 +55,11 @@ class Config(
|
|
52
55
|
|
53
56
|
|
54
57
|
ConfigT = TypeVar("ConfigT", bound=Config)
|
58
|
+
InitParamsT = TypeVar("InitParamsT", bound=InitParams)
|
55
59
|
|
56
60
|
|
57
61
|
class Task(
|
58
|
-
TrainMixin[ConfigT],
|
62
|
+
TrainMixin[ConfigT, InitParamsT],
|
59
63
|
CheckpointingMixin[ConfigT],
|
60
64
|
CompileMixin[ConfigT],
|
61
65
|
DataloadersMixin[ConfigT],
|
@@ -67,6 +71,26 @@ class Task(
|
|
67
71
|
ArtifactsMixin[ConfigT],
|
68
72
|
RunnableMixin[ConfigT],
|
69
73
|
BaseTask[ConfigT],
|
70
|
-
Generic[ConfigT],
|
74
|
+
Generic[ConfigT, InitParamsT],
|
75
|
+
):
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
@jax.tree_util.register_dataclass
|
80
|
+
@dataclass
|
81
|
+
class SupervisedConfig(
|
82
|
+
BaseSupervisedConfig,
|
83
|
+
Config,
|
84
|
+
):
|
85
|
+
pass
|
86
|
+
|
87
|
+
|
88
|
+
SupervisedConfigT = TypeVar("SupervisedConfigT", bound=SupervisedConfig)
|
89
|
+
|
90
|
+
|
91
|
+
class SupervisedTask(
|
92
|
+
BaseSupervisedMixin[SupervisedConfigT],
|
93
|
+
Task[SupervisedConfigT, InitParams],
|
94
|
+
Generic[SupervisedConfigT],
|
71
95
|
):
|
72
96
|
pass
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=50NFQGS6aOMcJQAJ4U1mLpvMRtWc8Kbgtv4zIMWodfc,17164
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -21,7 +21,7 @@ xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
21
|
xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
|
22
22
|
xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
|
23
23
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
24
|
-
xax/task/task.py,sha256=
|
24
|
+
xax/task/task.py,sha256=Iy02wRUti5lDX1rfDHIgX87dGYeayJxJ9nzJzp_lMq0,1960
|
25
25
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
27
27
|
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
@@ -32,7 +32,7 @@ xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
|
|
32
32
|
xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
|
33
33
|
xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
|
34
34
|
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
35
|
-
xax/task/mixins/__init__.py,sha256=
|
35
|
+
xax/task/mixins/__init__.py,sha256=wYc4zfutdMyEmzCVV421gSf25ZXW9htNTSY_TW6vL_8,894
|
36
36
|
xax/task/mixins/artifacts.py,sha256=UN26TW22ARduO6Bjs0yRu4-V6-Md9MPbXLKDnS28m44,3861
|
37
37
|
xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
|
38
38
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
@@ -43,7 +43,8 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
43
43
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
44
44
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
45
45
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
46
|
-
xax/task/mixins/
|
46
|
+
xax/task/mixins/supervised.py,sha256=IxAh-ywvjDNoqXtzHwv2WpVsXFOX45SZjyF3qpbN-2k,13757
|
47
|
+
xax/task/mixins/train.py,sha256=0loO44W6vVjP5usWvN0D1TgYTJ7N3PDevR7brmw3ymQ,20493
|
47
48
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
49
|
xax/utils/debugging.py,sha256=85JYIdnzLnvXsuli-4YHei_3tE3DnX3rmDSARKW2u1M,2192
|
49
50
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
@@ -60,9 +61,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
60
61
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
61
62
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
62
63
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
63
|
-
xax-0.
|
64
|
-
xax-0.
|
65
|
-
xax-0.
|
66
|
-
xax-0.
|
67
|
-
xax-0.
|
68
|
-
xax-0.
|
64
|
+
xax-0.4.0.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
65
|
+
xax-0.4.0.dist-info/METADATA,sha256=oaK0oAc0WM428EAjuwTvFFvaa0JibJl-CpPOBUBVmUY,1246
|
66
|
+
xax-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
67
|
+
xax-0.4.0.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
68
|
+
xax-0.4.0.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
69
|
+
xax-0.4.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|