xax 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.9"
15
+ __version__ = "0.1.11"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,9 +43,16 @@ __all__ = [
43
43
  "euler_to_quat",
44
44
  "get_projected_gravity_vector_from_quat",
45
45
  "quat_to_euler",
46
+ "cross_entropy",
46
47
  "cast_norm_type",
47
48
  "get_norm",
48
49
  "is_master",
50
+ "DiagSSMBlock",
51
+ "DiscreteTimeS4",
52
+ "S4",
53
+ "S4Layer",
54
+ "S6Layer",
55
+ "SSMBlock",
49
56
  "BaseLauncher",
50
57
  "CliLauncher",
51
58
  "SingleProcessLauncher",
@@ -196,9 +203,16 @@ NAME_MAP: dict[str, str] = {
196
203
  "euler_to_quat": "nn.geom",
197
204
  "get_projected_gravity_vector_from_quat": "nn.geom",
198
205
  "quat_to_euler": "nn.geom",
206
+ "cross_entropy": "nn.losses",
199
207
  "cast_norm_type": "nn.norm",
200
208
  "get_norm": "nn.norm",
201
209
  "is_master": "nn.parallel",
210
+ "DiagSSMBlock": "nn.ssm",
211
+ "DiscreteTimeS4": "nn.ssm",
212
+ "S4": "nn.ssm",
213
+ "S4Layer": "nn.ssm",
214
+ "S6Layer": "nn.ssm",
215
+ "SSMBlock": "nn.ssm",
202
216
  "BaseLauncher": "task.launchers.base",
203
217
  "CliLauncher": "task.launchers.cli",
204
218
  "SingleProcessLauncher": "task.launchers.single_process",
@@ -351,8 +365,10 @@ if IMPORT_ALL or TYPE_CHECKING:
351
365
  get_projected_gravity_vector_from_quat,
352
366
  quat_to_euler,
353
367
  )
368
+ from xax.nn.losses import cross_entropy
354
369
  from xax.nn.norm import NormType, cast_norm_type, get_norm
355
370
  from xax.nn.parallel import is_master
371
+ from xax.nn.ssm import S4, DiagSSMBlock, DiscreteTimeS4, S4Layer, S6Layer, SSMBlock
356
372
  from xax.task.base import RawConfigType
357
373
  from xax.task.launchers.base import BaseLauncher
358
374
  from xax.task.launchers.cli import CliLauncher
xax/nn/losses.py ADDED
@@ -0,0 +1,9 @@
1
+ """Defines some common loss functions."""
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array
5
+
6
+
7
+ def cross_entropy(y: Array, pred_y: Array, axis: int = 1) -> Array:
8
+ pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, axis), axis=axis)
9
+ return -jnp.mean(pred_y)
xax/nn/norm.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from typing import Literal, cast, get_args
4
4
 
5
5
  import jax.numpy as jnp
6
+ from jaxtyping import Array
6
7
 
7
8
  NormType = Literal["l1", "l2"]
8
9
 
@@ -13,7 +14,7 @@ def cast_norm_type(norm: str) -> NormType:
13
14
  return cast(NormType, norm)
14
15
 
15
16
 
16
- def get_norm(x: jnp.ndarray, norm: NormType) -> jnp.ndarray:
17
+ def get_norm(x: Array, norm: NormType) -> Array:
17
18
  match norm:
18
19
  case "l1":
19
20
  return jnp.abs(x)
xax/nn/ssm.py ADDED
@@ -0,0 +1,296 @@
1
+ """State space models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Literal
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, PRNGKeyArray
10
+
11
+
12
+ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
13
+ return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
14
+
15
+
16
+ class DiscreteTimeS4(eqx.Module):
17
+ a: Array
18
+ B: Array
19
+ C: Array
20
+ proj_in: eqx.nn.Linear
21
+ proj_out: eqx.nn.Linear
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ projection_size: int,
27
+ input_size: int,
28
+ output_size: int,
29
+ *,
30
+ key: PRNGKeyArray,
31
+ ) -> None:
32
+ self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
33
+ self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
34
+ self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
35
+ self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
36
+ self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
37
+
38
+ def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
39
+ h = self.a * h + self.B.T @ x
40
+ y = self.C.T @ h
41
+ return h, y
42
+
43
+ def predict_sequence(self, x_seq: Array) -> Array:
44
+ x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
45
+ h = jnp.zeros(self.a.shape[0])
46
+
47
+ def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
48
+ h = self.a * h + self.B.T @ x
49
+ y = self.C.T @ h
50
+ return h, y
51
+
52
+ _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
53
+ y_out = jax.vmap(self.proj_out)(y_seq)
54
+ return y_out
55
+
56
+
57
+ class S4Layer(eqx.Module):
58
+ a: Array
59
+ B: Array
60
+ C: Array
61
+ proj_in: eqx.nn.Linear
62
+ proj_out: eqx.nn.Linear
63
+ delta: Array
64
+
65
+ def __init__(
66
+ self,
67
+ hidden_size: int,
68
+ projection_size: int,
69
+ input_size: int,
70
+ output_size: int,
71
+ *,
72
+ key: PRNGKeyArray,
73
+ ) -> None:
74
+ self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
75
+ self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
76
+ self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
77
+ self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
78
+ self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
79
+ self.delta = jax.random.uniform(key, (hidden_size,))
80
+
81
+ def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
82
+ delta_a = self.delta * self.a
83
+ a_bar = jnp.exp(delta_a)
84
+ b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
85
+ h = a_bar * h + b_bar.T @ x
86
+ y = self.C.T @ h
87
+ return h, y
88
+
89
+ def predict_sequence(self, x_seq: Array) -> Array:
90
+ x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
91
+ h = jnp.zeros(self.a.shape[0])
92
+
93
+ def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
94
+ h = self.a * h + self.B.T @ x
95
+ y = self.C.T @ h
96
+ return h, y
97
+
98
+ _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
99
+ y_out = jax.vmap(self.proj_out)(y_seq)
100
+ return y_out
101
+
102
+
103
+ class S6Layer(eqx.Module):
104
+ a: Array
105
+ B: Array
106
+ C: Array
107
+ proj_in: eqx.nn.Linear
108
+ proj_out: eqx.nn.Linear
109
+ delta: Array
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int,
114
+ projection_size: int,
115
+ input_size: int,
116
+ output_size: int,
117
+ *,
118
+ key: PRNGKeyArray,
119
+ ) -> None:
120
+ self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
121
+ self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
122
+ self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
123
+ self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
124
+ self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
125
+ self.delta = jax.random.uniform(key, (hidden_size,))
126
+
127
+ def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
128
+ h = self.a * h + self.B.T @ x
129
+ y = self.C.T @ h
130
+ return h, y
131
+
132
+ def predict_sequence(self, x_seq: Array) -> Array:
133
+ x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
134
+ h = jnp.zeros(self.a.shape[0])
135
+
136
+ def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
137
+ h = self.a * h + self.B.T @ x
138
+ y = self.C.T @ h
139
+ return h, y
140
+
141
+ _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
142
+ y_out = jax.vmap(self.proj_out)(y_seq)
143
+ return y_out
144
+
145
+
146
+ class BaseSSMBlock(eqx.Module, ABC):
147
+ @abstractmethod
148
+ def forward(self, h: Array, x: Array) -> Array:
149
+ pass
150
+
151
+
152
+ class SSMBlock(BaseSSMBlock):
153
+ a_mat: Array
154
+ b_mat: Array
155
+
156
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
157
+ key_a, key_b = jax.random.split(key)
158
+ self.a_mat = glorot(key_a, (hidden_size, hidden_size))
159
+ self.b_mat = glorot(key_b, (hidden_size, hidden_size))
160
+
161
+ def forward(self, h: Array, x: Array) -> Array:
162
+ h = self.a_mat @ h + self.b_mat.T @ x
163
+ return h
164
+
165
+ def get_kernel(self, length: int) -> Array:
166
+ return self.a_mat
167
+
168
+
169
+ class DiagSSMBlock(BaseSSMBlock):
170
+ a_mat: Array
171
+ b_mat: Array
172
+
173
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
174
+ keys = jax.random.split(key, 2)
175
+ self.a_mat = glorot(keys[0], (hidden_size,))
176
+ self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
177
+
178
+ def forward(self, h: Array, x: Array) -> Array:
179
+ h = self.a_mat * h + self.b_mat.T @ x
180
+ h = jax.nn.tanh(h)
181
+ return h
182
+
183
+ def get_kernel(self, length: int) -> Array:
184
+ """Returns the kernel with time as the final dimension."""
185
+ exponents = jnp.arange(length)
186
+ kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
187
+ kernel = kernel[:, None, :] # (H, 1, L)
188
+ return kernel
189
+
190
+ def forward_across_time(self, x: Array) -> Array:
191
+ """Convolves x (T, H) across time using the kernel."""
192
+ tsz, nhid = x.shape
193
+
194
+ # Compute s = x @ U.T + b, with shape (N, T, H)
195
+ s = self.b_mat.T @ x
196
+ s = s.T # (H, T)
197
+
198
+ kernel = self.get_kernel(tsz) # (H, 1, T)
199
+ kernel_flipped = jnp.flip(kernel, axis=-1)
200
+
201
+ # Pad s on the left along the time axis (pad length T-1)
202
+ s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
203
+
204
+ # Perform depthwise (grouped) 1D convolution.
205
+ # We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
206
+ # The dimension_numbers are chosen so that the channel dimension is second.
207
+ conv_out = jax.lax.conv_general_dilated(
208
+ s_padded,
209
+ kernel_flipped,
210
+ window_strides=(1,),
211
+ padding="VALID",
212
+ dimension_numbers=("NCH", "OIH", "NCH"),
213
+ feature_group_count=nhid,
214
+ )
215
+ # conv_out has shape (N, H, T); transpose to (N, T, H)
216
+ conv_out = jnp.transpose(conv_out, (0, 2, 1))
217
+ return conv_out
218
+
219
+ def naive_forward_accross_time(self, x: Array) -> Array:
220
+ """Naively forward across time."""
221
+
222
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
223
+ h = self.forward(h, x)
224
+ return h, h
225
+
226
+ h_0 = jnp.zeros(self.a_mat.shape[0])
227
+ _, h_seq = jax.lax.scan(step, h_0, x)
228
+ return h_seq
229
+
230
+
231
+ class S4(eqx.Module):
232
+ vocab_embedding: eqx.nn.Embedding
233
+ proj_in: eqx.nn.Linear
234
+ proj_out: eqx.nn.Linear
235
+ blocks: list[BaseSSMBlock]
236
+ num_layers: int = eqx.static_field()
237
+ hidden_size: int = eqx.static_field()
238
+ skip_connections: bool = eqx.static_field()
239
+
240
+ def __init__(
241
+ self,
242
+ input_size: int,
243
+ hidden_size: int,
244
+ output_size: int,
245
+ num_layers: int,
246
+ block_type: Literal["ssm", "diag"] = "ssm",
247
+ skip_connections: bool = False,
248
+ *,
249
+ key: PRNGKeyArray,
250
+ ) -> None:
251
+ vocab_key, s4_key = jax.random.split(key, 2)
252
+ self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
253
+ self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
254
+ self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
255
+
256
+ block_keys = jax.random.split(s4_key, num_layers)
257
+
258
+ def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
259
+ match block_type:
260
+ case "ssm":
261
+ return SSMBlock(hidden_size, key=key)
262
+ case "diag":
263
+ return DiagSSMBlock(hidden_size, key=key)
264
+ case _:
265
+ raise ValueError(f"Unknown block type: {block_type}")
266
+
267
+ self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
268
+ self.skip_connections = skip_connections
269
+ self.num_layers = num_layers
270
+ self.hidden_size = hidden_size
271
+
272
+ def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
273
+ new_hs = []
274
+ for i, block in enumerate(self.blocks):
275
+ h = block.forward(hs[i], x)
276
+ new_hs.append(h)
277
+ xh = jax.nn.gelu(h)
278
+ x = xh + x if self.skip_connections else xh
279
+ y = self.proj_out(x)
280
+ return new_hs, y
281
+
282
+ def _embed_input(self, x: Array) -> Array:
283
+ """U is the input to the S4 cell."""
284
+ embedded = self.vocab_embedding(x)
285
+ return jax.nn.gelu(self.proj_in(embedded))
286
+
287
+ def predict_sequence(self, x_seq: Array) -> Array:
288
+ x_emb = jax.vmap(self._embed_input)(x_seq)
289
+ hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
290
+
291
+ def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
292
+ hs, y = self(hs, x)
293
+ return hs, y
294
+
295
+ _, y_seq = jax.lax.scan(step, hs, x_emb)
296
+ return y_seq
xax/task/mixins/train.py CHANGED
@@ -29,6 +29,7 @@ from typing import (
29
29
 
30
30
  import equinox as eqx
31
31
  import jax
32
+ import jax.numpy as jnp
32
33
  import numpy as np
33
34
  import optax
34
35
  from jaxtyping import Array, PRNGKeyArray, PyTree
@@ -56,6 +57,7 @@ from xax.utils.experiments import (
56
57
  from xax.utils.jax import jit as xax_jit
57
58
  from xax.utils.logging import LOG_STATUS
58
59
  from xax.utils.text import highlight_exception_message, show_info
60
+ from xax.utils.types.frozen_dict import FrozenDict
59
61
 
60
62
  logger = logging.getLogger(__name__)
61
63
 
@@ -161,6 +163,7 @@ class TrainConfig(
161
163
  max_steps: int | None = field(None, help="Maximum number of steps to run")
162
164
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
163
165
  random_seed: int = field(1337, help="Random seed for the task")
166
+ global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
164
167
 
165
168
 
166
169
  Config = TypeVar("Config", bound=TrainConfig)
@@ -215,7 +218,7 @@ class TrainMixin(
215
218
  state = super().on_step_end(state)
216
219
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
217
220
 
218
- def log_train_step(self, batch: Batch, output: Output, state: State) -> None:
221
+ def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
219
222
  """Override this function to do logging during the training phase.
220
223
 
221
224
  This function is called after the model forward pass and before the
@@ -224,10 +227,11 @@ class TrainMixin(
224
227
  Args:
225
228
  batch: The batch from the dataloader.
226
229
  output: The model output.
230
+ metrics: The metrics for the current batch.
227
231
  state: The current training state.
228
232
  """
229
233
 
230
- def log_valid_step(self, batch: Batch, output: Output, state: State) -> None:
234
+ def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
231
235
  """Override this function to do logging during the validation phase.
232
236
 
233
237
  This function is called after the model forward pass. It is called in
@@ -236,6 +240,7 @@ class TrainMixin(
236
240
  Args:
237
241
  batch: The batch from the dataloader.
238
242
  output: The model output.
243
+ metrics: The metrics for the current batch.
239
244
  state: The current training state.
240
245
  """
241
246
 
@@ -246,18 +251,23 @@ class TrainMixin(
246
251
  for k, v in d.items():
247
252
  self.logger.log_scalar(k, v, namespace=ns)
248
253
 
249
- def log_step(self, batch: Batch, output: Output, loss: Array, state: State) -> None:
254
+ def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
250
255
  phase = state.phase
251
256
 
252
- self.logger.log_scalar("loss", loss, namespace="loss")
257
+ for k, v in metrics.items():
258
+ if v.size == 1:
259
+ self.logger.log_scalar(k, v.item())
260
+ else:
261
+ self.logger.log_histogram(k, v)
262
+
253
263
  self.log_state_timers(state)
254
264
 
255
265
  # Delegate to the appropriate logging function based on the phase.
256
266
  match phase:
257
267
  case "train":
258
- self.log_train_step(batch, output, state)
268
+ self.log_train_step(batch, output, metrics, state)
259
269
  case "valid":
260
- self.log_valid_step(batch, output, state)
270
+ self.log_valid_step(batch, output, metrics, state)
261
271
  case _:
262
272
  raise KeyError(f"Unknown phase: {phase}")
263
273
 
@@ -364,32 +374,90 @@ class TrainMixin(
364
374
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
365
375
  return output
366
376
 
377
+ def compute_metrics(
378
+ self,
379
+ model: PyTree,
380
+ batch: Batch,
381
+ output: Output,
382
+ loss: Array,
383
+ state: State,
384
+ ) -> dict[str, Array]:
385
+ """Computes the metrics for the current batch.
386
+
387
+ Args:
388
+ model: The current model.
389
+ batch: The current minibatch of samples.
390
+ output: The output from the model.
391
+ loss: The loss for the current batch.
392
+ state: The current training state.
393
+
394
+ Returns:
395
+ A dictionary of metrics.
396
+ """
397
+ return {
398
+ "loss": loss,
399
+ }
400
+
401
+ @xax_jit(static_argnames=["self", "model_static"])
367
402
  def get_output_and_loss(
368
403
  self,
369
- model_static: PyTree,
370
404
  model_arr: PyTree,
405
+ model_static: PyTree,
371
406
  batch: Batch,
372
407
  state: State,
373
- ) -> tuple[Array, Output]:
408
+ ) -> tuple[Array, tuple[Output, dict[str, Array]]]:
374
409
  model = eqx.combine(model_arr, model_static)
375
410
  output = self.get_output(model, batch, state)
376
411
  loss = self.compute_loss(model, batch, output, state)
377
- return loss, output
412
+ metrics = self.compute_metrics(model, batch, output, loss, state)
413
+ return loss, (output, metrics)
378
414
 
379
415
  def update(
380
416
  self,
381
- model_static: PyTree,
382
417
  model_arr: PyTree,
418
+ model_static: PyTree,
383
419
  optimizer: optax.GradientTransformation,
384
420
  opt_state: optax.OptState,
385
421
  batch: Batch,
386
422
  state: State,
387
- ) -> tuple[Array, PyTree, optax.OptState, Output]:
388
- grad_fn = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)
389
- (loss, output), grads = grad_fn(model_static, model_arr, batch, state)
390
- updates, opt_state = optimizer.update(grads, opt_state, model_arr)
391
- model_arr = eqx.apply_updates(model_arr, updates)
392
- return loss, model_arr, opt_state, output
423
+ ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
424
+ grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
425
+ grad_fn = xax_jit(static_argnums=[1])(grad_fn)
426
+ grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
427
+ model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
428
+ return model_arr, opt_state, output, metrics | grad_metrics
429
+
430
+ @xax_jit(static_argnames=["self", "optimizer"])
431
+ def apply_gradients_with_clipping(
432
+ self,
433
+ model_arr: PyTree,
434
+ grads: PyTree,
435
+ optimizer: optax.GradientTransformation,
436
+ opt_state: optax.OptState,
437
+ ) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
438
+ grad_norm = optax.global_norm(grads)
439
+ grad_metrics = {"grad_norm": grad_norm}
440
+
441
+ def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
442
+ # Clip the global gradient norm to some desired range.
443
+ grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
444
+ grads = jax.tree.map(lambda x: x * grad_factor, grads)
445
+
446
+ # Apply the gradient updates.
447
+ updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
448
+ new_model_arr = eqx.apply_updates(model_arr, updates)
449
+ return new_model_arr, new_opt_state
450
+
451
+ # Don't apply updates if the gradient is NaN or Inf.
452
+ new_model_arr, new_opt_state = jax.lax.cond(
453
+ jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
454
+ lambda *_: (model_arr, opt_state),
455
+ apply,
456
+ grads,
457
+ grad_norm,
458
+ )
459
+
460
+ return new_model_arr, new_opt_state, grad_metrics
393
461
 
394
462
  def get_size_of_batch(self, batch: Batch) -> int | None:
395
463
  """Gets the batch size for the current batch.
@@ -469,25 +537,26 @@ class TrainMixin(
469
537
  @xax_jit(static_argnames=["self", "model_static", "optimizer"])
470
538
  def train_step(
471
539
  self,
472
- model_static: PyTree,
473
540
  model_arr: PyTree,
541
+ model_static: PyTree,
474
542
  optimizer: optax.GradientTransformation,
475
543
  opt_state: optax.OptState,
476
544
  batch: Batch,
477
545
  state: State,
478
- ) -> tuple[PyTree, optax.OptState, Array, Output]:
479
- loss, model_arr, opt_state, output = self.update(model_static, model_arr, optimizer, opt_state, batch, state)
480
- return model_arr, opt_state, loss, output
546
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
547
+ model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
548
+ return model_arr, opt_state, output, FrozenDict(metrics)
481
549
 
482
550
  @xax_jit(static_argnames=["self", "model_static"])
483
551
  def val_step(
484
552
  self,
485
- model_static: PyTree,
486
553
  model_arr: PyTree,
554
+ model_static: PyTree,
487
555
  batch: Batch,
488
556
  state: State,
489
- ) -> tuple[Array, Output]:
490
- return self.get_output_and_loss(model_static, model_arr, batch, state)
557
+ ) -> tuple[Output, FrozenDict[str, Array]]:
558
+ _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
559
+ return output, FrozenDict(metrics)
491
560
 
492
561
  def train_loop(
493
562
  self,
@@ -509,8 +578,8 @@ class TrainMixin(
509
578
  num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
510
579
  )
511
580
 
512
- loss, output = self.val_step(model_static, model_arr, valid_batch, state)
513
- self.log_step(valid_batch, output, loss, state)
581
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
582
+ self.log_step(valid_batch, output, metrics, state)
514
583
 
515
584
  state = self.on_step_start(state)
516
585
  train_batch = next(train_pf)
@@ -520,15 +589,15 @@ class TrainMixin(
520
589
  num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
521
590
  )
522
591
 
523
- model_arr, opt_state, loss, output = self.train_step(
524
- model_static=model_static,
592
+ model_arr, opt_state, output, metrics = self.train_step(
525
593
  model_arr=model_arr,
594
+ model_static=model_static,
526
595
  optimizer=optimizer,
527
596
  opt_state=opt_state,
528
597
  batch=train_batch,
529
598
  state=state,
530
599
  )
531
- self.log_step(train_batch, output, loss, state)
600
+ self.log_step(train_batch, output, metrics, state)
532
601
 
533
602
  state = self.on_step_end(state)
534
603
 
xax/utils/pytree.py CHANGED
@@ -31,7 +31,7 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
31
31
 
32
32
  def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
33
33
  """Get a slice of a pytree."""
34
- return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
34
+ return jax.tree.map(lambda x: slice_array(x, start, slice_length), pytree)
35
35
 
36
36
 
37
37
  def flatten_array(x: Array, flatten_size: int) -> Array:
@@ -43,14 +43,14 @@ def flatten_array(x: Array, flatten_size: int) -> Array:
43
43
 
44
44
  def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
45
45
  """Flatten a pytree into a (flatten_size, ...) pytree."""
46
- return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
46
+ return jax.tree.map(lambda x: flatten_array(x, flatten_size), pytree)
47
47
 
48
48
 
49
49
  def pytree_has_nans(pytree: PyTree) -> Array:
50
50
  """Check if a pytree has any NaNs."""
51
51
  has_nans = jax.tree_util.tree_reduce(
52
52
  lambda a, b: jnp.logical_or(a, b),
53
- jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
53
+ jax.tree.map(lambda x: jnp.any(jnp.isnan(x)), pytree),
54
54
  )
55
55
  return has_nans
56
56
 
@@ -58,13 +58,13 @@ def pytree_has_nans(pytree: PyTree) -> Array:
58
58
  def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
59
  """Update a pytree based on a condition."""
60
60
  # Tricky, need use tree_map because where expects array leafs.
61
- return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
61
+ return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
62
62
 
63
63
 
64
64
  def compute_nan_ratio(pytree: PyTree) -> Array:
65
65
  """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
66
- nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
- total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
66
+ nan_counts = jax.tree.map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
+ total_counts = jax.tree.map(lambda x: x.size, pytree)
68
68
 
69
69
  total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
70
70
  total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
@@ -118,7 +118,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
118
118
  # Reshape back to the original shape
119
119
  return permuted.reshape(orig_shape)
120
120
 
121
- return jax.tree_util.tree_map(permute_array, data)
121
+ return jax.tree.map(permute_array, data)
122
122
 
123
123
 
124
124
  def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
@@ -133,7 +133,7 @@ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], r
133
133
  return x[tuple(idx_grids)]
134
134
  return x
135
135
 
136
- return jax.tree_util.tree_map(permute_array, data)
136
+ return jax.tree.map(permute_array, data)
137
137
 
138
138
 
139
139
  TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
@@ -215,7 +215,7 @@ def reshuffle_pytree_along_dims(
215
215
  transpose_info[path] = (transpose_order, original_shape)
216
216
  return x
217
217
 
218
- jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
218
+ jax.tree.map_with_path(prepare_for_shuffle, data)
219
219
 
220
220
  # Create a transposed pytree
221
221
  def get_transposed(path: PathType, x: PyTree) -> PyTree:
@@ -223,7 +223,7 @@ def reshuffle_pytree_along_dims(
223
223
  return transposed_data[path]
224
224
  return x
225
225
 
226
- transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
226
+ transposed_pytree = jax.tree.map_with_path(get_transposed, data)
227
227
 
228
228
  # Reshuffle the transposed pytree along the leading dimensions
229
229
  reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
@@ -235,4 +235,4 @@ def reshuffle_pytree_along_dims(
235
235
  return transpose_back(x, transpose_order, original_shape)
236
236
  return x
237
237
 
238
- return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
238
+ return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=_xb60-jl7arZEleSwUw4ElPaq4MzD24_ZYQrnWO5_cs,13391
1
+ xax/__init__.py,sha256=2JdSxsZphJJFVMGBVXNc0hP2p0FVOu5y7xSgPRNeyNY,13835
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -11,8 +11,10 @@ xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
13
  xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
- xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
14
+ xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
+ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
15
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
+ xax/nn/ssm.py,sha256=eFeGkV1pkVGc0vNrQbykCbFnlPXQqsqVA_JVzLBHD28,9865
16
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
19
  xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
18
20
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
@@ -39,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=JbrSiBqpgOrdDanNYuAzzh2radPrXOVrHYA6VcxjIzY,23248
44
+ xax/task/mixins/train.py,sha256=lgLHiHQtnDK0XS3SwHTYZtDv5CTbPRN1-p_K9KiIpHQ,26000
43
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
46
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
47
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -48,7 +50,7 @@ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
49
51
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
50
52
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
51
- xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
53
+ xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
52
54
  xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
53
55
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
56
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
56
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
58
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
59
- xax-0.1.9.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
- xax-0.1.9.dist-info/METADATA,sha256=Ou8KmYWWNxgo_9ZAU2KLaeGeXAxd6b9qJ95ky4HRm-o,1877
61
- xax-0.1.9.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
- xax-0.1.9.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
- xax-0.1.9.dist-info/RECORD,,
61
+ xax-0.1.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.11.dist-info/METADATA,sha256=qDhn5EGxdiuEe5gQUZiBC430sXhJOPRWboTvsh2onxs,1878
63
+ xax-0.1.11.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.11.dist-info/RECORD,,
File without changes