xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/nn/norm.py ADDED
@@ -0,0 +1,23 @@
1
+ """Normalization utilities."""
2
+
3
+ from typing import Literal, cast, get_args
4
+
5
+ import jax.numpy as jnp
6
+
7
+ NormType = Literal["l1", "l2"]
8
+
9
+
10
+ def cast_norm_type(norm: str) -> NormType:
11
+ if norm not in get_args(NormType):
12
+ raise ValueError(f"Invalid norm: {norm}")
13
+ return cast(NormType, norm)
14
+
15
+
16
+ def get_norm(x: jnp.ndarray, norm: NormType) -> jnp.ndarray:
17
+ match norm:
18
+ case "l1":
19
+ return jnp.abs(x)
20
+ case "l2":
21
+ return jnp.square(x)
22
+ case _:
23
+ raise ValueError(f"Invalid norm: {norm}")
xax/requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
1
  # requirements.txt
2
2
 
3
3
  # Core ML/JAX dependencies
4
+ attrs
4
5
  jax
5
6
  jaxtyping
6
7
  equinox
xax/task/base.py CHANGED
@@ -81,6 +81,12 @@ class BaseTask(Generic[Config]):
81
81
  def on_training_end(self, state: State) -> State:
82
82
  return state
83
83
 
84
+ def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
85
+ return state
86
+
87
+ def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
88
+ pass
89
+
84
90
  @functools.cached_property
85
91
  def task_class_name(self) -> str:
86
92
  return self.__class__.__name__
xax/task/logger.py CHANGED
@@ -223,10 +223,29 @@ class LogVideo:
223
223
  fps: int
224
224
 
225
225
 
226
+ @dataclass(kw_only=True)
227
+ class LogDistribution:
228
+ mean: Number
229
+ std: Number
230
+
231
+
232
+ @dataclass(kw_only=True)
233
+ class LogHistogram:
234
+ min: Number
235
+ max: Number
236
+ num: int
237
+ sum: Number
238
+ sum_squares: Number
239
+ bucket_limits: list[Number]
240
+ bucket_counts: list[int]
241
+
242
+
226
243
  @dataclass(kw_only=True)
227
244
  class LogLine:
228
245
  state: State
229
246
  scalars: dict[str, dict[str, Number]]
247
+ distributions: dict[str, dict[str, LogDistribution]]
248
+ histograms: dict[str, dict[str, LogHistogram]]
230
249
  strings: dict[str, dict[str, str]]
231
250
  images: dict[str, dict[str, LogImage]]
232
251
  videos: dict[str, dict[str, LogVideo]]
@@ -329,9 +348,9 @@ def image_with_text(
329
348
  else:
330
349
  text = text[:max_num_lines]
331
350
  width, height = image.size
332
- font: ImageFont.ImageFont = ImageFont.load_default()
351
+ font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
333
352
  _, _, _, line_height = font.getbbox(text[0])
334
- new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
353
+ new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
335
354
  padded_image = Image.new(image.mode, (new_width, new_height), 255)
336
355
  padded_image.paste(image, (0, 0))
337
356
  drawer = ImageDraw.Draw(padded_image)
@@ -497,6 +516,8 @@ class Logger:
497
516
 
498
517
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
499
518
  self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
519
+ self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
520
+ self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
500
521
  self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
501
522
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
502
523
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
@@ -522,6 +543,8 @@ class Logger:
522
543
  return LogLine(
523
544
  state=state,
524
545
  scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
546
+ distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
547
+ histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
525
548
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
526
549
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
527
550
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
@@ -529,6 +552,8 @@ class Logger:
529
552
 
530
553
  def clear(self) -> None:
531
554
  self.scalars.clear()
555
+ self.distributions.clear()
556
+ self.histograms.clear()
532
557
  self.strings.clear()
533
558
  self.images.clear()
534
559
  self.videos.clear()
@@ -612,6 +637,76 @@ class Logger:
612
637
 
613
638
  self.scalars[namespace][key] = scalar_future
614
639
 
640
+ def log_distribution(
641
+ self,
642
+ key: str,
643
+ value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
644
+ *,
645
+ namespace: str | None = None,
646
+ ) -> None:
647
+ """Logs a distribution value.
648
+
649
+ Args:
650
+ key: The key being logged
651
+ value: The distribution value being logged, a tuple of (mean, std)
652
+ namespace: An optional logging namespace
653
+ """
654
+ if not self.active:
655
+ raise RuntimeError("The logger is not active")
656
+ namespace = self.resolve_namespace(namespace)
657
+
658
+ @functools.lru_cache(maxsize=None)
659
+ def distribution_future() -> LogDistribution:
660
+ mean, std = value() if callable(value) else value
661
+ return LogDistribution(mean=mean, std=std)
662
+
663
+ self.distributions[namespace][key] = distribution_future
664
+
665
+ def log_histogram(
666
+ self,
667
+ key: str,
668
+ value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
669
+ *,
670
+ bins: int = 100,
671
+ namespace: str | None = None,
672
+ ) -> None:
673
+ """Logs a histogram value.
674
+
675
+ Args:
676
+ key: The key being logged
677
+ value: The histogram value being logged
678
+ bins: The number of bins to use for the histogram
679
+ namespace: An optional logging namespace
680
+ """
681
+ if not self.active:
682
+ raise RuntimeError("The logger is not active")
683
+ namespace = self.resolve_namespace(namespace)
684
+
685
+ @functools.lru_cache(maxsize=None)
686
+ def histogram_future() -> LogHistogram:
687
+ values = value() if callable(value) else value
688
+ values = values.reshape(-1) # Must be flat.
689
+
690
+ if isinstance(values, Array):
691
+ counts, limits = jnp.histogram(values, bins=bins)
692
+ counts, limits = as_numpy(counts), as_numpy(limits)
693
+ elif isinstance(values, np.ndarray):
694
+ counts, limits = np.histogram(values, bins=bins)
695
+ else:
696
+ raise ValueError(f"Unsupported histogram type: {type(values)}")
697
+
698
+ return LogHistogram(
699
+ min=float(values.min()),
700
+ max=float(values.max()),
701
+ num=int(values.size),
702
+ sum=float(values.sum()),
703
+ sum_squares=float(values.dot(values)),
704
+ bucket_limits=limits[1:].tolist(),
705
+ bucket_counts=counts.tolist(),
706
+ )
707
+
708
+ self.histograms[namespace][key] = histogram_future
709
+
615
710
  def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
616
711
  """Logs a string value.
617
712
 
@@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl):
33
33
  self,
34
34
  write_fp: TextIO = sys.stdout,
35
35
  precision: int = 4,
36
- log_timers: bool = False,
36
+ log_timers: bool = True,
37
37
  log_perf: bool = False,
38
38
  log_optim: bool = False,
39
39
  log_fp: bool = False,
@@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl):
98
98
 
99
99
  def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
100
100
  for namespace, values in log.items():
101
- if not self.log_timers and namespace.startswith(""):
101
+ if not self.log_timers and namespace.startswith(""):
102
102
  continue
103
103
  if not self.log_perf and namespace.startswith("🔧"):
104
104
  continue
@@ -1,11 +1,9 @@
1
1
  """Defines a Tensorboard logger backend."""
2
2
 
3
3
  import atexit
4
- import functools
5
4
  import logging
6
5
  import os
7
6
  import re
8
- import shutil
9
7
  import subprocess
10
8
  import threading
11
9
  import time
@@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl):
140
138
  def __del__(self) -> None:
141
139
  self.cleanup()
142
140
 
143
- @functools.lru_cache(None) # Avoid clearing logs multiple times.
144
- def clear_logs(self) -> None:
145
- if not self.log_directory.exists():
146
- return
147
- if not any(child.is_dir() for child in self.log_directory.iterdir()):
148
- return
149
- logger.warning("Clearing TensorBoard logs")
150
- shutil.rmtree(self.log_directory)
151
-
152
141
  def get_writer(self, phase: Phase) -> TensorboardWriter:
153
142
  self._start()
154
143
  return self.writers.writer(phase)
@@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl):
162
151
  if not is_master():
163
152
  return
164
153
 
165
- if line.state.num_steps == 0:
166
- self.clear_logs()
167
-
168
154
  writer = self.get_writer(line.state.phase)
169
155
  walltime = line.state.start_time_s + line.state.elapsed_time_s
170
156
 
@@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl):
177
163
  walltime=walltime,
178
164
  )
179
165
 
166
+ for namespace, distributions in line.distributions.items():
167
+ for distribution_key, distribution_value in distributions.items():
168
+ writer.add_gaussian_distribution(
169
+ f"{namespace}/{distribution_key}",
170
+ mean=float(distribution_value.mean),
171
+ std=float(distribution_value.std),
172
+ global_step=line.state.num_steps,
173
+ walltime=walltime,
174
+ )
175
+
176
+ for namespace, histograms in line.histograms.items():
177
+ for histogram_key, histogram_value in histograms.items():
178
+ writer.add_histogram_raw(
179
+ f"{namespace}/{histogram_key}",
180
+ min=float(histogram_value.min),
181
+ max=float(histogram_value.max),
182
+ num=int(histogram_value.num),
183
+ sum=float(histogram_value.sum),
184
+ sum_squares=float(histogram_value.sum_squares),
185
+ bucket_limits=[float(x) for x in histogram_value.bucket_limits],
186
+ bucket_counts=[int(x) for x in histogram_value.bucket_counts],
187
+ global_step=line.state.num_steps,
188
+ walltime=walltime,
189
+ )
190
+
180
191
  for namespace, strings in line.strings.items():
181
192
  for string_key, string_value in strings.items():
182
193
  writer.add_text(
@@ -3,7 +3,6 @@
3
3
  import functools
4
4
  import inspect
5
5
  import logging
6
- import os
7
6
  from dataclasses import dataclass
8
7
  from pathlib import Path
9
8
  from typing import Self, TypeVar
@@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]):
54
53
  self._exp_dir = Path(exp_dir).expanduser().resolve()
55
54
  return self
56
55
 
57
- def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
58
- if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
59
- if not exists_ok:
60
- raise RuntimeError(f"Lock file already exists at {lock_file}")
61
- else:
62
- with open(lock_file, "w", encoding="utf-8") as f:
63
- f.write(f"PID: {os.getpid()}")
64
-
65
- def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
66
- if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
67
- lock_file.unlink()
68
- elif not missing_ok:
69
- raise RuntimeError(f"Lock file not found at {lock_file}")
70
-
71
56
  def get_exp_dir(self) -> Path:
72
57
  if self._exp_dir is not None:
73
58
  return self._exp_dir
@@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]):
82
67
  def get_exp_dir(run_id: int) -> Path:
83
68
  return self.run_dir / f"run_{run_id}"
84
69
 
85
- def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
86
- if lock_type is not None:
87
- return (exp_dir / f".lock_{lock_type}").exists()
88
- return any(exp_dir.glob(".lock_*"))
89
-
90
70
  run_id = 0
91
- while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
71
+ while (exp_dir := get_exp_dir(run_id)).is_dir():
92
72
  run_id += 1
93
73
  exp_dir.mkdir(exist_ok=True, parents=True)
94
74
  self._exp_dir = exp_dir.expanduser().resolve()
@@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
- CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
24
+ CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
25
25
 
26
26
 
27
27
  def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
@@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
88
88
  def load_checkpoint(
89
89
  self,
90
90
  path: Path,
91
+ part: Literal["all"] = "all",
91
92
  ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
92
93
 
94
+ @overload
95
+ def load_checkpoint(
96
+ self,
97
+ path: Path,
98
+ part: Literal["model_state_config"] = "model_state_config",
99
+ ) -> tuple[PyTree, State, DictConfig]: ...
100
+
93
101
  @overload
94
102
  def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
95
103
 
@@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
108
116
  def load_checkpoint(
109
117
  self,
110
118
  path: Path,
111
- part: CheckpointPart | None = None,
119
+ part: CheckpointPart = "all",
112
120
  ) -> (
113
121
  tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
122
+ | tuple[PyTree, State, DictConfig]
114
123
  | PyTree
115
124
  | optax.GradientTransformation
116
125
  | optax.OptState
117
126
  | State
118
127
  | DictConfig
119
128
  ):
129
+ # Calls the base callback.
130
+ self.on_before_checkpoint_load(path)
131
+
120
132
  with tarfile.open(path, "r:gz") as tar:
121
133
 
122
134
  def get_model() -> PyTree:
@@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
155
167
  return get_state()
156
168
  case "config":
157
169
  return get_config()
158
- case None:
170
+ case "model_state_config":
171
+ return get_model(), get_state(), get_config()
172
+ case "all":
159
173
  return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
160
174
  case _:
161
175
  raise ValueError(f"Invalid checkpoint part: {part}")
@@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
215
229
  except FileExistsError:
216
230
  logger.exception("Exception while trying to update %s", ckpt_path)
217
231
 
218
- # Marks directory as having artifacts which shouldn't be overwritten.
219
- self.add_lock_file("ckpt", exists_ok=True)
232
+ # Calls the base callback.
233
+ self.on_after_checkpoint_save(ckpt_path, state)
220
234
 
221
235
  return ckpt_path
xax/task/mixins/logger.py CHANGED
@@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar
8
8
 
9
9
  import jax
10
10
 
11
+ from xax.core.conf import field
11
12
  from xax.core.state import State
12
13
  from xax.task.base import BaseConfig, BaseTask
13
14
  from xax.task.logger import Logger, LoggerImpl
@@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session
22
23
  @jax.tree_util.register_dataclass
23
24
  @dataclass
24
25
  class LoggerConfig(BaseConfig):
25
- pass
26
+ log_interval_seconds: float = field(
27
+ value=1.0,
28
+ help="The interval between successive log lines.",
29
+ )
30
+ tensorboard_log_interval_seconds: float = field(
31
+ value=10.0,
32
+ help="The interval between successive Tensorboard log lines.",
33
+ )
26
34
 
27
35
 
28
36
  Config = TypeVar("Config", bound=LoggerConfig)
@@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
49
57
  self.logger.add_logger(*logger)
50
58
 
51
59
  def set_loggers(self) -> None:
52
- self.add_logger(StdoutLogger() if is_interactive_session() else JsonLogger())
60
+ self.add_logger(
61
+ StdoutLogger(
62
+ log_interval_seconds=self.config.log_interval_seconds,
63
+ )
64
+ if is_interactive_session()
65
+ else JsonLogger(
66
+ log_interval_seconds=self.config.log_interval_seconds,
67
+ )
68
+ )
69
+
70
+ # If this is also an ArtifactsMixin, we should default add some
71
+ # additional loggers which log data to the artifacts directory.
53
72
  if isinstance(self, ArtifactsMixin):
54
73
  self.add_logger(
55
- StateLogger(self.exp_dir),
56
- TensorboardLogger(self.exp_dir),
74
+ StateLogger(
75
+ run_directory=self.exp_dir,
76
+ ),
77
+ TensorboardLogger(
78
+ run_directory=self.exp_dir,
79
+ log_interval_seconds=self.config.tensorboard_log_interval_seconds,
80
+ ),
57
81
  )
58
82
 
59
83
  def write_logs(self, state: State) -> None:
@@ -1,53 +1,39 @@
1
1
  """Defines a mixin to wrap some steps in a context manager."""
2
2
 
3
+ import time
3
4
  from dataclasses import dataclass
4
5
  from types import TracebackType
5
- from typing import ContextManager, Literal, TypeVar
6
+ from typing import Callable, ContextManager, TypeVar
6
7
 
7
- import equinox as eqx
8
8
  import jax
9
9
 
10
10
  from xax.task.base import BaseConfig, BaseTask
11
11
 
12
- StepType = Literal[
13
- "backward",
14
- "change_mode",
15
- "clip_grads",
16
- "create_optimizers",
17
- "forward",
18
- "get_dataloader",
19
- "get_dataset",
20
- "get_prefetcher",
21
- "get_model",
22
- "get_optimizer",
23
- "get_initial_opt_state",
24
- "get_update_fn",
25
- "load_checkpoint",
26
- "log_losses",
27
- "model_to_device",
28
- "on_step_end",
29
- "on_step_start",
30
- "save_checkpoint",
31
- "step",
32
- "update_state",
33
- "write_logs",
34
- "zero_grads",
35
- ]
36
-
37
12
 
38
13
  class StepContext(ContextManager):
39
14
  """Context manager to get the current step type."""
40
15
 
41
- CURRENT_STEP: StepType | None = None
16
+ CURRENT_STEP: str | None = None
42
17
 
43
- def __init__(self, step: StepType) -> None:
18
+ def __init__(
19
+ self,
20
+ step: str,
21
+ on_context_start: Callable[[str], None],
22
+ on_context_end: Callable[[str, float], None],
23
+ ) -> None:
44
24
  self.step = step
25
+ self.start_time = 0.0
26
+ self.on_context_start = on_context_start
27
+ self.on_context_end = on_context_end
45
28
 
46
29
  def __enter__(self) -> None:
47
30
  StepContext.CURRENT_STEP = self.step
31
+ self.start_time = time.time()
32
+ self.on_context_start(self.step)
48
33
 
49
34
  def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
50
35
  StepContext.CURRENT_STEP = None
36
+ self.on_context_end(self.step, time.time() - self.start_time)
51
37
 
52
38
 
53
39
  @jax.tree_util.register_dataclass
@@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]):
63
49
  def __init__(self, config: Config) -> None:
64
50
  super().__init__(config)
65
51
 
66
- @eqx.filter_jit
67
- def step_context(self, step: StepType) -> ContextManager:
68
- return StepContext(step)
52
+ def step_context(self, step: str) -> ContextManager:
53
+ return StepContext(step, self.on_context_start, self.on_context_stop)
54
+
55
+ def on_context_start(self, step: str) -> None:
56
+ pass
57
+
58
+ def on_context_stop(self, step: str, elapsed_time: float) -> None:
59
+ pass