xax 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -1
 - xax/nn/geom.py +26 -5
 - xax/task/mixins/train.py +3 -5
 - xax/utils/experiments.py +14 -0
 - {xax-0.1.14.dist-info → xax-0.1.15.dist-info}/METADATA +1 -1
 - {xax-0.1.14.dist-info → xax-0.1.15.dist-info}/RECORD +9 -9
 - {xax-0.1.14.dist-info → xax-0.1.15.dist-info}/WHEEL +0 -0
 - {xax-0.1.14.dist-info → xax-0.1.15.dist-info}/licenses/LICENSE +0 -0
 - {xax-0.1.14.dist-info → xax-0.1.15.dist-info}/top_level.txt +0 -0
 
    
        xax/__init__.py
    CHANGED
    
    | 
         @@ -12,7 +12,7 @@ and running the update script: 
     | 
|
| 
       12 
12 
     | 
    
         
             
                python -m scripts.update_api --inplace
         
     | 
| 
       13 
13 
     | 
    
         
             
            """
         
     | 
| 
       14 
14 
     | 
    
         | 
| 
       15 
     | 
    
         
            -
            __version__ = "0.1. 
     | 
| 
      
 15 
     | 
    
         
            +
            __version__ = "0.1.15"
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
       17 
17 
     | 
    
         
             
            # This list shouldn't be modified by hand; instead, run the update script.
         
     | 
| 
       18 
18 
     | 
    
         
             
            __all__ = [
         
     | 
| 
         @@ -40,6 +40,7 @@ __all__ = [ 
     | 
|
| 
       40 
40 
     | 
    
         
             
                "load_eqx_mlp",
         
     | 
| 
       41 
41 
     | 
    
         
             
                "make_eqx_mlp",
         
     | 
| 
       42 
42 
     | 
    
         
             
                "save_eqx",
         
     | 
| 
      
 43 
     | 
    
         
            +
                "cubic_bezier_interpolation",
         
     | 
| 
       43 
44 
     | 
    
         
             
                "euler_to_quat",
         
     | 
| 
       44 
45 
     | 
    
         
             
                "get_projected_gravity_vector_from_quat",
         
     | 
| 
       45 
46 
     | 
    
         
             
                "quat_to_euler",
         
     | 
| 
         @@ -201,6 +202,7 @@ NAME_MAP: dict[str, str] = { 
     | 
|
| 
       201 
202 
     | 
    
         
             
                "load_eqx_mlp": "nn.equinox",
         
     | 
| 
       202 
203 
     | 
    
         
             
                "make_eqx_mlp": "nn.equinox",
         
     | 
| 
       203 
204 
     | 
    
         
             
                "save_eqx": "nn.equinox",
         
     | 
| 
      
 205 
     | 
    
         
            +
                "cubic_bezier_interpolation": "nn.geom",
         
     | 
| 
       204 
206 
     | 
    
         
             
                "euler_to_quat": "nn.geom",
         
     | 
| 
       205 
207 
     | 
    
         
             
                "get_projected_gravity_vector_from_quat": "nn.geom",
         
     | 
| 
       206 
208 
     | 
    
         
             
                "quat_to_euler": "nn.geom",
         
     | 
| 
         @@ -363,6 +365,7 @@ if IMPORT_ALL or TYPE_CHECKING: 
     | 
|
| 
       363 
365 
     | 
    
         
             
                    save_eqx,
         
     | 
| 
       364 
366 
     | 
    
         
             
                )
         
     | 
| 
       365 
367 
     | 
    
         
             
                from xax.nn.geom import (
         
     | 
| 
      
 368 
     | 
    
         
            +
                    cubic_bezier_interpolation,
         
     | 
| 
       366 
369 
     | 
    
         
             
                    euler_to_quat,
         
     | 
| 
       367 
370 
     | 
    
         
             
                    get_projected_gravity_vector_from_quat,
         
     | 
| 
       368 
371 
     | 
    
         
             
                    quat_to_euler,
         
     | 
    
        xax/nn/geom.py
    CHANGED
    
    | 
         @@ -1,10 +1,10 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            """Defines geometry functions."""
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
       3 
     | 
    
         
            -
            import jax
         
     | 
| 
       4 
3 
     | 
    
         
             
            from jax import numpy as jnp
         
     | 
| 
      
 4 
     | 
    
         
            +
            from jaxtyping import Array
         
     | 
| 
       5 
5 
     | 
    
         | 
| 
       6 
6 
     | 
    
         | 
| 
       7 
     | 
    
         
            -
            def quat_to_euler(quat_4:  
     | 
| 
      
 7 
     | 
    
         
            +
            def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
         
     | 
| 
       8 
8 
     | 
    
         
             
                """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
                Args:
         
     | 
| 
         @@ -40,7 +40,7 @@ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array: 
     | 
|
| 
       40 
40 
     | 
    
         
             
                return jnp.concatenate([roll, pitch, yaw], axis=-1)
         
     | 
| 
       41 
41 
     | 
    
         | 
| 
       42 
42 
     | 
    
         | 
| 
       43 
     | 
    
         
            -
            def euler_to_quat(euler_3:  
     | 
| 
      
 43 
     | 
    
         
            +
            def euler_to_quat(euler_3: Array) -> Array:
         
     | 
| 
       44 
44 
     | 
    
         
             
                """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
         
     | 
| 
       45 
45 
     | 
    
         | 
| 
       46 
46 
     | 
    
         
             
                Args:
         
     | 
| 
         @@ -75,7 +75,7 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array: 
     | 
|
| 
       75 
75 
     | 
    
         
             
                return quat
         
     | 
| 
       76 
76 
     | 
    
         | 
| 
       77 
77 
     | 
    
         | 
| 
       78 
     | 
    
         
            -
            def get_projected_gravity_vector_from_quat(quat:  
     | 
| 
      
 78 
     | 
    
         
            +
            def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Array:
         
     | 
| 
       79 
79 
     | 
    
         
             
                """Calculates the gravity vector projected onto the local frame given a quaternion orientation.
         
     | 
| 
       80 
80 
     | 
    
         | 
| 
       81 
81 
     | 
    
         
             
                Args:
         
     | 
| 
         @@ -101,7 +101,7 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) - 
     | 
|
| 
       101 
101 
     | 
    
         
             
                return jnp.concatenate([gx, gy, -gz], axis=-1)
         
     | 
| 
       102 
102 
     | 
    
         | 
| 
       103 
103 
     | 
    
         | 
| 
       104 
     | 
    
         
            -
            def rotate_vector_by_quat(vector:  
     | 
| 
      
 104 
     | 
    
         
            +
            def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
         
     | 
| 
       105 
105 
     | 
    
         
             
                """Rotates a vector by a quaternion.
         
     | 
| 
       106 
106 
     | 
    
         | 
| 
       107 
107 
     | 
    
         
             
                Args:
         
     | 
| 
         @@ -156,3 +156,24 @@ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) 
     | 
|
| 
       156 
156 
     | 
    
         
             
                )
         
     | 
| 
       157 
157 
     | 
    
         | 
| 
       158 
158 
     | 
    
         
             
                return jnp.concatenate([xx, yy, zz], axis=-1)
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
             
     | 
| 
      
 161 
     | 
    
         
            +
            def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
         
     | 
| 
      
 162 
     | 
    
         
            +
                """Cubic bezier interpolation.
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                This is a cubic bezier curve that starts at y_start and ends at y_end,
         
     | 
| 
      
 165 
     | 
    
         
            +
                and is controlled by the parameter x. The curve is defined by the following formula:
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
                y(x) = y_start + (y_end - y_start) * (x**3 + 3 * (x**2 * (1 - x)))
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 170 
     | 
    
         
            +
                    y_start: The start value, shape (*).
         
     | 
| 
      
 171 
     | 
    
         
            +
                    y_end: The end value, shape (*).
         
     | 
| 
      
 172 
     | 
    
         
            +
                    x: The interpolation parameter, shape (*).
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 175 
     | 
    
         
            +
                    The interpolated value, shape (*).
         
     | 
| 
      
 176 
     | 
    
         
            +
                """
         
     | 
| 
      
 177 
     | 
    
         
            +
                y_diff = y_end - y_start
         
     | 
| 
      
 178 
     | 
    
         
            +
                bezier = x**3 + 3 * (x**2 * (1 - x))
         
     | 
| 
      
 179 
     | 
    
         
            +
                return y_start + y_diff * bezier
         
     | 
    
        xax/task/mixins/train.py
    CHANGED
    
    | 
         @@ -50,8 +50,7 @@ from xax.utils.experiments import ( 
     | 
|
| 
       50 
50 
     | 
    
         
             
                TrainingFinishedError,
         
     | 
| 
       51 
51 
     | 
    
         
             
                diff_configs,
         
     | 
| 
       52 
52 
     | 
    
         
             
                get_diff_string,
         
     | 
| 
       53 
     | 
    
         
            -
                 
     | 
| 
       54 
     | 
    
         
            -
                get_packages_with_versions,
         
     | 
| 
      
 53 
     | 
    
         
            +
                get_state_file_string,
         
     | 
| 
       55 
54 
     | 
    
         
             
                get_training_code,
         
     | 
| 
       56 
55 
     | 
    
         
             
            )
         
     | 
| 
       57 
56 
     | 
    
         
             
            from xax.utils.jax import jit as xax_jit
         
     | 
| 
         @@ -534,9 +533,8 @@ class TrainMixin( 
     | 
|
| 
       534 
533 
     | 
    
         
             
                    logger.log(LOG_STATUS, self.task_path)
         
     | 
| 
       535 
534 
     | 
    
         
             
                    logger.log(LOG_STATUS, self.task_name)
         
     | 
| 
       536 
535 
     | 
    
         
             
                    logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
         
     | 
| 
       537 
     | 
    
         
            -
                    self.logger.log_file(" 
     | 
| 
       538 
     | 
    
         
            -
                    self.logger.log_file(" 
     | 
| 
       539 
     | 
    
         
            -
                    self.logger.log_file("training_code.txt", get_training_code(self))
         
     | 
| 
      
 536 
     | 
    
         
            +
                    self.logger.log_file("state.txt", get_state_file_string(self))
         
     | 
| 
      
 537 
     | 
    
         
            +
                    self.logger.log_file("training_code.py", get_training_code(self))
         
     | 
| 
       540 
538 
     | 
    
         
             
                    self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
         
     | 
| 
       541 
539 
     | 
    
         | 
| 
       542 
540 
     | 
    
         
             
                def model_partition_fn(self, item: Any) -> bool:  # noqa: ANN401
         
     | 
    
        xax/utils/experiments.py
    CHANGED
    
    | 
         @@ -479,6 +479,20 @@ def get_packages_with_versions() -> str: 
     | 
|
| 
       479 
479 
     | 
    
         
             
                return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
         
     | 
| 
       480 
480 
     | 
    
         | 
| 
       481 
481 
     | 
    
         | 
| 
      
 482 
     | 
    
         
            +
            def get_command_line_string() -> str:
         
     | 
| 
      
 483 
     | 
    
         
            +
                return " ".join(sys.argv)
         
     | 
| 
      
 484 
     | 
    
         
            +
             
     | 
| 
      
 485 
     | 
    
         
            +
             
     | 
| 
      
 486 
     | 
    
         
            +
            def get_state_file_string(obj: object) -> str:
         
     | 
| 
      
 487 
     | 
    
         
            +
                return "\n\n".join(
         
     | 
| 
      
 488 
     | 
    
         
            +
                    [
         
     | 
| 
      
 489 
     | 
    
         
            +
                        f"=== Command Line ===\n\n{get_command_line_string()}",
         
     | 
| 
      
 490 
     | 
    
         
            +
                        f"=== Git State ===\n\n{get_git_state(obj)}",
         
     | 
| 
      
 491 
     | 
    
         
            +
                        f"=== Packages ===\n\n{get_packages_with_versions()}",
         
     | 
| 
      
 492 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 493 
     | 
    
         
            +
                )
         
     | 
| 
      
 494 
     | 
    
         
            +
             
     | 
| 
      
 495 
     | 
    
         
            +
             
     | 
| 
       482 
496 
     | 
    
         
             
            def get_training_code(obj: object) -> str:
         
     | 
| 
       483 
497 
     | 
    
         
             
                """Gets the text from the file containing the provided object.
         
     | 
| 
       484 
498 
     | 
    
         | 
| 
         @@ -1,4 +1,4 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            xax/__init__.py,sha256= 
     | 
| 
      
 1 
     | 
    
         
            +
            xax/__init__.py,sha256=bV2mTcuiVaVNvwgbDgg7dKDkMeuyA0mqF0muU5KZHeg,14104
         
     | 
| 
       2 
2 
     | 
    
         
             
            xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       3 
3 
     | 
    
         
             
            xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
         
     | 
| 
       4 
4 
     | 
    
         
             
            xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
         
     | 
| 
         @@ -10,7 +10,7 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847 
     | 
|
| 
       10 
10 
     | 
    
         
             
            xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
         
     | 
| 
       11 
11 
     | 
    
         
             
            xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
         
     | 
| 
       12 
12 
     | 
    
         
             
            xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
         
     | 
| 
       13 
     | 
    
         
            -
            xax/nn/geom.py,sha256= 
     | 
| 
      
 13 
     | 
    
         
            +
            xax/nn/geom.py,sha256=PN0Ndn575aVtsSfxi67RghHB7luRkqtpS7bPbT1LpLE,5201
         
     | 
| 
       14 
14 
     | 
    
         
             
            xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
         
     | 
| 
       15 
15 
     | 
    
         
             
            xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
         
     | 
| 
       16 
16 
     | 
    
         
             
            xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
         
     | 
| 
         @@ -41,10 +41,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280 
     | 
|
| 
       41 
41 
     | 
    
         
             
            xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
         
     | 
| 
       42 
42 
     | 
    
         
             
            xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
         
     | 
| 
       43 
43 
     | 
    
         
             
            xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
         
     | 
| 
       44 
     | 
    
         
            -
            xax/task/mixins/train.py,sha256= 
     | 
| 
      
 44 
     | 
    
         
            +
            xax/task/mixins/train.py,sha256=1hmUx1HIL8HKfwOnupS3Knsw1CiK2YCbIQnUTYyDEms,26157
         
     | 
| 
       45 
45 
     | 
    
         
             
            xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       46 
46 
     | 
    
         
             
            xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
         
     | 
| 
       47 
     | 
    
         
            -
            xax/utils/experiments.py,sha256= 
     | 
| 
      
 47 
     | 
    
         
            +
            xax/utils/experiments.py,sha256=X6MESZ3z_Z0DLH6NQucuPzibuOc6rZmlf5UZt4in458,29591
         
     | 
| 
       48 
48 
     | 
    
         
             
            xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
         
     | 
| 
       49 
49 
     | 
    
         
             
            xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
         
     | 
| 
       50 
50 
     | 
    
         
             
            xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
         
     | 
| 
         @@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706 
     | 
|
| 
       58 
58 
     | 
    
         
             
            xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         
     | 
| 
       59 
59 
     | 
    
         
             
            xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
         
     | 
| 
       60 
60 
     | 
    
         
             
            xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
         
     | 
| 
       61 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       62 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       63 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       64 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
       65 
     | 
    
         
            -
            xax-0.1. 
     | 
| 
      
 61 
     | 
    
         
            +
            xax-0.1.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
         
     | 
| 
      
 62 
     | 
    
         
            +
            xax-0.1.15.dist-info/METADATA,sha256=i5thFSTL1Zx03UpnCj7f71rxSgs0P3L6ZDd6vYEtM7U,1878
         
     | 
| 
      
 63 
     | 
    
         
            +
            xax-0.1.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         
     | 
| 
      
 64 
     | 
    
         
            +
            xax-0.1.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
         
     | 
| 
      
 65 
     | 
    
         
            +
            xax-0.1.15.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |