xax 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/RECORD +26 -21
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
xax/task/logger.py
CHANGED
@@ -223,10 +223,29 @@ class LogVideo:
|
|
223
223
|
fps: int
|
224
224
|
|
225
225
|
|
226
|
+
@dataclass(kw_only=True)
|
227
|
+
class LogDistribution:
|
228
|
+
mean: Number
|
229
|
+
std: Number
|
230
|
+
|
231
|
+
|
232
|
+
@dataclass(kw_only=True)
|
233
|
+
class LogHistogram:
|
234
|
+
min: Number
|
235
|
+
max: Number
|
236
|
+
num: int
|
237
|
+
sum: Number
|
238
|
+
sum_squares: Number
|
239
|
+
bucket_limits: list[Number]
|
240
|
+
bucket_counts: list[int]
|
241
|
+
|
242
|
+
|
226
243
|
@dataclass(kw_only=True)
|
227
244
|
class LogLine:
|
228
245
|
state: State
|
229
246
|
scalars: dict[str, dict[str, Number]]
|
247
|
+
distributions: dict[str, dict[str, LogDistribution]]
|
248
|
+
histograms: dict[str, dict[str, LogHistogram]]
|
230
249
|
strings: dict[str, dict[str, str]]
|
231
250
|
images: dict[str, dict[str, LogImage]]
|
232
251
|
videos: dict[str, dict[str, LogVideo]]
|
@@ -329,9 +348,9 @@ def image_with_text(
|
|
329
348
|
else:
|
330
349
|
text = text[:max_num_lines]
|
331
350
|
width, height = image.size
|
332
|
-
font: ImageFont.ImageFont = ImageFont.load_default()
|
351
|
+
font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
|
333
352
|
_, _, _, line_height = font.getbbox(text[0])
|
334
|
-
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
353
|
+
new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
|
335
354
|
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
336
355
|
padded_image.paste(image, (0, 0))
|
337
356
|
drawer = ImageDraw.Draw(padded_image)
|
@@ -497,6 +516,8 @@ class Logger:
|
|
497
516
|
|
498
517
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
499
518
|
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
519
|
+
self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
|
520
|
+
self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
|
500
521
|
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
501
522
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
502
523
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
@@ -522,6 +543,8 @@ class Logger:
|
|
522
543
|
return LogLine(
|
523
544
|
state=state,
|
524
545
|
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
546
|
+
distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
|
547
|
+
histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
|
525
548
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
526
549
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
527
550
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
@@ -529,6 +552,8 @@ class Logger:
|
|
529
552
|
|
530
553
|
def clear(self) -> None:
|
531
554
|
self.scalars.clear()
|
555
|
+
self.distributions.clear()
|
556
|
+
self.histograms.clear()
|
532
557
|
self.strings.clear()
|
533
558
|
self.images.clear()
|
534
559
|
self.videos.clear()
|
@@ -612,6 +637,76 @@ class Logger:
|
|
612
637
|
|
613
638
|
self.scalars[namespace][key] = scalar_future
|
614
639
|
|
640
|
+
def log_distribution(
|
641
|
+
self,
|
642
|
+
key: str,
|
643
|
+
value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
|
644
|
+
*,
|
645
|
+
namespace: str | None = None,
|
646
|
+
) -> None:
|
647
|
+
"""Logs a distribution value.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
key: The key being logged
|
651
|
+
value: The distribution value being logged, a tuple of (mean, std)
|
652
|
+
namespace: An optional logging namespace
|
653
|
+
"""
|
654
|
+
if not self.active:
|
655
|
+
raise RuntimeError("The logger is not active")
|
656
|
+
namespace = self.resolve_namespace(namespace)
|
657
|
+
|
658
|
+
@functools.lru_cache(maxsize=None)
|
659
|
+
def distribution_future() -> LogDistribution:
|
660
|
+
mean, std = value() if callable(value) else value
|
661
|
+
return LogDistribution(mean=mean, std=std)
|
662
|
+
|
663
|
+
self.distributions[namespace][key] = distribution_future
|
664
|
+
|
665
|
+
def log_histogram(
|
666
|
+
self,
|
667
|
+
key: str,
|
668
|
+
value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
|
669
|
+
*,
|
670
|
+
bins: int = 100,
|
671
|
+
namespace: str | None = None,
|
672
|
+
) -> None:
|
673
|
+
"""Logs a histogram value.
|
674
|
+
|
675
|
+
Args:
|
676
|
+
key: The key being logged
|
677
|
+
value: The histogram value being logged
|
678
|
+
bins: The number of bins to use for the histogram
|
679
|
+
namespace: An optional logging namespace
|
680
|
+
"""
|
681
|
+
if not self.active:
|
682
|
+
raise RuntimeError("The logger is not active")
|
683
|
+
namespace = self.resolve_namespace(namespace)
|
684
|
+
|
685
|
+
@functools.lru_cache(maxsize=None)
|
686
|
+
def histogram_future() -> LogHistogram:
|
687
|
+
values = value() if callable(value) else value
|
688
|
+
values = values.reshape(-1) # Must be flat.
|
689
|
+
|
690
|
+
if isinstance(values, Array):
|
691
|
+
counts, limits = jnp.histogram(values, bins=bins)
|
692
|
+
counts, limits = as_numpy(counts), as_numpy(limits)
|
693
|
+
elif isinstance(values, np.ndarray):
|
694
|
+
counts, limits = np.histogram(values, bins=bins)
|
695
|
+
else:
|
696
|
+
raise ValueError(f"Unsupported histogram type: {type(values)}")
|
697
|
+
|
698
|
+
return LogHistogram(
|
699
|
+
min=float(values.min()),
|
700
|
+
max=float(values.max()),
|
701
|
+
num=int(values.size),
|
702
|
+
sum=float(values.sum()),
|
703
|
+
sum_squares=float(values.dot(values)),
|
704
|
+
bucket_limits=limits[1:].tolist(),
|
705
|
+
bucket_counts=counts.tolist(),
|
706
|
+
)
|
707
|
+
|
708
|
+
self.histograms[namespace][key] = histogram_future
|
709
|
+
|
615
710
|
def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
|
616
711
|
"""Logs a string value.
|
617
712
|
|
xax/task/loggers/stdout.py
CHANGED
@@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl):
|
|
33
33
|
self,
|
34
34
|
write_fp: TextIO = sys.stdout,
|
35
35
|
precision: int = 4,
|
36
|
-
log_timers: bool =
|
36
|
+
log_timers: bool = True,
|
37
37
|
log_perf: bool = False,
|
38
38
|
log_optim: bool = False,
|
39
39
|
log_fp: bool = False,
|
@@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl):
|
|
98
98
|
|
99
99
|
def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
|
100
100
|
for namespace, values in log.items():
|
101
|
-
if not self.log_timers and namespace.startswith("
|
101
|
+
if not self.log_timers and namespace.startswith("⌛"):
|
102
102
|
continue
|
103
103
|
if not self.log_perf and namespace.startswith("🔧"):
|
104
104
|
continue
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1
1
|
"""Defines a Tensorboard logger backend."""
|
2
2
|
|
3
3
|
import atexit
|
4
|
-
import functools
|
5
4
|
import logging
|
6
5
|
import os
|
7
6
|
import re
|
8
|
-
import shutil
|
9
7
|
import subprocess
|
10
8
|
import threading
|
11
9
|
import time
|
@@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl):
|
|
140
138
|
def __del__(self) -> None:
|
141
139
|
self.cleanup()
|
142
140
|
|
143
|
-
@functools.lru_cache(None) # Avoid clearing logs multiple times.
|
144
|
-
def clear_logs(self) -> None:
|
145
|
-
if not self.log_directory.exists():
|
146
|
-
return
|
147
|
-
if not any(child.is_dir() for child in self.log_directory.iterdir()):
|
148
|
-
return
|
149
|
-
logger.warning("Clearing TensorBoard logs")
|
150
|
-
shutil.rmtree(self.log_directory)
|
151
|
-
|
152
141
|
def get_writer(self, phase: Phase) -> TensorboardWriter:
|
153
142
|
self._start()
|
154
143
|
return self.writers.writer(phase)
|
@@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl):
|
|
162
151
|
if not is_master():
|
163
152
|
return
|
164
153
|
|
165
|
-
if line.state.num_steps == 0:
|
166
|
-
self.clear_logs()
|
167
|
-
|
168
154
|
writer = self.get_writer(line.state.phase)
|
169
155
|
walltime = line.state.start_time_s + line.state.elapsed_time_s
|
170
156
|
|
@@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl):
|
|
177
163
|
walltime=walltime,
|
178
164
|
)
|
179
165
|
|
166
|
+
for namespace, distributions in line.distributions.items():
|
167
|
+
for distribution_key, distribution_value in distributions.items():
|
168
|
+
writer.add_gaussian_distribution(
|
169
|
+
f"{namespace}/{distribution_key}",
|
170
|
+
mean=float(distribution_value.mean),
|
171
|
+
std=float(distribution_value.std),
|
172
|
+
global_step=line.state.num_steps,
|
173
|
+
walltime=walltime,
|
174
|
+
)
|
175
|
+
|
176
|
+
for namespace, histograms in line.histograms.items():
|
177
|
+
for histogram_key, histogram_value in histograms.items():
|
178
|
+
writer.add_histogram_raw(
|
179
|
+
f"{namespace}/{histogram_key}",
|
180
|
+
min=float(histogram_value.min),
|
181
|
+
max=float(histogram_value.max),
|
182
|
+
num=int(histogram_value.num),
|
183
|
+
sum=float(histogram_value.sum),
|
184
|
+
sum_squares=float(histogram_value.sum_squares),
|
185
|
+
bucket_limits=[float(x) for x in histogram_value.bucket_limits],
|
186
|
+
bucket_counts=[int(x) for x in histogram_value.bucket_counts],
|
187
|
+
global_step=line.state.num_steps,
|
188
|
+
walltime=walltime,
|
189
|
+
)
|
190
|
+
|
180
191
|
for namespace, strings in line.strings.items():
|
181
192
|
for string_key, string_value in strings.items():
|
182
193
|
writer.add_text(
|
xax/task/mixins/artifacts.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
import functools
|
4
4
|
import inspect
|
5
5
|
import logging
|
6
|
-
import os
|
7
6
|
from dataclasses import dataclass
|
8
7
|
from pathlib import Path
|
9
8
|
from typing import Self, TypeVar
|
@@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
54
53
|
self._exp_dir = Path(exp_dir).expanduser().resolve()
|
55
54
|
return self
|
56
55
|
|
57
|
-
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
58
|
-
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
59
|
-
if not exists_ok:
|
60
|
-
raise RuntimeError(f"Lock file already exists at {lock_file}")
|
61
|
-
else:
|
62
|
-
with open(lock_file, "w", encoding="utf-8") as f:
|
63
|
-
f.write(f"PID: {os.getpid()}")
|
64
|
-
|
65
|
-
def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
|
66
|
-
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
67
|
-
lock_file.unlink()
|
68
|
-
elif not missing_ok:
|
69
|
-
raise RuntimeError(f"Lock file not found at {lock_file}")
|
70
|
-
|
71
56
|
def get_exp_dir(self) -> Path:
|
72
57
|
if self._exp_dir is not None:
|
73
58
|
return self._exp_dir
|
@@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
82
67
|
def get_exp_dir(run_id: int) -> Path:
|
83
68
|
return self.run_dir / f"run_{run_id}"
|
84
69
|
|
85
|
-
def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
|
86
|
-
if lock_type is not None:
|
87
|
-
return (exp_dir / f".lock_{lock_type}").exists()
|
88
|
-
return any(exp_dir.glob(".lock_*"))
|
89
|
-
|
90
70
|
run_id = 0
|
91
|
-
while (exp_dir := get_exp_dir(run_id)).is_dir()
|
71
|
+
while (exp_dir := get_exp_dir(run_id)).is_dir():
|
92
72
|
run_id += 1
|
93
73
|
exp_dir.mkdir(exist_ok=True, parents=True)
|
94
74
|
self._exp_dir = exp_dir.expanduser().resolve()
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
24
|
-
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
|
24
|
+
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
|
25
25
|
|
26
26
|
|
27
27
|
def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
@@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
88
88
|
def load_checkpoint(
|
89
89
|
self,
|
90
90
|
path: Path,
|
91
|
+
part: Literal["all"] = "all",
|
91
92
|
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
92
93
|
|
94
|
+
@overload
|
95
|
+
def load_checkpoint(
|
96
|
+
self,
|
97
|
+
path: Path,
|
98
|
+
part: Literal["model_state_config"] = "model_state_config",
|
99
|
+
) -> tuple[PyTree, State, DictConfig]: ...
|
100
|
+
|
93
101
|
@overload
|
94
102
|
def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
|
95
103
|
|
@@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
108
116
|
def load_checkpoint(
|
109
117
|
self,
|
110
118
|
path: Path,
|
111
|
-
part: CheckpointPart
|
119
|
+
part: CheckpointPart = "all",
|
112
120
|
) -> (
|
113
121
|
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
122
|
+
| tuple[PyTree, State, DictConfig]
|
114
123
|
| PyTree
|
115
124
|
| optax.GradientTransformation
|
116
125
|
| optax.OptState
|
117
126
|
| State
|
118
127
|
| DictConfig
|
119
128
|
):
|
129
|
+
# Calls the base callback.
|
130
|
+
self.on_before_checkpoint_load(path)
|
131
|
+
|
120
132
|
with tarfile.open(path, "r:gz") as tar:
|
121
133
|
|
122
134
|
def get_model() -> PyTree:
|
@@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
155
167
|
return get_state()
|
156
168
|
case "config":
|
157
169
|
return get_config()
|
158
|
-
case
|
170
|
+
case "model_state_config":
|
171
|
+
return get_model(), get_state(), get_config()
|
172
|
+
case "all":
|
159
173
|
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
160
174
|
case _:
|
161
175
|
raise ValueError(f"Invalid checkpoint part: {part}")
|
@@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
215
229
|
except FileExistsError:
|
216
230
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
217
231
|
|
218
|
-
#
|
219
|
-
self.
|
232
|
+
# Calls the base callback.
|
233
|
+
self.on_after_checkpoint_save(ckpt_path, state)
|
220
234
|
|
221
235
|
return ckpt_path
|
xax/task/mixins/logger.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar
|
|
8
8
|
|
9
9
|
import jax
|
10
10
|
|
11
|
+
from xax.core.conf import field
|
11
12
|
from xax.core.state import State
|
12
13
|
from xax.task.base import BaseConfig, BaseTask
|
13
14
|
from xax.task.logger import Logger, LoggerImpl
|
@@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session
|
|
22
23
|
@jax.tree_util.register_dataclass
|
23
24
|
@dataclass
|
24
25
|
class LoggerConfig(BaseConfig):
|
25
|
-
|
26
|
+
log_interval_seconds: float = field(
|
27
|
+
value=1.0,
|
28
|
+
help="The interval between successive log lines.",
|
29
|
+
)
|
30
|
+
tensorboard_log_interval_seconds: float = field(
|
31
|
+
value=10.0,
|
32
|
+
help="The interval between successive Tensorboard log lines.",
|
33
|
+
)
|
26
34
|
|
27
35
|
|
28
36
|
Config = TypeVar("Config", bound=LoggerConfig)
|
@@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
|
|
49
57
|
self.logger.add_logger(*logger)
|
50
58
|
|
51
59
|
def set_loggers(self) -> None:
|
52
|
-
self.add_logger(
|
60
|
+
self.add_logger(
|
61
|
+
StdoutLogger(
|
62
|
+
log_interval_seconds=self.config.log_interval_seconds,
|
63
|
+
)
|
64
|
+
if is_interactive_session()
|
65
|
+
else JsonLogger(
|
66
|
+
log_interval_seconds=self.config.log_interval_seconds,
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
# If this is also an ArtifactsMixin, we should default add some
|
71
|
+
# additional loggers which log data to the artifacts directory.
|
53
72
|
if isinstance(self, ArtifactsMixin):
|
54
73
|
self.add_logger(
|
55
|
-
StateLogger(
|
56
|
-
|
74
|
+
StateLogger(
|
75
|
+
run_directory=self.exp_dir,
|
76
|
+
),
|
77
|
+
TensorboardLogger(
|
78
|
+
run_directory=self.exp_dir,
|
79
|
+
log_interval_seconds=self.config.tensorboard_log_interval_seconds,
|
80
|
+
),
|
57
81
|
)
|
58
82
|
|
59
83
|
def write_logs(self, state: State) -> None:
|
xax/task/mixins/step_wrapper.py
CHANGED
@@ -1,53 +1,39 @@
|
|
1
1
|
"""Defines a mixin to wrap some steps in a context manager."""
|
2
2
|
|
3
|
+
import time
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from types import TracebackType
|
5
|
-
from typing import
|
6
|
+
from typing import Callable, ContextManager, TypeVar
|
6
7
|
|
7
|
-
import equinox as eqx
|
8
8
|
import jax
|
9
9
|
|
10
10
|
from xax.task.base import BaseConfig, BaseTask
|
11
11
|
|
12
|
-
StepType = Literal[
|
13
|
-
"backward",
|
14
|
-
"change_mode",
|
15
|
-
"clip_grads",
|
16
|
-
"create_optimizers",
|
17
|
-
"forward",
|
18
|
-
"get_dataloader",
|
19
|
-
"get_dataset",
|
20
|
-
"get_prefetcher",
|
21
|
-
"get_model",
|
22
|
-
"get_optimizer",
|
23
|
-
"get_initial_opt_state",
|
24
|
-
"get_update_fn",
|
25
|
-
"load_checkpoint",
|
26
|
-
"log_losses",
|
27
|
-
"model_to_device",
|
28
|
-
"on_step_end",
|
29
|
-
"on_step_start",
|
30
|
-
"save_checkpoint",
|
31
|
-
"step",
|
32
|
-
"update_state",
|
33
|
-
"write_logs",
|
34
|
-
"zero_grads",
|
35
|
-
]
|
36
|
-
|
37
12
|
|
38
13
|
class StepContext(ContextManager):
|
39
14
|
"""Context manager to get the current step type."""
|
40
15
|
|
41
|
-
CURRENT_STEP:
|
16
|
+
CURRENT_STEP: str | None = None
|
42
17
|
|
43
|
-
def __init__(
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
step: str,
|
21
|
+
on_context_start: Callable[[str], None],
|
22
|
+
on_context_end: Callable[[str, float], None],
|
23
|
+
) -> None:
|
44
24
|
self.step = step
|
25
|
+
self.start_time = 0.0
|
26
|
+
self.on_context_start = on_context_start
|
27
|
+
self.on_context_end = on_context_end
|
45
28
|
|
46
29
|
def __enter__(self) -> None:
|
47
30
|
StepContext.CURRENT_STEP = self.step
|
31
|
+
self.start_time = time.time()
|
32
|
+
self.on_context_start(self.step)
|
48
33
|
|
49
34
|
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
50
35
|
StepContext.CURRENT_STEP = None
|
36
|
+
self.on_context_end(self.step, time.time() - self.start_time)
|
51
37
|
|
52
38
|
|
53
39
|
@jax.tree_util.register_dataclass
|
@@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]):
|
|
63
49
|
def __init__(self, config: Config) -> None:
|
64
50
|
super().__init__(config)
|
65
51
|
|
66
|
-
|
67
|
-
|
68
|
-
|
52
|
+
def step_context(self, step: str) -> ContextManager:
|
53
|
+
return StepContext(step, self.on_context_start, self.on_context_stop)
|
54
|
+
|
55
|
+
def on_context_start(self, step: str) -> None:
|
56
|
+
pass
|
57
|
+
|
58
|
+
def on_context_stop(self, step: str, elapsed_time: float) -> None:
|
59
|
+
pass
|
xax/task/mixins/train.py
CHANGED
@@ -24,6 +24,7 @@ from typing import (
|
|
24
24
|
TypeVar,
|
25
25
|
cast,
|
26
26
|
get_args,
|
27
|
+
overload,
|
27
28
|
)
|
28
29
|
|
29
30
|
import equinox as eqx
|
@@ -35,6 +36,7 @@ from omegaconf import DictConfig
|
|
35
36
|
|
36
37
|
from xax.core.conf import field
|
37
38
|
from xax.core.state import Phase, State
|
39
|
+
from xax.nn.functions import set_random_seed
|
38
40
|
from xax.nn.parallel import is_master
|
39
41
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
40
42
|
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
@@ -115,7 +117,7 @@ class ValidStepTimer:
|
|
115
117
|
if self.last_valid_time is None or self.last_valid_step is None:
|
116
118
|
self.last_valid_time = state.elapsed_time_s
|
117
119
|
self.last_valid_step = state.num_steps
|
118
|
-
return
|
120
|
+
return False
|
119
121
|
|
120
122
|
# Step-based validation.
|
121
123
|
valid_every_n_steps = self.valid_every_n_steps
|
@@ -183,6 +185,9 @@ class TrainMixin(
|
|
183
185
|
def __init__(self, config: Config) -> None:
|
184
186
|
super().__init__(config)
|
185
187
|
|
188
|
+
# Sets the random seed whenever we instantiate a new train mixin.
|
189
|
+
set_random_seed(self.config.random_seed)
|
190
|
+
|
186
191
|
# Timer for validation steps.
|
187
192
|
self.valid_step_timer = ValidStepTimer(
|
188
193
|
valid_every_n_steps=config.valid_every_n_steps,
|
@@ -279,31 +284,53 @@ class TrainMixin(
|
|
279
284
|
def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
|
280
285
|
return optimizer.init(eqx.filter(model, eqx.is_array))
|
281
286
|
|
287
|
+
@overload
|
288
|
+
def load_initial_state(
|
289
|
+
self,
|
290
|
+
key: PRNGKeyArray,
|
291
|
+
load_optimizer: Literal[False] = False,
|
292
|
+
) -> tuple[PyTree, State]: ...
|
293
|
+
|
294
|
+
@overload
|
282
295
|
def load_initial_state(
|
283
296
|
self,
|
284
297
|
key: PRNGKeyArray,
|
285
|
-
|
298
|
+
load_optimizer: Literal[True],
|
299
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
|
300
|
+
|
301
|
+
def load_initial_state(
|
302
|
+
self,
|
303
|
+
key: PRNGKeyArray,
|
304
|
+
load_optimizer: bool = False,
|
305
|
+
) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
|
286
306
|
init_ckpt_path = self.get_init_ckpt_path()
|
287
307
|
|
288
308
|
if init_ckpt_path is not None:
|
289
309
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
290
|
-
|
310
|
+
if load_optimizer:
|
291
311
|
model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
|
292
312
|
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
293
313
|
if config_diff:
|
294
314
|
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
295
315
|
return model, optimizer, opt_state, state
|
296
316
|
|
297
|
-
|
298
|
-
|
317
|
+
else:
|
318
|
+
model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
|
319
|
+
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
320
|
+
if config_diff:
|
321
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
322
|
+
return model, state
|
323
|
+
|
324
|
+
model = self.get_model(key)
|
325
|
+
state = State.init_state()
|
299
326
|
|
300
|
-
|
301
|
-
|
327
|
+
if not load_optimizer:
|
328
|
+
return model, state
|
302
329
|
|
303
|
-
|
304
|
-
|
330
|
+
optimizer = self.get_optimizer()
|
331
|
+
opt_state = self.get_initial_opt_state(model, optimizer)
|
305
332
|
|
306
|
-
return model, optimizer, opt_state,
|
333
|
+
return model, optimizer, opt_state, state
|
307
334
|
|
308
335
|
@eqx.filter_jit
|
309
336
|
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
@@ -424,6 +451,7 @@ class TrainMixin(
|
|
424
451
|
def log_state(self) -> None:
|
425
452
|
logger.log(LOG_STATUS, self.task_path)
|
426
453
|
logger.log(LOG_STATUS, self.task_name)
|
454
|
+
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
427
455
|
self.logger.log_file("git_state.txt", get_git_state(self))
|
428
456
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
429
457
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
@@ -456,7 +484,8 @@ class TrainMixin(
|
|
456
484
|
while not self.is_training_over(state):
|
457
485
|
if self.valid_step_timer.is_valid_step(state):
|
458
486
|
valid_batch = next(valid_pf)
|
459
|
-
|
487
|
+
with self.step_context("model_step"):
|
488
|
+
model, loss, output = self.val_step(model, valid_batch)
|
460
489
|
|
461
490
|
# Perform logging.
|
462
491
|
with self.step_context("write_logs"):
|
@@ -464,22 +493,19 @@ class TrainMixin(
|
|
464
493
|
self.log_step(model, valid_batch, output, loss, state)
|
465
494
|
state.num_valid_samples += 1
|
466
495
|
|
467
|
-
|
468
|
-
state = self.on_step_start(state)
|
496
|
+
state = self.on_step_start(state)
|
469
497
|
|
470
|
-
with self.step_context("
|
498
|
+
with self.step_context("model_step"):
|
471
499
|
train_batch = next(train_pf)
|
472
500
|
model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
|
473
501
|
|
474
|
-
# Perform logging.
|
475
502
|
with self.step_context("write_logs"):
|
476
503
|
state.phase = "train"
|
477
504
|
self.log_step(model, train_batch, output, loss, state)
|
478
505
|
state.num_steps += 1
|
479
506
|
state.num_samples += self.get_size_of_batch(train_batch) or 0
|
480
507
|
|
481
|
-
|
482
|
-
state = self.on_step_end(state)
|
508
|
+
state = self.on_step_end(state)
|
483
509
|
|
484
510
|
if self.should_checkpoint(state):
|
485
511
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
@@ -496,14 +522,9 @@ class TrainMixin(
|
|
496
522
|
except NotImplementedError:
|
497
523
|
pass
|
498
524
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
with self.step_context("get_dataloader"):
|
503
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
504
|
-
|
505
|
-
with self.step_context("get_prefetcher"):
|
506
|
-
train_pf = self.get_prefetcher(train_dl)
|
525
|
+
train_ds = self.get_dataset("train")
|
526
|
+
train_dl = self.get_dataloader(train_ds, "train")
|
527
|
+
train_pf = self.get_prefetcher(train_dl)
|
507
528
|
|
508
529
|
try:
|
509
530
|
with train_pf as train_pf_ctx:
|
@@ -520,14 +541,9 @@ class TrainMixin(
|
|
520
541
|
except NotImplementedError:
|
521
542
|
pass
|
522
543
|
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
with self.step_context("get_dataloader"):
|
527
|
-
valid_dl = self.get_dataloader(valid_ds, "valid")
|
528
|
-
|
529
|
-
with self.step_context("get_prefetcher"):
|
530
|
-
valid_pf = self.get_prefetcher(valid_dl)
|
544
|
+
valid_ds = self.get_dataset("valid")
|
545
|
+
valid_dl = self.get_dataloader(valid_ds, "valid")
|
546
|
+
valid_pf = self.get_prefetcher(valid_dl)
|
531
547
|
|
532
548
|
try:
|
533
549
|
with valid_pf as valid_pf_ctx:
|
@@ -559,7 +575,7 @@ class TrainMixin(
|
|
559
575
|
Thread(target=self.log_state, daemon=True).start()
|
560
576
|
|
561
577
|
key, model_key = jax.random.split(key)
|
562
|
-
model, optimizer, opt_state, state = self.load_initial_state(model_key)
|
578
|
+
model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
|
563
579
|
state = self.on_training_start(state)
|
564
580
|
|
565
581
|
def on_exit() -> None:
|