xax 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +9 -10
- xax/nn/geom.py +57 -0
- xax/nn/ssm.py +194 -174
- xax/task/mixins/train.py +15 -7
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/METADATA +1 -1
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/RECORD +9 -9
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/WHEEL +0 -0
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.11.dist-info → xax-0.1.12.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.12"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,15 +43,14 @@ __all__ = [
|
|
43
43
|
"euler_to_quat",
|
44
44
|
"get_projected_gravity_vector_from_quat",
|
45
45
|
"quat_to_euler",
|
46
|
+
"rotate_vector_by_quat",
|
46
47
|
"cross_entropy",
|
47
48
|
"cast_norm_type",
|
48
49
|
"get_norm",
|
49
50
|
"is_master",
|
51
|
+
"BaseSSMBlock",
|
50
52
|
"DiagSSMBlock",
|
51
|
-
"
|
52
|
-
"S4",
|
53
|
-
"S4Layer",
|
54
|
-
"S6Layer",
|
53
|
+
"SSM",
|
55
54
|
"SSMBlock",
|
56
55
|
"BaseLauncher",
|
57
56
|
"CliLauncher",
|
@@ -203,15 +202,14 @@ NAME_MAP: dict[str, str] = {
|
|
203
202
|
"euler_to_quat": "nn.geom",
|
204
203
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
205
204
|
"quat_to_euler": "nn.geom",
|
205
|
+
"rotate_vector_by_quat": "nn.geom",
|
206
206
|
"cross_entropy": "nn.losses",
|
207
207
|
"cast_norm_type": "nn.norm",
|
208
208
|
"get_norm": "nn.norm",
|
209
209
|
"is_master": "nn.parallel",
|
210
|
+
"BaseSSMBlock": "nn.ssm",
|
210
211
|
"DiagSSMBlock": "nn.ssm",
|
211
|
-
"
|
212
|
-
"S4": "nn.ssm",
|
213
|
-
"S4Layer": "nn.ssm",
|
214
|
-
"S6Layer": "nn.ssm",
|
212
|
+
"SSM": "nn.ssm",
|
215
213
|
"SSMBlock": "nn.ssm",
|
216
214
|
"BaseLauncher": "task.launchers.base",
|
217
215
|
"CliLauncher": "task.launchers.cli",
|
@@ -364,11 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
364
362
|
euler_to_quat,
|
365
363
|
get_projected_gravity_vector_from_quat,
|
366
364
|
quat_to_euler,
|
365
|
+
rotate_vector_by_quat,
|
367
366
|
)
|
368
367
|
from xax.nn.losses import cross_entropy
|
369
368
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
370
369
|
from xax.nn.parallel import is_master
|
371
|
-
from xax.nn.ssm import
|
370
|
+
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
372
371
|
from xax.task.base import RawConfigType
|
373
372
|
from xax.task.launchers.base import BaseLauncher
|
374
373
|
from xax.task.launchers.cli import CliLauncher
|
xax/nn/geom.py
CHANGED
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
|
|
99
99
|
|
100
100
|
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
101
101
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
102
|
+
|
103
|
+
|
104
|
+
def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
|
105
|
+
"""Rotates a vector by a quaternion.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
vector: The vector to rotate, shape (*, 3).
|
109
|
+
quat: The quaternion to rotate by, shape (*, 4).
|
110
|
+
eps: A small epsilon value to avoid division by zero.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The rotated vector, shape (*, 3).
|
114
|
+
"""
|
115
|
+
# Normalize quaternion
|
116
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
117
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
118
|
+
|
119
|
+
# Extract vector components
|
120
|
+
vx, vy, vz = jnp.split(vector, 3, axis=-1)
|
121
|
+
|
122
|
+
# Terms for x component
|
123
|
+
xx = (
|
124
|
+
w * w * vx
|
125
|
+
+ 2 * y * w * vz
|
126
|
+
- 2 * z * w * vy
|
127
|
+
+ x * x * vx
|
128
|
+
+ 2 * y * x * vy
|
129
|
+
+ 2 * z * x * vz
|
130
|
+
- z * z * vx
|
131
|
+
- y * y * vx
|
132
|
+
)
|
133
|
+
|
134
|
+
# Terms for y component
|
135
|
+
yy = (
|
136
|
+
2 * x * y * vx
|
137
|
+
+ y * y * vy
|
138
|
+
+ 2 * z * y * vz
|
139
|
+
+ 2 * w * z * vx
|
140
|
+
- z * z * vy
|
141
|
+
+ w * w * vy
|
142
|
+
- 2 * w * x * vz
|
143
|
+
- x * x * vy
|
144
|
+
)
|
145
|
+
|
146
|
+
# Terms for z component
|
147
|
+
zz = (
|
148
|
+
2 * x * z * vx
|
149
|
+
+ 2 * y * z * vy
|
150
|
+
+ z * z * vz
|
151
|
+
- 2 * w * y * vx
|
152
|
+
+ w * w * vz
|
153
|
+
+ 2 * w * x * vy
|
154
|
+
- y * y * vz
|
155
|
+
- x * x * vz
|
156
|
+
)
|
157
|
+
|
158
|
+
return jnp.concatenate([xx, yy, zz], axis=-1)
|
xax/nn/ssm.py
CHANGED
@@ -13,140 +13,18 @@ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
|
13
13
|
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
14
|
|
15
15
|
|
16
|
-
class
|
17
|
-
|
18
|
-
|
19
|
-
C: Array
|
20
|
-
proj_in: eqx.nn.Linear
|
21
|
-
proj_out: eqx.nn.Linear
|
22
|
-
|
23
|
-
def __init__(
|
24
|
-
self,
|
25
|
-
hidden_size: int,
|
26
|
-
projection_size: int,
|
27
|
-
input_size: int,
|
28
|
-
output_size: int,
|
29
|
-
*,
|
30
|
-
key: PRNGKeyArray,
|
31
|
-
) -> None:
|
32
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
33
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
34
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
35
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
36
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
37
|
-
|
38
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
39
|
-
h = self.a * h + self.B.T @ x
|
40
|
-
y = self.C.T @ h
|
41
|
-
return h, y
|
42
|
-
|
43
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
44
|
-
x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
|
45
|
-
h = jnp.zeros(self.a.shape[0])
|
46
|
-
|
47
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
48
|
-
h = self.a * h + self.B.T @ x
|
49
|
-
y = self.C.T @ h
|
50
|
-
return h, y
|
51
|
-
|
52
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
53
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
54
|
-
return y_out
|
55
|
-
|
56
|
-
|
57
|
-
class S4Layer(eqx.Module):
|
58
|
-
a: Array
|
59
|
-
B: Array
|
60
|
-
C: Array
|
61
|
-
proj_in: eqx.nn.Linear
|
62
|
-
proj_out: eqx.nn.Linear
|
63
|
-
delta: Array
|
64
|
-
|
65
|
-
def __init__(
|
66
|
-
self,
|
67
|
-
hidden_size: int,
|
68
|
-
projection_size: int,
|
69
|
-
input_size: int,
|
70
|
-
output_size: int,
|
71
|
-
*,
|
72
|
-
key: PRNGKeyArray,
|
73
|
-
) -> None:
|
74
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
75
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
76
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
77
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
78
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
79
|
-
self.delta = jax.random.uniform(key, (hidden_size,))
|
80
|
-
|
81
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
82
|
-
delta_a = self.delta * self.a
|
83
|
-
a_bar = jnp.exp(delta_a)
|
84
|
-
b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
|
85
|
-
h = a_bar * h + b_bar.T @ x
|
86
|
-
y = self.C.T @ h
|
87
|
-
return h, y
|
88
|
-
|
89
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
90
|
-
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
91
|
-
h = jnp.zeros(self.a.shape[0])
|
92
|
-
|
93
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
94
|
-
h = self.a * h + self.B.T @ x
|
95
|
-
y = self.C.T @ h
|
96
|
-
return h, y
|
97
|
-
|
98
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
99
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
100
|
-
return y_out
|
101
|
-
|
102
|
-
|
103
|
-
class S6Layer(eqx.Module):
|
104
|
-
a: Array
|
105
|
-
B: Array
|
106
|
-
C: Array
|
107
|
-
proj_in: eqx.nn.Linear
|
108
|
-
proj_out: eqx.nn.Linear
|
109
|
-
delta: Array
|
110
|
-
|
111
|
-
def __init__(
|
112
|
-
self,
|
113
|
-
hidden_size: int,
|
114
|
-
projection_size: int,
|
115
|
-
input_size: int,
|
116
|
-
output_size: int,
|
117
|
-
*,
|
118
|
-
key: PRNGKeyArray,
|
119
|
-
) -> None:
|
120
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
121
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
122
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
123
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
124
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
125
|
-
self.delta = jax.random.uniform(key, (hidden_size,))
|
126
|
-
|
127
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
128
|
-
h = self.a * h + self.B.T @ x
|
129
|
-
y = self.C.T @ h
|
130
|
-
return h, y
|
131
|
-
|
132
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
133
|
-
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
134
|
-
h = jnp.zeros(self.a.shape[0])
|
135
|
-
|
136
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
137
|
-
h = self.a * h + self.B.T @ x
|
138
|
-
y = self.C.T @ h
|
139
|
-
return h, y
|
16
|
+
class BaseSSMBlock(eqx.Module, ABC):
|
17
|
+
@abstractmethod
|
18
|
+
def forward(self, h: Array, x: Array) -> Array: ...
|
140
19
|
|
141
|
-
|
142
|
-
|
143
|
-
return y_out
|
20
|
+
@abstractmethod
|
21
|
+
def forward_sequence(self, x_seq: Array) -> Array: ...
|
144
22
|
|
23
|
+
@abstractmethod
|
24
|
+
def get_a_mat(self, x: Array) -> Array: ...
|
145
25
|
|
146
|
-
class BaseSSMBlock(eqx.Module, ABC):
|
147
26
|
@abstractmethod
|
148
|
-
def
|
149
|
-
pass
|
27
|
+
def get_b_mat(self, x: Array) -> Array: ...
|
150
28
|
|
151
29
|
|
152
30
|
class SSMBlock(BaseSSMBlock):
|
@@ -158,80 +36,194 @@ class SSMBlock(BaseSSMBlock):
|
|
158
36
|
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
159
37
|
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
160
38
|
|
39
|
+
def get_a_mat(self, x: Array) -> Array:
|
40
|
+
return self.a_mat
|
41
|
+
|
42
|
+
def get_b_mat(self, x: Array) -> Array:
|
43
|
+
return self.b_mat
|
44
|
+
|
161
45
|
def forward(self, h: Array, x: Array) -> Array:
|
162
|
-
|
46
|
+
"""Perform a forward pass.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
h: Hidden state of shape (H,).
|
50
|
+
x: Input of shape (H,).
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Hidden state of shape (H,).
|
54
|
+
"""
|
55
|
+
a_mat = self.get_a_mat(x)
|
56
|
+
b_mat = self.get_b_mat(x)
|
57
|
+
h = a_mat @ h + b_mat.T @ x
|
163
58
|
return h
|
164
59
|
|
165
|
-
def
|
166
|
-
|
60
|
+
def forward_sequence(self, x_seq: Array) -> Array:
|
61
|
+
"""Perform a forward pass across time.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
x_seq: Input sequence of shape (T, H).
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Hidden state sequence of shape (T, H).
|
68
|
+
"""
|
69
|
+
|
70
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
71
|
+
h = self.forward(h, x)
|
72
|
+
return h, h
|
73
|
+
|
74
|
+
a_mat = self.get_a_mat(x_seq)
|
75
|
+
h_0 = jnp.zeros(a_mat.shape[0])
|
76
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
77
|
+
return h_seq
|
167
78
|
|
168
79
|
|
169
80
|
class DiagSSMBlock(BaseSSMBlock):
|
170
|
-
|
81
|
+
a_diag: Array
|
171
82
|
b_mat: Array
|
172
83
|
|
173
84
|
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
174
85
|
keys = jax.random.split(key, 2)
|
175
|
-
self.
|
86
|
+
self.a_diag = glorot(keys[0], (hidden_size,))
|
176
87
|
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
177
88
|
|
89
|
+
def get_a_mat(self, x: Array) -> Array:
|
90
|
+
return self.a_diag
|
91
|
+
|
92
|
+
def get_b_mat(self, x: Array) -> Array:
|
93
|
+
return self.b_mat
|
94
|
+
|
178
95
|
def forward(self, h: Array, x: Array) -> Array:
|
179
|
-
|
180
|
-
|
96
|
+
"""Perform a forward pass.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
h: Hidden state of shape (H,).
|
100
|
+
x: Input of shape (H,).
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Hidden state of shape (H,).
|
104
|
+
"""
|
105
|
+
a_diag = self.get_a_mat(x)
|
106
|
+
b_mat = self.get_b_mat(x)
|
107
|
+
h = a_diag * h + b_mat.T @ x
|
181
108
|
return h
|
182
109
|
|
183
|
-
def
|
110
|
+
def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
|
111
|
+
"""Perform a potentially parallelized forward pass across time.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
x_seq: Input sequence of shape (T, H).
|
115
|
+
use_conv: Whether to use convolution to compute the sequence.
|
116
|
+
recursive_kernel_calc: Whether to use a recursive kernel calculation.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Hidden state sequence of shape (T, H).
|
120
|
+
"""
|
121
|
+
if use_conv:
|
122
|
+
return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
|
123
|
+
else:
|
124
|
+
return self._forward_sequence_scan(x_seq)
|
125
|
+
|
126
|
+
def _get_kernel(self, x_seq: Array, length: int) -> Array:
|
184
127
|
"""Returns the kernel with time as the final dimension."""
|
185
128
|
exponents = jnp.arange(length)
|
186
|
-
|
187
|
-
kernel =
|
129
|
+
a_diag = self.get_a_mat(x_seq)
|
130
|
+
kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
|
131
|
+
kernel = kernel[:, None, :] # (H, 1, T)
|
188
132
|
return kernel
|
189
133
|
|
190
|
-
def
|
134
|
+
def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
|
135
|
+
"""Returns the kernel with time as the final dimension."""
|
136
|
+
assert length % 2 == 0, "Length must be even."
|
137
|
+
a_diag = self.get_a_mat(x_seq)
|
138
|
+
|
139
|
+
def helper(length: int) -> tuple[Array, Array]:
|
140
|
+
"""Returns the kernel and the sqrt of the diagonal."""
|
141
|
+
if length == 1:
|
142
|
+
return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
|
143
|
+
|
144
|
+
half_length = length // 2
|
145
|
+
kernel_half, a_half = helper(half_length)
|
146
|
+
kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
|
147
|
+
return kernel, a_half * a_half
|
148
|
+
|
149
|
+
kernel, a_diag = helper(length)
|
150
|
+
return kernel[:, None, :] # (H, 1, L)
|
151
|
+
|
152
|
+
def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
|
191
153
|
"""Convolves x (T, H) across time using the kernel."""
|
192
|
-
|
154
|
+
seq_len, hidden_size = x_seq.shape
|
155
|
+
b_mat = self.get_b_mat(x_seq)
|
193
156
|
|
194
|
-
|
195
|
-
|
196
|
-
s = s.T # (H, T)
|
157
|
+
s = b_mat.T @ x_seq.T # (H, T)
|
158
|
+
s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
|
197
159
|
|
198
|
-
|
199
|
-
|
160
|
+
if recursive_kernel_calc:
|
161
|
+
kernel = self._get_kernel_recursive(x_seq, seq_len)
|
162
|
+
else:
|
163
|
+
kernel = self._get_kernel(x_seq, seq_len)
|
200
164
|
|
201
|
-
|
202
|
-
s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
|
165
|
+
kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
|
203
166
|
|
204
|
-
# Perform depthwise (grouped) 1D convolution.
|
205
|
-
# We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
|
206
|
-
# The dimension_numbers are chosen so that the channel dimension is second.
|
207
167
|
conv_out = jax.lax.conv_general_dilated(
|
208
168
|
s_padded,
|
209
169
|
kernel_flipped,
|
210
170
|
window_strides=(1,),
|
211
171
|
padding="VALID",
|
212
|
-
dimension_numbers=("
|
213
|
-
feature_group_count=
|
172
|
+
dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
|
173
|
+
feature_group_count=hidden_size,
|
214
174
|
)
|
215
|
-
|
216
|
-
conv_out = jnp.transpose(conv_out, (0, 2, 1))
|
175
|
+
conv_out = conv_out[0].T # (T, H)
|
217
176
|
return conv_out
|
218
177
|
|
219
|
-
def
|
178
|
+
def _forward_sequence_scan(self, x_seq: Array) -> Array:
|
220
179
|
"""Naively forward across time."""
|
221
180
|
|
222
181
|
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
223
182
|
h = self.forward(h, x)
|
224
183
|
return h, h
|
225
184
|
|
226
|
-
|
227
|
-
|
185
|
+
a_diag = self.get_a_mat(x_seq)
|
186
|
+
h_0 = jnp.zeros(a_diag.shape[0])
|
187
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
228
188
|
return h_seq
|
229
189
|
|
230
190
|
|
231
|
-
class
|
191
|
+
class DiscreteDiagSSMBlock(DiagSSMBlock):
|
192
|
+
delta: Array
|
193
|
+
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
hidden_size: int,
|
197
|
+
*,
|
198
|
+
key: PRNGKeyArray,
|
199
|
+
init_delta: float = 1.0,
|
200
|
+
init_scale: float = 10.0,
|
201
|
+
) -> None:
|
202
|
+
super().__init__(hidden_size, key=key)
|
203
|
+
self.delta = jnp.array(init_delta)
|
204
|
+
|
205
|
+
# A positive scale helps reduce the gradient at the start.
|
206
|
+
self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
|
207
|
+
|
208
|
+
def get_a_mat(self, x: Array) -> Array:
|
209
|
+
"""Discretize the diagonal matrix using zero-order hold."""
|
210
|
+
a_diag_discrete = jnp.exp(self.a_diag * self.delta)
|
211
|
+
return a_diag_discrete
|
212
|
+
|
213
|
+
def get_b_mat(self, x: Array) -> Array:
|
214
|
+
"""Discretize the input matrix using zero-order hold."""
|
215
|
+
delta_a_diag = self.a_diag * self.delta
|
216
|
+
exp_a_diag = jnp.exp(delta_a_diag)
|
217
|
+
delta_a_inv = 1 / delta_a_diag
|
218
|
+
delta_b_mat = self.delta * self.b_mat
|
219
|
+
|
220
|
+
b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
|
221
|
+
return b_discrete
|
222
|
+
|
223
|
+
|
224
|
+
class SSM(eqx.Module):
|
232
225
|
vocab_embedding: eqx.nn.Embedding
|
233
|
-
|
234
|
-
proj_out: eqx.nn.Linear
|
226
|
+
output_layer: eqx.nn.Linear
|
235
227
|
blocks: list[BaseSSMBlock]
|
236
228
|
num_layers: int = eqx.static_field()
|
237
229
|
hidden_size: int = eqx.static_field()
|
@@ -243,24 +235,30 @@ class S4(eqx.Module):
|
|
243
235
|
hidden_size: int,
|
244
236
|
output_size: int,
|
245
237
|
num_layers: int,
|
246
|
-
block_type: Literal["
|
238
|
+
block_type: Literal["diagonal", "full_rank"] = "full_rank",
|
247
239
|
skip_connections: bool = False,
|
240
|
+
discretize: bool = False,
|
248
241
|
*,
|
249
242
|
key: PRNGKeyArray,
|
250
243
|
) -> None:
|
251
244
|
vocab_key, s4_key = jax.random.split(key, 2)
|
252
245
|
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
253
|
-
self.
|
254
|
-
self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
|
246
|
+
self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
|
255
247
|
|
256
248
|
block_keys = jax.random.split(s4_key, num_layers)
|
257
249
|
|
258
250
|
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
259
251
|
match block_type:
|
260
|
-
case "
|
252
|
+
case "diagonal":
|
253
|
+
return (
|
254
|
+
DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
|
255
|
+
if discretize
|
256
|
+
else DiagSSMBlock(hidden_size, key=key)
|
257
|
+
)
|
258
|
+
case "full_rank":
|
259
|
+
if discretize:
|
260
|
+
raise ValueError("Full rank blocks do not support discretization due to instability.")
|
261
261
|
return SSMBlock(hidden_size, key=key)
|
262
|
-
case "diag":
|
263
|
-
return DiagSSMBlock(hidden_size, key=key)
|
264
262
|
case _:
|
265
263
|
raise ValueError(f"Unknown block type: {block_type}")
|
266
264
|
|
@@ -276,21 +274,43 @@ class S4(eqx.Module):
|
|
276
274
|
new_hs.append(h)
|
277
275
|
xh = jax.nn.gelu(h)
|
278
276
|
x = xh + x if self.skip_connections else xh
|
279
|
-
y = self.
|
277
|
+
y = self.output_layer(x)
|
280
278
|
return new_hs, y
|
281
279
|
|
282
280
|
def _embed_input(self, x: Array) -> Array:
|
283
281
|
"""U is the input to the S4 cell."""
|
284
|
-
|
285
|
-
return jax.nn.gelu(self.proj_in(embedded))
|
282
|
+
return self.vocab_embedding(x)
|
286
283
|
|
287
284
|
def predict_sequence(self, x_seq: Array) -> Array:
|
288
285
|
x_emb = jax.vmap(self._embed_input)(x_seq)
|
286
|
+
for block in self.blocks:
|
287
|
+
h = block.forward_sequence(x_emb)
|
288
|
+
# h = block.naive_forward_sequence(x_emb)
|
289
|
+
h = jax.nn.gelu(h)
|
290
|
+
x_emb = h + x_emb if self.skip_connections else h
|
291
|
+
y = jax.vmap(self.output_layer)(x_emb)
|
292
|
+
return y
|
293
|
+
|
294
|
+
def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
|
289
295
|
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
296
|
+
prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
|
290
297
|
|
291
|
-
def
|
298
|
+
def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
292
299
|
hs, y = self(hs, x)
|
293
300
|
return hs, y
|
294
301
|
|
295
|
-
|
296
|
-
|
302
|
+
def decode_step(
|
303
|
+
carry: tuple[list[Array], Array, PRNGKeyArray],
|
304
|
+
_: None,
|
305
|
+
) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
|
306
|
+
hs, last_token, rng = carry
|
307
|
+
token_embedded = self._embed_input(last_token)
|
308
|
+
hs, y = self(hs, token_embedded)
|
309
|
+
token = jax.random.categorical(rng, y)
|
310
|
+
rng = jax.random.split(rng)[0]
|
311
|
+
return (hs, token, rng), token
|
312
|
+
|
313
|
+
hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
|
314
|
+
_, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
|
315
|
+
|
316
|
+
return sequence
|
xax/task/mixins/train.py
CHANGED
@@ -218,26 +218,32 @@ class TrainMixin(
|
|
218
218
|
state = super().on_step_end(state)
|
219
219
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
220
220
|
|
221
|
-
def log_train_step(
|
221
|
+
def log_train_step(
|
222
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
223
|
+
) -> None:
|
222
224
|
"""Override this function to do logging during the training phase.
|
223
225
|
|
224
226
|
This function is called after the model forward pass and before the
|
225
227
|
backward pass. It is called in the training phase.
|
226
228
|
|
227
229
|
Args:
|
230
|
+
model: The current model.
|
228
231
|
batch: The batch from the dataloader.
|
229
232
|
output: The model output.
|
230
233
|
metrics: The metrics for the current batch.
|
231
234
|
state: The current training state.
|
232
235
|
"""
|
233
236
|
|
234
|
-
def log_valid_step(
|
237
|
+
def log_valid_step(
|
238
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
239
|
+
) -> None:
|
235
240
|
"""Override this function to do logging during the validation phase.
|
236
241
|
|
237
242
|
This function is called after the model forward pass. It is called in
|
238
243
|
the validation phase.
|
239
244
|
|
240
245
|
Args:
|
246
|
+
model: The current model.
|
241
247
|
batch: The batch from the dataloader.
|
242
248
|
output: The model output.
|
243
249
|
metrics: The metrics for the current batch.
|
@@ -251,7 +257,9 @@ class TrainMixin(
|
|
251
257
|
for k, v in d.items():
|
252
258
|
self.logger.log_scalar(k, v, namespace=ns)
|
253
259
|
|
254
|
-
def log_step(
|
260
|
+
def log_step(
|
261
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
262
|
+
) -> None:
|
255
263
|
phase = state.phase
|
256
264
|
|
257
265
|
for k, v in metrics.items():
|
@@ -265,9 +273,9 @@ class TrainMixin(
|
|
265
273
|
# Delegate to the appropriate logging function based on the phase.
|
266
274
|
match phase:
|
267
275
|
case "train":
|
268
|
-
self.log_train_step(batch, output, metrics, state)
|
276
|
+
self.log_train_step(model, batch, output, metrics, state)
|
269
277
|
case "valid":
|
270
|
-
self.log_valid_step(batch, output, metrics, state)
|
278
|
+
self.log_valid_step(model, batch, output, metrics, state)
|
271
279
|
case _:
|
272
280
|
raise KeyError(f"Unknown phase: {phase}")
|
273
281
|
|
@@ -579,7 +587,7 @@ class TrainMixin(
|
|
579
587
|
)
|
580
588
|
|
581
589
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
582
|
-
self.log_step(valid_batch, output, metrics, state)
|
590
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
583
591
|
|
584
592
|
state = self.on_step_start(state)
|
585
593
|
train_batch = next(train_pf)
|
@@ -597,7 +605,7 @@ class TrainMixin(
|
|
597
605
|
batch=train_batch,
|
598
606
|
state=state,
|
599
607
|
)
|
600
|
-
self.log_step(train_batch, output, metrics, state)
|
608
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
601
609
|
|
602
610
|
state = self.on_step_end(state)
|
603
611
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=7vdTYO7jAJdDxKZURlFxc3Y5kr5mVQcTQjeh_sYjD6I,13834
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -10,11 +10,11 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
|
-
xax/nn/ssm.py,sha256=
|
17
|
+
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
|
20
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
47
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.1.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.12.dist-info/METADATA,sha256=hLRAX5__7QjBgjzhxbRftGvEsNrt8IAdgd22dMtHu_Y,1878
|
63
|
+
xax-0.1.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|