xax 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/RECORD +26 -21
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
    
        xax/task/logger.py
    CHANGED
    
    | @@ -223,10 +223,29 @@ class LogVideo: | |
| 223 223 | 
             
                fps: int
         | 
| 224 224 |  | 
| 225 225 |  | 
| 226 | 
            +
            @dataclass(kw_only=True)
         | 
| 227 | 
            +
            class LogDistribution:
         | 
| 228 | 
            +
                mean: Number
         | 
| 229 | 
            +
                std: Number
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            @dataclass(kw_only=True)
         | 
| 233 | 
            +
            class LogHistogram:
         | 
| 234 | 
            +
                min: Number
         | 
| 235 | 
            +
                max: Number
         | 
| 236 | 
            +
                num: int
         | 
| 237 | 
            +
                sum: Number
         | 
| 238 | 
            +
                sum_squares: Number
         | 
| 239 | 
            +
                bucket_limits: list[Number]
         | 
| 240 | 
            +
                bucket_counts: list[int]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 226 243 | 
             
            @dataclass(kw_only=True)
         | 
| 227 244 | 
             
            class LogLine:
         | 
| 228 245 | 
             
                state: State
         | 
| 229 246 | 
             
                scalars: dict[str, dict[str, Number]]
         | 
| 247 | 
            +
                distributions: dict[str, dict[str, LogDistribution]]
         | 
| 248 | 
            +
                histograms: dict[str, dict[str, LogHistogram]]
         | 
| 230 249 | 
             
                strings: dict[str, dict[str, str]]
         | 
| 231 250 | 
             
                images: dict[str, dict[str, LogImage]]
         | 
| 232 251 | 
             
                videos: dict[str, dict[str, LogVideo]]
         | 
| @@ -329,9 +348,9 @@ def image_with_text( | |
| 329 348 | 
             
                else:
         | 
| 330 349 | 
             
                    text = text[:max_num_lines]
         | 
| 331 350 | 
             
                width, height = image.size
         | 
| 332 | 
            -
                font: ImageFont.ImageFont = ImageFont.load_default()
         | 
| 351 | 
            +
                font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
         | 
| 333 352 | 
             
                _, _, _, line_height = font.getbbox(text[0])
         | 
| 334 | 
            -
                new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
         | 
| 353 | 
            +
                new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
         | 
| 335 354 | 
             
                padded_image = Image.new(image.mode, (new_width, new_height), 255)
         | 
| 336 355 | 
             
                padded_image.paste(image, (0, 0))
         | 
| 337 356 | 
             
                drawer = ImageDraw.Draw(padded_image)
         | 
| @@ -497,6 +516,8 @@ class Logger: | |
| 497 516 |  | 
| 498 517 | 
             
                def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
         | 
| 499 518 | 
             
                    self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
         | 
| 519 | 
            +
                    self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
         | 
| 520 | 
            +
                    self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
         | 
| 500 521 | 
             
                    self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
         | 
| 501 522 | 
             
                    self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
         | 
| 502 523 | 
             
                    self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
         | 
| @@ -522,6 +543,8 @@ class Logger: | |
| 522 543 | 
             
                    return LogLine(
         | 
| 523 544 | 
             
                        state=state,
         | 
| 524 545 | 
             
                        scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
         | 
| 546 | 
            +
                        distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
         | 
| 547 | 
            +
                        histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
         | 
| 525 548 | 
             
                        strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
         | 
| 526 549 | 
             
                        images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
         | 
| 527 550 | 
             
                        videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
         | 
| @@ -529,6 +552,8 @@ class Logger: | |
| 529 552 |  | 
| 530 553 | 
             
                def clear(self) -> None:
         | 
| 531 554 | 
             
                    self.scalars.clear()
         | 
| 555 | 
            +
                    self.distributions.clear()
         | 
| 556 | 
            +
                    self.histograms.clear()
         | 
| 532 557 | 
             
                    self.strings.clear()
         | 
| 533 558 | 
             
                    self.images.clear()
         | 
| 534 559 | 
             
                    self.videos.clear()
         | 
| @@ -612,6 +637,76 @@ class Logger: | |
| 612 637 |  | 
| 613 638 | 
             
                    self.scalars[namespace][key] = scalar_future
         | 
| 614 639 |  | 
| 640 | 
            +
                def log_distribution(
         | 
| 641 | 
            +
                    self,
         | 
| 642 | 
            +
                    key: str,
         | 
| 643 | 
            +
                    value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
         | 
| 644 | 
            +
                    *,
         | 
| 645 | 
            +
                    namespace: str | None = None,
         | 
| 646 | 
            +
                ) -> None:
         | 
| 647 | 
            +
                    """Logs a distribution value.
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    Args:
         | 
| 650 | 
            +
                        key: The key being logged
         | 
| 651 | 
            +
                        value: The distribution value being logged, a tuple of (mean, std)
         | 
| 652 | 
            +
                        namespace: An optional logging namespace
         | 
| 653 | 
            +
                    """
         | 
| 654 | 
            +
                    if not self.active:
         | 
| 655 | 
            +
                        raise RuntimeError("The logger is not active")
         | 
| 656 | 
            +
                    namespace = self.resolve_namespace(namespace)
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                    @functools.lru_cache(maxsize=None)
         | 
| 659 | 
            +
                    def distribution_future() -> LogDistribution:
         | 
| 660 | 
            +
                        mean, std = value() if callable(value) else value
         | 
| 661 | 
            +
                        return LogDistribution(mean=mean, std=std)
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    self.distributions[namespace][key] = distribution_future
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                def log_histogram(
         | 
| 666 | 
            +
                    self,
         | 
| 667 | 
            +
                    key: str,
         | 
| 668 | 
            +
                    value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
         | 
| 669 | 
            +
                    *,
         | 
| 670 | 
            +
                    bins: int = 100,
         | 
| 671 | 
            +
                    namespace: str | None = None,
         | 
| 672 | 
            +
                ) -> None:
         | 
| 673 | 
            +
                    """Logs a histogram value.
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    Args:
         | 
| 676 | 
            +
                        key: The key being logged
         | 
| 677 | 
            +
                        value: The histogram value being logged
         | 
| 678 | 
            +
                        bins: The number of bins to use for the histogram
         | 
| 679 | 
            +
                        namespace: An optional logging namespace
         | 
| 680 | 
            +
                    """
         | 
| 681 | 
            +
                    if not self.active:
         | 
| 682 | 
            +
                        raise RuntimeError("The logger is not active")
         | 
| 683 | 
            +
                    namespace = self.resolve_namespace(namespace)
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                    @functools.lru_cache(maxsize=None)
         | 
| 686 | 
            +
                    def histogram_future() -> LogHistogram:
         | 
| 687 | 
            +
                        values = value() if callable(value) else value
         | 
| 688 | 
            +
                        values = values.reshape(-1)  # Must be flat.
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                        if isinstance(values, Array):
         | 
| 691 | 
            +
                            counts, limits = jnp.histogram(values, bins=bins)
         | 
| 692 | 
            +
                            counts, limits = as_numpy(counts), as_numpy(limits)
         | 
| 693 | 
            +
                        elif isinstance(values, np.ndarray):
         | 
| 694 | 
            +
                            counts, limits = np.histogram(values, bins=bins)
         | 
| 695 | 
            +
                        else:
         | 
| 696 | 
            +
                            raise ValueError(f"Unsupported histogram type: {type(values)}")
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                        return LogHistogram(
         | 
| 699 | 
            +
                            min=float(values.min()),
         | 
| 700 | 
            +
                            max=float(values.max()),
         | 
| 701 | 
            +
                            num=int(values.size),
         | 
| 702 | 
            +
                            sum=float(values.sum()),
         | 
| 703 | 
            +
                            sum_squares=float(values.dot(values)),
         | 
| 704 | 
            +
                            bucket_limits=limits[1:].tolist(),
         | 
| 705 | 
            +
                            bucket_counts=counts.tolist(),
         | 
| 706 | 
            +
                        )
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                    self.histograms[namespace][key] = histogram_future
         | 
| 709 | 
            +
             | 
| 615 710 | 
             
                def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
         | 
| 616 711 | 
             
                    """Logs a string value.
         | 
| 617 712 |  | 
    
        xax/task/loggers/stdout.py
    CHANGED
    
    | @@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl): | |
| 33 33 | 
             
                    self,
         | 
| 34 34 | 
             
                    write_fp: TextIO = sys.stdout,
         | 
| 35 35 | 
             
                    precision: int = 4,
         | 
| 36 | 
            -
                    log_timers: bool =  | 
| 36 | 
            +
                    log_timers: bool = True,
         | 
| 37 37 | 
             
                    log_perf: bool = False,
         | 
| 38 38 | 
             
                    log_optim: bool = False,
         | 
| 39 39 | 
             
                    log_fp: bool = False,
         | 
| @@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl): | |
| 98 98 |  | 
| 99 99 | 
             
                    def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
         | 
| 100 100 | 
             
                        for namespace, values in log.items():
         | 
| 101 | 
            -
                            if not self.log_timers and namespace.startswith(" | 
| 101 | 
            +
                            if not self.log_timers and namespace.startswith("⌛"):
         | 
| 102 102 | 
             
                                continue
         | 
| 103 103 | 
             
                            if not self.log_perf and namespace.startswith("🔧"):
         | 
| 104 104 | 
             
                                continue
         | 
    
        xax/task/loggers/tensorboard.py
    CHANGED
    
    | @@ -1,11 +1,9 @@ | |
| 1 1 | 
             
            """Defines a Tensorboard logger backend."""
         | 
| 2 2 |  | 
| 3 3 | 
             
            import atexit
         | 
| 4 | 
            -
            import functools
         | 
| 5 4 | 
             
            import logging
         | 
| 6 5 | 
             
            import os
         | 
| 7 6 | 
             
            import re
         | 
| 8 | 
            -
            import shutil
         | 
| 9 7 | 
             
            import subprocess
         | 
| 10 8 | 
             
            import threading
         | 
| 11 9 | 
             
            import time
         | 
| @@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl): | |
| 140 138 | 
             
                def __del__(self) -> None:
         | 
| 141 139 | 
             
                    self.cleanup()
         | 
| 142 140 |  | 
| 143 | 
            -
                @functools.lru_cache(None)  # Avoid clearing logs multiple times.
         | 
| 144 | 
            -
                def clear_logs(self) -> None:
         | 
| 145 | 
            -
                    if not self.log_directory.exists():
         | 
| 146 | 
            -
                        return
         | 
| 147 | 
            -
                    if not any(child.is_dir() for child in self.log_directory.iterdir()):
         | 
| 148 | 
            -
                        return
         | 
| 149 | 
            -
                    logger.warning("Clearing TensorBoard logs")
         | 
| 150 | 
            -
                    shutil.rmtree(self.log_directory)
         | 
| 151 | 
            -
             | 
| 152 141 | 
             
                def get_writer(self, phase: Phase) -> TensorboardWriter:
         | 
| 153 142 | 
             
                    self._start()
         | 
| 154 143 | 
             
                    return self.writers.writer(phase)
         | 
| @@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl): | |
| 162 151 | 
             
                    if not is_master():
         | 
| 163 152 | 
             
                        return
         | 
| 164 153 |  | 
| 165 | 
            -
                    if line.state.num_steps == 0:
         | 
| 166 | 
            -
                        self.clear_logs()
         | 
| 167 | 
            -
             | 
| 168 154 | 
             
                    writer = self.get_writer(line.state.phase)
         | 
| 169 155 | 
             
                    walltime = line.state.start_time_s + line.state.elapsed_time_s
         | 
| 170 156 |  | 
| @@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl): | |
| 177 163 | 
             
                                walltime=walltime,
         | 
| 178 164 | 
             
                            )
         | 
| 179 165 |  | 
| 166 | 
            +
                    for namespace, distributions in line.distributions.items():
         | 
| 167 | 
            +
                        for distribution_key, distribution_value in distributions.items():
         | 
| 168 | 
            +
                            writer.add_gaussian_distribution(
         | 
| 169 | 
            +
                                f"{namespace}/{distribution_key}",
         | 
| 170 | 
            +
                                mean=float(distribution_value.mean),
         | 
| 171 | 
            +
                                std=float(distribution_value.std),
         | 
| 172 | 
            +
                                global_step=line.state.num_steps,
         | 
| 173 | 
            +
                                walltime=walltime,
         | 
| 174 | 
            +
                            )
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    for namespace, histograms in line.histograms.items():
         | 
| 177 | 
            +
                        for histogram_key, histogram_value in histograms.items():
         | 
| 178 | 
            +
                            writer.add_histogram_raw(
         | 
| 179 | 
            +
                                f"{namespace}/{histogram_key}",
         | 
| 180 | 
            +
                                min=float(histogram_value.min),
         | 
| 181 | 
            +
                                max=float(histogram_value.max),
         | 
| 182 | 
            +
                                num=int(histogram_value.num),
         | 
| 183 | 
            +
                                sum=float(histogram_value.sum),
         | 
| 184 | 
            +
                                sum_squares=float(histogram_value.sum_squares),
         | 
| 185 | 
            +
                                bucket_limits=[float(x) for x in histogram_value.bucket_limits],
         | 
| 186 | 
            +
                                bucket_counts=[int(x) for x in histogram_value.bucket_counts],
         | 
| 187 | 
            +
                                global_step=line.state.num_steps,
         | 
| 188 | 
            +
                                walltime=walltime,
         | 
| 189 | 
            +
                            )
         | 
| 190 | 
            +
             | 
| 180 191 | 
             
                    for namespace, strings in line.strings.items():
         | 
| 181 192 | 
             
                        for string_key, string_value in strings.items():
         | 
| 182 193 | 
             
                            writer.add_text(
         | 
    
        xax/task/mixins/artifacts.py
    CHANGED
    
    | @@ -3,7 +3,6 @@ | |
| 3 3 | 
             
            import functools
         | 
| 4 4 | 
             
            import inspect
         | 
| 5 5 | 
             
            import logging
         | 
| 6 | 
            -
            import os
         | 
| 7 6 | 
             
            from dataclasses import dataclass
         | 
| 8 7 | 
             
            from pathlib import Path
         | 
| 9 8 | 
             
            from typing import Self, TypeVar
         | 
| @@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]): | |
| 54 53 | 
             
                    self._exp_dir = Path(exp_dir).expanduser().resolve()
         | 
| 55 54 | 
             
                    return self
         | 
| 56 55 |  | 
| 57 | 
            -
                def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
         | 
| 58 | 
            -
                    if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
         | 
| 59 | 
            -
                        if not exists_ok:
         | 
| 60 | 
            -
                            raise RuntimeError(f"Lock file already exists at {lock_file}")
         | 
| 61 | 
            -
                    else:
         | 
| 62 | 
            -
                        with open(lock_file, "w", encoding="utf-8") as f:
         | 
| 63 | 
            -
                            f.write(f"PID: {os.getpid()}")
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
         | 
| 66 | 
            -
                    if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
         | 
| 67 | 
            -
                        lock_file.unlink()
         | 
| 68 | 
            -
                    elif not missing_ok:
         | 
| 69 | 
            -
                        raise RuntimeError(f"Lock file not found at {lock_file}")
         | 
| 70 | 
            -
             | 
| 71 56 | 
             
                def get_exp_dir(self) -> Path:
         | 
| 72 57 | 
             
                    if self._exp_dir is not None:
         | 
| 73 58 | 
             
                        return self._exp_dir
         | 
| @@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]): | |
| 82 67 | 
             
                    def get_exp_dir(run_id: int) -> Path:
         | 
| 83 68 | 
             
                        return self.run_dir / f"run_{run_id}"
         | 
| 84 69 |  | 
| 85 | 
            -
                    def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
         | 
| 86 | 
            -
                        if lock_type is not None:
         | 
| 87 | 
            -
                            return (exp_dir / f".lock_{lock_type}").exists()
         | 
| 88 | 
            -
                        return any(exp_dir.glob(".lock_*"))
         | 
| 89 | 
            -
             | 
| 90 70 | 
             
                    run_id = 0
         | 
| 91 | 
            -
                    while (exp_dir := get_exp_dir(run_id)).is_dir() | 
| 71 | 
            +
                    while (exp_dir := get_exp_dir(run_id)).is_dir():
         | 
| 92 72 | 
             
                        run_id += 1
         | 
| 93 73 | 
             
                    exp_dir.mkdir(exist_ok=True, parents=True)
         | 
| 94 74 | 
             
                    self._exp_dir = exp_dir.expanduser().resolve()
         | 
    
        xax/task/mixins/checkpointing.py
    CHANGED
    
    | @@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin | |
| 21 21 |  | 
| 22 22 | 
             
            logger = logging.getLogger(__name__)
         | 
| 23 23 |  | 
| 24 | 
            -
            CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
         | 
| 24 | 
            +
            CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
         | 
| 25 25 |  | 
| 26 26 |  | 
| 27 27 | 
             
            def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
         | 
| @@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 88 88 | 
             
                def load_checkpoint(
         | 
| 89 89 | 
             
                    self,
         | 
| 90 90 | 
             
                    path: Path,
         | 
| 91 | 
            +
                    part: Literal["all"] = "all",
         | 
| 91 92 | 
             
                ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
         | 
| 92 93 |  | 
| 94 | 
            +
                @overload
         | 
| 95 | 
            +
                def load_checkpoint(
         | 
| 96 | 
            +
                    self,
         | 
| 97 | 
            +
                    path: Path,
         | 
| 98 | 
            +
                    part: Literal["model_state_config"] = "model_state_config",
         | 
| 99 | 
            +
                ) -> tuple[PyTree, State, DictConfig]: ...
         | 
| 100 | 
            +
             | 
| 93 101 | 
             
                @overload
         | 
| 94 102 | 
             
                def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
         | 
| 95 103 |  | 
| @@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 108 116 | 
             
                def load_checkpoint(
         | 
| 109 117 | 
             
                    self,
         | 
| 110 118 | 
             
                    path: Path,
         | 
| 111 | 
            -
                    part: CheckpointPart  | 
| 119 | 
            +
                    part: CheckpointPart = "all",
         | 
| 112 120 | 
             
                ) -> (
         | 
| 113 121 | 
             
                    tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
         | 
| 122 | 
            +
                    | tuple[PyTree, State, DictConfig]
         | 
| 114 123 | 
             
                    | PyTree
         | 
| 115 124 | 
             
                    | optax.GradientTransformation
         | 
| 116 125 | 
             
                    | optax.OptState
         | 
| 117 126 | 
             
                    | State
         | 
| 118 127 | 
             
                    | DictConfig
         | 
| 119 128 | 
             
                ):
         | 
| 129 | 
            +
                    # Calls the base callback.
         | 
| 130 | 
            +
                    self.on_before_checkpoint_load(path)
         | 
| 131 | 
            +
             | 
| 120 132 | 
             
                    with tarfile.open(path, "r:gz") as tar:
         | 
| 121 133 |  | 
| 122 134 | 
             
                        def get_model() -> PyTree:
         | 
| @@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 155 167 | 
             
                                return get_state()
         | 
| 156 168 | 
             
                            case "config":
         | 
| 157 169 | 
             
                                return get_config()
         | 
| 158 | 
            -
                            case  | 
| 170 | 
            +
                            case "model_state_config":
         | 
| 171 | 
            +
                                return get_model(), get_state(), get_config()
         | 
| 172 | 
            +
                            case "all":
         | 
| 159 173 | 
             
                                return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
         | 
| 160 174 | 
             
                            case _:
         | 
| 161 175 | 
             
                                raise ValueError(f"Invalid checkpoint part: {part}")
         | 
| @@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 215 229 | 
             
                    except FileExistsError:
         | 
| 216 230 | 
             
                        logger.exception("Exception while trying to update %s", ckpt_path)
         | 
| 217 231 |  | 
| 218 | 
            -
                    #  | 
| 219 | 
            -
                    self. | 
| 232 | 
            +
                    # Calls the base callback.
         | 
| 233 | 
            +
                    self.on_after_checkpoint_save(ckpt_path, state)
         | 
| 220 234 |  | 
| 221 235 | 
             
                    return ckpt_path
         | 
    
        xax/task/mixins/logger.py
    CHANGED
    
    | @@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar | |
| 8 8 |  | 
| 9 9 | 
             
            import jax
         | 
| 10 10 |  | 
| 11 | 
            +
            from xax.core.conf import field
         | 
| 11 12 | 
             
            from xax.core.state import State
         | 
| 12 13 | 
             
            from xax.task.base import BaseConfig, BaseTask
         | 
| 13 14 | 
             
            from xax.task.logger import Logger, LoggerImpl
         | 
| @@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session | |
| 22 23 | 
             
            @jax.tree_util.register_dataclass
         | 
| 23 24 | 
             
            @dataclass
         | 
| 24 25 | 
             
            class LoggerConfig(BaseConfig):
         | 
| 25 | 
            -
                 | 
| 26 | 
            +
                log_interval_seconds: float = field(
         | 
| 27 | 
            +
                    value=1.0,
         | 
| 28 | 
            +
                    help="The interval between successive log lines.",
         | 
| 29 | 
            +
                )
         | 
| 30 | 
            +
                tensorboard_log_interval_seconds: float = field(
         | 
| 31 | 
            +
                    value=10.0,
         | 
| 32 | 
            +
                    help="The interval between successive Tensorboard log lines.",
         | 
| 33 | 
            +
                )
         | 
| 26 34 |  | 
| 27 35 |  | 
| 28 36 | 
             
            Config = TypeVar("Config", bound=LoggerConfig)
         | 
| @@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]): | |
| 49 57 | 
             
                    self.logger.add_logger(*logger)
         | 
| 50 58 |  | 
| 51 59 | 
             
                def set_loggers(self) -> None:
         | 
| 52 | 
            -
                    self.add_logger( | 
| 60 | 
            +
                    self.add_logger(
         | 
| 61 | 
            +
                        StdoutLogger(
         | 
| 62 | 
            +
                            log_interval_seconds=self.config.log_interval_seconds,
         | 
| 63 | 
            +
                        )
         | 
| 64 | 
            +
                        if is_interactive_session()
         | 
| 65 | 
            +
                        else JsonLogger(
         | 
| 66 | 
            +
                            log_interval_seconds=self.config.log_interval_seconds,
         | 
| 67 | 
            +
                        )
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # If this is also an ArtifactsMixin, we should default add some
         | 
| 71 | 
            +
                    # additional loggers which log data to the artifacts directory.
         | 
| 53 72 | 
             
                    if isinstance(self, ArtifactsMixin):
         | 
| 54 73 | 
             
                        self.add_logger(
         | 
| 55 | 
            -
                            StateLogger( | 
| 56 | 
            -
             | 
| 74 | 
            +
                            StateLogger(
         | 
| 75 | 
            +
                                run_directory=self.exp_dir,
         | 
| 76 | 
            +
                            ),
         | 
| 77 | 
            +
                            TensorboardLogger(
         | 
| 78 | 
            +
                                run_directory=self.exp_dir,
         | 
| 79 | 
            +
                                log_interval_seconds=self.config.tensorboard_log_interval_seconds,
         | 
| 80 | 
            +
                            ),
         | 
| 57 81 | 
             
                        )
         | 
| 58 82 |  | 
| 59 83 | 
             
                def write_logs(self, state: State) -> None:
         | 
    
        xax/task/mixins/step_wrapper.py
    CHANGED
    
    | @@ -1,53 +1,39 @@ | |
| 1 1 | 
             
            """Defines a mixin to wrap some steps in a context manager."""
         | 
| 2 2 |  | 
| 3 | 
            +
            import time
         | 
| 3 4 | 
             
            from dataclasses import dataclass
         | 
| 4 5 | 
             
            from types import TracebackType
         | 
| 5 | 
            -
            from typing import  | 
| 6 | 
            +
            from typing import Callable, ContextManager, TypeVar
         | 
| 6 7 |  | 
| 7 | 
            -
            import equinox as eqx
         | 
| 8 8 | 
             
            import jax
         | 
| 9 9 |  | 
| 10 10 | 
             
            from xax.task.base import BaseConfig, BaseTask
         | 
| 11 11 |  | 
| 12 | 
            -
            StepType = Literal[
         | 
| 13 | 
            -
                "backward",
         | 
| 14 | 
            -
                "change_mode",
         | 
| 15 | 
            -
                "clip_grads",
         | 
| 16 | 
            -
                "create_optimizers",
         | 
| 17 | 
            -
                "forward",
         | 
| 18 | 
            -
                "get_dataloader",
         | 
| 19 | 
            -
                "get_dataset",
         | 
| 20 | 
            -
                "get_prefetcher",
         | 
| 21 | 
            -
                "get_model",
         | 
| 22 | 
            -
                "get_optimizer",
         | 
| 23 | 
            -
                "get_initial_opt_state",
         | 
| 24 | 
            -
                "get_update_fn",
         | 
| 25 | 
            -
                "load_checkpoint",
         | 
| 26 | 
            -
                "log_losses",
         | 
| 27 | 
            -
                "model_to_device",
         | 
| 28 | 
            -
                "on_step_end",
         | 
| 29 | 
            -
                "on_step_start",
         | 
| 30 | 
            -
                "save_checkpoint",
         | 
| 31 | 
            -
                "step",
         | 
| 32 | 
            -
                "update_state",
         | 
| 33 | 
            -
                "write_logs",
         | 
| 34 | 
            -
                "zero_grads",
         | 
| 35 | 
            -
            ]
         | 
| 36 | 
            -
             | 
| 37 12 |  | 
| 38 13 | 
             
            class StepContext(ContextManager):
         | 
| 39 14 | 
             
                """Context manager to get the current step type."""
         | 
| 40 15 |  | 
| 41 | 
            -
                CURRENT_STEP:  | 
| 16 | 
            +
                CURRENT_STEP: str | None = None
         | 
| 42 17 |  | 
| 43 | 
            -
                def __init__( | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self,
         | 
| 20 | 
            +
                    step: str,
         | 
| 21 | 
            +
                    on_context_start: Callable[[str], None],
         | 
| 22 | 
            +
                    on_context_end: Callable[[str, float], None],
         | 
| 23 | 
            +
                ) -> None:
         | 
| 44 24 | 
             
                    self.step = step
         | 
| 25 | 
            +
                    self.start_time = 0.0
         | 
| 26 | 
            +
                    self.on_context_start = on_context_start
         | 
| 27 | 
            +
                    self.on_context_end = on_context_end
         | 
| 45 28 |  | 
| 46 29 | 
             
                def __enter__(self) -> None:
         | 
| 47 30 | 
             
                    StepContext.CURRENT_STEP = self.step
         | 
| 31 | 
            +
                    self.start_time = time.time()
         | 
| 32 | 
            +
                    self.on_context_start(self.step)
         | 
| 48 33 |  | 
| 49 34 | 
             
                def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
         | 
| 50 35 | 
             
                    StepContext.CURRENT_STEP = None
         | 
| 36 | 
            +
                    self.on_context_end(self.step, time.time() - self.start_time)
         | 
| 51 37 |  | 
| 52 38 |  | 
| 53 39 | 
             
            @jax.tree_util.register_dataclass
         | 
| @@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]): | |
| 63 49 | 
             
                def __init__(self, config: Config) -> None:
         | 
| 64 50 | 
             
                    super().__init__(config)
         | 
| 65 51 |  | 
| 66 | 
            -
                 | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 52 | 
            +
                def step_context(self, step: str) -> ContextManager:
         | 
| 53 | 
            +
                    return StepContext(step, self.on_context_start, self.on_context_stop)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def on_context_start(self, step: str) -> None:
         | 
| 56 | 
            +
                    pass
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def on_context_stop(self, step: str, elapsed_time: float) -> None:
         | 
| 59 | 
            +
                    pass
         | 
    
        xax/task/mixins/train.py
    CHANGED
    
    | @@ -24,6 +24,7 @@ from typing import ( | |
| 24 24 | 
             
                TypeVar,
         | 
| 25 25 | 
             
                cast,
         | 
| 26 26 | 
             
                get_args,
         | 
| 27 | 
            +
                overload,
         | 
| 27 28 | 
             
            )
         | 
| 28 29 |  | 
| 29 30 | 
             
            import equinox as eqx
         | 
| @@ -35,6 +36,7 @@ from omegaconf import DictConfig | |
| 35 36 |  | 
| 36 37 | 
             
            from xax.core.conf import field
         | 
| 37 38 | 
             
            from xax.core.state import Phase, State
         | 
| 39 | 
            +
            from xax.nn.functions import set_random_seed
         | 
| 38 40 | 
             
            from xax.nn.parallel import is_master
         | 
| 39 41 | 
             
            from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
         | 
| 40 42 | 
             
            from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
         | 
| @@ -115,7 +117,7 @@ class ValidStepTimer: | |
| 115 117 | 
             
                    if self.last_valid_time is None or self.last_valid_step is None:
         | 
| 116 118 | 
             
                        self.last_valid_time = state.elapsed_time_s
         | 
| 117 119 | 
             
                        self.last_valid_step = state.num_steps
         | 
| 118 | 
            -
                        return  | 
| 120 | 
            +
                        return False
         | 
| 119 121 |  | 
| 120 122 | 
             
                    # Step-based validation.
         | 
| 121 123 | 
             
                    valid_every_n_steps = self.valid_every_n_steps
         | 
| @@ -183,6 +185,9 @@ class TrainMixin( | |
| 183 185 | 
             
                def __init__(self, config: Config) -> None:
         | 
| 184 186 | 
             
                    super().__init__(config)
         | 
| 185 187 |  | 
| 188 | 
            +
                    # Sets the random seed whenever we instantiate a new train mixin.
         | 
| 189 | 
            +
                    set_random_seed(self.config.random_seed)
         | 
| 190 | 
            +
             | 
| 186 191 | 
             
                    # Timer for validation steps.
         | 
| 187 192 | 
             
                    self.valid_step_timer = ValidStepTimer(
         | 
| 188 193 | 
             
                        valid_every_n_steps=config.valid_every_n_steps,
         | 
| @@ -279,31 +284,53 @@ class TrainMixin( | |
| 279 284 | 
             
                def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
         | 
| 280 285 | 
             
                    return optimizer.init(eqx.filter(model, eqx.is_array))
         | 
| 281 286 |  | 
| 287 | 
            +
                @overload
         | 
| 288 | 
            +
                def load_initial_state(
         | 
| 289 | 
            +
                    self,
         | 
| 290 | 
            +
                    key: PRNGKeyArray,
         | 
| 291 | 
            +
                    load_optimizer: Literal[False] = False,
         | 
| 292 | 
            +
                ) -> tuple[PyTree, State]: ...
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                @overload
         | 
| 282 295 | 
             
                def load_initial_state(
         | 
| 283 296 | 
             
                    self,
         | 
| 284 297 | 
             
                    key: PRNGKeyArray,
         | 
| 285 | 
            -
             | 
| 298 | 
            +
                    load_optimizer: Literal[True],
         | 
| 299 | 
            +
                ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                def load_initial_state(
         | 
| 302 | 
            +
                    self,
         | 
| 303 | 
            +
                    key: PRNGKeyArray,
         | 
| 304 | 
            +
                    load_optimizer: bool = False,
         | 
| 305 | 
            +
                ) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
         | 
| 286 306 | 
             
                    init_ckpt_path = self.get_init_ckpt_path()
         | 
| 287 307 |  | 
| 288 308 | 
             
                    if init_ckpt_path is not None:
         | 
| 289 309 | 
             
                        logger.info("Loading checkpoint from %s", init_ckpt_path)
         | 
| 290 | 
            -
                         | 
| 310 | 
            +
                        if load_optimizer:
         | 
| 291 311 | 
             
                            model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
         | 
| 292 312 | 
             
                            config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
         | 
| 293 313 | 
             
                            if config_diff:
         | 
| 294 314 | 
             
                                logger.warning("Loaded config differs from current config:\n%s", config_diff)
         | 
| 295 315 | 
             
                            return model, optimizer, opt_state, state
         | 
| 296 316 |  | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 317 | 
            +
                        else:
         | 
| 318 | 
            +
                            model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
         | 
| 319 | 
            +
                            config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
         | 
| 320 | 
            +
                            if config_diff:
         | 
| 321 | 
            +
                                logger.warning("Loaded config differs from current config:\n%s", config_diff)
         | 
| 322 | 
            +
                            return model, state
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    model = self.get_model(key)
         | 
| 325 | 
            +
                    state = State.init_state()
         | 
| 299 326 |  | 
| 300 | 
            -
                     | 
| 301 | 
            -
                         | 
| 327 | 
            +
                    if not load_optimizer:
         | 
| 328 | 
            +
                        return model, state
         | 
| 302 329 |  | 
| 303 | 
            -
                     | 
| 304 | 
            -
             | 
| 330 | 
            +
                    optimizer = self.get_optimizer()
         | 
| 331 | 
            +
                    opt_state = self.get_initial_opt_state(model, optimizer)
         | 
| 305 332 |  | 
| 306 | 
            -
                    return model, optimizer, opt_state,  | 
| 333 | 
            +
                    return model, optimizer, opt_state, state
         | 
| 307 334 |  | 
| 308 335 | 
             
                @eqx.filter_jit
         | 
| 309 336 | 
             
                def get_output(self, model: PyTree, batch: Batch) -> Output:
         | 
| @@ -424,6 +451,7 @@ class TrainMixin( | |
| 424 451 | 
             
                def log_state(self) -> None:
         | 
| 425 452 | 
             
                    logger.log(LOG_STATUS, self.task_path)
         | 
| 426 453 | 
             
                    logger.log(LOG_STATUS, self.task_name)
         | 
| 454 | 
            +
                    logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
         | 
| 427 455 | 
             
                    self.logger.log_file("git_state.txt", get_git_state(self))
         | 
| 428 456 | 
             
                    self.logger.log_file("training_code.txt", get_training_code(self))
         | 
| 429 457 | 
             
                    self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
         | 
| @@ -456,7 +484,8 @@ class TrainMixin( | |
| 456 484 | 
             
                    while not self.is_training_over(state):
         | 
| 457 485 | 
             
                        if self.valid_step_timer.is_valid_step(state):
         | 
| 458 486 | 
             
                            valid_batch = next(valid_pf)
         | 
| 459 | 
            -
                             | 
| 487 | 
            +
                            with self.step_context("model_step"):
         | 
| 488 | 
            +
                                model, loss, output = self.val_step(model, valid_batch)
         | 
| 460 489 |  | 
| 461 490 | 
             
                            # Perform logging.
         | 
| 462 491 | 
             
                            with self.step_context("write_logs"):
         | 
| @@ -464,22 +493,19 @@ class TrainMixin( | |
| 464 493 | 
             
                                self.log_step(model, valid_batch, output, loss, state)
         | 
| 465 494 | 
             
                                state.num_valid_samples += 1
         | 
| 466 495 |  | 
| 467 | 
            -
                         | 
| 468 | 
            -
                            state = self.on_step_start(state)
         | 
| 496 | 
            +
                        state = self.on_step_start(state)
         | 
| 469 497 |  | 
| 470 | 
            -
                        with self.step_context(" | 
| 498 | 
            +
                        with self.step_context("model_step"):
         | 
| 471 499 | 
             
                            train_batch = next(train_pf)
         | 
| 472 500 | 
             
                            model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
         | 
| 473 501 |  | 
| 474 | 
            -
                        # Perform logging.
         | 
| 475 502 | 
             
                        with self.step_context("write_logs"):
         | 
| 476 503 | 
             
                            state.phase = "train"
         | 
| 477 504 | 
             
                            self.log_step(model, train_batch, output, loss, state)
         | 
| 478 505 | 
             
                            state.num_steps += 1
         | 
| 479 506 | 
             
                            state.num_samples += self.get_size_of_batch(train_batch) or 0
         | 
| 480 507 |  | 
| 481 | 
            -
                         | 
| 482 | 
            -
                            state = self.on_step_end(state)
         | 
| 508 | 
            +
                        state = self.on_step_end(state)
         | 
| 483 509 |  | 
| 484 510 | 
             
                        if self.should_checkpoint(state):
         | 
| 485 511 | 
             
                            self.save_checkpoint(model, optimizer, opt_state, state)
         | 
| @@ -496,14 +522,9 @@ class TrainMixin( | |
| 496 522 | 
             
                    except NotImplementedError:
         | 
| 497 523 | 
             
                        pass
         | 
| 498 524 |  | 
| 499 | 
            -
                     | 
| 500 | 
            -
             | 
| 501 | 
            -
             | 
| 502 | 
            -
                    with self.step_context("get_dataloader"):
         | 
| 503 | 
            -
                        train_dl = self.get_dataloader(train_ds, "train")
         | 
| 504 | 
            -
             | 
| 505 | 
            -
                    with self.step_context("get_prefetcher"):
         | 
| 506 | 
            -
                        train_pf = self.get_prefetcher(train_dl)
         | 
| 525 | 
            +
                    train_ds = self.get_dataset("train")
         | 
| 526 | 
            +
                    train_dl = self.get_dataloader(train_ds, "train")
         | 
| 527 | 
            +
                    train_pf = self.get_prefetcher(train_dl)
         | 
| 507 528 |  | 
| 508 529 | 
             
                    try:
         | 
| 509 530 | 
             
                        with train_pf as train_pf_ctx:
         | 
| @@ -520,14 +541,9 @@ class TrainMixin( | |
| 520 541 | 
             
                    except NotImplementedError:
         | 
| 521 542 | 
             
                        pass
         | 
| 522 543 |  | 
| 523 | 
            -
                     | 
| 524 | 
            -
             | 
| 525 | 
            -
             | 
| 526 | 
            -
                    with self.step_context("get_dataloader"):
         | 
| 527 | 
            -
                        valid_dl = self.get_dataloader(valid_ds, "valid")
         | 
| 528 | 
            -
             | 
| 529 | 
            -
                    with self.step_context("get_prefetcher"):
         | 
| 530 | 
            -
                        valid_pf = self.get_prefetcher(valid_dl)
         | 
| 544 | 
            +
                    valid_ds = self.get_dataset("valid")
         | 
| 545 | 
            +
                    valid_dl = self.get_dataloader(valid_ds, "valid")
         | 
| 546 | 
            +
                    valid_pf = self.get_prefetcher(valid_dl)
         | 
| 531 547 |  | 
| 532 548 | 
             
                    try:
         | 
| 533 549 | 
             
                        with valid_pf as valid_pf_ctx:
         | 
| @@ -559,7 +575,7 @@ class TrainMixin( | |
| 559 575 | 
             
                            Thread(target=self.log_state, daemon=True).start()
         | 
| 560 576 |  | 
| 561 577 | 
             
                        key, model_key = jax.random.split(key)
         | 
| 562 | 
            -
                        model, optimizer, opt_state, state = self.load_initial_state(model_key)
         | 
| 578 | 
            +
                        model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
         | 
| 563 579 | 
             
                        state = self.on_training_start(state)
         | 
| 564 580 |  | 
| 565 581 | 
             
                        def on_exit() -> None:
         |