xax 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xax might be problematic. Click here for more details.
- xax/__init__.py +6 -2
- xax/utils/jax.py +22 -0
- {xax-0.1.5.dist-info → xax-0.1.6.dist-info}/METADATA +1 -1
- {xax-0.1.5.dist-info → xax-0.1.6.dist-info}/RECORD +7 -7
- {xax-0.1.5.dist-info → xax-0.1.6.dist-info}/WHEEL +0 -0
- {xax-0.1.5.dist-info → xax-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.5.dist-info → xax-0.1.6.dist-info}/top_level.txt +0 -0
    
        xax/__init__.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ and running the update script: | |
| 12 12 | 
             
                python -m scripts.update_api --inplace
         | 
| 13 13 | 
             
            """
         | 
| 14 14 |  | 
| 15 | 
            -
            __version__ = "0.1. | 
| 15 | 
            +
            __version__ = "0.1.6"
         | 
| 16 16 |  | 
| 17 17 | 
             
            # This list shouldn't be modified by hand; instead, run the update script.
         | 
| 18 18 | 
             
            __all__ = [
         | 
| @@ -97,6 +97,8 @@ __all__ = [ | |
| 97 97 | 
             
                "save_config",
         | 
| 98 98 | 
             
                "stage_environment",
         | 
| 99 99 | 
             
                "to_markdown_table",
         | 
| 100 | 
            +
                "HashableArray",
         | 
| 101 | 
            +
                "hashable_array",
         | 
| 100 102 | 
             
                "jit",
         | 
| 101 103 | 
             
                "save_jaxpr_dot",
         | 
| 102 104 | 
             
                "ColoredFormatter",
         | 
| @@ -251,6 +253,8 @@ NAME_MAP: dict[str, str] = { | |
| 251 253 | 
             
                "save_config": "utils.experiments",
         | 
| 252 254 | 
             
                "stage_environment": "utils.experiments",
         | 
| 253 255 | 
             
                "to_markdown_table": "utils.experiments",
         | 
| 256 | 
            +
                "HashableArray": "utils.jax",
         | 
| 257 | 
            +
                "hashable_array": "utils.jax",
         | 
| 254 258 | 
             
                "jit": "utils.jax",
         | 
| 255 259 | 
             
                "save_jaxpr_dot": "utils.jaxpr",
         | 
| 256 260 | 
             
                "ColoredFormatter": "utils.logging",
         | 
| @@ -404,7 +408,7 @@ if IMPORT_ALL or TYPE_CHECKING: | |
| 404 408 | 
             
                    stage_environment,
         | 
| 405 409 | 
             
                    to_markdown_table,
         | 
| 406 410 | 
             
                )
         | 
| 407 | 
            -
                from xax.utils.jax import jit
         | 
| 411 | 
            +
                from xax.utils.jax import HashableArray, hashable_array, jit
         | 
| 408 412 | 
             
                from xax.utils.jaxpr import save_jaxpr_dot
         | 
| 409 413 | 
             
                from xax.utils.logging import (
         | 
| 410 414 | 
             
                    LOG_ERROR_SUMMARY,
         | 
    
        xax/utils/jax.py
    CHANGED
    
    | @@ -138,3 +138,25 @@ def jit( | |
| 138 138 | 
             
                    return wrapped
         | 
| 139 139 |  | 
| 140 140 | 
             
                return decorator
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            class HashableArray:
         | 
| 144 | 
            +
                def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
         | 
| 145 | 
            +
                    if not isinstance(array, (np.ndarray, jnp.ndarray)):
         | 
| 146 | 
            +
                        raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
         | 
| 147 | 
            +
                    self.array = array
         | 
| 148 | 
            +
                    self._hash: int | None = None
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def __hash__(self) -> int:
         | 
| 151 | 
            +
                    if self._hash is None:
         | 
| 152 | 
            +
                        self._hash = hash(self.array.tobytes())
         | 
| 153 | 
            +
                    return self._hash
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def __eq__(self, other: object) -> bool:
         | 
| 156 | 
            +
                    if not isinstance(other, HashableArray):
         | 
| 157 | 
            +
                        return False
         | 
| 158 | 
            +
                    return bool(jnp.array_equal(self.array, other.array))
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
         | 
| 162 | 
            +
                return HashableArray(array)
         | 
| @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            xax/__init__.py,sha256= | 
| 1 | 
            +
            xax/__init__.py,sha256=rjqydWhxQVUAj3lXgFpzj4iLFOdDJGHArfAH7_QSkhk,13504
         | 
| 2 2 | 
             
            xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 3 3 | 
             
            xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
         | 
| 4 4 | 
             
            xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
         | 
| @@ -43,7 +43,7 @@ xax/task/mixins/train.py,sha256=vsH_QpyrThlh9AzWnyvDJv58Y8U_516oi8gmMq_0iMg,2233 | |
| 43 43 | 
             
            xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 44 44 | 
             
            xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
         | 
| 45 45 | 
             
            xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
         | 
| 46 | 
            -
            xax/utils/jax.py,sha256= | 
| 46 | 
            +
            xax/utils/jax.py,sha256=eObvWt2DraCs2IMDZSdQ0rRk8tA3P5XBlF_UeVq7Aro,5480
         | 
| 47 47 | 
             
            xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
         | 
| 48 48 | 
             
            xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
         | 
| 49 49 | 
             
            xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
         | 
| @@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,1705 | |
| 53 53 | 
             
            xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
         | 
| 54 54 | 
             
            xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 55 55 | 
             
            xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
         | 
| 56 | 
            -
            xax-0.1. | 
| 57 | 
            -
            xax-0.1. | 
| 58 | 
            -
            xax-0.1. | 
| 59 | 
            -
            xax-0.1. | 
| 60 | 
            -
            xax-0.1. | 
| 56 | 
            +
            xax-0.1.6.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
         | 
| 57 | 
            +
            xax-0.1.6.dist-info/METADATA,sha256=vKxhuOt02ALjFV9fAt-rPVTwvqX4uNr_shL1DGEotA4,1877
         | 
| 58 | 
            +
            xax-0.1.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         | 
| 59 | 
            +
            xax-0.1.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
         | 
| 60 | 
            +
            xax-0.1.6.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |