xax 0.1.11__tar.gz → 0.1.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.11/xax.egg-info → xax-0.1.13}/PKG-INFO +1 -1
- {xax-0.1.11 → xax-0.1.13}/xax/__init__.py +12 -11
- {xax-0.1.11 → xax-0.1.13}/xax/nn/geom.py +57 -0
- xax-0.1.13/xax/nn/ssm.py +316 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/loggers/stdout.py +5 -6
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/train.py +15 -7
- {xax-0.1.11 → xax-0.1.13}/xax/utils/debugging.py +6 -0
- {xax-0.1.11 → xax-0.1.13/xax.egg-info}/PKG-INFO +1 -1
- xax-0.1.11/xax/nn/ssm.py +0 -296
- {xax-0.1.11 → xax-0.1.13}/LICENSE +0 -0
- {xax-0.1.11 → xax-0.1.13}/MANIFEST.in +0 -0
- {xax-0.1.11 → xax-0.1.13}/README.md +0 -0
- {xax-0.1.11 → xax-0.1.13}/pyproject.toml +0 -0
- {xax-0.1.11 → xax-0.1.13}/setup.cfg +0 -0
- {xax-0.1.11 → xax-0.1.13}/setup.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/core/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/core/conf.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/core/state.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/embeddings.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/equinox.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/export.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/functions.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/losses.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/norm.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/nn/parallel.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/py.typed +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/requirements-dev.txt +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/requirements.txt +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/base.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/launchers/base.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/logger.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/loggers/json.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/loggers/state.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/process.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/script.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/task/task.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/data/collate.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/experiments.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/jax.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/logging.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/numpy.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/profile.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/pytree.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/tensorboard.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/text.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/types/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.11 → xax-0.1.13}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.13"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,15 +43,14 @@ __all__ = [
|
|
43
43
|
"euler_to_quat",
|
44
44
|
"get_projected_gravity_vector_from_quat",
|
45
45
|
"quat_to_euler",
|
46
|
+
"rotate_vector_by_quat",
|
46
47
|
"cross_entropy",
|
47
48
|
"cast_norm_type",
|
48
49
|
"get_norm",
|
49
50
|
"is_master",
|
51
|
+
"BaseSSMBlock",
|
50
52
|
"DiagSSMBlock",
|
51
|
-
"
|
52
|
-
"S4",
|
53
|
-
"S4Layer",
|
54
|
-
"S6Layer",
|
53
|
+
"SSM",
|
55
54
|
"SSMBlock",
|
56
55
|
"BaseLauncher",
|
57
56
|
"CliLauncher",
|
@@ -76,6 +75,7 @@ __all__ = [
|
|
76
75
|
"Task",
|
77
76
|
"collate",
|
78
77
|
"collate_non_null",
|
78
|
+
"breakpoint_if_nan",
|
79
79
|
"get_named_leaves",
|
80
80
|
"BaseFileDownloader",
|
81
81
|
"ContextTimer",
|
@@ -203,15 +203,14 @@ NAME_MAP: dict[str, str] = {
|
|
203
203
|
"euler_to_quat": "nn.geom",
|
204
204
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
205
205
|
"quat_to_euler": "nn.geom",
|
206
|
+
"rotate_vector_by_quat": "nn.geom",
|
206
207
|
"cross_entropy": "nn.losses",
|
207
208
|
"cast_norm_type": "nn.norm",
|
208
209
|
"get_norm": "nn.norm",
|
209
210
|
"is_master": "nn.parallel",
|
211
|
+
"BaseSSMBlock": "nn.ssm",
|
210
212
|
"DiagSSMBlock": "nn.ssm",
|
211
|
-
"
|
212
|
-
"S4": "nn.ssm",
|
213
|
-
"S4Layer": "nn.ssm",
|
214
|
-
"S6Layer": "nn.ssm",
|
213
|
+
"SSM": "nn.ssm",
|
215
214
|
"SSMBlock": "nn.ssm",
|
216
215
|
"BaseLauncher": "task.launchers.base",
|
217
216
|
"CliLauncher": "task.launchers.cli",
|
@@ -236,6 +235,7 @@ NAME_MAP: dict[str, str] = {
|
|
236
235
|
"Task": "task.task",
|
237
236
|
"collate": "utils.data.collate",
|
238
237
|
"collate_non_null": "utils.data.collate",
|
238
|
+
"breakpoint_if_nan": "utils.debugging",
|
239
239
|
"get_named_leaves": "utils.debugging",
|
240
240
|
"BaseFileDownloader": "utils.experiments",
|
241
241
|
"ContextTimer": "utils.experiments",
|
@@ -364,11 +364,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
364
364
|
euler_to_quat,
|
365
365
|
get_projected_gravity_vector_from_quat,
|
366
366
|
quat_to_euler,
|
367
|
+
rotate_vector_by_quat,
|
367
368
|
)
|
368
369
|
from xax.nn.losses import cross_entropy
|
369
370
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
370
371
|
from xax.nn.parallel import is_master
|
371
|
-
from xax.nn.ssm import
|
372
|
+
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
372
373
|
from xax.task.base import RawConfigType
|
373
374
|
from xax.task.launchers.base import BaseLauncher
|
374
375
|
from xax.task.launchers.cli import CliLauncher
|
@@ -387,7 +388,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
387
388
|
from xax.task.script import Script, ScriptConfig
|
388
389
|
from xax.task.task import Config, Task
|
389
390
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
390
|
-
from xax.utils.debugging import get_named_leaves
|
391
|
+
from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
|
391
392
|
from xax.utils.experiments import (
|
392
393
|
BaseFileDownloader,
|
393
394
|
ContextTimer,
|
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
|
|
99
99
|
|
100
100
|
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
101
101
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
102
|
+
|
103
|
+
|
104
|
+
def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
|
105
|
+
"""Rotates a vector by a quaternion.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
vector: The vector to rotate, shape (*, 3).
|
109
|
+
quat: The quaternion to rotate by, shape (*, 4).
|
110
|
+
eps: A small epsilon value to avoid division by zero.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The rotated vector, shape (*, 3).
|
114
|
+
"""
|
115
|
+
# Normalize quaternion
|
116
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
117
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
118
|
+
|
119
|
+
# Extract vector components
|
120
|
+
vx, vy, vz = jnp.split(vector, 3, axis=-1)
|
121
|
+
|
122
|
+
# Terms for x component
|
123
|
+
xx = (
|
124
|
+
w * w * vx
|
125
|
+
+ 2 * y * w * vz
|
126
|
+
- 2 * z * w * vy
|
127
|
+
+ x * x * vx
|
128
|
+
+ 2 * y * x * vy
|
129
|
+
+ 2 * z * x * vz
|
130
|
+
- z * z * vx
|
131
|
+
- y * y * vx
|
132
|
+
)
|
133
|
+
|
134
|
+
# Terms for y component
|
135
|
+
yy = (
|
136
|
+
2 * x * y * vx
|
137
|
+
+ y * y * vy
|
138
|
+
+ 2 * z * y * vz
|
139
|
+
+ 2 * w * z * vx
|
140
|
+
- z * z * vy
|
141
|
+
+ w * w * vy
|
142
|
+
- 2 * w * x * vz
|
143
|
+
- x * x * vy
|
144
|
+
)
|
145
|
+
|
146
|
+
# Terms for z component
|
147
|
+
zz = (
|
148
|
+
2 * x * z * vx
|
149
|
+
+ 2 * y * z * vy
|
150
|
+
+ z * z * vz
|
151
|
+
- 2 * w * y * vx
|
152
|
+
+ w * w * vz
|
153
|
+
+ 2 * w * x * vy
|
154
|
+
- y * y * vz
|
155
|
+
- x * x * vz
|
156
|
+
)
|
157
|
+
|
158
|
+
return jnp.concatenate([xx, yy, zz], axis=-1)
|
xax-0.1.13/xax/nn/ssm.py
ADDED
@@ -0,0 +1,316 @@
|
|
1
|
+
"""State space models."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
+
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
+
|
15
|
+
|
16
|
+
class BaseSSMBlock(eqx.Module, ABC):
|
17
|
+
@abstractmethod
|
18
|
+
def forward(self, h: Array, x: Array) -> Array: ...
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def forward_sequence(self, x_seq: Array) -> Array: ...
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def get_a_mat(self, x: Array) -> Array: ...
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def get_b_mat(self, x: Array) -> Array: ...
|
28
|
+
|
29
|
+
|
30
|
+
class SSMBlock(BaseSSMBlock):
|
31
|
+
a_mat: Array
|
32
|
+
b_mat: Array
|
33
|
+
|
34
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
35
|
+
key_a, key_b = jax.random.split(key)
|
36
|
+
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
37
|
+
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
38
|
+
|
39
|
+
def get_a_mat(self, x: Array) -> Array:
|
40
|
+
return self.a_mat
|
41
|
+
|
42
|
+
def get_b_mat(self, x: Array) -> Array:
|
43
|
+
return self.b_mat
|
44
|
+
|
45
|
+
def forward(self, h: Array, x: Array) -> Array:
|
46
|
+
"""Perform a forward pass.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
h: Hidden state of shape (H,).
|
50
|
+
x: Input of shape (H,).
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Hidden state of shape (H,).
|
54
|
+
"""
|
55
|
+
a_mat = self.get_a_mat(x)
|
56
|
+
b_mat = self.get_b_mat(x)
|
57
|
+
h = a_mat @ h + b_mat.T @ x
|
58
|
+
return h
|
59
|
+
|
60
|
+
def forward_sequence(self, x_seq: Array) -> Array:
|
61
|
+
"""Perform a forward pass across time.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
x_seq: Input sequence of shape (T, H).
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Hidden state sequence of shape (T, H).
|
68
|
+
"""
|
69
|
+
|
70
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
71
|
+
h = self.forward(h, x)
|
72
|
+
return h, h
|
73
|
+
|
74
|
+
a_mat = self.get_a_mat(x_seq)
|
75
|
+
h_0 = jnp.zeros(a_mat.shape[0])
|
76
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
77
|
+
return h_seq
|
78
|
+
|
79
|
+
|
80
|
+
class DiagSSMBlock(BaseSSMBlock):
|
81
|
+
a_diag: Array
|
82
|
+
b_mat: Array
|
83
|
+
|
84
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
85
|
+
keys = jax.random.split(key, 2)
|
86
|
+
self.a_diag = glorot(keys[0], (hidden_size,))
|
87
|
+
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
88
|
+
|
89
|
+
def get_a_mat(self, x: Array) -> Array:
|
90
|
+
return self.a_diag
|
91
|
+
|
92
|
+
def get_b_mat(self, x: Array) -> Array:
|
93
|
+
return self.b_mat
|
94
|
+
|
95
|
+
def forward(self, h: Array, x: Array) -> Array:
|
96
|
+
"""Perform a forward pass.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
h: Hidden state of shape (H,).
|
100
|
+
x: Input of shape (H,).
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Hidden state of shape (H,).
|
104
|
+
"""
|
105
|
+
a_diag = self.get_a_mat(x)
|
106
|
+
b_mat = self.get_b_mat(x)
|
107
|
+
h = a_diag * h + b_mat.T @ x
|
108
|
+
return h
|
109
|
+
|
110
|
+
def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
|
111
|
+
"""Perform a potentially parallelized forward pass across time.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
x_seq: Input sequence of shape (T, H).
|
115
|
+
use_conv: Whether to use convolution to compute the sequence.
|
116
|
+
recursive_kernel_calc: Whether to use a recursive kernel calculation.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Hidden state sequence of shape (T, H).
|
120
|
+
"""
|
121
|
+
if use_conv:
|
122
|
+
return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
|
123
|
+
else:
|
124
|
+
return self._forward_sequence_scan(x_seq)
|
125
|
+
|
126
|
+
def _get_kernel(self, x_seq: Array, length: int) -> Array:
|
127
|
+
"""Returns the kernel with time as the final dimension."""
|
128
|
+
exponents = jnp.arange(length)
|
129
|
+
a_diag = self.get_a_mat(x_seq)
|
130
|
+
kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
|
131
|
+
kernel = kernel[:, None, :] # (H, 1, T)
|
132
|
+
return kernel
|
133
|
+
|
134
|
+
def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
|
135
|
+
"""Returns the kernel with time as the final dimension."""
|
136
|
+
assert length % 2 == 0, "Length must be even."
|
137
|
+
a_diag = self.get_a_mat(x_seq)
|
138
|
+
|
139
|
+
def helper(length: int) -> tuple[Array, Array]:
|
140
|
+
"""Returns the kernel and the sqrt of the diagonal."""
|
141
|
+
if length == 1:
|
142
|
+
return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
|
143
|
+
|
144
|
+
half_length = length // 2
|
145
|
+
kernel_half, a_half = helper(half_length)
|
146
|
+
kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
|
147
|
+
return kernel, a_half * a_half
|
148
|
+
|
149
|
+
kernel, a_diag = helper(length)
|
150
|
+
return kernel[:, None, :] # (H, 1, L)
|
151
|
+
|
152
|
+
def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
|
153
|
+
"""Convolves x (T, H) across time using the kernel."""
|
154
|
+
seq_len, hidden_size = x_seq.shape
|
155
|
+
b_mat = self.get_b_mat(x_seq)
|
156
|
+
|
157
|
+
s = b_mat.T @ x_seq.T # (H, T)
|
158
|
+
s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
|
159
|
+
|
160
|
+
if recursive_kernel_calc:
|
161
|
+
kernel = self._get_kernel_recursive(x_seq, seq_len)
|
162
|
+
else:
|
163
|
+
kernel = self._get_kernel(x_seq, seq_len)
|
164
|
+
|
165
|
+
kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
|
166
|
+
|
167
|
+
conv_out = jax.lax.conv_general_dilated(
|
168
|
+
s_padded,
|
169
|
+
kernel_flipped,
|
170
|
+
window_strides=(1,),
|
171
|
+
padding="VALID",
|
172
|
+
dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
|
173
|
+
feature_group_count=hidden_size,
|
174
|
+
)
|
175
|
+
conv_out = conv_out[0].T # (T, H)
|
176
|
+
return conv_out
|
177
|
+
|
178
|
+
def _forward_sequence_scan(self, x_seq: Array) -> Array:
|
179
|
+
"""Naively forward across time."""
|
180
|
+
|
181
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
182
|
+
h = self.forward(h, x)
|
183
|
+
return h, h
|
184
|
+
|
185
|
+
a_diag = self.get_a_mat(x_seq)
|
186
|
+
h_0 = jnp.zeros(a_diag.shape[0])
|
187
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
188
|
+
return h_seq
|
189
|
+
|
190
|
+
|
191
|
+
class DiscreteDiagSSMBlock(DiagSSMBlock):
|
192
|
+
delta: Array
|
193
|
+
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
hidden_size: int,
|
197
|
+
*,
|
198
|
+
key: PRNGKeyArray,
|
199
|
+
init_delta: float = 1.0,
|
200
|
+
init_scale: float = 10.0,
|
201
|
+
) -> None:
|
202
|
+
super().__init__(hidden_size, key=key)
|
203
|
+
self.delta = jnp.array(init_delta)
|
204
|
+
|
205
|
+
# A positive scale helps reduce the gradient at the start.
|
206
|
+
self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
|
207
|
+
|
208
|
+
def get_a_mat(self, x: Array) -> Array:
|
209
|
+
"""Discretize the diagonal matrix using zero-order hold."""
|
210
|
+
a_diag_discrete = jnp.exp(self.a_diag * self.delta)
|
211
|
+
return a_diag_discrete
|
212
|
+
|
213
|
+
def get_b_mat(self, x: Array) -> Array:
|
214
|
+
"""Discretize the input matrix using zero-order hold."""
|
215
|
+
delta_a_diag = self.a_diag * self.delta
|
216
|
+
exp_a_diag = jnp.exp(delta_a_diag)
|
217
|
+
delta_a_inv = 1 / delta_a_diag
|
218
|
+
delta_b_mat = self.delta * self.b_mat
|
219
|
+
|
220
|
+
b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
|
221
|
+
return b_discrete
|
222
|
+
|
223
|
+
|
224
|
+
class SSM(eqx.Module):
|
225
|
+
vocab_embedding: eqx.nn.Embedding
|
226
|
+
output_layer: eqx.nn.Linear
|
227
|
+
blocks: list[BaseSSMBlock]
|
228
|
+
num_layers: int = eqx.static_field()
|
229
|
+
hidden_size: int = eqx.static_field()
|
230
|
+
skip_connections: bool = eqx.static_field()
|
231
|
+
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
input_size: int,
|
235
|
+
hidden_size: int,
|
236
|
+
output_size: int,
|
237
|
+
num_layers: int,
|
238
|
+
block_type: Literal["diagonal", "full_rank"] = "full_rank",
|
239
|
+
skip_connections: bool = False,
|
240
|
+
discretize: bool = False,
|
241
|
+
*,
|
242
|
+
key: PRNGKeyArray,
|
243
|
+
) -> None:
|
244
|
+
vocab_key, s4_key = jax.random.split(key, 2)
|
245
|
+
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
246
|
+
self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
|
247
|
+
|
248
|
+
block_keys = jax.random.split(s4_key, num_layers)
|
249
|
+
|
250
|
+
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
251
|
+
match block_type:
|
252
|
+
case "diagonal":
|
253
|
+
return (
|
254
|
+
DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
|
255
|
+
if discretize
|
256
|
+
else DiagSSMBlock(hidden_size, key=key)
|
257
|
+
)
|
258
|
+
case "full_rank":
|
259
|
+
if discretize:
|
260
|
+
raise ValueError("Full rank blocks do not support discretization due to instability.")
|
261
|
+
return SSMBlock(hidden_size, key=key)
|
262
|
+
case _:
|
263
|
+
raise ValueError(f"Unknown block type: {block_type}")
|
264
|
+
|
265
|
+
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
266
|
+
self.skip_connections = skip_connections
|
267
|
+
self.num_layers = num_layers
|
268
|
+
self.hidden_size = hidden_size
|
269
|
+
|
270
|
+
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
271
|
+
new_hs = []
|
272
|
+
for i, block in enumerate(self.blocks):
|
273
|
+
h = block.forward(hs[i], x)
|
274
|
+
new_hs.append(h)
|
275
|
+
xh = jax.nn.gelu(h)
|
276
|
+
x = xh + x if self.skip_connections else xh
|
277
|
+
y = self.output_layer(x)
|
278
|
+
return new_hs, y
|
279
|
+
|
280
|
+
def _embed_input(self, x: Array) -> Array:
|
281
|
+
"""U is the input to the S4 cell."""
|
282
|
+
return self.vocab_embedding(x)
|
283
|
+
|
284
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
285
|
+
x_emb = jax.vmap(self._embed_input)(x_seq)
|
286
|
+
for block in self.blocks:
|
287
|
+
h = block.forward_sequence(x_emb)
|
288
|
+
# h = block.naive_forward_sequence(x_emb)
|
289
|
+
h = jax.nn.gelu(h)
|
290
|
+
x_emb = h + x_emb if self.skip_connections else h
|
291
|
+
y = jax.vmap(self.output_layer)(x_emb)
|
292
|
+
return y
|
293
|
+
|
294
|
+
def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
|
295
|
+
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
296
|
+
prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
|
297
|
+
|
298
|
+
def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
299
|
+
hs, y = self(hs, x)
|
300
|
+
return hs, y
|
301
|
+
|
302
|
+
def decode_step(
|
303
|
+
carry: tuple[list[Array], Array, PRNGKeyArray],
|
304
|
+
_: None,
|
305
|
+
) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
|
306
|
+
hs, last_token, rng = carry
|
307
|
+
token_embedded = self._embed_input(last_token)
|
308
|
+
hs, y = self(hs, token_embedded)
|
309
|
+
token = jax.random.categorical(rng, y)
|
310
|
+
rng = jax.random.split(rng)[0]
|
311
|
+
return (hs, token, rng), token
|
312
|
+
|
313
|
+
hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
|
314
|
+
_, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
|
315
|
+
|
316
|
+
return sequence
|
@@ -14,7 +14,7 @@ from xax.utils.text import Color, colored, format_timedelta
|
|
14
14
|
|
15
15
|
def format_number(value: int | float, precision: int) -> str:
|
16
16
|
if isinstance(value, int):
|
17
|
-
return
|
17
|
+
return f"{value:,}" # Add commas to the number
|
18
18
|
return f"{value:.{precision}g}"
|
19
19
|
|
20
20
|
|
@@ -80,11 +80,10 @@ class StdoutLogger(LoggerImpl):
|
|
80
80
|
self.write_fp.write("\033[2J\033[H")
|
81
81
|
|
82
82
|
def write_state_window(self, line: LogLine) -> None:
|
83
|
-
|
84
|
-
|
85
|
-
"
|
86
|
-
"
|
87
|
-
"Elapsed Time": f"{elapsed_time}",
|
83
|
+
state_info: dict[str, str] = {
|
84
|
+
"Steps": format_number(line.state.num_steps, 0),
|
85
|
+
"Samples": format_number(line.state.num_samples, 0),
|
86
|
+
"Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
|
88
87
|
}
|
89
88
|
|
90
89
|
colored_prefix = colored("Phase: ", "grey", bold=True)
|
@@ -218,26 +218,32 @@ class TrainMixin(
|
|
218
218
|
state = super().on_step_end(state)
|
219
219
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
220
220
|
|
221
|
-
def log_train_step(
|
221
|
+
def log_train_step(
|
222
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
223
|
+
) -> None:
|
222
224
|
"""Override this function to do logging during the training phase.
|
223
225
|
|
224
226
|
This function is called after the model forward pass and before the
|
225
227
|
backward pass. It is called in the training phase.
|
226
228
|
|
227
229
|
Args:
|
230
|
+
model: The current model.
|
228
231
|
batch: The batch from the dataloader.
|
229
232
|
output: The model output.
|
230
233
|
metrics: The metrics for the current batch.
|
231
234
|
state: The current training state.
|
232
235
|
"""
|
233
236
|
|
234
|
-
def log_valid_step(
|
237
|
+
def log_valid_step(
|
238
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
239
|
+
) -> None:
|
235
240
|
"""Override this function to do logging during the validation phase.
|
236
241
|
|
237
242
|
This function is called after the model forward pass. It is called in
|
238
243
|
the validation phase.
|
239
244
|
|
240
245
|
Args:
|
246
|
+
model: The current model.
|
241
247
|
batch: The batch from the dataloader.
|
242
248
|
output: The model output.
|
243
249
|
metrics: The metrics for the current batch.
|
@@ -251,7 +257,9 @@ class TrainMixin(
|
|
251
257
|
for k, v in d.items():
|
252
258
|
self.logger.log_scalar(k, v, namespace=ns)
|
253
259
|
|
254
|
-
def log_step(
|
260
|
+
def log_step(
|
261
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
262
|
+
) -> None:
|
255
263
|
phase = state.phase
|
256
264
|
|
257
265
|
for k, v in metrics.items():
|
@@ -265,9 +273,9 @@ class TrainMixin(
|
|
265
273
|
# Delegate to the appropriate logging function based on the phase.
|
266
274
|
match phase:
|
267
275
|
case "train":
|
268
|
-
self.log_train_step(batch, output, metrics, state)
|
276
|
+
self.log_train_step(model, batch, output, metrics, state)
|
269
277
|
case "valid":
|
270
|
-
self.log_valid_step(batch, output, metrics, state)
|
278
|
+
self.log_valid_step(model, batch, output, metrics, state)
|
271
279
|
case _:
|
272
280
|
raise KeyError(f"Unknown phase: {phase}")
|
273
281
|
|
@@ -579,7 +587,7 @@ class TrainMixin(
|
|
579
587
|
)
|
580
588
|
|
581
589
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
582
|
-
self.log_step(valid_batch, output, metrics, state)
|
590
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
583
591
|
|
584
592
|
state = self.on_step_start(state)
|
585
593
|
train_batch = next(train_pf)
|
@@ -597,7 +605,7 @@ class TrainMixin(
|
|
597
605
|
batch=train_batch,
|
598
606
|
state=state,
|
599
607
|
)
|
600
|
-
self.log_step(train_batch, output, metrics, state)
|
608
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
601
609
|
|
602
610
|
state = self.on_step_end(state)
|
603
611
|
|
@@ -4,6 +4,8 @@ from collections import deque
|
|
4
4
|
from collections.abc import Iterable, Mapping
|
5
5
|
from typing import Any, Callable, Deque
|
6
6
|
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
7
9
|
from jaxtyping import Array
|
8
10
|
|
9
11
|
|
@@ -47,3 +49,7 @@ def get_named_leaves(
|
|
47
49
|
q.append((depth + 1, gname, cnode))
|
48
50
|
|
49
51
|
return ret
|
52
|
+
|
53
|
+
|
54
|
+
def breakpoint_if_nan(x: Array) -> None:
|
55
|
+
jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
|
xax-0.1.11/xax/nn/ssm.py
DELETED
@@ -1,296 +0,0 @@
|
|
1
|
-
"""State space models."""
|
2
|
-
|
3
|
-
from abc import ABC, abstractmethod
|
4
|
-
from typing import Literal
|
5
|
-
|
6
|
-
import equinox as eqx
|
7
|
-
import jax
|
8
|
-
import jax.numpy as jnp
|
9
|
-
from jaxtyping import Array, PRNGKeyArray
|
10
|
-
|
11
|
-
|
12
|
-
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
-
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
-
|
15
|
-
|
16
|
-
class DiscreteTimeS4(eqx.Module):
|
17
|
-
a: Array
|
18
|
-
B: Array
|
19
|
-
C: Array
|
20
|
-
proj_in: eqx.nn.Linear
|
21
|
-
proj_out: eqx.nn.Linear
|
22
|
-
|
23
|
-
def __init__(
|
24
|
-
self,
|
25
|
-
hidden_size: int,
|
26
|
-
projection_size: int,
|
27
|
-
input_size: int,
|
28
|
-
output_size: int,
|
29
|
-
*,
|
30
|
-
key: PRNGKeyArray,
|
31
|
-
) -> None:
|
32
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
33
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
34
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
35
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
36
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
37
|
-
|
38
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
39
|
-
h = self.a * h + self.B.T @ x
|
40
|
-
y = self.C.T @ h
|
41
|
-
return h, y
|
42
|
-
|
43
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
44
|
-
x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
|
45
|
-
h = jnp.zeros(self.a.shape[0])
|
46
|
-
|
47
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
48
|
-
h = self.a * h + self.B.T @ x
|
49
|
-
y = self.C.T @ h
|
50
|
-
return h, y
|
51
|
-
|
52
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
53
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
54
|
-
return y_out
|
55
|
-
|
56
|
-
|
57
|
-
class S4Layer(eqx.Module):
|
58
|
-
a: Array
|
59
|
-
B: Array
|
60
|
-
C: Array
|
61
|
-
proj_in: eqx.nn.Linear
|
62
|
-
proj_out: eqx.nn.Linear
|
63
|
-
delta: Array
|
64
|
-
|
65
|
-
def __init__(
|
66
|
-
self,
|
67
|
-
hidden_size: int,
|
68
|
-
projection_size: int,
|
69
|
-
input_size: int,
|
70
|
-
output_size: int,
|
71
|
-
*,
|
72
|
-
key: PRNGKeyArray,
|
73
|
-
) -> None:
|
74
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
75
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
76
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
77
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
78
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
79
|
-
self.delta = jax.random.uniform(key, (hidden_size,))
|
80
|
-
|
81
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
82
|
-
delta_a = self.delta * self.a
|
83
|
-
a_bar = jnp.exp(delta_a)
|
84
|
-
b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
|
85
|
-
h = a_bar * h + b_bar.T @ x
|
86
|
-
y = self.C.T @ h
|
87
|
-
return h, y
|
88
|
-
|
89
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
90
|
-
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
91
|
-
h = jnp.zeros(self.a.shape[0])
|
92
|
-
|
93
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
94
|
-
h = self.a * h + self.B.T @ x
|
95
|
-
y = self.C.T @ h
|
96
|
-
return h, y
|
97
|
-
|
98
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
99
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
100
|
-
return y_out
|
101
|
-
|
102
|
-
|
103
|
-
class S6Layer(eqx.Module):
|
104
|
-
a: Array
|
105
|
-
B: Array
|
106
|
-
C: Array
|
107
|
-
proj_in: eqx.nn.Linear
|
108
|
-
proj_out: eqx.nn.Linear
|
109
|
-
delta: Array
|
110
|
-
|
111
|
-
def __init__(
|
112
|
-
self,
|
113
|
-
hidden_size: int,
|
114
|
-
projection_size: int,
|
115
|
-
input_size: int,
|
116
|
-
output_size: int,
|
117
|
-
*,
|
118
|
-
key: PRNGKeyArray,
|
119
|
-
) -> None:
|
120
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
121
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
122
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
123
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
124
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
125
|
-
self.delta = jax.random.uniform(key, (hidden_size,))
|
126
|
-
|
127
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
128
|
-
h = self.a * h + self.B.T @ x
|
129
|
-
y = self.C.T @ h
|
130
|
-
return h, y
|
131
|
-
|
132
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
133
|
-
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
134
|
-
h = jnp.zeros(self.a.shape[0])
|
135
|
-
|
136
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
137
|
-
h = self.a * h + self.B.T @ x
|
138
|
-
y = self.C.T @ h
|
139
|
-
return h, y
|
140
|
-
|
141
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
142
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
143
|
-
return y_out
|
144
|
-
|
145
|
-
|
146
|
-
class BaseSSMBlock(eqx.Module, ABC):
|
147
|
-
@abstractmethod
|
148
|
-
def forward(self, h: Array, x: Array) -> Array:
|
149
|
-
pass
|
150
|
-
|
151
|
-
|
152
|
-
class SSMBlock(BaseSSMBlock):
|
153
|
-
a_mat: Array
|
154
|
-
b_mat: Array
|
155
|
-
|
156
|
-
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
157
|
-
key_a, key_b = jax.random.split(key)
|
158
|
-
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
159
|
-
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
160
|
-
|
161
|
-
def forward(self, h: Array, x: Array) -> Array:
|
162
|
-
h = self.a_mat @ h + self.b_mat.T @ x
|
163
|
-
return h
|
164
|
-
|
165
|
-
def get_kernel(self, length: int) -> Array:
|
166
|
-
return self.a_mat
|
167
|
-
|
168
|
-
|
169
|
-
class DiagSSMBlock(BaseSSMBlock):
|
170
|
-
a_mat: Array
|
171
|
-
b_mat: Array
|
172
|
-
|
173
|
-
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
174
|
-
keys = jax.random.split(key, 2)
|
175
|
-
self.a_mat = glorot(keys[0], (hidden_size,))
|
176
|
-
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
177
|
-
|
178
|
-
def forward(self, h: Array, x: Array) -> Array:
|
179
|
-
h = self.a_mat * h + self.b_mat.T @ x
|
180
|
-
h = jax.nn.tanh(h)
|
181
|
-
return h
|
182
|
-
|
183
|
-
def get_kernel(self, length: int) -> Array:
|
184
|
-
"""Returns the kernel with time as the final dimension."""
|
185
|
-
exponents = jnp.arange(length)
|
186
|
-
kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
|
187
|
-
kernel = kernel[:, None, :] # (H, 1, L)
|
188
|
-
return kernel
|
189
|
-
|
190
|
-
def forward_across_time(self, x: Array) -> Array:
|
191
|
-
"""Convolves x (T, H) across time using the kernel."""
|
192
|
-
tsz, nhid = x.shape
|
193
|
-
|
194
|
-
# Compute s = x @ U.T + b, with shape (N, T, H)
|
195
|
-
s = self.b_mat.T @ x
|
196
|
-
s = s.T # (H, T)
|
197
|
-
|
198
|
-
kernel = self.get_kernel(tsz) # (H, 1, T)
|
199
|
-
kernel_flipped = jnp.flip(kernel, axis=-1)
|
200
|
-
|
201
|
-
# Pad s on the left along the time axis (pad length T-1)
|
202
|
-
s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
|
203
|
-
|
204
|
-
# Perform depthwise (grouped) 1D convolution.
|
205
|
-
# We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
|
206
|
-
# The dimension_numbers are chosen so that the channel dimension is second.
|
207
|
-
conv_out = jax.lax.conv_general_dilated(
|
208
|
-
s_padded,
|
209
|
-
kernel_flipped,
|
210
|
-
window_strides=(1,),
|
211
|
-
padding="VALID",
|
212
|
-
dimension_numbers=("NCH", "OIH", "NCH"),
|
213
|
-
feature_group_count=nhid,
|
214
|
-
)
|
215
|
-
# conv_out has shape (N, H, T); transpose to (N, T, H)
|
216
|
-
conv_out = jnp.transpose(conv_out, (0, 2, 1))
|
217
|
-
return conv_out
|
218
|
-
|
219
|
-
def naive_forward_accross_time(self, x: Array) -> Array:
|
220
|
-
"""Naively forward across time."""
|
221
|
-
|
222
|
-
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
223
|
-
h = self.forward(h, x)
|
224
|
-
return h, h
|
225
|
-
|
226
|
-
h_0 = jnp.zeros(self.a_mat.shape[0])
|
227
|
-
_, h_seq = jax.lax.scan(step, h_0, x)
|
228
|
-
return h_seq
|
229
|
-
|
230
|
-
|
231
|
-
class S4(eqx.Module):
|
232
|
-
vocab_embedding: eqx.nn.Embedding
|
233
|
-
proj_in: eqx.nn.Linear
|
234
|
-
proj_out: eqx.nn.Linear
|
235
|
-
blocks: list[BaseSSMBlock]
|
236
|
-
num_layers: int = eqx.static_field()
|
237
|
-
hidden_size: int = eqx.static_field()
|
238
|
-
skip_connections: bool = eqx.static_field()
|
239
|
-
|
240
|
-
def __init__(
|
241
|
-
self,
|
242
|
-
input_size: int,
|
243
|
-
hidden_size: int,
|
244
|
-
output_size: int,
|
245
|
-
num_layers: int,
|
246
|
-
block_type: Literal["ssm", "diag"] = "ssm",
|
247
|
-
skip_connections: bool = False,
|
248
|
-
*,
|
249
|
-
key: PRNGKeyArray,
|
250
|
-
) -> None:
|
251
|
-
vocab_key, s4_key = jax.random.split(key, 2)
|
252
|
-
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
253
|
-
self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
|
254
|
-
self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
|
255
|
-
|
256
|
-
block_keys = jax.random.split(s4_key, num_layers)
|
257
|
-
|
258
|
-
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
259
|
-
match block_type:
|
260
|
-
case "ssm":
|
261
|
-
return SSMBlock(hidden_size, key=key)
|
262
|
-
case "diag":
|
263
|
-
return DiagSSMBlock(hidden_size, key=key)
|
264
|
-
case _:
|
265
|
-
raise ValueError(f"Unknown block type: {block_type}")
|
266
|
-
|
267
|
-
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
268
|
-
self.skip_connections = skip_connections
|
269
|
-
self.num_layers = num_layers
|
270
|
-
self.hidden_size = hidden_size
|
271
|
-
|
272
|
-
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
273
|
-
new_hs = []
|
274
|
-
for i, block in enumerate(self.blocks):
|
275
|
-
h = block.forward(hs[i], x)
|
276
|
-
new_hs.append(h)
|
277
|
-
xh = jax.nn.gelu(h)
|
278
|
-
x = xh + x if self.skip_connections else xh
|
279
|
-
y = self.proj_out(x)
|
280
|
-
return new_hs, y
|
281
|
-
|
282
|
-
def _embed_input(self, x: Array) -> Array:
|
283
|
-
"""U is the input to the S4 cell."""
|
284
|
-
embedded = self.vocab_embedding(x)
|
285
|
-
return jax.nn.gelu(self.proj_in(embedded))
|
286
|
-
|
287
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
288
|
-
x_emb = jax.vmap(self._embed_input)(x_seq)
|
289
|
-
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
290
|
-
|
291
|
-
def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
292
|
-
hs, y = self(hs, x)
|
293
|
-
return hs, y
|
294
|
-
|
295
|
-
_, y_seq = jax.lax.scan(step, hs, x_emb)
|
296
|
-
return y_seq
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|