xax 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/nn/attention.py CHANGED
@@ -1,6 +1,11 @@
1
- """Attention mechanisms for transformer models."""
1
+ """Attention mechanisms for transformer models.
2
2
 
3
- from typing import Literal, cast, overload
3
+ This module implements standard attention mechanisms for transformers, but
4
+ supporting a fixed-size context window and caching that can be used to train
5
+ transformers which can be unrolled with a fixed-length cache.
6
+ """
7
+
8
+ from typing import NotRequired, TypedDict
4
9
 
5
10
  import chex
6
11
  import equinox as eqx
@@ -9,6 +14,114 @@ import jax.numpy as jnp
9
14
  from jaxtyping import Array, PRNGKeyArray
10
15
 
11
16
 
17
+ class RotaryEmbedding(eqx.Module):
18
+ """Rotary Position Embedding (RoPE) for transformer attention.
19
+
20
+ This implements the rotary position embedding as described in:
21
+ "RoFormer: Enhanced Transformer with Rotary Position Embedding"
22
+ https://arxiv.org/abs/2104.09864
23
+ """
24
+
25
+ head_dim: int = eqx.static_field()
26
+ base: float = eqx.static_field()
27
+
28
+ def __init__(
29
+ self,
30
+ head_dim: int,
31
+ base: float = 10000.0,
32
+ ) -> None:
33
+ """Initialize rotary embedding.
34
+
35
+ Args:
36
+ head_dim: Dimension of each attention head
37
+ base: Base for the frequency computation
38
+ """
39
+ self.head_dim = head_dim
40
+ self.base = base
41
+
42
+ def _get_rotary_embeddings(self, positions: Array, dtype: jnp.dtype) -> tuple[Array, Array]:
43
+ """Get rotary embeddings for a given sequence length.
44
+
45
+ Args:
46
+ positions: Positions of the sequence
47
+ dtype: Data type for the embeddings
48
+
49
+ Returns:
50
+ Tuple of (cos_embeddings, sin_embeddings) of shape (seq_len, head_dim//2)
51
+ """
52
+ # Create frequency bands
53
+ dim = self.head_dim // 2
54
+ freqs = jnp.exp(-jnp.arange(0, dim, dtype=dtype) * jnp.log(self.base) / dim)
55
+
56
+ # Compute angles
57
+ angles = positions[:, None] * freqs[None, :] # (seq_len, dim)
58
+
59
+ # Compute cos and sin embeddings
60
+ cos_embeddings = jnp.cos(angles)
61
+ sin_embeddings = jnp.sin(angles)
62
+
63
+ return cos_embeddings, sin_embeddings
64
+
65
+ def apply_rotary_embeddings(
66
+ self,
67
+ x: Array,
68
+ positions: Array | None = None,
69
+ ) -> Array:
70
+ """Apply rotary embeddings to input tensor.
71
+
72
+ Args:
73
+ x: Input tensor of shape (seq_len, num_heads, head_dim)
74
+ positions: Optional position indices of shape (seq_len,)
75
+ If None, uses sequential positions starting from 0
76
+
77
+ Returns:
78
+ Tensor with rotary embeddings applied, same shape as input
79
+ """
80
+ seq_len, _, head_dim = x.shape
81
+ assert head_dim == self.head_dim, f"Expected head_dim {self.head_dim}, got {head_dim}"
82
+
83
+ # Get rotary embeddings
84
+ if positions is None:
85
+ positions = jnp.arange(seq_len, dtype=x.dtype)
86
+ cos_emb, sin_emb = self._get_rotary_embeddings(positions, x.dtype)
87
+
88
+ # Reshape to (seq_len, 1, head_dim//2) for broadcasting
89
+ cos_emb = cos_emb[:, None, :] # (seq_len, 1, head_dim//2)
90
+ sin_emb = sin_emb[:, None, :] # (seq_len, 1, head_dim//2)
91
+
92
+ # Split input into even and odd dimensions
93
+ x_even = x[..., ::2] # (seq_len, num_heads, head_dim//2)
94
+ x_odd = x[..., 1::2] # (seq_len, num_heads, head_dim//2)
95
+
96
+ # Apply rotation
97
+ rotated_even = x_even * cos_emb - x_odd * sin_emb
98
+ rotated_odd = x_even * sin_emb + x_odd * cos_emb
99
+
100
+ # Interleave back together
101
+ result = jnp.zeros_like(x)
102
+ result = result.at[..., ::2].set(rotated_even)
103
+ result = result.at[..., 1::2].set(rotated_odd)
104
+
105
+ return result
106
+
107
+
108
+ class AttentionCache(TypedDict):
109
+ k: Array
110
+ v: Array
111
+ position: int # Position counter for rotary embeddings
112
+
113
+
114
+ class AttentionCacheDict(TypedDict):
115
+ self_attn: AttentionCache
116
+ cross_attn: NotRequired[AttentionCache]
117
+
118
+
119
+ class TransformerCache(TypedDict):
120
+ """Cache for the entire transformer stack."""
121
+
122
+ layers: dict[str, AttentionCacheDict]
123
+
124
+
12
125
  class SelfAttentionBlock(eqx.Module):
13
126
  """Self-attention block using jax.nn.dot_product_attention."""
14
127
 
@@ -16,9 +129,11 @@ class SelfAttentionBlock(eqx.Module):
16
129
  k_proj: eqx.nn.Linear
17
130
  v_proj: eqx.nn.Linear
18
131
  output_proj: eqx.nn.Linear
132
+ rotary_emb: RotaryEmbedding | None
19
133
  num_heads: int = eqx.static_field()
20
134
  head_dim: int = eqx.static_field()
21
135
  causal: bool = eqx.static_field()
136
+ context_length: int | None = eqx.static_field()
22
137
 
23
138
  def __init__(
24
139
  self,
@@ -27,7 +142,13 @@ class SelfAttentionBlock(eqx.Module):
27
142
  *,
28
143
  key: PRNGKeyArray,
29
144
  causal: bool = False,
145
+ context_length: int | None = None,
146
+ use_rotary_embeddings: bool = False,
147
+ rotary_base: float = 10000.0,
30
148
  ) -> None:
149
+ if context_length is not None:
150
+ assert context_length > 1, "context_length must be at least 2"
151
+
31
152
  keys = jax.random.split(key, 4)
32
153
 
33
154
  self.num_heads = num_heads
@@ -39,7 +160,21 @@ class SelfAttentionBlock(eqx.Module):
39
160
  self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
40
161
  self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
41
162
 
163
+ # Initialize rotary embeddings if requested
164
+ if use_rotary_embeddings:
165
+ self.rotary_emb = RotaryEmbedding(
166
+ head_dim=self.head_dim,
167
+ base=rotary_base,
168
+ )
169
+ else:
170
+ self.rotary_emb = None
171
+
42
172
  self.causal = causal
173
+ self.context_length = context_length
174
+
175
+ @property
176
+ def embed_dim(self) -> int:
177
+ return self.head_dim * self.num_heads
43
178
 
44
179
  def _reshape_for_multihead(self, x: Array) -> Array:
45
180
  """Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
@@ -48,58 +183,91 @@ class SelfAttentionBlock(eqx.Module):
48
183
 
49
184
  def _combine_heads(self, x: Array) -> Array:
50
185
  """Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
51
- seq_len, _, _ = x.shape
52
- return x.reshape(seq_len, -1)
186
+ _, n, h = x.shape
187
+ return x.reshape(-1, n * h)
188
+
189
+ def init_cache(self, dtype: jnp.dtype | None = None) -> AttentionCache:
190
+ """Initialize cache for the input.
191
+
192
+ Args:
193
+ dtype: The dtype of the cache
194
+
195
+ Returns:
196
+ Cache with fixed-length k and v tensors
197
+ """
198
+ if self.context_length is None:
199
+ raise ValueError("context_length must be set for caching")
200
+
201
+ # Create fixed-length cache
202
+ k_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
203
+ v_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
204
+
205
+ return {"k": k_cache, "v": v_cache, "position": 0}
53
206
 
54
- def __call__(
207
+ def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
208
+ in_dim, out_dim = seq_len, seq_len
209
+ if with_cache:
210
+ if self.context_length is None:
211
+ raise ValueError("context_length must be set for caching")
212
+ in_dim = in_dim + self.context_length - 1
213
+
214
+ mask = jnp.tril(jnp.ones((in_dim, out_dim)))
215
+ if self.context_length is not None:
216
+ neg_mask = 1 - jnp.tril(jnp.ones((in_dim, out_dim)), -self.context_length)
217
+ mask = mask * neg_mask
218
+
219
+ return mask.astype(jnp.bool_).transpose()
220
+
221
+ def forward(
55
222
  self,
56
- x: Array,
223
+ x_tn: Array,
57
224
  *,
58
- key: PRNGKeyArray | None = None,
59
225
  mask: Array | None = None,
60
- cache: dict[str, Array] | None = None,
61
- update_cache: bool = False,
62
- ) -> Array | tuple[Array, dict[str, Array]]:
63
- """Apply self-attention to the input.
226
+ cache: AttentionCache | None = None,
227
+ ) -> tuple[Array, AttentionCache]:
228
+ """Apply self-attention.
64
229
 
65
230
  Args:
66
- x: Input tensor of shape (seq_len, embed_dim)
67
- key: PRNGKey for dropout randomness
68
- mask: Optional mask tensor of shape (seq_len, seq_len) or broadcastable
69
- cache: Optional dictionary containing cached key and value tensors
70
- update_cache: Whether to update the cache and return it
231
+ x_tn: Input tensor of shape (seq_len, embed_dim)
232
+ mask: Optional mask tensor
233
+ cache: The cached key and value tensors (fixed-length)
71
234
 
72
235
  Returns:
73
- If update_cache is False: Output tensor of shape (seq_len, embed_dim)
74
- If update_cache is True: Tuple of (output tensor, updated cache)
236
+ The output tensor of shape (seq_len, embed_dim) and updated cache
75
237
  """
76
- chex.assert_rank(x, 2)
238
+ chex.assert_rank(x_tn, 2)
77
239
 
78
240
  # Project inputs to queries, keys, and values
79
- q = jax.vmap(self.q_proj)(x)
80
-
81
- # Use cached key/value if provided and not updating cache
82
- if cache is not None and not update_cache:
83
- k = cache["k"]
84
- v = cache["v"]
85
- else:
86
- k = jax.vmap(self.k_proj)(x)
87
- v = jax.vmap(self.v_proj)(x)
88
-
89
- # Update cache if needed
90
- if update_cache:
91
- if cache is None:
92
- cache = {}
93
- cache = {"k": k, "v": v}
241
+ q = jax.vmap(self.q_proj)(x_tn)
242
+ k = jax.vmap(self.k_proj)(x_tn)
243
+ v = jax.vmap(self.v_proj)(x_tn)
94
244
 
95
245
  # Reshape to multihead format
96
246
  q = self._reshape_for_multihead(q)
97
247
  k = self._reshape_for_multihead(k)
98
248
  v = self._reshape_for_multihead(v)
99
249
 
100
- # Apply dot product attention.
101
- # Note that Apple Silicon struggles with this:
102
- # https://github.com/jax-ml/jax/issues/20114
250
+ seq_len = q.shape[0]
251
+ if self.rotary_emb is not None:
252
+ # Determine position indices for rotary embeddings
253
+ if cache is not None:
254
+ start_pos = cache["position"]
255
+ else:
256
+ start_pos = 0
257
+ positions = jnp.arange(seq_len) + start_pos
258
+ q = self.rotary_emb.apply_rotary_embeddings(q, positions=positions)
259
+ k = self.rotary_emb.apply_rotary_embeddings(k, positions=positions)
260
+
261
+ if cache is not None:
262
+ k_cache = cache["k"]
263
+ v_cache = cache["v"]
264
+ k = jnp.concatenate([k_cache, k], axis=0)
265
+ v = jnp.concatenate([v_cache, v], axis=0)
266
+ new_position = cache["position"] + seq_len
267
+
268
+ else:
269
+ new_position = seq_len
270
+
103
271
  attn_output = jax.nn.dot_product_attention(
104
272
  q,
105
273
  k,
@@ -108,15 +276,14 @@ class SelfAttentionBlock(eqx.Module):
108
276
  is_causal=self.causal and mask is None,
109
277
  )
110
278
 
111
- # Combine heads
112
279
  attn_output = self._combine_heads(attn_output)
113
-
114
- # Final projection
115
280
  output = jax.vmap(self.output_proj)(attn_output)
116
281
 
117
- if update_cache:
118
- return output, cast(dict[str, Array], cache)
119
- return output
282
+ if self.context_length is not None:
283
+ k = k[-(self.context_length - 1) :]
284
+ v = v[-(self.context_length - 1) :]
285
+
286
+ return output, {"k": k, "v": v, "position": new_position}
120
287
 
121
288
 
122
289
  class CrossAttentionBlock(eqx.Module):
@@ -126,6 +293,7 @@ class CrossAttentionBlock(eqx.Module):
126
293
  k_proj: eqx.nn.Linear
127
294
  v_proj: eqx.nn.Linear
128
295
  output_proj: eqx.nn.Linear
296
+ rotary_emb: RotaryEmbedding | None
129
297
  num_heads: int = eqx.static_field()
130
298
  head_dim: int = eqx.static_field()
131
299
 
@@ -135,6 +303,8 @@ class CrossAttentionBlock(eqx.Module):
135
303
  num_heads: int,
136
304
  *,
137
305
  key: PRNGKeyArray,
306
+ use_rotary_embeddings: bool = False,
307
+ rotary_base: float = 10000.0,
138
308
  ) -> None:
139
309
  keys = jax.random.split(key, 4)
140
310
 
@@ -147,6 +317,15 @@ class CrossAttentionBlock(eqx.Module):
147
317
  self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
148
318
  self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
149
319
 
320
+ # Initialize rotary embeddings if requested
321
+ if use_rotary_embeddings:
322
+ self.rotary_emb = RotaryEmbedding(
323
+ head_dim=self.head_dim,
324
+ base=rotary_base,
325
+ )
326
+ else:
327
+ self.rotary_emb = None
328
+
150
329
  def _reshape_for_multihead(self, x: Array) -> Array:
151
330
  """Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
152
331
  seq_len, _ = x.shape
@@ -157,59 +336,73 @@ class CrossAttentionBlock(eqx.Module):
157
336
  seq_len, _, _ = x.shape
158
337
  return x.reshape(seq_len, -1)
159
338
 
160
- def __call__(
339
+ def init_cache(self, kv_sn: Array) -> AttentionCache:
340
+ """Initialize cache for the input."""
341
+ chex.assert_rank(kv_sn, 2)
342
+ k = jax.vmap(self.k_proj)(kv_sn)
343
+ v = jax.vmap(self.v_proj)(kv_sn)
344
+ # Reshape to multihead format
345
+ k = self._reshape_for_multihead(k)
346
+ v = self._reshape_for_multihead(v)
347
+ return {"k": k, "v": v, "position": 0}
348
+
349
+ def forward(
161
350
  self,
162
- q_input: Array,
163
- kv_input: Array,
351
+ q_tn: Array,
164
352
  *,
165
- key: PRNGKeyArray | None = None,
353
+ kv_sn: Array | None = None,
354
+ cache: AttentionCache | None = None,
166
355
  mask: Array | None = None,
167
- cache: dict[str, Array] | None = None,
168
- update_cache: bool = False,
169
- ) -> Array | tuple[Array, dict[str, Array]]:
356
+ ) -> tuple[Array, AttentionCache]:
170
357
  """Apply cross-attention.
171
358
 
172
359
  Args:
173
- q_input: Query input tensor of shape (q_seq_len, embed_dim)
174
- kv_input: Key/value input tensor of shape (kv_seq_len, embed_dim)
175
- key: PRNGKey for dropout randomness
360
+ q_tn: Query input tensor of shape (q_seq_len, embed_dim)
361
+ kv_sn: Key/value input tensor of shape (kv_seq_len, embed_dim).
362
+ If not provided, then `cache` must be provided.
363
+ cache: The cached key and value tensors. If not provided, then
364
+ `kv_sn` must be provided.
176
365
  mask: Optional mask tensor
177
- cache: Optional dictionary containing cached key and value tensors
178
- update_cache: Whether to update the cache and return it
179
366
 
180
367
  Returns:
181
- If update_cache is False: Output tensor of shape (q_seq_len, embed_dim)
182
- If update_cache is True: Tuple of (output tensor, updated cache)
368
+ The output tensor of shape (q_seq_len, embed_dim)
183
369
  """
184
- chex.assert_rank(q_input, 2)
185
- chex.assert_rank(kv_input, 2)
370
+ chex.assert_rank(q_tn, 2)
186
371
 
187
372
  # Project inputs to queries, keys, and values
188
- q = jax.vmap(self.q_proj)(q_input)
373
+ q = jax.vmap(self.q_proj)(q_tn)
374
+ q = self._reshape_for_multihead(q)
375
+ q_seq_len = q.shape[0]
189
376
 
190
- # Use cached key/value if provided and not updating cache
191
- if cache is not None and not update_cache:
377
+ # Use cached key/value if provided
378
+ if cache is not None:
192
379
  k = cache["k"]
193
380
  v = cache["v"]
381
+ q_position = cache["position"]
382
+ elif kv_sn is not None:
383
+ chex.assert_rank(kv_sn, 2)
384
+ k = jax.vmap(self.k_proj)(kv_sn)
385
+ v = jax.vmap(self.v_proj)(kv_sn)
386
+ k = self._reshape_for_multihead(k)
387
+ v = self._reshape_for_multihead(v)
388
+ q_position = 0
194
389
  else:
195
- k = jax.vmap(self.k_proj)(kv_input)
196
- v = jax.vmap(self.v_proj)(kv_input)
197
-
198
- # Update cache if needed
199
- if update_cache:
200
- if cache is None:
201
- cache = {}
202
- cache = {"k": k, "v": v}
390
+ raise ValueError("Either `cache` or `kv_sn` must be provided.")
203
391
 
204
- # Reshape to multihead format
205
- q = self._reshape_for_multihead(q)
206
- k = self._reshape_for_multihead(k)
207
- v = self._reshape_for_multihead(v)
392
+ # Apply rotary embeddings to queries and keys if enabled
393
+ if self.rotary_emb is None:
394
+ q_rot = q
395
+ k_rot = k
396
+ else:
397
+ q_positions = jnp.arange(q_seq_len) + q_position
398
+ k_positions = jnp.arange(k.shape[0])
399
+ q_rot = self.rotary_emb.apply_rotary_embeddings(q, positions=q_positions)
400
+ k_rot = self.rotary_emb.apply_rotary_embeddings(k, positions=k_positions)
208
401
 
209
402
  # Apply dot product attention
210
403
  attn_output = jax.nn.dot_product_attention(
211
- q,
212
- k,
404
+ q_rot,
405
+ k_rot,
213
406
  v,
214
407
  mask=mask,
215
408
  is_causal=False,
@@ -221,9 +414,7 @@ class CrossAttentionBlock(eqx.Module):
221
414
  # Final projection
222
415
  output = jax.vmap(self.output_proj)(attn_output)
223
416
 
224
- if update_cache:
225
- return output, cast(dict[str, Array], cache)
226
- return output
417
+ return output, {"k": k, "v": v, "position": q_position + q_seq_len}
227
418
 
228
419
 
229
420
  class TransformerBlock(eqx.Module):
@@ -233,6 +424,10 @@ class TransformerBlock(eqx.Module):
233
424
  layer_norm1: eqx.nn.LayerNorm
234
425
  layer_norm2: eqx.nn.LayerNorm
235
426
  layer_norm3: eqx.nn.LayerNorm | None
427
+ num_heads: int = eqx.static_field()
428
+ head_dim: int = eqx.static_field()
429
+ causal: bool = eqx.static_field()
430
+ context_length: int | None = eqx.static_field()
236
431
 
237
432
  def __init__(
238
433
  self,
@@ -243,14 +438,20 @@ class TransformerBlock(eqx.Module):
243
438
  key: PRNGKeyArray,
244
439
  causal: bool = False,
245
440
  cross_attention: bool = False,
441
+ context_length: int | None = None,
442
+ use_rotary_embeddings: bool = False,
443
+ rotary_base: float = 10000.0,
246
444
  ) -> None:
247
- keys = jax.random.split(key, 4)
445
+ keys = jax.random.split(key, 3)
248
446
 
249
447
  self.self_attn = SelfAttentionBlock(
250
448
  embed_dim=embed_dim,
251
449
  num_heads=num_heads,
252
450
  key=keys[0],
253
451
  causal=causal,
452
+ context_length=context_length,
453
+ use_rotary_embeddings=use_rotary_embeddings,
454
+ rotary_base=rotary_base,
254
455
  )
255
456
 
256
457
  if cross_attention:
@@ -258,8 +459,11 @@ class TransformerBlock(eqx.Module):
258
459
  embed_dim=embed_dim,
259
460
  num_heads=num_heads,
260
461
  key=keys[1],
462
+ use_rotary_embeddings=use_rotary_embeddings,
463
+ rotary_base=rotary_base,
261
464
  )
262
465
  self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
466
+
263
467
  else:
264
468
  self.cross_attn = None
265
469
  self.layer_norm3 = None
@@ -276,390 +480,350 @@ class TransformerBlock(eqx.Module):
276
480
  key=keys[2],
277
481
  )
278
482
 
279
- @overload
280
- def __call__(
281
- self,
282
- x: Array,
283
- *,
284
- context: Array | None = None,
285
- self_mask: Array | None = None,
286
- cross_mask: Array | None = None,
287
- key: PRNGKeyArray | None = None,
288
- cache: dict[str, dict[str, Array]] | None = None,
289
- update_cache: Literal[True],
290
- ) -> tuple[Array, dict[str, dict[str, Array]]]: ...
291
-
292
- @overload
293
- def __call__(
294
- self,
295
- x: Array,
296
- *,
297
- context: Array | None = None,
298
- self_mask: Array | None = None,
299
- cross_mask: Array | None = None,
300
- key: PRNGKeyArray | None = None,
301
- cache: dict[str, dict[str, Array]] | None = None,
302
- update_cache: Literal[False] = False,
303
- ) -> Array: ...
304
-
305
- def __call__(
483
+ self.num_heads = num_heads
484
+ self.head_dim = embed_dim // num_heads
485
+ self.causal = causal
486
+ self.context_length = context_length
487
+
488
+ @property
489
+ def embed_dim(self) -> int:
490
+ return self.head_dim * self.num_heads
491
+
492
+ def init_cache(self, dtype: jnp.dtype | None = None, context_sn: Array | None = None) -> AttentionCacheDict:
493
+ """Initialize cache for the input."""
494
+ if dtype is None and context_sn is not None:
495
+ dtype = context_sn.dtype
496
+ cache: AttentionCacheDict = {"self_attn": self.self_attn.init_cache(dtype=dtype)}
497
+ if self.cross_attn is not None:
498
+ if context_sn is None:
499
+ raise ValueError("x_tn must be provided if cross_attn is not None")
500
+ cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
501
+ return cache
502
+
503
+ def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
504
+ return self.self_attn.init_mask(seq_len, with_cache=with_cache)
505
+
506
+ def forward(
306
507
  self,
307
- x: Array,
508
+ x_tn: Array,
308
509
  *,
309
- context: Array | None = None,
510
+ context_sn: Array | None = None,
310
511
  self_mask: Array | None = None,
311
512
  cross_mask: Array | None = None,
312
- key: PRNGKeyArray | None = None,
313
- cache: dict[str, dict[str, Array]] | None = None,
314
- update_cache: bool = False,
315
- ) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
513
+ cache: AttentionCacheDict | None = None,
514
+ ) -> tuple[Array, AttentionCacheDict]:
316
515
  """Apply transformer block.
317
516
 
318
517
  Args:
319
- x: Input tensor
320
- context: Optional context for cross-attention
518
+ x_tn: Input tensor of shape (seq_len, embed_dim)
519
+ context_sn: Optional context for cross-attention
321
520
  self_mask: Mask for self-attention
322
521
  cross_mask: Mask for cross-attention
323
- key: Optional PRNG key for dropout
324
522
  cache: Optional dictionary containing cached key and value tensors
325
- update_cache: Whether to update the cache and return it
326
523
 
327
524
  Returns:
328
- If update_cache is False: Output tensor
329
- If update_cache is True: Tuple of (output tensor, updated cache)
525
+ The output tensor and the updated cache
330
526
  """
331
- chex.assert_rank(x, 2)
332
- if key is not None:
333
- key1, key2 = jax.random.split(key)
334
- else:
335
- key1 = key2 = None
336
-
337
- # Initialize cache if needed
338
- updated_cache = {}
339
- if cache is None:
340
- cache = {}
527
+ chex.assert_rank(x_tn, 2)
341
528
 
342
529
  # Self-attention block with pre-norm
343
- norm_x = jax.vmap(self.layer_norm1)(x)
530
+ norm_x = jax.vmap(self.layer_norm1)(x_tn)
344
531
 
345
- self_attn_cache = cache.get("self_attn")
346
- if update_cache:
347
- attn_output, self_attn_cache = self.self_attn(
348
- norm_x, key=key1, mask=self_mask, cache=self_attn_cache, update_cache=True
349
- )
350
- updated_cache["self_attn"] = self_attn_cache
351
- else:
352
- attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
532
+ attn_output, self_attn_cache = self.self_attn.forward(
533
+ x_tn=norm_x,
534
+ mask=self_mask,
535
+ cache=None if cache is None else cache["self_attn"],
536
+ )
537
+ updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
353
538
 
354
- x = x + attn_output
539
+ x_tn = x_tn + attn_output
355
540
 
356
541
  # Cross-attention block (if enabled) with pre-norm
357
- if self.cross_attn is not None and context is not None:
542
+ if self.cross_attn is not None:
358
543
  assert self.layer_norm3 is not None
359
544
 
360
- norm_x = jax.vmap(self.layer_norm3)(x)
361
- cross_attn_cache = cache.get("cross_attn")
545
+ norm_x = jax.vmap(self.layer_norm3)(x_tn)
362
546
 
363
- if update_cache:
364
- cross_attn_output, cross_attn_cache = self.cross_attn(
365
- norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache, update_cache=True
366
- )
367
- updated_cache["cross_attn"] = cross_attn_cache
368
- else:
369
- cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
547
+ cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
548
+ q_tn=norm_x,
549
+ kv_sn=context_sn,
550
+ mask=cross_mask,
551
+ cache=None if cache is None else cache.get("cross_attn"),
552
+ )
370
553
 
371
- x = x + cross_attn_output
554
+ x_tn = x_tn + cross_attn_output
372
555
 
373
556
  # Feed-forward block with pre-norm
374
- norm_x = jax.vmap(self.layer_norm2)(x)
557
+ norm_x = jax.vmap(self.layer_norm2)(x_tn)
375
558
  ff_output = jax.vmap(self.feed_forward)(norm_x)
376
- x = x + ff_output
559
+ x_tn = x_tn + ff_output
377
560
 
378
- if update_cache:
379
- return x, updated_cache
380
- return x
561
+ return x_tn, updated_cache
381
562
 
382
563
 
383
- class Transformer(eqx.Module):
384
- token_embedding: eqx.nn.Embedding
385
- position_embedding: eqx.nn.Embedding | None
564
+ class TransformerStack(eqx.Module):
565
+ """A stack of transformer blocks."""
566
+
386
567
  layers: list[TransformerBlock]
387
- output_layer: eqx.nn.Linear | None
388
- layer_norm: eqx.nn.LayerNorm
389
- max_seq_len: int = eqx.static_field()
390
- embed_dim: int = eqx.static_field()
568
+ num_layers: int = eqx.static_field()
569
+ causal: bool = eqx.static_field()
391
570
 
392
571
  def __init__(
393
572
  self,
394
- vocab_size: int,
395
573
  embed_dim: int,
396
574
  num_heads: int,
397
575
  ff_dim: int,
398
576
  num_layers: int,
399
- max_seq_len: int,
400
- output_size: int | None = None,
401
577
  *,
402
578
  key: PRNGKeyArray,
403
579
  causal: bool = False,
404
580
  cross_attention: bool = False,
405
- use_absolute_position: bool = True,
581
+ context_length: int | None = None,
582
+ use_rotary_embeddings: bool = False,
583
+ rotary_base: float = 10000.0,
406
584
  ) -> None:
407
- keys = jax.random.split(key, num_layers + 3)
408
-
409
- self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
410
-
411
- # Position embeddings can be disabled
412
- if use_absolute_position:
413
- self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
414
- else:
415
- self.position_embedding = None
585
+ keys = jax.random.split(key, num_layers)
416
586
 
417
587
  self.layers = [
418
588
  TransformerBlock(
419
589
  embed_dim=embed_dim,
420
590
  num_heads=num_heads,
421
591
  ff_dim=ff_dim,
422
- key=keys[i + 2],
592
+ key=keys[i],
423
593
  causal=causal,
424
594
  cross_attention=cross_attention,
595
+ context_length=context_length,
596
+ use_rotary_embeddings=use_rotary_embeddings,
597
+ rotary_base=rotary_base,
425
598
  )
426
599
  for i in range(num_layers)
427
600
  ]
428
601
 
429
- self.layer_norm = eqx.nn.LayerNorm(embed_dim)
602
+ self.num_layers = num_layers
603
+ self.causal = causal
604
+
605
+ def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
606
+ """Initialize cache for the input."""
607
+ cache = {}
608
+ for i, layer in enumerate(self.layers):
609
+ cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
610
+ return {"layers": cache}
611
+
612
+ def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
613
+ return self.layers[0].init_mask(seq_len, with_cache=with_cache)
614
+
615
+ def forward(
616
+ self,
617
+ x_tn: Array,
618
+ *,
619
+ context_sn: Array | None = None,
620
+ self_mask: Array | None = None,
621
+ cross_mask: Array | None = None,
622
+ cache: TransformerCache | None = None,
623
+ ) -> tuple[Array, TransformerCache]:
624
+ """Apply transformer stack.
625
+
626
+ Args:
627
+ x_tn: Input tensor of shape (seq_len, embed_dim)
628
+ context_sn: Optional context for cross-attention
629
+ self_mask: Mask for self-attention
630
+ cross_mask: Mask for cross-attention
631
+ cache: Optional dictionary containing cached key and value tensors
632
+
633
+ Returns:
634
+ The output tensor and the updated cache
635
+ """
636
+ # Initialize layer caches
637
+ if cache is None:
638
+ cache = {"layers": {}}
639
+
640
+ # Updated cache will be built
641
+ updated_cache: TransformerCache = {"layers": {}}
642
+
643
+ # Apply transformer layers
644
+ for i, layer in enumerate(self.layers):
645
+ layer_cache = cache["layers"].get(f"layer_{i}")
646
+
647
+ x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
648
+ x_tn,
649
+ context_sn=context_sn,
650
+ self_mask=self_mask,
651
+ cross_mask=cross_mask,
652
+ cache=layer_cache,
653
+ )
654
+
655
+ return x_tn, updated_cache
656
+
657
+
658
+ class Transformer(eqx.Module):
659
+ token_embedding: eqx.nn.Embedding
660
+ layers: TransformerStack
661
+ output_layer: eqx.nn.Linear | None
662
+ layer_norm: eqx.nn.LayerNorm
663
+ embed_dim: int = eqx.static_field()
664
+ causal: bool = eqx.static_field()
665
+ context_length: int | None = eqx.static_field()
666
+
667
+ def __init__(
668
+ self,
669
+ vocab_size: int,
670
+ embed_dim: int,
671
+ num_heads: int,
672
+ ff_dim: int,
673
+ num_layers: int,
674
+ output_size: int | None = None,
675
+ *,
676
+ key: PRNGKeyArray,
677
+ causal: bool = False,
678
+ cross_attention: bool = False,
679
+ context_length: int | None = None,
680
+ use_rotary_embeddings: bool = False,
681
+ rotary_base: float = 10000.0,
682
+ ) -> None:
683
+ # Calculate number of keys needed
684
+ num_keys = 3 if output_size is None else 4
685
+ keys = jax.random.split(key, num_keys)
686
+
687
+ self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
688
+
689
+ self.layers = TransformerStack(
690
+ embed_dim=embed_dim,
691
+ num_heads=num_heads,
692
+ ff_dim=ff_dim,
693
+ num_layers=num_layers,
694
+ key=keys[2],
695
+ causal=causal,
696
+ cross_attention=cross_attention,
697
+ context_length=context_length,
698
+ use_rotary_embeddings=use_rotary_embeddings,
699
+ rotary_base=rotary_base,
700
+ )
430
701
 
702
+ self.layer_norm = eqx.nn.LayerNorm(embed_dim)
431
703
  if output_size is not None:
432
- self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[-1])
704
+ self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[3])
433
705
  else:
434
706
  self.output_layer = None
435
707
 
436
- self.max_seq_len = max_seq_len
437
708
  self.embed_dim = embed_dim
709
+ self.causal = causal
710
+ self.context_length = context_length
438
711
 
439
- def _add_positional_embedding(self, x_embedded: Array, positions: Array | None = None) -> Array:
440
- """Add positional embeddings to the token embeddings."""
441
- if self.position_embedding is None:
442
- return x_embedded
443
-
444
- seq_len, _ = x_embedded.shape
445
-
446
- if positions is None:
447
- positions = jnp.arange(seq_len)
448
- pos_embedded = jax.vmap(self.position_embedding)(positions)
712
+ def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
713
+ """Initialize cache for the input."""
714
+ return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
449
715
 
450
- return x_embedded + pos_embedded
716
+ def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
717
+ return self.layers.init_mask(seq_len, with_cache=with_cache)
451
718
 
452
719
  def encode(
453
720
  self,
454
721
  x: Array,
722
+ *,
455
723
  mask: Array | None = None,
456
- positions: Array | None = None,
457
- key: PRNGKeyArray | None = None,
458
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
459
- update_cache: bool = False,
460
- ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
724
+ cache: TransformerCache | None = None,
725
+ ) -> tuple[Array, TransformerCache]:
461
726
  """Encode the input sequence.
462
727
 
463
728
  Args:
464
729
  x: Input token indices of shape (seq_len)
465
730
  mask: Optional attention mask
466
- positions: Optional positions
467
- key: Optional PRNG key for dropout
468
731
  cache: Optional dictionary containing cached key and value tensors
469
- update_cache: Whether to update the cache and return it
470
732
 
471
733
  Returns:
472
- If update_cache is False: Encoded representation
473
- If update_cache is True: Tuple of (encoded representation, updated cache)
734
+ The encoded representation and the updated cache
474
735
  """
475
736
  # Token embedding
476
737
  x_embedded = jax.vmap(self.token_embedding)(x)
477
738
 
478
- # Add positional embedding
479
- x_embedded = self._add_positional_embedding(x_embedded, positions)
480
-
481
- # Initialize layer caches
482
- if cache is None and update_cache:
483
- cache = {f"layer_{i}": {} for i in range(len(self.layers))}
484
-
485
- # Updated cache will be built if needed
486
- updated_cache = {}
487
-
488
- # Apply transformer layers
489
- keys: Array | list[None] = [None] * len(self.layers)
490
- if key is not None:
491
- keys = jax.random.split(key, len(self.layers))
492
-
493
- for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
494
- layer_cache = None if cache is None else cache.get(f"layer_{i}")
495
-
496
- if update_cache:
497
- x_embedded, layer_updated_cache = layer.__call__(
498
- x_embedded,
499
- self_mask=mask,
500
- key=layer_key,
501
- cache=layer_cache,
502
- update_cache=True,
503
- )
504
- updated_cache[f"layer_{i}"] = layer_updated_cache
505
- else:
506
- x_embedded = layer.__call__(
507
- x_embedded,
508
- self_mask=mask,
509
- key=layer_key,
510
- cache=layer_cache,
511
- )
739
+ # Apply transformer stack
740
+ x_embedded, updated_cache = self.layers.forward(
741
+ x_embedded,
742
+ self_mask=mask,
743
+ cache=cache,
744
+ )
512
745
 
513
746
  # Apply final layer norm
514
747
  output = jax.vmap(self.layer_norm)(x_embedded)
515
748
 
516
- if update_cache:
517
- return output, updated_cache
518
- return output
749
+ return output, updated_cache
519
750
 
520
751
  def decode(
521
752
  self,
522
- x: Array,
523
- context: Array,
753
+ x_t: Array,
754
+ context_s: Array,
755
+ *,
524
756
  self_mask: Array | None = None,
525
757
  cross_mask: Array | None = None,
526
- positions: Array | None = None,
527
- key: PRNGKeyArray | None = None,
528
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
529
- update_cache: bool = False,
530
- ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
758
+ cache: TransformerCache | None = None,
759
+ ) -> tuple[Array, TransformerCache]:
531
760
  """Decode with self-attention and cross-attention.
532
761
 
533
762
  Args:
534
- x: Input token indices
535
- context: Context from encoder
536
- self_mask: Optional self-attention mask
537
- cross_mask: Optional cross-attention mask
538
- positions: Optional positions
539
- key: Optional PRNG key for dropout
763
+ x_t: Input token indices, shape (seq_len)
764
+ context_s: Context from encoder (token indices or embedded),
765
+ shape (context_len, embed_dim)
766
+ self_mask: Optional self-attention mask, shape (seq_len, seq_len)
767
+ cross_mask: Optional cross-attention mask, shape (seq_len, context_len)
540
768
  cache: Optional dictionary containing cached key and value tensors
541
- update_cache: Whether to update the cache and return it
542
769
 
543
770
  Returns:
544
- If update_cache is False: Decoded representation
545
- If update_cache is True: Tuple of (decoded representation, updated cache)
771
+ The decoded representation and the updated cache
546
772
  """
547
- # Token embedding
548
- x_embedded = jax.vmap(lambda x_seq: jax.vmap(self.token_embedding)(x_seq))(x)
549
-
550
- # Add positional embedding
551
- x_embedded = self._add_positional_embedding(x_embedded, positions)
552
-
553
- # Initialize layer caches
554
- if cache is None and update_cache:
555
- cache = {f"layer_{i}": {} for i in range(len(self.layers))}
556
-
557
- # Updated cache will be built if needed
558
- updated_cache = {}
559
-
560
- # Apply transformer layers with cross-attention
561
- keys: Array | list[None] = [None] * len(self.layers)
562
- if key is not None:
563
- keys = jax.random.split(key, len(self.layers))
564
-
565
- for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
566
- layer_cache = None if cache is None else cache.get(f"layer_{i}")
567
-
568
- if update_cache:
569
- x_embedded, layer_updated_cache = layer.__call__(
570
- x_embedded,
571
- context=context,
572
- self_mask=self_mask,
573
- cross_mask=cross_mask,
574
- key=layer_key,
575
- cache=layer_cache,
576
- update_cache=True,
577
- )
578
- updated_cache[f"layer_{i}"] = layer_updated_cache
579
- else:
580
- x_embedded = layer(
581
- x_embedded,
582
- context=context,
583
- self_mask=self_mask,
584
- cross_mask=cross_mask,
585
- key=layer_key,
586
- cache=layer_cache,
587
- )
773
+ # Token embedding for x
774
+ x_embedded = jax.vmap(self.token_embedding)(x_t)
775
+
776
+ # Token embedding for context if needed
777
+ context_embedded = jax.vmap(self.token_embedding)(context_s)
778
+
779
+ # Apply transformer stack with cross-attention
780
+ x_embedded, updated_cache = self.layers.forward(
781
+ x_embedded,
782
+ context_sn=context_embedded,
783
+ self_mask=self_mask,
784
+ cross_mask=cross_mask,
785
+ cache=cache,
786
+ )
588
787
 
589
788
  # Apply final layer norm
590
789
  output = jax.vmap(self.layer_norm)(x_embedded)
591
790
 
592
- if update_cache:
593
- return output, updated_cache
594
- return output
595
-
596
- @overload
597
- def __call__(
598
- self,
599
- x: Array,
600
- *,
601
- mask: Array | None = None,
602
- positions: Array | None = None,
603
- key: PRNGKeyArray | None = None,
604
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
605
- update_cache: Literal[True],
606
- ) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
607
-
608
- @overload
609
- def __call__(
610
- self,
611
- x: Array,
612
- *,
613
- mask: Array | None = None,
614
- positions: Array | None = None,
615
- key: PRNGKeyArray | None = None,
616
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
617
- update_cache: Literal[False] = False,
618
- ) -> Array: ...
791
+ return output, updated_cache
619
792
 
620
- def __call__(
793
+ def forward(
621
794
  self,
622
795
  x: Array,
623
796
  *,
624
797
  mask: Array | None = None,
625
- positions: Array | None = None,
626
- key: PRNGKeyArray | None = None,
627
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
628
- update_cache: bool = False,
629
- ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
798
+ cache: TransformerCache | None = None,
799
+ ) -> tuple[Array, TransformerCache]:
630
800
  """Forward pass for encoder-only or decoder-only transformers.
631
801
 
632
802
  Args:
633
803
  x: Input token indices of shape (seq_len)
634
804
  mask: Optional attention mask
635
- positions: Optional positions
636
- key: Optional PRNG key for dropout
637
805
  cache: Optional dictionary containing cached key and value tensors
638
- update_cache: Whether to update the cache and return it
639
806
 
640
807
  Returns:
641
- If update_cache is False: Output representation
642
- If update_cache is True: Tuple of (output representation, updated cache)
808
+ The output representation and the updated cache
643
809
  """
644
810
  chex.assert_rank(x, 1)
645
811
 
646
- if update_cache:
647
- output, updated_cache = self.encode(
648
- x, mask=mask, positions=positions, key=key, cache=cache, update_cache=True
649
- )
650
- else:
651
- output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
812
+ output, updated_cache = self.encode(
813
+ x,
814
+ mask=mask,
815
+ cache=cache,
816
+ )
652
817
 
653
818
  # Apply output layer if it exists
654
819
  if self.output_layer is not None:
655
820
  output = jax.vmap(self.output_layer)(output)
656
821
 
657
- if update_cache:
658
- return output, updated_cache
659
- return output
822
+ return output, updated_cache
660
823
 
661
824
  def predict_sequence(self, x_seq: Array) -> Array:
662
- return self(x=x_seq)
825
+ output, _ = self.forward(x=x_seq)
826
+ return output
663
827
 
664
828
  def generate_sequence(
665
829
  self,
@@ -685,54 +849,40 @@ class Transformer(eqx.Module):
685
849
  key = jax.random.PRNGKey(0)
686
850
 
687
851
  prompt_len = prompt_seq.shape[0]
688
- sequence = prompt_seq
689
852
 
690
- # Create causal mask for generation
691
- causal_mask = jnp.tril(jnp.ones((self.max_seq_len, self.max_seq_len), dtype=jnp.bool_))
692
-
693
- # Initialize cache with the prompt
694
- _, cache = self(x=prompt_seq, mask=causal_mask[:prompt_len, :prompt_len], update_cache=True)
695
-
696
- # Define decode step function (for clarity)
697
- def decode_step(seq: Array, pos: int, cur_cache: dict, rng: PRNGKeyArray) -> tuple[Array, dict, PRNGKeyArray]:
698
- # Get the next position and last token
699
- pos_tensor = jnp.array([pos])
700
- last_token = seq[-1:]
701
-
702
- # Get logits for next token
703
- rng, subrng = jax.random.split(rng)
704
- logits, new_cache = self(
705
- x=last_token,
706
- positions=pos_tensor,
707
- key=subrng,
708
- cache=cur_cache,
709
- update_cache=True,
853
+ total_len = prompt_len + max_len
854
+ output_seq = jnp.zeros(total_len, dtype=prompt_seq.dtype)
855
+ output_seq = output_seq.at[:prompt_len].set(prompt_seq)
856
+
857
+ # Initialize cache with prompt
858
+ cache = self.init_cache()
859
+ _, cache = self.encode(prompt_seq, cache=cache)
860
+
861
+ # Define scan function for autoregressive generation
862
+ def scan_fn(
863
+ carry: tuple[Array, int, TransformerCache, PRNGKeyArray],
864
+ _: None,
865
+ ) -> tuple[tuple[Array, int, TransformerCache, PRNGKeyArray], Array]:
866
+ output_seq, pos, cache, rng = carry
867
+ current_token = jax.lax.dynamic_slice(output_seq, (pos,), (1,))
868
+
869
+ # Forward pass with cache update
870
+ logits, new_cache = self.forward(
871
+ x=current_token,
872
+ cache=cache,
710
873
  )
711
874
 
712
- # Extract final logits and apply temperature
713
875
  logits = logits[-1] / temperature
714
-
715
- # Apply top-k sampling if specified
716
876
  if top_k is not None:
717
877
  top_logits, top_indices = jax.lax.top_k(logits, top_k)
718
878
  logits = jnp.full_like(logits, float("-inf"))
719
879
  logits = logits.at[top_indices].set(top_logits)
720
-
721
- # Sample next token
722
880
  rng, subrng = jax.random.split(rng)
723
881
  next_token = jax.random.categorical(subrng, logits[None, ...])[0]
882
+ new_output_seq = jax.lax.dynamic_update_slice(output_seq, next_token[None], (pos + 1,))
724
883
 
725
- # Add token to sequence
726
- new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
727
- return new_seq, new_cache, rng
728
-
729
- # Generate tokens one by one
730
- for _ in range(max_len):
731
- # Break if max sequence length reached
732
- if sequence.shape[0] >= self.max_seq_len:
733
- break
734
-
735
- # Decode next token
736
- sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
884
+ return (new_output_seq, pos + 1, new_cache, rng), next_token
737
885
 
738
- return sequence
886
+ init_carry = (output_seq, prompt_len - 1, cache, key)
887
+ (final_seq, _, _, _), _ = jax.lax.scan(scan_fn, init_carry, length=max_len)
888
+ return final_seq