xax 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -2
 - xax/task/base.py +0 -3
 - xax/task/mixins/checkpointing.py +25 -8
 - xax/task/mixins/compile.py +8 -0
 - xax/utils/debugging.py +4 -0
 - {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/METADATA +1 -1
 - {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/RECORD +10 -10
 - {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/WHEEL +0 -0
 - {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/licenses/LICENSE +0 -0
 - {xax-0.1.13.dist-info → xax-0.1.14.dist-info}/top_level.txt +0 -0
 
    
        xax/__init__.py
    CHANGED
    
    | 
         @@ -12,7 +12,7 @@ and running the update script: 
     | 
|
| 
       12 
12 
     | 
    
         
             
                python -m scripts.update_api --inplace
         
     | 
| 
       13 
13 
     | 
    
         
             
            """
         
     | 
| 
       14 
14 
     | 
    
         | 
| 
       15 
     | 
    
         
            -
            __version__ = "0.1. 
     | 
| 
      
 15 
     | 
    
         
            +
            __version__ = "0.1.14"
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
       17 
17 
     | 
    
         
             
            # This list shouldn't be modified by hand; instead, run the update script.
         
     | 
| 
       18 
18 
     | 
    
         
             
            __all__ = [
         
     | 
| 
         @@ -77,6 +77,7 @@ __all__ = [ 
     | 
|
| 
       77 
77 
     | 
    
         
             
                "collate_non_null",
         
     | 
| 
       78 
78 
     | 
    
         
             
                "breakpoint_if_nan",
         
     | 
| 
       79 
79 
     | 
    
         
             
                "get_named_leaves",
         
     | 
| 
      
 80 
     | 
    
         
            +
                "log_if_nan",
         
     | 
| 
       80 
81 
     | 
    
         
             
                "BaseFileDownloader",
         
     | 
| 
       81 
82 
     | 
    
         
             
                "ContextTimer",
         
     | 
| 
       82 
83 
     | 
    
         
             
                "CumulativeTimer",
         
     | 
| 
         @@ -237,6 +238,7 @@ NAME_MAP: dict[str, str] = { 
     | 
|
| 
       237 
238 
     | 
    
         
             
                "collate_non_null": "utils.data.collate",
         
     | 
| 
       238 
239 
     | 
    
         
             
                "breakpoint_if_nan": "utils.debugging",
         
     | 
| 
       239 
240 
     | 
    
         
             
                "get_named_leaves": "utils.debugging",
         
     | 
| 
      
 241 
     | 
    
         
            +
                "log_if_nan": "utils.debugging",
         
     | 
| 
       240 
242 
     | 
    
         
             
                "BaseFileDownloader": "utils.experiments",
         
     | 
| 
       241 
243 
     | 
    
         
             
                "ContextTimer": "utils.experiments",
         
     | 
| 
       242 
244 
     | 
    
         
             
                "CumulativeTimer": "utils.experiments",
         
     | 
| 
         @@ -388,7 +390,7 @@ if IMPORT_ALL or TYPE_CHECKING: 
     | 
|
| 
       388 
390 
     | 
    
         
             
                from xax.task.script import Script, ScriptConfig
         
     | 
| 
       389 
391 
     | 
    
         
             
                from xax.task.task import Config, Task
         
     | 
| 
       390 
392 
     | 
    
         
             
                from xax.utils.data.collate import CollateMode, collate, collate_non_null
         
     | 
| 
       391 
     | 
    
         
            -
                from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
         
     | 
| 
      
 393 
     | 
    
         
            +
                from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
         
     | 
| 
       392 
394 
     | 
    
         
             
                from xax.utils.experiments import (
         
     | 
| 
       393 
395 
     | 
    
         
             
                    BaseFileDownloader,
         
     | 
| 
       394 
396 
     | 
    
         
             
                    ContextTimer,
         
     | 
    
        xax/task/base.py
    CHANGED
    
    | 
         @@ -82,9 +82,6 @@ class BaseTask(Generic[Config]): 
     | 
|
| 
       82 
82 
     | 
    
         
             
                def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
         
     | 
| 
       83 
83 
     | 
    
         
             
                    return state
         
     | 
| 
       84 
84 
     | 
    
         | 
| 
       85 
     | 
    
         
            -
                def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
         
     | 
| 
       86 
     | 
    
         
            -
                    pass
         
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
       88 
85 
     | 
    
         
             
                @functools.cached_property
         
     | 
| 
       89 
86 
     | 
    
         
             
                def task_class_name(self) -> str:
         
     | 
| 
       90 
87 
     | 
    
         
             
                    return self.__class__.__name__
         
     | 
    
        xax/task/mixins/checkpointing.py
    CHANGED
    
    | 
         @@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): 
     | 
|
| 
       98 
98 
     | 
    
         
             
                ) -> tuple[PyTree, State, DictConfig]: ...
         
     | 
| 
       99 
99 
     | 
    
         | 
| 
       100 
100 
     | 
    
         
             
                @overload
         
     | 
| 
       101 
     | 
    
         
            -
                def load_checkpoint( 
     | 
| 
      
 101 
     | 
    
         
            +
                def load_checkpoint(
         
     | 
| 
      
 102 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    path: Path,
         
     | 
| 
      
 104 
     | 
    
         
            +
                    part: Literal["model"],
         
     | 
| 
      
 105 
     | 
    
         
            +
                ) -> PyTree: ...
         
     | 
| 
       102 
106 
     | 
    
         | 
| 
       103 
107 
     | 
    
         
             
                @overload
         
     | 
| 
       104 
     | 
    
         
            -
                def load_checkpoint( 
     | 
| 
      
 108 
     | 
    
         
            +
                def load_checkpoint(
         
     | 
| 
      
 109 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 110 
     | 
    
         
            +
                    path: Path,
         
     | 
| 
      
 111 
     | 
    
         
            +
                    part: Literal["opt"],
         
     | 
| 
      
 112 
     | 
    
         
            +
                ) -> optax.GradientTransformation: ...
         
     | 
| 
       105 
113 
     | 
    
         | 
| 
       106 
114 
     | 
    
         
             
                @overload
         
     | 
| 
       107 
     | 
    
         
            -
                def load_checkpoint( 
     | 
| 
      
 115 
     | 
    
         
            +
                def load_checkpoint(
         
     | 
| 
      
 116 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 117 
     | 
    
         
            +
                    path: Path,
         
     | 
| 
      
 118 
     | 
    
         
            +
                    part: Literal["opt_state"],
         
     | 
| 
      
 119 
     | 
    
         
            +
                ) -> optax.OptState: ...
         
     | 
| 
       108 
120 
     | 
    
         | 
| 
       109 
121 
     | 
    
         
             
                @overload
         
     | 
| 
       110 
     | 
    
         
            -
                def load_checkpoint( 
     | 
| 
      
 122 
     | 
    
         
            +
                def load_checkpoint(
         
     | 
| 
      
 123 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 124 
     | 
    
         
            +
                    path: Path,
         
     | 
| 
      
 125 
     | 
    
         
            +
                    part: Literal["state"],
         
     | 
| 
      
 126 
     | 
    
         
            +
                ) -> State: ...
         
     | 
| 
       111 
127 
     | 
    
         | 
| 
       112 
128 
     | 
    
         
             
                @overload
         
     | 
| 
       113 
     | 
    
         
            -
                def load_checkpoint( 
     | 
| 
      
 129 
     | 
    
         
            +
                def load_checkpoint(
         
     | 
| 
      
 130 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 131 
     | 
    
         
            +
                    path: Path,
         
     | 
| 
      
 132 
     | 
    
         
            +
                    part: Literal["config"],
         
     | 
| 
      
 133 
     | 
    
         
            +
                ) -> DictConfig: ...
         
     | 
| 
       114 
134 
     | 
    
         | 
| 
       115 
135 
     | 
    
         
             
                def load_checkpoint(
         
     | 
| 
       116 
136 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]): 
     | 
|
| 
       125 
145 
     | 
    
         
             
                    | State
         
     | 
| 
       126 
146 
     | 
    
         
             
                    | DictConfig
         
     | 
| 
       127 
147 
     | 
    
         
             
                ):
         
     | 
| 
       128 
     | 
    
         
            -
                    # Calls the base callback.
         
     | 
| 
       129 
     | 
    
         
            -
                    self.on_before_checkpoint_load(path)
         
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
148 
     | 
    
         
             
                    with tarfile.open(path, "r:gz") as tar:
         
     | 
| 
       132 
149 
     | 
    
         | 
| 
       133 
150 
     | 
    
         
             
                        def get_model() -> PyTree:
         
     | 
    
        xax/task/mixins/compile.py
    CHANGED
    
    | 
         @@ -32,6 +32,10 @@ def get_cache_dir() -> str | None: 
     | 
|
| 
       32 
32 
     | 
    
         
             
            @dataclass
         
     | 
| 
       33 
33 
     | 
    
         
             
            class CompileOptions:
         
     | 
| 
       34 
34 
     | 
    
         
             
                # JAX compilation options
         
     | 
| 
      
 35 
     | 
    
         
            +
                debug_nans: bool = field(
         
     | 
| 
      
 36 
     | 
    
         
            +
                    value=False,
         
     | 
| 
      
 37 
     | 
    
         
            +
                    help="If True, breaks on NaNs",
         
     | 
| 
      
 38 
     | 
    
         
            +
                )
         
     | 
| 
       35 
39 
     | 
    
         
             
                disable_jit: bool = field(
         
     | 
| 
       36 
40 
     | 
    
         
             
                    value=False,
         
     | 
| 
       37 
41 
     | 
    
         
             
                    help="If True, disables JIT compilation",
         
     | 
| 
         @@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]): 
     | 
|
| 
       89 
93 
     | 
    
         
             
                    cc = self.config.compile
         
     | 
| 
       90 
94 
     | 
    
         | 
| 
       91 
95 
     | 
    
         
             
                    # Set basic compilation flags
         
     | 
| 
      
 96 
     | 
    
         
            +
                    if cc.debug_nans:
         
     | 
| 
      
 97 
     | 
    
         
            +
                        logger.info("Enabling NaNs debugging")
         
     | 
| 
      
 98 
     | 
    
         
            +
                        jax.config.update("jax_debug_nans", True)
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
       92 
100 
     | 
    
         
             
                    if cc.disable_jit:
         
     | 
| 
       93 
101 
     | 
    
         
             
                        logger.info("Disabling JIT compilation")
         
     | 
| 
       94 
102 
     | 
    
         
             
                        jax.config.update("jax_disable_jit", True)
         
     | 
    
        xax/utils/debugging.py
    CHANGED
    
    | 
         @@ -53,3 +53,7 @@ def get_named_leaves( 
     | 
|
| 
       53 
53 
     | 
    
         | 
| 
       54 
54 
     | 
    
         
             
            def breakpoint_if_nan(x: Array) -> None:
         
     | 
| 
       55 
55 
     | 
    
         
             
                jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
            def log_if_nan(x: Array, loc: str) -> None:
         
     | 
| 
      
 59 
     | 
    
         
            +
                jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
         
     | 
| 
         @@ -1,4 +1,4 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            xax/__init__.py,sha256= 
     | 
| 
      
 1 
     | 
    
         
            +
            xax/__init__.py,sha256=D7czvfKKQJlemPuatMPVYbAO4ST3U272QRIyTOru7JI,13989
         
     | 
| 
       2 
2 
     | 
    
         
             
            xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       3 
3 
     | 
    
         
             
            xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
         
     | 
| 
       4 
4 
     | 
    
         
             
            xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
         
     | 
| 
         @@ -16,7 +16,7 @@ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564 
     | 
|
| 
       16 
16 
     | 
    
         
             
            xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
         
     | 
| 
       17 
17 
     | 
    
         
             
            xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
         
     | 
| 
       18 
18 
     | 
    
         
             
            xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       19 
     | 
    
         
            -
            xax/task/base.py,sha256= 
     | 
| 
      
 19 
     | 
    
         
            +
            xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
         
     | 
| 
       20 
20 
     | 
    
         
             
            xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
         
     | 
| 
       21 
21 
     | 
    
         
             
            xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
         
     | 
| 
       22 
22 
     | 
    
         
             
            xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
         
     | 
| 
         @@ -32,8 +32,8 @@ xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,67 
     | 
|
| 
       32 
32 
     | 
    
         
             
            xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
         
     | 
| 
       33 
33 
     | 
    
         
             
            xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
         
     | 
| 
       34 
34 
     | 
    
         
             
            xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
         
     | 
| 
       35 
     | 
    
         
            -
            xax/task/mixins/checkpointing.py,sha256= 
     | 
| 
       36 
     | 
    
         
            -
            xax/task/mixins/compile.py,sha256= 
     | 
| 
      
 35 
     | 
    
         
            +
            xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
         
     | 
| 
      
 36 
     | 
    
         
            +
            xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
         
     | 
| 
       37 
37 
     | 
    
         
             
            xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
         
     | 
| 
       38 
38 
     | 
    
         
             
            xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
         
     | 
| 
       39 
39 
     | 
    
         
             
            xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
         
     | 
| 
         @@ -43,7 +43,7 @@ xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1 
     | 
|
| 
       43 
43 
     | 
    
         
             
            xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
         
     | 
| 
       44 
44 
     | 
    
         
             
            xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
         
     | 
| 
       45 
45 
     | 
    
         
             
            xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       46 
     | 
    
         
            -
            xax/utils/debugging.py,sha256= 
     | 
| 
      
 46 
     | 
    
         
            +
            xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
         
     | 
| 
       47 
47 
     | 
    
         
             
            xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
         
     | 
| 
       48 
48 
     | 
    
         
             
            xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
         
     | 
| 
       49 
49 
     | 
    
         
             
            xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
         
     | 
| 
         @@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706 
     | 
|
| 
       58 
58 
     | 
    
         
             
            xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       59 
59 
     | 
    
         
             
            xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
         
     | 
| 
       60 
60 
     | 
    
         
             
            xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
         
     | 
| 
       61 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       62 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       63 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       64 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       65 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
      
 61 
     | 
    
         
            +
            xax-0.1.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
         
     | 
| 
      
 62 
     | 
    
         
            +
            xax-0.1.14.dist-info/METADATA,sha256=WbKtAXJUYKHvBrOJPEm_eXF9O9ekc0WdPmsQQCSGG5Q,1878
         
     | 
| 
      
 63 
     | 
    
         
            +
            xax-0.1.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         
     | 
| 
      
 64 
     | 
    
         
            +
            xax-0.1.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
         
     | 
| 
      
 65 
     | 
    
         
            +
            xax-0.1.14.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |