xax 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/logger.py CHANGED
@@ -18,7 +18,7 @@ from abc import ABC, abstractmethod
18
18
  from collections import defaultdict
19
19
  from dataclasses import dataclass
20
20
  from types import TracebackType
21
- from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
21
+ from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
22
22
 
23
23
  import jax
24
24
  import jax.numpy as jnp
@@ -28,7 +28,7 @@ from PIL import Image, ImageDraw, ImageFont
28
28
  from PIL.Image import Image as PILImage
29
29
 
30
30
  from xax.core.state import Phase, State
31
- from xax.utils.experiments import IntervalTicker
31
+ from xax.utils.experiments import ContextTimer, IntervalTicker
32
32
  from xax.utils.logging import LOG_ERROR_SUMMARY, LOG_PING, LOG_STATUS
33
33
 
34
34
  logger = logging.getLogger(__name__)
@@ -223,10 +223,29 @@ class LogVideo:
223
223
  fps: int
224
224
 
225
225
 
226
+ @dataclass(kw_only=True)
227
+ class LogDistribution:
228
+ mean: Number
229
+ std: Number
230
+
231
+
232
+ @dataclass(kw_only=True)
233
+ class LogHistogram:
234
+ min: Number
235
+ max: Number
236
+ num: int
237
+ sum: Number
238
+ sum_squares: Number
239
+ bucket_limits: list[float]
240
+ bucket_counts: list[int]
241
+
242
+
226
243
  @dataclass(kw_only=True)
227
244
  class LogLine:
228
245
  state: State
229
246
  scalars: dict[str, dict[str, Number]]
247
+ distributions: dict[str, dict[str, LogDistribution]]
248
+ histograms: dict[str, dict[str, LogHistogram]]
230
249
  strings: dict[str, dict[str, str]]
231
250
  images: dict[str, dict[str, LogImage]]
232
251
  videos: dict[str, dict[str, LogVideo]]
@@ -329,9 +348,9 @@ def image_with_text(
329
348
  else:
330
349
  text = text[:max_num_lines]
331
350
  width, height = image.size
332
- font: ImageFont.ImageFont = ImageFont.load_default()
351
+ font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
333
352
  _, _, _, line_height = font.getbbox(text[0])
334
- new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
353
+ new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
335
354
  padded_image = Image.new(image.mode, (new_width, new_height), 255)
336
355
  padded_image.paste(image, (0, 0))
337
356
  drawer = ImageDraw.Draw(padded_image)
@@ -497,6 +516,8 @@ class Logger:
497
516
 
498
517
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
499
518
  self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
519
+ self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
520
+ self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
500
521
  self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
501
522
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
502
523
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
@@ -522,6 +543,8 @@ class Logger:
522
543
  return LogLine(
523
544
  state=state,
524
545
  scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
546
+ distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
547
+ histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
525
548
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
526
549
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
527
550
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
@@ -529,6 +552,8 @@ class Logger:
529
552
 
530
553
  def clear(self) -> None:
531
554
  self.scalars.clear()
555
+ self.distributions.clear()
556
+ self.histograms.clear()
532
557
  self.strings.clear()
533
558
  self.images.clear()
534
559
  self.videos.clear()
@@ -539,14 +564,14 @@ class Logger:
539
564
  Args:
540
565
  state: The current step's state.
541
566
  """
542
- should_log = [logger.should_log(state) for logger in self.loggers]
567
+ should_log = [lg.should_log(state) for lg in self.loggers]
543
568
  if not any(should_log):
544
569
  self.clear()
545
570
  return
546
571
  line = self.pack(state)
547
572
  self.clear()
548
- for logger in (logger for logger, should_log in zip(self.loggers, should_log) if should_log):
549
- logger.write(line)
573
+ for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
574
+ lg.write(line)
550
575
 
551
576
  def write_error_summary(self, error_summary: str) -> None:
552
577
  for logger in self.loggers:
@@ -608,10 +633,143 @@ class Logger:
608
633
 
609
634
  @functools.lru_cache(maxsize=None)
610
635
  def scalar_future() -> Number:
611
- return value() if callable(value) else value
636
+ with ContextTimer() as timer:
637
+ value_concrete = value() if callable(value) else value
638
+ logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
639
+ return value_concrete
612
640
 
613
641
  self.scalars[namespace][key] = scalar_future
614
642
 
643
+ def log_distribution(
644
+ self,
645
+ key: str,
646
+ value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
647
+ *,
648
+ namespace: str | None = None,
649
+ ) -> None:
650
+ """Logs a distribution value.
651
+
652
+ Args:
653
+ key: The key being logged
654
+ value: The distribution value being logged, a tuple of (mean, std)
655
+ namespace: An optional logging namespace
656
+ """
657
+ if not self.active:
658
+ raise RuntimeError("The logger is not active")
659
+ namespace = self.resolve_namespace(namespace)
660
+
661
+ @functools.lru_cache(maxsize=None)
662
+ def distribution_future() -> LogDistribution:
663
+ with ContextTimer() as timer:
664
+ mean, std = value() if callable(value) else value
665
+ logger.debug("Distribution Key: %s, Time: %s", key, timer.elapsed_time)
666
+ return LogDistribution(mean=mean, std=std)
667
+
668
+ self.distributions[namespace][key] = distribution_future
669
+
670
+ def log_histogram(
671
+ self,
672
+ key: str,
673
+ value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
674
+ *,
675
+ bins: int = 100,
676
+ namespace: str | None = None,
677
+ ) -> None:
678
+ """Logs a histogram value.
679
+
680
+ Args:
681
+ key: The key being logged
682
+ value: The histogram value being logged
683
+ bins: The number of bins to use for the histogram
684
+ namespace: An optional logging namespace
685
+ """
686
+ if not self.active:
687
+ raise RuntimeError("The logger is not active")
688
+ namespace = self.resolve_namespace(namespace)
689
+
690
+ @functools.lru_cache(maxsize=None)
691
+ def histogram_future() -> LogHistogram:
692
+ with ContextTimer() as timer:
693
+ values = value() if callable(value) else value
694
+ values = values.reshape(-1) # Must be flat.
695
+
696
+ if isinstance(values, Array):
697
+ counts, limits = jnp.histogram(values, bins=bins)
698
+ counts, limits = as_numpy(counts), as_numpy(limits)
699
+ elif isinstance(values, np.ndarray):
700
+ counts, limits = np.histogram(values, bins=bins)
701
+ else:
702
+ raise ValueError(f"Unsupported histogram type: {type(values)}")
703
+
704
+ histogram_values = LogHistogram(
705
+ min=float(values.min()),
706
+ max=float(values.max()),
707
+ num=int(values.size),
708
+ sum=float(values.sum()),
709
+ sum_squares=float(values.dot(values)),
710
+ bucket_limits=cast(list[float], limits[1:].tolist()),
711
+ bucket_counts=cast(list[int], counts.tolist()),
712
+ )
713
+
714
+ logger.debug("Histogram Key: %s, Time: %s", key, timer.elapsed_time)
715
+ return histogram_values
716
+
717
+ self.histograms[namespace][key] = histogram_future
718
+
719
+ def log_histogram_raw(
720
+ self,
721
+ key: str,
722
+ counts: Array | np.ndarray,
723
+ limits: Array | np.ndarray,
724
+ minv: Number | None = None,
725
+ maxv: Number | None = None,
726
+ sumv: Number | None = None,
727
+ sum_squaresv: Number | None = None,
728
+ *,
729
+ namespace: str | None = None,
730
+ ) -> None:
731
+ """Logs a histogram from raw counts and limits.
732
+
733
+ Args:
734
+ key: The key being logged
735
+ counts: The counts of the histogram
736
+ limits: The limits of the histogram
737
+ minv: The minimum value of the histogram
738
+ maxv: The maximum value of the histogram
739
+ sumv: The sum of the histogram
740
+ sum_squaresv: The sum of the squares of the histogram
741
+ namespace: An optional logging namespace
742
+ """
743
+ if not self.active:
744
+ raise RuntimeError("The logger is not active")
745
+ namespace = self.resolve_namespace(namespace)
746
+
747
+ @functools.lru_cache(maxsize=None)
748
+ def histogram_future() -> LogHistogram:
749
+ with ContextTimer() as timer:
750
+ counts_np = (as_numpy(counts) if isinstance(counts, Array) else counts).astype(int)
751
+ limits_np = (as_numpy(limits) if isinstance(limits, Array) else limits).astype(float)
752
+
753
+ minv_ = counts_np.min() if minv is None else minv
754
+ maxv_ = counts_np.max() if maxv is None else maxv
755
+ sumv_ = counts_np.sum() if sumv is None else sumv
756
+ sum_squaresv_ = counts_np.dot(counts_np) if sum_squaresv is None else sum_squaresv
757
+
758
+ histogram_values = LogHistogram(
759
+ min=float(minv_),
760
+ max=float(maxv_),
761
+ num=int(counts_np.size),
762
+ sum=float(sumv_),
763
+ sum_squares=float(sum_squaresv_),
764
+ bucket_limits=cast(list[float], limits_np.tolist()),
765
+ bucket_counts=cast(list[int], counts_np.tolist()),
766
+ )
767
+
768
+ logger.debug("Raw Histogram Key: %s, Time: %s", key, timer.elapsed_time)
769
+ return histogram_values
770
+
771
+ self.histograms[namespace][key] = histogram_future
772
+
615
773
  def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
616
774
  """Logs a string value.
617
775
 
@@ -653,7 +811,10 @@ class Logger:
653
811
 
654
812
  @functools.lru_cache(maxsize=None)
655
813
  def image_future() -> LogImage:
656
- return get_image(value() if callable(value) else value, target_resolution)
814
+ with ContextTimer() as timer:
815
+ image = get_image(value() if callable(value) else value, target_resolution)
816
+ logger.debug("Image Key: %s, Time: %s", key, timer.elapsed_time)
817
+ return image
657
818
 
658
819
  self.images[namespace][key] = image_future
659
820
 
@@ -689,15 +850,20 @@ class Logger:
689
850
 
690
851
  @functools.lru_cache(maxsize=None)
691
852
  def image_future() -> LogImage:
692
- image, label = value() if callable(value) else value
693
- image = get_image(image, target_resolution)
694
- return image_with_text(
695
- image.image,
696
- standardize_text(label, max_line_length),
697
- max_num_lines=max_num_lines,
698
- line_spacing=line_spacing,
699
- centered=centered,
700
- )
853
+ with ContextTimer() as timer:
854
+ image, label = value() if callable(value) else value
855
+ image = get_image(image, target_resolution)
856
+
857
+ image_value = image_with_text(
858
+ image.image,
859
+ standardize_text(label, max_line_length),
860
+ max_num_lines=max_num_lines,
861
+ line_spacing=line_spacing,
862
+ centered=centered,
863
+ )
864
+
865
+ logger.debug("Labeled Image Key: %s, Time: %s", key, timer.elapsed_time)
866
+ return image_value
701
867
 
702
868
  self.images[namespace][key] = image_future
703
869
 
@@ -736,15 +902,18 @@ class Logger:
736
902
 
737
903
  @functools.lru_cache(maxsize=None)
738
904
  def images_future() -> LogImage:
739
- images = value() if callable(value) else value
740
- if max_images is not None:
741
- images = images[:max_images]
742
- if isinstance(images, Array):
743
- images = as_numpy(images)
744
- if isinstance(images, Sequence):
745
- images = list(images)
746
- images = [get_image(image, target_resolution) for image in images]
747
- tiled = tile_images([img.image for img in images], sep)
905
+ with ContextTimer() as timer:
906
+ images = value() if callable(value) else value
907
+ if max_images is not None:
908
+ images = images[:max_images]
909
+ if isinstance(images, Array):
910
+ images = as_numpy(images)
911
+ if isinstance(images, Sequence):
912
+ images = list(images)
913
+ images = [get_image(image, target_resolution) for image in images]
914
+ tiled = tile_images([img.image for img in images], sep)
915
+
916
+ logger.debug("Images Key: %s, Time: %s", key, timer.elapsed_time)
748
917
  return LogImage(image=tiled)
749
918
 
750
919
  self.images[namespace][key] = images_future
@@ -791,22 +960,25 @@ class Logger:
791
960
 
792
961
  @functools.lru_cache(maxsize=None)
793
962
  def images_future() -> LogImage:
794
- images, labels = value() if callable(value) else value
795
- if max_images is not None:
796
- images = images[:max_images]
797
- labels = labels[:max_images]
798
- images = [get_image(image, target_resolution) for image in images]
799
- labeled = [
800
- image_with_text(
801
- img.image,
802
- standardize_text(label, max_line_length),
803
- max_num_lines=max_num_lines,
804
- line_spacing=line_spacing,
805
- centered=centered,
806
- )
807
- for img, label in zip(images, labels)
808
- ]
809
- tiled = tile_images([img.image for img in labeled], sep)
963
+ with ContextTimer() as timer:
964
+ images, labels = value() if callable(value) else value
965
+ if max_images is not None:
966
+ images = images[:max_images]
967
+ labels = labels[:max_images]
968
+ images = [get_image(image, target_resolution) for image in images]
969
+ labeled = [
970
+ image_with_text(
971
+ img.image,
972
+ standardize_text(label, max_line_length),
973
+ max_num_lines=max_num_lines,
974
+ line_spacing=line_spacing,
975
+ centered=centered,
976
+ )
977
+ for img, label in zip(images, labels)
978
+ ]
979
+ tiled = tile_images([img.image for img in labeled], sep)
980
+
981
+ logger.debug("Labeled Images Key: %s, Time: %s", key, timer.elapsed_time)
810
982
  return LogImage(image=tiled)
811
983
 
812
984
  self.images[namespace][key] = images_future
@@ -839,7 +1011,11 @@ class Logger:
839
1011
 
840
1012
  @functools.lru_cache(maxsize=None)
841
1013
  def video_future() -> LogVideo:
842
- return get_video(value() if callable(value) else value, fps=fps)
1014
+ with ContextTimer() as timer:
1015
+ video = get_video(value() if callable(value) else value, fps=fps)
1016
+
1017
+ logger.debug("Video Key: %s, Time: %s", key, timer.elapsed_time)
1018
+ return video
843
1019
 
844
1020
  self.videos[namespace][key] = video_future
845
1021
 
@@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl):
33
33
  self,
34
34
  write_fp: TextIO = sys.stdout,
35
35
  precision: int = 4,
36
- log_timers: bool = False,
36
+ log_timers: bool = True,
37
37
  log_perf: bool = False,
38
38
  log_optim: bool = False,
39
39
  log_fp: bool = False,
@@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl):
98
98
 
99
99
  def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
100
100
  for namespace, values in log.items():
101
- if not self.log_timers and namespace.startswith(""):
101
+ if not self.log_timers and namespace.startswith(""):
102
102
  continue
103
103
  if not self.log_perf and namespace.startswith("🔧"):
104
104
  continue
@@ -1,11 +1,9 @@
1
1
  """Defines a Tensorboard logger backend."""
2
2
 
3
3
  import atexit
4
- import functools
5
4
  import logging
6
5
  import os
7
6
  import re
8
- import shutil
9
7
  import subprocess
10
8
  import threading
11
9
  import time
@@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl):
140
138
  def __del__(self) -> None:
141
139
  self.cleanup()
142
140
 
143
- @functools.lru_cache(None) # Avoid clearing logs multiple times.
144
- def clear_logs(self) -> None:
145
- if not self.log_directory.exists():
146
- return
147
- if not any(child.is_dir() for child in self.log_directory.iterdir()):
148
- return
149
- logger.warning("Clearing TensorBoard logs")
150
- shutil.rmtree(self.log_directory)
151
-
152
141
  def get_writer(self, phase: Phase) -> TensorboardWriter:
153
142
  self._start()
154
143
  return self.writers.writer(phase)
@@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl):
162
151
  if not is_master():
163
152
  return
164
153
 
165
- if line.state.num_steps == 0:
166
- self.clear_logs()
167
-
168
154
  writer = self.get_writer(line.state.phase)
169
155
  walltime = line.state.start_time_s + line.state.elapsed_time_s
170
156
 
@@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl):
177
163
  walltime=walltime,
178
164
  )
179
165
 
166
+ for namespace, distributions in line.distributions.items():
167
+ for distribution_key, distribution_value in distributions.items():
168
+ writer.add_gaussian_distribution(
169
+ f"{namespace}/{distribution_key}",
170
+ mean=float(distribution_value.mean),
171
+ std=float(distribution_value.std),
172
+ global_step=line.state.num_steps,
173
+ walltime=walltime,
174
+ )
175
+
176
+ for namespace, histograms in line.histograms.items():
177
+ for histogram_key, histogram_value in histograms.items():
178
+ writer.add_histogram_raw(
179
+ f"{namespace}/{histogram_key}",
180
+ min=float(histogram_value.min),
181
+ max=float(histogram_value.max),
182
+ num=int(histogram_value.num),
183
+ sum=float(histogram_value.sum),
184
+ sum_squares=float(histogram_value.sum_squares),
185
+ bucket_limits=[float(x) for x in histogram_value.bucket_limits],
186
+ bucket_counts=[int(x) for x in histogram_value.bucket_counts],
187
+ global_step=line.state.num_steps,
188
+ walltime=walltime,
189
+ )
190
+
180
191
  for namespace, strings in line.strings.items():
181
192
  for string_key, string_value in strings.items():
182
193
  writer.add_text(
@@ -3,7 +3,6 @@
3
3
  import functools
4
4
  import inspect
5
5
  import logging
6
- import os
7
6
  from dataclasses import dataclass
8
7
  from pathlib import Path
9
8
  from typing import Self, TypeVar
@@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]):
54
53
  self._exp_dir = Path(exp_dir).expanduser().resolve()
55
54
  return self
56
55
 
57
- def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
58
- if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
59
- if not exists_ok:
60
- raise RuntimeError(f"Lock file already exists at {lock_file}")
61
- else:
62
- with open(lock_file, "w", encoding="utf-8") as f:
63
- f.write(f"PID: {os.getpid()}")
64
-
65
- def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
66
- if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
67
- lock_file.unlink()
68
- elif not missing_ok:
69
- raise RuntimeError(f"Lock file not found at {lock_file}")
70
-
71
56
  def get_exp_dir(self) -> Path:
72
57
  if self._exp_dir is not None:
73
58
  return self._exp_dir
@@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]):
82
67
  def get_exp_dir(run_id: int) -> Path:
83
68
  return self.run_dir / f"run_{run_id}"
84
69
 
85
- def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
86
- if lock_type is not None:
87
- return (exp_dir / f".lock_{lock_type}").exists()
88
- return any(exp_dir.glob(".lock_*"))
89
-
90
70
  run_id = 0
91
- while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
71
+ while (exp_dir := get_exp_dir(run_id)).is_dir():
92
72
  run_id += 1
93
73
  exp_dir.mkdir(exist_ok=True, parents=True)
94
74
  self._exp_dir = exp_dir.expanduser().resolve()
@@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
- CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
24
+ CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
25
25
 
26
26
 
27
27
  def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
@@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
88
88
  def load_checkpoint(
89
89
  self,
90
90
  path: Path,
91
+ part: Literal["all"] = "all",
91
92
  ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
92
93
 
94
+ @overload
95
+ def load_checkpoint(
96
+ self,
97
+ path: Path,
98
+ part: Literal["model_state_config"] = "model_state_config",
99
+ ) -> tuple[PyTree, State, DictConfig]: ...
100
+
93
101
  @overload
94
102
  def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
95
103
 
@@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
108
116
  def load_checkpoint(
109
117
  self,
110
118
  path: Path,
111
- part: CheckpointPart | None = None,
119
+ part: CheckpointPart = "all",
112
120
  ) -> (
113
121
  tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
122
+ | tuple[PyTree, State, DictConfig]
114
123
  | PyTree
115
124
  | optax.GradientTransformation
116
125
  | optax.OptState
117
126
  | State
118
127
  | DictConfig
119
128
  ):
129
+ # Calls the base callback.
130
+ self.on_before_checkpoint_load(path)
131
+
120
132
  with tarfile.open(path, "r:gz") as tar:
121
133
 
122
134
  def get_model() -> PyTree:
@@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
155
167
  return get_state()
156
168
  case "config":
157
169
  return get_config()
158
- case None:
170
+ case "model_state_config":
171
+ return get_model(), get_state(), get_config()
172
+ case "all":
159
173
  return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
160
174
  case _:
161
175
  raise ValueError(f"Invalid checkpoint part: {part}")
@@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
215
229
  except FileExistsError:
216
230
  logger.exception("Exception while trying to update %s", ckpt_path)
217
231
 
218
- # Marks directory as having artifacts which shouldn't be overwritten.
219
- self.add_lock_file("ckpt", exists_ok=True)
232
+ # Calls the base callback.
233
+ self.on_after_checkpoint_save(ckpt_path, state)
220
234
 
221
235
  return ckpt_path
xax/task/mixins/logger.py CHANGED
@@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar
8
8
 
9
9
  import jax
10
10
 
11
+ from xax.core.conf import field
11
12
  from xax.core.state import State
12
13
  from xax.task.base import BaseConfig, BaseTask
13
14
  from xax.task.logger import Logger, LoggerImpl
@@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session
22
23
  @jax.tree_util.register_dataclass
23
24
  @dataclass
24
25
  class LoggerConfig(BaseConfig):
25
- pass
26
+ log_interval_seconds: float = field(
27
+ value=1.0,
28
+ help="The interval between successive log lines.",
29
+ )
30
+ tensorboard_log_interval_seconds: float = field(
31
+ value=10.0,
32
+ help="The interval between successive Tensorboard log lines.",
33
+ )
26
34
 
27
35
 
28
36
  Config = TypeVar("Config", bound=LoggerConfig)
@@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
49
57
  self.logger.add_logger(*logger)
50
58
 
51
59
  def set_loggers(self) -> None:
52
- self.add_logger(StdoutLogger() if is_interactive_session() else JsonLogger())
60
+ self.add_logger(
61
+ StdoutLogger(
62
+ log_interval_seconds=self.config.log_interval_seconds,
63
+ )
64
+ if is_interactive_session()
65
+ else JsonLogger(
66
+ log_interval_seconds=self.config.log_interval_seconds,
67
+ )
68
+ )
69
+
70
+ # If this is also an ArtifactsMixin, we should default add some
71
+ # additional loggers which log data to the artifacts directory.
53
72
  if isinstance(self, ArtifactsMixin):
54
73
  self.add_logger(
55
- StateLogger(self.exp_dir),
56
- TensorboardLogger(self.exp_dir),
74
+ StateLogger(
75
+ run_directory=self.exp_dir,
76
+ ),
77
+ TensorboardLogger(
78
+ run_directory=self.exp_dir,
79
+ log_interval_seconds=self.config.tensorboard_log_interval_seconds,
80
+ ),
57
81
  )
58
82
 
59
83
  def write_logs(self, state: State) -> None: