xax 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +17 -1
- xax/nn/losses.py +9 -0
- xax/nn/norm.py +2 -1
- xax/nn/ssm.py +296 -0
- xax/task/mixins/train.py +41 -8
- {xax-0.1.10.dist-info → xax-0.1.11.dist-info}/METADATA +1 -1
- {xax-0.1.10.dist-info → xax-0.1.11.dist-info}/RECORD +10 -8
- {xax-0.1.10.dist-info → xax-0.1.11.dist-info}/WHEEL +0 -0
- {xax-0.1.10.dist-info → xax-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.10.dist-info → xax-0.1.11.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.11"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,9 +43,16 @@ __all__ = [
|
|
43
43
|
"euler_to_quat",
|
44
44
|
"get_projected_gravity_vector_from_quat",
|
45
45
|
"quat_to_euler",
|
46
|
+
"cross_entropy",
|
46
47
|
"cast_norm_type",
|
47
48
|
"get_norm",
|
48
49
|
"is_master",
|
50
|
+
"DiagSSMBlock",
|
51
|
+
"DiscreteTimeS4",
|
52
|
+
"S4",
|
53
|
+
"S4Layer",
|
54
|
+
"S6Layer",
|
55
|
+
"SSMBlock",
|
49
56
|
"BaseLauncher",
|
50
57
|
"CliLauncher",
|
51
58
|
"SingleProcessLauncher",
|
@@ -196,9 +203,16 @@ NAME_MAP: dict[str, str] = {
|
|
196
203
|
"euler_to_quat": "nn.geom",
|
197
204
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
198
205
|
"quat_to_euler": "nn.geom",
|
206
|
+
"cross_entropy": "nn.losses",
|
199
207
|
"cast_norm_type": "nn.norm",
|
200
208
|
"get_norm": "nn.norm",
|
201
209
|
"is_master": "nn.parallel",
|
210
|
+
"DiagSSMBlock": "nn.ssm",
|
211
|
+
"DiscreteTimeS4": "nn.ssm",
|
212
|
+
"S4": "nn.ssm",
|
213
|
+
"S4Layer": "nn.ssm",
|
214
|
+
"S6Layer": "nn.ssm",
|
215
|
+
"SSMBlock": "nn.ssm",
|
202
216
|
"BaseLauncher": "task.launchers.base",
|
203
217
|
"CliLauncher": "task.launchers.cli",
|
204
218
|
"SingleProcessLauncher": "task.launchers.single_process",
|
@@ -351,8 +365,10 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
351
365
|
get_projected_gravity_vector_from_quat,
|
352
366
|
quat_to_euler,
|
353
367
|
)
|
368
|
+
from xax.nn.losses import cross_entropy
|
354
369
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
355
370
|
from xax.nn.parallel import is_master
|
371
|
+
from xax.nn.ssm import S4, DiagSSMBlock, DiscreteTimeS4, S4Layer, S6Layer, SSMBlock
|
356
372
|
from xax.task.base import RawConfigType
|
357
373
|
from xax.task.launchers.base import BaseLauncher
|
358
374
|
from xax.task.launchers.cli import CliLauncher
|
xax/nn/losses.py
ADDED
@@ -0,0 +1,9 @@
|
|
1
|
+
"""Defines some common loss functions."""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from jaxtyping import Array
|
5
|
+
|
6
|
+
|
7
|
+
def cross_entropy(y: Array, pred_y: Array, axis: int = 1) -> Array:
|
8
|
+
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, axis), axis=axis)
|
9
|
+
return -jnp.mean(pred_y)
|
xax/nn/norm.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from typing import Literal, cast, get_args
|
4
4
|
|
5
5
|
import jax.numpy as jnp
|
6
|
+
from jaxtyping import Array
|
6
7
|
|
7
8
|
NormType = Literal["l1", "l2"]
|
8
9
|
|
@@ -13,7 +14,7 @@ def cast_norm_type(norm: str) -> NormType:
|
|
13
14
|
return cast(NormType, norm)
|
14
15
|
|
15
16
|
|
16
|
-
def get_norm(x:
|
17
|
+
def get_norm(x: Array, norm: NormType) -> Array:
|
17
18
|
match norm:
|
18
19
|
case "l1":
|
19
20
|
return jnp.abs(x)
|
xax/nn/ssm.py
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
"""State space models."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
+
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
+
|
15
|
+
|
16
|
+
class DiscreteTimeS4(eqx.Module):
|
17
|
+
a: Array
|
18
|
+
B: Array
|
19
|
+
C: Array
|
20
|
+
proj_in: eqx.nn.Linear
|
21
|
+
proj_out: eqx.nn.Linear
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
hidden_size: int,
|
26
|
+
projection_size: int,
|
27
|
+
input_size: int,
|
28
|
+
output_size: int,
|
29
|
+
*,
|
30
|
+
key: PRNGKeyArray,
|
31
|
+
) -> None:
|
32
|
+
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
33
|
+
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
34
|
+
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
35
|
+
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
36
|
+
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
37
|
+
|
38
|
+
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
39
|
+
h = self.a * h + self.B.T @ x
|
40
|
+
y = self.C.T @ h
|
41
|
+
return h, y
|
42
|
+
|
43
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
44
|
+
x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
|
45
|
+
h = jnp.zeros(self.a.shape[0])
|
46
|
+
|
47
|
+
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
48
|
+
h = self.a * h + self.B.T @ x
|
49
|
+
y = self.C.T @ h
|
50
|
+
return h, y
|
51
|
+
|
52
|
+
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
53
|
+
y_out = jax.vmap(self.proj_out)(y_seq)
|
54
|
+
return y_out
|
55
|
+
|
56
|
+
|
57
|
+
class S4Layer(eqx.Module):
|
58
|
+
a: Array
|
59
|
+
B: Array
|
60
|
+
C: Array
|
61
|
+
proj_in: eqx.nn.Linear
|
62
|
+
proj_out: eqx.nn.Linear
|
63
|
+
delta: Array
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
hidden_size: int,
|
68
|
+
projection_size: int,
|
69
|
+
input_size: int,
|
70
|
+
output_size: int,
|
71
|
+
*,
|
72
|
+
key: PRNGKeyArray,
|
73
|
+
) -> None:
|
74
|
+
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
75
|
+
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
76
|
+
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
77
|
+
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
78
|
+
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
79
|
+
self.delta = jax.random.uniform(key, (hidden_size,))
|
80
|
+
|
81
|
+
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
82
|
+
delta_a = self.delta * self.a
|
83
|
+
a_bar = jnp.exp(delta_a)
|
84
|
+
b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
|
85
|
+
h = a_bar * h + b_bar.T @ x
|
86
|
+
y = self.C.T @ h
|
87
|
+
return h, y
|
88
|
+
|
89
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
90
|
+
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
91
|
+
h = jnp.zeros(self.a.shape[0])
|
92
|
+
|
93
|
+
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
94
|
+
h = self.a * h + self.B.T @ x
|
95
|
+
y = self.C.T @ h
|
96
|
+
return h, y
|
97
|
+
|
98
|
+
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
99
|
+
y_out = jax.vmap(self.proj_out)(y_seq)
|
100
|
+
return y_out
|
101
|
+
|
102
|
+
|
103
|
+
class S6Layer(eqx.Module):
|
104
|
+
a: Array
|
105
|
+
B: Array
|
106
|
+
C: Array
|
107
|
+
proj_in: eqx.nn.Linear
|
108
|
+
proj_out: eqx.nn.Linear
|
109
|
+
delta: Array
|
110
|
+
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
hidden_size: int,
|
114
|
+
projection_size: int,
|
115
|
+
input_size: int,
|
116
|
+
output_size: int,
|
117
|
+
*,
|
118
|
+
key: PRNGKeyArray,
|
119
|
+
) -> None:
|
120
|
+
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
121
|
+
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
122
|
+
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
123
|
+
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
124
|
+
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
125
|
+
self.delta = jax.random.uniform(key, (hidden_size,))
|
126
|
+
|
127
|
+
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
128
|
+
h = self.a * h + self.B.T @ x
|
129
|
+
y = self.C.T @ h
|
130
|
+
return h, y
|
131
|
+
|
132
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
133
|
+
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
134
|
+
h = jnp.zeros(self.a.shape[0])
|
135
|
+
|
136
|
+
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
137
|
+
h = self.a * h + self.B.T @ x
|
138
|
+
y = self.C.T @ h
|
139
|
+
return h, y
|
140
|
+
|
141
|
+
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
142
|
+
y_out = jax.vmap(self.proj_out)(y_seq)
|
143
|
+
return y_out
|
144
|
+
|
145
|
+
|
146
|
+
class BaseSSMBlock(eqx.Module, ABC):
|
147
|
+
@abstractmethod
|
148
|
+
def forward(self, h: Array, x: Array) -> Array:
|
149
|
+
pass
|
150
|
+
|
151
|
+
|
152
|
+
class SSMBlock(BaseSSMBlock):
|
153
|
+
a_mat: Array
|
154
|
+
b_mat: Array
|
155
|
+
|
156
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
157
|
+
key_a, key_b = jax.random.split(key)
|
158
|
+
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
159
|
+
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
160
|
+
|
161
|
+
def forward(self, h: Array, x: Array) -> Array:
|
162
|
+
h = self.a_mat @ h + self.b_mat.T @ x
|
163
|
+
return h
|
164
|
+
|
165
|
+
def get_kernel(self, length: int) -> Array:
|
166
|
+
return self.a_mat
|
167
|
+
|
168
|
+
|
169
|
+
class DiagSSMBlock(BaseSSMBlock):
|
170
|
+
a_mat: Array
|
171
|
+
b_mat: Array
|
172
|
+
|
173
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
174
|
+
keys = jax.random.split(key, 2)
|
175
|
+
self.a_mat = glorot(keys[0], (hidden_size,))
|
176
|
+
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
177
|
+
|
178
|
+
def forward(self, h: Array, x: Array) -> Array:
|
179
|
+
h = self.a_mat * h + self.b_mat.T @ x
|
180
|
+
h = jax.nn.tanh(h)
|
181
|
+
return h
|
182
|
+
|
183
|
+
def get_kernel(self, length: int) -> Array:
|
184
|
+
"""Returns the kernel with time as the final dimension."""
|
185
|
+
exponents = jnp.arange(length)
|
186
|
+
kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
|
187
|
+
kernel = kernel[:, None, :] # (H, 1, L)
|
188
|
+
return kernel
|
189
|
+
|
190
|
+
def forward_across_time(self, x: Array) -> Array:
|
191
|
+
"""Convolves x (T, H) across time using the kernel."""
|
192
|
+
tsz, nhid = x.shape
|
193
|
+
|
194
|
+
# Compute s = x @ U.T + b, with shape (N, T, H)
|
195
|
+
s = self.b_mat.T @ x
|
196
|
+
s = s.T # (H, T)
|
197
|
+
|
198
|
+
kernel = self.get_kernel(tsz) # (H, 1, T)
|
199
|
+
kernel_flipped = jnp.flip(kernel, axis=-1)
|
200
|
+
|
201
|
+
# Pad s on the left along the time axis (pad length T-1)
|
202
|
+
s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
|
203
|
+
|
204
|
+
# Perform depthwise (grouped) 1D convolution.
|
205
|
+
# We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
|
206
|
+
# The dimension_numbers are chosen so that the channel dimension is second.
|
207
|
+
conv_out = jax.lax.conv_general_dilated(
|
208
|
+
s_padded,
|
209
|
+
kernel_flipped,
|
210
|
+
window_strides=(1,),
|
211
|
+
padding="VALID",
|
212
|
+
dimension_numbers=("NCH", "OIH", "NCH"),
|
213
|
+
feature_group_count=nhid,
|
214
|
+
)
|
215
|
+
# conv_out has shape (N, H, T); transpose to (N, T, H)
|
216
|
+
conv_out = jnp.transpose(conv_out, (0, 2, 1))
|
217
|
+
return conv_out
|
218
|
+
|
219
|
+
def naive_forward_accross_time(self, x: Array) -> Array:
|
220
|
+
"""Naively forward across time."""
|
221
|
+
|
222
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
223
|
+
h = self.forward(h, x)
|
224
|
+
return h, h
|
225
|
+
|
226
|
+
h_0 = jnp.zeros(self.a_mat.shape[0])
|
227
|
+
_, h_seq = jax.lax.scan(step, h_0, x)
|
228
|
+
return h_seq
|
229
|
+
|
230
|
+
|
231
|
+
class S4(eqx.Module):
|
232
|
+
vocab_embedding: eqx.nn.Embedding
|
233
|
+
proj_in: eqx.nn.Linear
|
234
|
+
proj_out: eqx.nn.Linear
|
235
|
+
blocks: list[BaseSSMBlock]
|
236
|
+
num_layers: int = eqx.static_field()
|
237
|
+
hidden_size: int = eqx.static_field()
|
238
|
+
skip_connections: bool = eqx.static_field()
|
239
|
+
|
240
|
+
def __init__(
|
241
|
+
self,
|
242
|
+
input_size: int,
|
243
|
+
hidden_size: int,
|
244
|
+
output_size: int,
|
245
|
+
num_layers: int,
|
246
|
+
block_type: Literal["ssm", "diag"] = "ssm",
|
247
|
+
skip_connections: bool = False,
|
248
|
+
*,
|
249
|
+
key: PRNGKeyArray,
|
250
|
+
) -> None:
|
251
|
+
vocab_key, s4_key = jax.random.split(key, 2)
|
252
|
+
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
253
|
+
self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
|
254
|
+
self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
|
255
|
+
|
256
|
+
block_keys = jax.random.split(s4_key, num_layers)
|
257
|
+
|
258
|
+
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
259
|
+
match block_type:
|
260
|
+
case "ssm":
|
261
|
+
return SSMBlock(hidden_size, key=key)
|
262
|
+
case "diag":
|
263
|
+
return DiagSSMBlock(hidden_size, key=key)
|
264
|
+
case _:
|
265
|
+
raise ValueError(f"Unknown block type: {block_type}")
|
266
|
+
|
267
|
+
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
268
|
+
self.skip_connections = skip_connections
|
269
|
+
self.num_layers = num_layers
|
270
|
+
self.hidden_size = hidden_size
|
271
|
+
|
272
|
+
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
273
|
+
new_hs = []
|
274
|
+
for i, block in enumerate(self.blocks):
|
275
|
+
h = block.forward(hs[i], x)
|
276
|
+
new_hs.append(h)
|
277
|
+
xh = jax.nn.gelu(h)
|
278
|
+
x = xh + x if self.skip_connections else xh
|
279
|
+
y = self.proj_out(x)
|
280
|
+
return new_hs, y
|
281
|
+
|
282
|
+
def _embed_input(self, x: Array) -> Array:
|
283
|
+
"""U is the input to the S4 cell."""
|
284
|
+
embedded = self.vocab_embedding(x)
|
285
|
+
return jax.nn.gelu(self.proj_in(embedded))
|
286
|
+
|
287
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
288
|
+
x_emb = jax.vmap(self._embed_input)(x_seq)
|
289
|
+
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
290
|
+
|
291
|
+
def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
292
|
+
hs, y = self(hs, x)
|
293
|
+
return hs, y
|
294
|
+
|
295
|
+
_, y_seq = jax.lax.scan(step, hs, x_emb)
|
296
|
+
return y_seq
|
xax/task/mixins/train.py
CHANGED
@@ -29,6 +29,7 @@ from typing import (
|
|
29
29
|
|
30
30
|
import equinox as eqx
|
31
31
|
import jax
|
32
|
+
import jax.numpy as jnp
|
32
33
|
import numpy as np
|
33
34
|
import optax
|
34
35
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
@@ -162,6 +163,7 @@ class TrainConfig(
|
|
162
163
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
163
164
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
164
165
|
random_seed: int = field(1337, help="Random seed for the task")
|
166
|
+
global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
|
165
167
|
|
166
168
|
|
167
169
|
Config = TypeVar("Config", bound=TrainConfig)
|
@@ -403,12 +405,12 @@ class TrainMixin(
|
|
403
405
|
model_static: PyTree,
|
404
406
|
batch: Batch,
|
405
407
|
state: State,
|
406
|
-
) -> tuple[Array, tuple[Output,
|
408
|
+
) -> tuple[Array, tuple[Output, dict[str, Array]]]:
|
407
409
|
model = eqx.combine(model_arr, model_static)
|
408
410
|
output = self.get_output(model, batch, state)
|
409
411
|
loss = self.compute_loss(model, batch, output, state)
|
410
412
|
metrics = self.compute_metrics(model, batch, output, loss, state)
|
411
|
-
return loss, (output,
|
413
|
+
return loss, (output, metrics)
|
412
414
|
|
413
415
|
def update(
|
414
416
|
self,
|
@@ -418,13 +420,44 @@ class TrainMixin(
|
|
418
420
|
opt_state: optax.OptState,
|
419
421
|
batch: Batch,
|
420
422
|
state: State,
|
421
|
-
) -> tuple[PyTree, optax.OptState, Output,
|
423
|
+
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
422
424
|
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
423
425
|
grad_fn = xax_jit(static_argnums=[1])(grad_fn)
|
424
426
|
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
425
|
-
|
426
|
-
model_arr
|
427
|
-
|
427
|
+
model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
|
428
|
+
return model_arr, opt_state, output, metrics | grad_metrics
|
429
|
+
|
430
|
+
@xax_jit(static_argnames=["self", "optimizer"])
|
431
|
+
def apply_gradients_with_clipping(
|
432
|
+
self,
|
433
|
+
model_arr: PyTree,
|
434
|
+
grads: PyTree,
|
435
|
+
optimizer: optax.GradientTransformation,
|
436
|
+
opt_state: optax.OptState,
|
437
|
+
) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
|
438
|
+
grad_norm = optax.global_norm(grads)
|
439
|
+
grad_metrics = {"grad_norm": grad_norm}
|
440
|
+
|
441
|
+
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
442
|
+
# Clip the global gradient norm to some desired range.
|
443
|
+
grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
|
444
|
+
grads = jax.tree.map(lambda x: x * grad_factor, grads)
|
445
|
+
|
446
|
+
# Apply the gradient updates.
|
447
|
+
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
448
|
+
new_model_arr = eqx.apply_updates(model_arr, updates)
|
449
|
+
return new_model_arr, new_opt_state
|
450
|
+
|
451
|
+
# Don't apply updates if the gradient is NaN or Inf.
|
452
|
+
new_model_arr, new_opt_state = jax.lax.cond(
|
453
|
+
jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
|
454
|
+
lambda *_: (model_arr, opt_state),
|
455
|
+
apply,
|
456
|
+
grads,
|
457
|
+
grad_norm,
|
458
|
+
)
|
459
|
+
|
460
|
+
return new_model_arr, new_opt_state, grad_metrics
|
428
461
|
|
429
462
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
430
463
|
"""Gets the batch size for the current batch.
|
@@ -512,7 +545,7 @@ class TrainMixin(
|
|
512
545
|
state: State,
|
513
546
|
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
514
547
|
model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
|
515
|
-
return model_arr, opt_state, output, metrics
|
548
|
+
return model_arr, opt_state, output, FrozenDict(metrics)
|
516
549
|
|
517
550
|
@xax_jit(static_argnames=["self", "model_static"])
|
518
551
|
def val_step(
|
@@ -523,7 +556,7 @@ class TrainMixin(
|
|
523
556
|
state: State,
|
524
557
|
) -> tuple[Output, FrozenDict[str, Array]]:
|
525
558
|
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
526
|
-
return output, metrics
|
559
|
+
return output, FrozenDict(metrics)
|
527
560
|
|
528
561
|
def train_loop(
|
529
562
|
self,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=2JdSxsZphJJFVMGBVXNc0hP2p0FVOu5y7xSgPRNeyNY,13835
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -11,8 +11,10 @@ xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
13
|
xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
14
|
-
xax/nn/
|
14
|
+
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
|
+
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
15
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
|
+
xax/nn/ssm.py,sha256=eFeGkV1pkVGc0vNrQbykCbFnlPXQqsqVA_JVzLBHD28,9865
|
16
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
19
|
xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
|
18
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
@@ -39,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=lgLHiHQtnDK0XS3SwHTYZtDv5CTbPRN1-p_K9KiIpHQ,26000
|
43
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
46
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -56,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
56
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
61
|
+
xax-0.1.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.11.dist-info/METADATA,sha256=qDhn5EGxdiuEe5gQUZiBC430sXhJOPRWboTvsh2onxs,1878
|
63
|
+
xax-0.1.11.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|