xax 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -1
- xax/nn/geom.py +34 -0
- xax/task/mixins/checkpointing.py +9 -12
- xax/task/mixins/train.py +133 -19
- {xax-0.2.0.dist-info → xax-0.2.1.dist-info}/METADATA +1 -1
- {xax-0.2.0.dist-info → xax-0.2.1.dist-info}/RECORD +9 -9
- {xax-0.2.0.dist-info → xax-0.2.1.dist-info}/WHEEL +0 -0
- {xax-0.2.0.dist-info → xax-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.0.dist-info → xax-0.2.1.dist-info}/top_level.txt +0 -0
    
        xax/__init__.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ and running the update script: | |
| 12 12 | 
             
                python -m scripts.update_api --inplace
         | 
| 13 13 | 
             
            """
         | 
| 14 14 |  | 
| 15 | 
            -
            __version__ = "0.2. | 
| 15 | 
            +
            __version__ = "0.2.1"
         | 
| 16 16 |  | 
| 17 17 | 
             
            # This list shouldn't be modified by hand; instead, run the update script.
         | 
| 18 18 | 
             
            __all__ = [
         | 
| @@ -44,6 +44,7 @@ __all__ = [ | |
| 44 44 | 
             
                "euler_to_quat",
         | 
| 45 45 | 
             
                "get_projected_gravity_vector_from_quat",
         | 
| 46 46 | 
             
                "quat_to_euler",
         | 
| 47 | 
            +
                "quat_to_rotmat",
         | 
| 47 48 | 
             
                "rotate_vector_by_quat",
         | 
| 48 49 | 
             
                "cross_entropy",
         | 
| 49 50 | 
             
                "cast_norm_type",
         | 
| @@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = { | |
| 206 207 | 
             
                "euler_to_quat": "nn.geom",
         | 
| 207 208 | 
             
                "get_projected_gravity_vector_from_quat": "nn.geom",
         | 
| 208 209 | 
             
                "quat_to_euler": "nn.geom",
         | 
| 210 | 
            +
                "quat_to_rotmat": "nn.geom",
         | 
| 209 211 | 
             
                "rotate_vector_by_quat": "nn.geom",
         | 
| 210 212 | 
             
                "cross_entropy": "nn.losses",
         | 
| 211 213 | 
             
                "cast_norm_type": "nn.norm",
         | 
| @@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING: | |
| 369 371 | 
             
                    euler_to_quat,
         | 
| 370 372 | 
             
                    get_projected_gravity_vector_from_quat,
         | 
| 371 373 | 
             
                    quat_to_euler,
         | 
| 374 | 
            +
                    quat_to_rotmat,
         | 
| 372 375 | 
             
                    rotate_vector_by_quat,
         | 
| 373 376 | 
             
                )
         | 
| 374 377 | 
             
                from xax.nn.losses import cross_entropy
         | 
    
        xax/nn/geom.py
    CHANGED
    
    | @@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array: | |
| 177 177 | 
             
                y_diff = y_end - y_start
         | 
| 178 178 | 
             
                bezier = x**3 + 3 * (x**2 * (1 - x))
         | 
| 179 179 | 
             
                return y_start + y_diff * bezier
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
         | 
| 183 | 
            +
                """Converts a quaternion to a rotation matrix.
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                Args:
         | 
| 186 | 
            +
                    quat: The quaternion to convert, shape (*, 4).
         | 
| 187 | 
            +
                    eps: A small epsilon value to avoid division by zero.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                Returns:
         | 
| 190 | 
            +
                    The rotation matrix, shape (*, 3, 3).
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
                quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
         | 
| 193 | 
            +
                w, x, y, z = jnp.split(quat, 4, axis=-1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                xx = 1 - 2 * (y * y + z * z)
         | 
| 196 | 
            +
                xy = 2 * (x * y - z * w)
         | 
| 197 | 
            +
                xz = 2 * (x * z + y * w)
         | 
| 198 | 
            +
                yx = 2 * (x * y + z * w)
         | 
| 199 | 
            +
                yy = 1 - 2 * (x * x + z * z)
         | 
| 200 | 
            +
                yz = 2 * (y * z - x * w)
         | 
| 201 | 
            +
                zx = 2 * (x * z - y * w)
         | 
| 202 | 
            +
                zy = 2 * (y * z + x * w)
         | 
| 203 | 
            +
                zz = 1 - 2 * (x * x + y * y)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                # Corrected stacking: row-major order
         | 
| 206 | 
            +
                return jnp.concatenate(
         | 
| 207 | 
            +
                    [
         | 
| 208 | 
            +
                        jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
         | 
| 209 | 
            +
                        jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
         | 
| 210 | 
            +
                        jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
         | 
| 211 | 
            +
                    ],
         | 
| 212 | 
            +
                    axis=-2,
         | 
| 213 | 
            +
                )
         | 
    
        xax/task/mixins/checkpointing.py
    CHANGED
    
    | @@ -63,10 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 63 63 |  | 
| 64 64 | 
             
                def get_init_ckpt_path(self) -> Path | None:
         | 
| 65 65 | 
             
                    if self._exp_dir is not None:
         | 
| 66 | 
            -
                        ckpt_path  | 
| 67 | 
            -
                        if not ckpt_path.exists():
         | 
| 68 | 
            -
                            logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
         | 
| 69 | 
            -
                        else:
         | 
| 66 | 
            +
                        if (ckpt_path := self.get_ckpt_path()).exists():
         | 
| 70 67 | 
             
                            return ckpt_path
         | 
| 71 68 | 
             
                    if self.config.load_from_ckpt_path is not None:
         | 
| 72 69 | 
             
                        ckpt_path = Path(self.config.load_from_ckpt_path)
         | 
| @@ -86,7 +83,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 86 83 | 
             
                    return False
         | 
| 87 84 |  | 
| 88 85 | 
             
                @overload
         | 
| 89 | 
            -
                def  | 
| 86 | 
            +
                def load_ckpt_with_template(
         | 
| 90 87 | 
             
                    self,
         | 
| 91 88 | 
             
                    path: Path,
         | 
| 92 89 | 
             
                    *,
         | 
| @@ -97,7 +94,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 97 94 | 
             
                ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
         | 
| 98 95 |  | 
| 99 96 | 
             
                @overload
         | 
| 100 | 
            -
                def  | 
| 97 | 
            +
                def load_ckpt_with_template(
         | 
| 101 98 | 
             
                    self,
         | 
| 102 99 | 
             
                    path: Path,
         | 
| 103 100 | 
             
                    *,
         | 
| @@ -106,7 +103,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 106 103 | 
             
                ) -> tuple[PyTree, State, Config]: ...
         | 
| 107 104 |  | 
| 108 105 | 
             
                @overload
         | 
| 109 | 
            -
                def  | 
| 106 | 
            +
                def load_ckpt_with_template(
         | 
| 110 107 | 
             
                    self,
         | 
| 111 108 | 
             
                    path: Path,
         | 
| 112 109 | 
             
                    *,
         | 
| @@ -115,7 +112,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 115 112 | 
             
                ) -> PyTree: ...
         | 
| 116 113 |  | 
| 117 114 | 
             
                @overload
         | 
| 118 | 
            -
                def  | 
| 115 | 
            +
                def load_ckpt_with_template(
         | 
| 119 116 | 
             
                    self,
         | 
| 120 117 | 
             
                    path: Path,
         | 
| 121 118 | 
             
                    *,
         | 
| @@ -124,7 +121,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 124 121 | 
             
                ) -> optax.GradientTransformation: ...
         | 
| 125 122 |  | 
| 126 123 | 
             
                @overload
         | 
| 127 | 
            -
                def  | 
| 124 | 
            +
                def load_ckpt_with_template(
         | 
| 128 125 | 
             
                    self,
         | 
| 129 126 | 
             
                    path: Path,
         | 
| 130 127 | 
             
                    *,
         | 
| @@ -133,7 +130,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 133 130 | 
             
                ) -> optax.OptState: ...
         | 
| 134 131 |  | 
| 135 132 | 
             
                @overload
         | 
| 136 | 
            -
                def  | 
| 133 | 
            +
                def load_ckpt_with_template(
         | 
| 137 134 | 
             
                    self,
         | 
| 138 135 | 
             
                    path: Path,
         | 
| 139 136 | 
             
                    *,
         | 
| @@ -141,14 +138,14 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): | |
| 141 138 | 
             
                ) -> State: ...
         | 
| 142 139 |  | 
| 143 140 | 
             
                @overload
         | 
| 144 | 
            -
                def  | 
| 141 | 
            +
                def load_ckpt_with_template(
         | 
| 145 142 | 
             
                    self,
         | 
| 146 143 | 
             
                    path: Path,
         | 
| 147 144 | 
             
                    *,
         | 
| 148 145 | 
             
                    part: Literal["config"],
         | 
| 149 146 | 
             
                ) -> Config: ...
         | 
| 150 147 |  | 
| 151 | 
            -
                def  | 
| 148 | 
            +
                def load_ckpt_with_template(
         | 
| 152 149 | 
             
                    self,
         | 
| 153 150 | 
             
                    path: Path,
         | 
| 154 151 | 
             
                    *,
         | 
    
        xax/task/mixins/train.py
    CHANGED
    
    | @@ -12,6 +12,7 @@ import time | |
| 12 12 | 
             
            import traceback
         | 
| 13 13 | 
             
            from abc import ABC, abstractmethod
         | 
| 14 14 | 
             
            from dataclasses import asdict, dataclass, is_dataclass
         | 
| 15 | 
            +
            from pathlib import Path
         | 
| 15 16 | 
             
            from threading import Thread
         | 
| 16 17 | 
             
            from typing import (
         | 
| 17 18 | 
             
                Any,
         | 
| @@ -39,7 +40,7 @@ from xax.core.state import Phase, State | |
| 39 40 | 
             
            from xax.nn.functions import set_random_seed
         | 
| 40 41 | 
             
            from xax.nn.parallel import is_master
         | 
| 41 42 | 
             
            from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
         | 
| 42 | 
            -
            from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
         | 
| 43 | 
            +
            from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
         | 
| 43 44 | 
             
            from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
         | 
| 44 45 | 
             
            from xax.task.mixins.logger import LoggerConfig, LoggerMixin
         | 
| 45 46 | 
             
            from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
         | 
| @@ -54,7 +55,7 @@ from xax.utils.experiments import ( | |
| 54 55 | 
             
                get_training_code,
         | 
| 55 56 | 
             
            )
         | 
| 56 57 | 
             
            from xax.utils.jax import jit as xax_jit
         | 
| 57 | 
            -
            from xax.utils.logging import LOG_STATUS
         | 
| 58 | 
            +
            from xax.utils.logging import LOG_PING, LOG_STATUS
         | 
| 58 59 | 
             
            from xax.utils.text import highlight_exception_message, show_info
         | 
| 59 60 | 
             
            from xax.utils.types.frozen_dict import FrozenDict
         | 
| 60 61 |  | 
| @@ -340,12 +341,7 @@ class TrainMixin( | |
| 340 341 |  | 
| 341 342 | 
             
                    if init_ckpt_path is not None:
         | 
| 342 343 | 
             
                        logger.info("Loading checkpoint from %s", init_ckpt_path)
         | 
| 343 | 
            -
                         | 
| 344 | 
            -
                        model, state, config = self.load_checkpoint(
         | 
| 345 | 
            -
                            init_ckpt_path,
         | 
| 346 | 
            -
                            part="model_state_config",
         | 
| 347 | 
            -
                            model_template=model_spec,
         | 
| 348 | 
            -
                        )
         | 
| 344 | 
            +
                        model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
         | 
| 349 345 | 
             
                        config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
         | 
| 350 346 | 
             
                        if config_diff:
         | 
| 351 347 | 
             
                            logger.warning("Loaded config differs from current config:\n%s", config_diff)
         | 
| @@ -353,17 +349,11 @@ class TrainMixin( | |
| 353 349 | 
             
                        if not load_optimizer:
         | 
| 354 350 | 
             
                            return model, state
         | 
| 355 351 |  | 
| 356 | 
            -
                         | 
| 357 | 
            -
                         | 
| 358 | 
            -
                        optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
         | 
| 359 | 
            -
             | 
| 360 | 
            -
                        # Loads the optimizer state.
         | 
| 361 | 
            -
                        opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
         | 
| 362 | 
            -
                        opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
         | 
| 363 | 
            -
             | 
| 352 | 
            +
                        optimizer = self.load_ckpt(init_ckpt_path, part="opt")
         | 
| 353 | 
            +
                        opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
         | 
| 364 354 | 
             
                        return model, optimizer, opt_state, state
         | 
| 365 355 |  | 
| 366 | 
            -
                    logger.info(" | 
| 356 | 
            +
                    logger.info("Starting a new training run")
         | 
| 367 357 | 
             
                    model = self.get_model(key)
         | 
| 368 358 | 
             
                    state = State.init_state()
         | 
| 369 359 |  | 
| @@ -375,6 +365,131 @@ class TrainMixin( | |
| 375 365 |  | 
| 376 366 | 
             
                    return model, optimizer, opt_state, state
         | 
| 377 367 |  | 
| 368 | 
            +
                @overload
         | 
| 369 | 
            +
                def load_ckpt(
         | 
| 370 | 
            +
                    self,
         | 
| 371 | 
            +
                    path: Path,
         | 
| 372 | 
            +
                    *,
         | 
| 373 | 
            +
                    part: Literal["all"],
         | 
| 374 | 
            +
                ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                @overload
         | 
| 377 | 
            +
                def load_ckpt(
         | 
| 378 | 
            +
                    self,
         | 
| 379 | 
            +
                    path: Path,
         | 
| 380 | 
            +
                    *,
         | 
| 381 | 
            +
                    part: Literal["model_state_config"],
         | 
| 382 | 
            +
                ) -> tuple[PyTree, State, Config]: ...
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                @overload
         | 
| 385 | 
            +
                def load_ckpt(
         | 
| 386 | 
            +
                    self,
         | 
| 387 | 
            +
                    path: Path,
         | 
| 388 | 
            +
                    *,
         | 
| 389 | 
            +
                    part: Literal["model"],
         | 
| 390 | 
            +
                ) -> PyTree: ...
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                @overload
         | 
| 393 | 
            +
                def load_ckpt(
         | 
| 394 | 
            +
                    self,
         | 
| 395 | 
            +
                    path: Path,
         | 
| 396 | 
            +
                    *,
         | 
| 397 | 
            +
                    part: Literal["opt"],
         | 
| 398 | 
            +
                ) -> optax.GradientTransformation: ...
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                @overload
         | 
| 401 | 
            +
                def load_ckpt(
         | 
| 402 | 
            +
                    self,
         | 
| 403 | 
            +
                    path: Path,
         | 
| 404 | 
            +
                    *,
         | 
| 405 | 
            +
                    part: Literal["opt_state"],
         | 
| 406 | 
            +
                    model: PyTree | None = None,
         | 
| 407 | 
            +
                    optimizer: optax.GradientTransformation | None = None,
         | 
| 408 | 
            +
                ) -> optax.OptState: ...
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                @overload
         | 
| 411 | 
            +
                def load_ckpt(
         | 
| 412 | 
            +
                    self,
         | 
| 413 | 
            +
                    path: Path,
         | 
| 414 | 
            +
                    *,
         | 
| 415 | 
            +
                    part: Literal["state"],
         | 
| 416 | 
            +
                ) -> State: ...
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                @overload
         | 
| 419 | 
            +
                def load_ckpt(
         | 
| 420 | 
            +
                    self,
         | 
| 421 | 
            +
                    path: Path,
         | 
| 422 | 
            +
                    *,
         | 
| 423 | 
            +
                    part: Literal["config"],
         | 
| 424 | 
            +
                ) -> Config: ...
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                def load_ckpt(
         | 
| 427 | 
            +
                    self,
         | 
| 428 | 
            +
                    path: str | Path,
         | 
| 429 | 
            +
                    *,
         | 
| 430 | 
            +
                    part: CheckpointPart = "all",
         | 
| 431 | 
            +
                    model: PyTree | None = None,
         | 
| 432 | 
            +
                    optimizer: optax.GradientTransformation | None = None,
         | 
| 433 | 
            +
                ) -> (
         | 
| 434 | 
            +
                    tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
         | 
| 435 | 
            +
                    | tuple[PyTree, State, Config]
         | 
| 436 | 
            +
                    | PyTree
         | 
| 437 | 
            +
                    | optax.GradientTransformation
         | 
| 438 | 
            +
                    | optax.OptState
         | 
| 439 | 
            +
                    | State
         | 
| 440 | 
            +
                    | Config
         | 
| 441 | 
            +
                ):
         | 
| 442 | 
            +
                    path = Path(path)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    # This key isn't used for anything, it's just a required argument.
         | 
| 445 | 
            +
                    key = jax.random.PRNGKey(0)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    match part:
         | 
| 448 | 
            +
                        case "model_state_config":
         | 
| 449 | 
            +
                            model_spec = eqx.filter_eval_shape(self.get_model, key)
         | 
| 450 | 
            +
                            return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        case "model":
         | 
| 453 | 
            +
                            model_spec = eqx.filter_eval_shape(self.get_model, key)
         | 
| 454 | 
            +
                            return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                        case "config":
         | 
| 457 | 
            +
                            return self.load_ckpt_with_template(path, part="config")
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                        case "opt":
         | 
| 460 | 
            +
                            optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
         | 
| 461 | 
            +
                            return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                        case "opt_state":
         | 
| 464 | 
            +
                            if model is None:
         | 
| 465 | 
            +
                                model_spec = eqx.filter_eval_shape(self.get_model, key)
         | 
| 466 | 
            +
                                model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
         | 
| 467 | 
            +
                            if optimizer is None:
         | 
| 468 | 
            +
                                optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
         | 
| 469 | 
            +
                                optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
         | 
| 470 | 
            +
                            opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
         | 
| 471 | 
            +
                            return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                        case "state":
         | 
| 474 | 
            +
                            return self.load_ckpt_with_template(path, part="state")
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                        case "config":
         | 
| 477 | 
            +
                            return self.load_ckpt_with_template(path, part="config")
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                        case "all":
         | 
| 480 | 
            +
                            model_spec = eqx.filter_eval_shape(self.get_model, key)
         | 
| 481 | 
            +
                            model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
         | 
| 482 | 
            +
                            optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
         | 
| 483 | 
            +
                            optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
         | 
| 484 | 
            +
                            opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
         | 
| 485 | 
            +
                            opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
         | 
| 486 | 
            +
                            state = self.load_ckpt_with_template(path, part="state")
         | 
| 487 | 
            +
                            config = self.load_ckpt_with_template(path, part="config")
         | 
| 488 | 
            +
                            return model, optimizer, opt_state, state, config
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                        case _:
         | 
| 491 | 
            +
                            raise ValueError(f"Unknown checkpoint part: {part}")
         | 
| 492 | 
            +
             | 
| 378 493 | 
             
                def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
         | 
| 379 494 | 
             
                    """Gets the output from the model.
         | 
| 380 495 |  | 
| @@ -529,8 +644,7 @@ class TrainMixin( | |
| 529 644 | 
             
                    self._last_printed_remaining_time = state.elapsed_time_s
         | 
| 530 645 | 
             
                    remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
         | 
| 531 646 | 
             
                    termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
         | 
| 532 | 
            -
                     | 
| 533 | 
            -
                    jax.debug.print("Estimated finish time: {}", termination_time)
         | 
| 647 | 
            +
                    logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
         | 
| 534 648 |  | 
| 535 649 | 
             
                def get_remaining_percent(self, state: State) -> float | None:
         | 
| 536 650 | 
             
                    if self.config.max_steps is None:
         | 
| @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            xax/__init__.py,sha256= | 
| 1 | 
            +
            xax/__init__.py,sha256=kd-88OQGnuHb91PXwroAfLb0bMfbe37fXqpECRrjhoU,14182
         | 
| 2 2 | 
             
            xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 3 3 | 
             
            xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
         | 
| 4 4 | 
             
            xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
         | 
| @@ -10,7 +10,7 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847 | |
| 10 10 | 
             
            xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
         | 
| 11 11 | 
             
            xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
         | 
| 12 12 | 
             
            xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
         | 
| 13 | 
            -
            xax/nn/geom.py,sha256= | 
| 13 | 
            +
            xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
         | 
| 14 14 | 
             
            xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
         | 
| 15 15 | 
             
            xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
         | 
| 16 16 | 
             
            xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
         | 
| @@ -32,7 +32,7 @@ xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,65 | |
| 32 32 | 
             
            xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
         | 
| 33 33 | 
             
            xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
         | 
| 34 34 | 
             
            xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
         | 
| 35 | 
            -
            xax/task/mixins/checkpointing.py,sha256= | 
| 35 | 
            +
            xax/task/mixins/checkpointing.py,sha256=2nJgqFcV-D8W-4j8TR3PvVh1g5hQUOo-_quKO-XlE4U,11398
         | 
| 36 36 | 
             
            xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
         | 
| 37 37 | 
             
            xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
         | 
| 38 38 | 
             
            xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
         | 
| @@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280 | |
| 41 41 | 
             
            xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
         | 
| 42 42 | 
             
            xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
         | 
| 43 43 | 
             
            xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
         | 
| 44 | 
            -
            xax/task/mixins/train.py,sha256= | 
| 44 | 
            +
            xax/task/mixins/train.py,sha256=v9oi9tNsNBYo-Ne_98nCG9qHX6sxvymHjsRDnL6GL-U,30871
         | 
| 45 45 | 
             
            xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 46 46 | 
             
            xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
         | 
| 47 47 | 
             
            xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
         | 
| @@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706 | |
| 58 58 | 
             
            xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 59 59 | 
             
            xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
         | 
| 60 60 | 
             
            xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
         | 
| 61 | 
            -
            xax-0.2. | 
| 62 | 
            -
            xax-0.2. | 
| 63 | 
            -
            xax-0.2. | 
| 64 | 
            -
            xax-0.2. | 
| 65 | 
            -
            xax-0.2. | 
| 61 | 
            +
            xax-0.2.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
         | 
| 62 | 
            +
            xax-0.2.1.dist-info/METADATA,sha256=2pOZLKMIcLoQTM-tRqRvVkF57PZyMoALM87UI5B4dtk,1882
         | 
| 63 | 
            +
            xax-0.2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         | 
| 64 | 
            +
            xax-0.2.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
         | 
| 65 | 
            +
            xax-0.2.1.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |