xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +49 -7
- xax/core/conf.py +1 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +15 -10
- xax/task/base.py +0 -6
- xax/task/logger.py +328 -393
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/tensorboard.py +2 -5
- xax/task/mixins/__init__.py +2 -1
- xax/task/mixins/artifacts.py +14 -7
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/data_loader.py +6 -9
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/logger.py +2 -250
- xax/task/mixins/process.py +4 -0
- xax/task/mixins/train.py +71 -40
- xax/task/task.py +6 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +49 -29
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/METADATA +15 -14
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/top_level.txt +0 -0
xax/task/mixins/logger.py
CHANGED
@@ -4,14 +4,12 @@ import os
|
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from pathlib import Path
|
6
6
|
from types import TracebackType
|
7
|
-
from typing import
|
8
|
-
|
9
|
-
from jaxtyping import Array
|
7
|
+
from typing import Generic, Self, TypeVar
|
10
8
|
|
11
9
|
from xax.core.conf import Device as BaseDeviceConfig, field
|
12
10
|
from xax.core.state import State
|
13
11
|
from xax.task.base import BaseConfig, BaseTask
|
14
|
-
from xax.task.logger import
|
12
|
+
from xax.task.logger import Logger, LoggerImpl
|
15
13
|
from xax.task.loggers.json import JsonLogger
|
16
14
|
from xax.task.loggers.state import StateLogger
|
17
15
|
from xax.task.loggers.stdout import StdoutLogger
|
@@ -59,252 +57,6 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
|
|
59
57
|
def write_logs(self, state: State) -> None:
|
60
58
|
self.logger.write(state)
|
61
59
|
|
62
|
-
def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
|
63
|
-
self.logger.log_scalar(key, value, namespace=namespace)
|
64
|
-
|
65
|
-
def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
|
66
|
-
self.logger.log_string(key, value, namespace=namespace)
|
67
|
-
|
68
|
-
def log_image(
|
69
|
-
self,
|
70
|
-
key: str,
|
71
|
-
value: Callable[[], Array] | Array,
|
72
|
-
*,
|
73
|
-
namespace: str | None = None,
|
74
|
-
keep_resolution: bool = False,
|
75
|
-
) -> None:
|
76
|
-
self.logger.log_image(
|
77
|
-
key,
|
78
|
-
value,
|
79
|
-
namespace=namespace,
|
80
|
-
keep_resolution=keep_resolution,
|
81
|
-
)
|
82
|
-
|
83
|
-
def log_labeled_image(
|
84
|
-
self,
|
85
|
-
key: str,
|
86
|
-
value: Callable[[], tuple[Array, str]] | tuple[Array, str],
|
87
|
-
*,
|
88
|
-
namespace: str | None = None,
|
89
|
-
max_line_length: int | None = None,
|
90
|
-
keep_resolution: bool = False,
|
91
|
-
centered: bool = True,
|
92
|
-
) -> None:
|
93
|
-
self.logger.log_labeled_image(
|
94
|
-
key,
|
95
|
-
value,
|
96
|
-
namespace=namespace,
|
97
|
-
max_line_length=max_line_length,
|
98
|
-
keep_resolution=keep_resolution,
|
99
|
-
centered=centered,
|
100
|
-
)
|
101
|
-
|
102
|
-
def log_images(
|
103
|
-
self,
|
104
|
-
key: str,
|
105
|
-
value: Callable[[], Array] | Array,
|
106
|
-
*,
|
107
|
-
namespace: str | None = None,
|
108
|
-
keep_resolution: bool = False,
|
109
|
-
max_images: int | None = None,
|
110
|
-
sep: int = 0,
|
111
|
-
) -> None:
|
112
|
-
self.logger.log_images(
|
113
|
-
key,
|
114
|
-
value,
|
115
|
-
namespace=namespace,
|
116
|
-
keep_resolution=keep_resolution,
|
117
|
-
max_images=max_images,
|
118
|
-
sep=sep,
|
119
|
-
)
|
120
|
-
|
121
|
-
def log_labeled_images(
|
122
|
-
self,
|
123
|
-
key: str,
|
124
|
-
value: Callable[[], tuple[Array, Sequence[str]]] | tuple[Array, Sequence[str]],
|
125
|
-
*,
|
126
|
-
namespace: str | None = None,
|
127
|
-
max_line_length: int | None = None,
|
128
|
-
keep_resolution: bool = False,
|
129
|
-
max_images: int | None = None,
|
130
|
-
sep: int = 0,
|
131
|
-
centered: bool = True,
|
132
|
-
) -> None:
|
133
|
-
self.logger.log_labeled_images(
|
134
|
-
key,
|
135
|
-
value,
|
136
|
-
namespace=namespace,
|
137
|
-
max_line_length=max_line_length,
|
138
|
-
keep_resolution=keep_resolution,
|
139
|
-
max_images=max_images,
|
140
|
-
sep=sep,
|
141
|
-
centered=centered,
|
142
|
-
)
|
143
|
-
|
144
|
-
def log_audio(
|
145
|
-
self,
|
146
|
-
key: str,
|
147
|
-
value: Callable[[], Array] | Array,
|
148
|
-
*,
|
149
|
-
namespace: str | None = None,
|
150
|
-
sample_rate: int = 44100,
|
151
|
-
log_spec: bool = True,
|
152
|
-
n_fft_ms: float = 32.0,
|
153
|
-
hop_length_ms: float | None = None,
|
154
|
-
channel_select_mode: ChannelSelectMode = "first",
|
155
|
-
keep_resolution: bool = False,
|
156
|
-
) -> None:
|
157
|
-
self.logger.log_audio(
|
158
|
-
key,
|
159
|
-
value,
|
160
|
-
namespace=namespace,
|
161
|
-
sample_rate=sample_rate,
|
162
|
-
log_spec=log_spec,
|
163
|
-
n_fft_ms=n_fft_ms,
|
164
|
-
hop_length_ms=hop_length_ms,
|
165
|
-
channel_select_mode=channel_select_mode,
|
166
|
-
keep_resolution=keep_resolution,
|
167
|
-
)
|
168
|
-
|
169
|
-
def log_audios(
|
170
|
-
self,
|
171
|
-
key: str,
|
172
|
-
value: Callable[[], Array] | Array,
|
173
|
-
*,
|
174
|
-
namespace: str | None = None,
|
175
|
-
sep_ms: float = 0.0,
|
176
|
-
max_audios: int | None = None,
|
177
|
-
sample_rate: int = 44100,
|
178
|
-
log_spec: bool = True,
|
179
|
-
n_fft_ms: float = 32.0,
|
180
|
-
hop_length_ms: float | None = None,
|
181
|
-
channel_select_mode: ChannelSelectMode = "first",
|
182
|
-
spec_sep: int = 0,
|
183
|
-
keep_resolution: bool = False,
|
184
|
-
) -> None:
|
185
|
-
self.logger.log_audios(
|
186
|
-
key,
|
187
|
-
value,
|
188
|
-
namespace=namespace,
|
189
|
-
sep_ms=sep_ms,
|
190
|
-
max_audios=max_audios,
|
191
|
-
sample_rate=sample_rate,
|
192
|
-
log_spec=log_spec,
|
193
|
-
n_fft_ms=n_fft_ms,
|
194
|
-
hop_length_ms=hop_length_ms,
|
195
|
-
channel_select_mode=channel_select_mode,
|
196
|
-
spec_sep=spec_sep,
|
197
|
-
keep_resolution=keep_resolution,
|
198
|
-
)
|
199
|
-
|
200
|
-
def log_spectrogram(
|
201
|
-
self,
|
202
|
-
key: str,
|
203
|
-
value: Callable[[], Array] | Array,
|
204
|
-
*,
|
205
|
-
namespace: str | None = None,
|
206
|
-
sample_rate: int = 44100,
|
207
|
-
n_fft_ms: float = 32.0,
|
208
|
-
hop_length_ms: float | None = None,
|
209
|
-
channel_select_mode: ChannelSelectMode = "first",
|
210
|
-
keep_resolution: bool = False,
|
211
|
-
) -> None:
|
212
|
-
self.logger.log_spectrogram(
|
213
|
-
key,
|
214
|
-
value,
|
215
|
-
namespace=namespace,
|
216
|
-
sample_rate=sample_rate,
|
217
|
-
n_fft_ms=n_fft_ms,
|
218
|
-
hop_length_ms=hop_length_ms,
|
219
|
-
channel_select_mode=channel_select_mode,
|
220
|
-
keep_resolution=keep_resolution,
|
221
|
-
)
|
222
|
-
|
223
|
-
def log_spectrograms(
|
224
|
-
self,
|
225
|
-
key: str,
|
226
|
-
value: Callable[[], Array] | Array,
|
227
|
-
*,
|
228
|
-
namespace: str | None = None,
|
229
|
-
max_audios: int | None = None,
|
230
|
-
sample_rate: int = 44100,
|
231
|
-
n_fft_ms: float = 32.0,
|
232
|
-
hop_length_ms: float | None = None,
|
233
|
-
channel_select_mode: ChannelSelectMode = "first",
|
234
|
-
spec_sep: int = 0,
|
235
|
-
keep_resolution: bool = False,
|
236
|
-
) -> None:
|
237
|
-
self.logger.log_spectrograms(
|
238
|
-
key,
|
239
|
-
value,
|
240
|
-
namespace=namespace,
|
241
|
-
max_audios=max_audios,
|
242
|
-
sample_rate=sample_rate,
|
243
|
-
n_fft_ms=n_fft_ms,
|
244
|
-
hop_length_ms=hop_length_ms,
|
245
|
-
channel_select_mode=channel_select_mode,
|
246
|
-
spec_sep=spec_sep,
|
247
|
-
keep_resolution=keep_resolution,
|
248
|
-
)
|
249
|
-
|
250
|
-
def log_video(
|
251
|
-
self,
|
252
|
-
key: str,
|
253
|
-
value: Callable[[], Array] | Array,
|
254
|
-
*,
|
255
|
-
namespace: str | None = None,
|
256
|
-
fps: int | None = None,
|
257
|
-
length: float | None = None,
|
258
|
-
) -> None:
|
259
|
-
self.logger.log_video(
|
260
|
-
key,
|
261
|
-
value,
|
262
|
-
namespace=namespace,
|
263
|
-
fps=fps,
|
264
|
-
length=length,
|
265
|
-
)
|
266
|
-
|
267
|
-
def log_videos(
|
268
|
-
self,
|
269
|
-
key: str,
|
270
|
-
value: Callable[[], Array | list[Array]] | Array | list[Array],
|
271
|
-
*,
|
272
|
-
namespace: str | None = None,
|
273
|
-
max_videos: int | None = None,
|
274
|
-
sep: int = 0,
|
275
|
-
fps: int | None = None,
|
276
|
-
length: int | None = None,
|
277
|
-
) -> None:
|
278
|
-
self.logger.log_videos(
|
279
|
-
key,
|
280
|
-
value,
|
281
|
-
namespace=namespace,
|
282
|
-
max_videos=max_videos,
|
283
|
-
sep=sep,
|
284
|
-
fps=fps,
|
285
|
-
length=length,
|
286
|
-
)
|
287
|
-
|
288
|
-
def log_histogram(self, key: str, value: Callable[[], Array] | Array, *, namespace: str | None = None) -> None:
|
289
|
-
self.logger.log_histogram(key, value, namespace=namespace)
|
290
|
-
|
291
|
-
def log_point_cloud(
|
292
|
-
self,
|
293
|
-
key: str,
|
294
|
-
value: Callable[[], Array] | Array,
|
295
|
-
*,
|
296
|
-
namespace: str | None = None,
|
297
|
-
max_points: int = 1000,
|
298
|
-
colors: Callable[[], Array] | Array | None = None,
|
299
|
-
) -> None:
|
300
|
-
self.logger.log_point_cloud(
|
301
|
-
key,
|
302
|
-
value,
|
303
|
-
namespace=namespace,
|
304
|
-
max_points=max_points,
|
305
|
-
colors=colors,
|
306
|
-
)
|
307
|
-
|
308
60
|
def __enter__(self) -> Self:
|
309
61
|
self.logger.__enter__()
|
310
62
|
return self
|
xax/task/mixins/process.py
CHANGED
@@ -38,6 +38,10 @@ class ProcessMixin(BaseTask[Config], Generic[Config]):
|
|
38
38
|
def multiprocessing_context(self) -> BaseContext:
|
39
39
|
return self._mp_ctx
|
40
40
|
|
41
|
+
@property
|
42
|
+
def multiprocessing_manager(self) -> SyncManager:
|
43
|
+
return self._mp_manager
|
44
|
+
|
41
45
|
def on_training_end(self, state: State) -> State:
|
42
46
|
state = super().on_training_end(state)
|
43
47
|
|
xax/task/mixins/train.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
"""Defines a mixin for running the training loop."""
|
2
2
|
|
3
|
+
import bdb
|
3
4
|
import contextlib
|
4
5
|
import functools
|
5
6
|
import itertools
|
6
7
|
import logging
|
8
|
+
import signal
|
7
9
|
import sys
|
8
10
|
import textwrap
|
9
11
|
import time
|
@@ -11,20 +13,21 @@ import traceback
|
|
11
13
|
from abc import ABC, abstractmethod
|
12
14
|
from dataclasses import dataclass, is_dataclass
|
13
15
|
from threading import Thread
|
14
|
-
from typing import Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
|
16
|
+
from typing import Any, Generic, Literal, Mapping, Sequence, TypeVar, cast, get_args
|
15
17
|
|
16
18
|
import equinox as eqx
|
17
19
|
import jax
|
18
20
|
import jax.numpy as jnp
|
19
21
|
import numpy as np
|
20
22
|
import optax
|
21
|
-
from jaxtyping import Array
|
23
|
+
from jaxtyping import Array, PyTree
|
22
24
|
from omegaconf import DictConfig
|
23
25
|
|
24
26
|
from xax.core.conf import field
|
25
27
|
from xax.core.state import Phase, State
|
26
28
|
from xax.nn.parallel import is_master
|
27
29
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
30
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
28
31
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
29
32
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
30
33
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -32,6 +35,8 @@ from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
|
32
35
|
from xax.utils.experiments import (
|
33
36
|
StateTimer,
|
34
37
|
TrainingFinishedError,
|
38
|
+
diff_configs,
|
39
|
+
get_diff_string,
|
35
40
|
get_git_state,
|
36
41
|
get_training_code,
|
37
42
|
)
|
@@ -40,9 +45,11 @@ from xax.utils.text import highlight_exception_message, show_info
|
|
40
45
|
|
41
46
|
logger = logging.getLogger(__name__)
|
42
47
|
|
43
|
-
|
44
|
-
|
45
|
-
|
48
|
+
# Batch = TypeVar("Batch")
|
49
|
+
# Output = TypeVar("Output")
|
50
|
+
|
51
|
+
Batch = Any
|
52
|
+
Output = Any
|
46
53
|
|
47
54
|
StepKind = Literal["step", "sample", "second"]
|
48
55
|
|
@@ -125,6 +132,7 @@ class ValidStepTimer:
|
|
125
132
|
|
126
133
|
@dataclass
|
127
134
|
class TrainConfig(
|
135
|
+
CheckpointingConfig,
|
128
136
|
DataloadersConfig,
|
129
137
|
LoggerConfig,
|
130
138
|
StepContextConfig,
|
@@ -145,12 +153,13 @@ Config = TypeVar("Config", bound=TrainConfig)
|
|
145
153
|
|
146
154
|
|
147
155
|
class TrainMixin(
|
156
|
+
CheckpointingMixin[Config],
|
148
157
|
DataloadersMixin[Config],
|
149
158
|
LoggerMixin[Config],
|
150
159
|
StepContextMixin[Config],
|
151
160
|
ArtifactsMixin[Config],
|
152
161
|
RunnableMixin[Config],
|
153
|
-
Generic[Config
|
162
|
+
Generic[Config],
|
154
163
|
ABC,
|
155
164
|
):
|
156
165
|
valid_step_timer: ValidStepTimer
|
@@ -159,7 +168,6 @@ class TrainMixin(
|
|
159
168
|
_training_over_flag: bool
|
160
169
|
_last_printed_remaining_time: float
|
161
170
|
_step_kind: StepKind
|
162
|
-
_prng_key: jnp.ndarray
|
163
171
|
|
164
172
|
def __init__(self, config: Config) -> None:
|
165
173
|
super().__init__(config)
|
@@ -183,12 +191,8 @@ class TrainMixin(
|
|
183
191
|
# The kind of step that was specified in the config.
|
184
192
|
self._step_kind = cast_step_kind(self.config.step_kind)
|
185
193
|
|
186
|
-
# Defines a PRNG key for the task.
|
187
|
-
self._prng_key = jax.random.PRNGKey(self.config.random_seed)
|
188
|
-
|
189
|
-
@property
|
190
194
|
def prng_key(self) -> jnp.ndarray:
|
191
|
-
return self.
|
195
|
+
return jax.random.PRNGKey(self.config.random_seed)
|
192
196
|
|
193
197
|
def on_step_end(self, state: State) -> State:
|
194
198
|
state = super().on_step_end(state)
|
@@ -198,7 +202,7 @@ class TrainMixin(
|
|
198
202
|
},
|
199
203
|
)
|
200
204
|
|
201
|
-
def log_train_step(self, model:
|
205
|
+
def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
202
206
|
"""Override this function to do logging during the training phase.
|
203
207
|
|
204
208
|
This function is called after the model forward pass and before the
|
@@ -211,7 +215,7 @@ class TrainMixin(
|
|
211
215
|
state: The current training state.
|
212
216
|
"""
|
213
217
|
|
214
|
-
def log_valid_step(self, model:
|
218
|
+
def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
215
219
|
"""Override this function to do logging during the validation phase.
|
216
220
|
|
217
221
|
This function is called after the model forward pass. It is called in
|
@@ -224,7 +228,7 @@ class TrainMixin(
|
|
224
228
|
state: The current training state.
|
225
229
|
"""
|
226
230
|
|
227
|
-
def log_step(self, model:
|
231
|
+
def log_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
|
228
232
|
phase = state.phase
|
229
233
|
|
230
234
|
# Log the state timers.
|
@@ -232,7 +236,7 @@ class TrainMixin(
|
|
232
236
|
timer.step(state)
|
233
237
|
for ns, d in timer.log_dict().items():
|
234
238
|
for k, v in d.items():
|
235
|
-
self.log_scalar(k, v, namespace=ns)
|
239
|
+
self.logger.log_scalar(k, v, namespace=ns)
|
236
240
|
|
237
241
|
# Delegate to the appropriate logging function based on the phase.
|
238
242
|
match phase:
|
@@ -244,7 +248,7 @@ class TrainMixin(
|
|
244
248
|
raise KeyError(f"Unknown phase: {phase}")
|
245
249
|
|
246
250
|
@abstractmethod
|
247
|
-
def get_model(self) ->
|
251
|
+
def get_model(self) -> PyTree:
|
248
252
|
"""Returns the Equinox model to train.
|
249
253
|
|
250
254
|
Returns:
|
@@ -259,11 +263,34 @@ class TrainMixin(
|
|
259
263
|
The optimizer to use to train the model.
|
260
264
|
"""
|
261
265
|
|
262
|
-
def get_initial_opt_state(self, model:
|
266
|
+
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
263
267
|
return optimizer.init(eqx.filter(model, eqx.is_array))
|
264
268
|
|
269
|
+
def load_initial_state(self) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
270
|
+
init_ckpt_path = self.get_init_ckpt_path()
|
271
|
+
|
272
|
+
if init_ckpt_path is not None:
|
273
|
+
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
274
|
+
with self.step_context("load_checkpoint"):
|
275
|
+
model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
|
276
|
+
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
277
|
+
if config_diff:
|
278
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
279
|
+
return model, optimizer, opt_state, state
|
280
|
+
|
281
|
+
with self.step_context("get_model"):
|
282
|
+
model = self.get_model()
|
283
|
+
|
284
|
+
with self.step_context("get_optimizer"):
|
285
|
+
optimizer = self.get_optimizer()
|
286
|
+
|
287
|
+
with self.step_context("get_initial_opt_state"):
|
288
|
+
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
|
289
|
+
|
290
|
+
return model, optimizer, opt_state, State.init_state()
|
291
|
+
|
265
292
|
@abstractmethod
|
266
|
-
def get_output(self, model:
|
293
|
+
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
267
294
|
"""Gets the output from the model.
|
268
295
|
|
269
296
|
By default, we assume the model is a function that takes the batch as
|
@@ -276,7 +303,7 @@ class TrainMixin(
|
|
276
303
|
state: The current training state.
|
277
304
|
"""
|
278
305
|
|
279
|
-
def compute_loss(self, model:
|
306
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
280
307
|
"""Gets the loss for the current batch.
|
281
308
|
|
282
309
|
By default, we assume the model is a function that takes the batch as
|
@@ -296,7 +323,7 @@ class TrainMixin(
|
|
296
323
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
297
324
|
return output
|
298
325
|
|
299
|
-
def get_output_and_loss(self, model:
|
326
|
+
def get_output_and_loss(self, model: PyTree, batch: Batch, state: State) -> tuple[Array, Output]:
|
300
327
|
output = self.get_output(model, batch, state)
|
301
328
|
loss = self.compute_loss(model, batch, output, state)
|
302
329
|
return loss, output
|
@@ -304,12 +331,12 @@ class TrainMixin(
|
|
304
331
|
@eqx.filter_jit
|
305
332
|
def update(
|
306
333
|
self,
|
307
|
-
model:
|
334
|
+
model: PyTree,
|
308
335
|
optimizer: optax.GradientTransformation,
|
309
336
|
opt_state: optax.OptState,
|
310
337
|
batch: Batch,
|
311
338
|
state: State,
|
312
|
-
) -> tuple[Array,
|
339
|
+
) -> tuple[Array, PyTree, optax.OptState, Output]:
|
313
340
|
(loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch, state)
|
314
341
|
updates, opt_state = optimizer.update(grads, opt_state)
|
315
342
|
model = eqx.apply_updates(model, updates)
|
@@ -358,7 +385,7 @@ class TrainMixin(
|
|
358
385
|
remaining_percent = self.get_remaining_percent(state)
|
359
386
|
if remaining_percent is None:
|
360
387
|
return False
|
361
|
-
self.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
|
388
|
+
self.logger.log_scalar("percent", remaining_percent, namespace="⏰ remaining")
|
362
389
|
self.maybe_log_termination_time(remaining_percent, state)
|
363
390
|
return remaining_percent <= 0.0
|
364
391
|
|
@@ -387,15 +414,15 @@ class TrainMixin(
|
|
387
414
|
|
388
415
|
def train_step(
|
389
416
|
self,
|
390
|
-
model:
|
417
|
+
model: PyTree,
|
391
418
|
optimizer: optax.GradientTransformation,
|
392
419
|
opt_state: optax.OptState,
|
393
420
|
batch: Batch,
|
394
421
|
state: State,
|
395
|
-
) -> tuple[
|
422
|
+
) -> tuple[PyTree, optax.OptState, State]:
|
396
423
|
state = state.with_phase("train")
|
397
424
|
loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch, state)
|
398
|
-
self.log_scalar("loss", loss, namespace="loss")
|
425
|
+
self.logger.log_scalar("loss", loss, namespace="loss")
|
399
426
|
self.log_step(model, batch, output, state)
|
400
427
|
self.write_logs(state)
|
401
428
|
return (
|
@@ -409,10 +436,10 @@ class TrainMixin(
|
|
409
436
|
),
|
410
437
|
)
|
411
438
|
|
412
|
-
def val_step(self, model:
|
439
|
+
def val_step(self, model: PyTree, batch: Batch, state: State) -> tuple[PyTree, State]:
|
413
440
|
state = state.with_phase("valid")
|
414
441
|
loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch, state)
|
415
|
-
self.log_scalar("loss", loss, namespace="loss")
|
442
|
+
self.logger.log_scalar("loss", loss, namespace="loss")
|
416
443
|
self.log_step(model, batch, output, state)
|
417
444
|
self.write_logs(state)
|
418
445
|
return model, state.replace(
|
@@ -462,20 +489,15 @@ class TrainMixin(
|
|
462
489
|
ctx.enter_context(train_pf)
|
463
490
|
ctx.enter_context(valid_pf)
|
464
491
|
|
465
|
-
|
466
|
-
with self.step_context("get_model"):
|
467
|
-
model = self.get_model()
|
492
|
+
model, optimizer, opt_state, state = self.load_initial_state()
|
468
493
|
|
469
|
-
|
470
|
-
with self.step_context("get_optimizer"):
|
471
|
-
optimizer = self.get_optimizer()
|
494
|
+
state = self.on_training_start(state)
|
472
495
|
|
473
|
-
|
474
|
-
|
475
|
-
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
|
496
|
+
def on_exit() -> None:
|
497
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
476
498
|
|
477
|
-
|
478
|
-
|
499
|
+
# Handle user-defined interrupts during the training loop.
|
500
|
+
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
479
501
|
|
480
502
|
try:
|
481
503
|
while True:
|
@@ -494,17 +516,26 @@ class TrainMixin(
|
|
494
516
|
with self.step_context("on_step_end"):
|
495
517
|
state = self.on_step_end(state)
|
496
518
|
|
519
|
+
if self.should_checkpoint(state):
|
520
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
521
|
+
|
497
522
|
except TrainingFinishedError:
|
498
523
|
if is_master():
|
499
524
|
show_info(
|
500
525
|
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
501
526
|
important=True,
|
502
527
|
)
|
528
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
529
|
+
|
530
|
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
531
|
+
if is_master():
|
532
|
+
show_info("Interrupted training", important=True)
|
503
533
|
|
504
534
|
except BaseException:
|
505
535
|
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
506
536
|
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
507
537
|
sys.stdout.flush()
|
538
|
+
self.save_checkpoint(model, optimizer, opt_state, state)
|
508
539
|
|
509
540
|
finally:
|
510
541
|
state = self.on_training_end(state)
|
xax/task/task.py
CHANGED
@@ -7,7 +7,8 @@ from xax.task.base import BaseConfig, BaseTask
|
|
7
7
|
from xax.task.mixins import (
|
8
8
|
ArtifactsConfig,
|
9
9
|
ArtifactsMixin,
|
10
|
-
|
10
|
+
CheckpointingConfig,
|
11
|
+
CheckpointingMixin,
|
11
12
|
CPUStatsConfig,
|
12
13
|
CPUStatsMixin,
|
13
14
|
DataloadersConfig,
|
@@ -16,8 +17,6 @@ from xax.task.mixins import (
|
|
16
17
|
GPUStatsMixin,
|
17
18
|
LoggerConfig,
|
18
19
|
LoggerMixin,
|
19
|
-
Model,
|
20
|
-
Output,
|
21
20
|
ProcessConfig,
|
22
21
|
ProcessMixin,
|
23
22
|
RunnableConfig,
|
@@ -32,6 +31,7 @@ from xax.task.mixins import (
|
|
32
31
|
@dataclass
|
33
32
|
class Config(
|
34
33
|
TrainConfig,
|
34
|
+
CheckpointingConfig,
|
35
35
|
DataloadersConfig,
|
36
36
|
CPUStatsConfig,
|
37
37
|
GPUStatsConfig,
|
@@ -49,7 +49,8 @@ ConfigT = TypeVar("ConfigT", bound=Config)
|
|
49
49
|
|
50
50
|
|
51
51
|
class Task(
|
52
|
-
TrainMixin[ConfigT
|
52
|
+
TrainMixin[ConfigT],
|
53
|
+
CheckpointingMixin[ConfigT],
|
53
54
|
DataloadersMixin[ConfigT],
|
54
55
|
CPUStatsMixin[ConfigT],
|
55
56
|
GPUStatsMixin[ConfigT],
|
@@ -59,6 +60,6 @@ class Task(
|
|
59
60
|
ArtifactsMixin[ConfigT],
|
60
61
|
RunnableMixin[ConfigT],
|
61
62
|
BaseTask[ConfigT],
|
62
|
-
Generic[ConfigT
|
63
|
+
Generic[ConfigT],
|
63
64
|
):
|
64
65
|
pass
|
xax/utils/data/collate.py
CHANGED
@@ -167,9 +167,9 @@ def collate(
|
|
167
167
|
# Collate dictionaries if they have the same keys.
|
168
168
|
if isinstance(item, dict) and all(set(i.keys()) == set(item.keys()) for i in items):
|
169
169
|
output_dict = {}
|
170
|
-
|
171
|
-
for
|
172
|
-
output_dict[
|
170
|
+
item_keys_set = set(item.keys())
|
171
|
+
for key_in_set in item_keys_set:
|
172
|
+
output_dict[key_in_set] = collate([i[key_in_set] for i in items], mode=mode, pad=pad)
|
173
173
|
return output_dict
|
174
174
|
|
175
175
|
# Collate lists and tuples if they have the same lengths.
|
@@ -186,9 +186,9 @@ def collate(
|
|
186
186
|
# Handles dataclasses.
|
187
187
|
if is_dataclass(item):
|
188
188
|
output_dict = {}
|
189
|
-
|
190
|
-
for
|
191
|
-
output_dict[
|
189
|
+
item_keys_dict = item.__dict__.keys()
|
190
|
+
for key_in_dict in item_keys_dict:
|
191
|
+
output_dict[key_in_dict] = collate([getattr(i, key_in_dict) for i in items], mode=mode, pad=pad)
|
192
192
|
return item.__class__(**output_dict)
|
193
193
|
|
194
194
|
# By default, don't do anything.
|