xax 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +17 -1
- xax/nn/losses.py +9 -0
- xax/nn/norm.py +2 -1
- xax/nn/ssm.py +296 -0
- xax/task/mixins/train.py +97 -28
- xax/utils/pytree.py +11 -11
- {xax-0.1.9.dist-info → xax-0.1.11.dist-info}/METADATA +1 -1
- {xax-0.1.9.dist-info → xax-0.1.11.dist-info}/RECORD +11 -9
- {xax-0.1.9.dist-info → xax-0.1.11.dist-info}/WHEEL +0 -0
- {xax-0.1.9.dist-info → xax-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.9.dist-info → xax-0.1.11.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.11"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,9 +43,16 @@ __all__ = [
|
|
43
43
|
"euler_to_quat",
|
44
44
|
"get_projected_gravity_vector_from_quat",
|
45
45
|
"quat_to_euler",
|
46
|
+
"cross_entropy",
|
46
47
|
"cast_norm_type",
|
47
48
|
"get_norm",
|
48
49
|
"is_master",
|
50
|
+
"DiagSSMBlock",
|
51
|
+
"DiscreteTimeS4",
|
52
|
+
"S4",
|
53
|
+
"S4Layer",
|
54
|
+
"S6Layer",
|
55
|
+
"SSMBlock",
|
49
56
|
"BaseLauncher",
|
50
57
|
"CliLauncher",
|
51
58
|
"SingleProcessLauncher",
|
@@ -196,9 +203,16 @@ NAME_MAP: dict[str, str] = {
|
|
196
203
|
"euler_to_quat": "nn.geom",
|
197
204
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
198
205
|
"quat_to_euler": "nn.geom",
|
206
|
+
"cross_entropy": "nn.losses",
|
199
207
|
"cast_norm_type": "nn.norm",
|
200
208
|
"get_norm": "nn.norm",
|
201
209
|
"is_master": "nn.parallel",
|
210
|
+
"DiagSSMBlock": "nn.ssm",
|
211
|
+
"DiscreteTimeS4": "nn.ssm",
|
212
|
+
"S4": "nn.ssm",
|
213
|
+
"S4Layer": "nn.ssm",
|
214
|
+
"S6Layer": "nn.ssm",
|
215
|
+
"SSMBlock": "nn.ssm",
|
202
216
|
"BaseLauncher": "task.launchers.base",
|
203
217
|
"CliLauncher": "task.launchers.cli",
|
204
218
|
"SingleProcessLauncher": "task.launchers.single_process",
|
@@ -351,8 +365,10 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
351
365
|
get_projected_gravity_vector_from_quat,
|
352
366
|
quat_to_euler,
|
353
367
|
)
|
368
|
+
from xax.nn.losses import cross_entropy
|
354
369
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
355
370
|
from xax.nn.parallel import is_master
|
371
|
+
from xax.nn.ssm import S4, DiagSSMBlock, DiscreteTimeS4, S4Layer, S6Layer, SSMBlock
|
356
372
|
from xax.task.base import RawConfigType
|
357
373
|
from xax.task.launchers.base import BaseLauncher
|
358
374
|
from xax.task.launchers.cli import CliLauncher
|
xax/nn/losses.py
ADDED
@@ -0,0 +1,9 @@
|
|
1
|
+
"""Defines some common loss functions."""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from jaxtyping import Array
|
5
|
+
|
6
|
+
|
7
|
+
def cross_entropy(y: Array, pred_y: Array, axis: int = 1) -> Array:
|
8
|
+
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, axis), axis=axis)
|
9
|
+
return -jnp.mean(pred_y)
|
xax/nn/norm.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from typing import Literal, cast, get_args
|
4
4
|
|
5
5
|
import jax.numpy as jnp
|
6
|
+
from jaxtyping import Array
|
6
7
|
|
7
8
|
NormType = Literal["l1", "l2"]
|
8
9
|
|
@@ -13,7 +14,7 @@ def cast_norm_type(norm: str) -> NormType:
|
|
13
14
|
return cast(NormType, norm)
|
14
15
|
|
15
16
|
|
16
|
-
def get_norm(x:
|
17
|
+
def get_norm(x: Array, norm: NormType) -> Array:
|
17
18
|
match norm:
|
18
19
|
case "l1":
|
19
20
|
return jnp.abs(x)
|
xax/nn/ssm.py
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
"""State space models."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
+
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
+
|
15
|
+
|
16
|
+
class DiscreteTimeS4(eqx.Module):
|
17
|
+
a: Array
|
18
|
+
B: Array
|
19
|
+
C: Array
|
20
|
+
proj_in: eqx.nn.Linear
|
21
|
+
proj_out: eqx.nn.Linear
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
hidden_size: int,
|
26
|
+
projection_size: int,
|
27
|
+
input_size: int,
|
28
|
+
output_size: int,
|
29
|
+
*,
|
30
|
+
key: PRNGKeyArray,
|
31
|
+
) -> None:
|
32
|
+
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
33
|
+
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
34
|
+
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
35
|
+
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
36
|
+
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
37
|
+
|
38
|
+
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
39
|
+
h = self.a * h + self.B.T @ x
|
40
|
+
y = self.C.T @ h
|
41
|
+
return h, y
|
42
|
+
|
43
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
44
|
+
x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
|
45
|
+
h = jnp.zeros(self.a.shape[0])
|
46
|
+
|
47
|
+
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
48
|
+
h = self.a * h + self.B.T @ x
|
49
|
+
y = self.C.T @ h
|
50
|
+
return h, y
|
51
|
+
|
52
|
+
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
53
|
+
y_out = jax.vmap(self.proj_out)(y_seq)
|
54
|
+
return y_out
|
55
|
+
|
56
|
+
|
57
|
+
class S4Layer(eqx.Module):
|
58
|
+
a: Array
|
59
|
+
B: Array
|
60
|
+
C: Array
|
61
|
+
proj_in: eqx.nn.Linear
|
62
|
+
proj_out: eqx.nn.Linear
|
63
|
+
delta: Array
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
hidden_size: int,
|
68
|
+
projection_size: int,
|
69
|
+
input_size: int,
|
70
|
+
output_size: int,
|
71
|
+
*,
|
72
|
+
key: PRNGKeyArray,
|
73
|
+
) -> None:
|
74
|
+
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
75
|
+
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
76
|
+
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
77
|
+
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
78
|
+
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
79
|
+
self.delta = jax.random.uniform(key, (hidden_size,))
|
80
|
+
|
81
|
+
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
82
|
+
delta_a = self.delta * self.a
|
83
|
+
a_bar = jnp.exp(delta_a)
|
84
|
+
b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
|
85
|
+
h = a_bar * h + b_bar.T @ x
|
86
|
+
y = self.C.T @ h
|
87
|
+
return h, y
|
88
|
+
|
89
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
90
|
+
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
91
|
+
h = jnp.zeros(self.a.shape[0])
|
92
|
+
|
93
|
+
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
94
|
+
h = self.a * h + self.B.T @ x
|
95
|
+
y = self.C.T @ h
|
96
|
+
return h, y
|
97
|
+
|
98
|
+
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
99
|
+
y_out = jax.vmap(self.proj_out)(y_seq)
|
100
|
+
return y_out
|
101
|
+
|
102
|
+
|
103
|
+
class S6Layer(eqx.Module):
|
104
|
+
a: Array
|
105
|
+
B: Array
|
106
|
+
C: Array
|
107
|
+
proj_in: eqx.nn.Linear
|
108
|
+
proj_out: eqx.nn.Linear
|
109
|
+
delta: Array
|
110
|
+
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
hidden_size: int,
|
114
|
+
projection_size: int,
|
115
|
+
input_size: int,
|
116
|
+
output_size: int,
|
117
|
+
*,
|
118
|
+
key: PRNGKeyArray,
|
119
|
+
) -> None:
|
120
|
+
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
121
|
+
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
122
|
+
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
123
|
+
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
124
|
+
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
125
|
+
self.delta = jax.random.uniform(key, (hidden_size,))
|
126
|
+
|
127
|
+
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
128
|
+
h = self.a * h + self.B.T @ x
|
129
|
+
y = self.C.T @ h
|
130
|
+
return h, y
|
131
|
+
|
132
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
133
|
+
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
134
|
+
h = jnp.zeros(self.a.shape[0])
|
135
|
+
|
136
|
+
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
137
|
+
h = self.a * h + self.B.T @ x
|
138
|
+
y = self.C.T @ h
|
139
|
+
return h, y
|
140
|
+
|
141
|
+
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
142
|
+
y_out = jax.vmap(self.proj_out)(y_seq)
|
143
|
+
return y_out
|
144
|
+
|
145
|
+
|
146
|
+
class BaseSSMBlock(eqx.Module, ABC):
|
147
|
+
@abstractmethod
|
148
|
+
def forward(self, h: Array, x: Array) -> Array:
|
149
|
+
pass
|
150
|
+
|
151
|
+
|
152
|
+
class SSMBlock(BaseSSMBlock):
|
153
|
+
a_mat: Array
|
154
|
+
b_mat: Array
|
155
|
+
|
156
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
157
|
+
key_a, key_b = jax.random.split(key)
|
158
|
+
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
159
|
+
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
160
|
+
|
161
|
+
def forward(self, h: Array, x: Array) -> Array:
|
162
|
+
h = self.a_mat @ h + self.b_mat.T @ x
|
163
|
+
return h
|
164
|
+
|
165
|
+
def get_kernel(self, length: int) -> Array:
|
166
|
+
return self.a_mat
|
167
|
+
|
168
|
+
|
169
|
+
class DiagSSMBlock(BaseSSMBlock):
|
170
|
+
a_mat: Array
|
171
|
+
b_mat: Array
|
172
|
+
|
173
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
174
|
+
keys = jax.random.split(key, 2)
|
175
|
+
self.a_mat = glorot(keys[0], (hidden_size,))
|
176
|
+
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
177
|
+
|
178
|
+
def forward(self, h: Array, x: Array) -> Array:
|
179
|
+
h = self.a_mat * h + self.b_mat.T @ x
|
180
|
+
h = jax.nn.tanh(h)
|
181
|
+
return h
|
182
|
+
|
183
|
+
def get_kernel(self, length: int) -> Array:
|
184
|
+
"""Returns the kernel with time as the final dimension."""
|
185
|
+
exponents = jnp.arange(length)
|
186
|
+
kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
|
187
|
+
kernel = kernel[:, None, :] # (H, 1, L)
|
188
|
+
return kernel
|
189
|
+
|
190
|
+
def forward_across_time(self, x: Array) -> Array:
|
191
|
+
"""Convolves x (T, H) across time using the kernel."""
|
192
|
+
tsz, nhid = x.shape
|
193
|
+
|
194
|
+
# Compute s = x @ U.T + b, with shape (N, T, H)
|
195
|
+
s = self.b_mat.T @ x
|
196
|
+
s = s.T # (H, T)
|
197
|
+
|
198
|
+
kernel = self.get_kernel(tsz) # (H, 1, T)
|
199
|
+
kernel_flipped = jnp.flip(kernel, axis=-1)
|
200
|
+
|
201
|
+
# Pad s on the left along the time axis (pad length T-1)
|
202
|
+
s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
|
203
|
+
|
204
|
+
# Perform depthwise (grouped) 1D convolution.
|
205
|
+
# We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
|
206
|
+
# The dimension_numbers are chosen so that the channel dimension is second.
|
207
|
+
conv_out = jax.lax.conv_general_dilated(
|
208
|
+
s_padded,
|
209
|
+
kernel_flipped,
|
210
|
+
window_strides=(1,),
|
211
|
+
padding="VALID",
|
212
|
+
dimension_numbers=("NCH", "OIH", "NCH"),
|
213
|
+
feature_group_count=nhid,
|
214
|
+
)
|
215
|
+
# conv_out has shape (N, H, T); transpose to (N, T, H)
|
216
|
+
conv_out = jnp.transpose(conv_out, (0, 2, 1))
|
217
|
+
return conv_out
|
218
|
+
|
219
|
+
def naive_forward_accross_time(self, x: Array) -> Array:
|
220
|
+
"""Naively forward across time."""
|
221
|
+
|
222
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
223
|
+
h = self.forward(h, x)
|
224
|
+
return h, h
|
225
|
+
|
226
|
+
h_0 = jnp.zeros(self.a_mat.shape[0])
|
227
|
+
_, h_seq = jax.lax.scan(step, h_0, x)
|
228
|
+
return h_seq
|
229
|
+
|
230
|
+
|
231
|
+
class S4(eqx.Module):
|
232
|
+
vocab_embedding: eqx.nn.Embedding
|
233
|
+
proj_in: eqx.nn.Linear
|
234
|
+
proj_out: eqx.nn.Linear
|
235
|
+
blocks: list[BaseSSMBlock]
|
236
|
+
num_layers: int = eqx.static_field()
|
237
|
+
hidden_size: int = eqx.static_field()
|
238
|
+
skip_connections: bool = eqx.static_field()
|
239
|
+
|
240
|
+
def __init__(
|
241
|
+
self,
|
242
|
+
input_size: int,
|
243
|
+
hidden_size: int,
|
244
|
+
output_size: int,
|
245
|
+
num_layers: int,
|
246
|
+
block_type: Literal["ssm", "diag"] = "ssm",
|
247
|
+
skip_connections: bool = False,
|
248
|
+
*,
|
249
|
+
key: PRNGKeyArray,
|
250
|
+
) -> None:
|
251
|
+
vocab_key, s4_key = jax.random.split(key, 2)
|
252
|
+
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
253
|
+
self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
|
254
|
+
self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
|
255
|
+
|
256
|
+
block_keys = jax.random.split(s4_key, num_layers)
|
257
|
+
|
258
|
+
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
259
|
+
match block_type:
|
260
|
+
case "ssm":
|
261
|
+
return SSMBlock(hidden_size, key=key)
|
262
|
+
case "diag":
|
263
|
+
return DiagSSMBlock(hidden_size, key=key)
|
264
|
+
case _:
|
265
|
+
raise ValueError(f"Unknown block type: {block_type}")
|
266
|
+
|
267
|
+
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
268
|
+
self.skip_connections = skip_connections
|
269
|
+
self.num_layers = num_layers
|
270
|
+
self.hidden_size = hidden_size
|
271
|
+
|
272
|
+
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
273
|
+
new_hs = []
|
274
|
+
for i, block in enumerate(self.blocks):
|
275
|
+
h = block.forward(hs[i], x)
|
276
|
+
new_hs.append(h)
|
277
|
+
xh = jax.nn.gelu(h)
|
278
|
+
x = xh + x if self.skip_connections else xh
|
279
|
+
y = self.proj_out(x)
|
280
|
+
return new_hs, y
|
281
|
+
|
282
|
+
def _embed_input(self, x: Array) -> Array:
|
283
|
+
"""U is the input to the S4 cell."""
|
284
|
+
embedded = self.vocab_embedding(x)
|
285
|
+
return jax.nn.gelu(self.proj_in(embedded))
|
286
|
+
|
287
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
288
|
+
x_emb = jax.vmap(self._embed_input)(x_seq)
|
289
|
+
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
290
|
+
|
291
|
+
def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
292
|
+
hs, y = self(hs, x)
|
293
|
+
return hs, y
|
294
|
+
|
295
|
+
_, y_seq = jax.lax.scan(step, hs, x_emb)
|
296
|
+
return y_seq
|
xax/task/mixins/train.py
CHANGED
@@ -29,6 +29,7 @@ from typing import (
|
|
29
29
|
|
30
30
|
import equinox as eqx
|
31
31
|
import jax
|
32
|
+
import jax.numpy as jnp
|
32
33
|
import numpy as np
|
33
34
|
import optax
|
34
35
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
@@ -56,6 +57,7 @@ from xax.utils.experiments import (
|
|
56
57
|
from xax.utils.jax import jit as xax_jit
|
57
58
|
from xax.utils.logging import LOG_STATUS
|
58
59
|
from xax.utils.text import highlight_exception_message, show_info
|
60
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
59
61
|
|
60
62
|
logger = logging.getLogger(__name__)
|
61
63
|
|
@@ -161,6 +163,7 @@ class TrainConfig(
|
|
161
163
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
162
164
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
163
165
|
random_seed: int = field(1337, help="Random seed for the task")
|
166
|
+
global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
|
164
167
|
|
165
168
|
|
166
169
|
Config = TypeVar("Config", bound=TrainConfig)
|
@@ -215,7 +218,7 @@ class TrainMixin(
|
|
215
218
|
state = super().on_step_end(state)
|
216
219
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
217
220
|
|
218
|
-
def log_train_step(self, batch: Batch, output: Output, state: State) -> None:
|
221
|
+
def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
219
222
|
"""Override this function to do logging during the training phase.
|
220
223
|
|
221
224
|
This function is called after the model forward pass and before the
|
@@ -224,10 +227,11 @@ class TrainMixin(
|
|
224
227
|
Args:
|
225
228
|
batch: The batch from the dataloader.
|
226
229
|
output: The model output.
|
230
|
+
metrics: The metrics for the current batch.
|
227
231
|
state: The current training state.
|
228
232
|
"""
|
229
233
|
|
230
|
-
def log_valid_step(self, batch: Batch, output: Output, state: State) -> None:
|
234
|
+
def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
231
235
|
"""Override this function to do logging during the validation phase.
|
232
236
|
|
233
237
|
This function is called after the model forward pass. It is called in
|
@@ -236,6 +240,7 @@ class TrainMixin(
|
|
236
240
|
Args:
|
237
241
|
batch: The batch from the dataloader.
|
238
242
|
output: The model output.
|
243
|
+
metrics: The metrics for the current batch.
|
239
244
|
state: The current training state.
|
240
245
|
"""
|
241
246
|
|
@@ -246,18 +251,23 @@ class TrainMixin(
|
|
246
251
|
for k, v in d.items():
|
247
252
|
self.logger.log_scalar(k, v, namespace=ns)
|
248
253
|
|
249
|
-
def log_step(self, batch: Batch, output: Output,
|
254
|
+
def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
250
255
|
phase = state.phase
|
251
256
|
|
252
|
-
|
257
|
+
for k, v in metrics.items():
|
258
|
+
if v.size == 1:
|
259
|
+
self.logger.log_scalar(k, v.item())
|
260
|
+
else:
|
261
|
+
self.logger.log_histogram(k, v)
|
262
|
+
|
253
263
|
self.log_state_timers(state)
|
254
264
|
|
255
265
|
# Delegate to the appropriate logging function based on the phase.
|
256
266
|
match phase:
|
257
267
|
case "train":
|
258
|
-
self.log_train_step(batch, output, state)
|
268
|
+
self.log_train_step(batch, output, metrics, state)
|
259
269
|
case "valid":
|
260
|
-
self.log_valid_step(batch, output, state)
|
270
|
+
self.log_valid_step(batch, output, metrics, state)
|
261
271
|
case _:
|
262
272
|
raise KeyError(f"Unknown phase: {phase}")
|
263
273
|
|
@@ -364,32 +374,90 @@ class TrainMixin(
|
|
364
374
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
365
375
|
return output
|
366
376
|
|
377
|
+
def compute_metrics(
|
378
|
+
self,
|
379
|
+
model: PyTree,
|
380
|
+
batch: Batch,
|
381
|
+
output: Output,
|
382
|
+
loss: Array,
|
383
|
+
state: State,
|
384
|
+
) -> dict[str, Array]:
|
385
|
+
"""Computes the metrics for the current batch.
|
386
|
+
|
387
|
+
Args:
|
388
|
+
model: The current model.
|
389
|
+
batch: The current minibatch of samples.
|
390
|
+
output: The output from the model.
|
391
|
+
loss: The loss for the current batch.
|
392
|
+
state: The current training state.
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
A dictionary of metrics.
|
396
|
+
"""
|
397
|
+
return {
|
398
|
+
"loss": loss,
|
399
|
+
}
|
400
|
+
|
401
|
+
@xax_jit(static_argnames=["self", "model_static"])
|
367
402
|
def get_output_and_loss(
|
368
403
|
self,
|
369
|
-
model_static: PyTree,
|
370
404
|
model_arr: PyTree,
|
405
|
+
model_static: PyTree,
|
371
406
|
batch: Batch,
|
372
407
|
state: State,
|
373
|
-
) -> tuple[Array, Output]:
|
408
|
+
) -> tuple[Array, tuple[Output, dict[str, Array]]]:
|
374
409
|
model = eqx.combine(model_arr, model_static)
|
375
410
|
output = self.get_output(model, batch, state)
|
376
411
|
loss = self.compute_loss(model, batch, output, state)
|
377
|
-
|
412
|
+
metrics = self.compute_metrics(model, batch, output, loss, state)
|
413
|
+
return loss, (output, metrics)
|
378
414
|
|
379
415
|
def update(
|
380
416
|
self,
|
381
|
-
model_static: PyTree,
|
382
417
|
model_arr: PyTree,
|
418
|
+
model_static: PyTree,
|
383
419
|
optimizer: optax.GradientTransformation,
|
384
420
|
opt_state: optax.OptState,
|
385
421
|
batch: Batch,
|
386
422
|
state: State,
|
387
|
-
) -> tuple[
|
388
|
-
grad_fn =
|
389
|
-
|
390
|
-
|
391
|
-
model_arr =
|
392
|
-
return
|
423
|
+
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
424
|
+
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
425
|
+
grad_fn = xax_jit(static_argnums=[1])(grad_fn)
|
426
|
+
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
427
|
+
model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
|
428
|
+
return model_arr, opt_state, output, metrics | grad_metrics
|
429
|
+
|
430
|
+
@xax_jit(static_argnames=["self", "optimizer"])
|
431
|
+
def apply_gradients_with_clipping(
|
432
|
+
self,
|
433
|
+
model_arr: PyTree,
|
434
|
+
grads: PyTree,
|
435
|
+
optimizer: optax.GradientTransformation,
|
436
|
+
opt_state: optax.OptState,
|
437
|
+
) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
|
438
|
+
grad_norm = optax.global_norm(grads)
|
439
|
+
grad_metrics = {"grad_norm": grad_norm}
|
440
|
+
|
441
|
+
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
442
|
+
# Clip the global gradient norm to some desired range.
|
443
|
+
grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
|
444
|
+
grads = jax.tree.map(lambda x: x * grad_factor, grads)
|
445
|
+
|
446
|
+
# Apply the gradient updates.
|
447
|
+
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
448
|
+
new_model_arr = eqx.apply_updates(model_arr, updates)
|
449
|
+
return new_model_arr, new_opt_state
|
450
|
+
|
451
|
+
# Don't apply updates if the gradient is NaN or Inf.
|
452
|
+
new_model_arr, new_opt_state = jax.lax.cond(
|
453
|
+
jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
|
454
|
+
lambda *_: (model_arr, opt_state),
|
455
|
+
apply,
|
456
|
+
grads,
|
457
|
+
grad_norm,
|
458
|
+
)
|
459
|
+
|
460
|
+
return new_model_arr, new_opt_state, grad_metrics
|
393
461
|
|
394
462
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
395
463
|
"""Gets the batch size for the current batch.
|
@@ -469,25 +537,26 @@ class TrainMixin(
|
|
469
537
|
@xax_jit(static_argnames=["self", "model_static", "optimizer"])
|
470
538
|
def train_step(
|
471
539
|
self,
|
472
|
-
model_static: PyTree,
|
473
540
|
model_arr: PyTree,
|
541
|
+
model_static: PyTree,
|
474
542
|
optimizer: optax.GradientTransformation,
|
475
543
|
opt_state: optax.OptState,
|
476
544
|
batch: Batch,
|
477
545
|
state: State,
|
478
|
-
) -> tuple[PyTree, optax.OptState,
|
479
|
-
|
480
|
-
return model_arr, opt_state,
|
546
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
547
|
+
model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
|
548
|
+
return model_arr, opt_state, output, FrozenDict(metrics)
|
481
549
|
|
482
550
|
@xax_jit(static_argnames=["self", "model_static"])
|
483
551
|
def val_step(
|
484
552
|
self,
|
485
|
-
model_static: PyTree,
|
486
553
|
model_arr: PyTree,
|
554
|
+
model_static: PyTree,
|
487
555
|
batch: Batch,
|
488
556
|
state: State,
|
489
|
-
) -> tuple[
|
490
|
-
|
557
|
+
) -> tuple[Output, FrozenDict[str, Array]]:
|
558
|
+
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
559
|
+
return output, FrozenDict(metrics)
|
491
560
|
|
492
561
|
def train_loop(
|
493
562
|
self,
|
@@ -509,8 +578,8 @@ class TrainMixin(
|
|
509
578
|
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
510
579
|
)
|
511
580
|
|
512
|
-
|
513
|
-
self.log_step(valid_batch, output,
|
581
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
582
|
+
self.log_step(valid_batch, output, metrics, state)
|
514
583
|
|
515
584
|
state = self.on_step_start(state)
|
516
585
|
train_batch = next(train_pf)
|
@@ -520,15 +589,15 @@ class TrainMixin(
|
|
520
589
|
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
521
590
|
)
|
522
591
|
|
523
|
-
model_arr, opt_state,
|
524
|
-
model_static=model_static,
|
592
|
+
model_arr, opt_state, output, metrics = self.train_step(
|
525
593
|
model_arr=model_arr,
|
594
|
+
model_static=model_static,
|
526
595
|
optimizer=optimizer,
|
527
596
|
opt_state=opt_state,
|
528
597
|
batch=train_batch,
|
529
598
|
state=state,
|
530
599
|
)
|
531
|
-
self.log_step(train_batch, output,
|
600
|
+
self.log_step(train_batch, output, metrics, state)
|
532
601
|
|
533
602
|
state = self.on_step_end(state)
|
534
603
|
|
xax/utils/pytree.py
CHANGED
@@ -31,7 +31,7 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
|
31
31
|
|
32
32
|
def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
|
33
33
|
"""Get a slice of a pytree."""
|
34
|
-
return jax.
|
34
|
+
return jax.tree.map(lambda x: slice_array(x, start, slice_length), pytree)
|
35
35
|
|
36
36
|
|
37
37
|
def flatten_array(x: Array, flatten_size: int) -> Array:
|
@@ -43,14 +43,14 @@ def flatten_array(x: Array, flatten_size: int) -> Array:
|
|
43
43
|
|
44
44
|
def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
45
45
|
"""Flatten a pytree into a (flatten_size, ...) pytree."""
|
46
|
-
return jax.
|
46
|
+
return jax.tree.map(lambda x: flatten_array(x, flatten_size), pytree)
|
47
47
|
|
48
48
|
|
49
49
|
def pytree_has_nans(pytree: PyTree) -> Array:
|
50
50
|
"""Check if a pytree has any NaNs."""
|
51
51
|
has_nans = jax.tree_util.tree_reduce(
|
52
52
|
lambda a, b: jnp.logical_or(a, b),
|
53
|
-
jax.
|
53
|
+
jax.tree.map(lambda x: jnp.any(jnp.isnan(x)), pytree),
|
54
54
|
)
|
55
55
|
return has_nans
|
56
56
|
|
@@ -58,13 +58,13 @@ def pytree_has_nans(pytree: PyTree) -> Array:
|
|
58
58
|
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
59
|
"""Update a pytree based on a condition."""
|
60
60
|
# Tricky, need use tree_map because where expects array leafs.
|
61
|
-
return jax.
|
61
|
+
return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
62
|
|
63
63
|
|
64
64
|
def compute_nan_ratio(pytree: PyTree) -> Array:
|
65
65
|
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
66
|
-
nan_counts = jax.
|
67
|
-
total_counts = jax.
|
66
|
+
nan_counts = jax.tree.map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
67
|
+
total_counts = jax.tree.map(lambda x: x.size, pytree)
|
68
68
|
|
69
69
|
total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
|
70
70
|
total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
|
@@ -118,7 +118,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
|
|
118
118
|
# Reshape back to the original shape
|
119
119
|
return permuted.reshape(orig_shape)
|
120
120
|
|
121
|
-
return jax.
|
121
|
+
return jax.tree.map(permute_array, data)
|
122
122
|
|
123
123
|
|
124
124
|
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
@@ -133,7 +133,7 @@ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], r
|
|
133
133
|
return x[tuple(idx_grids)]
|
134
134
|
return x
|
135
135
|
|
136
|
-
return jax.
|
136
|
+
return jax.tree.map(permute_array, data)
|
137
137
|
|
138
138
|
|
139
139
|
TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
|
@@ -215,7 +215,7 @@ def reshuffle_pytree_along_dims(
|
|
215
215
|
transpose_info[path] = (transpose_order, original_shape)
|
216
216
|
return x
|
217
217
|
|
218
|
-
jax.
|
218
|
+
jax.tree.map_with_path(prepare_for_shuffle, data)
|
219
219
|
|
220
220
|
# Create a transposed pytree
|
221
221
|
def get_transposed(path: PathType, x: PyTree) -> PyTree:
|
@@ -223,7 +223,7 @@ def reshuffle_pytree_along_dims(
|
|
223
223
|
return transposed_data[path]
|
224
224
|
return x
|
225
225
|
|
226
|
-
transposed_pytree = jax.
|
226
|
+
transposed_pytree = jax.tree.map_with_path(get_transposed, data)
|
227
227
|
|
228
228
|
# Reshuffle the transposed pytree along the leading dimensions
|
229
229
|
reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
|
@@ -235,4 +235,4 @@ def reshuffle_pytree_along_dims(
|
|
235
235
|
return transpose_back(x, transpose_order, original_shape)
|
236
236
|
return x
|
237
237
|
|
238
|
-
return jax.
|
238
|
+
return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=2JdSxsZphJJFVMGBVXNc0hP2p0FVOu5y7xSgPRNeyNY,13835
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -11,8 +11,10 @@ xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
13
|
xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
14
|
-
xax/nn/
|
14
|
+
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
|
+
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
15
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
|
+
xax/nn/ssm.py,sha256=eFeGkV1pkVGc0vNrQbykCbFnlPXQqsqVA_JVzLBHD28,9865
|
16
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
19
|
xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
|
18
20
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
@@ -39,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=lgLHiHQtnDK0XS3SwHTYZtDv5CTbPRN1-p_K9KiIpHQ,26000
|
43
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
46
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
47
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -48,7 +50,7 @@ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
|
48
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
49
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
50
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
51
|
-
xax/utils/pytree.py,sha256=
|
53
|
+
xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
|
52
54
|
xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
|
53
55
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -56,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
56
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
61
|
+
xax-0.1.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.1.11.dist-info/METADATA,sha256=qDhn5EGxdiuEe5gQUZiBC430sXhJOPRWboTvsh2onxs,1878
|
63
|
+
xax-0.1.11.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.1.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.1.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|