xax 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +23 -8
- xax/nn/attention.py +519 -408
- xax/nn/embeddings.py +10 -10
- xax/nn/geom.py +5 -5
- xax/nn/ssm.py +6 -6
- xax/task/mixins/train.py +6 -1
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/METADATA +1 -1
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/RECORD +12 -12
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/WHEEL +0 -0
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/entry_points.txt +0 -0
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/top_level.txt +0 -0
xax/nn/attention.py
CHANGED
@@ -1,6 +1,13 @@
|
|
1
|
-
"""Attention mechanisms for transformer models.
|
1
|
+
"""Attention mechanisms for transformer models.
|
2
2
|
|
3
|
-
|
3
|
+
This module implements standard attention mechanisms for transformers, but
|
4
|
+
supporting a fixed-size context window and caching that can be used to train
|
5
|
+
transformers which can be unrolled with a fixed-length cache.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import math
|
9
|
+
import warnings
|
10
|
+
from typing import NotRequired, TypedDict
|
4
11
|
|
5
12
|
import chex
|
6
13
|
import equinox as eqx
|
@@ -8,17 +15,129 @@ import jax
|
|
8
15
|
import jax.numpy as jnp
|
9
16
|
from jaxtyping import Array, PRNGKeyArray
|
10
17
|
|
18
|
+
from xax.utils.jax import scan as xax_scan
|
19
|
+
|
20
|
+
|
21
|
+
class RotaryEmbedding(eqx.Module):
|
22
|
+
"""Rotary Position Embedding (RoPE) for transformer attention.
|
23
|
+
|
24
|
+
This implements the rotary position embedding as described in:
|
25
|
+
"RoFormer: Enhanced Transformer with Rotary Position Embedding"
|
26
|
+
https://arxiv.org/abs/2104.09864
|
27
|
+
"""
|
28
|
+
|
29
|
+
head_dim: int = eqx.field()
|
30
|
+
base: float = eqx.field()
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
head_dim: int,
|
35
|
+
base: float = 10000.0,
|
36
|
+
) -> None:
|
37
|
+
"""Initialize rotary embedding.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
head_dim: Dimension of each attention head
|
41
|
+
base: Base for the frequency computation
|
42
|
+
"""
|
43
|
+
self.head_dim = head_dim
|
44
|
+
self.base = base
|
45
|
+
|
46
|
+
def _get_rotary_embeddings(self, positions: Array, dtype: jnp.dtype) -> tuple[Array, Array]:
|
47
|
+
"""Get rotary embeddings for a given sequence length.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
positions: Positions of the sequence
|
51
|
+
dtype: Data type for the embeddings
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
Tuple of (cos_embeddings, sin_embeddings) of shape (seq_len, head_dim//2)
|
55
|
+
"""
|
56
|
+
# Create frequency bands
|
57
|
+
dim = self.head_dim // 2
|
58
|
+
freqs = jnp.exp(-jnp.arange(0, dim, dtype=dtype) * jnp.log(self.base) / dim)
|
59
|
+
|
60
|
+
# Compute angles
|
61
|
+
angles = positions[:, None] * freqs[None, :] # (seq_len, dim)
|
62
|
+
|
63
|
+
# Compute cos and sin embeddings
|
64
|
+
cos_embeddings = jnp.cos(angles)
|
65
|
+
sin_embeddings = jnp.sin(angles)
|
66
|
+
|
67
|
+
return cos_embeddings, sin_embeddings
|
68
|
+
|
69
|
+
def apply_rotary_embeddings(
|
70
|
+
self,
|
71
|
+
x: Array,
|
72
|
+
positions: Array | None = None,
|
73
|
+
) -> Array:
|
74
|
+
"""Apply rotary embeddings to input tensor.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
x: Input tensor of shape (seq_len, num_heads, head_dim)
|
78
|
+
positions: Optional position indices of shape (seq_len,)
|
79
|
+
If None, uses sequential positions starting from 0
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Tensor with rotary embeddings applied, same shape as input
|
83
|
+
"""
|
84
|
+
seq_len, _, head_dim = x.shape
|
85
|
+
assert head_dim == self.head_dim, f"Expected head_dim {self.head_dim}, got {head_dim}"
|
86
|
+
|
87
|
+
# Get rotary embeddings
|
88
|
+
if positions is None:
|
89
|
+
positions = jnp.arange(seq_len, dtype=x.dtype)
|
90
|
+
cos_emb, sin_emb = self._get_rotary_embeddings(positions, x.dtype)
|
91
|
+
|
92
|
+
# Reshape to (seq_len, 1, head_dim//2) for broadcasting
|
93
|
+
cos_emb = cos_emb[:, None, :] # (seq_len, 1, head_dim//2)
|
94
|
+
sin_emb = sin_emb[:, None, :] # (seq_len, 1, head_dim//2)
|
95
|
+
|
96
|
+
# Split input into even and odd dimensions
|
97
|
+
x_even = x[..., ::2] # (seq_len, num_heads, head_dim//2)
|
98
|
+
x_odd = x[..., 1::2] # (seq_len, num_heads, head_dim//2)
|
99
|
+
|
100
|
+
# Apply rotation
|
101
|
+
rotated_even = x_even * cos_emb - x_odd * sin_emb
|
102
|
+
rotated_odd = x_even * sin_emb + x_odd * cos_emb
|
103
|
+
|
104
|
+
# Interleave back together
|
105
|
+
result = jnp.zeros_like(x)
|
106
|
+
result = result.at[..., ::2].set(rotated_even)
|
107
|
+
result = result.at[..., 1::2].set(rotated_odd)
|
108
|
+
|
109
|
+
return result
|
110
|
+
|
111
|
+
|
112
|
+
class AttentionCache(TypedDict):
|
113
|
+
k: Array
|
114
|
+
v: Array
|
115
|
+
position: int # Position counter for rotary embeddings
|
116
|
+
|
117
|
+
|
118
|
+
class AttentionCacheDict(TypedDict):
|
119
|
+
self_attn: AttentionCache
|
120
|
+
cross_attn: NotRequired[AttentionCache]
|
121
|
+
|
122
|
+
|
123
|
+
class TransformerCache(TypedDict):
|
124
|
+
"""Cache for the entire transformer stack."""
|
125
|
+
|
126
|
+
layers: dict[str, AttentionCacheDict]
|
127
|
+
|
11
128
|
|
12
129
|
class SelfAttentionBlock(eqx.Module):
|
13
130
|
"""Self-attention block using jax.nn.dot_product_attention."""
|
14
131
|
|
15
|
-
q_proj: eqx.nn.Linear
|
16
|
-
k_proj: eqx.nn.Linear
|
17
|
-
v_proj: eqx.nn.Linear
|
18
|
-
output_proj: eqx.nn.Linear
|
19
|
-
|
20
|
-
|
21
|
-
|
132
|
+
q_proj: eqx.nn.Linear = eqx.field()
|
133
|
+
k_proj: eqx.nn.Linear = eqx.field()
|
134
|
+
v_proj: eqx.nn.Linear = eqx.field()
|
135
|
+
output_proj: eqx.nn.Linear = eqx.field()
|
136
|
+
rotary_emb: RotaryEmbedding | None = eqx.field()
|
137
|
+
num_heads: int = eqx.field()
|
138
|
+
head_dim: int = eqx.field()
|
139
|
+
causal: bool = eqx.field()
|
140
|
+
local_window_size: int | None = eqx.field()
|
22
141
|
|
23
142
|
def __init__(
|
24
143
|
self,
|
@@ -27,7 +146,13 @@ class SelfAttentionBlock(eqx.Module):
|
|
27
146
|
*,
|
28
147
|
key: PRNGKeyArray,
|
29
148
|
causal: bool = False,
|
149
|
+
context_length: int | None = None,
|
150
|
+
use_rotary_embeddings: bool = False,
|
151
|
+
rotary_base: float = 10000.0,
|
30
152
|
) -> None:
|
153
|
+
if context_length is not None:
|
154
|
+
assert context_length > 1, "context_length must be at least 2"
|
155
|
+
|
31
156
|
keys = jax.random.split(key, 4)
|
32
157
|
|
33
158
|
self.num_heads = num_heads
|
@@ -39,7 +164,25 @@ class SelfAttentionBlock(eqx.Module):
|
|
39
164
|
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
40
165
|
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
41
166
|
|
167
|
+
# Initialize rotary embeddings if requested
|
168
|
+
if use_rotary_embeddings:
|
169
|
+
self.rotary_emb = RotaryEmbedding(
|
170
|
+
head_dim=self.head_dim,
|
171
|
+
base=rotary_base,
|
172
|
+
)
|
173
|
+
else:
|
174
|
+
self.rotary_emb = None
|
175
|
+
|
176
|
+
if context_length is not None and not causal:
|
177
|
+
warnings.warn("context_length is set but causal is False; overriding causal to True", stacklevel=2)
|
178
|
+
causal = True
|
179
|
+
|
42
180
|
self.causal = causal
|
181
|
+
self.local_window_size = None if context_length is None else context_length - 1
|
182
|
+
|
183
|
+
@property
|
184
|
+
def embed_dim(self) -> int:
|
185
|
+
return self.head_dim * self.num_heads
|
43
186
|
|
44
187
|
def _reshape_for_multihead(self, x: Array) -> Array:
|
45
188
|
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
@@ -48,75 +191,100 @@ class SelfAttentionBlock(eqx.Module):
|
|
48
191
|
|
49
192
|
def _combine_heads(self, x: Array) -> Array:
|
50
193
|
"""Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
|
51
|
-
|
52
|
-
return x.reshape(
|
194
|
+
_, n, h = x.shape
|
195
|
+
return x.reshape(-1, n * h)
|
196
|
+
|
197
|
+
def init_cache(self, dtype: jnp.dtype | None = None) -> AttentionCache:
|
198
|
+
"""Initialize cache for the input.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
dtype: The dtype of the cache
|
53
202
|
|
54
|
-
|
203
|
+
Returns:
|
204
|
+
Cache with fixed-length k and v tensors
|
205
|
+
"""
|
206
|
+
if self.local_window_size is None:
|
207
|
+
raise ValueError("context_length must be set for caching")
|
208
|
+
|
209
|
+
# Create fixed-length cache
|
210
|
+
k_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
|
211
|
+
v_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
|
212
|
+
|
213
|
+
return {"k": k_cache, "v": v_cache, "position": 0}
|
214
|
+
|
215
|
+
def forward(
|
55
216
|
self,
|
56
|
-
|
217
|
+
x_tn: Array,
|
57
218
|
*,
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
update_cache: bool = False,
|
62
|
-
) -> Array | tuple[Array, dict[str, Array]]:
|
63
|
-
"""Apply self-attention to the input.
|
219
|
+
cache: AttentionCache | None = None,
|
220
|
+
) -> tuple[Array, AttentionCache]:
|
221
|
+
"""Apply self-attention.
|
64
222
|
|
65
223
|
Args:
|
66
|
-
|
67
|
-
|
68
|
-
mask: Optional mask tensor of shape (seq_len, seq_len) or broadcastable
|
69
|
-
cache: Optional dictionary containing cached key and value tensors
|
70
|
-
update_cache: Whether to update the cache and return it
|
224
|
+
x_tn: Input tensor of shape (seq_len, embed_dim)
|
225
|
+
cache: The cached key and value tensors (fixed-length)
|
71
226
|
|
72
227
|
Returns:
|
73
|
-
|
74
|
-
If update_cache is True: Tuple of (output tensor, updated cache)
|
228
|
+
The output tensor of shape (seq_len, embed_dim) and updated cache
|
75
229
|
"""
|
76
|
-
chex.assert_rank(
|
230
|
+
chex.assert_rank(x_tn, 2)
|
77
231
|
|
78
232
|
# Project inputs to queries, keys, and values
|
79
|
-
q = jax.vmap(self.q_proj)(
|
80
|
-
|
81
|
-
|
82
|
-
if cache is not None and not update_cache:
|
83
|
-
k = cache["k"]
|
84
|
-
v = cache["v"]
|
85
|
-
else:
|
86
|
-
k = jax.vmap(self.k_proj)(x)
|
87
|
-
v = jax.vmap(self.v_proj)(x)
|
88
|
-
|
89
|
-
# Update cache if needed
|
90
|
-
if update_cache:
|
91
|
-
if cache is None:
|
92
|
-
cache = {}
|
93
|
-
cache = {"k": k, "v": v}
|
233
|
+
q = jax.vmap(self.q_proj)(x_tn)
|
234
|
+
k = jax.vmap(self.k_proj)(x_tn)
|
235
|
+
v = jax.vmap(self.v_proj)(x_tn)
|
94
236
|
|
95
237
|
# Reshape to multihead format
|
96
238
|
q = self._reshape_for_multihead(q)
|
97
239
|
k = self._reshape_for_multihead(k)
|
98
240
|
v = self._reshape_for_multihead(v)
|
99
241
|
|
100
|
-
|
101
|
-
|
102
|
-
|
242
|
+
seq_len = q.shape[0]
|
243
|
+
if self.rotary_emb is not None:
|
244
|
+
# Determine position indices for rotary embeddings
|
245
|
+
if cache is not None:
|
246
|
+
start_pos = cache["position"]
|
247
|
+
else:
|
248
|
+
start_pos = 0
|
249
|
+
positions = jnp.arange(seq_len) + start_pos
|
250
|
+
q = self.rotary_emb.apply_rotary_embeddings(q, positions=positions)
|
251
|
+
k = self.rotary_emb.apply_rotary_embeddings(k, positions=positions)
|
252
|
+
|
253
|
+
if cache is not None:
|
254
|
+
k_cache = cache["k"]
|
255
|
+
v_cache = cache["v"]
|
256
|
+
k = jnp.concatenate([k_cache, k], axis=0)
|
257
|
+
v = jnp.concatenate([v_cache, v], axis=0)
|
258
|
+
|
259
|
+
# Pads query with `k_cache.shape[0]` zeros.
|
260
|
+
q = jnp.pad(q, ((k_cache.shape[0], 0), (0, 0), (0, 0)), mode="constant", constant_values=0)
|
261
|
+
|
262
|
+
new_position = cache["position"] + seq_len
|
263
|
+
|
264
|
+
else:
|
265
|
+
new_position = seq_len
|
266
|
+
|
103
267
|
attn_output = jax.nn.dot_product_attention(
|
104
268
|
q,
|
105
269
|
k,
|
106
270
|
v,
|
107
|
-
|
108
|
-
is_causal=self.causal
|
271
|
+
scale=1.0 / math.sqrt(self.head_dim),
|
272
|
+
is_causal=self.causal,
|
273
|
+
local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
|
109
274
|
)
|
110
275
|
|
111
|
-
|
112
|
-
|
276
|
+
if cache is not None:
|
277
|
+
# Remove the padding.
|
278
|
+
attn_output = attn_output[cache["k"].shape[0] :]
|
113
279
|
|
114
|
-
|
280
|
+
attn_output = self._combine_heads(attn_output)
|
115
281
|
output = jax.vmap(self.output_proj)(attn_output)
|
116
282
|
|
117
|
-
if
|
118
|
-
|
119
|
-
|
283
|
+
if self.local_window_size is not None:
|
284
|
+
k = k[-self.local_window_size :]
|
285
|
+
v = v[-self.local_window_size :]
|
286
|
+
|
287
|
+
return output, {"k": k, "v": v, "position": new_position}
|
120
288
|
|
121
289
|
|
122
290
|
class CrossAttentionBlock(eqx.Module):
|
@@ -126,8 +294,9 @@ class CrossAttentionBlock(eqx.Module):
|
|
126
294
|
k_proj: eqx.nn.Linear
|
127
295
|
v_proj: eqx.nn.Linear
|
128
296
|
output_proj: eqx.nn.Linear
|
129
|
-
|
130
|
-
|
297
|
+
rotary_emb: RotaryEmbedding | None
|
298
|
+
num_heads: int = eqx.field()
|
299
|
+
head_dim: int = eqx.field()
|
131
300
|
|
132
301
|
def __init__(
|
133
302
|
self,
|
@@ -135,6 +304,8 @@ class CrossAttentionBlock(eqx.Module):
|
|
135
304
|
num_heads: int,
|
136
305
|
*,
|
137
306
|
key: PRNGKeyArray,
|
307
|
+
use_rotary_embeddings: bool = False,
|
308
|
+
rotary_base: float = 10000.0,
|
138
309
|
) -> None:
|
139
310
|
keys = jax.random.split(key, 4)
|
140
311
|
|
@@ -147,6 +318,15 @@ class CrossAttentionBlock(eqx.Module):
|
|
147
318
|
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
148
319
|
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
149
320
|
|
321
|
+
# Initialize rotary embeddings if requested
|
322
|
+
if use_rotary_embeddings:
|
323
|
+
self.rotary_emb = RotaryEmbedding(
|
324
|
+
head_dim=self.head_dim,
|
325
|
+
base=rotary_base,
|
326
|
+
)
|
327
|
+
else:
|
328
|
+
self.rotary_emb = None
|
329
|
+
|
150
330
|
def _reshape_for_multihead(self, x: Array) -> Array:
|
151
331
|
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
152
332
|
seq_len, _ = x.shape
|
@@ -157,61 +337,72 @@ class CrossAttentionBlock(eqx.Module):
|
|
157
337
|
seq_len, _, _ = x.shape
|
158
338
|
return x.reshape(seq_len, -1)
|
159
339
|
|
160
|
-
def
|
340
|
+
def init_cache(self, kv_sn: Array) -> AttentionCache:
|
341
|
+
"""Initialize cache for the input."""
|
342
|
+
chex.assert_rank(kv_sn, 2)
|
343
|
+
k = jax.vmap(self.k_proj)(kv_sn)
|
344
|
+
v = jax.vmap(self.v_proj)(kv_sn)
|
345
|
+
# Reshape to multihead format
|
346
|
+
k = self._reshape_for_multihead(k)
|
347
|
+
v = self._reshape_for_multihead(v)
|
348
|
+
return {"k": k, "v": v, "position": 0}
|
349
|
+
|
350
|
+
def forward(
|
161
351
|
self,
|
162
|
-
|
163
|
-
kv_input: Array,
|
352
|
+
q_tn: Array,
|
164
353
|
*,
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
update_cache: bool = False,
|
169
|
-
) -> Array | tuple[Array, dict[str, Array]]:
|
354
|
+
kv_sn: Array | None = None,
|
355
|
+
cache: AttentionCache | None = None,
|
356
|
+
) -> tuple[Array, AttentionCache]:
|
170
357
|
"""Apply cross-attention.
|
171
358
|
|
172
359
|
Args:
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
update_cache: Whether to update the cache and return it
|
360
|
+
q_tn: Query input tensor of shape (q_seq_len, embed_dim)
|
361
|
+
kv_sn: Key/value input tensor of shape (kv_seq_len, embed_dim).
|
362
|
+
If not provided, then `cache` must be provided.
|
363
|
+
cache: The cached key and value tensors. If not provided, then
|
364
|
+
`kv_sn` must be provided.
|
179
365
|
|
180
366
|
Returns:
|
181
|
-
|
182
|
-
If update_cache is True: Tuple of (output tensor, updated cache)
|
367
|
+
The output tensor of shape (q_seq_len, embed_dim)
|
183
368
|
"""
|
184
|
-
chex.assert_rank(
|
185
|
-
chex.assert_rank(kv_input, 2)
|
369
|
+
chex.assert_rank(q_tn, 2)
|
186
370
|
|
187
371
|
# Project inputs to queries, keys, and values
|
188
|
-
q = jax.vmap(self.q_proj)(
|
372
|
+
q = jax.vmap(self.q_proj)(q_tn)
|
373
|
+
q = self._reshape_for_multihead(q)
|
374
|
+
q_seq_len = q.shape[0]
|
189
375
|
|
190
|
-
# Use cached key/value if provided
|
191
|
-
if cache is not None
|
376
|
+
# Use cached key/value if provided
|
377
|
+
if cache is not None:
|
192
378
|
k = cache["k"]
|
193
379
|
v = cache["v"]
|
380
|
+
q_position = cache["position"]
|
381
|
+
elif kv_sn is not None:
|
382
|
+
chex.assert_rank(kv_sn, 2)
|
383
|
+
k = jax.vmap(self.k_proj)(kv_sn)
|
384
|
+
v = jax.vmap(self.v_proj)(kv_sn)
|
385
|
+
k = self._reshape_for_multihead(k)
|
386
|
+
v = self._reshape_for_multihead(v)
|
387
|
+
q_position = 0
|
194
388
|
else:
|
195
|
-
|
196
|
-
v = jax.vmap(self.v_proj)(kv_input)
|
389
|
+
raise ValueError("Either `cache` or `kv_sn` must be provided.")
|
197
390
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
v = self._reshape_for_multihead(v)
|
391
|
+
# Apply rotary embeddings to queries and keys if enabled
|
392
|
+
if self.rotary_emb is None:
|
393
|
+
q_rot = q
|
394
|
+
k_rot = k
|
395
|
+
else:
|
396
|
+
q_positions = jnp.arange(q_seq_len) + q_position
|
397
|
+
k_positions = jnp.arange(k.shape[0])
|
398
|
+
q_rot = self.rotary_emb.apply_rotary_embeddings(q, positions=q_positions)
|
399
|
+
k_rot = self.rotary_emb.apply_rotary_embeddings(k, positions=k_positions)
|
208
400
|
|
209
401
|
# Apply dot product attention
|
210
402
|
attn_output = jax.nn.dot_product_attention(
|
211
|
-
|
212
|
-
|
403
|
+
q_rot,
|
404
|
+
k_rot,
|
213
405
|
v,
|
214
|
-
mask=mask,
|
215
406
|
is_causal=False,
|
216
407
|
)
|
217
408
|
|
@@ -221,9 +412,7 @@ class CrossAttentionBlock(eqx.Module):
|
|
221
412
|
# Final projection
|
222
413
|
output = jax.vmap(self.output_proj)(attn_output)
|
223
414
|
|
224
|
-
|
225
|
-
return output, cast(dict[str, Array], cache)
|
226
|
-
return output
|
415
|
+
return output, {"k": k, "v": v, "position": q_position + q_seq_len}
|
227
416
|
|
228
417
|
|
229
418
|
class TransformerBlock(eqx.Module):
|
@@ -233,6 +422,10 @@ class TransformerBlock(eqx.Module):
|
|
233
422
|
layer_norm1: eqx.nn.LayerNorm
|
234
423
|
layer_norm2: eqx.nn.LayerNorm
|
235
424
|
layer_norm3: eqx.nn.LayerNorm | None
|
425
|
+
num_heads: int = eqx.field()
|
426
|
+
head_dim: int = eqx.field()
|
427
|
+
causal: bool = eqx.field()
|
428
|
+
context_length: int | None = eqx.field()
|
236
429
|
|
237
430
|
def __init__(
|
238
431
|
self,
|
@@ -243,14 +436,20 @@ class TransformerBlock(eqx.Module):
|
|
243
436
|
key: PRNGKeyArray,
|
244
437
|
causal: bool = False,
|
245
438
|
cross_attention: bool = False,
|
439
|
+
context_length: int | None = None,
|
440
|
+
use_rotary_embeddings: bool = False,
|
441
|
+
rotary_base: float = 10000.0,
|
246
442
|
) -> None:
|
247
|
-
keys = jax.random.split(key,
|
443
|
+
keys = jax.random.split(key, 3)
|
248
444
|
|
249
445
|
self.self_attn = SelfAttentionBlock(
|
250
446
|
embed_dim=embed_dim,
|
251
447
|
num_heads=num_heads,
|
252
448
|
key=keys[0],
|
253
449
|
causal=causal,
|
450
|
+
context_length=context_length,
|
451
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
452
|
+
rotary_base=rotary_base,
|
254
453
|
)
|
255
454
|
|
256
455
|
if cross_attention:
|
@@ -258,8 +457,11 @@ class TransformerBlock(eqx.Module):
|
|
258
457
|
embed_dim=embed_dim,
|
259
458
|
num_heads=num_heads,
|
260
459
|
key=keys[1],
|
460
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
461
|
+
rotary_base=rotary_base,
|
261
462
|
)
|
262
463
|
self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
|
464
|
+
|
263
465
|
else:
|
264
466
|
self.cross_attn = None
|
265
467
|
self.layer_norm3 = None
|
@@ -276,390 +478,311 @@ class TransformerBlock(eqx.Module):
|
|
276
478
|
key=keys[2],
|
277
479
|
)
|
278
480
|
|
279
|
-
|
280
|
-
|
281
|
-
self
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
key: PRNGKeyArray | None = None,
|
301
|
-
cache: dict[str, dict[str, Array]] | None = None,
|
302
|
-
update_cache: Literal[False] = False,
|
303
|
-
) -> Array: ...
|
304
|
-
|
305
|
-
def __call__(
|
481
|
+
self.num_heads = num_heads
|
482
|
+
self.head_dim = embed_dim // num_heads
|
483
|
+
self.causal = causal
|
484
|
+
self.context_length = context_length
|
485
|
+
|
486
|
+
@property
|
487
|
+
def embed_dim(self) -> int:
|
488
|
+
return self.head_dim * self.num_heads
|
489
|
+
|
490
|
+
def init_cache(self, dtype: jnp.dtype | None = None, context_sn: Array | None = None) -> AttentionCacheDict:
|
491
|
+
"""Initialize cache for the input."""
|
492
|
+
if dtype is None and context_sn is not None:
|
493
|
+
dtype = context_sn.dtype
|
494
|
+
cache: AttentionCacheDict = {"self_attn": self.self_attn.init_cache(dtype=dtype)}
|
495
|
+
if self.cross_attn is not None:
|
496
|
+
if context_sn is None:
|
497
|
+
raise ValueError("x_tn must be provided if cross_attn is not None")
|
498
|
+
cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
|
499
|
+
return cache
|
500
|
+
|
501
|
+
def forward(
|
306
502
|
self,
|
307
|
-
|
503
|
+
x_tn: Array,
|
308
504
|
*,
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
key: PRNGKeyArray | None = None,
|
313
|
-
cache: dict[str, dict[str, Array]] | None = None,
|
314
|
-
update_cache: bool = False,
|
315
|
-
) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
|
505
|
+
context_sn: Array | None = None,
|
506
|
+
cache: AttentionCacheDict | None = None,
|
507
|
+
) -> tuple[Array, AttentionCacheDict]:
|
316
508
|
"""Apply transformer block.
|
317
509
|
|
318
510
|
Args:
|
319
|
-
|
320
|
-
|
321
|
-
self_mask: Mask for self-attention
|
322
|
-
cross_mask: Mask for cross-attention
|
323
|
-
key: Optional PRNG key for dropout
|
511
|
+
x_tn: Input tensor of shape (seq_len, embed_dim)
|
512
|
+
context_sn: Optional context for cross-attention
|
324
513
|
cache: Optional dictionary containing cached key and value tensors
|
325
|
-
update_cache: Whether to update the cache and return it
|
326
514
|
|
327
515
|
Returns:
|
328
|
-
|
329
|
-
If update_cache is True: Tuple of (output tensor, updated cache)
|
516
|
+
The output tensor and the updated cache
|
330
517
|
"""
|
331
|
-
chex.assert_rank(
|
332
|
-
if key is not None:
|
333
|
-
key1, key2 = jax.random.split(key)
|
334
|
-
else:
|
335
|
-
key1 = key2 = None
|
336
|
-
|
337
|
-
# Initialize cache if needed
|
338
|
-
updated_cache = {}
|
339
|
-
if cache is None:
|
340
|
-
cache = {}
|
518
|
+
chex.assert_rank(x_tn, 2)
|
341
519
|
|
342
520
|
# Self-attention block with pre-norm
|
343
|
-
norm_x = jax.vmap(self.layer_norm1)(
|
521
|
+
norm_x = jax.vmap(self.layer_norm1)(x_tn)
|
344
522
|
|
345
|
-
self_attn_cache =
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
updated_cache["self_attn"] = self_attn_cache
|
351
|
-
else:
|
352
|
-
attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
|
523
|
+
attn_output, self_attn_cache = self.self_attn.forward(
|
524
|
+
x_tn=norm_x,
|
525
|
+
cache=None if cache is None else cache["self_attn"],
|
526
|
+
)
|
527
|
+
updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
|
353
528
|
|
354
|
-
|
529
|
+
x_tn = x_tn + attn_output
|
355
530
|
|
356
531
|
# Cross-attention block (if enabled) with pre-norm
|
357
|
-
if self.cross_attn is not None
|
532
|
+
if self.cross_attn is not None:
|
358
533
|
assert self.layer_norm3 is not None
|
359
534
|
|
360
|
-
norm_x = jax.vmap(self.layer_norm3)(
|
361
|
-
cross_attn_cache = cache.get("cross_attn")
|
535
|
+
norm_x = jax.vmap(self.layer_norm3)(x_tn)
|
362
536
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
)
|
367
|
-
|
368
|
-
else:
|
369
|
-
cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
|
537
|
+
cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
|
538
|
+
q_tn=norm_x,
|
539
|
+
kv_sn=context_sn,
|
540
|
+
cache=None if cache is None else cache.get("cross_attn"),
|
541
|
+
)
|
370
542
|
|
371
|
-
|
543
|
+
x_tn = x_tn + cross_attn_output
|
372
544
|
|
373
545
|
# Feed-forward block with pre-norm
|
374
|
-
norm_x = jax.vmap(self.layer_norm2)(
|
546
|
+
norm_x = jax.vmap(self.layer_norm2)(x_tn)
|
375
547
|
ff_output = jax.vmap(self.feed_forward)(norm_x)
|
376
|
-
|
548
|
+
x_tn = x_tn + ff_output
|
377
549
|
|
378
|
-
|
379
|
-
return x, updated_cache
|
380
|
-
return x
|
550
|
+
return x_tn, updated_cache
|
381
551
|
|
382
552
|
|
383
|
-
class
|
384
|
-
|
385
|
-
|
386
|
-
layers:
|
387
|
-
|
388
|
-
|
389
|
-
max_seq_len: int = eqx.static_field()
|
390
|
-
embed_dim: int = eqx.static_field()
|
553
|
+
class TransformerStack(eqx.Module):
|
554
|
+
"""A stack of transformer blocks."""
|
555
|
+
|
556
|
+
layers: tuple[TransformerBlock, ...]
|
557
|
+
num_layers: int = eqx.field()
|
558
|
+
causal: bool = eqx.field()
|
391
559
|
|
392
560
|
def __init__(
|
393
561
|
self,
|
394
|
-
vocab_size: int,
|
395
562
|
embed_dim: int,
|
396
563
|
num_heads: int,
|
397
564
|
ff_dim: int,
|
398
565
|
num_layers: int,
|
399
|
-
max_seq_len: int,
|
400
|
-
output_size: int | None = None,
|
401
566
|
*,
|
402
567
|
key: PRNGKeyArray,
|
403
568
|
causal: bool = False,
|
404
569
|
cross_attention: bool = False,
|
405
|
-
|
570
|
+
context_length: int | None = None,
|
571
|
+
use_rotary_embeddings: bool = False,
|
572
|
+
rotary_base: float = 10000.0,
|
406
573
|
) -> None:
|
407
|
-
keys = jax.random.split(key, num_layers
|
408
|
-
|
409
|
-
self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
|
410
|
-
|
411
|
-
# Position embeddings can be disabled
|
412
|
-
if use_absolute_position:
|
413
|
-
self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
|
414
|
-
else:
|
415
|
-
self.position_embedding = None
|
574
|
+
keys = jax.random.split(key, num_layers)
|
416
575
|
|
417
|
-
self.layers =
|
576
|
+
self.layers = tuple(
|
418
577
|
TransformerBlock(
|
419
578
|
embed_dim=embed_dim,
|
420
579
|
num_heads=num_heads,
|
421
580
|
ff_dim=ff_dim,
|
422
|
-
key=keys[i
|
581
|
+
key=keys[i],
|
423
582
|
causal=causal,
|
424
583
|
cross_attention=cross_attention,
|
584
|
+
context_length=context_length,
|
585
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
586
|
+
rotary_base=rotary_base,
|
425
587
|
)
|
426
588
|
for i in range(num_layers)
|
427
|
-
|
589
|
+
)
|
428
590
|
|
429
|
-
self.
|
591
|
+
self.num_layers = num_layers
|
592
|
+
self.causal = causal
|
593
|
+
|
594
|
+
def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
|
595
|
+
"""Initialize cache for the input."""
|
596
|
+
cache = {}
|
597
|
+
for i, layer in enumerate(self.layers):
|
598
|
+
cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
|
599
|
+
return {"layers": cache}
|
600
|
+
|
601
|
+
def forward(
|
602
|
+
self,
|
603
|
+
x_tn: Array,
|
604
|
+
*,
|
605
|
+
context_sn: Array | None = None,
|
606
|
+
cache: TransformerCache | None = None,
|
607
|
+
) -> tuple[Array, TransformerCache]:
|
608
|
+
"""Apply transformer stack.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
x_tn: Input tensor of shape (seq_len, embed_dim)
|
612
|
+
context_sn: Optional context for cross-attention
|
613
|
+
cache: Optional dictionary containing cached key and value tensors
|
430
614
|
|
615
|
+
Returns:
|
616
|
+
The output tensor and the updated cache
|
617
|
+
"""
|
618
|
+
# Initialize layer caches
|
619
|
+
if cache is None:
|
620
|
+
cache = {"layers": {}}
|
621
|
+
|
622
|
+
# Updated cache will be built
|
623
|
+
updated_cache: TransformerCache = {"layers": {}}
|
624
|
+
|
625
|
+
# Apply transformer layers
|
626
|
+
for i, layer in enumerate(self.layers):
|
627
|
+
layer_cache = cache["layers"].get(f"layer_{i}")
|
628
|
+
|
629
|
+
x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
|
630
|
+
x_tn,
|
631
|
+
context_sn=context_sn,
|
632
|
+
cache=layer_cache,
|
633
|
+
)
|
634
|
+
|
635
|
+
return x_tn, updated_cache
|
636
|
+
|
637
|
+
|
638
|
+
class Transformer(eqx.Module):
|
639
|
+
token_embedding: eqx.nn.Embedding
|
640
|
+
layers: TransformerStack
|
641
|
+
output_layer: eqx.nn.Linear | None
|
642
|
+
layer_norm: eqx.nn.LayerNorm
|
643
|
+
embed_dim: int = eqx.field()
|
644
|
+
causal: bool = eqx.field()
|
645
|
+
context_length: int | None = eqx.field()
|
646
|
+
|
647
|
+
def __init__(
|
648
|
+
self,
|
649
|
+
vocab_size: int,
|
650
|
+
embed_dim: int,
|
651
|
+
num_heads: int,
|
652
|
+
ff_dim: int,
|
653
|
+
num_layers: int,
|
654
|
+
output_size: int | None = None,
|
655
|
+
*,
|
656
|
+
key: PRNGKeyArray,
|
657
|
+
causal: bool = False,
|
658
|
+
cross_attention: bool = False,
|
659
|
+
context_length: int | None = None,
|
660
|
+
use_rotary_embeddings: bool = False,
|
661
|
+
rotary_base: float = 10000.0,
|
662
|
+
) -> None:
|
663
|
+
# Calculate number of keys needed
|
664
|
+
num_keys = 3 if output_size is None else 4
|
665
|
+
keys = jax.random.split(key, num_keys)
|
666
|
+
|
667
|
+
self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
|
668
|
+
|
669
|
+
self.layers = TransformerStack(
|
670
|
+
embed_dim=embed_dim,
|
671
|
+
num_heads=num_heads,
|
672
|
+
ff_dim=ff_dim,
|
673
|
+
num_layers=num_layers,
|
674
|
+
key=keys[2],
|
675
|
+
causal=causal,
|
676
|
+
cross_attention=cross_attention,
|
677
|
+
context_length=context_length,
|
678
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
679
|
+
rotary_base=rotary_base,
|
680
|
+
)
|
681
|
+
|
682
|
+
self.layer_norm = eqx.nn.LayerNorm(embed_dim)
|
431
683
|
if output_size is not None:
|
432
|
-
self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[
|
684
|
+
self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[3])
|
433
685
|
else:
|
434
686
|
self.output_layer = None
|
435
687
|
|
436
|
-
self.max_seq_len = max_seq_len
|
437
688
|
self.embed_dim = embed_dim
|
689
|
+
self.causal = causal
|
690
|
+
self.context_length = context_length
|
438
691
|
|
439
|
-
def
|
440
|
-
"""
|
441
|
-
|
442
|
-
return x_embedded
|
443
|
-
|
444
|
-
seq_len, _ = x_embedded.shape
|
445
|
-
|
446
|
-
if positions is None:
|
447
|
-
positions = jnp.arange(seq_len)
|
448
|
-
pos_embedded = jax.vmap(self.position_embedding)(positions)
|
449
|
-
|
450
|
-
return x_embedded + pos_embedded
|
692
|
+
def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
|
693
|
+
"""Initialize cache for the input."""
|
694
|
+
return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
|
451
695
|
|
452
696
|
def encode(
|
453
697
|
self,
|
454
698
|
x: Array,
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
459
|
-
update_cache: bool = False,
|
460
|
-
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
699
|
+
*,
|
700
|
+
cache: TransformerCache | None = None,
|
701
|
+
) -> tuple[Array, TransformerCache]:
|
461
702
|
"""Encode the input sequence.
|
462
703
|
|
463
704
|
Args:
|
464
705
|
x: Input token indices of shape (seq_len)
|
465
|
-
mask: Optional attention mask
|
466
|
-
positions: Optional positions
|
467
|
-
key: Optional PRNG key for dropout
|
468
706
|
cache: Optional dictionary containing cached key and value tensors
|
469
|
-
update_cache: Whether to update the cache and return it
|
470
707
|
|
471
708
|
Returns:
|
472
|
-
|
473
|
-
If update_cache is True: Tuple of (encoded representation, updated cache)
|
709
|
+
The encoded representation and the updated cache
|
474
710
|
"""
|
475
711
|
# Token embedding
|
476
712
|
x_embedded = jax.vmap(self.token_embedding)(x)
|
477
713
|
|
478
|
-
#
|
479
|
-
x_embedded = self.
|
480
|
-
|
481
|
-
# Initialize layer caches
|
482
|
-
if cache is None and update_cache:
|
483
|
-
cache = {f"layer_{i}": {} for i in range(len(self.layers))}
|
484
|
-
|
485
|
-
# Updated cache will be built if needed
|
486
|
-
updated_cache = {}
|
487
|
-
|
488
|
-
# Apply transformer layers
|
489
|
-
keys: Array | list[None] = [None] * len(self.layers)
|
490
|
-
if key is not None:
|
491
|
-
keys = jax.random.split(key, len(self.layers))
|
492
|
-
|
493
|
-
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
494
|
-
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
495
|
-
|
496
|
-
if update_cache:
|
497
|
-
x_embedded, layer_updated_cache = layer.__call__(
|
498
|
-
x_embedded,
|
499
|
-
self_mask=mask,
|
500
|
-
key=layer_key,
|
501
|
-
cache=layer_cache,
|
502
|
-
update_cache=True,
|
503
|
-
)
|
504
|
-
updated_cache[f"layer_{i}"] = layer_updated_cache
|
505
|
-
else:
|
506
|
-
x_embedded = layer.__call__(
|
507
|
-
x_embedded,
|
508
|
-
self_mask=mask,
|
509
|
-
key=layer_key,
|
510
|
-
cache=layer_cache,
|
511
|
-
)
|
714
|
+
# Apply transformer stack
|
715
|
+
x_embedded, updated_cache = self.layers.forward(x_embedded, cache=cache)
|
512
716
|
|
513
717
|
# Apply final layer norm
|
514
718
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
515
719
|
|
516
|
-
|
517
|
-
return output, updated_cache
|
518
|
-
return output
|
720
|
+
return output, updated_cache
|
519
721
|
|
520
722
|
def decode(
|
521
723
|
self,
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
key: PRNGKeyArray | None = None,
|
528
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
529
|
-
update_cache: bool = False,
|
530
|
-
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
724
|
+
x_t: Array,
|
725
|
+
context_s: Array,
|
726
|
+
*,
|
727
|
+
cache: TransformerCache | None = None,
|
728
|
+
) -> tuple[Array, TransformerCache]:
|
531
729
|
"""Decode with self-attention and cross-attention.
|
532
730
|
|
533
731
|
Args:
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
cross_mask: Optional cross-attention mask
|
538
|
-
positions: Optional positions
|
539
|
-
key: Optional PRNG key for dropout
|
732
|
+
x_t: Input token indices, shape (seq_len)
|
733
|
+
context_s: Context from encoder (token indices or embedded),
|
734
|
+
shape (context_len, embed_dim)
|
540
735
|
cache: Optional dictionary containing cached key and value tensors
|
541
|
-
update_cache: Whether to update the cache and return it
|
542
736
|
|
543
737
|
Returns:
|
544
|
-
|
545
|
-
If update_cache is True: Tuple of (decoded representation, updated cache)
|
738
|
+
The decoded representation and the updated cache
|
546
739
|
"""
|
547
|
-
# Token embedding
|
548
|
-
x_embedded = jax.vmap(
|
740
|
+
# Token embedding for x
|
741
|
+
x_embedded = jax.vmap(self.token_embedding)(x_t)
|
549
742
|
|
550
|
-
#
|
551
|
-
|
743
|
+
# Token embedding for context if needed
|
744
|
+
context_embedded = jax.vmap(self.token_embedding)(context_s)
|
552
745
|
|
553
|
-
#
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
# Apply transformer layers with cross-attention
|
561
|
-
keys: Array | list[None] = [None] * len(self.layers)
|
562
|
-
if key is not None:
|
563
|
-
keys = jax.random.split(key, len(self.layers))
|
564
|
-
|
565
|
-
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
566
|
-
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
567
|
-
|
568
|
-
if update_cache:
|
569
|
-
x_embedded, layer_updated_cache = layer.__call__(
|
570
|
-
x_embedded,
|
571
|
-
context=context,
|
572
|
-
self_mask=self_mask,
|
573
|
-
cross_mask=cross_mask,
|
574
|
-
key=layer_key,
|
575
|
-
cache=layer_cache,
|
576
|
-
update_cache=True,
|
577
|
-
)
|
578
|
-
updated_cache[f"layer_{i}"] = layer_updated_cache
|
579
|
-
else:
|
580
|
-
x_embedded = layer(
|
581
|
-
x_embedded,
|
582
|
-
context=context,
|
583
|
-
self_mask=self_mask,
|
584
|
-
cross_mask=cross_mask,
|
585
|
-
key=layer_key,
|
586
|
-
cache=layer_cache,
|
587
|
-
)
|
746
|
+
# Apply transformer stack with cross-attention
|
747
|
+
x_embedded, updated_cache = self.layers.forward(
|
748
|
+
x_embedded,
|
749
|
+
context_sn=context_embedded,
|
750
|
+
cache=cache,
|
751
|
+
)
|
588
752
|
|
589
753
|
# Apply final layer norm
|
590
754
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
591
755
|
|
592
|
-
|
593
|
-
return output, updated_cache
|
594
|
-
return output
|
595
|
-
|
596
|
-
@overload
|
597
|
-
def __call__(
|
598
|
-
self,
|
599
|
-
x: Array,
|
600
|
-
*,
|
601
|
-
mask: Array | None = None,
|
602
|
-
positions: Array | None = None,
|
603
|
-
key: PRNGKeyArray | None = None,
|
604
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
605
|
-
update_cache: Literal[True],
|
606
|
-
) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
|
607
|
-
|
608
|
-
@overload
|
609
|
-
def __call__(
|
610
|
-
self,
|
611
|
-
x: Array,
|
612
|
-
*,
|
613
|
-
mask: Array | None = None,
|
614
|
-
positions: Array | None = None,
|
615
|
-
key: PRNGKeyArray | None = None,
|
616
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
617
|
-
update_cache: Literal[False] = False,
|
618
|
-
) -> Array: ...
|
756
|
+
return output, updated_cache
|
619
757
|
|
620
|
-
def
|
758
|
+
def forward(
|
621
759
|
self,
|
622
760
|
x: Array,
|
623
761
|
*,
|
624
|
-
|
625
|
-
|
626
|
-
key: PRNGKeyArray | None = None,
|
627
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
628
|
-
update_cache: bool = False,
|
629
|
-
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
762
|
+
cache: TransformerCache | None = None,
|
763
|
+
) -> tuple[Array, TransformerCache]:
|
630
764
|
"""Forward pass for encoder-only or decoder-only transformers.
|
631
765
|
|
632
766
|
Args:
|
633
767
|
x: Input token indices of shape (seq_len)
|
634
|
-
mask: Optional attention mask
|
635
|
-
positions: Optional positions
|
636
|
-
key: Optional PRNG key for dropout
|
637
768
|
cache: Optional dictionary containing cached key and value tensors
|
638
|
-
update_cache: Whether to update the cache and return it
|
639
769
|
|
640
770
|
Returns:
|
641
|
-
|
642
|
-
If update_cache is True: Tuple of (output representation, updated cache)
|
771
|
+
The output representation and the updated cache
|
643
772
|
"""
|
644
773
|
chex.assert_rank(x, 1)
|
645
774
|
|
646
|
-
|
647
|
-
output, updated_cache = self.encode(
|
648
|
-
x, mask=mask, positions=positions, key=key, cache=cache, update_cache=True
|
649
|
-
)
|
650
|
-
else:
|
651
|
-
output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
|
775
|
+
output, updated_cache = self.encode(x, cache=cache)
|
652
776
|
|
653
777
|
# Apply output layer if it exists
|
654
778
|
if self.output_layer is not None:
|
655
779
|
output = jax.vmap(self.output_layer)(output)
|
656
780
|
|
657
|
-
|
658
|
-
return output, updated_cache
|
659
|
-
return output
|
781
|
+
return output, updated_cache
|
660
782
|
|
661
783
|
def predict_sequence(self, x_seq: Array) -> Array:
|
662
|
-
|
784
|
+
output, _ = self.forward(x=x_seq)
|
785
|
+
return output
|
663
786
|
|
664
787
|
def generate_sequence(
|
665
788
|
self,
|
@@ -668,6 +791,7 @@ class Transformer(eqx.Module):
|
|
668
791
|
temperature: float = 1.0,
|
669
792
|
top_k: int | None = None,
|
670
793
|
key: PRNGKeyArray | None = None,
|
794
|
+
jit_level: int | None = None,
|
671
795
|
) -> Array:
|
672
796
|
"""Generate a sequence autoregressively with KV caching.
|
673
797
|
|
@@ -677,6 +801,7 @@ class Transformer(eqx.Module):
|
|
677
801
|
temperature: Sampling temperature
|
678
802
|
top_k: Optional top-k sampling parameter
|
679
803
|
key: PRNG key for sampling
|
804
|
+
jit_level: JIT level for the scan function
|
680
805
|
|
681
806
|
Returns:
|
682
807
|
Generated sequence of shape (prompt_len + max_len,)
|
@@ -685,54 +810,40 @@ class Transformer(eqx.Module):
|
|
685
810
|
key = jax.random.PRNGKey(0)
|
686
811
|
|
687
812
|
prompt_len = prompt_seq.shape[0]
|
688
|
-
sequence = prompt_seq
|
689
|
-
|
690
|
-
# Create causal mask for generation
|
691
|
-
causal_mask = jnp.tril(jnp.ones((self.max_seq_len, self.max_seq_len), dtype=jnp.bool_))
|
692
813
|
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
814
|
+
total_len = prompt_len + max_len
|
815
|
+
output_seq = jnp.zeros(total_len, dtype=prompt_seq.dtype)
|
816
|
+
output_seq = output_seq.at[:prompt_len].set(prompt_seq)
|
817
|
+
|
818
|
+
# Initialize cache with prompt
|
819
|
+
cache = self.init_cache()
|
820
|
+
_, cache = self.encode(prompt_seq, cache=cache)
|
821
|
+
|
822
|
+
# Define scan function for autoregressive generation
|
823
|
+
def scan_fn(
|
824
|
+
carry: tuple[Array, int, TransformerCache, PRNGKeyArray],
|
825
|
+
_: None,
|
826
|
+
) -> tuple[tuple[Array, int, TransformerCache, PRNGKeyArray], Array]:
|
827
|
+
output_seq, pos, cache, rng = carry
|
828
|
+
current_token = jax.lax.dynamic_slice(output_seq, (pos,), (1,))
|
829
|
+
|
830
|
+
# Forward pass with cache update
|
831
|
+
logits, new_cache = self.forward(
|
832
|
+
x=current_token,
|
833
|
+
cache=cache,
|
710
834
|
)
|
711
835
|
|
712
|
-
# Extract final logits and apply temperature
|
713
836
|
logits = logits[-1] / temperature
|
714
|
-
|
715
|
-
# Apply top-k sampling if specified
|
716
837
|
if top_k is not None:
|
717
838
|
top_logits, top_indices = jax.lax.top_k(logits, top_k)
|
718
839
|
logits = jnp.full_like(logits, float("-inf"))
|
719
840
|
logits = logits.at[top_indices].set(top_logits)
|
720
|
-
|
721
|
-
# Sample next token
|
722
841
|
rng, subrng = jax.random.split(rng)
|
723
842
|
next_token = jax.random.categorical(subrng, logits[None, ...])[0]
|
843
|
+
new_output_seq = jax.lax.dynamic_update_slice(output_seq, next_token[None], (pos + 1,))
|
724
844
|
|
725
|
-
|
726
|
-
new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
|
727
|
-
return new_seq, new_cache, rng
|
728
|
-
|
729
|
-
# Generate tokens one by one
|
730
|
-
for _ in range(max_len):
|
731
|
-
# Break if max sequence length reached
|
732
|
-
if sequence.shape[0] >= self.max_seq_len:
|
733
|
-
break
|
734
|
-
|
735
|
-
# Decode next token
|
736
|
-
sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
|
845
|
+
return (new_output_seq, pos + 1, new_cache, rng), next_token
|
737
846
|
|
738
|
-
|
847
|
+
init_carry = (output_seq, prompt_len - 1, cache, key)
|
848
|
+
(final_seq, _, _, _), _ = xax_scan(scan_fn, init_carry, length=max_len, jit_level=jit_level)
|
849
|
+
return final_seq
|