xax 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.10"
15
+ __version__ = "0.1.12"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,9 +43,15 @@ __all__ = [
43
43
  "euler_to_quat",
44
44
  "get_projected_gravity_vector_from_quat",
45
45
  "quat_to_euler",
46
+ "rotate_vector_by_quat",
47
+ "cross_entropy",
46
48
  "cast_norm_type",
47
49
  "get_norm",
48
50
  "is_master",
51
+ "BaseSSMBlock",
52
+ "DiagSSMBlock",
53
+ "SSM",
54
+ "SSMBlock",
49
55
  "BaseLauncher",
50
56
  "CliLauncher",
51
57
  "SingleProcessLauncher",
@@ -196,9 +202,15 @@ NAME_MAP: dict[str, str] = {
196
202
  "euler_to_quat": "nn.geom",
197
203
  "get_projected_gravity_vector_from_quat": "nn.geom",
198
204
  "quat_to_euler": "nn.geom",
205
+ "rotate_vector_by_quat": "nn.geom",
206
+ "cross_entropy": "nn.losses",
199
207
  "cast_norm_type": "nn.norm",
200
208
  "get_norm": "nn.norm",
201
209
  "is_master": "nn.parallel",
210
+ "BaseSSMBlock": "nn.ssm",
211
+ "DiagSSMBlock": "nn.ssm",
212
+ "SSM": "nn.ssm",
213
+ "SSMBlock": "nn.ssm",
202
214
  "BaseLauncher": "task.launchers.base",
203
215
  "CliLauncher": "task.launchers.cli",
204
216
  "SingleProcessLauncher": "task.launchers.single_process",
@@ -350,9 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING:
350
362
  euler_to_quat,
351
363
  get_projected_gravity_vector_from_quat,
352
364
  quat_to_euler,
365
+ rotate_vector_by_quat,
353
366
  )
367
+ from xax.nn.losses import cross_entropy
354
368
  from xax.nn.norm import NormType, cast_norm_type, get_norm
355
369
  from xax.nn.parallel import is_master
370
+ from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
356
371
  from xax.task.base import RawConfigType
357
372
  from xax.task.launchers.base import BaseLauncher
358
373
  from xax.task.launchers.cli import CliLauncher
xax/nn/geom.py CHANGED
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
99
99
 
100
100
  # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
101
101
  return jnp.concatenate([gx, gy, -gz], axis=-1)
102
+
103
+
104
+ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
105
+ """Rotates a vector by a quaternion.
106
+
107
+ Args:
108
+ vector: The vector to rotate, shape (*, 3).
109
+ quat: The quaternion to rotate by, shape (*, 4).
110
+ eps: A small epsilon value to avoid division by zero.
111
+
112
+ Returns:
113
+ The rotated vector, shape (*, 3).
114
+ """
115
+ # Normalize quaternion
116
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
117
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
118
+
119
+ # Extract vector components
120
+ vx, vy, vz = jnp.split(vector, 3, axis=-1)
121
+
122
+ # Terms for x component
123
+ xx = (
124
+ w * w * vx
125
+ + 2 * y * w * vz
126
+ - 2 * z * w * vy
127
+ + x * x * vx
128
+ + 2 * y * x * vy
129
+ + 2 * z * x * vz
130
+ - z * z * vx
131
+ - y * y * vx
132
+ )
133
+
134
+ # Terms for y component
135
+ yy = (
136
+ 2 * x * y * vx
137
+ + y * y * vy
138
+ + 2 * z * y * vz
139
+ + 2 * w * z * vx
140
+ - z * z * vy
141
+ + w * w * vy
142
+ - 2 * w * x * vz
143
+ - x * x * vy
144
+ )
145
+
146
+ # Terms for z component
147
+ zz = (
148
+ 2 * x * z * vx
149
+ + 2 * y * z * vy
150
+ + z * z * vz
151
+ - 2 * w * y * vx
152
+ + w * w * vz
153
+ + 2 * w * x * vy
154
+ - y * y * vz
155
+ - x * x * vz
156
+ )
157
+
158
+ return jnp.concatenate([xx, yy, zz], axis=-1)
xax/nn/losses.py ADDED
@@ -0,0 +1,9 @@
1
+ """Defines some common loss functions."""
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array
5
+
6
+
7
+ def cross_entropy(y: Array, pred_y: Array, axis: int = 1) -> Array:
8
+ pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, axis), axis=axis)
9
+ return -jnp.mean(pred_y)
xax/nn/norm.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from typing import Literal, cast, get_args
4
4
 
5
5
  import jax.numpy as jnp
6
+ from jaxtyping import Array
6
7
 
7
8
  NormType = Literal["l1", "l2"]
8
9
 
@@ -13,7 +14,7 @@ def cast_norm_type(norm: str) -> NormType:
13
14
  return cast(NormType, norm)
14
15
 
15
16
 
16
- def get_norm(x: jnp.ndarray, norm: NormType) -> jnp.ndarray:
17
+ def get_norm(x: Array, norm: NormType) -> Array:
17
18
  match norm:
18
19
  case "l1":
19
20
  return jnp.abs(x)
xax/nn/ssm.py ADDED
@@ -0,0 +1,316 @@
1
+ """State space models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Literal
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, PRNGKeyArray
10
+
11
+
12
+ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
13
+ return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
14
+
15
+
16
+ class BaseSSMBlock(eqx.Module, ABC):
17
+ @abstractmethod
18
+ def forward(self, h: Array, x: Array) -> Array: ...
19
+
20
+ @abstractmethod
21
+ def forward_sequence(self, x_seq: Array) -> Array: ...
22
+
23
+ @abstractmethod
24
+ def get_a_mat(self, x: Array) -> Array: ...
25
+
26
+ @abstractmethod
27
+ def get_b_mat(self, x: Array) -> Array: ...
28
+
29
+
30
+ class SSMBlock(BaseSSMBlock):
31
+ a_mat: Array
32
+ b_mat: Array
33
+
34
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
35
+ key_a, key_b = jax.random.split(key)
36
+ self.a_mat = glorot(key_a, (hidden_size, hidden_size))
37
+ self.b_mat = glorot(key_b, (hidden_size, hidden_size))
38
+
39
+ def get_a_mat(self, x: Array) -> Array:
40
+ return self.a_mat
41
+
42
+ def get_b_mat(self, x: Array) -> Array:
43
+ return self.b_mat
44
+
45
+ def forward(self, h: Array, x: Array) -> Array:
46
+ """Perform a forward pass.
47
+
48
+ Args:
49
+ h: Hidden state of shape (H,).
50
+ x: Input of shape (H,).
51
+
52
+ Returns:
53
+ Hidden state of shape (H,).
54
+ """
55
+ a_mat = self.get_a_mat(x)
56
+ b_mat = self.get_b_mat(x)
57
+ h = a_mat @ h + b_mat.T @ x
58
+ return h
59
+
60
+ def forward_sequence(self, x_seq: Array) -> Array:
61
+ """Perform a forward pass across time.
62
+
63
+ Args:
64
+ x_seq: Input sequence of shape (T, H).
65
+
66
+ Returns:
67
+ Hidden state sequence of shape (T, H).
68
+ """
69
+
70
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
71
+ h = self.forward(h, x)
72
+ return h, h
73
+
74
+ a_mat = self.get_a_mat(x_seq)
75
+ h_0 = jnp.zeros(a_mat.shape[0])
76
+ _, h_seq = jax.lax.scan(step, h_0, x_seq)
77
+ return h_seq
78
+
79
+
80
+ class DiagSSMBlock(BaseSSMBlock):
81
+ a_diag: Array
82
+ b_mat: Array
83
+
84
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
85
+ keys = jax.random.split(key, 2)
86
+ self.a_diag = glorot(keys[0], (hidden_size,))
87
+ self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
88
+
89
+ def get_a_mat(self, x: Array) -> Array:
90
+ return self.a_diag
91
+
92
+ def get_b_mat(self, x: Array) -> Array:
93
+ return self.b_mat
94
+
95
+ def forward(self, h: Array, x: Array) -> Array:
96
+ """Perform a forward pass.
97
+
98
+ Args:
99
+ h: Hidden state of shape (H,).
100
+ x: Input of shape (H,).
101
+
102
+ Returns:
103
+ Hidden state of shape (H,).
104
+ """
105
+ a_diag = self.get_a_mat(x)
106
+ b_mat = self.get_b_mat(x)
107
+ h = a_diag * h + b_mat.T @ x
108
+ return h
109
+
110
+ def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
111
+ """Perform a potentially parallelized forward pass across time.
112
+
113
+ Args:
114
+ x_seq: Input sequence of shape (T, H).
115
+ use_conv: Whether to use convolution to compute the sequence.
116
+ recursive_kernel_calc: Whether to use a recursive kernel calculation.
117
+
118
+ Returns:
119
+ Hidden state sequence of shape (T, H).
120
+ """
121
+ if use_conv:
122
+ return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
123
+ else:
124
+ return self._forward_sequence_scan(x_seq)
125
+
126
+ def _get_kernel(self, x_seq: Array, length: int) -> Array:
127
+ """Returns the kernel with time as the final dimension."""
128
+ exponents = jnp.arange(length)
129
+ a_diag = self.get_a_mat(x_seq)
130
+ kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
131
+ kernel = kernel[:, None, :] # (H, 1, T)
132
+ return kernel
133
+
134
+ def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
135
+ """Returns the kernel with time as the final dimension."""
136
+ assert length % 2 == 0, "Length must be even."
137
+ a_diag = self.get_a_mat(x_seq)
138
+
139
+ def helper(length: int) -> tuple[Array, Array]:
140
+ """Returns the kernel and the sqrt of the diagonal."""
141
+ if length == 1:
142
+ return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
143
+
144
+ half_length = length // 2
145
+ kernel_half, a_half = helper(half_length)
146
+ kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
147
+ return kernel, a_half * a_half
148
+
149
+ kernel, a_diag = helper(length)
150
+ return kernel[:, None, :] # (H, 1, L)
151
+
152
+ def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
153
+ """Convolves x (T, H) across time using the kernel."""
154
+ seq_len, hidden_size = x_seq.shape
155
+ b_mat = self.get_b_mat(x_seq)
156
+
157
+ s = b_mat.T @ x_seq.T # (H, T)
158
+ s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
159
+
160
+ if recursive_kernel_calc:
161
+ kernel = self._get_kernel_recursive(x_seq, seq_len)
162
+ else:
163
+ kernel = self._get_kernel(x_seq, seq_len)
164
+
165
+ kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
166
+
167
+ conv_out = jax.lax.conv_general_dilated(
168
+ s_padded,
169
+ kernel_flipped,
170
+ window_strides=(1,),
171
+ padding="VALID",
172
+ dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
173
+ feature_group_count=hidden_size,
174
+ )
175
+ conv_out = conv_out[0].T # (T, H)
176
+ return conv_out
177
+
178
+ def _forward_sequence_scan(self, x_seq: Array) -> Array:
179
+ """Naively forward across time."""
180
+
181
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
182
+ h = self.forward(h, x)
183
+ return h, h
184
+
185
+ a_diag = self.get_a_mat(x_seq)
186
+ h_0 = jnp.zeros(a_diag.shape[0])
187
+ _, h_seq = jax.lax.scan(step, h_0, x_seq)
188
+ return h_seq
189
+
190
+
191
+ class DiscreteDiagSSMBlock(DiagSSMBlock):
192
+ delta: Array
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size: int,
197
+ *,
198
+ key: PRNGKeyArray,
199
+ init_delta: float = 1.0,
200
+ init_scale: float = 10.0,
201
+ ) -> None:
202
+ super().__init__(hidden_size, key=key)
203
+ self.delta = jnp.array(init_delta)
204
+
205
+ # A positive scale helps reduce the gradient at the start.
206
+ self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
207
+
208
+ def get_a_mat(self, x: Array) -> Array:
209
+ """Discretize the diagonal matrix using zero-order hold."""
210
+ a_diag_discrete = jnp.exp(self.a_diag * self.delta)
211
+ return a_diag_discrete
212
+
213
+ def get_b_mat(self, x: Array) -> Array:
214
+ """Discretize the input matrix using zero-order hold."""
215
+ delta_a_diag = self.a_diag * self.delta
216
+ exp_a_diag = jnp.exp(delta_a_diag)
217
+ delta_a_inv = 1 / delta_a_diag
218
+ delta_b_mat = self.delta * self.b_mat
219
+
220
+ b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
221
+ return b_discrete
222
+
223
+
224
+ class SSM(eqx.Module):
225
+ vocab_embedding: eqx.nn.Embedding
226
+ output_layer: eqx.nn.Linear
227
+ blocks: list[BaseSSMBlock]
228
+ num_layers: int = eqx.static_field()
229
+ hidden_size: int = eqx.static_field()
230
+ skip_connections: bool = eqx.static_field()
231
+
232
+ def __init__(
233
+ self,
234
+ input_size: int,
235
+ hidden_size: int,
236
+ output_size: int,
237
+ num_layers: int,
238
+ block_type: Literal["diagonal", "full_rank"] = "full_rank",
239
+ skip_connections: bool = False,
240
+ discretize: bool = False,
241
+ *,
242
+ key: PRNGKeyArray,
243
+ ) -> None:
244
+ vocab_key, s4_key = jax.random.split(key, 2)
245
+ self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
246
+ self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
247
+
248
+ block_keys = jax.random.split(s4_key, num_layers)
249
+
250
+ def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
251
+ match block_type:
252
+ case "diagonal":
253
+ return (
254
+ DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
255
+ if discretize
256
+ else DiagSSMBlock(hidden_size, key=key)
257
+ )
258
+ case "full_rank":
259
+ if discretize:
260
+ raise ValueError("Full rank blocks do not support discretization due to instability.")
261
+ return SSMBlock(hidden_size, key=key)
262
+ case _:
263
+ raise ValueError(f"Unknown block type: {block_type}")
264
+
265
+ self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
266
+ self.skip_connections = skip_connections
267
+ self.num_layers = num_layers
268
+ self.hidden_size = hidden_size
269
+
270
+ def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
271
+ new_hs = []
272
+ for i, block in enumerate(self.blocks):
273
+ h = block.forward(hs[i], x)
274
+ new_hs.append(h)
275
+ xh = jax.nn.gelu(h)
276
+ x = xh + x if self.skip_connections else xh
277
+ y = self.output_layer(x)
278
+ return new_hs, y
279
+
280
+ def _embed_input(self, x: Array) -> Array:
281
+ """U is the input to the S4 cell."""
282
+ return self.vocab_embedding(x)
283
+
284
+ def predict_sequence(self, x_seq: Array) -> Array:
285
+ x_emb = jax.vmap(self._embed_input)(x_seq)
286
+ for block in self.blocks:
287
+ h = block.forward_sequence(x_emb)
288
+ # h = block.naive_forward_sequence(x_emb)
289
+ h = jax.nn.gelu(h)
290
+ x_emb = h + x_emb if self.skip_connections else h
291
+ y = jax.vmap(self.output_layer)(x_emb)
292
+ return y
293
+
294
+ def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
295
+ hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
296
+ prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
297
+
298
+ def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
299
+ hs, y = self(hs, x)
300
+ return hs, y
301
+
302
+ def decode_step(
303
+ carry: tuple[list[Array], Array, PRNGKeyArray],
304
+ _: None,
305
+ ) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
306
+ hs, last_token, rng = carry
307
+ token_embedded = self._embed_input(last_token)
308
+ hs, y = self(hs, token_embedded)
309
+ token = jax.random.categorical(rng, y)
310
+ rng = jax.random.split(rng)[0]
311
+ return (hs, token, rng), token
312
+
313
+ hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
314
+ _, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
315
+
316
+ return sequence
xax/task/mixins/train.py CHANGED
@@ -29,6 +29,7 @@ from typing import (
29
29
 
30
30
  import equinox as eqx
31
31
  import jax
32
+ import jax.numpy as jnp
32
33
  import numpy as np
33
34
  import optax
34
35
  from jaxtyping import Array, PRNGKeyArray, PyTree
@@ -162,6 +163,7 @@ class TrainConfig(
162
163
  max_steps: int | None = field(None, help="Maximum number of steps to run")
163
164
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
164
165
  random_seed: int = field(1337, help="Random seed for the task")
166
+ global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
165
167
 
166
168
 
167
169
  Config = TypeVar("Config", bound=TrainConfig)
@@ -216,26 +218,32 @@ class TrainMixin(
216
218
  state = super().on_step_end(state)
217
219
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
218
220
 
219
- def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
221
+ def log_train_step(
222
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
223
+ ) -> None:
220
224
  """Override this function to do logging during the training phase.
221
225
 
222
226
  This function is called after the model forward pass and before the
223
227
  backward pass. It is called in the training phase.
224
228
 
225
229
  Args:
230
+ model: The current model.
226
231
  batch: The batch from the dataloader.
227
232
  output: The model output.
228
233
  metrics: The metrics for the current batch.
229
234
  state: The current training state.
230
235
  """
231
236
 
232
- def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
237
+ def log_valid_step(
238
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
239
+ ) -> None:
233
240
  """Override this function to do logging during the validation phase.
234
241
 
235
242
  This function is called after the model forward pass. It is called in
236
243
  the validation phase.
237
244
 
238
245
  Args:
246
+ model: The current model.
239
247
  batch: The batch from the dataloader.
240
248
  output: The model output.
241
249
  metrics: The metrics for the current batch.
@@ -249,7 +257,9 @@ class TrainMixin(
249
257
  for k, v in d.items():
250
258
  self.logger.log_scalar(k, v, namespace=ns)
251
259
 
252
- def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
260
+ def log_step(
261
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
262
+ ) -> None:
253
263
  phase = state.phase
254
264
 
255
265
  for k, v in metrics.items():
@@ -263,9 +273,9 @@ class TrainMixin(
263
273
  # Delegate to the appropriate logging function based on the phase.
264
274
  match phase:
265
275
  case "train":
266
- self.log_train_step(batch, output, metrics, state)
276
+ self.log_train_step(model, batch, output, metrics, state)
267
277
  case "valid":
268
- self.log_valid_step(batch, output, metrics, state)
278
+ self.log_valid_step(model, batch, output, metrics, state)
269
279
  case _:
270
280
  raise KeyError(f"Unknown phase: {phase}")
271
281
 
@@ -403,12 +413,12 @@ class TrainMixin(
403
413
  model_static: PyTree,
404
414
  batch: Batch,
405
415
  state: State,
406
- ) -> tuple[Array, tuple[Output, FrozenDict[str, Array]]]:
416
+ ) -> tuple[Array, tuple[Output, dict[str, Array]]]:
407
417
  model = eqx.combine(model_arr, model_static)
408
418
  output = self.get_output(model, batch, state)
409
419
  loss = self.compute_loss(model, batch, output, state)
410
420
  metrics = self.compute_metrics(model, batch, output, loss, state)
411
- return loss, (output, FrozenDict(metrics))
421
+ return loss, (output, metrics)
412
422
 
413
423
  def update(
414
424
  self,
@@ -418,13 +428,44 @@ class TrainMixin(
418
428
  opt_state: optax.OptState,
419
429
  batch: Batch,
420
430
  state: State,
421
- ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
431
+ ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
422
432
  grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
423
433
  grad_fn = xax_jit(static_argnums=[1])(grad_fn)
424
434
  grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
425
- updates, opt_state = optimizer.update(grads, opt_state, model_arr)
426
- model_arr = eqx.apply_updates(model_arr, updates)
427
- return model_arr, opt_state, output, metrics
435
+ model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
436
+ return model_arr, opt_state, output, metrics | grad_metrics
437
+
438
+ @xax_jit(static_argnames=["self", "optimizer"])
439
+ def apply_gradients_with_clipping(
440
+ self,
441
+ model_arr: PyTree,
442
+ grads: PyTree,
443
+ optimizer: optax.GradientTransformation,
444
+ opt_state: optax.OptState,
445
+ ) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
446
+ grad_norm = optax.global_norm(grads)
447
+ grad_metrics = {"grad_norm": grad_norm}
448
+
449
+ def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
450
+ # Clip the global gradient norm to some desired range.
451
+ grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
452
+ grads = jax.tree.map(lambda x: x * grad_factor, grads)
453
+
454
+ # Apply the gradient updates.
455
+ updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
456
+ new_model_arr = eqx.apply_updates(model_arr, updates)
457
+ return new_model_arr, new_opt_state
458
+
459
+ # Don't apply updates if the gradient is NaN or Inf.
460
+ new_model_arr, new_opt_state = jax.lax.cond(
461
+ jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
462
+ lambda *_: (model_arr, opt_state),
463
+ apply,
464
+ grads,
465
+ grad_norm,
466
+ )
467
+
468
+ return new_model_arr, new_opt_state, grad_metrics
428
469
 
429
470
  def get_size_of_batch(self, batch: Batch) -> int | None:
430
471
  """Gets the batch size for the current batch.
@@ -512,7 +553,7 @@ class TrainMixin(
512
553
  state: State,
513
554
  ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
514
555
  model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
515
- return model_arr, opt_state, output, metrics
556
+ return model_arr, opt_state, output, FrozenDict(metrics)
516
557
 
517
558
  @xax_jit(static_argnames=["self", "model_static"])
518
559
  def val_step(
@@ -523,7 +564,7 @@ class TrainMixin(
523
564
  state: State,
524
565
  ) -> tuple[Output, FrozenDict[str, Array]]:
525
566
  _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
526
- return output, metrics
567
+ return output, FrozenDict(metrics)
527
568
 
528
569
  def train_loop(
529
570
  self,
@@ -546,7 +587,7 @@ class TrainMixin(
546
587
  )
547
588
 
548
589
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
549
- self.log_step(valid_batch, output, metrics, state)
590
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
550
591
 
551
592
  state = self.on_step_start(state)
552
593
  train_batch = next(train_pf)
@@ -564,7 +605,7 @@ class TrainMixin(
564
605
  batch=train_batch,
565
606
  state=state,
566
607
  )
567
- self.log_step(train_batch, output, metrics, state)
608
+ self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
568
609
 
569
610
  state = self.on_step_end(state)
570
611
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=bvOBMlEVA46I7ILGfk5AbpwpcdTAjw-4vWI7ci7L7-g,13392
1
+ xax/__init__.py,sha256=7vdTYO7jAJdDxKZURlFxc3Y5kr5mVQcTQjeh_sYjD6I,13834
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -10,9 +10,11 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
- xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
- xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
13
+ xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
14
+ xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
+ xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
15
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
+ xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
16
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
19
  xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
18
20
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
@@ -39,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=jAzc9RD25DbhekvItzsRQQrK9aEwtA_sXy0m2Hfkuxo,24594
44
+ xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
43
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
46
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
47
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -56,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
56
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
58
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
59
- xax-0.1.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
- xax-0.1.10.dist-info/METADATA,sha256=kJ1lxZ6cWrtJ5R-adTorzEE_1l0VRJ67xfuBjYXG9Vo,1878
61
- xax-0.1.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
- xax-0.1.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
- xax-0.1.10.dist-info/RECORD,,
61
+ xax-0.1.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.12.dist-info/METADATA,sha256=hLRAX5__7QjBgjzhxbRftGvEsNrt8IAdgd22dMtHu_Y,1878
63
+ xax-0.1.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.12.dist-info/RECORD,,
File without changes