xax 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/nn/attention.py CHANGED
@@ -1,6 +1,13 @@
1
- """Attention mechanisms for transformer models."""
1
+ """Attention mechanisms for transformer models.
2
2
 
3
- from typing import Literal, cast, overload
3
+ This module implements standard attention mechanisms for transformers, but
4
+ supporting a fixed-size context window and caching that can be used to train
5
+ transformers which can be unrolled with a fixed-length cache.
6
+ """
7
+
8
+ import math
9
+ import warnings
10
+ from typing import NotRequired, TypedDict
4
11
 
5
12
  import chex
6
13
  import equinox as eqx
@@ -8,17 +15,129 @@ import jax
8
15
  import jax.numpy as jnp
9
16
  from jaxtyping import Array, PRNGKeyArray
10
17
 
18
+ from xax.utils.jax import scan as xax_scan
19
+
20
+
21
+ class RotaryEmbedding(eqx.Module):
22
+ """Rotary Position Embedding (RoPE) for transformer attention.
23
+
24
+ This implements the rotary position embedding as described in:
25
+ "RoFormer: Enhanced Transformer with Rotary Position Embedding"
26
+ https://arxiv.org/abs/2104.09864
27
+ """
28
+
29
+ head_dim: int = eqx.field()
30
+ base: float = eqx.field()
31
+
32
+ def __init__(
33
+ self,
34
+ head_dim: int,
35
+ base: float = 10000.0,
36
+ ) -> None:
37
+ """Initialize rotary embedding.
38
+
39
+ Args:
40
+ head_dim: Dimension of each attention head
41
+ base: Base for the frequency computation
42
+ """
43
+ self.head_dim = head_dim
44
+ self.base = base
45
+
46
+ def _get_rotary_embeddings(self, positions: Array, dtype: jnp.dtype) -> tuple[Array, Array]:
47
+ """Get rotary embeddings for a given sequence length.
48
+
49
+ Args:
50
+ positions: Positions of the sequence
51
+ dtype: Data type for the embeddings
52
+
53
+ Returns:
54
+ Tuple of (cos_embeddings, sin_embeddings) of shape (seq_len, head_dim//2)
55
+ """
56
+ # Create frequency bands
57
+ dim = self.head_dim // 2
58
+ freqs = jnp.exp(-jnp.arange(0, dim, dtype=dtype) * jnp.log(self.base) / dim)
59
+
60
+ # Compute angles
61
+ angles = positions[:, None] * freqs[None, :] # (seq_len, dim)
62
+
63
+ # Compute cos and sin embeddings
64
+ cos_embeddings = jnp.cos(angles)
65
+ sin_embeddings = jnp.sin(angles)
66
+
67
+ return cos_embeddings, sin_embeddings
68
+
69
+ def apply_rotary_embeddings(
70
+ self,
71
+ x: Array,
72
+ positions: Array | None = None,
73
+ ) -> Array:
74
+ """Apply rotary embeddings to input tensor.
75
+
76
+ Args:
77
+ x: Input tensor of shape (seq_len, num_heads, head_dim)
78
+ positions: Optional position indices of shape (seq_len,)
79
+ If None, uses sequential positions starting from 0
80
+
81
+ Returns:
82
+ Tensor with rotary embeddings applied, same shape as input
83
+ """
84
+ seq_len, _, head_dim = x.shape
85
+ assert head_dim == self.head_dim, f"Expected head_dim {self.head_dim}, got {head_dim}"
86
+
87
+ # Get rotary embeddings
88
+ if positions is None:
89
+ positions = jnp.arange(seq_len, dtype=x.dtype)
90
+ cos_emb, sin_emb = self._get_rotary_embeddings(positions, x.dtype)
91
+
92
+ # Reshape to (seq_len, 1, head_dim//2) for broadcasting
93
+ cos_emb = cos_emb[:, None, :] # (seq_len, 1, head_dim//2)
94
+ sin_emb = sin_emb[:, None, :] # (seq_len, 1, head_dim//2)
95
+
96
+ # Split input into even and odd dimensions
97
+ x_even = x[..., ::2] # (seq_len, num_heads, head_dim//2)
98
+ x_odd = x[..., 1::2] # (seq_len, num_heads, head_dim//2)
99
+
100
+ # Apply rotation
101
+ rotated_even = x_even * cos_emb - x_odd * sin_emb
102
+ rotated_odd = x_even * sin_emb + x_odd * cos_emb
103
+
104
+ # Interleave back together
105
+ result = jnp.zeros_like(x)
106
+ result = result.at[..., ::2].set(rotated_even)
107
+ result = result.at[..., 1::2].set(rotated_odd)
108
+
109
+ return result
110
+
111
+
112
+ class AttentionCache(TypedDict):
113
+ k: Array
114
+ v: Array
115
+ position: int # Position counter for rotary embeddings
116
+
117
+
118
+ class AttentionCacheDict(TypedDict):
119
+ self_attn: AttentionCache
120
+ cross_attn: NotRequired[AttentionCache]
121
+
122
+
123
+ class TransformerCache(TypedDict):
124
+ """Cache for the entire transformer stack."""
125
+
126
+ layers: dict[str, AttentionCacheDict]
127
+
11
128
 
12
129
  class SelfAttentionBlock(eqx.Module):
13
130
  """Self-attention block using jax.nn.dot_product_attention."""
14
131
 
15
- q_proj: eqx.nn.Linear
16
- k_proj: eqx.nn.Linear
17
- v_proj: eqx.nn.Linear
18
- output_proj: eqx.nn.Linear
19
- num_heads: int = eqx.static_field()
20
- head_dim: int = eqx.static_field()
21
- causal: bool = eqx.static_field()
132
+ q_proj: eqx.nn.Linear = eqx.field()
133
+ k_proj: eqx.nn.Linear = eqx.field()
134
+ v_proj: eqx.nn.Linear = eqx.field()
135
+ output_proj: eqx.nn.Linear = eqx.field()
136
+ rotary_emb: RotaryEmbedding | None = eqx.field()
137
+ num_heads: int = eqx.field()
138
+ head_dim: int = eqx.field()
139
+ causal: bool = eqx.field()
140
+ local_window_size: int | None = eqx.field()
22
141
 
23
142
  def __init__(
24
143
  self,
@@ -27,7 +146,13 @@ class SelfAttentionBlock(eqx.Module):
27
146
  *,
28
147
  key: PRNGKeyArray,
29
148
  causal: bool = False,
149
+ context_length: int | None = None,
150
+ use_rotary_embeddings: bool = False,
151
+ rotary_base: float = 10000.0,
30
152
  ) -> None:
153
+ if context_length is not None:
154
+ assert context_length > 1, "context_length must be at least 2"
155
+
31
156
  keys = jax.random.split(key, 4)
32
157
 
33
158
  self.num_heads = num_heads
@@ -39,7 +164,25 @@ class SelfAttentionBlock(eqx.Module):
39
164
  self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
40
165
  self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
41
166
 
167
+ # Initialize rotary embeddings if requested
168
+ if use_rotary_embeddings:
169
+ self.rotary_emb = RotaryEmbedding(
170
+ head_dim=self.head_dim,
171
+ base=rotary_base,
172
+ )
173
+ else:
174
+ self.rotary_emb = None
175
+
176
+ if context_length is not None and not causal:
177
+ warnings.warn("context_length is set but causal is False; overriding causal to True", stacklevel=2)
178
+ causal = True
179
+
42
180
  self.causal = causal
181
+ self.local_window_size = None if context_length is None else context_length - 1
182
+
183
+ @property
184
+ def embed_dim(self) -> int:
185
+ return self.head_dim * self.num_heads
43
186
 
44
187
  def _reshape_for_multihead(self, x: Array) -> Array:
45
188
  """Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
@@ -48,75 +191,100 @@ class SelfAttentionBlock(eqx.Module):
48
191
 
49
192
  def _combine_heads(self, x: Array) -> Array:
50
193
  """Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
51
- seq_len, _, _ = x.shape
52
- return x.reshape(seq_len, -1)
194
+ _, n, h = x.shape
195
+ return x.reshape(-1, n * h)
196
+
197
+ def init_cache(self, dtype: jnp.dtype | None = None) -> AttentionCache:
198
+ """Initialize cache for the input.
199
+
200
+ Args:
201
+ dtype: The dtype of the cache
53
202
 
54
- def __call__(
203
+ Returns:
204
+ Cache with fixed-length k and v tensors
205
+ """
206
+ if self.local_window_size is None:
207
+ raise ValueError("context_length must be set for caching")
208
+
209
+ # Create fixed-length cache
210
+ k_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
211
+ v_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
212
+
213
+ return {"k": k_cache, "v": v_cache, "position": 0}
214
+
215
+ def forward(
55
216
  self,
56
- x: Array,
217
+ x_tn: Array,
57
218
  *,
58
- key: PRNGKeyArray | None = None,
59
- mask: Array | None = None,
60
- cache: dict[str, Array] | None = None,
61
- update_cache: bool = False,
62
- ) -> Array | tuple[Array, dict[str, Array]]:
63
- """Apply self-attention to the input.
219
+ cache: AttentionCache | None = None,
220
+ ) -> tuple[Array, AttentionCache]:
221
+ """Apply self-attention.
64
222
 
65
223
  Args:
66
- x: Input tensor of shape (seq_len, embed_dim)
67
- key: PRNGKey for dropout randomness
68
- mask: Optional mask tensor of shape (seq_len, seq_len) or broadcastable
69
- cache: Optional dictionary containing cached key and value tensors
70
- update_cache: Whether to update the cache and return it
224
+ x_tn: Input tensor of shape (seq_len, embed_dim)
225
+ cache: The cached key and value tensors (fixed-length)
71
226
 
72
227
  Returns:
73
- If update_cache is False: Output tensor of shape (seq_len, embed_dim)
74
- If update_cache is True: Tuple of (output tensor, updated cache)
228
+ The output tensor of shape (seq_len, embed_dim) and updated cache
75
229
  """
76
- chex.assert_rank(x, 2)
230
+ chex.assert_rank(x_tn, 2)
77
231
 
78
232
  # Project inputs to queries, keys, and values
79
- q = jax.vmap(self.q_proj)(x)
80
-
81
- # Use cached key/value if provided and not updating cache
82
- if cache is not None and not update_cache:
83
- k = cache["k"]
84
- v = cache["v"]
85
- else:
86
- k = jax.vmap(self.k_proj)(x)
87
- v = jax.vmap(self.v_proj)(x)
88
-
89
- # Update cache if needed
90
- if update_cache:
91
- if cache is None:
92
- cache = {}
93
- cache = {"k": k, "v": v}
233
+ q = jax.vmap(self.q_proj)(x_tn)
234
+ k = jax.vmap(self.k_proj)(x_tn)
235
+ v = jax.vmap(self.v_proj)(x_tn)
94
236
 
95
237
  # Reshape to multihead format
96
238
  q = self._reshape_for_multihead(q)
97
239
  k = self._reshape_for_multihead(k)
98
240
  v = self._reshape_for_multihead(v)
99
241
 
100
- # Apply dot product attention.
101
- # Note that Apple Silicon struggles with this:
102
- # https://github.com/jax-ml/jax/issues/20114
242
+ seq_len = q.shape[0]
243
+ if self.rotary_emb is not None:
244
+ # Determine position indices for rotary embeddings
245
+ if cache is not None:
246
+ start_pos = cache["position"]
247
+ else:
248
+ start_pos = 0
249
+ positions = jnp.arange(seq_len) + start_pos
250
+ q = self.rotary_emb.apply_rotary_embeddings(q, positions=positions)
251
+ k = self.rotary_emb.apply_rotary_embeddings(k, positions=positions)
252
+
253
+ if cache is not None:
254
+ k_cache = cache["k"]
255
+ v_cache = cache["v"]
256
+ k = jnp.concatenate([k_cache, k], axis=0)
257
+ v = jnp.concatenate([v_cache, v], axis=0)
258
+
259
+ # Pads query with `k_cache.shape[0]` zeros.
260
+ q = jnp.pad(q, ((k_cache.shape[0], 0), (0, 0), (0, 0)), mode="constant", constant_values=0)
261
+
262
+ new_position = cache["position"] + seq_len
263
+
264
+ else:
265
+ new_position = seq_len
266
+
103
267
  attn_output = jax.nn.dot_product_attention(
104
268
  q,
105
269
  k,
106
270
  v,
107
- mask=mask,
108
- is_causal=self.causal and mask is None,
271
+ scale=1.0 / math.sqrt(self.head_dim),
272
+ is_causal=self.causal,
273
+ local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
109
274
  )
110
275
 
111
- # Combine heads
112
- attn_output = self._combine_heads(attn_output)
276
+ if cache is not None:
277
+ # Remove the padding.
278
+ attn_output = attn_output[cache["k"].shape[0] :]
113
279
 
114
- # Final projection
280
+ attn_output = self._combine_heads(attn_output)
115
281
  output = jax.vmap(self.output_proj)(attn_output)
116
282
 
117
- if update_cache:
118
- return output, cast(dict[str, Array], cache)
119
- return output
283
+ if self.local_window_size is not None:
284
+ k = k[-self.local_window_size :]
285
+ v = v[-self.local_window_size :]
286
+
287
+ return output, {"k": k, "v": v, "position": new_position}
120
288
 
121
289
 
122
290
  class CrossAttentionBlock(eqx.Module):
@@ -126,8 +294,9 @@ class CrossAttentionBlock(eqx.Module):
126
294
  k_proj: eqx.nn.Linear
127
295
  v_proj: eqx.nn.Linear
128
296
  output_proj: eqx.nn.Linear
129
- num_heads: int = eqx.static_field()
130
- head_dim: int = eqx.static_field()
297
+ rotary_emb: RotaryEmbedding | None
298
+ num_heads: int = eqx.field()
299
+ head_dim: int = eqx.field()
131
300
 
132
301
  def __init__(
133
302
  self,
@@ -135,6 +304,8 @@ class CrossAttentionBlock(eqx.Module):
135
304
  num_heads: int,
136
305
  *,
137
306
  key: PRNGKeyArray,
307
+ use_rotary_embeddings: bool = False,
308
+ rotary_base: float = 10000.0,
138
309
  ) -> None:
139
310
  keys = jax.random.split(key, 4)
140
311
 
@@ -147,6 +318,15 @@ class CrossAttentionBlock(eqx.Module):
147
318
  self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
148
319
  self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
149
320
 
321
+ # Initialize rotary embeddings if requested
322
+ if use_rotary_embeddings:
323
+ self.rotary_emb = RotaryEmbedding(
324
+ head_dim=self.head_dim,
325
+ base=rotary_base,
326
+ )
327
+ else:
328
+ self.rotary_emb = None
329
+
150
330
  def _reshape_for_multihead(self, x: Array) -> Array:
151
331
  """Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
152
332
  seq_len, _ = x.shape
@@ -157,61 +337,72 @@ class CrossAttentionBlock(eqx.Module):
157
337
  seq_len, _, _ = x.shape
158
338
  return x.reshape(seq_len, -1)
159
339
 
160
- def __call__(
340
+ def init_cache(self, kv_sn: Array) -> AttentionCache:
341
+ """Initialize cache for the input."""
342
+ chex.assert_rank(kv_sn, 2)
343
+ k = jax.vmap(self.k_proj)(kv_sn)
344
+ v = jax.vmap(self.v_proj)(kv_sn)
345
+ # Reshape to multihead format
346
+ k = self._reshape_for_multihead(k)
347
+ v = self._reshape_for_multihead(v)
348
+ return {"k": k, "v": v, "position": 0}
349
+
350
+ def forward(
161
351
  self,
162
- q_input: Array,
163
- kv_input: Array,
352
+ q_tn: Array,
164
353
  *,
165
- key: PRNGKeyArray | None = None,
166
- mask: Array | None = None,
167
- cache: dict[str, Array] | None = None,
168
- update_cache: bool = False,
169
- ) -> Array | tuple[Array, dict[str, Array]]:
354
+ kv_sn: Array | None = None,
355
+ cache: AttentionCache | None = None,
356
+ ) -> tuple[Array, AttentionCache]:
170
357
  """Apply cross-attention.
171
358
 
172
359
  Args:
173
- q_input: Query input tensor of shape (q_seq_len, embed_dim)
174
- kv_input: Key/value input tensor of shape (kv_seq_len, embed_dim)
175
- key: PRNGKey for dropout randomness
176
- mask: Optional mask tensor
177
- cache: Optional dictionary containing cached key and value tensors
178
- update_cache: Whether to update the cache and return it
360
+ q_tn: Query input tensor of shape (q_seq_len, embed_dim)
361
+ kv_sn: Key/value input tensor of shape (kv_seq_len, embed_dim).
362
+ If not provided, then `cache` must be provided.
363
+ cache: The cached key and value tensors. If not provided, then
364
+ `kv_sn` must be provided.
179
365
 
180
366
  Returns:
181
- If update_cache is False: Output tensor of shape (q_seq_len, embed_dim)
182
- If update_cache is True: Tuple of (output tensor, updated cache)
367
+ The output tensor of shape (q_seq_len, embed_dim)
183
368
  """
184
- chex.assert_rank(q_input, 2)
185
- chex.assert_rank(kv_input, 2)
369
+ chex.assert_rank(q_tn, 2)
186
370
 
187
371
  # Project inputs to queries, keys, and values
188
- q = jax.vmap(self.q_proj)(q_input)
372
+ q = jax.vmap(self.q_proj)(q_tn)
373
+ q = self._reshape_for_multihead(q)
374
+ q_seq_len = q.shape[0]
189
375
 
190
- # Use cached key/value if provided and not updating cache
191
- if cache is not None and not update_cache:
376
+ # Use cached key/value if provided
377
+ if cache is not None:
192
378
  k = cache["k"]
193
379
  v = cache["v"]
380
+ q_position = cache["position"]
381
+ elif kv_sn is not None:
382
+ chex.assert_rank(kv_sn, 2)
383
+ k = jax.vmap(self.k_proj)(kv_sn)
384
+ v = jax.vmap(self.v_proj)(kv_sn)
385
+ k = self._reshape_for_multihead(k)
386
+ v = self._reshape_for_multihead(v)
387
+ q_position = 0
194
388
  else:
195
- k = jax.vmap(self.k_proj)(kv_input)
196
- v = jax.vmap(self.v_proj)(kv_input)
389
+ raise ValueError("Either `cache` or `kv_sn` must be provided.")
197
390
 
198
- # Update cache if needed
199
- if update_cache:
200
- if cache is None:
201
- cache = {}
202
- cache = {"k": k, "v": v}
203
-
204
- # Reshape to multihead format
205
- q = self._reshape_for_multihead(q)
206
- k = self._reshape_for_multihead(k)
207
- v = self._reshape_for_multihead(v)
391
+ # Apply rotary embeddings to queries and keys if enabled
392
+ if self.rotary_emb is None:
393
+ q_rot = q
394
+ k_rot = k
395
+ else:
396
+ q_positions = jnp.arange(q_seq_len) + q_position
397
+ k_positions = jnp.arange(k.shape[0])
398
+ q_rot = self.rotary_emb.apply_rotary_embeddings(q, positions=q_positions)
399
+ k_rot = self.rotary_emb.apply_rotary_embeddings(k, positions=k_positions)
208
400
 
209
401
  # Apply dot product attention
210
402
  attn_output = jax.nn.dot_product_attention(
211
- q,
212
- k,
403
+ q_rot,
404
+ k_rot,
213
405
  v,
214
- mask=mask,
215
406
  is_causal=False,
216
407
  )
217
408
 
@@ -221,9 +412,7 @@ class CrossAttentionBlock(eqx.Module):
221
412
  # Final projection
222
413
  output = jax.vmap(self.output_proj)(attn_output)
223
414
 
224
- if update_cache:
225
- return output, cast(dict[str, Array], cache)
226
- return output
415
+ return output, {"k": k, "v": v, "position": q_position + q_seq_len}
227
416
 
228
417
 
229
418
  class TransformerBlock(eqx.Module):
@@ -233,6 +422,10 @@ class TransformerBlock(eqx.Module):
233
422
  layer_norm1: eqx.nn.LayerNorm
234
423
  layer_norm2: eqx.nn.LayerNorm
235
424
  layer_norm3: eqx.nn.LayerNorm | None
425
+ num_heads: int = eqx.field()
426
+ head_dim: int = eqx.field()
427
+ causal: bool = eqx.field()
428
+ context_length: int | None = eqx.field()
236
429
 
237
430
  def __init__(
238
431
  self,
@@ -243,14 +436,20 @@ class TransformerBlock(eqx.Module):
243
436
  key: PRNGKeyArray,
244
437
  causal: bool = False,
245
438
  cross_attention: bool = False,
439
+ context_length: int | None = None,
440
+ use_rotary_embeddings: bool = False,
441
+ rotary_base: float = 10000.0,
246
442
  ) -> None:
247
- keys = jax.random.split(key, 4)
443
+ keys = jax.random.split(key, 3)
248
444
 
249
445
  self.self_attn = SelfAttentionBlock(
250
446
  embed_dim=embed_dim,
251
447
  num_heads=num_heads,
252
448
  key=keys[0],
253
449
  causal=causal,
450
+ context_length=context_length,
451
+ use_rotary_embeddings=use_rotary_embeddings,
452
+ rotary_base=rotary_base,
254
453
  )
255
454
 
256
455
  if cross_attention:
@@ -258,8 +457,11 @@ class TransformerBlock(eqx.Module):
258
457
  embed_dim=embed_dim,
259
458
  num_heads=num_heads,
260
459
  key=keys[1],
460
+ use_rotary_embeddings=use_rotary_embeddings,
461
+ rotary_base=rotary_base,
261
462
  )
262
463
  self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
464
+
263
465
  else:
264
466
  self.cross_attn = None
265
467
  self.layer_norm3 = None
@@ -276,390 +478,311 @@ class TransformerBlock(eqx.Module):
276
478
  key=keys[2],
277
479
  )
278
480
 
279
- @overload
280
- def __call__(
281
- self,
282
- x: Array,
283
- *,
284
- context: Array | None = None,
285
- self_mask: Array | None = None,
286
- cross_mask: Array | None = None,
287
- key: PRNGKeyArray | None = None,
288
- cache: dict[str, dict[str, Array]] | None = None,
289
- update_cache: Literal[True],
290
- ) -> tuple[Array, dict[str, dict[str, Array]]]: ...
291
-
292
- @overload
293
- def __call__(
294
- self,
295
- x: Array,
296
- *,
297
- context: Array | None = None,
298
- self_mask: Array | None = None,
299
- cross_mask: Array | None = None,
300
- key: PRNGKeyArray | None = None,
301
- cache: dict[str, dict[str, Array]] | None = None,
302
- update_cache: Literal[False] = False,
303
- ) -> Array: ...
304
-
305
- def __call__(
481
+ self.num_heads = num_heads
482
+ self.head_dim = embed_dim // num_heads
483
+ self.causal = causal
484
+ self.context_length = context_length
485
+
486
+ @property
487
+ def embed_dim(self) -> int:
488
+ return self.head_dim * self.num_heads
489
+
490
+ def init_cache(self, dtype: jnp.dtype | None = None, context_sn: Array | None = None) -> AttentionCacheDict:
491
+ """Initialize cache for the input."""
492
+ if dtype is None and context_sn is not None:
493
+ dtype = context_sn.dtype
494
+ cache: AttentionCacheDict = {"self_attn": self.self_attn.init_cache(dtype=dtype)}
495
+ if self.cross_attn is not None:
496
+ if context_sn is None:
497
+ raise ValueError("x_tn must be provided if cross_attn is not None")
498
+ cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
499
+ return cache
500
+
501
+ def forward(
306
502
  self,
307
- x: Array,
503
+ x_tn: Array,
308
504
  *,
309
- context: Array | None = None,
310
- self_mask: Array | None = None,
311
- cross_mask: Array | None = None,
312
- key: PRNGKeyArray | None = None,
313
- cache: dict[str, dict[str, Array]] | None = None,
314
- update_cache: bool = False,
315
- ) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
505
+ context_sn: Array | None = None,
506
+ cache: AttentionCacheDict | None = None,
507
+ ) -> tuple[Array, AttentionCacheDict]:
316
508
  """Apply transformer block.
317
509
 
318
510
  Args:
319
- x: Input tensor
320
- context: Optional context for cross-attention
321
- self_mask: Mask for self-attention
322
- cross_mask: Mask for cross-attention
323
- key: Optional PRNG key for dropout
511
+ x_tn: Input tensor of shape (seq_len, embed_dim)
512
+ context_sn: Optional context for cross-attention
324
513
  cache: Optional dictionary containing cached key and value tensors
325
- update_cache: Whether to update the cache and return it
326
514
 
327
515
  Returns:
328
- If update_cache is False: Output tensor
329
- If update_cache is True: Tuple of (output tensor, updated cache)
516
+ The output tensor and the updated cache
330
517
  """
331
- chex.assert_rank(x, 2)
332
- if key is not None:
333
- key1, key2 = jax.random.split(key)
334
- else:
335
- key1 = key2 = None
336
-
337
- # Initialize cache if needed
338
- updated_cache = {}
339
- if cache is None:
340
- cache = {}
518
+ chex.assert_rank(x_tn, 2)
341
519
 
342
520
  # Self-attention block with pre-norm
343
- norm_x = jax.vmap(self.layer_norm1)(x)
521
+ norm_x = jax.vmap(self.layer_norm1)(x_tn)
344
522
 
345
- self_attn_cache = cache.get("self_attn")
346
- if update_cache:
347
- attn_output, self_attn_cache = self.self_attn(
348
- norm_x, key=key1, mask=self_mask, cache=self_attn_cache, update_cache=True
349
- )
350
- updated_cache["self_attn"] = self_attn_cache
351
- else:
352
- attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
523
+ attn_output, self_attn_cache = self.self_attn.forward(
524
+ x_tn=norm_x,
525
+ cache=None if cache is None else cache["self_attn"],
526
+ )
527
+ updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
353
528
 
354
- x = x + attn_output
529
+ x_tn = x_tn + attn_output
355
530
 
356
531
  # Cross-attention block (if enabled) with pre-norm
357
- if self.cross_attn is not None and context is not None:
532
+ if self.cross_attn is not None:
358
533
  assert self.layer_norm3 is not None
359
534
 
360
- norm_x = jax.vmap(self.layer_norm3)(x)
361
- cross_attn_cache = cache.get("cross_attn")
535
+ norm_x = jax.vmap(self.layer_norm3)(x_tn)
362
536
 
363
- if update_cache:
364
- cross_attn_output, cross_attn_cache = self.cross_attn(
365
- norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache, update_cache=True
366
- )
367
- updated_cache["cross_attn"] = cross_attn_cache
368
- else:
369
- cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
537
+ cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
538
+ q_tn=norm_x,
539
+ kv_sn=context_sn,
540
+ cache=None if cache is None else cache.get("cross_attn"),
541
+ )
370
542
 
371
- x = x + cross_attn_output
543
+ x_tn = x_tn + cross_attn_output
372
544
 
373
545
  # Feed-forward block with pre-norm
374
- norm_x = jax.vmap(self.layer_norm2)(x)
546
+ norm_x = jax.vmap(self.layer_norm2)(x_tn)
375
547
  ff_output = jax.vmap(self.feed_forward)(norm_x)
376
- x = x + ff_output
548
+ x_tn = x_tn + ff_output
377
549
 
378
- if update_cache:
379
- return x, updated_cache
380
- return x
550
+ return x_tn, updated_cache
381
551
 
382
552
 
383
- class Transformer(eqx.Module):
384
- token_embedding: eqx.nn.Embedding
385
- position_embedding: eqx.nn.Embedding | None
386
- layers: list[TransformerBlock]
387
- output_layer: eqx.nn.Linear | None
388
- layer_norm: eqx.nn.LayerNorm
389
- max_seq_len: int = eqx.static_field()
390
- embed_dim: int = eqx.static_field()
553
+ class TransformerStack(eqx.Module):
554
+ """A stack of transformer blocks."""
555
+
556
+ layers: tuple[TransformerBlock, ...]
557
+ num_layers: int = eqx.field()
558
+ causal: bool = eqx.field()
391
559
 
392
560
  def __init__(
393
561
  self,
394
- vocab_size: int,
395
562
  embed_dim: int,
396
563
  num_heads: int,
397
564
  ff_dim: int,
398
565
  num_layers: int,
399
- max_seq_len: int,
400
- output_size: int | None = None,
401
566
  *,
402
567
  key: PRNGKeyArray,
403
568
  causal: bool = False,
404
569
  cross_attention: bool = False,
405
- use_absolute_position: bool = True,
570
+ context_length: int | None = None,
571
+ use_rotary_embeddings: bool = False,
572
+ rotary_base: float = 10000.0,
406
573
  ) -> None:
407
- keys = jax.random.split(key, num_layers + 3)
408
-
409
- self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
410
-
411
- # Position embeddings can be disabled
412
- if use_absolute_position:
413
- self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
414
- else:
415
- self.position_embedding = None
574
+ keys = jax.random.split(key, num_layers)
416
575
 
417
- self.layers = [
576
+ self.layers = tuple(
418
577
  TransformerBlock(
419
578
  embed_dim=embed_dim,
420
579
  num_heads=num_heads,
421
580
  ff_dim=ff_dim,
422
- key=keys[i + 2],
581
+ key=keys[i],
423
582
  causal=causal,
424
583
  cross_attention=cross_attention,
584
+ context_length=context_length,
585
+ use_rotary_embeddings=use_rotary_embeddings,
586
+ rotary_base=rotary_base,
425
587
  )
426
588
  for i in range(num_layers)
427
- ]
589
+ )
428
590
 
429
- self.layer_norm = eqx.nn.LayerNorm(embed_dim)
591
+ self.num_layers = num_layers
592
+ self.causal = causal
593
+
594
+ def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
595
+ """Initialize cache for the input."""
596
+ cache = {}
597
+ for i, layer in enumerate(self.layers):
598
+ cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
599
+ return {"layers": cache}
600
+
601
+ def forward(
602
+ self,
603
+ x_tn: Array,
604
+ *,
605
+ context_sn: Array | None = None,
606
+ cache: TransformerCache | None = None,
607
+ ) -> tuple[Array, TransformerCache]:
608
+ """Apply transformer stack.
609
+
610
+ Args:
611
+ x_tn: Input tensor of shape (seq_len, embed_dim)
612
+ context_sn: Optional context for cross-attention
613
+ cache: Optional dictionary containing cached key and value tensors
430
614
 
615
+ Returns:
616
+ The output tensor and the updated cache
617
+ """
618
+ # Initialize layer caches
619
+ if cache is None:
620
+ cache = {"layers": {}}
621
+
622
+ # Updated cache will be built
623
+ updated_cache: TransformerCache = {"layers": {}}
624
+
625
+ # Apply transformer layers
626
+ for i, layer in enumerate(self.layers):
627
+ layer_cache = cache["layers"].get(f"layer_{i}")
628
+
629
+ x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
630
+ x_tn,
631
+ context_sn=context_sn,
632
+ cache=layer_cache,
633
+ )
634
+
635
+ return x_tn, updated_cache
636
+
637
+
638
+ class Transformer(eqx.Module):
639
+ token_embedding: eqx.nn.Embedding
640
+ layers: TransformerStack
641
+ output_layer: eqx.nn.Linear | None
642
+ layer_norm: eqx.nn.LayerNorm
643
+ embed_dim: int = eqx.field()
644
+ causal: bool = eqx.field()
645
+ context_length: int | None = eqx.field()
646
+
647
+ def __init__(
648
+ self,
649
+ vocab_size: int,
650
+ embed_dim: int,
651
+ num_heads: int,
652
+ ff_dim: int,
653
+ num_layers: int,
654
+ output_size: int | None = None,
655
+ *,
656
+ key: PRNGKeyArray,
657
+ causal: bool = False,
658
+ cross_attention: bool = False,
659
+ context_length: int | None = None,
660
+ use_rotary_embeddings: bool = False,
661
+ rotary_base: float = 10000.0,
662
+ ) -> None:
663
+ # Calculate number of keys needed
664
+ num_keys = 3 if output_size is None else 4
665
+ keys = jax.random.split(key, num_keys)
666
+
667
+ self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
668
+
669
+ self.layers = TransformerStack(
670
+ embed_dim=embed_dim,
671
+ num_heads=num_heads,
672
+ ff_dim=ff_dim,
673
+ num_layers=num_layers,
674
+ key=keys[2],
675
+ causal=causal,
676
+ cross_attention=cross_attention,
677
+ context_length=context_length,
678
+ use_rotary_embeddings=use_rotary_embeddings,
679
+ rotary_base=rotary_base,
680
+ )
681
+
682
+ self.layer_norm = eqx.nn.LayerNorm(embed_dim)
431
683
  if output_size is not None:
432
- self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[-1])
684
+ self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[3])
433
685
  else:
434
686
  self.output_layer = None
435
687
 
436
- self.max_seq_len = max_seq_len
437
688
  self.embed_dim = embed_dim
689
+ self.causal = causal
690
+ self.context_length = context_length
438
691
 
439
- def _add_positional_embedding(self, x_embedded: Array, positions: Array | None = None) -> Array:
440
- """Add positional embeddings to the token embeddings."""
441
- if self.position_embedding is None:
442
- return x_embedded
443
-
444
- seq_len, _ = x_embedded.shape
445
-
446
- if positions is None:
447
- positions = jnp.arange(seq_len)
448
- pos_embedded = jax.vmap(self.position_embedding)(positions)
449
-
450
- return x_embedded + pos_embedded
692
+ def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
693
+ """Initialize cache for the input."""
694
+ return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
451
695
 
452
696
  def encode(
453
697
  self,
454
698
  x: Array,
455
- mask: Array | None = None,
456
- positions: Array | None = None,
457
- key: PRNGKeyArray | None = None,
458
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
459
- update_cache: bool = False,
460
- ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
699
+ *,
700
+ cache: TransformerCache | None = None,
701
+ ) -> tuple[Array, TransformerCache]:
461
702
  """Encode the input sequence.
462
703
 
463
704
  Args:
464
705
  x: Input token indices of shape (seq_len)
465
- mask: Optional attention mask
466
- positions: Optional positions
467
- key: Optional PRNG key for dropout
468
706
  cache: Optional dictionary containing cached key and value tensors
469
- update_cache: Whether to update the cache and return it
470
707
 
471
708
  Returns:
472
- If update_cache is False: Encoded representation
473
- If update_cache is True: Tuple of (encoded representation, updated cache)
709
+ The encoded representation and the updated cache
474
710
  """
475
711
  # Token embedding
476
712
  x_embedded = jax.vmap(self.token_embedding)(x)
477
713
 
478
- # Add positional embedding
479
- x_embedded = self._add_positional_embedding(x_embedded, positions)
480
-
481
- # Initialize layer caches
482
- if cache is None and update_cache:
483
- cache = {f"layer_{i}": {} for i in range(len(self.layers))}
484
-
485
- # Updated cache will be built if needed
486
- updated_cache = {}
487
-
488
- # Apply transformer layers
489
- keys: Array | list[None] = [None] * len(self.layers)
490
- if key is not None:
491
- keys = jax.random.split(key, len(self.layers))
492
-
493
- for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
494
- layer_cache = None if cache is None else cache.get(f"layer_{i}")
495
-
496
- if update_cache:
497
- x_embedded, layer_updated_cache = layer.__call__(
498
- x_embedded,
499
- self_mask=mask,
500
- key=layer_key,
501
- cache=layer_cache,
502
- update_cache=True,
503
- )
504
- updated_cache[f"layer_{i}"] = layer_updated_cache
505
- else:
506
- x_embedded = layer.__call__(
507
- x_embedded,
508
- self_mask=mask,
509
- key=layer_key,
510
- cache=layer_cache,
511
- )
714
+ # Apply transformer stack
715
+ x_embedded, updated_cache = self.layers.forward(x_embedded, cache=cache)
512
716
 
513
717
  # Apply final layer norm
514
718
  output = jax.vmap(self.layer_norm)(x_embedded)
515
719
 
516
- if update_cache:
517
- return output, updated_cache
518
- return output
720
+ return output, updated_cache
519
721
 
520
722
  def decode(
521
723
  self,
522
- x: Array,
523
- context: Array,
524
- self_mask: Array | None = None,
525
- cross_mask: Array | None = None,
526
- positions: Array | None = None,
527
- key: PRNGKeyArray | None = None,
528
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
529
- update_cache: bool = False,
530
- ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
724
+ x_t: Array,
725
+ context_s: Array,
726
+ *,
727
+ cache: TransformerCache | None = None,
728
+ ) -> tuple[Array, TransformerCache]:
531
729
  """Decode with self-attention and cross-attention.
532
730
 
533
731
  Args:
534
- x: Input token indices
535
- context: Context from encoder
536
- self_mask: Optional self-attention mask
537
- cross_mask: Optional cross-attention mask
538
- positions: Optional positions
539
- key: Optional PRNG key for dropout
732
+ x_t: Input token indices, shape (seq_len)
733
+ context_s: Context from encoder (token indices or embedded),
734
+ shape (context_len, embed_dim)
540
735
  cache: Optional dictionary containing cached key and value tensors
541
- update_cache: Whether to update the cache and return it
542
736
 
543
737
  Returns:
544
- If update_cache is False: Decoded representation
545
- If update_cache is True: Tuple of (decoded representation, updated cache)
738
+ The decoded representation and the updated cache
546
739
  """
547
- # Token embedding
548
- x_embedded = jax.vmap(lambda x_seq: jax.vmap(self.token_embedding)(x_seq))(x)
740
+ # Token embedding for x
741
+ x_embedded = jax.vmap(self.token_embedding)(x_t)
549
742
 
550
- # Add positional embedding
551
- x_embedded = self._add_positional_embedding(x_embedded, positions)
743
+ # Token embedding for context if needed
744
+ context_embedded = jax.vmap(self.token_embedding)(context_s)
552
745
 
553
- # Initialize layer caches
554
- if cache is None and update_cache:
555
- cache = {f"layer_{i}": {} for i in range(len(self.layers))}
556
-
557
- # Updated cache will be built if needed
558
- updated_cache = {}
559
-
560
- # Apply transformer layers with cross-attention
561
- keys: Array | list[None] = [None] * len(self.layers)
562
- if key is not None:
563
- keys = jax.random.split(key, len(self.layers))
564
-
565
- for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
566
- layer_cache = None if cache is None else cache.get(f"layer_{i}")
567
-
568
- if update_cache:
569
- x_embedded, layer_updated_cache = layer.__call__(
570
- x_embedded,
571
- context=context,
572
- self_mask=self_mask,
573
- cross_mask=cross_mask,
574
- key=layer_key,
575
- cache=layer_cache,
576
- update_cache=True,
577
- )
578
- updated_cache[f"layer_{i}"] = layer_updated_cache
579
- else:
580
- x_embedded = layer(
581
- x_embedded,
582
- context=context,
583
- self_mask=self_mask,
584
- cross_mask=cross_mask,
585
- key=layer_key,
586
- cache=layer_cache,
587
- )
746
+ # Apply transformer stack with cross-attention
747
+ x_embedded, updated_cache = self.layers.forward(
748
+ x_embedded,
749
+ context_sn=context_embedded,
750
+ cache=cache,
751
+ )
588
752
 
589
753
  # Apply final layer norm
590
754
  output = jax.vmap(self.layer_norm)(x_embedded)
591
755
 
592
- if update_cache:
593
- return output, updated_cache
594
- return output
595
-
596
- @overload
597
- def __call__(
598
- self,
599
- x: Array,
600
- *,
601
- mask: Array | None = None,
602
- positions: Array | None = None,
603
- key: PRNGKeyArray | None = None,
604
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
605
- update_cache: Literal[True],
606
- ) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
607
-
608
- @overload
609
- def __call__(
610
- self,
611
- x: Array,
612
- *,
613
- mask: Array | None = None,
614
- positions: Array | None = None,
615
- key: PRNGKeyArray | None = None,
616
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
617
- update_cache: Literal[False] = False,
618
- ) -> Array: ...
756
+ return output, updated_cache
619
757
 
620
- def __call__(
758
+ def forward(
621
759
  self,
622
760
  x: Array,
623
761
  *,
624
- mask: Array | None = None,
625
- positions: Array | None = None,
626
- key: PRNGKeyArray | None = None,
627
- cache: dict[str, dict[str, dict[str, Array]]] | None = None,
628
- update_cache: bool = False,
629
- ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
762
+ cache: TransformerCache | None = None,
763
+ ) -> tuple[Array, TransformerCache]:
630
764
  """Forward pass for encoder-only or decoder-only transformers.
631
765
 
632
766
  Args:
633
767
  x: Input token indices of shape (seq_len)
634
- mask: Optional attention mask
635
- positions: Optional positions
636
- key: Optional PRNG key for dropout
637
768
  cache: Optional dictionary containing cached key and value tensors
638
- update_cache: Whether to update the cache and return it
639
769
 
640
770
  Returns:
641
- If update_cache is False: Output representation
642
- If update_cache is True: Tuple of (output representation, updated cache)
771
+ The output representation and the updated cache
643
772
  """
644
773
  chex.assert_rank(x, 1)
645
774
 
646
- if update_cache:
647
- output, updated_cache = self.encode(
648
- x, mask=mask, positions=positions, key=key, cache=cache, update_cache=True
649
- )
650
- else:
651
- output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
775
+ output, updated_cache = self.encode(x, cache=cache)
652
776
 
653
777
  # Apply output layer if it exists
654
778
  if self.output_layer is not None:
655
779
  output = jax.vmap(self.output_layer)(output)
656
780
 
657
- if update_cache:
658
- return output, updated_cache
659
- return output
781
+ return output, updated_cache
660
782
 
661
783
  def predict_sequence(self, x_seq: Array) -> Array:
662
- return self(x=x_seq)
784
+ output, _ = self.forward(x=x_seq)
785
+ return output
663
786
 
664
787
  def generate_sequence(
665
788
  self,
@@ -668,6 +791,7 @@ class Transformer(eqx.Module):
668
791
  temperature: float = 1.0,
669
792
  top_k: int | None = None,
670
793
  key: PRNGKeyArray | None = None,
794
+ jit_level: int | None = None,
671
795
  ) -> Array:
672
796
  """Generate a sequence autoregressively with KV caching.
673
797
 
@@ -677,6 +801,7 @@ class Transformer(eqx.Module):
677
801
  temperature: Sampling temperature
678
802
  top_k: Optional top-k sampling parameter
679
803
  key: PRNG key for sampling
804
+ jit_level: JIT level for the scan function
680
805
 
681
806
  Returns:
682
807
  Generated sequence of shape (prompt_len + max_len,)
@@ -685,54 +810,40 @@ class Transformer(eqx.Module):
685
810
  key = jax.random.PRNGKey(0)
686
811
 
687
812
  prompt_len = prompt_seq.shape[0]
688
- sequence = prompt_seq
689
-
690
- # Create causal mask for generation
691
- causal_mask = jnp.tril(jnp.ones((self.max_seq_len, self.max_seq_len), dtype=jnp.bool_))
692
813
 
693
- # Initialize cache with the prompt
694
- _, cache = self(x=prompt_seq, mask=causal_mask[:prompt_len, :prompt_len], update_cache=True)
695
-
696
- # Define decode step function (for clarity)
697
- def decode_step(seq: Array, pos: int, cur_cache: dict, rng: PRNGKeyArray) -> tuple[Array, dict, PRNGKeyArray]:
698
- # Get the next position and last token
699
- pos_tensor = jnp.array([pos])
700
- last_token = seq[-1:]
701
-
702
- # Get logits for next token
703
- rng, subrng = jax.random.split(rng)
704
- logits, new_cache = self(
705
- x=last_token,
706
- positions=pos_tensor,
707
- key=subrng,
708
- cache=cur_cache,
709
- update_cache=True,
814
+ total_len = prompt_len + max_len
815
+ output_seq = jnp.zeros(total_len, dtype=prompt_seq.dtype)
816
+ output_seq = output_seq.at[:prompt_len].set(prompt_seq)
817
+
818
+ # Initialize cache with prompt
819
+ cache = self.init_cache()
820
+ _, cache = self.encode(prompt_seq, cache=cache)
821
+
822
+ # Define scan function for autoregressive generation
823
+ def scan_fn(
824
+ carry: tuple[Array, int, TransformerCache, PRNGKeyArray],
825
+ _: None,
826
+ ) -> tuple[tuple[Array, int, TransformerCache, PRNGKeyArray], Array]:
827
+ output_seq, pos, cache, rng = carry
828
+ current_token = jax.lax.dynamic_slice(output_seq, (pos,), (1,))
829
+
830
+ # Forward pass with cache update
831
+ logits, new_cache = self.forward(
832
+ x=current_token,
833
+ cache=cache,
710
834
  )
711
835
 
712
- # Extract final logits and apply temperature
713
836
  logits = logits[-1] / temperature
714
-
715
- # Apply top-k sampling if specified
716
837
  if top_k is not None:
717
838
  top_logits, top_indices = jax.lax.top_k(logits, top_k)
718
839
  logits = jnp.full_like(logits, float("-inf"))
719
840
  logits = logits.at[top_indices].set(top_logits)
720
-
721
- # Sample next token
722
841
  rng, subrng = jax.random.split(rng)
723
842
  next_token = jax.random.categorical(subrng, logits[None, ...])[0]
843
+ new_output_seq = jax.lax.dynamic_update_slice(output_seq, next_token[None], (pos + 1,))
724
844
 
725
- # Add token to sequence
726
- new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
727
- return new_seq, new_cache, rng
728
-
729
- # Generate tokens one by one
730
- for _ in range(max_len):
731
- # Break if max sequence length reached
732
- if sequence.shape[0] >= self.max_seq_len:
733
- break
734
-
735
- # Decode next token
736
- sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
845
+ return (new_output_seq, pos + 1, new_cache, rng), next_token
737
846
 
738
- return sequence
847
+ init_carry = (output_seq, prompt_len - 1, cache, key)
848
+ (final_seq, _, _, _), _ = xax_scan(scan_fn, init_carry, length=max_len, jit_level=jit_level)
849
+ return final_seq