xax 0.3.4__tar.gz → 0.3.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.4/xax.egg-info → xax-0.3.9}/PKG-INFO +1 -1
  2. {xax-0.3.4 → xax-0.3.9}/xax/__init__.py +13 -1
  3. {xax-0.3.4 → xax-0.3.9}/xax/nn/attention.py +144 -92
  4. xax-0.3.9/xax/nn/distributions.py +181 -0
  5. {xax-0.3.4 → xax-0.3.9}/xax/nn/embeddings.py +10 -10
  6. {xax-0.3.4 → xax-0.3.9}/xax/nn/geom.py +5 -5
  7. {xax-0.3.4 → xax-0.3.9}/xax/nn/ssm.py +6 -6
  8. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/data_loader.py +7 -2
  9. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/train.py +51 -58
  10. {xax-0.3.4 → xax-0.3.9}/xax/utils/pytree.py +13 -0
  11. {xax-0.3.4 → xax-0.3.9/xax.egg-info}/PKG-INFO +1 -1
  12. {xax-0.3.4 → xax-0.3.9}/xax.egg-info/SOURCES.txt +1 -0
  13. {xax-0.3.4 → xax-0.3.9}/LICENSE +0 -0
  14. {xax-0.3.4 → xax-0.3.9}/MANIFEST.in +0 -0
  15. {xax-0.3.4 → xax-0.3.9}/README.md +0 -0
  16. {xax-0.3.4 → xax-0.3.9}/pyproject.toml +0 -0
  17. {xax-0.3.4 → xax-0.3.9}/setup.cfg +0 -0
  18. {xax-0.3.4 → xax-0.3.9}/setup.py +0 -0
  19. {xax-0.3.4 → xax-0.3.9}/xax/cli/__init__.py +0 -0
  20. {xax-0.3.4 → xax-0.3.9}/xax/cli/edit_config.py +0 -0
  21. {xax-0.3.4 → xax-0.3.9}/xax/core/__init__.py +0 -0
  22. {xax-0.3.4 → xax-0.3.9}/xax/core/conf.py +0 -0
  23. {xax-0.3.4 → xax-0.3.9}/xax/core/state.py +0 -0
  24. {xax-0.3.4 → xax-0.3.9}/xax/nn/__init__.py +0 -0
  25. {xax-0.3.4 → xax-0.3.9}/xax/nn/functions.py +0 -0
  26. {xax-0.3.4 → xax-0.3.9}/xax/nn/losses.py +0 -0
  27. {xax-0.3.4 → xax-0.3.9}/xax/nn/metrics.py +0 -0
  28. {xax-0.3.4 → xax-0.3.9}/xax/nn/parallel.py +0 -0
  29. {xax-0.3.4 → xax-0.3.9}/xax/py.typed +0 -0
  30. {xax-0.3.4 → xax-0.3.9}/xax/requirements-dev.txt +0 -0
  31. {xax-0.3.4 → xax-0.3.9}/xax/requirements.txt +0 -0
  32. {xax-0.3.4 → xax-0.3.9}/xax/task/__init__.py +0 -0
  33. {xax-0.3.4 → xax-0.3.9}/xax/task/base.py +0 -0
  34. {xax-0.3.4 → xax-0.3.9}/xax/task/launchers/__init__.py +0 -0
  35. {xax-0.3.4 → xax-0.3.9}/xax/task/launchers/base.py +0 -0
  36. {xax-0.3.4 → xax-0.3.9}/xax/task/launchers/cli.py +0 -0
  37. {xax-0.3.4 → xax-0.3.9}/xax/task/launchers/single_process.py +0 -0
  38. {xax-0.3.4 → xax-0.3.9}/xax/task/logger.py +0 -0
  39. {xax-0.3.4 → xax-0.3.9}/xax/task/loggers/__init__.py +0 -0
  40. {xax-0.3.4 → xax-0.3.9}/xax/task/loggers/callback.py +0 -0
  41. {xax-0.3.4 → xax-0.3.9}/xax/task/loggers/json.py +0 -0
  42. {xax-0.3.4 → xax-0.3.9}/xax/task/loggers/state.py +0 -0
  43. {xax-0.3.4 → xax-0.3.9}/xax/task/loggers/stdout.py +0 -0
  44. {xax-0.3.4 → xax-0.3.9}/xax/task/loggers/tensorboard.py +0 -0
  45. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/__init__.py +0 -0
  46. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/artifacts.py +0 -0
  47. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/checkpointing.py +0 -0
  48. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/compile.py +0 -0
  49. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/cpu_stats.py +0 -0
  50. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/gpu_stats.py +0 -0
  51. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/logger.py +0 -0
  52. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/process.py +0 -0
  53. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/runnable.py +0 -0
  54. {xax-0.3.4 → xax-0.3.9}/xax/task/mixins/step_wrapper.py +0 -0
  55. {xax-0.3.4 → xax-0.3.9}/xax/task/script.py +0 -0
  56. {xax-0.3.4 → xax-0.3.9}/xax/task/task.py +0 -0
  57. {xax-0.3.4 → xax-0.3.9}/xax/utils/__init__.py +0 -0
  58. {xax-0.3.4 → xax-0.3.9}/xax/utils/data/__init__.py +0 -0
  59. {xax-0.3.4 → xax-0.3.9}/xax/utils/data/collate.py +0 -0
  60. {xax-0.3.4 → xax-0.3.9}/xax/utils/debugging.py +0 -0
  61. {xax-0.3.4 → xax-0.3.9}/xax/utils/experiments.py +0 -0
  62. {xax-0.3.4 → xax-0.3.9}/xax/utils/jax.py +0 -0
  63. {xax-0.3.4 → xax-0.3.9}/xax/utils/jaxpr.py +0 -0
  64. {xax-0.3.4 → xax-0.3.9}/xax/utils/logging.py +0 -0
  65. {xax-0.3.4 → xax-0.3.9}/xax/utils/numpy.py +0 -0
  66. {xax-0.3.4 → xax-0.3.9}/xax/utils/profile.py +0 -0
  67. {xax-0.3.4 → xax-0.3.9}/xax/utils/tensorboard.py +0 -0
  68. {xax-0.3.4 → xax-0.3.9}/xax/utils/text.py +0 -0
  69. {xax-0.3.4 → xax-0.3.9}/xax/utils/types/__init__.py +0 -0
  70. {xax-0.3.4 → xax-0.3.9}/xax/utils/types/frozen_dict.py +0 -0
  71. {xax-0.3.4 → xax-0.3.9}/xax/utils/types/hashable_array.py +0 -0
  72. {xax-0.3.4 → xax-0.3.9}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.4 → xax-0.3.9}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.4 → xax-0.3.9}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.4 → xax-0.3.9}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.4
3
+ Version: 0.3.9
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.4"
15
+ __version__ = "0.3.9"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -31,6 +31,10 @@ __all__ = [
31
31
  "TransformerBlock",
32
32
  "TransformerCache",
33
33
  "TransformerStack",
34
+ "Categorical",
35
+ "Distribution",
36
+ "MixtureOfGaussians",
37
+ "Normal",
34
38
  "FourierEmbeddings",
35
39
  "IdentityPositionalEmbeddings",
36
40
  "LearnedPositionalEmbeddings",
@@ -136,6 +140,7 @@ __all__ = [
136
140
  "compute_nan_ratio",
137
141
  "flatten_array",
138
142
  "flatten_pytree",
143
+ "get_pytree_mapping",
139
144
  "get_pytree_param_count",
140
145
  "pytree_has_nans",
141
146
  "reshuffle_pytree",
@@ -218,6 +223,10 @@ NAME_MAP: dict[str, str] = {
218
223
  "TransformerBlock": "nn.attention",
219
224
  "TransformerCache": "nn.attention",
220
225
  "TransformerStack": "nn.attention",
226
+ "Categorical": "nn.distributions",
227
+ "Distribution": "nn.distributions",
228
+ "MixtureOfGaussians": "nn.distributions",
229
+ "Normal": "nn.distributions",
221
230
  "FourierEmbeddings": "nn.embeddings",
222
231
  "IdentityPositionalEmbeddings": "nn.embeddings",
223
232
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -323,6 +332,7 @@ NAME_MAP: dict[str, str] = {
323
332
  "compute_nan_ratio": "utils.pytree",
324
333
  "flatten_array": "utils.pytree",
325
334
  "flatten_pytree": "utils.pytree",
335
+ "get_pytree_mapping": "utils.pytree",
326
336
  "get_pytree_param_count": "utils.pytree",
327
337
  "pytree_has_nans": "utils.pytree",
328
338
  "reshuffle_pytree": "utils.pytree",
@@ -403,6 +413,7 @@ if IMPORT_ALL or TYPE_CHECKING:
403
413
  TransformerCache,
404
414
  TransformerStack,
405
415
  )
416
+ from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
406
417
  from xax.nn.embeddings import (
407
418
  EmbeddingKind,
408
419
  FourierEmbeddings,
@@ -509,6 +520,7 @@ if IMPORT_ALL or TYPE_CHECKING:
509
520
  compute_nan_ratio,
510
521
  flatten_array,
511
522
  flatten_pytree,
523
+ get_pytree_mapping,
512
524
  get_pytree_param_count,
513
525
  pytree_has_nans,
514
526
  reshuffle_pytree,
@@ -5,6 +5,8 @@ supporting a fixed-size context window and caching that can be used to train
5
5
  transformers which can be unrolled with a fixed-length cache.
6
6
  """
7
7
 
8
+ import math
9
+ import warnings
8
10
  from typing import NotRequired, TypedDict
9
11
 
10
12
  import chex
@@ -13,6 +15,8 @@ import jax
13
15
  import jax.numpy as jnp
14
16
  from jaxtyping import Array, PRNGKeyArray
15
17
 
18
+ from xax.utils.jax import scan as xax_scan
19
+
16
20
 
17
21
  class RotaryEmbedding(eqx.Module):
18
22
  """Rotary Position Embedding (RoPE) for transformer attention.
@@ -22,8 +26,8 @@ class RotaryEmbedding(eqx.Module):
22
26
  https://arxiv.org/abs/2104.09864
23
27
  """
24
28
 
25
- head_dim: int = eqx.static_field()
26
- base: float = eqx.static_field()
29
+ head_dim: int = eqx.field()
30
+ base: float = eqx.field()
27
31
 
28
32
  def __init__(
29
33
  self,
@@ -125,15 +129,15 @@ class TransformerCache(TypedDict):
125
129
  class SelfAttentionBlock(eqx.Module):
126
130
  """Self-attention block using jax.nn.dot_product_attention."""
127
131
 
128
- q_proj: eqx.nn.Linear
129
- k_proj: eqx.nn.Linear
130
- v_proj: eqx.nn.Linear
131
- output_proj: eqx.nn.Linear
132
- rotary_emb: RotaryEmbedding | None
133
- num_heads: int = eqx.static_field()
134
- head_dim: int = eqx.static_field()
135
- causal: bool = eqx.static_field()
136
- context_length: int | None = eqx.static_field()
132
+ q_proj: eqx.nn.Linear = eqx.field()
133
+ k_proj: eqx.nn.Linear = eqx.field()
134
+ v_proj: eqx.nn.Linear = eqx.field()
135
+ output_proj: eqx.nn.Linear = eqx.field()
136
+ rotary_emb: RotaryEmbedding | None = eqx.field()
137
+ num_heads: int = eqx.field()
138
+ head_dim: int = eqx.field()
139
+ causal: bool = eqx.field()
140
+ local_window_size: int | None = eqx.field()
137
141
 
138
142
  def __init__(
139
143
  self,
@@ -169,8 +173,12 @@ class SelfAttentionBlock(eqx.Module):
169
173
  else:
170
174
  self.rotary_emb = None
171
175
 
176
+ if context_length is not None and not causal:
177
+ warnings.warn("context_length is set but causal is False; overriding causal to True", stacklevel=2)
178
+ causal = True
179
+
172
180
  self.causal = causal
173
- self.context_length = context_length
181
+ self.local_window_size = None if context_length is None else context_length - 1
174
182
 
175
183
  @property
176
184
  def embed_dim(self) -> int:
@@ -195,28 +203,44 @@ class SelfAttentionBlock(eqx.Module):
195
203
  Returns:
196
204
  Cache with fixed-length k and v tensors
197
205
  """
198
- if self.context_length is None:
206
+ if self.local_window_size is None:
199
207
  raise ValueError("context_length must be set for caching")
200
208
 
201
209
  # Create fixed-length cache
202
- k_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
203
- v_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
210
+ k_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
211
+ v_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
204
212
 
205
213
  return {"k": k_cache, "v": v_cache, "position": 0}
206
214
 
207
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
208
- in_dim, out_dim = seq_len, seq_len
209
- if with_cache:
210
- if self.context_length is None:
211
- raise ValueError("context_length must be set for caching")
212
- in_dim = in_dim + self.context_length - 1
215
+ def init_mask(
216
+ self,
217
+ seq_len: int,
218
+ add_cache: bool = False,
219
+ batch_dim: bool = False,
220
+ ) -> Array:
221
+ """Initialize the attention matrix mask.
213
222
 
214
- mask = jnp.tril(jnp.ones((in_dim, out_dim)))
215
- if self.context_length is not None:
216
- neg_mask = 1 - jnp.tril(jnp.ones((in_dim, out_dim)), -self.context_length)
217
- mask = mask * neg_mask
223
+ Args:
224
+ seq_len: The length of the sequence
225
+ add_cache: Whether to add the cache to the mask
226
+ batch_dim: Whether to add a batch dimension to the mask
218
227
 
219
- return mask.astype(jnp.bool_).transpose()
228
+ Returns:
229
+ The attention matrix mask of shape (bsz, 1, seq_len, seq_len + cache_len)
230
+ if batch_dim is True, otherwise (seq_len, seq_len + cache_len).
231
+ """
232
+ t, s, o = seq_len, seq_len, 0
233
+ if add_cache:
234
+ if self.local_window_size is None:
235
+ raise ValueError("local_window_size must be set for caching")
236
+ s += self.local_window_size
237
+ o -= self.local_window_size
238
+ mask = jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-o)
239
+ if self.local_window_size is not None:
240
+ neg_mask = ~jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-(self.local_window_size + 1 + o))
241
+ mask = mask & neg_mask
242
+ mask = mask.reshape(1, 1, t, s) if batch_dim else mask.reshape(t, s)
243
+ return mask
220
244
 
221
245
  def forward(
222
246
  self,
@@ -229,7 +253,8 @@ class SelfAttentionBlock(eqx.Module):
229
253
 
230
254
  Args:
231
255
  x_tn: Input tensor of shape (seq_len, embed_dim)
232
- mask: Optional mask tensor
256
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
257
+ seq_len + cache_len)
233
258
  cache: The cached key and value tensors (fixed-length)
234
259
 
235
260
  Returns:
@@ -263,25 +288,36 @@ class SelfAttentionBlock(eqx.Module):
263
288
  v_cache = cache["v"]
264
289
  k = jnp.concatenate([k_cache, k], axis=0)
265
290
  v = jnp.concatenate([v_cache, v], axis=0)
291
+
266
292
  new_position = cache["position"] + seq_len
267
293
 
268
294
  else:
269
295
  new_position = seq_len
270
296
 
271
- attn_output = jax.nn.dot_product_attention(
272
- q,
273
- k,
274
- v,
275
- mask=mask,
276
- is_causal=self.causal and mask is None,
277
- )
297
+ if seq_len == 1:
298
+ attn_output = jax.nn.dot_product_attention(q, k, v)
299
+
300
+ elif mask is not None:
301
+ attn_output = jax.nn.dot_product_attention(q, k, v, mask=mask)
302
+
303
+ elif cache is not None:
304
+ raise NotImplementedError("For training with a cache, provide a mask instead.")
305
+
306
+ else:
307
+ attn_output = jax.nn.dot_product_attention(
308
+ q,
309
+ k,
310
+ v,
311
+ is_causal=self.causal,
312
+ local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
313
+ )
278
314
 
279
315
  attn_output = self._combine_heads(attn_output)
280
316
  output = jax.vmap(self.output_proj)(attn_output)
281
317
 
282
- if self.context_length is not None:
283
- k = k[-(self.context_length - 1) :]
284
- v = v[-(self.context_length - 1) :]
318
+ if self.local_window_size is not None:
319
+ k = k[-self.local_window_size :]
320
+ v = v[-self.local_window_size :]
285
321
 
286
322
  return output, {"k": k, "v": v, "position": new_position}
287
323
 
@@ -294,8 +330,8 @@ class CrossAttentionBlock(eqx.Module):
294
330
  v_proj: eqx.nn.Linear
295
331
  output_proj: eqx.nn.Linear
296
332
  rotary_emb: RotaryEmbedding | None
297
- num_heads: int = eqx.static_field()
298
- head_dim: int = eqx.static_field()
333
+ num_heads: int = eqx.field()
334
+ head_dim: int = eqx.field()
299
335
 
300
336
  def __init__(
301
337
  self,
@@ -352,7 +388,6 @@ class CrossAttentionBlock(eqx.Module):
352
388
  *,
353
389
  kv_sn: Array | None = None,
354
390
  cache: AttentionCache | None = None,
355
- mask: Array | None = None,
356
391
  ) -> tuple[Array, AttentionCache]:
357
392
  """Apply cross-attention.
358
393
 
@@ -362,7 +397,6 @@ class CrossAttentionBlock(eqx.Module):
362
397
  If not provided, then `cache` must be provided.
363
398
  cache: The cached key and value tensors. If not provided, then
364
399
  `kv_sn` must be provided.
365
- mask: Optional mask tensor
366
400
 
367
401
  Returns:
368
402
  The output tensor of shape (q_seq_len, embed_dim)
@@ -404,7 +438,7 @@ class CrossAttentionBlock(eqx.Module):
404
438
  q_rot,
405
439
  k_rot,
406
440
  v,
407
- mask=mask,
441
+ scale=1.0 / math.sqrt(self.head_dim),
408
442
  is_causal=False,
409
443
  )
410
444
 
@@ -424,10 +458,10 @@ class TransformerBlock(eqx.Module):
424
458
  layer_norm1: eqx.nn.LayerNorm
425
459
  layer_norm2: eqx.nn.LayerNorm
426
460
  layer_norm3: eqx.nn.LayerNorm | None
427
- num_heads: int = eqx.static_field()
428
- head_dim: int = eqx.static_field()
429
- causal: bool = eqx.static_field()
430
- context_length: int | None = eqx.static_field()
461
+ num_heads: int = eqx.field()
462
+ head_dim: int = eqx.field()
463
+ causal: bool = eqx.field()
464
+ context_length: int | None = eqx.field()
431
465
 
432
466
  def __init__(
433
467
  self,
@@ -500,16 +534,24 @@ class TransformerBlock(eqx.Module):
500
534
  cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
501
535
  return cache
502
536
 
503
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
504
- return self.self_attn.init_mask(seq_len, with_cache=with_cache)
537
+ def init_mask(
538
+ self,
539
+ seq_len: int,
540
+ add_cache: bool = False,
541
+ batch_dim: bool = False,
542
+ ) -> Array:
543
+ return self.self_attn.init_mask(
544
+ seq_len,
545
+ add_cache=add_cache,
546
+ batch_dim=batch_dim,
547
+ )
505
548
 
506
549
  def forward(
507
550
  self,
508
551
  x_tn: Array,
509
552
  *,
510
553
  context_sn: Array | None = None,
511
- self_mask: Array | None = None,
512
- cross_mask: Array | None = None,
554
+ mask: Array | None = None,
513
555
  cache: AttentionCacheDict | None = None,
514
556
  ) -> tuple[Array, AttentionCacheDict]:
515
557
  """Apply transformer block.
@@ -517,8 +559,8 @@ class TransformerBlock(eqx.Module):
517
559
  Args:
518
560
  x_tn: Input tensor of shape (seq_len, embed_dim)
519
561
  context_sn: Optional context for cross-attention
520
- self_mask: Mask for self-attention
521
- cross_mask: Mask for cross-attention
562
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
563
+ seq_len + cache_len)
522
564
  cache: Optional dictionary containing cached key and value tensors
523
565
 
524
566
  Returns:
@@ -531,7 +573,7 @@ class TransformerBlock(eqx.Module):
531
573
 
532
574
  attn_output, self_attn_cache = self.self_attn.forward(
533
575
  x_tn=norm_x,
534
- mask=self_mask,
576
+ mask=mask,
535
577
  cache=None if cache is None else cache["self_attn"],
536
578
  )
537
579
  updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
@@ -547,7 +589,6 @@ class TransformerBlock(eqx.Module):
547
589
  cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
548
590
  q_tn=norm_x,
549
591
  kv_sn=context_sn,
550
- mask=cross_mask,
551
592
  cache=None if cache is None else cache.get("cross_attn"),
552
593
  )
553
594
 
@@ -564,9 +605,9 @@ class TransformerBlock(eqx.Module):
564
605
  class TransformerStack(eqx.Module):
565
606
  """A stack of transformer blocks."""
566
607
 
567
- layers: list[TransformerBlock]
568
- num_layers: int = eqx.static_field()
569
- causal: bool = eqx.static_field()
608
+ layers: tuple[TransformerBlock, ...]
609
+ num_layers: int = eqx.field()
610
+ causal: bool = eqx.field()
570
611
 
571
612
  def __init__(
572
613
  self,
@@ -584,7 +625,7 @@ class TransformerStack(eqx.Module):
584
625
  ) -> None:
585
626
  keys = jax.random.split(key, num_layers)
586
627
 
587
- self.layers = [
628
+ self.layers = tuple(
588
629
  TransformerBlock(
589
630
  embed_dim=embed_dim,
590
631
  num_heads=num_heads,
@@ -597,7 +638,7 @@ class TransformerStack(eqx.Module):
597
638
  rotary_base=rotary_base,
598
639
  )
599
640
  for i in range(num_layers)
600
- ]
641
+ )
601
642
 
602
643
  self.num_layers = num_layers
603
644
  self.causal = causal
@@ -609,16 +650,24 @@ class TransformerStack(eqx.Module):
609
650
  cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
610
651
  return {"layers": cache}
611
652
 
612
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
613
- return self.layers[0].init_mask(seq_len, with_cache=with_cache)
653
+ def init_mask(
654
+ self,
655
+ seq_len: int,
656
+ add_cache: bool = False,
657
+ batch_dim: bool = False,
658
+ ) -> Array:
659
+ return self.layers[0].init_mask(
660
+ seq_len,
661
+ add_cache=add_cache,
662
+ batch_dim=batch_dim,
663
+ )
614
664
 
615
665
  def forward(
616
666
  self,
617
667
  x_tn: Array,
618
668
  *,
619
669
  context_sn: Array | None = None,
620
- self_mask: Array | None = None,
621
- cross_mask: Array | None = None,
670
+ mask: Array | None = None,
622
671
  cache: TransformerCache | None = None,
623
672
  ) -> tuple[Array, TransformerCache]:
624
673
  """Apply transformer stack.
@@ -626,8 +675,8 @@ class TransformerStack(eqx.Module):
626
675
  Args:
627
676
  x_tn: Input tensor of shape (seq_len, embed_dim)
628
677
  context_sn: Optional context for cross-attention
629
- self_mask: Mask for self-attention
630
- cross_mask: Mask for cross-attention
678
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
679
+ seq_len + cache_len)
631
680
  cache: Optional dictionary containing cached key and value tensors
632
681
 
633
682
  Returns:
@@ -647,8 +696,7 @@ class TransformerStack(eqx.Module):
647
696
  x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
648
697
  x_tn,
649
698
  context_sn=context_sn,
650
- self_mask=self_mask,
651
- cross_mask=cross_mask,
699
+ mask=mask,
652
700
  cache=layer_cache,
653
701
  )
654
702
 
@@ -660,9 +708,9 @@ class Transformer(eqx.Module):
660
708
  layers: TransformerStack
661
709
  output_layer: eqx.nn.Linear | None
662
710
  layer_norm: eqx.nn.LayerNorm
663
- embed_dim: int = eqx.static_field()
664
- causal: bool = eqx.static_field()
665
- context_length: int | None = eqx.static_field()
711
+ embed_dim: int = eqx.field()
712
+ causal: bool = eqx.field()
713
+ context_length: int | None = eqx.field()
666
714
 
667
715
  def __init__(
668
716
  self,
@@ -713,8 +761,17 @@ class Transformer(eqx.Module):
713
761
  """Initialize cache for the input."""
714
762
  return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
715
763
 
716
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
717
- return self.layers.init_mask(seq_len, with_cache=with_cache)
764
+ def init_mask(
765
+ self,
766
+ seq_len: int,
767
+ add_cache: bool = False,
768
+ batch_dim: bool = False,
769
+ ) -> Array:
770
+ return self.layers.init_mask(
771
+ seq_len,
772
+ add_cache=add_cache,
773
+ batch_dim=batch_dim,
774
+ )
718
775
 
719
776
  def encode(
720
777
  self,
@@ -727,7 +784,8 @@ class Transformer(eqx.Module):
727
784
 
728
785
  Args:
729
786
  x: Input token indices of shape (seq_len)
730
- mask: Optional attention mask
787
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
788
+ seq_len + cache_len)
731
789
  cache: Optional dictionary containing cached key and value tensors
732
790
 
733
791
  Returns:
@@ -737,11 +795,7 @@ class Transformer(eqx.Module):
737
795
  x_embedded = jax.vmap(self.token_embedding)(x)
738
796
 
739
797
  # Apply transformer stack
740
- x_embedded, updated_cache = self.layers.forward(
741
- x_embedded,
742
- self_mask=mask,
743
- cache=cache,
744
- )
798
+ x_embedded, updated_cache = self.layers.forward(x_embedded, mask=mask, cache=cache)
745
799
 
746
800
  # Apply final layer norm
747
801
  output = jax.vmap(self.layer_norm)(x_embedded)
@@ -753,8 +807,7 @@ class Transformer(eqx.Module):
753
807
  x_t: Array,
754
808
  context_s: Array,
755
809
  *,
756
- self_mask: Array | None = None,
757
- cross_mask: Array | None = None,
810
+ mask: Array | None = None,
758
811
  cache: TransformerCache | None = None,
759
812
  ) -> tuple[Array, TransformerCache]:
760
813
  """Decode with self-attention and cross-attention.
@@ -763,8 +816,8 @@ class Transformer(eqx.Module):
763
816
  x_t: Input token indices, shape (seq_len)
764
817
  context_s: Context from encoder (token indices or embedded),
765
818
  shape (context_len, embed_dim)
766
- self_mask: Optional self-attention mask, shape (seq_len, seq_len)
767
- cross_mask: Optional cross-attention mask, shape (seq_len, context_len)
819
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
820
+ seq_len + cache_len)
768
821
  cache: Optional dictionary containing cached key and value tensors
769
822
 
770
823
  Returns:
@@ -780,8 +833,7 @@ class Transformer(eqx.Module):
780
833
  x_embedded, updated_cache = self.layers.forward(
781
834
  x_embedded,
782
835
  context_sn=context_embedded,
783
- self_mask=self_mask,
784
- cross_mask=cross_mask,
836
+ mask=mask,
785
837
  cache=cache,
786
838
  )
787
839
 
@@ -801,7 +853,8 @@ class Transformer(eqx.Module):
801
853
 
802
854
  Args:
803
855
  x: Input token indices of shape (seq_len)
804
- mask: Optional attention mask
856
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
857
+ seq_len + cache_len)
805
858
  cache: Optional dictionary containing cached key and value tensors
806
859
 
807
860
  Returns:
@@ -809,11 +862,7 @@ class Transformer(eqx.Module):
809
862
  """
810
863
  chex.assert_rank(x, 1)
811
864
 
812
- output, updated_cache = self.encode(
813
- x,
814
- mask=mask,
815
- cache=cache,
816
- )
865
+ output, updated_cache = self.encode(x, mask=mask, cache=cache)
817
866
 
818
867
  # Apply output layer if it exists
819
868
  if self.output_layer is not None:
@@ -832,6 +881,7 @@ class Transformer(eqx.Module):
832
881
  temperature: float = 1.0,
833
882
  top_k: int | None = None,
834
883
  key: PRNGKeyArray | None = None,
884
+ jit_level: int | None = None,
835
885
  ) -> Array:
836
886
  """Generate a sequence autoregressively with KV caching.
837
887
 
@@ -841,6 +891,7 @@ class Transformer(eqx.Module):
841
891
  temperature: Sampling temperature
842
892
  top_k: Optional top-k sampling parameter
843
893
  key: PRNG key for sampling
894
+ jit_level: JIT level for the scan function
844
895
 
845
896
  Returns:
846
897
  Generated sequence of shape (prompt_len + max_len,)
@@ -856,7 +907,8 @@ class Transformer(eqx.Module):
856
907
 
857
908
  # Initialize cache with prompt
858
909
  cache = self.init_cache()
859
- _, cache = self.encode(prompt_seq, cache=cache)
910
+ mask = self.init_mask(prompt_len, add_cache=True, batch_dim=False)
911
+ _, cache = self.encode(prompt_seq, cache=cache, mask=mask)
860
912
 
861
913
  # Define scan function for autoregressive generation
862
914
  def scan_fn(
@@ -884,5 +936,5 @@ class Transformer(eqx.Module):
884
936
  return (new_output_seq, pos + 1, new_cache, rng), next_token
885
937
 
886
938
  init_carry = (output_seq, prompt_len - 1, cache, key)
887
- (final_seq, _, _, _), _ = jax.lax.scan(scan_fn, init_carry, length=max_len)
939
+ (final_seq, _, _, _), _ = xax_scan(scan_fn, init_carry, length=max_len, jit_level=jit_level)
888
940
  return final_seq