xax 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +16 -1
- xax/nn/geom.py +57 -0
- xax/nn/losses.py +9 -0
- xax/nn/norm.py +2 -1
- xax/nn/ssm.py +316 -0
- xax/task/mixins/train.py +56 -15
- {xax-0.1.10.dist-info → xax-0.1.12.dist-info}/METADATA +1 -1
- {xax-0.1.10.dist-info → xax-0.1.12.dist-info}/RECORD +11 -9
- {xax-0.1.10.dist-info → xax-0.1.12.dist-info}/WHEEL +0 -0
- {xax-0.1.10.dist-info → xax-0.1.12.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.10.dist-info → xax-0.1.12.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.12"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,9 +43,15 @@ __all__ = [
|
|
43
43
|
"euler_to_quat",
|
44
44
|
"get_projected_gravity_vector_from_quat",
|
45
45
|
"quat_to_euler",
|
46
|
+
"rotate_vector_by_quat",
|
47
|
+
"cross_entropy",
|
46
48
|
"cast_norm_type",
|
47
49
|
"get_norm",
|
48
50
|
"is_master",
|
51
|
+
"BaseSSMBlock",
|
52
|
+
"DiagSSMBlock",
|
53
|
+
"SSM",
|
54
|
+
"SSMBlock",
|
49
55
|
"BaseLauncher",
|
50
56
|
"CliLauncher",
|
51
57
|
"SingleProcessLauncher",
|
@@ -196,9 +202,15 @@ NAME_MAP: dict[str, str] = {
|
|
196
202
|
"euler_to_quat": "nn.geom",
|
197
203
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
198
204
|
"quat_to_euler": "nn.geom",
|
205
|
+
"rotate_vector_by_quat": "nn.geom",
|
206
|
+
"cross_entropy": "nn.losses",
|
199
207
|
"cast_norm_type": "nn.norm",
|
200
208
|
"get_norm": "nn.norm",
|
201
209
|
"is_master": "nn.parallel",
|
210
|
+
"BaseSSMBlock": "nn.ssm",
|
211
|
+
"DiagSSMBlock": "nn.ssm",
|
212
|
+
"SSM": "nn.ssm",
|
213
|
+
"SSMBlock": "nn.ssm",
|
202
214
|
"BaseLauncher": "task.launchers.base",
|
203
215
|
"CliLauncher": "task.launchers.cli",
|
204
216
|
"SingleProcessLauncher": "task.launchers.single_process",
|
@@ -350,9 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
350
362
|
euler_to_quat,
|
351
363
|
get_projected_gravity_vector_from_quat,
|
352
364
|
quat_to_euler,
|
365
|
+
rotate_vector_by_quat,
|
353
366
|
)
|
367
|
+
from xax.nn.losses import cross_entropy
|
354
368
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
355
369
|
from xax.nn.parallel import is_master
|
370
|
+
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
356
371
|
from xax.task.base import RawConfigType
|
357
372
|
from xax.task.launchers.base import BaseLauncher
|
358
373
|
from xax.task.launchers.cli import CliLauncher
|
xax/nn/geom.py
CHANGED
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
|
|
99
99
|
|
100
100
|
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
101
101
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
102
|
+
|
103
|
+
|
104
|
+
def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
|
105
|
+
"""Rotates a vector by a quaternion.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
vector: The vector to rotate, shape (*, 3).
|
109
|
+
quat: The quaternion to rotate by, shape (*, 4).
|
110
|
+
eps: A small epsilon value to avoid division by zero.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The rotated vector, shape (*, 3).
|
114
|
+
"""
|
115
|
+
# Normalize quaternion
|
116
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
117
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
118
|
+
|
119
|
+
# Extract vector components
|
120
|
+
vx, vy, vz = jnp.split(vector, 3, axis=-1)
|
121
|
+
|
122
|
+
# Terms for x component
|
123
|
+
xx = (
|
124
|
+
w * w * vx
|
125
|
+
+ 2 * y * w * vz
|
126
|
+
- 2 * z * w * vy
|
127
|
+
+ x * x * vx
|
128
|
+
+ 2 * y * x * vy
|
129
|
+
+ 2 * z * x * vz
|
130
|
+
- z * z * vx
|
131
|
+
- y * y * vx
|
132
|
+
)
|
133
|
+
|
134
|
+
# Terms for y component
|
135
|
+
yy = (
|
136
|
+
2 * x * y * vx
|
137
|
+
+ y * y * vy
|
138
|
+
+ 2 * z * y * vz
|
139
|
+
+ 2 * w * z * vx
|
140
|
+
- z * z * vy
|
141
|
+
+ w * w * vy
|
142
|
+
- 2 * w * x * vz
|
143
|
+
- x * x * vy
|
144
|
+
)
|
145
|
+
|
146
|
+
# Terms for z component
|
147
|
+
zz = (
|
148
|
+
2 * x * z * vx
|
149
|
+
+ 2 * y * z * vy
|
150
|
+
+ z * z * vz
|
151
|
+
- 2 * w * y * vx
|
152
|
+
+ w * w * vz
|
153
|
+
+ 2 * w * x * vy
|
154
|
+
- y * y * vz
|
155
|
+
- x * x * vz
|
156
|
+
)
|
157
|
+
|
158
|
+
return jnp.concatenate([xx, yy, zz], axis=-1)
|
xax/nn/losses.py
ADDED
@@ -0,0 +1,9 @@
|
|
1
|
+
"""Defines some common loss functions."""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from jaxtyping import Array
|
5
|
+
|
6
|
+
|
7
|
+
def cross_entropy(y: Array, pred_y: Array, axis: int = 1) -> Array:
|
8
|
+
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, axis), axis=axis)
|
9
|
+
return -jnp.mean(pred_y)
|
xax/nn/norm.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from typing import Literal, cast, get_args
|
4
4
|
|
5
5
|
import jax.numpy as jnp
|
6
|
+
from jaxtyping import Array
|
6
7
|
|
7
8
|
NormType = Literal["l1", "l2"]
|
8
9
|
|
@@ -13,7 +14,7 @@ def cast_norm_type(norm: str) -> NormType:
|
|
13
14
|
return cast(NormType, norm)
|
14
15
|
|
15
16
|
|
16
|
-
def get_norm(x:
|
17
|
+
def get_norm(x: Array, norm: NormType) -> Array:
|
17
18
|
match norm:
|
18
19
|
case "l1":
|
19
20
|
return jnp.abs(x)
|
xax/nn/ssm.py
ADDED
@@ -0,0 +1,316 @@
|
|
1
|
+
"""State space models."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
+
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
+
|
15
|
+
|
16
|
+
class BaseSSMBlock(eqx.Module, ABC):
|
17
|
+
@abstractmethod
|
18
|
+
def forward(self, h: Array, x: Array) -> Array: ...
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def forward_sequence(self, x_seq: Array) -> Array: ...
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def get_a_mat(self, x: Array) -> Array: ...
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def get_b_mat(self, x: Array) -> Array: ...
|
28
|
+
|
29
|
+
|
30
|
+
class SSMBlock(BaseSSMBlock):
|
31
|
+
a_mat: Array
|
32
|
+
b_mat: Array
|
33
|
+
|
34
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
35
|
+
key_a, key_b = jax.random.split(key)
|
36
|
+
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
37
|
+
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
38
|
+
|
39
|
+
def get_a_mat(self, x: Array) -> Array:
|
40
|
+
return self.a_mat
|
41
|
+
|
42
|
+
def get_b_mat(self, x: Array) -> Array:
|
43
|
+
return self.b_mat
|
44
|
+
|
45
|
+
def forward(self, h: Array, x: Array) -> Array:
|
46
|
+
"""Perform a forward pass.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
h: Hidden state of shape (H,).
|
50
|
+
x: Input of shape (H,).
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Hidden state of shape (H,).
|
54
|
+
"""
|
55
|
+
a_mat = self.get_a_mat(x)
|
56
|
+
b_mat = self.get_b_mat(x)
|
57
|
+
h = a_mat @ h + b_mat.T @ x
|
58
|
+
return h
|
59
|
+
|
60
|
+
def forward_sequence(self, x_seq: Array) -> Array:
|
61
|
+
"""Perform a forward pass across time.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
x_seq: Input sequence of shape (T, H).
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Hidden state sequence of shape (T, H).
|
68
|
+
"""
|
69
|
+
|
70
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
71
|
+
h = self.forward(h, x)
|
72
|
+
return h, h
|
73
|
+
|
74
|
+
a_mat = self.get_a_mat(x_seq)
|
75
|
+
h_0 = jnp.zeros(a_mat.shape[0])
|
76
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
77
|
+
return h_seq
|
78
|
+
|
79
|
+
|
80
|
+
class DiagSSMBlock(BaseSSMBlock):
|
81
|
+
a_diag: Array
|
82
|
+
b_mat: Array
|
83
|
+
|
84
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
85
|
+
keys = jax.random.split(key, 2)
|
86
|
+
self.a_diag = glorot(keys[0], (hidden_size,))
|
87
|
+
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
88
|
+
|
89
|
+
def get_a_mat(self, x: Array) -> Array:
|
90
|
+
return self.a_diag
|
91
|
+
|
92
|
+
def get_b_mat(self, x: Array) -> Array:
|
93
|
+
return self.b_mat
|
94
|
+
|
95
|
+
def forward(self, h: Array, x: Array) -> Array:
|
96
|
+
"""Perform a forward pass.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
h: Hidden state of shape (H,).
|
100
|
+
x: Input of shape (H,).
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Hidden state of shape (H,).
|
104
|
+
"""
|
105
|
+
a_diag = self.get_a_mat(x)
|
106
|
+
b_mat = self.get_b_mat(x)
|
107
|
+
h = a_diag * h + b_mat.T @ x
|
108
|
+
return h
|
109
|
+
|
110
|
+
def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
|
111
|
+
"""Perform a potentially parallelized forward pass across time.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
x_seq: Input sequence of shape (T, H).
|
115
|
+
use_conv: Whether to use convolution to compute the sequence.
|
116
|
+
recursive_kernel_calc: Whether to use a recursive kernel calculation.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Hidden state sequence of shape (T, H).
|
120
|
+
"""
|
121
|
+
if use_conv:
|
122
|
+
return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
|
123
|
+
else:
|
124
|
+
return self._forward_sequence_scan(x_seq)
|
125
|
+
|
126
|
+
def _get_kernel(self, x_seq: Array, length: int) -> Array:
|
127
|
+
"""Returns the kernel with time as the final dimension."""
|
128
|
+
exponents = jnp.arange(length)
|
129
|
+
a_diag = self.get_a_mat(x_seq)
|
130
|
+
kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
|
131
|
+
kernel = kernel[:, None, :] # (H, 1, T)
|
132
|
+
return kernel
|
133
|
+
|
134
|
+
def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
|
135
|
+
"""Returns the kernel with time as the final dimension."""
|
136
|
+
assert length % 2 == 0, "Length must be even."
|
137
|
+
a_diag = self.get_a_mat(x_seq)
|
138
|
+
|
139
|
+
def helper(length: int) -> tuple[Array, Array]:
|
140
|
+
"""Returns the kernel and the sqrt of the diagonal."""
|
141
|
+
if length == 1:
|
142
|
+
return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
|
143
|
+
|
144
|
+
half_length = length // 2
|
145
|
+
kernel_half, a_half = helper(half_length)
|
146
|
+
kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
|
147
|
+
return kernel, a_half * a_half
|
148
|
+
|
149
|
+
kernel, a_diag = helper(length)
|
150
|
+
return kernel[:, None, :] # (H, 1, L)
|
151
|
+
|
152
|
+
def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
|
153
|
+
"""Convolves x (T, H) across time using the kernel."""
|
154
|
+
seq_len, hidden_size = x_seq.shape
|
155
|
+
b_mat = self.get_b_mat(x_seq)
|
156
|
+
|
157
|
+
s = b_mat.T @ x_seq.T # (H, T)
|
158
|
+
s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
|
159
|
+
|
160
|
+
if recursive_kernel_calc:
|
161
|
+
kernel = self._get_kernel_recursive(x_seq, seq_len)
|
162
|
+
else:
|
163
|
+
kernel = self._get_kernel(x_seq, seq_len)
|
164
|
+
|
165
|
+
kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
|
166
|
+
|
167
|
+
conv_out = jax.lax.conv_general_dilated(
|
168
|
+
s_padded,
|
169
|
+
kernel_flipped,
|
170
|
+
window_strides=(1,),
|
171
|
+
padding="VALID",
|
172
|
+
dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
|
173
|
+
feature_group_count=hidden_size,
|
174
|
+
)
|
175
|
+
conv_out = conv_out[0].T # (T, H)
|
176
|
+
return conv_out
|
177
|
+
|
178
|
+
def _forward_sequence_scan(self, x_seq: Array) -> Array:
|
179
|
+
"""Naively forward across time."""
|
180
|
+
|
181
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
182
|
+
h = self.forward(h, x)
|
183
|
+
return h, h
|
184
|
+
|
185
|
+
a_diag = self.get_a_mat(x_seq)
|
186
|
+
h_0 = jnp.zeros(a_diag.shape[0])
|
187
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
188
|
+
return h_seq
|
189
|
+
|
190
|
+
|
191
|
+
class DiscreteDiagSSMBlock(DiagSSMBlock):
|
192
|
+
delta: Array
|
193
|
+
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
hidden_size: int,
|
197
|
+
*,
|
198
|
+
key: PRNGKeyArray,
|
199
|
+
init_delta: float = 1.0,
|
200
|
+
init_scale: float = 10.0,
|
201
|
+
) -> None:
|
202
|
+
super().__init__(hidden_size, key=key)
|
203
|
+
self.delta = jnp.array(init_delta)
|
204
|
+
|
205
|
+
# A positive scale helps reduce the gradient at the start.
|
206
|
+
self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
|
207
|
+
|
208
|
+
def get_a_mat(self, x: Array) -> Array:
|
209
|
+
"""Discretize the diagonal matrix using zero-order hold."""
|
210
|
+
a_diag_discrete = jnp.exp(self.a_diag * self.delta)
|
211
|
+
return a_diag_discrete
|
212
|
+
|
213
|
+
def get_b_mat(self, x: Array) -> Array:
|
214
|
+
"""Discretize the input matrix using zero-order hold."""
|
215
|
+
delta_a_diag = self.a_diag * self.delta
|
216
|
+
exp_a_diag = jnp.exp(delta_a_diag)
|
217
|
+
delta_a_inv = 1 / delta_a_diag
|
218
|
+
delta_b_mat = self.delta * self.b_mat
|
219
|
+
|
220
|
+
b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
|
221
|
+
return b_discrete
|
222
|
+
|
223
|
+
|
224
|
+
class SSM(eqx.Module):
|
225
|
+
vocab_embedding: eqx.nn.Embedding
|
226
|
+
output_layer: eqx.nn.Linear
|
227
|
+
blocks: list[BaseSSMBlock]
|
228
|
+
num_layers: int = eqx.static_field()
|
229
|
+
hidden_size: int = eqx.static_field()
|
230
|
+
skip_connections: bool = eqx.static_field()
|
231
|
+
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
input_size: int,
|
235
|
+
hidden_size: int,
|
236
|
+
output_size: int,
|
237
|
+
num_layers: int,
|
238
|
+
block_type: Literal["diagonal", "full_rank"] = "full_rank",
|
239
|
+
skip_connections: bool = False,
|
240
|
+
discretize: bool = False,
|
241
|
+
*,
|
242
|
+
key: PRNGKeyArray,
|
243
|
+
) -> None:
|
244
|
+
vocab_key, s4_key = jax.random.split(key, 2)
|
245
|
+
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
246
|
+
self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
|
247
|
+
|
248
|
+
block_keys = jax.random.split(s4_key, num_layers)
|
249
|
+
|
250
|
+
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
251
|
+
match block_type:
|
252
|
+
case "diagonal":
|
253
|
+
return (
|
254
|
+
DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
|
255
|
+
if discretize
|
256
|
+
else DiagSSMBlock(hidden_size, key=key)
|
257
|
+
)
|
258
|
+
case "full_rank":
|
259
|
+
if discretize:
|
260
|
+
raise ValueError("Full rank blocks do not support discretization due to instability.")
|
261
|
+
return SSMBlock(hidden_size, key=key)
|
262
|
+
case _:
|
263
|
+
raise ValueError(f"Unknown block type: {block_type}")
|
264
|
+
|
265
|
+
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
266
|
+
self.skip_connections = skip_connections
|
267
|
+
self.num_layers = num_layers
|
268
|
+
self.hidden_size = hidden_size
|
269
|
+
|
270
|
+
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
271
|
+
new_hs = []
|
272
|
+
for i, block in enumerate(self.blocks):
|
273
|
+
h = block.forward(hs[i], x)
|
274
|
+
new_hs.append(h)
|
275
|
+
xh = jax.nn.gelu(h)
|
276
|
+
x = xh + x if self.skip_connections else xh
|
277
|
+
y = self.output_layer(x)
|
278
|
+
return new_hs, y
|
279
|
+
|
280
|
+
def _embed_input(self, x: Array) -> Array:
|
281
|
+
"""U is the input to the S4 cell."""
|
282
|
+
return self.vocab_embedding(x)
|
283
|
+
|
284
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
285
|
+
x_emb = jax.vmap(self._embed_input)(x_seq)
|
286
|
+
for block in self.blocks:
|
287
|
+
h = block.forward_sequence(x_emb)
|
288
|
+
# h = block.naive_forward_sequence(x_emb)
|
289
|
+
h = jax.nn.gelu(h)
|
290
|
+
x_emb = h + x_emb if self.skip_connections else h
|
291
|
+
y = jax.vmap(self.output_layer)(x_emb)
|
292
|
+
return y
|
293
|
+
|
294
|
+
def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
|
295
|
+
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
296
|
+
prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
|
297
|
+
|
298
|
+
def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
299
|
+
hs, y = self(hs, x)
|
300
|
+
return hs, y
|
301
|
+
|
302
|
+
def decode_step(
|
303
|
+
carry: tuple[list[Array], Array, PRNGKeyArray],
|
304
|
+
_: None,
|
305
|
+
) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
|
306
|
+
hs, last_token, rng = carry
|
307
|
+
token_embedded = self._embed_input(last_token)
|
308
|
+
hs, y = self(hs, token_embedded)
|
309
|
+
token = jax.random.categorical(rng, y)
|
310
|
+
rng = jax.random.split(rng)[0]
|
311
|
+
return (hs, token, rng), token
|
312
|
+
|
313
|
+
hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
|
314
|
+
_, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
|
315
|
+
|
316
|
+
return sequence
|
xax/task/mixins/train.py
CHANGED
@@ -29,6 +29,7 @@ from typing import (
|
|
29
29
|
|
30
30
|
import equinox as eqx
|
31
31
|
import jax
|
32
|
+
import jax.numpy as jnp
|
32
33
|
import numpy as np
|
33
34
|
import optax
|
34
35
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
@@ -162,6 +163,7 @@ class TrainConfig(
|
|
162
163
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
163
164
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
164
165
|
random_seed: int = field(1337, help="Random seed for the task")
|
166
|
+
global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
|
165
167
|
|
166
168
|
|
167
169
|
Config = TypeVar("Config", bound=TrainConfig)
|
@@ -216,26 +218,32 @@ class TrainMixin(
|
|
216
218
|
state = super().on_step_end(state)
|
217
219
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
218
220
|
|
219
|
-
def log_train_step(
|
221
|
+
def log_train_step(
|
222
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
223
|
+
) -> None:
|
220
224
|
"""Override this function to do logging during the training phase.
|
221
225
|
|
222
226
|
This function is called after the model forward pass and before the
|
223
227
|
backward pass. It is called in the training phase.
|
224
228
|
|
225
229
|
Args:
|
230
|
+
model: The current model.
|
226
231
|
batch: The batch from the dataloader.
|
227
232
|
output: The model output.
|
228
233
|
metrics: The metrics for the current batch.
|
229
234
|
state: The current training state.
|
230
235
|
"""
|
231
236
|
|
232
|
-
def log_valid_step(
|
237
|
+
def log_valid_step(
|
238
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
239
|
+
) -> None:
|
233
240
|
"""Override this function to do logging during the validation phase.
|
234
241
|
|
235
242
|
This function is called after the model forward pass. It is called in
|
236
243
|
the validation phase.
|
237
244
|
|
238
245
|
Args:
|
246
|
+
model: The current model.
|
239
247
|
batch: The batch from the dataloader.
|
240
248
|
output: The model output.
|
241
249
|
metrics: The metrics for the current batch.
|
@@ -249,7 +257,9 @@ class TrainMixin(
|
|
249
257
|
for k, v in d.items():
|
250
258
|
self.logger.log_scalar(k, v, namespace=ns)
|
251
259
|
|
252
|
-
def log_step(
|
260
|
+
def log_step(
|
261
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
262
|
+
) -> None:
|
253
263
|
phase = state.phase
|
254
264
|
|
255
265
|
for k, v in metrics.items():
|
@@ -263,9 +273,9 @@ class TrainMixin(
|
|
263
273
|
# Delegate to the appropriate logging function based on the phase.
|
264
274
|
match phase:
|
265
275
|
case "train":
|
266
|
-
self.log_train_step(batch, output, metrics, state)
|
276
|
+
self.log_train_step(model, batch, output, metrics, state)
|
267
277
|
case "valid":
|
268
|
-
self.log_valid_step(batch, output, metrics, state)
|
278
|
+
self.log_valid_step(model, batch, output, metrics, state)
|
269
279
|
case _:
|
270
280
|
raise KeyError(f"Unknown phase: {phase}")
|
271
281
|
|
@@ -403,12 +413,12 @@ class TrainMixin(
|
|
403
413
|
model_static: PyTree,
|
404
414
|
batch: Batch,
|
405
415
|
state: State,
|
406
|
-
) -> tuple[Array, tuple[Output,
|
416
|
+
) -> tuple[Array, tuple[Output, dict[str, Array]]]:
|
407
417
|
model = eqx.combine(model_arr, model_static)
|
408
418
|
output = self.get_output(model, batch, state)
|
409
419
|
loss = self.compute_loss(model, batch, output, state)
|
410
420
|
metrics = self.compute_metrics(model, batch, output, loss, state)
|
411
|
-
return loss, (output,
|
421
|
+
return loss, (output, metrics)
|
412
422
|
|
413
423
|
def update(
|
414
424
|
self,
|
@@ -418,13 +428,44 @@ class TrainMixin(
|
|
418
428
|
opt_state: optax.OptState,
|
419
429
|
batch: Batch,
|
420
430
|
state: State,
|
421
|
-
) -> tuple[PyTree, optax.OptState, Output,
|
431
|
+
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
422
432
|
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
423
433
|
grad_fn = xax_jit(static_argnums=[1])(grad_fn)
|
424
434
|
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
425
|
-
|
426
|
-
model_arr
|
427
|
-
|
435
|
+
model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
|
436
|
+
return model_arr, opt_state, output, metrics | grad_metrics
|
437
|
+
|
438
|
+
@xax_jit(static_argnames=["self", "optimizer"])
|
439
|
+
def apply_gradients_with_clipping(
|
440
|
+
self,
|
441
|
+
model_arr: PyTree,
|
442
|
+
grads: PyTree,
|
443
|
+
optimizer: optax.GradientTransformation,
|
444
|
+
opt_state: optax.OptState,
|
445
|
+
) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
|
446
|
+
grad_norm = optax.global_norm(grads)
|
447
|
+
grad_metrics = {"grad_norm": grad_norm}
|
448
|
+
|
449
|
+
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
450
|
+
# Clip the global gradient norm to some desired range.
|
451
|
+
grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
|
452
|
+
grads = jax.tree.map(lambda x: x * grad_factor, grads)
|
453
|
+
|
454
|
+
# Apply the gradient updates.
|
455
|
+
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
456
|
+
new_model_arr = eqx.apply_updates(model_arr, updates)
|
457
|
+
return new_model_arr, new_opt_state
|
458
|
+
|
459
|
+
# Don't apply updates if the gradient is NaN or Inf.
|
460
|
+
new_model_arr, new_opt_state = jax.lax.cond(
|
461
|
+
jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
|
462
|
+
lambda *_: (model_arr, opt_state),
|
463
|
+
apply,
|
464
|
+
grads,
|
465
|
+
grad_norm,
|
466
|
+
)
|
467
|
+
|
468
|
+
return new_model_arr, new_opt_state, grad_metrics
|
428
469
|
|
429
470
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
430
471
|
"""Gets the batch size for the current batch.
|
@@ -512,7 +553,7 @@ class TrainMixin(
|
|
512
553
|
state: State,
|
513
554
|
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
514
555
|
model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
|
515
|
-
return model_arr, opt_state, output, metrics
|
556
|
+
return model_arr, opt_state, output, FrozenDict(metrics)
|
516
557
|
|
517
558
|
@xax_jit(static_argnames=["self", "model_static"])
|
518
559
|
def val_step(
|
@@ -523,7 +564,7 @@ class TrainMixin(
|
|
523
564
|
state: State,
|
524
565
|
) -> tuple[Output, FrozenDict[str, Array]]:
|
525
566
|
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
526
|
-
return output, metrics
|
567
|
+
return output, FrozenDict(metrics)
|
527
568
|
|
528
569
|
def train_loop(
|
529
570
|
self,
|
@@ -546,7 +587,7 @@ class TrainMixin(
|
|
546
587
|
)
|
547
588
|
|
548
589
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
549
|
-
self.log_step(valid_batch, output, metrics, state)
|
590
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
550
591
|
|
551
592
|
state = self.on_step_start(state)
|
552
593
|
train_batch = next(train_pf)
|
@@ -564,7 +605,7 @@ class TrainMixin(
|
|
564
605
|
batch=train_batch,
|
565
606
|
state=state,
|
566
607
|
)
|
567
|
-
self.log_step(train_batch, output, metrics, state)
|
608
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
568
609
|
|
569
610
|
state = self.on_step_end(state)
|
570
611
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=7vdTYO7jAJdDxKZURlFxc3Y5kr5mVQcTQjeh_sYjD6I,13834
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -10,9 +10,11 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
|
-
xax/nn/geom.py,sha256=
|
14
|
-
xax/nn/
|
13
|
+
xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
|
14
|
+
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
|
+
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
15
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
|
+
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
16
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
19
|
xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
|
18
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
@@ -39,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
|
43
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
46
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -56,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
56
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
61
|
+
xax-0.1.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.12.dist-info/METADATA,sha256=hLRAX5__7QjBgjzhxbRftGvEsNrt8IAdgd22dMtHu_Y,1878
|
63
|
+
xax-0.1.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|