xax 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +220 -44
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -35
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/logging.py +12 -2
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/RECORD +27 -22
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/top_level.txt +0 -0
xax/task/logger.py
CHANGED
@@ -18,7 +18,7 @@ from abc import ABC, abstractmethod
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from types import TracebackType
|
21
|
-
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
|
21
|
+
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
|
22
22
|
|
23
23
|
import jax
|
24
24
|
import jax.numpy as jnp
|
@@ -28,7 +28,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|
28
28
|
from PIL.Image import Image as PILImage
|
29
29
|
|
30
30
|
from xax.core.state import Phase, State
|
31
|
-
from xax.utils.experiments import IntervalTicker
|
31
|
+
from xax.utils.experiments import ContextTimer, IntervalTicker
|
32
32
|
from xax.utils.logging import LOG_ERROR_SUMMARY, LOG_PING, LOG_STATUS
|
33
33
|
|
34
34
|
logger = logging.getLogger(__name__)
|
@@ -223,10 +223,29 @@ class LogVideo:
|
|
223
223
|
fps: int
|
224
224
|
|
225
225
|
|
226
|
+
@dataclass(kw_only=True)
|
227
|
+
class LogDistribution:
|
228
|
+
mean: Number
|
229
|
+
std: Number
|
230
|
+
|
231
|
+
|
232
|
+
@dataclass(kw_only=True)
|
233
|
+
class LogHistogram:
|
234
|
+
min: Number
|
235
|
+
max: Number
|
236
|
+
num: int
|
237
|
+
sum: Number
|
238
|
+
sum_squares: Number
|
239
|
+
bucket_limits: list[float]
|
240
|
+
bucket_counts: list[int]
|
241
|
+
|
242
|
+
|
226
243
|
@dataclass(kw_only=True)
|
227
244
|
class LogLine:
|
228
245
|
state: State
|
229
246
|
scalars: dict[str, dict[str, Number]]
|
247
|
+
distributions: dict[str, dict[str, LogDistribution]]
|
248
|
+
histograms: dict[str, dict[str, LogHistogram]]
|
230
249
|
strings: dict[str, dict[str, str]]
|
231
250
|
images: dict[str, dict[str, LogImage]]
|
232
251
|
videos: dict[str, dict[str, LogVideo]]
|
@@ -329,9 +348,9 @@ def image_with_text(
|
|
329
348
|
else:
|
330
349
|
text = text[:max_num_lines]
|
331
350
|
width, height = image.size
|
332
|
-
font: ImageFont.ImageFont = ImageFont.load_default()
|
351
|
+
font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
|
333
352
|
_, _, _, line_height = font.getbbox(text[0])
|
334
|
-
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
353
|
+
new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
|
335
354
|
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
336
355
|
padded_image.paste(image, (0, 0))
|
337
356
|
drawer = ImageDraw.Draw(padded_image)
|
@@ -497,6 +516,8 @@ class Logger:
|
|
497
516
|
|
498
517
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
499
518
|
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
519
|
+
self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
|
520
|
+
self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
|
500
521
|
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
501
522
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
502
523
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
@@ -522,6 +543,8 @@ class Logger:
|
|
522
543
|
return LogLine(
|
523
544
|
state=state,
|
524
545
|
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
546
|
+
distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
|
547
|
+
histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
|
525
548
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
526
549
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
527
550
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
@@ -529,6 +552,8 @@ class Logger:
|
|
529
552
|
|
530
553
|
def clear(self) -> None:
|
531
554
|
self.scalars.clear()
|
555
|
+
self.distributions.clear()
|
556
|
+
self.histograms.clear()
|
532
557
|
self.strings.clear()
|
533
558
|
self.images.clear()
|
534
559
|
self.videos.clear()
|
@@ -539,14 +564,14 @@ class Logger:
|
|
539
564
|
Args:
|
540
565
|
state: The current step's state.
|
541
566
|
"""
|
542
|
-
should_log = [
|
567
|
+
should_log = [lg.should_log(state) for lg in self.loggers]
|
543
568
|
if not any(should_log):
|
544
569
|
self.clear()
|
545
570
|
return
|
546
571
|
line = self.pack(state)
|
547
572
|
self.clear()
|
548
|
-
for
|
549
|
-
|
573
|
+
for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
|
574
|
+
lg.write(line)
|
550
575
|
|
551
576
|
def write_error_summary(self, error_summary: str) -> None:
|
552
577
|
for logger in self.loggers:
|
@@ -608,10 +633,143 @@ class Logger:
|
|
608
633
|
|
609
634
|
@functools.lru_cache(maxsize=None)
|
610
635
|
def scalar_future() -> Number:
|
611
|
-
|
636
|
+
with ContextTimer() as timer:
|
637
|
+
value_concrete = value() if callable(value) else value
|
638
|
+
logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
|
639
|
+
return value_concrete
|
612
640
|
|
613
641
|
self.scalars[namespace][key] = scalar_future
|
614
642
|
|
643
|
+
def log_distribution(
|
644
|
+
self,
|
645
|
+
key: str,
|
646
|
+
value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
|
647
|
+
*,
|
648
|
+
namespace: str | None = None,
|
649
|
+
) -> None:
|
650
|
+
"""Logs a distribution value.
|
651
|
+
|
652
|
+
Args:
|
653
|
+
key: The key being logged
|
654
|
+
value: The distribution value being logged, a tuple of (mean, std)
|
655
|
+
namespace: An optional logging namespace
|
656
|
+
"""
|
657
|
+
if not self.active:
|
658
|
+
raise RuntimeError("The logger is not active")
|
659
|
+
namespace = self.resolve_namespace(namespace)
|
660
|
+
|
661
|
+
@functools.lru_cache(maxsize=None)
|
662
|
+
def distribution_future() -> LogDistribution:
|
663
|
+
with ContextTimer() as timer:
|
664
|
+
mean, std = value() if callable(value) else value
|
665
|
+
logger.debug("Distribution Key: %s, Time: %s", key, timer.elapsed_time)
|
666
|
+
return LogDistribution(mean=mean, std=std)
|
667
|
+
|
668
|
+
self.distributions[namespace][key] = distribution_future
|
669
|
+
|
670
|
+
def log_histogram(
|
671
|
+
self,
|
672
|
+
key: str,
|
673
|
+
value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
|
674
|
+
*,
|
675
|
+
bins: int = 100,
|
676
|
+
namespace: str | None = None,
|
677
|
+
) -> None:
|
678
|
+
"""Logs a histogram value.
|
679
|
+
|
680
|
+
Args:
|
681
|
+
key: The key being logged
|
682
|
+
value: The histogram value being logged
|
683
|
+
bins: The number of bins to use for the histogram
|
684
|
+
namespace: An optional logging namespace
|
685
|
+
"""
|
686
|
+
if not self.active:
|
687
|
+
raise RuntimeError("The logger is not active")
|
688
|
+
namespace = self.resolve_namespace(namespace)
|
689
|
+
|
690
|
+
@functools.lru_cache(maxsize=None)
|
691
|
+
def histogram_future() -> LogHistogram:
|
692
|
+
with ContextTimer() as timer:
|
693
|
+
values = value() if callable(value) else value
|
694
|
+
values = values.reshape(-1) # Must be flat.
|
695
|
+
|
696
|
+
if isinstance(values, Array):
|
697
|
+
counts, limits = jnp.histogram(values, bins=bins)
|
698
|
+
counts, limits = as_numpy(counts), as_numpy(limits)
|
699
|
+
elif isinstance(values, np.ndarray):
|
700
|
+
counts, limits = np.histogram(values, bins=bins)
|
701
|
+
else:
|
702
|
+
raise ValueError(f"Unsupported histogram type: {type(values)}")
|
703
|
+
|
704
|
+
histogram_values = LogHistogram(
|
705
|
+
min=float(values.min()),
|
706
|
+
max=float(values.max()),
|
707
|
+
num=int(values.size),
|
708
|
+
sum=float(values.sum()),
|
709
|
+
sum_squares=float(values.dot(values)),
|
710
|
+
bucket_limits=cast(list[float], limits[1:].tolist()),
|
711
|
+
bucket_counts=cast(list[int], counts.tolist()),
|
712
|
+
)
|
713
|
+
|
714
|
+
logger.debug("Histogram Key: %s, Time: %s", key, timer.elapsed_time)
|
715
|
+
return histogram_values
|
716
|
+
|
717
|
+
self.histograms[namespace][key] = histogram_future
|
718
|
+
|
719
|
+
def log_histogram_raw(
|
720
|
+
self,
|
721
|
+
key: str,
|
722
|
+
counts: Array | np.ndarray,
|
723
|
+
limits: Array | np.ndarray,
|
724
|
+
minv: Number | None = None,
|
725
|
+
maxv: Number | None = None,
|
726
|
+
sumv: Number | None = None,
|
727
|
+
sum_squaresv: Number | None = None,
|
728
|
+
*,
|
729
|
+
namespace: str | None = None,
|
730
|
+
) -> None:
|
731
|
+
"""Logs a histogram from raw counts and limits.
|
732
|
+
|
733
|
+
Args:
|
734
|
+
key: The key being logged
|
735
|
+
counts: The counts of the histogram
|
736
|
+
limits: The limits of the histogram
|
737
|
+
minv: The minimum value of the histogram
|
738
|
+
maxv: The maximum value of the histogram
|
739
|
+
sumv: The sum of the histogram
|
740
|
+
sum_squaresv: The sum of the squares of the histogram
|
741
|
+
namespace: An optional logging namespace
|
742
|
+
"""
|
743
|
+
if not self.active:
|
744
|
+
raise RuntimeError("The logger is not active")
|
745
|
+
namespace = self.resolve_namespace(namespace)
|
746
|
+
|
747
|
+
@functools.lru_cache(maxsize=None)
|
748
|
+
def histogram_future() -> LogHistogram:
|
749
|
+
with ContextTimer() as timer:
|
750
|
+
counts_np = (as_numpy(counts) if isinstance(counts, Array) else counts).astype(int)
|
751
|
+
limits_np = (as_numpy(limits) if isinstance(limits, Array) else limits).astype(float)
|
752
|
+
|
753
|
+
minv_ = counts_np.min() if minv is None else minv
|
754
|
+
maxv_ = counts_np.max() if maxv is None else maxv
|
755
|
+
sumv_ = counts_np.sum() if sumv is None else sumv
|
756
|
+
sum_squaresv_ = counts_np.dot(counts_np) if sum_squaresv is None else sum_squaresv
|
757
|
+
|
758
|
+
histogram_values = LogHistogram(
|
759
|
+
min=float(minv_),
|
760
|
+
max=float(maxv_),
|
761
|
+
num=int(counts_np.size),
|
762
|
+
sum=float(sumv_),
|
763
|
+
sum_squares=float(sum_squaresv_),
|
764
|
+
bucket_limits=cast(list[float], limits_np.tolist()),
|
765
|
+
bucket_counts=cast(list[int], counts_np.tolist()),
|
766
|
+
)
|
767
|
+
|
768
|
+
logger.debug("Raw Histogram Key: %s, Time: %s", key, timer.elapsed_time)
|
769
|
+
return histogram_values
|
770
|
+
|
771
|
+
self.histograms[namespace][key] = histogram_future
|
772
|
+
|
615
773
|
def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
|
616
774
|
"""Logs a string value.
|
617
775
|
|
@@ -653,7 +811,10 @@ class Logger:
|
|
653
811
|
|
654
812
|
@functools.lru_cache(maxsize=None)
|
655
813
|
def image_future() -> LogImage:
|
656
|
-
|
814
|
+
with ContextTimer() as timer:
|
815
|
+
image = get_image(value() if callable(value) else value, target_resolution)
|
816
|
+
logger.debug("Image Key: %s, Time: %s", key, timer.elapsed_time)
|
817
|
+
return image
|
657
818
|
|
658
819
|
self.images[namespace][key] = image_future
|
659
820
|
|
@@ -689,15 +850,20 @@ class Logger:
|
|
689
850
|
|
690
851
|
@functools.lru_cache(maxsize=None)
|
691
852
|
def image_future() -> LogImage:
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
853
|
+
with ContextTimer() as timer:
|
854
|
+
image, label = value() if callable(value) else value
|
855
|
+
image = get_image(image, target_resolution)
|
856
|
+
|
857
|
+
image_value = image_with_text(
|
858
|
+
image.image,
|
859
|
+
standardize_text(label, max_line_length),
|
860
|
+
max_num_lines=max_num_lines,
|
861
|
+
line_spacing=line_spacing,
|
862
|
+
centered=centered,
|
863
|
+
)
|
864
|
+
|
865
|
+
logger.debug("Labeled Image Key: %s, Time: %s", key, timer.elapsed_time)
|
866
|
+
return image_value
|
701
867
|
|
702
868
|
self.images[namespace][key] = image_future
|
703
869
|
|
@@ -736,15 +902,18 @@ class Logger:
|
|
736
902
|
|
737
903
|
@functools.lru_cache(maxsize=None)
|
738
904
|
def images_future() -> LogImage:
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
905
|
+
with ContextTimer() as timer:
|
906
|
+
images = value() if callable(value) else value
|
907
|
+
if max_images is not None:
|
908
|
+
images = images[:max_images]
|
909
|
+
if isinstance(images, Array):
|
910
|
+
images = as_numpy(images)
|
911
|
+
if isinstance(images, Sequence):
|
912
|
+
images = list(images)
|
913
|
+
images = [get_image(image, target_resolution) for image in images]
|
914
|
+
tiled = tile_images([img.image for img in images], sep)
|
915
|
+
|
916
|
+
logger.debug("Images Key: %s, Time: %s", key, timer.elapsed_time)
|
748
917
|
return LogImage(image=tiled)
|
749
918
|
|
750
919
|
self.images[namespace][key] = images_future
|
@@ -791,22 +960,25 @@ class Logger:
|
|
791
960
|
|
792
961
|
@functools.lru_cache(maxsize=None)
|
793
962
|
def images_future() -> LogImage:
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
963
|
+
with ContextTimer() as timer:
|
964
|
+
images, labels = value() if callable(value) else value
|
965
|
+
if max_images is not None:
|
966
|
+
images = images[:max_images]
|
967
|
+
labels = labels[:max_images]
|
968
|
+
images = [get_image(image, target_resolution) for image in images]
|
969
|
+
labeled = [
|
970
|
+
image_with_text(
|
971
|
+
img.image,
|
972
|
+
standardize_text(label, max_line_length),
|
973
|
+
max_num_lines=max_num_lines,
|
974
|
+
line_spacing=line_spacing,
|
975
|
+
centered=centered,
|
976
|
+
)
|
977
|
+
for img, label in zip(images, labels)
|
978
|
+
]
|
979
|
+
tiled = tile_images([img.image for img in labeled], sep)
|
980
|
+
|
981
|
+
logger.debug("Labeled Images Key: %s, Time: %s", key, timer.elapsed_time)
|
810
982
|
return LogImage(image=tiled)
|
811
983
|
|
812
984
|
self.images[namespace][key] = images_future
|
@@ -839,7 +1011,11 @@ class Logger:
|
|
839
1011
|
|
840
1012
|
@functools.lru_cache(maxsize=None)
|
841
1013
|
def video_future() -> LogVideo:
|
842
|
-
|
1014
|
+
with ContextTimer() as timer:
|
1015
|
+
video = get_video(value() if callable(value) else value, fps=fps)
|
1016
|
+
|
1017
|
+
logger.debug("Video Key: %s, Time: %s", key, timer.elapsed_time)
|
1018
|
+
return video
|
843
1019
|
|
844
1020
|
self.videos[namespace][key] = video_future
|
845
1021
|
|
xax/task/loggers/stdout.py
CHANGED
@@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl):
|
|
33
33
|
self,
|
34
34
|
write_fp: TextIO = sys.stdout,
|
35
35
|
precision: int = 4,
|
36
|
-
log_timers: bool =
|
36
|
+
log_timers: bool = True,
|
37
37
|
log_perf: bool = False,
|
38
38
|
log_optim: bool = False,
|
39
39
|
log_fp: bool = False,
|
@@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl):
|
|
98
98
|
|
99
99
|
def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
|
100
100
|
for namespace, values in log.items():
|
101
|
-
if not self.log_timers and namespace.startswith("
|
101
|
+
if not self.log_timers and namespace.startswith("⌛"):
|
102
102
|
continue
|
103
103
|
if not self.log_perf and namespace.startswith("🔧"):
|
104
104
|
continue
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1
1
|
"""Defines a Tensorboard logger backend."""
|
2
2
|
|
3
3
|
import atexit
|
4
|
-
import functools
|
5
4
|
import logging
|
6
5
|
import os
|
7
6
|
import re
|
8
|
-
import shutil
|
9
7
|
import subprocess
|
10
8
|
import threading
|
11
9
|
import time
|
@@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl):
|
|
140
138
|
def __del__(self) -> None:
|
141
139
|
self.cleanup()
|
142
140
|
|
143
|
-
@functools.lru_cache(None) # Avoid clearing logs multiple times.
|
144
|
-
def clear_logs(self) -> None:
|
145
|
-
if not self.log_directory.exists():
|
146
|
-
return
|
147
|
-
if not any(child.is_dir() for child in self.log_directory.iterdir()):
|
148
|
-
return
|
149
|
-
logger.warning("Clearing TensorBoard logs")
|
150
|
-
shutil.rmtree(self.log_directory)
|
151
|
-
|
152
141
|
def get_writer(self, phase: Phase) -> TensorboardWriter:
|
153
142
|
self._start()
|
154
143
|
return self.writers.writer(phase)
|
@@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl):
|
|
162
151
|
if not is_master():
|
163
152
|
return
|
164
153
|
|
165
|
-
if line.state.num_steps == 0:
|
166
|
-
self.clear_logs()
|
167
|
-
|
168
154
|
writer = self.get_writer(line.state.phase)
|
169
155
|
walltime = line.state.start_time_s + line.state.elapsed_time_s
|
170
156
|
|
@@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl):
|
|
177
163
|
walltime=walltime,
|
178
164
|
)
|
179
165
|
|
166
|
+
for namespace, distributions in line.distributions.items():
|
167
|
+
for distribution_key, distribution_value in distributions.items():
|
168
|
+
writer.add_gaussian_distribution(
|
169
|
+
f"{namespace}/{distribution_key}",
|
170
|
+
mean=float(distribution_value.mean),
|
171
|
+
std=float(distribution_value.std),
|
172
|
+
global_step=line.state.num_steps,
|
173
|
+
walltime=walltime,
|
174
|
+
)
|
175
|
+
|
176
|
+
for namespace, histograms in line.histograms.items():
|
177
|
+
for histogram_key, histogram_value in histograms.items():
|
178
|
+
writer.add_histogram_raw(
|
179
|
+
f"{namespace}/{histogram_key}",
|
180
|
+
min=float(histogram_value.min),
|
181
|
+
max=float(histogram_value.max),
|
182
|
+
num=int(histogram_value.num),
|
183
|
+
sum=float(histogram_value.sum),
|
184
|
+
sum_squares=float(histogram_value.sum_squares),
|
185
|
+
bucket_limits=[float(x) for x in histogram_value.bucket_limits],
|
186
|
+
bucket_counts=[int(x) for x in histogram_value.bucket_counts],
|
187
|
+
global_step=line.state.num_steps,
|
188
|
+
walltime=walltime,
|
189
|
+
)
|
190
|
+
|
180
191
|
for namespace, strings in line.strings.items():
|
181
192
|
for string_key, string_value in strings.items():
|
182
193
|
writer.add_text(
|
xax/task/mixins/artifacts.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
import functools
|
4
4
|
import inspect
|
5
5
|
import logging
|
6
|
-
import os
|
7
6
|
from dataclasses import dataclass
|
8
7
|
from pathlib import Path
|
9
8
|
from typing import Self, TypeVar
|
@@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
54
53
|
self._exp_dir = Path(exp_dir).expanduser().resolve()
|
55
54
|
return self
|
56
55
|
|
57
|
-
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
58
|
-
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
59
|
-
if not exists_ok:
|
60
|
-
raise RuntimeError(f"Lock file already exists at {lock_file}")
|
61
|
-
else:
|
62
|
-
with open(lock_file, "w", encoding="utf-8") as f:
|
63
|
-
f.write(f"PID: {os.getpid()}")
|
64
|
-
|
65
|
-
def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
|
66
|
-
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
67
|
-
lock_file.unlink()
|
68
|
-
elif not missing_ok:
|
69
|
-
raise RuntimeError(f"Lock file not found at {lock_file}")
|
70
|
-
|
71
56
|
def get_exp_dir(self) -> Path:
|
72
57
|
if self._exp_dir is not None:
|
73
58
|
return self._exp_dir
|
@@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
82
67
|
def get_exp_dir(run_id: int) -> Path:
|
83
68
|
return self.run_dir / f"run_{run_id}"
|
84
69
|
|
85
|
-
def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
|
86
|
-
if lock_type is not None:
|
87
|
-
return (exp_dir / f".lock_{lock_type}").exists()
|
88
|
-
return any(exp_dir.glob(".lock_*"))
|
89
|
-
|
90
70
|
run_id = 0
|
91
|
-
while (exp_dir := get_exp_dir(run_id)).is_dir()
|
71
|
+
while (exp_dir := get_exp_dir(run_id)).is_dir():
|
92
72
|
run_id += 1
|
93
73
|
exp_dir.mkdir(exist_ok=True, parents=True)
|
94
74
|
self._exp_dir = exp_dir.expanduser().resolve()
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
24
|
-
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
|
24
|
+
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
|
25
25
|
|
26
26
|
|
27
27
|
def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
@@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
88
88
|
def load_checkpoint(
|
89
89
|
self,
|
90
90
|
path: Path,
|
91
|
+
part: Literal["all"] = "all",
|
91
92
|
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
92
93
|
|
94
|
+
@overload
|
95
|
+
def load_checkpoint(
|
96
|
+
self,
|
97
|
+
path: Path,
|
98
|
+
part: Literal["model_state_config"] = "model_state_config",
|
99
|
+
) -> tuple[PyTree, State, DictConfig]: ...
|
100
|
+
|
93
101
|
@overload
|
94
102
|
def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
|
95
103
|
|
@@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
108
116
|
def load_checkpoint(
|
109
117
|
self,
|
110
118
|
path: Path,
|
111
|
-
part: CheckpointPart
|
119
|
+
part: CheckpointPart = "all",
|
112
120
|
) -> (
|
113
121
|
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
122
|
+
| tuple[PyTree, State, DictConfig]
|
114
123
|
| PyTree
|
115
124
|
| optax.GradientTransformation
|
116
125
|
| optax.OptState
|
117
126
|
| State
|
118
127
|
| DictConfig
|
119
128
|
):
|
129
|
+
# Calls the base callback.
|
130
|
+
self.on_before_checkpoint_load(path)
|
131
|
+
|
120
132
|
with tarfile.open(path, "r:gz") as tar:
|
121
133
|
|
122
134
|
def get_model() -> PyTree:
|
@@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
155
167
|
return get_state()
|
156
168
|
case "config":
|
157
169
|
return get_config()
|
158
|
-
case
|
170
|
+
case "model_state_config":
|
171
|
+
return get_model(), get_state(), get_config()
|
172
|
+
case "all":
|
159
173
|
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
160
174
|
case _:
|
161
175
|
raise ValueError(f"Invalid checkpoint part: {part}")
|
@@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
215
229
|
except FileExistsError:
|
216
230
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
217
231
|
|
218
|
-
#
|
219
|
-
self.
|
232
|
+
# Calls the base callback.
|
233
|
+
self.on_after_checkpoint_save(ckpt_path, state)
|
220
234
|
|
221
235
|
return ckpt_path
|
xax/task/mixins/logger.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar
|
|
8
8
|
|
9
9
|
import jax
|
10
10
|
|
11
|
+
from xax.core.conf import field
|
11
12
|
from xax.core.state import State
|
12
13
|
from xax.task.base import BaseConfig, BaseTask
|
13
14
|
from xax.task.logger import Logger, LoggerImpl
|
@@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session
|
|
22
23
|
@jax.tree_util.register_dataclass
|
23
24
|
@dataclass
|
24
25
|
class LoggerConfig(BaseConfig):
|
25
|
-
|
26
|
+
log_interval_seconds: float = field(
|
27
|
+
value=1.0,
|
28
|
+
help="The interval between successive log lines.",
|
29
|
+
)
|
30
|
+
tensorboard_log_interval_seconds: float = field(
|
31
|
+
value=10.0,
|
32
|
+
help="The interval between successive Tensorboard log lines.",
|
33
|
+
)
|
26
34
|
|
27
35
|
|
28
36
|
Config = TypeVar("Config", bound=LoggerConfig)
|
@@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
|
|
49
57
|
self.logger.add_logger(*logger)
|
50
58
|
|
51
59
|
def set_loggers(self) -> None:
|
52
|
-
self.add_logger(
|
60
|
+
self.add_logger(
|
61
|
+
StdoutLogger(
|
62
|
+
log_interval_seconds=self.config.log_interval_seconds,
|
63
|
+
)
|
64
|
+
if is_interactive_session()
|
65
|
+
else JsonLogger(
|
66
|
+
log_interval_seconds=self.config.log_interval_seconds,
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
# If this is also an ArtifactsMixin, we should default add some
|
71
|
+
# additional loggers which log data to the artifacts directory.
|
53
72
|
if isinstance(self, ArtifactsMixin):
|
54
73
|
self.add_logger(
|
55
|
-
StateLogger(
|
56
|
-
|
74
|
+
StateLogger(
|
75
|
+
run_directory=self.exp_dir,
|
76
|
+
),
|
77
|
+
TensorboardLogger(
|
78
|
+
run_directory=self.exp_dir,
|
79
|
+
log_interval_seconds=self.config.tensorboard_log_interval_seconds,
|
80
|
+
),
|
57
81
|
)
|
58
82
|
|
59
83
|
def write_logs(self, state: State) -> None:
|