xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
 - xax/core/conf.py +193 -0
 - xax/core/state.py +81 -0
 - xax/nn/__init__.py +0 -0
 - xax/nn/embeddings.py +355 -0
 - xax/nn/functions.py +77 -0
 - xax/nn/parallel.py +211 -0
 - xax/requirements-dev.txt +15 -0
 - xax/requirements.txt +23 -0
 - xax/task/__init__.py +0 -0
 - xax/task/base.py +207 -0
 - xax/task/launchers/__init__.py +0 -0
 - xax/task/launchers/base.py +28 -0
 - xax/task/launchers/cli.py +42 -0
 - xax/task/launchers/single_process.py +30 -0
 - xax/task/launchers/staged.py +29 -0
 - xax/task/logger.py +783 -0
 - xax/task/loggers/__init__.py +0 -0
 - xax/task/loggers/callback.py +56 -0
 - xax/task/loggers/json.py +121 -0
 - xax/task/loggers/state.py +45 -0
 - xax/task/loggers/stdout.py +170 -0
 - xax/task/loggers/tensorboard.py +223 -0
 - xax/task/mixins/__init__.py +12 -0
 - xax/task/mixins/artifacts.py +114 -0
 - xax/task/mixins/checkpointing.py +209 -0
 - xax/task/mixins/cpu_stats.py +251 -0
 - xax/task/mixins/data_loader.py +149 -0
 - xax/task/mixins/gpu_stats.py +257 -0
 - xax/task/mixins/logger.py +66 -0
 - xax/task/mixins/process.py +51 -0
 - xax/task/mixins/runnable.py +63 -0
 - xax/task/mixins/step_wrapper.py +63 -0
 - xax/task/mixins/train.py +541 -0
 - xax/task/script.py +53 -0
 - xax/task/task.py +65 -0
 - xax/utils/__init__.py +0 -0
 - xax/utils/data/__init__.py +0 -0
 - xax/utils/data/collate.py +206 -0
 - xax/utils/experiments.py +802 -0
 - xax/utils/jax.py +14 -0
 - xax/utils/logging.py +223 -0
 - xax/utils/numpy.py +47 -0
 - xax/utils/tensorboard.py +258 -0
 - xax/utils/text.py +350 -0
 - xax-0.0.5.dist-info/METADATA +40 -0
 - xax-0.0.5.dist-info/RECORD +52 -0
 - {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
 - xax-0.0.5.dist-info/top_level.txt +1 -0
 - examples/mnist.py +0 -148
 - xax-0.0.1.dist-info/METADATA +0 -21
 - xax-0.0.1.dist-info/RECORD +0 -9
 - xax-0.0.1.dist-info/top_level.txt +0 -2
 - {examples → xax/core}/__init__.py +0 -0
 - {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
 
    
        xax/__init__.py
    CHANGED
    
    | 
         @@ -1 +1,256 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
             
     | 
| 
      
 1 
     | 
    
         
            +
            """Defines the top-level xax API.
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            This package is structured so that all the important stuff can be accessed
         
     | 
| 
      
 4 
     | 
    
         
            +
            without having to dig around through the internals. This is done by lazily
         
     | 
| 
      
 5 
     | 
    
         
            +
            importing the module by name.
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            This file can be maintained by running the update script:
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            .. code-block:: bash
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
                python -m scripts.update_api --inplace
         
     | 
| 
      
 12 
     | 
    
         
            +
            """
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            __version__ = "0.0.5"
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            # This list shouldn't be modified by hand; instead, run the update script.
         
     | 
| 
      
 17 
     | 
    
         
            +
            __all__ = [
         
     | 
| 
      
 18 
     | 
    
         
            +
                "UserConfig",
         
     | 
| 
      
 19 
     | 
    
         
            +
                "field",
         
     | 
| 
      
 20 
     | 
    
         
            +
                "get_data_dir",
         
     | 
| 
      
 21 
     | 
    
         
            +
                "get_pretrained_models_dir",
         
     | 
| 
      
 22 
     | 
    
         
            +
                "get_run_dir",
         
     | 
| 
      
 23 
     | 
    
         
            +
                "load_user_config",
         
     | 
| 
      
 24 
     | 
    
         
            +
                "State",
         
     | 
| 
      
 25 
     | 
    
         
            +
                "cast_phase",
         
     | 
| 
      
 26 
     | 
    
         
            +
                "FourierEmbeddings",
         
     | 
| 
      
 27 
     | 
    
         
            +
                "IdentityPositionalEmbeddings",
         
     | 
| 
      
 28 
     | 
    
         
            +
                "LearnedPositionalEmbeddings",
         
     | 
| 
      
 29 
     | 
    
         
            +
                "RotaryEmbeddings",
         
     | 
| 
      
 30 
     | 
    
         
            +
                "SinusoidalEmbeddings",
         
     | 
| 
      
 31 
     | 
    
         
            +
                "apply_rotary_embeddings",
         
     | 
| 
      
 32 
     | 
    
         
            +
                "cast_embedding_kind",
         
     | 
| 
      
 33 
     | 
    
         
            +
                "fourier_embeddings",
         
     | 
| 
      
 34 
     | 
    
         
            +
                "get_positional_embeddings",
         
     | 
| 
      
 35 
     | 
    
         
            +
                "get_rotary_embeddings",
         
     | 
| 
      
 36 
     | 
    
         
            +
                "rotary_embeddings",
         
     | 
| 
      
 37 
     | 
    
         
            +
                "BaseLauncher",
         
     | 
| 
      
 38 
     | 
    
         
            +
                "CliLauncher",
         
     | 
| 
      
 39 
     | 
    
         
            +
                "SingleProcessLauncher",
         
     | 
| 
      
 40 
     | 
    
         
            +
                "LogImage",
         
     | 
| 
      
 41 
     | 
    
         
            +
                "LogLine",
         
     | 
| 
      
 42 
     | 
    
         
            +
                "Logger",
         
     | 
| 
      
 43 
     | 
    
         
            +
                "LoggerImpl",
         
     | 
| 
      
 44 
     | 
    
         
            +
                "CallbackLogger",
         
     | 
| 
      
 45 
     | 
    
         
            +
                "JsonLogger",
         
     | 
| 
      
 46 
     | 
    
         
            +
                "StateLogger",
         
     | 
| 
      
 47 
     | 
    
         
            +
                "StdoutLogger",
         
     | 
| 
      
 48 
     | 
    
         
            +
                "TensorboardLogger",
         
     | 
| 
      
 49 
     | 
    
         
            +
                "CPUStatsOptions",
         
     | 
| 
      
 50 
     | 
    
         
            +
                "DataloaderConfig",
         
     | 
| 
      
 51 
     | 
    
         
            +
                "GPUStatsOptions",
         
     | 
| 
      
 52 
     | 
    
         
            +
                "Script",
         
     | 
| 
      
 53 
     | 
    
         
            +
                "ScriptConfig",
         
     | 
| 
      
 54 
     | 
    
         
            +
                "Config",
         
     | 
| 
      
 55 
     | 
    
         
            +
                "Task",
         
     | 
| 
      
 56 
     | 
    
         
            +
                "collate",
         
     | 
| 
      
 57 
     | 
    
         
            +
                "collate_non_null",
         
     | 
| 
      
 58 
     | 
    
         
            +
                "BaseFileDownloader",
         
     | 
| 
      
 59 
     | 
    
         
            +
                "DataDownloader",
         
     | 
| 
      
 60 
     | 
    
         
            +
                "ModelDownloader",
         
     | 
| 
      
 61 
     | 
    
         
            +
                "check_md5",
         
     | 
| 
      
 62 
     | 
    
         
            +
                "check_sha256",
         
     | 
| 
      
 63 
     | 
    
         
            +
                "get_git_state",
         
     | 
| 
      
 64 
     | 
    
         
            +
                "get_state_dict_prefix",
         
     | 
| 
      
 65 
     | 
    
         
            +
                "get_training_code",
         
     | 
| 
      
 66 
     | 
    
         
            +
                "save_config",
         
     | 
| 
      
 67 
     | 
    
         
            +
                "ColoredFormatter",
         
     | 
| 
      
 68 
     | 
    
         
            +
                "configure_logging",
         
     | 
| 
      
 69 
     | 
    
         
            +
                "one_hot",
         
     | 
| 
      
 70 
     | 
    
         
            +
                "partial_flatten",
         
     | 
| 
      
 71 
     | 
    
         
            +
                "worker_chunk",
         
     | 
| 
      
 72 
     | 
    
         
            +
                "TextBlock",
         
     | 
| 
      
 73 
     | 
    
         
            +
                "colored",
         
     | 
| 
      
 74 
     | 
    
         
            +
                "format_datetime",
         
     | 
| 
      
 75 
     | 
    
         
            +
                "format_timedelta",
         
     | 
| 
      
 76 
     | 
    
         
            +
                "outlined",
         
     | 
| 
      
 77 
     | 
    
         
            +
                "render_text_blocks",
         
     | 
| 
      
 78 
     | 
    
         
            +
                "show_error",
         
     | 
| 
      
 79 
     | 
    
         
            +
                "show_warning",
         
     | 
| 
      
 80 
     | 
    
         
            +
                "uncolored",
         
     | 
| 
      
 81 
     | 
    
         
            +
                "wrapped",
         
     | 
| 
      
 82 
     | 
    
         
            +
            ]
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
            __all__ += [
         
     | 
| 
      
 85 
     | 
    
         
            +
                "Batch",
         
     | 
| 
      
 86 
     | 
    
         
            +
                "CollateMode",
         
     | 
| 
      
 87 
     | 
    
         
            +
                "EmbeddingKind",
         
     | 
| 
      
 88 
     | 
    
         
            +
                "Output",
         
     | 
| 
      
 89 
     | 
    
         
            +
                "Phase",
         
     | 
| 
      
 90 
     | 
    
         
            +
            ]
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
            import os
         
     | 
| 
      
 93 
     | 
    
         
            +
            from typing import TYPE_CHECKING
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
            # If this flag is set, eagerly imports the entire package (not recommended).
         
     | 
| 
      
 96 
     | 
    
         
            +
            IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
            del os
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
            # This dictionary is auto-generated and shouldn't be modified by hand; instead,
         
     | 
| 
      
 101 
     | 
    
         
            +
            # run the update script.
         
     | 
| 
      
 102 
     | 
    
         
            +
            NAME_MAP: dict[str, str] = {
         
     | 
| 
      
 103 
     | 
    
         
            +
                "UserConfig": "core.conf",
         
     | 
| 
      
 104 
     | 
    
         
            +
                "field": "core.conf",
         
     | 
| 
      
 105 
     | 
    
         
            +
                "get_data_dir": "core.conf",
         
     | 
| 
      
 106 
     | 
    
         
            +
                "get_pretrained_models_dir": "core.conf",
         
     | 
| 
      
 107 
     | 
    
         
            +
                "get_run_dir": "core.conf",
         
     | 
| 
      
 108 
     | 
    
         
            +
                "load_user_config": "core.conf",
         
     | 
| 
      
 109 
     | 
    
         
            +
                "State": "core.state",
         
     | 
| 
      
 110 
     | 
    
         
            +
                "cast_phase": "core.state",
         
     | 
| 
      
 111 
     | 
    
         
            +
                "FourierEmbeddings": "nn.embeddings",
         
     | 
| 
      
 112 
     | 
    
         
            +
                "IdentityPositionalEmbeddings": "nn.embeddings",
         
     | 
| 
      
 113 
     | 
    
         
            +
                "LearnedPositionalEmbeddings": "nn.embeddings",
         
     | 
| 
      
 114 
     | 
    
         
            +
                "RotaryEmbeddings": "nn.embeddings",
         
     | 
| 
      
 115 
     | 
    
         
            +
                "SinusoidalEmbeddings": "nn.embeddings",
         
     | 
| 
      
 116 
     | 
    
         
            +
                "apply_rotary_embeddings": "nn.embeddings",
         
     | 
| 
      
 117 
     | 
    
         
            +
                "cast_embedding_kind": "nn.embeddings",
         
     | 
| 
      
 118 
     | 
    
         
            +
                "fourier_embeddings": "nn.embeddings",
         
     | 
| 
      
 119 
     | 
    
         
            +
                "get_positional_embeddings": "nn.embeddings",
         
     | 
| 
      
 120 
     | 
    
         
            +
                "get_rotary_embeddings": "nn.embeddings",
         
     | 
| 
      
 121 
     | 
    
         
            +
                "rotary_embeddings": "nn.embeddings",
         
     | 
| 
      
 122 
     | 
    
         
            +
                "BaseLauncher": "task.launchers.base",
         
     | 
| 
      
 123 
     | 
    
         
            +
                "CliLauncher": "task.launchers.cli",
         
     | 
| 
      
 124 
     | 
    
         
            +
                "SingleProcessLauncher": "task.launchers.single_process",
         
     | 
| 
      
 125 
     | 
    
         
            +
                "LogImage": "task.logger",
         
     | 
| 
      
 126 
     | 
    
         
            +
                "LogLine": "task.logger",
         
     | 
| 
      
 127 
     | 
    
         
            +
                "Logger": "task.logger",
         
     | 
| 
      
 128 
     | 
    
         
            +
                "LoggerImpl": "task.logger",
         
     | 
| 
      
 129 
     | 
    
         
            +
                "CallbackLogger": "task.loggers.callback",
         
     | 
| 
      
 130 
     | 
    
         
            +
                "JsonLogger": "task.loggers.json",
         
     | 
| 
      
 131 
     | 
    
         
            +
                "StateLogger": "task.loggers.state",
         
     | 
| 
      
 132 
     | 
    
         
            +
                "StdoutLogger": "task.loggers.stdout",
         
     | 
| 
      
 133 
     | 
    
         
            +
                "TensorboardLogger": "task.loggers.tensorboard",
         
     | 
| 
      
 134 
     | 
    
         
            +
                "CPUStatsOptions": "task.mixins.cpu_stats",
         
     | 
| 
      
 135 
     | 
    
         
            +
                "DataloaderConfig": "task.mixins.data_loader",
         
     | 
| 
      
 136 
     | 
    
         
            +
                "GPUStatsOptions": "task.mixins.gpu_stats",
         
     | 
| 
      
 137 
     | 
    
         
            +
                "Script": "task.script",
         
     | 
| 
      
 138 
     | 
    
         
            +
                "ScriptConfig": "task.script",
         
     | 
| 
      
 139 
     | 
    
         
            +
                "Config": "task.task",
         
     | 
| 
      
 140 
     | 
    
         
            +
                "Task": "task.task",
         
     | 
| 
      
 141 
     | 
    
         
            +
                "collate": "utils.data.collate",
         
     | 
| 
      
 142 
     | 
    
         
            +
                "collate_non_null": "utils.data.collate",
         
     | 
| 
      
 143 
     | 
    
         
            +
                "BaseFileDownloader": "utils.experiments",
         
     | 
| 
      
 144 
     | 
    
         
            +
                "DataDownloader": "utils.experiments",
         
     | 
| 
      
 145 
     | 
    
         
            +
                "ModelDownloader": "utils.experiments",
         
     | 
| 
      
 146 
     | 
    
         
            +
                "check_md5": "utils.experiments",
         
     | 
| 
      
 147 
     | 
    
         
            +
                "check_sha256": "utils.experiments",
         
     | 
| 
      
 148 
     | 
    
         
            +
                "get_git_state": "utils.experiments",
         
     | 
| 
      
 149 
     | 
    
         
            +
                "get_state_dict_prefix": "utils.experiments",
         
     | 
| 
      
 150 
     | 
    
         
            +
                "get_training_code": "utils.experiments",
         
     | 
| 
      
 151 
     | 
    
         
            +
                "save_config": "utils.experiments",
         
     | 
| 
      
 152 
     | 
    
         
            +
                "ColoredFormatter": "utils.logging",
         
     | 
| 
      
 153 
     | 
    
         
            +
                "configure_logging": "utils.logging",
         
     | 
| 
      
 154 
     | 
    
         
            +
                "one_hot": "utils.numpy",
         
     | 
| 
      
 155 
     | 
    
         
            +
                "partial_flatten": "utils.numpy",
         
     | 
| 
      
 156 
     | 
    
         
            +
                "worker_chunk": "utils.numpy",
         
     | 
| 
      
 157 
     | 
    
         
            +
                "TextBlock": "utils.text",
         
     | 
| 
      
 158 
     | 
    
         
            +
                "colored": "utils.text",
         
     | 
| 
      
 159 
     | 
    
         
            +
                "format_datetime": "utils.text",
         
     | 
| 
      
 160 
     | 
    
         
            +
                "format_timedelta": "utils.text",
         
     | 
| 
      
 161 
     | 
    
         
            +
                "outlined": "utils.text",
         
     | 
| 
      
 162 
     | 
    
         
            +
                "render_text_blocks": "utils.text",
         
     | 
| 
      
 163 
     | 
    
         
            +
                "show_error": "utils.text",
         
     | 
| 
      
 164 
     | 
    
         
            +
                "show_warning": "utils.text",
         
     | 
| 
      
 165 
     | 
    
         
            +
                "uncolored": "utils.text",
         
     | 
| 
      
 166 
     | 
    
         
            +
                "wrapped": "utils.text",
         
     | 
| 
      
 167 
     | 
    
         
            +
            }
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
            # Need to manually set some values which can't be auto-generated.
         
     | 
| 
      
 170 
     | 
    
         
            +
            NAME_MAP.update(
         
     | 
| 
      
 171 
     | 
    
         
            +
                {
         
     | 
| 
      
 172 
     | 
    
         
            +
                    "Batch": "task.mixins.train",
         
     | 
| 
      
 173 
     | 
    
         
            +
                    "CollateMode": "utils.data.collate",
         
     | 
| 
      
 174 
     | 
    
         
            +
                    "EmbeddingKind": "nn.embeddings",
         
     | 
| 
      
 175 
     | 
    
         
            +
                    "Output": "task.mixins.output",
         
     | 
| 
      
 176 
     | 
    
         
            +
                    "Phase": "core.state",
         
     | 
| 
      
 177 
     | 
    
         
            +
                },
         
     | 
| 
      
 178 
     | 
    
         
            +
            )
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
            def __getattr__(name: str) -> object:
         
     | 
| 
      
 182 
     | 
    
         
            +
                if name not in NAME_MAP:
         
     | 
| 
      
 183 
     | 
    
         
            +
                    raise AttributeError(f"{__name__} has no attribute {name}")
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
                module_name = f"xax.{NAME_MAP[name]}"
         
     | 
| 
      
 186 
     | 
    
         
            +
                module = __import__(module_name, fromlist=[name])
         
     | 
| 
      
 187 
     | 
    
         
            +
                return getattr(module, name)
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
             
     | 
| 
      
 190 
     | 
    
         
            +
            if IMPORT_ALL or TYPE_CHECKING:
         
     | 
| 
      
 191 
     | 
    
         
            +
                from xax.core.conf import (
         
     | 
| 
      
 192 
     | 
    
         
            +
                    UserConfig,
         
     | 
| 
      
 193 
     | 
    
         
            +
                    field,
         
     | 
| 
      
 194 
     | 
    
         
            +
                    get_data_dir,
         
     | 
| 
      
 195 
     | 
    
         
            +
                    get_pretrained_models_dir,
         
     | 
| 
      
 196 
     | 
    
         
            +
                    get_run_dir,
         
     | 
| 
      
 197 
     | 
    
         
            +
                    load_user_config,
         
     | 
| 
      
 198 
     | 
    
         
            +
                )
         
     | 
| 
      
 199 
     | 
    
         
            +
                from xax.core.state import Phase, State, cast_phase
         
     | 
| 
      
 200 
     | 
    
         
            +
                from xax.nn.embeddings import (
         
     | 
| 
      
 201 
     | 
    
         
            +
                    EmbeddingKind,
         
     | 
| 
      
 202 
     | 
    
         
            +
                    FourierEmbeddings,
         
     | 
| 
      
 203 
     | 
    
         
            +
                    IdentityPositionalEmbeddings,
         
     | 
| 
      
 204 
     | 
    
         
            +
                    LearnedPositionalEmbeddings,
         
     | 
| 
      
 205 
     | 
    
         
            +
                    RotaryEmbeddings,
         
     | 
| 
      
 206 
     | 
    
         
            +
                    SinusoidalEmbeddings,
         
     | 
| 
      
 207 
     | 
    
         
            +
                    apply_rotary_embeddings,
         
     | 
| 
      
 208 
     | 
    
         
            +
                    cast_embedding_kind,
         
     | 
| 
      
 209 
     | 
    
         
            +
                    fourier_embeddings,
         
     | 
| 
      
 210 
     | 
    
         
            +
                    get_positional_embeddings,
         
     | 
| 
      
 211 
     | 
    
         
            +
                    get_rotary_embeddings,
         
     | 
| 
      
 212 
     | 
    
         
            +
                    rotary_embeddings,
         
     | 
| 
      
 213 
     | 
    
         
            +
                )
         
     | 
| 
      
 214 
     | 
    
         
            +
                from xax.task.launchers.base import BaseLauncher
         
     | 
| 
      
 215 
     | 
    
         
            +
                from xax.task.launchers.cli import CliLauncher
         
     | 
| 
      
 216 
     | 
    
         
            +
                from xax.task.launchers.single_process import SingleProcessLauncher
         
     | 
| 
      
 217 
     | 
    
         
            +
                from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
         
     | 
| 
      
 218 
     | 
    
         
            +
                from xax.task.loggers.callback import CallbackLogger
         
     | 
| 
      
 219 
     | 
    
         
            +
                from xax.task.loggers.json import JsonLogger
         
     | 
| 
      
 220 
     | 
    
         
            +
                from xax.task.loggers.state import StateLogger
         
     | 
| 
      
 221 
     | 
    
         
            +
                from xax.task.loggers.stdout import StdoutLogger
         
     | 
| 
      
 222 
     | 
    
         
            +
                from xax.task.loggers.tensorboard import TensorboardLogger
         
     | 
| 
      
 223 
     | 
    
         
            +
                from xax.task.mixins.cpu_stats import CPUStatsOptions
         
     | 
| 
      
 224 
     | 
    
         
            +
                from xax.task.mixins.data_loader import DataloaderConfig
         
     | 
| 
      
 225 
     | 
    
         
            +
                from xax.task.mixins.gpu_stats import GPUStatsOptions
         
     | 
| 
      
 226 
     | 
    
         
            +
                from xax.task.mixins.train import Batch, Output
         
     | 
| 
      
 227 
     | 
    
         
            +
                from xax.task.script import Script, ScriptConfig
         
     | 
| 
      
 228 
     | 
    
         
            +
                from xax.task.task import Config, Task
         
     | 
| 
      
 229 
     | 
    
         
            +
                from xax.utils.data.collate import CollateMode, collate, collate_non_null
         
     | 
| 
      
 230 
     | 
    
         
            +
                from xax.utils.experiments import (
         
     | 
| 
      
 231 
     | 
    
         
            +
                    BaseFileDownloader,
         
     | 
| 
      
 232 
     | 
    
         
            +
                    DataDownloader,
         
     | 
| 
      
 233 
     | 
    
         
            +
                    ModelDownloader,
         
     | 
| 
      
 234 
     | 
    
         
            +
                    check_md5,
         
     | 
| 
      
 235 
     | 
    
         
            +
                    check_sha256,
         
     | 
| 
      
 236 
     | 
    
         
            +
                    get_git_state,
         
     | 
| 
      
 237 
     | 
    
         
            +
                    get_state_dict_prefix,
         
     | 
| 
      
 238 
     | 
    
         
            +
                    get_training_code,
         
     | 
| 
      
 239 
     | 
    
         
            +
                    save_config,
         
     | 
| 
      
 240 
     | 
    
         
            +
                )
         
     | 
| 
      
 241 
     | 
    
         
            +
                from xax.utils.logging import ColoredFormatter, configure_logging
         
     | 
| 
      
 242 
     | 
    
         
            +
                from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
         
     | 
| 
      
 243 
     | 
    
         
            +
                from xax.utils.text import (
         
     | 
| 
      
 244 
     | 
    
         
            +
                    TextBlock,
         
     | 
| 
      
 245 
     | 
    
         
            +
                    colored,
         
     | 
| 
      
 246 
     | 
    
         
            +
                    format_datetime,
         
     | 
| 
      
 247 
     | 
    
         
            +
                    format_timedelta,
         
     | 
| 
      
 248 
     | 
    
         
            +
                    outlined,
         
     | 
| 
      
 249 
     | 
    
         
            +
                    render_text_blocks,
         
     | 
| 
      
 250 
     | 
    
         
            +
                    show_error,
         
     | 
| 
      
 251 
     | 
    
         
            +
                    show_warning,
         
     | 
| 
      
 252 
     | 
    
         
            +
                    uncolored,
         
     | 
| 
      
 253 
     | 
    
         
            +
                    wrapped,
         
     | 
| 
      
 254 
     | 
    
         
            +
                )
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
            del TYPE_CHECKING, IMPORT_ALL
         
     | 
    
        xax/core/conf.py
    ADDED
    
    | 
         @@ -0,0 +1,193 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """Defines base configuration functions and utilities."""
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import functools
         
     | 
| 
      
 4 
     | 
    
         
            +
            import os
         
     | 
| 
      
 5 
     | 
    
         
            +
            from dataclasses import dataclass, field as field_base
         
     | 
| 
      
 6 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 7 
     | 
    
         
            +
            from typing import Any, cast
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 10 
     | 
    
         
            +
            from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            from xax.utils.text import show_error
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            FieldType = Any
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            def field(value: FieldType, **kwargs: str) -> FieldType:
         
     | 
| 
      
 18 
     | 
    
         
            +
                """Short-hand function for getting a config field.
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 21 
     | 
    
         
            +
                    value: The current field's default value.
         
     | 
| 
      
 22 
     | 
    
         
            +
                    kwargs: Additional metadata fields to supply.
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 25 
     | 
    
         
            +
                    The dataclass field.
         
     | 
| 
      
 26 
     | 
    
         
            +
                """
         
     | 
| 
      
 27 
     | 
    
         
            +
                metadata: dict[str, Any] = {}
         
     | 
| 
      
 28 
     | 
    
         
            +
                metadata.update(kwargs)
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                if hasattr(value, "__call__"):
         
     | 
| 
      
 31 
     | 
    
         
            +
                    return field_base(default_factory=value, metadata=metadata)
         
     | 
| 
      
 32 
     | 
    
         
            +
                if value.__class__.__hash__ is None:
         
     | 
| 
      
 33 
     | 
    
         
            +
                    return field_base(default_factory=lambda: value, metadata=metadata)
         
     | 
| 
      
 34 
     | 
    
         
            +
                return field_base(default=value, metadata=metadata)
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
            def is_missing(cfg: Any, key: str) -> bool:  # noqa: ANN401
         
     | 
| 
      
 38 
     | 
    
         
            +
                """Utility function for checking if a config key is missing.
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                This is for cases when you are using a raw dataclass rather than an
         
     | 
| 
      
 41 
     | 
    
         
            +
                OmegaConf container but want to treat them the same way.
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 44 
     | 
    
         
            +
                    cfg: The config to check
         
     | 
| 
      
 45 
     | 
    
         
            +
                    key: The key to check
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 48 
     | 
    
         
            +
                    Whether or not the key is missing a value in the config
         
     | 
| 
      
 49 
     | 
    
         
            +
                """
         
     | 
| 
      
 50 
     | 
    
         
            +
                if isinstance(cfg, OmegaConfContainer):
         
     | 
| 
      
 51 
     | 
    
         
            +
                    if OmegaConf.is_missing(cfg, key):
         
     | 
| 
      
 52 
     | 
    
         
            +
                        return True
         
     | 
| 
      
 53 
     | 
    
         
            +
                    if OmegaConf.is_interpolation(cfg, key):
         
     | 
| 
      
 54 
     | 
    
         
            +
                        try:
         
     | 
| 
      
 55 
     | 
    
         
            +
                            getattr(cfg, key)
         
     | 
| 
      
 56 
     | 
    
         
            +
                            return False
         
     | 
| 
      
 57 
     | 
    
         
            +
                        except Exception:
         
     | 
| 
      
 58 
     | 
    
         
            +
                            return True
         
     | 
| 
      
 59 
     | 
    
         
            +
                if getattr(cfg, key) is MISSING:
         
     | 
| 
      
 60 
     | 
    
         
            +
                    return True
         
     | 
| 
      
 61 
     | 
    
         
            +
                return False
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 65 
     | 
    
         
            +
            class Logging:
         
     | 
| 
      
 66 
     | 
    
         
            +
                hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
         
     | 
| 
      
 67 
     | 
    
         
            +
                log_level: str = field("INFO", help="The logging level to use")
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 71 
     | 
    
         
            +
            class Device:
         
     | 
| 
      
 72 
     | 
    
         
            +
                cpu: bool = field(True, help="Whether to use the CPU")
         
     | 
| 
      
 73 
     | 
    
         
            +
                gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
         
     | 
| 
      
 74 
     | 
    
         
            +
                metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
         
     | 
| 
      
 75 
     | 
    
         
            +
                use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
         
     | 
| 
      
 76 
     | 
    
         
            +
                use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
         
     | 
| 
      
 77 
     | 
    
         
            +
                use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
         
     | 
| 
      
 78 
     | 
    
         
            +
                use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
            def parse_dtype(cfg: Device) -> jnp.dtype | None:
         
     | 
| 
      
 82 
     | 
    
         
            +
                if cfg.use_fp64:
         
     | 
| 
      
 83 
     | 
    
         
            +
                    return jnp.float64
         
     | 
| 
      
 84 
     | 
    
         
            +
                if cfg.use_fp32:
         
     | 
| 
      
 85 
     | 
    
         
            +
                    return jnp.float32
         
     | 
| 
      
 86 
     | 
    
         
            +
                if cfg.use_bf16:
         
     | 
| 
      
 87 
     | 
    
         
            +
                    return jnp.bfloat16
         
     | 
| 
      
 88 
     | 
    
         
            +
                if cfg.use_fp16:
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return jnp.float16
         
     | 
| 
      
 90 
     | 
    
         
            +
                return None
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 94 
     | 
    
         
            +
            class Triton:
         
     | 
| 
      
 95 
     | 
    
         
            +
                use_triton_if_available: bool = field(True, help="Use Triton if available")
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 99 
     | 
    
         
            +
            class Experiment:
         
     | 
| 
      
 100 
     | 
    
         
            +
                default_random_seed: int = field(1337, help="The default random seed to use")
         
     | 
| 
      
 101 
     | 
    
         
            +
                max_workers: int = field(32, help="Maximum number of workers to use")
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 105 
     | 
    
         
            +
            class Directories:
         
     | 
| 
      
 106 
     | 
    
         
            +
                run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
         
     | 
| 
      
 107 
     | 
    
         
            +
                data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
         
     | 
| 
      
 108 
     | 
    
         
            +
                pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 112 
     | 
    
         
            +
            class SlurmPartition:
         
     | 
| 
      
 113 
     | 
    
         
            +
                partition: str = field(MISSING, help="The partition name")
         
     | 
| 
      
 114 
     | 
    
         
            +
                num_nodes: int = field(1, help="The number of nodes to use")
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 118 
     | 
    
         
            +
            class Slurm:
         
     | 
| 
      
 119 
     | 
    
         
            +
                launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 123 
     | 
    
         
            +
            class UserConfig:
         
     | 
| 
      
 124 
     | 
    
         
            +
                logging: Logging = field(Logging)
         
     | 
| 
      
 125 
     | 
    
         
            +
                device: Device = field(Device)
         
     | 
| 
      
 126 
     | 
    
         
            +
                triton: Triton = field(Triton)
         
     | 
| 
      
 127 
     | 
    
         
            +
                experiment: Experiment = field(Experiment)
         
     | 
| 
      
 128 
     | 
    
         
            +
                directories: Directories = field(Directories)
         
     | 
| 
      
 129 
     | 
    
         
            +
                slurm: Slurm = field(Slurm)
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
            def user_config_path() -> Path:
         
     | 
| 
      
 133 
     | 
    
         
            +
                xaxrc_path_raw = os.environ.get("XAXRC_PATH", "~/.xax.yml")
         
     | 
| 
      
 134 
     | 
    
         
            +
                xaxrc_path = Path(xaxrc_path_raw).expanduser()
         
     | 
| 
      
 135 
     | 
    
         
            +
                return xaxrc_path
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
            @functools.lru_cache(maxsize=None)
         
     | 
| 
      
 139 
     | 
    
         
            +
            def _load_user_config_cached() -> UserConfig:
         
     | 
| 
      
 140 
     | 
    
         
            +
                xaxrc_path = user_config_path()
         
     | 
| 
      
 141 
     | 
    
         
            +
                base_cfg = OmegaConf.structured(UserConfig)
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                # Writes the config file.
         
     | 
| 
      
 144 
     | 
    
         
            +
                if xaxrc_path.exists():
         
     | 
| 
      
 145 
     | 
    
         
            +
                    cfg = OmegaConf.merge(base_cfg, OmegaConf.load(xaxrc_path))
         
     | 
| 
      
 146 
     | 
    
         
            +
                else:
         
     | 
| 
      
 147 
     | 
    
         
            +
                    show_error(f"No config file was found in {xaxrc_path}; writing one...", important=True)
         
     | 
| 
      
 148 
     | 
    
         
            +
                    OmegaConf.save(base_cfg, xaxrc_path)
         
     | 
| 
      
 149 
     | 
    
         
            +
                    cfg = base_cfg
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                # Looks in the current directory for a config file.
         
     | 
| 
      
 152 
     | 
    
         
            +
                local_cfg_path = Path("xax.yml")
         
     | 
| 
      
 153 
     | 
    
         
            +
                if local_cfg_path.exists():
         
     | 
| 
      
 154 
     | 
    
         
            +
                    cfg = OmegaConf.merge(cfg, OmegaConf.load(local_cfg_path))
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                return cast(UserConfig, cfg)
         
     | 
| 
      
 157 
     | 
    
         
            +
             
     | 
| 
      
 158 
     | 
    
         
            +
             
     | 
| 
      
 159 
     | 
    
         
            +
            def load_user_config() -> UserConfig:
         
     | 
| 
      
 160 
     | 
    
         
            +
                """Loads the ``~/.xax.yml`` configuration file.
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 163 
     | 
    
         
            +
                    The loaded configuration.
         
     | 
| 
      
 164 
     | 
    
         
            +
                """
         
     | 
| 
      
 165 
     | 
    
         
            +
                return _load_user_config_cached()
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
             
     | 
| 
      
 168 
     | 
    
         
            +
            def get_run_dir() -> Path | None:
         
     | 
| 
      
 169 
     | 
    
         
            +
                config = load_user_config().directories
         
     | 
| 
      
 170 
     | 
    
         
            +
                if is_missing(config, "run"):
         
     | 
| 
      
 171 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 172 
     | 
    
         
            +
                (run_dir := Path(config.run)).mkdir(parents=True, exist_ok=True)
         
     | 
| 
      
 173 
     | 
    
         
            +
                return run_dir
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
             
     | 
| 
      
 176 
     | 
    
         
            +
            def get_data_dir() -> Path:
         
     | 
| 
      
 177 
     | 
    
         
            +
                config = load_user_config().directories
         
     | 
| 
      
 178 
     | 
    
         
            +
                if is_missing(config, "data"):
         
     | 
| 
      
 179 
     | 
    
         
            +
                    raise RuntimeError(
         
     | 
| 
      
 180 
     | 
    
         
            +
                        "The data directory has not been set! You should set it in your config file "
         
     | 
| 
      
 181 
     | 
    
         
            +
                        f"in {user_config_path()} or set the DATA_DIR environment variable."
         
     | 
| 
      
 182 
     | 
    
         
            +
                    )
         
     | 
| 
      
 183 
     | 
    
         
            +
                return Path(config.data)
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
            def get_pretrained_models_dir() -> Path:
         
     | 
| 
      
 187 
     | 
    
         
            +
                config = load_user_config().directories
         
     | 
| 
      
 188 
     | 
    
         
            +
                if is_missing(config, "pretrained_models"):
         
     | 
| 
      
 189 
     | 
    
         
            +
                    raise RuntimeError(
         
     | 
| 
      
 190 
     | 
    
         
            +
                        "The data directory has not been set! You should set it in your config file "
         
     | 
| 
      
 191 
     | 
    
         
            +
                        f"in {user_config_path()} or set the MODEL_DIR environment variable."
         
     | 
| 
      
 192 
     | 
    
         
            +
                    )
         
     | 
| 
      
 193 
     | 
    
         
            +
                return Path(config.pretrained_models)
         
     | 
    
        xax/core/state.py
    ADDED
    
    | 
         @@ -0,0 +1,81 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """Defines a dataclass for keeping track of the current training state."""
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import time
         
     | 
| 
      
 4 
     | 
    
         
            +
            from dataclasses import dataclass
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import Literal, TypedDict, cast, get_args
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            from omegaconf import MISSING
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            from xax.core.conf import field
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            Phase = Literal["train", "valid"]
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            def cast_phase(raw_phase: str) -> Phase:
         
     | 
| 
      
 15 
     | 
    
         
            +
                args = get_args(Phase)
         
     | 
| 
      
 16 
     | 
    
         
            +
                assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
         
     | 
| 
      
 17 
     | 
    
         
            +
                return cast(Phase, raw_phase)
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            class StateDict(TypedDict, total=False):
         
     | 
| 
      
 21 
     | 
    
         
            +
                num_steps: int
         
     | 
| 
      
 22 
     | 
    
         
            +
                num_samples: int
         
     | 
| 
      
 23 
     | 
    
         
            +
                num_valid_steps: int
         
     | 
| 
      
 24 
     | 
    
         
            +
                num_valid_samples: int
         
     | 
| 
      
 25 
     | 
    
         
            +
                start_time_s: float
         
     | 
| 
      
 26 
     | 
    
         
            +
                elapsed_time_s: float
         
     | 
| 
      
 27 
     | 
    
         
            +
                raw_phase: str
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
            @dataclass(frozen=True)
         
     | 
| 
      
 31 
     | 
    
         
            +
            class State:
         
     | 
| 
      
 32 
     | 
    
         
            +
                num_steps: int = field(MISSING, help="Number of steps so far")
         
     | 
| 
      
 33 
     | 
    
         
            +
                num_samples: int = field(MISSING, help="Number of sample so far")
         
     | 
| 
      
 34 
     | 
    
         
            +
                num_valid_steps: int = field(MISSING, help="Number of validation steps so far")
         
     | 
| 
      
 35 
     | 
    
         
            +
                num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
         
     | 
| 
      
 36 
     | 
    
         
            +
                start_time_s: float = field(MISSING, help="Start time of training")
         
     | 
| 
      
 37 
     | 
    
         
            +
                elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
         
     | 
| 
      
 38 
     | 
    
         
            +
                raw_phase: str = field(MISSING, help="Current training phase")
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                @property
         
     | 
| 
      
 41 
     | 
    
         
            +
                def phase(self) -> Phase:
         
     | 
| 
      
 42 
     | 
    
         
            +
                    return cast_phase(self.raw_phase)
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 45 
     | 
    
         
            +
                def init_state(cls) -> "State":
         
     | 
| 
      
 46 
     | 
    
         
            +
                    return cls(
         
     | 
| 
      
 47 
     | 
    
         
            +
                        num_steps=0,
         
     | 
| 
      
 48 
     | 
    
         
            +
                        num_samples=0,
         
     | 
| 
      
 49 
     | 
    
         
            +
                        num_valid_steps=0,
         
     | 
| 
      
 50 
     | 
    
         
            +
                        num_valid_samples=0,
         
     | 
| 
      
 51 
     | 
    
         
            +
                        start_time_s=time.time(),
         
     | 
| 
      
 52 
     | 
    
         
            +
                        elapsed_time_s=0.0,
         
     | 
| 
      
 53 
     | 
    
         
            +
                        raw_phase="train",
         
     | 
| 
      
 54 
     | 
    
         
            +
                    )
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                @property
         
     | 
| 
      
 57 
     | 
    
         
            +
                def training(self) -> bool:
         
     | 
| 
      
 58 
     | 
    
         
            +
                    return self.phase == "train"
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                def num_phase_steps(self, phase: Phase) -> int:
         
     | 
| 
      
 61 
     | 
    
         
            +
                    match phase:
         
     | 
| 
      
 62 
     | 
    
         
            +
                        case "train":
         
     | 
| 
      
 63 
     | 
    
         
            +
                            return self.num_steps
         
     | 
| 
      
 64 
     | 
    
         
            +
                        case "valid":
         
     | 
| 
      
 65 
     | 
    
         
            +
                            return self.num_valid_steps
         
     | 
| 
      
 66 
     | 
    
         
            +
                        case _:
         
     | 
| 
      
 67 
     | 
    
         
            +
                            raise ValueError(f"Invalid phase: {phase}")
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                def replace(self, values: StateDict) -> "State":
         
     | 
| 
      
 70 
     | 
    
         
            +
                    return State(
         
     | 
| 
      
 71 
     | 
    
         
            +
                        num_steps=values.get("num_steps", self.num_steps),
         
     | 
| 
      
 72 
     | 
    
         
            +
                        num_samples=values.get("num_samples", self.num_samples),
         
     | 
| 
      
 73 
     | 
    
         
            +
                        num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
         
     | 
| 
      
 74 
     | 
    
         
            +
                        num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
         
     | 
| 
      
 75 
     | 
    
         
            +
                        start_time_s=values.get("start_time_s", self.start_time_s),
         
     | 
| 
      
 76 
     | 
    
         
            +
                        elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
         
     | 
| 
      
 77 
     | 
    
         
            +
                        raw_phase=values.get("raw_phase", self.raw_phase),
         
     | 
| 
      
 78 
     | 
    
         
            +
                    )
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                def with_phase(self, phase: Phase) -> "State":
         
     | 
| 
      
 81 
     | 
    
         
            +
                    return self.replace({"raw_phase": phase})
         
     | 
    
        xax/nn/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     |