xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +121 -3
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +101 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jax.py +126 -0
- xax/utils/jaxpr.py +77 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +238 -0
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/RECORD +28 -20
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.6.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
xax/nn/norm.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Normalization utilities."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, get_args
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
NormType = Literal["l1", "l2"]
|
8
|
+
|
9
|
+
|
10
|
+
def cast_norm_type(norm: str) -> NormType:
|
11
|
+
if norm not in get_args(NormType):
|
12
|
+
raise ValueError(f"Invalid norm: {norm}")
|
13
|
+
return cast(NormType, norm)
|
14
|
+
|
15
|
+
|
16
|
+
def get_norm(x: jnp.ndarray, norm: NormType) -> jnp.ndarray:
|
17
|
+
match norm:
|
18
|
+
case "l1":
|
19
|
+
return jnp.abs(x)
|
20
|
+
case "l2":
|
21
|
+
return jnp.square(x)
|
22
|
+
case _:
|
23
|
+
raise ValueError(f"Invalid norm: {norm}")
|
xax/requirements.txt
CHANGED
xax/task/base.py
CHANGED
@@ -81,6 +81,12 @@ class BaseTask(Generic[Config]):
|
|
81
81
|
def on_training_end(self, state: State) -> State:
|
82
82
|
return state
|
83
83
|
|
84
|
+
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
85
|
+
return state
|
86
|
+
|
87
|
+
def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
|
88
|
+
pass
|
89
|
+
|
84
90
|
@functools.cached_property
|
85
91
|
def task_class_name(self) -> str:
|
86
92
|
return self.__class__.__name__
|
xax/task/logger.py
CHANGED
@@ -223,10 +223,29 @@ class LogVideo:
|
|
223
223
|
fps: int
|
224
224
|
|
225
225
|
|
226
|
+
@dataclass(kw_only=True)
|
227
|
+
class LogDistribution:
|
228
|
+
mean: Number
|
229
|
+
std: Number
|
230
|
+
|
231
|
+
|
232
|
+
@dataclass(kw_only=True)
|
233
|
+
class LogHistogram:
|
234
|
+
min: Number
|
235
|
+
max: Number
|
236
|
+
num: int
|
237
|
+
sum: Number
|
238
|
+
sum_squares: Number
|
239
|
+
bucket_limits: list[Number]
|
240
|
+
bucket_counts: list[int]
|
241
|
+
|
242
|
+
|
226
243
|
@dataclass(kw_only=True)
|
227
244
|
class LogLine:
|
228
245
|
state: State
|
229
246
|
scalars: dict[str, dict[str, Number]]
|
247
|
+
distributions: dict[str, dict[str, LogDistribution]]
|
248
|
+
histograms: dict[str, dict[str, LogHistogram]]
|
230
249
|
strings: dict[str, dict[str, str]]
|
231
250
|
images: dict[str, dict[str, LogImage]]
|
232
251
|
videos: dict[str, dict[str, LogVideo]]
|
@@ -329,9 +348,9 @@ def image_with_text(
|
|
329
348
|
else:
|
330
349
|
text = text[:max_num_lines]
|
331
350
|
width, height = image.size
|
332
|
-
font: ImageFont.ImageFont = ImageFont.load_default()
|
351
|
+
font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
|
333
352
|
_, _, _, line_height = font.getbbox(text[0])
|
334
|
-
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
353
|
+
new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
|
335
354
|
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
336
355
|
padded_image.paste(image, (0, 0))
|
337
356
|
drawer = ImageDraw.Draw(padded_image)
|
@@ -497,6 +516,8 @@ class Logger:
|
|
497
516
|
|
498
517
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
499
518
|
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
519
|
+
self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
|
520
|
+
self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
|
500
521
|
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
501
522
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
502
523
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
@@ -522,6 +543,8 @@ class Logger:
|
|
522
543
|
return LogLine(
|
523
544
|
state=state,
|
524
545
|
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
546
|
+
distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
|
547
|
+
histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
|
525
548
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
526
549
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
527
550
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
@@ -529,6 +552,8 @@ class Logger:
|
|
529
552
|
|
530
553
|
def clear(self) -> None:
|
531
554
|
self.scalars.clear()
|
555
|
+
self.distributions.clear()
|
556
|
+
self.histograms.clear()
|
532
557
|
self.strings.clear()
|
533
558
|
self.images.clear()
|
534
559
|
self.videos.clear()
|
@@ -612,6 +637,76 @@ class Logger:
|
|
612
637
|
|
613
638
|
self.scalars[namespace][key] = scalar_future
|
614
639
|
|
640
|
+
def log_distribution(
|
641
|
+
self,
|
642
|
+
key: str,
|
643
|
+
value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
|
644
|
+
*,
|
645
|
+
namespace: str | None = None,
|
646
|
+
) -> None:
|
647
|
+
"""Logs a distribution value.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
key: The key being logged
|
651
|
+
value: The distribution value being logged, a tuple of (mean, std)
|
652
|
+
namespace: An optional logging namespace
|
653
|
+
"""
|
654
|
+
if not self.active:
|
655
|
+
raise RuntimeError("The logger is not active")
|
656
|
+
namespace = self.resolve_namespace(namespace)
|
657
|
+
|
658
|
+
@functools.lru_cache(maxsize=None)
|
659
|
+
def distribution_future() -> LogDistribution:
|
660
|
+
mean, std = value() if callable(value) else value
|
661
|
+
return LogDistribution(mean=mean, std=std)
|
662
|
+
|
663
|
+
self.distributions[namespace][key] = distribution_future
|
664
|
+
|
665
|
+
def log_histogram(
|
666
|
+
self,
|
667
|
+
key: str,
|
668
|
+
value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
|
669
|
+
*,
|
670
|
+
bins: int = 100,
|
671
|
+
namespace: str | None = None,
|
672
|
+
) -> None:
|
673
|
+
"""Logs a histogram value.
|
674
|
+
|
675
|
+
Args:
|
676
|
+
key: The key being logged
|
677
|
+
value: The histogram value being logged
|
678
|
+
bins: The number of bins to use for the histogram
|
679
|
+
namespace: An optional logging namespace
|
680
|
+
"""
|
681
|
+
if not self.active:
|
682
|
+
raise RuntimeError("The logger is not active")
|
683
|
+
namespace = self.resolve_namespace(namespace)
|
684
|
+
|
685
|
+
@functools.lru_cache(maxsize=None)
|
686
|
+
def histogram_future() -> LogHistogram:
|
687
|
+
values = value() if callable(value) else value
|
688
|
+
values = values.reshape(-1) # Must be flat.
|
689
|
+
|
690
|
+
if isinstance(values, Array):
|
691
|
+
counts, limits = jnp.histogram(values, bins=bins)
|
692
|
+
counts, limits = as_numpy(counts), as_numpy(limits)
|
693
|
+
elif isinstance(values, np.ndarray):
|
694
|
+
counts, limits = np.histogram(values, bins=bins)
|
695
|
+
else:
|
696
|
+
raise ValueError(f"Unsupported histogram type: {type(values)}")
|
697
|
+
|
698
|
+
return LogHistogram(
|
699
|
+
min=float(values.min()),
|
700
|
+
max=float(values.max()),
|
701
|
+
num=int(values.size),
|
702
|
+
sum=float(values.sum()),
|
703
|
+
sum_squares=float(values.dot(values)),
|
704
|
+
bucket_limits=limits[1:].tolist(),
|
705
|
+
bucket_counts=counts.tolist(),
|
706
|
+
)
|
707
|
+
|
708
|
+
self.histograms[namespace][key] = histogram_future
|
709
|
+
|
615
710
|
def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
|
616
711
|
"""Logs a string value.
|
617
712
|
|
xax/task/loggers/stdout.py
CHANGED
@@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl):
|
|
33
33
|
self,
|
34
34
|
write_fp: TextIO = sys.stdout,
|
35
35
|
precision: int = 4,
|
36
|
-
log_timers: bool =
|
36
|
+
log_timers: bool = True,
|
37
37
|
log_perf: bool = False,
|
38
38
|
log_optim: bool = False,
|
39
39
|
log_fp: bool = False,
|
@@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl):
|
|
98
98
|
|
99
99
|
def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
|
100
100
|
for namespace, values in log.items():
|
101
|
-
if not self.log_timers and namespace.startswith("
|
101
|
+
if not self.log_timers and namespace.startswith("⌛"):
|
102
102
|
continue
|
103
103
|
if not self.log_perf and namespace.startswith("🔧"):
|
104
104
|
continue
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1
1
|
"""Defines a Tensorboard logger backend."""
|
2
2
|
|
3
3
|
import atexit
|
4
|
-
import functools
|
5
4
|
import logging
|
6
5
|
import os
|
7
6
|
import re
|
8
|
-
import shutil
|
9
7
|
import subprocess
|
10
8
|
import threading
|
11
9
|
import time
|
@@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl):
|
|
140
138
|
def __del__(self) -> None:
|
141
139
|
self.cleanup()
|
142
140
|
|
143
|
-
@functools.lru_cache(None) # Avoid clearing logs multiple times.
|
144
|
-
def clear_logs(self) -> None:
|
145
|
-
if not self.log_directory.exists():
|
146
|
-
return
|
147
|
-
if not any(child.is_dir() for child in self.log_directory.iterdir()):
|
148
|
-
return
|
149
|
-
logger.warning("Clearing TensorBoard logs")
|
150
|
-
shutil.rmtree(self.log_directory)
|
151
|
-
|
152
141
|
def get_writer(self, phase: Phase) -> TensorboardWriter:
|
153
142
|
self._start()
|
154
143
|
return self.writers.writer(phase)
|
@@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl):
|
|
162
151
|
if not is_master():
|
163
152
|
return
|
164
153
|
|
165
|
-
if line.state.num_steps == 0:
|
166
|
-
self.clear_logs()
|
167
|
-
|
168
154
|
writer = self.get_writer(line.state.phase)
|
169
155
|
walltime = line.state.start_time_s + line.state.elapsed_time_s
|
170
156
|
|
@@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl):
|
|
177
163
|
walltime=walltime,
|
178
164
|
)
|
179
165
|
|
166
|
+
for namespace, distributions in line.distributions.items():
|
167
|
+
for distribution_key, distribution_value in distributions.items():
|
168
|
+
writer.add_gaussian_distribution(
|
169
|
+
f"{namespace}/{distribution_key}",
|
170
|
+
mean=float(distribution_value.mean),
|
171
|
+
std=float(distribution_value.std),
|
172
|
+
global_step=line.state.num_steps,
|
173
|
+
walltime=walltime,
|
174
|
+
)
|
175
|
+
|
176
|
+
for namespace, histograms in line.histograms.items():
|
177
|
+
for histogram_key, histogram_value in histograms.items():
|
178
|
+
writer.add_histogram_raw(
|
179
|
+
f"{namespace}/{histogram_key}",
|
180
|
+
min=float(histogram_value.min),
|
181
|
+
max=float(histogram_value.max),
|
182
|
+
num=int(histogram_value.num),
|
183
|
+
sum=float(histogram_value.sum),
|
184
|
+
sum_squares=float(histogram_value.sum_squares),
|
185
|
+
bucket_limits=[float(x) for x in histogram_value.bucket_limits],
|
186
|
+
bucket_counts=[int(x) for x in histogram_value.bucket_counts],
|
187
|
+
global_step=line.state.num_steps,
|
188
|
+
walltime=walltime,
|
189
|
+
)
|
190
|
+
|
180
191
|
for namespace, strings in line.strings.items():
|
181
192
|
for string_key, string_value in strings.items():
|
182
193
|
writer.add_text(
|
xax/task/mixins/artifacts.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
import functools
|
4
4
|
import inspect
|
5
5
|
import logging
|
6
|
-
import os
|
7
6
|
from dataclasses import dataclass
|
8
7
|
from pathlib import Path
|
9
8
|
from typing import Self, TypeVar
|
@@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
54
53
|
self._exp_dir = Path(exp_dir).expanduser().resolve()
|
55
54
|
return self
|
56
55
|
|
57
|
-
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
58
|
-
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
59
|
-
if not exists_ok:
|
60
|
-
raise RuntimeError(f"Lock file already exists at {lock_file}")
|
61
|
-
else:
|
62
|
-
with open(lock_file, "w", encoding="utf-8") as f:
|
63
|
-
f.write(f"PID: {os.getpid()}")
|
64
|
-
|
65
|
-
def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
|
66
|
-
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
67
|
-
lock_file.unlink()
|
68
|
-
elif not missing_ok:
|
69
|
-
raise RuntimeError(f"Lock file not found at {lock_file}")
|
70
|
-
|
71
56
|
def get_exp_dir(self) -> Path:
|
72
57
|
if self._exp_dir is not None:
|
73
58
|
return self._exp_dir
|
@@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
82
67
|
def get_exp_dir(run_id: int) -> Path:
|
83
68
|
return self.run_dir / f"run_{run_id}"
|
84
69
|
|
85
|
-
def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
|
86
|
-
if lock_type is not None:
|
87
|
-
return (exp_dir / f".lock_{lock_type}").exists()
|
88
|
-
return any(exp_dir.glob(".lock_*"))
|
89
|
-
|
90
70
|
run_id = 0
|
91
|
-
while (exp_dir := get_exp_dir(run_id)).is_dir()
|
71
|
+
while (exp_dir := get_exp_dir(run_id)).is_dir():
|
92
72
|
run_id += 1
|
93
73
|
exp_dir.mkdir(exist_ok=True, parents=True)
|
94
74
|
self._exp_dir = exp_dir.expanduser().resolve()
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
24
|
-
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
|
24
|
+
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
|
25
25
|
|
26
26
|
|
27
27
|
def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
@@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
88
88
|
def load_checkpoint(
|
89
89
|
self,
|
90
90
|
path: Path,
|
91
|
+
part: Literal["all"] = "all",
|
91
92
|
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
92
93
|
|
94
|
+
@overload
|
95
|
+
def load_checkpoint(
|
96
|
+
self,
|
97
|
+
path: Path,
|
98
|
+
part: Literal["model_state_config"] = "model_state_config",
|
99
|
+
) -> tuple[PyTree, State, DictConfig]: ...
|
100
|
+
|
93
101
|
@overload
|
94
102
|
def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
|
95
103
|
|
@@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
108
116
|
def load_checkpoint(
|
109
117
|
self,
|
110
118
|
path: Path,
|
111
|
-
part: CheckpointPart
|
119
|
+
part: CheckpointPart = "all",
|
112
120
|
) -> (
|
113
121
|
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
122
|
+
| tuple[PyTree, State, DictConfig]
|
114
123
|
| PyTree
|
115
124
|
| optax.GradientTransformation
|
116
125
|
| optax.OptState
|
117
126
|
| State
|
118
127
|
| DictConfig
|
119
128
|
):
|
129
|
+
# Calls the base callback.
|
130
|
+
self.on_before_checkpoint_load(path)
|
131
|
+
|
120
132
|
with tarfile.open(path, "r:gz") as tar:
|
121
133
|
|
122
134
|
def get_model() -> PyTree:
|
@@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
155
167
|
return get_state()
|
156
168
|
case "config":
|
157
169
|
return get_config()
|
158
|
-
case
|
170
|
+
case "model_state_config":
|
171
|
+
return get_model(), get_state(), get_config()
|
172
|
+
case "all":
|
159
173
|
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
160
174
|
case _:
|
161
175
|
raise ValueError(f"Invalid checkpoint part: {part}")
|
@@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
215
229
|
except FileExistsError:
|
216
230
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
217
231
|
|
218
|
-
#
|
219
|
-
self.
|
232
|
+
# Calls the base callback.
|
233
|
+
self.on_after_checkpoint_save(ckpt_path, state)
|
220
234
|
|
221
235
|
return ckpt_path
|
xax/task/mixins/logger.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar
|
|
8
8
|
|
9
9
|
import jax
|
10
10
|
|
11
|
+
from xax.core.conf import field
|
11
12
|
from xax.core.state import State
|
12
13
|
from xax.task.base import BaseConfig, BaseTask
|
13
14
|
from xax.task.logger import Logger, LoggerImpl
|
@@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session
|
|
22
23
|
@jax.tree_util.register_dataclass
|
23
24
|
@dataclass
|
24
25
|
class LoggerConfig(BaseConfig):
|
25
|
-
|
26
|
+
log_interval_seconds: float = field(
|
27
|
+
value=1.0,
|
28
|
+
help="The interval between successive log lines.",
|
29
|
+
)
|
30
|
+
tensorboard_log_interval_seconds: float = field(
|
31
|
+
value=10.0,
|
32
|
+
help="The interval between successive Tensorboard log lines.",
|
33
|
+
)
|
26
34
|
|
27
35
|
|
28
36
|
Config = TypeVar("Config", bound=LoggerConfig)
|
@@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
|
|
49
57
|
self.logger.add_logger(*logger)
|
50
58
|
|
51
59
|
def set_loggers(self) -> None:
|
52
|
-
self.add_logger(
|
60
|
+
self.add_logger(
|
61
|
+
StdoutLogger(
|
62
|
+
log_interval_seconds=self.config.log_interval_seconds,
|
63
|
+
)
|
64
|
+
if is_interactive_session()
|
65
|
+
else JsonLogger(
|
66
|
+
log_interval_seconds=self.config.log_interval_seconds,
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
# If this is also an ArtifactsMixin, we should default add some
|
71
|
+
# additional loggers which log data to the artifacts directory.
|
53
72
|
if isinstance(self, ArtifactsMixin):
|
54
73
|
self.add_logger(
|
55
|
-
StateLogger(
|
56
|
-
|
74
|
+
StateLogger(
|
75
|
+
run_directory=self.exp_dir,
|
76
|
+
),
|
77
|
+
TensorboardLogger(
|
78
|
+
run_directory=self.exp_dir,
|
79
|
+
log_interval_seconds=self.config.tensorboard_log_interval_seconds,
|
80
|
+
),
|
57
81
|
)
|
58
82
|
|
59
83
|
def write_logs(self, state: State) -> None:
|
xax/task/mixins/step_wrapper.py
CHANGED
@@ -1,53 +1,39 @@
|
|
1
1
|
"""Defines a mixin to wrap some steps in a context manager."""
|
2
2
|
|
3
|
+
import time
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from types import TracebackType
|
5
|
-
from typing import
|
6
|
+
from typing import Callable, ContextManager, TypeVar
|
6
7
|
|
7
|
-
import equinox as eqx
|
8
8
|
import jax
|
9
9
|
|
10
10
|
from xax.task.base import BaseConfig, BaseTask
|
11
11
|
|
12
|
-
StepType = Literal[
|
13
|
-
"backward",
|
14
|
-
"change_mode",
|
15
|
-
"clip_grads",
|
16
|
-
"create_optimizers",
|
17
|
-
"forward",
|
18
|
-
"get_dataloader",
|
19
|
-
"get_dataset",
|
20
|
-
"get_prefetcher",
|
21
|
-
"get_model",
|
22
|
-
"get_optimizer",
|
23
|
-
"get_initial_opt_state",
|
24
|
-
"get_update_fn",
|
25
|
-
"load_checkpoint",
|
26
|
-
"log_losses",
|
27
|
-
"model_to_device",
|
28
|
-
"on_step_end",
|
29
|
-
"on_step_start",
|
30
|
-
"save_checkpoint",
|
31
|
-
"step",
|
32
|
-
"update_state",
|
33
|
-
"write_logs",
|
34
|
-
"zero_grads",
|
35
|
-
]
|
36
|
-
|
37
12
|
|
38
13
|
class StepContext(ContextManager):
|
39
14
|
"""Context manager to get the current step type."""
|
40
15
|
|
41
|
-
CURRENT_STEP:
|
16
|
+
CURRENT_STEP: str | None = None
|
42
17
|
|
43
|
-
def __init__(
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
step: str,
|
21
|
+
on_context_start: Callable[[str], None],
|
22
|
+
on_context_end: Callable[[str, float], None],
|
23
|
+
) -> None:
|
44
24
|
self.step = step
|
25
|
+
self.start_time = 0.0
|
26
|
+
self.on_context_start = on_context_start
|
27
|
+
self.on_context_end = on_context_end
|
45
28
|
|
46
29
|
def __enter__(self) -> None:
|
47
30
|
StepContext.CURRENT_STEP = self.step
|
31
|
+
self.start_time = time.time()
|
32
|
+
self.on_context_start(self.step)
|
48
33
|
|
49
34
|
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
50
35
|
StepContext.CURRENT_STEP = None
|
36
|
+
self.on_context_end(self.step, time.time() - self.start_time)
|
51
37
|
|
52
38
|
|
53
39
|
@jax.tree_util.register_dataclass
|
@@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]):
|
|
63
49
|
def __init__(self, config: Config) -> None:
|
64
50
|
super().__init__(config)
|
65
51
|
|
66
|
-
|
67
|
-
|
68
|
-
|
52
|
+
def step_context(self, step: str) -> ContextManager:
|
53
|
+
return StepContext(step, self.on_context_start, self.on_context_stop)
|
54
|
+
|
55
|
+
def on_context_start(self, step: str) -> None:
|
56
|
+
pass
|
57
|
+
|
58
|
+
def on_context_stop(self, step: str, elapsed_time: float) -> None:
|
59
|
+
pass
|