xax 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +9 -10
- xax/nn/geom.py +57 -0
- xax/nn/ssm.py +194 -174
- xax/task/mixins/train.py +15 -7
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/METADATA +1 -1
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/RECORD +9 -9
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/WHEEL +0 -0
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/top_level.txt +0 -0
    
        xax/__init__.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ and running the update script: | |
| 12 12 | 
             
                python -m scripts.update_api --inplace
         | 
| 13 13 | 
             
            """
         | 
| 14 14 |  | 
| 15 | 
            -
            __version__ = "0.1. | 
| 15 | 
            +
            __version__ = "0.1.12"
         | 
| 16 16 |  | 
| 17 17 | 
             
            # This list shouldn't be modified by hand; instead, run the update script.
         | 
| 18 18 | 
             
            __all__ = [
         | 
| @@ -43,15 +43,14 @@ __all__ = [ | |
| 43 43 | 
             
                "euler_to_quat",
         | 
| 44 44 | 
             
                "get_projected_gravity_vector_from_quat",
         | 
| 45 45 | 
             
                "quat_to_euler",
         | 
| 46 | 
            +
                "rotate_vector_by_quat",
         | 
| 46 47 | 
             
                "cross_entropy",
         | 
| 47 48 | 
             
                "cast_norm_type",
         | 
| 48 49 | 
             
                "get_norm",
         | 
| 49 50 | 
             
                "is_master",
         | 
| 51 | 
            +
                "BaseSSMBlock",
         | 
| 50 52 | 
             
                "DiagSSMBlock",
         | 
| 51 | 
            -
                " | 
| 52 | 
            -
                "S4",
         | 
| 53 | 
            -
                "S4Layer",
         | 
| 54 | 
            -
                "S6Layer",
         | 
| 53 | 
            +
                "SSM",
         | 
| 55 54 | 
             
                "SSMBlock",
         | 
| 56 55 | 
             
                "BaseLauncher",
         | 
| 57 56 | 
             
                "CliLauncher",
         | 
| @@ -203,15 +202,14 @@ NAME_MAP: dict[str, str] = { | |
| 203 202 | 
             
                "euler_to_quat": "nn.geom",
         | 
| 204 203 | 
             
                "get_projected_gravity_vector_from_quat": "nn.geom",
         | 
| 205 204 | 
             
                "quat_to_euler": "nn.geom",
         | 
| 205 | 
            +
                "rotate_vector_by_quat": "nn.geom",
         | 
| 206 206 | 
             
                "cross_entropy": "nn.losses",
         | 
| 207 207 | 
             
                "cast_norm_type": "nn.norm",
         | 
| 208 208 | 
             
                "get_norm": "nn.norm",
         | 
| 209 209 | 
             
                "is_master": "nn.parallel",
         | 
| 210 | 
            +
                "BaseSSMBlock": "nn.ssm",
         | 
| 210 211 | 
             
                "DiagSSMBlock": "nn.ssm",
         | 
| 211 | 
            -
                " | 
| 212 | 
            -
                "S4": "nn.ssm",
         | 
| 213 | 
            -
                "S4Layer": "nn.ssm",
         | 
| 214 | 
            -
                "S6Layer": "nn.ssm",
         | 
| 212 | 
            +
                "SSM": "nn.ssm",
         | 
| 215 213 | 
             
                "SSMBlock": "nn.ssm",
         | 
| 216 214 | 
             
                "BaseLauncher": "task.launchers.base",
         | 
| 217 215 | 
             
                "CliLauncher": "task.launchers.cli",
         | 
| @@ -364,11 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING: | |
| 364 362 | 
             
                    euler_to_quat,
         | 
| 365 363 | 
             
                    get_projected_gravity_vector_from_quat,
         | 
| 366 364 | 
             
                    quat_to_euler,
         | 
| 365 | 
            +
                    rotate_vector_by_quat,
         | 
| 367 366 | 
             
                )
         | 
| 368 367 | 
             
                from xax.nn.losses import cross_entropy
         | 
| 369 368 | 
             
                from xax.nn.norm import NormType, cast_norm_type, get_norm
         | 
| 370 369 | 
             
                from xax.nn.parallel import is_master
         | 
| 371 | 
            -
                from xax.nn.ssm import  | 
| 370 | 
            +
                from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
         | 
| 372 371 | 
             
                from xax.task.base import RawConfigType
         | 
| 373 372 | 
             
                from xax.task.launchers.base import BaseLauncher
         | 
| 374 373 | 
             
                from xax.task.launchers.cli import CliLauncher
         | 
    
        xax/nn/geom.py
    CHANGED
    
    | @@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) - | |
| 99 99 |  | 
| 100 100 | 
             
                # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
         | 
| 101 101 | 
             
                return jnp.concatenate([gx, gy, -gz], axis=-1)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
         | 
| 105 | 
            +
                """Rotates a vector by a quaternion.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                Args:
         | 
| 108 | 
            +
                    vector: The vector to rotate, shape (*, 3).
         | 
| 109 | 
            +
                    quat: The quaternion to rotate by, shape (*, 4).
         | 
| 110 | 
            +
                    eps: A small epsilon value to avoid division by zero.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                Returns:
         | 
| 113 | 
            +
                    The rotated vector, shape (*, 3).
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                # Normalize quaternion
         | 
| 116 | 
            +
                quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
         | 
| 117 | 
            +
                w, x, y, z = jnp.split(quat, 4, axis=-1)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # Extract vector components
         | 
| 120 | 
            +
                vx, vy, vz = jnp.split(vector, 3, axis=-1)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # Terms for x component
         | 
| 123 | 
            +
                xx = (
         | 
| 124 | 
            +
                    w * w * vx
         | 
| 125 | 
            +
                    + 2 * y * w * vz
         | 
| 126 | 
            +
                    - 2 * z * w * vy
         | 
| 127 | 
            +
                    + x * x * vx
         | 
| 128 | 
            +
                    + 2 * y * x * vy
         | 
| 129 | 
            +
                    + 2 * z * x * vz
         | 
| 130 | 
            +
                    - z * z * vx
         | 
| 131 | 
            +
                    - y * y * vx
         | 
| 132 | 
            +
                )
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                # Terms for y component
         | 
| 135 | 
            +
                yy = (
         | 
| 136 | 
            +
                    2 * x * y * vx
         | 
| 137 | 
            +
                    + y * y * vy
         | 
| 138 | 
            +
                    + 2 * z * y * vz
         | 
| 139 | 
            +
                    + 2 * w * z * vx
         | 
| 140 | 
            +
                    - z * z * vy
         | 
| 141 | 
            +
                    + w * w * vy
         | 
| 142 | 
            +
                    - 2 * w * x * vz
         | 
| 143 | 
            +
                    - x * x * vy
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                # Terms for z component
         | 
| 147 | 
            +
                zz = (
         | 
| 148 | 
            +
                    2 * x * z * vx
         | 
| 149 | 
            +
                    + 2 * y * z * vy
         | 
| 150 | 
            +
                    + z * z * vz
         | 
| 151 | 
            +
                    - 2 * w * y * vx
         | 
| 152 | 
            +
                    + w * w * vz
         | 
| 153 | 
            +
                    + 2 * w * x * vy
         | 
| 154 | 
            +
                    - y * y * vz
         | 
| 155 | 
            +
                    - x * x * vz
         | 
| 156 | 
            +
                )
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                return jnp.concatenate([xx, yy, zz], axis=-1)
         | 
    
        xax/nn/ssm.py
    CHANGED
    
    | @@ -13,140 +13,18 @@ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array: | |
| 13 13 | 
             
                return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
         | 
| 14 14 |  | 
| 15 15 |  | 
| 16 | 
            -
            class  | 
| 17 | 
            -
                 | 
| 18 | 
            -
                 | 
| 19 | 
            -
                C: Array
         | 
| 20 | 
            -
                proj_in: eqx.nn.Linear
         | 
| 21 | 
            -
                proj_out: eqx.nn.Linear
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                def __init__(
         | 
| 24 | 
            -
                    self,
         | 
| 25 | 
            -
                    hidden_size: int,
         | 
| 26 | 
            -
                    projection_size: int,
         | 
| 27 | 
            -
                    input_size: int,
         | 
| 28 | 
            -
                    output_size: int,
         | 
| 29 | 
            -
                    *,
         | 
| 30 | 
            -
                    key: PRNGKeyArray,
         | 
| 31 | 
            -
                ) -> None:
         | 
| 32 | 
            -
                    self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
         | 
| 33 | 
            -
                    self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
         | 
| 34 | 
            -
                    self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
         | 
| 35 | 
            -
                    self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
         | 
| 36 | 
            -
                    self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 39 | 
            -
                    h = self.a * h + self.B.T @ x
         | 
| 40 | 
            -
                    y = self.C.T @ h
         | 
| 41 | 
            -
                    return h, y
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                def predict_sequence(self, x_seq: Array) -> Array:
         | 
| 44 | 
            -
                    x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
         | 
| 45 | 
            -
                    h = jnp.zeros(self.a.shape[0])
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                    def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 48 | 
            -
                        h = self.a * h + self.B.T @ x
         | 
| 49 | 
            -
                        y = self.C.T @ h
         | 
| 50 | 
            -
                        return h, y
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                    _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
         | 
| 53 | 
            -
                    y_out = jax.vmap(self.proj_out)(y_seq)
         | 
| 54 | 
            -
                    return y_out
         | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
            class S4Layer(eqx.Module):
         | 
| 58 | 
            -
                a: Array
         | 
| 59 | 
            -
                B: Array
         | 
| 60 | 
            -
                C: Array
         | 
| 61 | 
            -
                proj_in: eqx.nn.Linear
         | 
| 62 | 
            -
                proj_out: eqx.nn.Linear
         | 
| 63 | 
            -
                delta: Array
         | 
| 64 | 
            -
             | 
| 65 | 
            -
                def __init__(
         | 
| 66 | 
            -
                    self,
         | 
| 67 | 
            -
                    hidden_size: int,
         | 
| 68 | 
            -
                    projection_size: int,
         | 
| 69 | 
            -
                    input_size: int,
         | 
| 70 | 
            -
                    output_size: int,
         | 
| 71 | 
            -
                    *,
         | 
| 72 | 
            -
                    key: PRNGKeyArray,
         | 
| 73 | 
            -
                ) -> None:
         | 
| 74 | 
            -
                    self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
         | 
| 75 | 
            -
                    self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
         | 
| 76 | 
            -
                    self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
         | 
| 77 | 
            -
                    self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
         | 
| 78 | 
            -
                    self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
         | 
| 79 | 
            -
                    self.delta = jax.random.uniform(key, (hidden_size,))
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 82 | 
            -
                    delta_a = self.delta * self.a
         | 
| 83 | 
            -
                    a_bar = jnp.exp(delta_a)
         | 
| 84 | 
            -
                    b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
         | 
| 85 | 
            -
                    h = a_bar * h + b_bar.T @ x
         | 
| 86 | 
            -
                    y = self.C.T @ h
         | 
| 87 | 
            -
                    return h, y
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                def predict_sequence(self, x_seq: Array) -> Array:
         | 
| 90 | 
            -
                    x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
         | 
| 91 | 
            -
                    h = jnp.zeros(self.a.shape[0])
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                    def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 94 | 
            -
                        h = self.a * h + self.B.T @ x
         | 
| 95 | 
            -
                        y = self.C.T @ h
         | 
| 96 | 
            -
                        return h, y
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                    _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
         | 
| 99 | 
            -
                    y_out = jax.vmap(self.proj_out)(y_seq)
         | 
| 100 | 
            -
                    return y_out
         | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
            class S6Layer(eqx.Module):
         | 
| 104 | 
            -
                a: Array
         | 
| 105 | 
            -
                B: Array
         | 
| 106 | 
            -
                C: Array
         | 
| 107 | 
            -
                proj_in: eqx.nn.Linear
         | 
| 108 | 
            -
                proj_out: eqx.nn.Linear
         | 
| 109 | 
            -
                delta: Array
         | 
| 110 | 
            -
             | 
| 111 | 
            -
                def __init__(
         | 
| 112 | 
            -
                    self,
         | 
| 113 | 
            -
                    hidden_size: int,
         | 
| 114 | 
            -
                    projection_size: int,
         | 
| 115 | 
            -
                    input_size: int,
         | 
| 116 | 
            -
                    output_size: int,
         | 
| 117 | 
            -
                    *,
         | 
| 118 | 
            -
                    key: PRNGKeyArray,
         | 
| 119 | 
            -
                ) -> None:
         | 
| 120 | 
            -
                    self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
         | 
| 121 | 
            -
                    self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
         | 
| 122 | 
            -
                    self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
         | 
| 123 | 
            -
                    self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
         | 
| 124 | 
            -
                    self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
         | 
| 125 | 
            -
                    self.delta = jax.random.uniform(key, (hidden_size,))
         | 
| 126 | 
            -
             | 
| 127 | 
            -
                def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 128 | 
            -
                    h = self.a * h + self.B.T @ x
         | 
| 129 | 
            -
                    y = self.C.T @ h
         | 
| 130 | 
            -
                    return h, y
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                def predict_sequence(self, x_seq: Array) -> Array:
         | 
| 133 | 
            -
                    x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
         | 
| 134 | 
            -
                    h = jnp.zeros(self.a.shape[0])
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                    def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 137 | 
            -
                        h = self.a * h + self.B.T @ x
         | 
| 138 | 
            -
                        y = self.C.T @ h
         | 
| 139 | 
            -
                        return h, y
         | 
| 16 | 
            +
            class BaseSSMBlock(eqx.Module, ABC):
         | 
| 17 | 
            +
                @abstractmethod
         | 
| 18 | 
            +
                def forward(self, h: Array, x: Array) -> Array: ...
         | 
| 140 19 |  | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
                    return y_out
         | 
| 20 | 
            +
                @abstractmethod
         | 
| 21 | 
            +
                def forward_sequence(self, x_seq: Array) -> Array: ...
         | 
| 144 22 |  | 
| 23 | 
            +
                @abstractmethod
         | 
| 24 | 
            +
                def get_a_mat(self, x: Array) -> Array: ...
         | 
| 145 25 |  | 
| 146 | 
            -
            class BaseSSMBlock(eqx.Module, ABC):
         | 
| 147 26 | 
             
                @abstractmethod
         | 
| 148 | 
            -
                def  | 
| 149 | 
            -
                    pass
         | 
| 27 | 
            +
                def get_b_mat(self, x: Array) -> Array: ...
         | 
| 150 28 |  | 
| 151 29 |  | 
| 152 30 | 
             
            class SSMBlock(BaseSSMBlock):
         | 
| @@ -158,80 +36,194 @@ class SSMBlock(BaseSSMBlock): | |
| 158 36 | 
             
                    self.a_mat = glorot(key_a, (hidden_size, hidden_size))
         | 
| 159 37 | 
             
                    self.b_mat = glorot(key_b, (hidden_size, hidden_size))
         | 
| 160 38 |  | 
| 39 | 
            +
                def get_a_mat(self, x: Array) -> Array:
         | 
| 40 | 
            +
                    return self.a_mat
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def get_b_mat(self, x: Array) -> Array:
         | 
| 43 | 
            +
                    return self.b_mat
         | 
| 44 | 
            +
             | 
| 161 45 | 
             
                def forward(self, h: Array, x: Array) -> Array:
         | 
| 162 | 
            -
                     | 
| 46 | 
            +
                    """Perform a forward pass.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    Args:
         | 
| 49 | 
            +
                        h: Hidden state of shape (H,).
         | 
| 50 | 
            +
                        x: Input of shape (H,).
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    Returns:
         | 
| 53 | 
            +
                        Hidden state of shape (H,).
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    a_mat = self.get_a_mat(x)
         | 
| 56 | 
            +
                    b_mat = self.get_b_mat(x)
         | 
| 57 | 
            +
                    h = a_mat @ h + b_mat.T @ x
         | 
| 163 58 | 
             
                    return h
         | 
| 164 59 |  | 
| 165 | 
            -
                def  | 
| 166 | 
            -
                     | 
| 60 | 
            +
                def forward_sequence(self, x_seq: Array) -> Array:
         | 
| 61 | 
            +
                    """Perform a forward pass across time.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    Args:
         | 
| 64 | 
            +
                        x_seq: Input sequence of shape (T, H).
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    Returns:
         | 
| 67 | 
            +
                        Hidden state sequence of shape (T, H).
         | 
| 68 | 
            +
                    """
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    def step(h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 71 | 
            +
                        h = self.forward(h, x)
         | 
| 72 | 
            +
                        return h, h
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    a_mat = self.get_a_mat(x_seq)
         | 
| 75 | 
            +
                    h_0 = jnp.zeros(a_mat.shape[0])
         | 
| 76 | 
            +
                    _, h_seq = jax.lax.scan(step, h_0, x_seq)
         | 
| 77 | 
            +
                    return h_seq
         | 
| 167 78 |  | 
| 168 79 |  | 
| 169 80 | 
             
            class DiagSSMBlock(BaseSSMBlock):
         | 
| 170 | 
            -
                 | 
| 81 | 
            +
                a_diag: Array
         | 
| 171 82 | 
             
                b_mat: Array
         | 
| 172 83 |  | 
| 173 84 | 
             
                def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
         | 
| 174 85 | 
             
                    keys = jax.random.split(key, 2)
         | 
| 175 | 
            -
                    self. | 
| 86 | 
            +
                    self.a_diag = glorot(keys[0], (hidden_size,))
         | 
| 176 87 | 
             
                    self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
         | 
| 177 88 |  | 
| 89 | 
            +
                def get_a_mat(self, x: Array) -> Array:
         | 
| 90 | 
            +
                    return self.a_diag
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def get_b_mat(self, x: Array) -> Array:
         | 
| 93 | 
            +
                    return self.b_mat
         | 
| 94 | 
            +
             | 
| 178 95 | 
             
                def forward(self, h: Array, x: Array) -> Array:
         | 
| 179 | 
            -
                     | 
| 180 | 
            -
             | 
| 96 | 
            +
                    """Perform a forward pass.
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    Args:
         | 
| 99 | 
            +
                        h: Hidden state of shape (H,).
         | 
| 100 | 
            +
                        x: Input of shape (H,).
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    Returns:
         | 
| 103 | 
            +
                        Hidden state of shape (H,).
         | 
| 104 | 
            +
                    """
         | 
| 105 | 
            +
                    a_diag = self.get_a_mat(x)
         | 
| 106 | 
            +
                    b_mat = self.get_b_mat(x)
         | 
| 107 | 
            +
                    h = a_diag * h + b_mat.T @ x
         | 
| 181 108 | 
             
                    return h
         | 
| 182 109 |  | 
| 183 | 
            -
                def  | 
| 110 | 
            +
                def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
         | 
| 111 | 
            +
                    """Perform a potentially parallelized forward pass across time.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    Args:
         | 
| 114 | 
            +
                        x_seq: Input sequence of shape (T, H).
         | 
| 115 | 
            +
                        use_conv: Whether to use convolution to compute the sequence.
         | 
| 116 | 
            +
                        recursive_kernel_calc: Whether to use a recursive kernel calculation.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    Returns:
         | 
| 119 | 
            +
                        Hidden state sequence of shape (T, H).
         | 
| 120 | 
            +
                    """
         | 
| 121 | 
            +
                    if use_conv:
         | 
| 122 | 
            +
                        return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        return self._forward_sequence_scan(x_seq)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def _get_kernel(self, x_seq: Array, length: int) -> Array:
         | 
| 184 127 | 
             
                    """Returns the kernel with time as the final dimension."""
         | 
| 185 128 | 
             
                    exponents = jnp.arange(length)
         | 
| 186 | 
            -
                     | 
| 187 | 
            -
                    kernel =  | 
| 129 | 
            +
                    a_diag = self.get_a_mat(x_seq)
         | 
| 130 | 
            +
                    kernel = jnp.power(a_diag[:, None], exponents)  # (H, T)
         | 
| 131 | 
            +
                    kernel = kernel[:, None, :]  # (H, 1, T)
         | 
| 188 132 | 
             
                    return kernel
         | 
| 189 133 |  | 
| 190 | 
            -
                def  | 
| 134 | 
            +
                def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
         | 
| 135 | 
            +
                    """Returns the kernel with time as the final dimension."""
         | 
| 136 | 
            +
                    assert length % 2 == 0, "Length must be even."
         | 
| 137 | 
            +
                    a_diag = self.get_a_mat(x_seq)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    def helper(length: int) -> tuple[Array, Array]:
         | 
| 140 | 
            +
                        """Returns the kernel and the sqrt of the diagonal."""
         | 
| 141 | 
            +
                        if length == 1:
         | 
| 142 | 
            +
                            return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        half_length = length // 2
         | 
| 145 | 
            +
                        kernel_half, a_half = helper(half_length)
         | 
| 146 | 
            +
                        kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
         | 
| 147 | 
            +
                        return kernel, a_half * a_half
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    kernel, a_diag = helper(length)
         | 
| 150 | 
            +
                    return kernel[:, None, :]  # (H, 1, L)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
         | 
| 191 153 | 
             
                    """Convolves x (T, H) across time using the kernel."""
         | 
| 192 | 
            -
                     | 
| 154 | 
            +
                    seq_len, hidden_size = x_seq.shape
         | 
| 155 | 
            +
                    b_mat = self.get_b_mat(x_seq)
         | 
| 193 156 |  | 
| 194 | 
            -
                     | 
| 195 | 
            -
                     | 
| 196 | 
            -
                    s = s.T  # (H, T)
         | 
| 157 | 
            +
                    s = b_mat.T @ x_seq.T  # (H, T)
         | 
| 158 | 
            +
                    s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :]  # (1, H, 2T-1)
         | 
| 197 159 |  | 
| 198 | 
            -
                     | 
| 199 | 
            -
             | 
| 160 | 
            +
                    if recursive_kernel_calc:
         | 
| 161 | 
            +
                        kernel = self._get_kernel_recursive(x_seq, seq_len)
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        kernel = self._get_kernel(x_seq, seq_len)
         | 
| 200 164 |  | 
| 201 | 
            -
                     | 
| 202 | 
            -
                    s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
         | 
| 165 | 
            +
                    kernel_flipped = jnp.flip(kernel, axis=-1)  # (H, 1, L)
         | 
| 203 166 |  | 
| 204 | 
            -
                    # Perform depthwise (grouped) 1D convolution.
         | 
| 205 | 
            -
                    # We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
         | 
| 206 | 
            -
                    # The dimension_numbers are chosen so that the channel dimension is second.
         | 
| 207 167 | 
             
                    conv_out = jax.lax.conv_general_dilated(
         | 
| 208 168 | 
             
                        s_padded,
         | 
| 209 169 | 
             
                        kernel_flipped,
         | 
| 210 170 | 
             
                        window_strides=(1,),
         | 
| 211 171 | 
             
                        padding="VALID",
         | 
| 212 | 
            -
                        dimension_numbers=(" | 
| 213 | 
            -
                        feature_group_count= | 
| 172 | 
            +
                        dimension_numbers=("NCT", "OIT", "NCT"),  # convolving over time
         | 
| 173 | 
            +
                        feature_group_count=hidden_size,
         | 
| 214 174 | 
             
                    )
         | 
| 215 | 
            -
                     | 
| 216 | 
            -
                    conv_out = jnp.transpose(conv_out, (0, 2, 1))
         | 
| 175 | 
            +
                    conv_out = conv_out[0].T  # (T, H)
         | 
| 217 176 | 
             
                    return conv_out
         | 
| 218 177 |  | 
| 219 | 
            -
                def  | 
| 178 | 
            +
                def _forward_sequence_scan(self, x_seq: Array) -> Array:
         | 
| 220 179 | 
             
                    """Naively forward across time."""
         | 
| 221 180 |  | 
| 222 181 | 
             
                    def step(h: Array, x: Array) -> tuple[Array, Array]:
         | 
| 223 182 | 
             
                        h = self.forward(h, x)
         | 
| 224 183 | 
             
                        return h, h
         | 
| 225 184 |  | 
| 226 | 
            -
                     | 
| 227 | 
            -
                     | 
| 185 | 
            +
                    a_diag = self.get_a_mat(x_seq)
         | 
| 186 | 
            +
                    h_0 = jnp.zeros(a_diag.shape[0])
         | 
| 187 | 
            +
                    _, h_seq = jax.lax.scan(step, h_0, x_seq)
         | 
| 228 188 | 
             
                    return h_seq
         | 
| 229 189 |  | 
| 230 190 |  | 
| 231 | 
            -
            class  | 
| 191 | 
            +
            class DiscreteDiagSSMBlock(DiagSSMBlock):
         | 
| 192 | 
            +
                delta: Array
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def __init__(
         | 
| 195 | 
            +
                    self,
         | 
| 196 | 
            +
                    hidden_size: int,
         | 
| 197 | 
            +
                    *,
         | 
| 198 | 
            +
                    key: PRNGKeyArray,
         | 
| 199 | 
            +
                    init_delta: float = 1.0,
         | 
| 200 | 
            +
                    init_scale: float = 10.0,
         | 
| 201 | 
            +
                ) -> None:
         | 
| 202 | 
            +
                    super().__init__(hidden_size, key=key)
         | 
| 203 | 
            +
                    self.delta = jnp.array(init_delta)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # A positive scale helps reduce the gradient at the start.
         | 
| 206 | 
            +
                    self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def get_a_mat(self, x: Array) -> Array:
         | 
| 209 | 
            +
                    """Discretize the diagonal matrix using zero-order hold."""
         | 
| 210 | 
            +
                    a_diag_discrete = jnp.exp(self.a_diag * self.delta)
         | 
| 211 | 
            +
                    return a_diag_discrete
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                def get_b_mat(self, x: Array) -> Array:
         | 
| 214 | 
            +
                    """Discretize the input matrix using zero-order hold."""
         | 
| 215 | 
            +
                    delta_a_diag = self.a_diag * self.delta
         | 
| 216 | 
            +
                    exp_a_diag = jnp.exp(delta_a_diag)
         | 
| 217 | 
            +
                    delta_a_inv = 1 / delta_a_diag
         | 
| 218 | 
            +
                    delta_b_mat = self.delta * self.b_mat
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
         | 
| 221 | 
            +
                    return b_discrete
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            class SSM(eqx.Module):
         | 
| 232 225 | 
             
                vocab_embedding: eqx.nn.Embedding
         | 
| 233 | 
            -
                 | 
| 234 | 
            -
                proj_out: eqx.nn.Linear
         | 
| 226 | 
            +
                output_layer: eqx.nn.Linear
         | 
| 235 227 | 
             
                blocks: list[BaseSSMBlock]
         | 
| 236 228 | 
             
                num_layers: int = eqx.static_field()
         | 
| 237 229 | 
             
                hidden_size: int = eqx.static_field()
         | 
| @@ -243,24 +235,30 @@ class S4(eqx.Module): | |
| 243 235 | 
             
                    hidden_size: int,
         | 
| 244 236 | 
             
                    output_size: int,
         | 
| 245 237 | 
             
                    num_layers: int,
         | 
| 246 | 
            -
                    block_type: Literal[" | 
| 238 | 
            +
                    block_type: Literal["diagonal", "full_rank"] = "full_rank",
         | 
| 247 239 | 
             
                    skip_connections: bool = False,
         | 
| 240 | 
            +
                    discretize: bool = False,
         | 
| 248 241 | 
             
                    *,
         | 
| 249 242 | 
             
                    key: PRNGKeyArray,
         | 
| 250 243 | 
             
                ) -> None:
         | 
| 251 244 | 
             
                    vocab_key, s4_key = jax.random.split(key, 2)
         | 
| 252 245 | 
             
                    self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
         | 
| 253 | 
            -
                    self. | 
| 254 | 
            -
                    self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
         | 
| 246 | 
            +
                    self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
         | 
| 255 247 |  | 
| 256 248 | 
             
                    block_keys = jax.random.split(s4_key, num_layers)
         | 
| 257 249 |  | 
| 258 250 | 
             
                    def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
         | 
| 259 251 | 
             
                        match block_type:
         | 
| 260 | 
            -
                            case " | 
| 252 | 
            +
                            case "diagonal":
         | 
| 253 | 
            +
                                return (
         | 
| 254 | 
            +
                                    DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
         | 
| 255 | 
            +
                                    if discretize
         | 
| 256 | 
            +
                                    else DiagSSMBlock(hidden_size, key=key)
         | 
| 257 | 
            +
                                )
         | 
| 258 | 
            +
                            case "full_rank":
         | 
| 259 | 
            +
                                if discretize:
         | 
| 260 | 
            +
                                    raise ValueError("Full rank blocks do not support discretization due to instability.")
         | 
| 261 261 | 
             
                                return SSMBlock(hidden_size, key=key)
         | 
| 262 | 
            -
                            case "diag":
         | 
| 263 | 
            -
                                return DiagSSMBlock(hidden_size, key=key)
         | 
| 264 262 | 
             
                            case _:
         | 
| 265 263 | 
             
                                raise ValueError(f"Unknown block type: {block_type}")
         | 
| 266 264 |  | 
| @@ -276,21 +274,43 @@ class S4(eqx.Module): | |
| 276 274 | 
             
                        new_hs.append(h)
         | 
| 277 275 | 
             
                        xh = jax.nn.gelu(h)
         | 
| 278 276 | 
             
                        x = xh + x if self.skip_connections else xh
         | 
| 279 | 
            -
                    y = self. | 
| 277 | 
            +
                    y = self.output_layer(x)
         | 
| 280 278 | 
             
                    return new_hs, y
         | 
| 281 279 |  | 
| 282 280 | 
             
                def _embed_input(self, x: Array) -> Array:
         | 
| 283 281 | 
             
                    """U is the input to the S4 cell."""
         | 
| 284 | 
            -
                     | 
| 285 | 
            -
                    return jax.nn.gelu(self.proj_in(embedded))
         | 
| 282 | 
            +
                    return self.vocab_embedding(x)
         | 
| 286 283 |  | 
| 287 284 | 
             
                def predict_sequence(self, x_seq: Array) -> Array:
         | 
| 288 285 | 
             
                    x_emb = jax.vmap(self._embed_input)(x_seq)
         | 
| 286 | 
            +
                    for block in self.blocks:
         | 
| 287 | 
            +
                        h = block.forward_sequence(x_emb)
         | 
| 288 | 
            +
                        # h = block.naive_forward_sequence(x_emb)
         | 
| 289 | 
            +
                        h = jax.nn.gelu(h)
         | 
| 290 | 
            +
                        x_emb = h + x_emb if self.skip_connections else h
         | 
| 291 | 
            +
                    y = jax.vmap(self.output_layer)(x_emb)
         | 
| 292 | 
            +
                    return y
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
         | 
| 289 295 | 
             
                    hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
         | 
| 296 | 
            +
                    prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
         | 
| 290 297 |  | 
| 291 | 
            -
                    def  | 
| 298 | 
            +
                    def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
         | 
| 292 299 | 
             
                        hs, y = self(hs, x)
         | 
| 293 300 | 
             
                        return hs, y
         | 
| 294 301 |  | 
| 295 | 
            -
                     | 
| 296 | 
            -
             | 
| 302 | 
            +
                    def decode_step(
         | 
| 303 | 
            +
                        carry: tuple[list[Array], Array, PRNGKeyArray],
         | 
| 304 | 
            +
                        _: None,
         | 
| 305 | 
            +
                    ) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
         | 
| 306 | 
            +
                        hs, last_token, rng = carry
         | 
| 307 | 
            +
                        token_embedded = self._embed_input(last_token)
         | 
| 308 | 
            +
                        hs, y = self(hs, token_embedded)
         | 
| 309 | 
            +
                        token = jax.random.categorical(rng, y)
         | 
| 310 | 
            +
                        rng = jax.random.split(rng)[0]
         | 
| 311 | 
            +
                        return (hs, token, rng), token
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
         | 
| 314 | 
            +
                    _, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    return sequence
         | 
    
        xax/task/mixins/train.py
    CHANGED
    
    | @@ -218,26 +218,32 @@ class TrainMixin( | |
| 218 218 | 
             
                    state = super().on_step_end(state)
         | 
| 219 219 | 
             
                    return state.replace(elapsed_time_s=time.time() - state.start_time_s)
         | 
| 220 220 |  | 
| 221 | 
            -
                def log_train_step( | 
| 221 | 
            +
                def log_train_step(
         | 
| 222 | 
            +
                    self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
         | 
| 223 | 
            +
                ) -> None:
         | 
| 222 224 | 
             
                    """Override this function to do logging during the training phase.
         | 
| 223 225 |  | 
| 224 226 | 
             
                    This function is called after the model forward pass and before the
         | 
| 225 227 | 
             
                    backward pass. It is called in the training phase.
         | 
| 226 228 |  | 
| 227 229 | 
             
                    Args:
         | 
| 230 | 
            +
                        model: The current model.
         | 
| 228 231 | 
             
                        batch: The batch from the dataloader.
         | 
| 229 232 | 
             
                        output: The model output.
         | 
| 230 233 | 
             
                        metrics: The metrics for the current batch.
         | 
| 231 234 | 
             
                        state: The current training state.
         | 
| 232 235 | 
             
                    """
         | 
| 233 236 |  | 
| 234 | 
            -
                def log_valid_step( | 
| 237 | 
            +
                def log_valid_step(
         | 
| 238 | 
            +
                    self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
         | 
| 239 | 
            +
                ) -> None:
         | 
| 235 240 | 
             
                    """Override this function to do logging during the validation phase.
         | 
| 236 241 |  | 
| 237 242 | 
             
                    This function is called after the model forward pass. It is called in
         | 
| 238 243 | 
             
                    the validation phase.
         | 
| 239 244 |  | 
| 240 245 | 
             
                    Args:
         | 
| 246 | 
            +
                        model: The current model.
         | 
| 241 247 | 
             
                        batch: The batch from the dataloader.
         | 
| 242 248 | 
             
                        output: The model output.
         | 
| 243 249 | 
             
                        metrics: The metrics for the current batch.
         | 
| @@ -251,7 +257,9 @@ class TrainMixin( | |
| 251 257 | 
             
                        for k, v in d.items():
         | 
| 252 258 | 
             
                            self.logger.log_scalar(k, v, namespace=ns)
         | 
| 253 259 |  | 
| 254 | 
            -
                def log_step( | 
| 260 | 
            +
                def log_step(
         | 
| 261 | 
            +
                    self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
         | 
| 262 | 
            +
                ) -> None:
         | 
| 255 263 | 
             
                    phase = state.phase
         | 
| 256 264 |  | 
| 257 265 | 
             
                    for k, v in metrics.items():
         | 
| @@ -265,9 +273,9 @@ class TrainMixin( | |
| 265 273 | 
             
                    # Delegate to the appropriate logging function based on the phase.
         | 
| 266 274 | 
             
                    match phase:
         | 
| 267 275 | 
             
                        case "train":
         | 
| 268 | 
            -
                            self.log_train_step(batch, output, metrics, state)
         | 
| 276 | 
            +
                            self.log_train_step(model, batch, output, metrics, state)
         | 
| 269 277 | 
             
                        case "valid":
         | 
| 270 | 
            -
                            self.log_valid_step(batch, output, metrics, state)
         | 
| 278 | 
            +
                            self.log_valid_step(model, batch, output, metrics, state)
         | 
| 271 279 | 
             
                        case _:
         | 
| 272 280 | 
             
                            raise KeyError(f"Unknown phase: {phase}")
         | 
| 273 281 |  | 
| @@ -579,7 +587,7 @@ class TrainMixin( | |
| 579 587 | 
             
                            )
         | 
| 580 588 |  | 
| 581 589 | 
             
                            output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
         | 
| 582 | 
            -
                            self.log_step(valid_batch, output, metrics, state)
         | 
| 590 | 
            +
                            self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
         | 
| 583 591 |  | 
| 584 592 | 
             
                        state = self.on_step_start(state)
         | 
| 585 593 | 
             
                        train_batch = next(train_pf)
         | 
| @@ -597,7 +605,7 @@ class TrainMixin( | |
| 597 605 | 
             
                            batch=train_batch,
         | 
| 598 606 | 
             
                            state=state,
         | 
| 599 607 | 
             
                        )
         | 
| 600 | 
            -
                        self.log_step(train_batch, output, metrics, state)
         | 
| 608 | 
            +
                        self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
         | 
| 601 609 |  | 
| 602 610 | 
             
                        state = self.on_step_end(state)
         | 
| 603 611 |  | 
| @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            xax/__init__.py,sha256= | 
| 1 | 
            +
            xax/__init__.py,sha256=7vdTYO7jAJdDxKZURlFxc3Y5kr5mVQcTQjeh_sYjD6I,13834
         | 
| 2 2 | 
             
            xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 3 3 | 
             
            xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
         | 
| 4 4 | 
             
            xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
         | 
| @@ -10,11 +10,11 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847 | |
| 10 10 | 
             
            xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
         | 
| 11 11 | 
             
            xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
         | 
| 12 12 | 
             
            xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
         | 
| 13 | 
            -
            xax/nn/geom.py,sha256= | 
| 13 | 
            +
            xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
         | 
| 14 14 | 
             
            xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
         | 
| 15 15 | 
             
            xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
         | 
| 16 16 | 
             
            xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
         | 
| 17 | 
            -
            xax/nn/ssm.py,sha256= | 
| 17 | 
            +
            xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
         | 
| 18 18 | 
             
            xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 19 19 | 
             
            xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
         | 
| 20 20 | 
             
            xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
         | 
| @@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280 | |
| 41 41 | 
             
            xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
         | 
| 42 42 | 
             
            xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
         | 
| 43 43 | 
             
            xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
         | 
| 44 | 
            -
            xax/task/mixins/train.py,sha256= | 
| 44 | 
            +
            xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
         | 
| 45 45 | 
             
            xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 46 46 | 
             
            xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
         | 
| 47 47 | 
             
            xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
         | 
| @@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706 | |
| 58 58 | 
             
            xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 59 59 | 
             
            xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
         | 
| 60 60 | 
             
            xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
         | 
| 61 | 
            -
            xax-0.1. | 
| 62 | 
            -
            xax-0.1. | 
| 63 | 
            -
            xax-0.1. | 
| 64 | 
            -
            xax-0.1. | 
| 65 | 
            -
            xax-0.1. | 
| 61 | 
            +
            xax-0.1.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
         | 
| 62 | 
            +
            xax-0.1.12.dist-info/METADATA,sha256=hLRAX5__7QjBgjzhxbRftGvEsNrt8IAdgd22dMtHu_Y,1878
         | 
| 63 | 
            +
            xax-0.1.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
         | 
| 64 | 
            +
            xax-0.1.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
         | 
| 65 | 
            +
            xax-0.1.12.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |