xax 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +23 -8
- xax/nn/attention.py +531 -381
- {xax-0.3.3.dist-info → xax-0.3.4.dist-info}/METADATA +1 -1
- {xax-0.3.3.dist-info → xax-0.3.4.dist-info}/RECORD +8 -8
- {xax-0.3.3.dist-info → xax-0.3.4.dist-info}/WHEEL +0 -0
- {xax-0.3.3.dist-info → xax-0.3.4.dist-info}/entry_points.txt +0 -0
- {xax-0.3.3.dist-info → xax-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.3.dist-info → xax-0.3.4.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.4"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,10 +23,14 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"AttentionCache",
|
27
|
+
"AttentionCacheDict",
|
26
28
|
"CrossAttentionBlock",
|
27
29
|
"SelfAttentionBlock",
|
28
30
|
"Transformer",
|
29
31
|
"TransformerBlock",
|
32
|
+
"TransformerCache",
|
33
|
+
"TransformerStack",
|
30
34
|
"FourierEmbeddings",
|
31
35
|
"IdentityPositionalEmbeddings",
|
32
36
|
"LearnedPositionalEmbeddings",
|
@@ -206,10 +210,14 @@ NAME_MAP: dict[str, str] = {
|
|
206
210
|
"get_run_dir": "core.conf",
|
207
211
|
"load_user_config": "core.conf",
|
208
212
|
"State": "core.state",
|
213
|
+
"AttentionCache": "nn.attention",
|
214
|
+
"AttentionCacheDict": "nn.attention",
|
209
215
|
"CrossAttentionBlock": "nn.attention",
|
210
216
|
"SelfAttentionBlock": "nn.attention",
|
211
217
|
"Transformer": "nn.attention",
|
212
218
|
"TransformerBlock": "nn.attention",
|
219
|
+
"TransformerCache": "nn.attention",
|
220
|
+
"TransformerStack": "nn.attention",
|
213
221
|
"FourierEmbeddings": "nn.embeddings",
|
214
222
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
215
223
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -362,6 +370,9 @@ NAME_MAP.update(
|
|
362
370
|
},
|
363
371
|
)
|
364
372
|
|
373
|
+
# In NAME_MAP
|
374
|
+
NAME_MAP["TransformerStack"] = "nn.attention"
|
375
|
+
|
365
376
|
|
366
377
|
def __getattr__(name: str) -> object:
|
367
378
|
if name not in NAME_MAP:
|
@@ -382,7 +393,16 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
382
393
|
load_user_config,
|
383
394
|
)
|
384
395
|
from xax.core.state import Phase, State
|
385
|
-
from xax.nn.attention import
|
396
|
+
from xax.nn.attention import (
|
397
|
+
AttentionCache,
|
398
|
+
AttentionCacheDict,
|
399
|
+
CrossAttentionBlock,
|
400
|
+
SelfAttentionBlock,
|
401
|
+
Transformer,
|
402
|
+
TransformerBlock,
|
403
|
+
TransformerCache,
|
404
|
+
TransformerStack,
|
405
|
+
)
|
386
406
|
from xax.nn.embeddings import (
|
387
407
|
EmbeddingKind,
|
388
408
|
FourierEmbeddings,
|
@@ -411,12 +431,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
411
431
|
rotation_matrix_to_rotation6d,
|
412
432
|
)
|
413
433
|
from xax.nn.losses import cross_entropy
|
414
|
-
from xax.nn.metrics import
|
415
|
-
NormType,
|
416
|
-
cast_norm_type,
|
417
|
-
dynamic_time_warping,
|
418
|
-
get_norm,
|
419
|
-
)
|
434
|
+
from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
|
420
435
|
from xax.nn.parallel import is_master
|
421
436
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
422
437
|
from xax.task.base import RawConfigType
|
xax/nn/attention.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1
|
-
"""Attention mechanisms for transformer models.
|
1
|
+
"""Attention mechanisms for transformer models.
|
2
2
|
|
3
|
-
|
3
|
+
This module implements standard attention mechanisms for transformers, but
|
4
|
+
supporting a fixed-size context window and caching that can be used to train
|
5
|
+
transformers which can be unrolled with a fixed-length cache.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import NotRequired, TypedDict
|
4
9
|
|
5
10
|
import chex
|
6
11
|
import equinox as eqx
|
@@ -9,6 +14,114 @@ import jax.numpy as jnp
|
|
9
14
|
from jaxtyping import Array, PRNGKeyArray
|
10
15
|
|
11
16
|
|
17
|
+
class RotaryEmbedding(eqx.Module):
|
18
|
+
"""Rotary Position Embedding (RoPE) for transformer attention.
|
19
|
+
|
20
|
+
This implements the rotary position embedding as described in:
|
21
|
+
"RoFormer: Enhanced Transformer with Rotary Position Embedding"
|
22
|
+
https://arxiv.org/abs/2104.09864
|
23
|
+
"""
|
24
|
+
|
25
|
+
head_dim: int = eqx.static_field()
|
26
|
+
base: float = eqx.static_field()
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
head_dim: int,
|
31
|
+
base: float = 10000.0,
|
32
|
+
) -> None:
|
33
|
+
"""Initialize rotary embedding.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
head_dim: Dimension of each attention head
|
37
|
+
base: Base for the frequency computation
|
38
|
+
"""
|
39
|
+
self.head_dim = head_dim
|
40
|
+
self.base = base
|
41
|
+
|
42
|
+
def _get_rotary_embeddings(self, positions: Array, dtype: jnp.dtype) -> tuple[Array, Array]:
|
43
|
+
"""Get rotary embeddings for a given sequence length.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
positions: Positions of the sequence
|
47
|
+
dtype: Data type for the embeddings
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Tuple of (cos_embeddings, sin_embeddings) of shape (seq_len, head_dim//2)
|
51
|
+
"""
|
52
|
+
# Create frequency bands
|
53
|
+
dim = self.head_dim // 2
|
54
|
+
freqs = jnp.exp(-jnp.arange(0, dim, dtype=dtype) * jnp.log(self.base) / dim)
|
55
|
+
|
56
|
+
# Compute angles
|
57
|
+
angles = positions[:, None] * freqs[None, :] # (seq_len, dim)
|
58
|
+
|
59
|
+
# Compute cos and sin embeddings
|
60
|
+
cos_embeddings = jnp.cos(angles)
|
61
|
+
sin_embeddings = jnp.sin(angles)
|
62
|
+
|
63
|
+
return cos_embeddings, sin_embeddings
|
64
|
+
|
65
|
+
def apply_rotary_embeddings(
|
66
|
+
self,
|
67
|
+
x: Array,
|
68
|
+
positions: Array | None = None,
|
69
|
+
) -> Array:
|
70
|
+
"""Apply rotary embeddings to input tensor.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
x: Input tensor of shape (seq_len, num_heads, head_dim)
|
74
|
+
positions: Optional position indices of shape (seq_len,)
|
75
|
+
If None, uses sequential positions starting from 0
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Tensor with rotary embeddings applied, same shape as input
|
79
|
+
"""
|
80
|
+
seq_len, _, head_dim = x.shape
|
81
|
+
assert head_dim == self.head_dim, f"Expected head_dim {self.head_dim}, got {head_dim}"
|
82
|
+
|
83
|
+
# Get rotary embeddings
|
84
|
+
if positions is None:
|
85
|
+
positions = jnp.arange(seq_len, dtype=x.dtype)
|
86
|
+
cos_emb, sin_emb = self._get_rotary_embeddings(positions, x.dtype)
|
87
|
+
|
88
|
+
# Reshape to (seq_len, 1, head_dim//2) for broadcasting
|
89
|
+
cos_emb = cos_emb[:, None, :] # (seq_len, 1, head_dim//2)
|
90
|
+
sin_emb = sin_emb[:, None, :] # (seq_len, 1, head_dim//2)
|
91
|
+
|
92
|
+
# Split input into even and odd dimensions
|
93
|
+
x_even = x[..., ::2] # (seq_len, num_heads, head_dim//2)
|
94
|
+
x_odd = x[..., 1::2] # (seq_len, num_heads, head_dim//2)
|
95
|
+
|
96
|
+
# Apply rotation
|
97
|
+
rotated_even = x_even * cos_emb - x_odd * sin_emb
|
98
|
+
rotated_odd = x_even * sin_emb + x_odd * cos_emb
|
99
|
+
|
100
|
+
# Interleave back together
|
101
|
+
result = jnp.zeros_like(x)
|
102
|
+
result = result.at[..., ::2].set(rotated_even)
|
103
|
+
result = result.at[..., 1::2].set(rotated_odd)
|
104
|
+
|
105
|
+
return result
|
106
|
+
|
107
|
+
|
108
|
+
class AttentionCache(TypedDict):
|
109
|
+
k: Array
|
110
|
+
v: Array
|
111
|
+
position: int # Position counter for rotary embeddings
|
112
|
+
|
113
|
+
|
114
|
+
class AttentionCacheDict(TypedDict):
|
115
|
+
self_attn: AttentionCache
|
116
|
+
cross_attn: NotRequired[AttentionCache]
|
117
|
+
|
118
|
+
|
119
|
+
class TransformerCache(TypedDict):
|
120
|
+
"""Cache for the entire transformer stack."""
|
121
|
+
|
122
|
+
layers: dict[str, AttentionCacheDict]
|
123
|
+
|
124
|
+
|
12
125
|
class SelfAttentionBlock(eqx.Module):
|
13
126
|
"""Self-attention block using jax.nn.dot_product_attention."""
|
14
127
|
|
@@ -16,9 +129,11 @@ class SelfAttentionBlock(eqx.Module):
|
|
16
129
|
k_proj: eqx.nn.Linear
|
17
130
|
v_proj: eqx.nn.Linear
|
18
131
|
output_proj: eqx.nn.Linear
|
132
|
+
rotary_emb: RotaryEmbedding | None
|
19
133
|
num_heads: int = eqx.static_field()
|
20
134
|
head_dim: int = eqx.static_field()
|
21
135
|
causal: bool = eqx.static_field()
|
136
|
+
context_length: int | None = eqx.static_field()
|
22
137
|
|
23
138
|
def __init__(
|
24
139
|
self,
|
@@ -27,7 +142,13 @@ class SelfAttentionBlock(eqx.Module):
|
|
27
142
|
*,
|
28
143
|
key: PRNGKeyArray,
|
29
144
|
causal: bool = False,
|
145
|
+
context_length: int | None = None,
|
146
|
+
use_rotary_embeddings: bool = False,
|
147
|
+
rotary_base: float = 10000.0,
|
30
148
|
) -> None:
|
149
|
+
if context_length is not None:
|
150
|
+
assert context_length > 1, "context_length must be at least 2"
|
151
|
+
|
31
152
|
keys = jax.random.split(key, 4)
|
32
153
|
|
33
154
|
self.num_heads = num_heads
|
@@ -39,7 +160,21 @@ class SelfAttentionBlock(eqx.Module):
|
|
39
160
|
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
40
161
|
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
41
162
|
|
163
|
+
# Initialize rotary embeddings if requested
|
164
|
+
if use_rotary_embeddings:
|
165
|
+
self.rotary_emb = RotaryEmbedding(
|
166
|
+
head_dim=self.head_dim,
|
167
|
+
base=rotary_base,
|
168
|
+
)
|
169
|
+
else:
|
170
|
+
self.rotary_emb = None
|
171
|
+
|
42
172
|
self.causal = causal
|
173
|
+
self.context_length = context_length
|
174
|
+
|
175
|
+
@property
|
176
|
+
def embed_dim(self) -> int:
|
177
|
+
return self.head_dim * self.num_heads
|
43
178
|
|
44
179
|
def _reshape_for_multihead(self, x: Array) -> Array:
|
45
180
|
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
@@ -48,58 +183,91 @@ class SelfAttentionBlock(eqx.Module):
|
|
48
183
|
|
49
184
|
def _combine_heads(self, x: Array) -> Array:
|
50
185
|
"""Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
|
51
|
-
|
52
|
-
return x.reshape(
|
186
|
+
_, n, h = x.shape
|
187
|
+
return x.reshape(-1, n * h)
|
188
|
+
|
189
|
+
def init_cache(self, dtype: jnp.dtype | None = None) -> AttentionCache:
|
190
|
+
"""Initialize cache for the input.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
dtype: The dtype of the cache
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
Cache with fixed-length k and v tensors
|
197
|
+
"""
|
198
|
+
if self.context_length is None:
|
199
|
+
raise ValueError("context_length must be set for caching")
|
200
|
+
|
201
|
+
# Create fixed-length cache
|
202
|
+
k_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
|
203
|
+
v_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
|
204
|
+
|
205
|
+
return {"k": k_cache, "v": v_cache, "position": 0}
|
53
206
|
|
54
|
-
def
|
207
|
+
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
208
|
+
in_dim, out_dim = seq_len, seq_len
|
209
|
+
if with_cache:
|
210
|
+
if self.context_length is None:
|
211
|
+
raise ValueError("context_length must be set for caching")
|
212
|
+
in_dim = in_dim + self.context_length - 1
|
213
|
+
|
214
|
+
mask = jnp.tril(jnp.ones((in_dim, out_dim)))
|
215
|
+
if self.context_length is not None:
|
216
|
+
neg_mask = 1 - jnp.tril(jnp.ones((in_dim, out_dim)), -self.context_length)
|
217
|
+
mask = mask * neg_mask
|
218
|
+
|
219
|
+
return mask.astype(jnp.bool_).transpose()
|
220
|
+
|
221
|
+
def forward(
|
55
222
|
self,
|
56
|
-
|
223
|
+
x_tn: Array,
|
57
224
|
*,
|
58
|
-
key: PRNGKeyArray | None = None,
|
59
225
|
mask: Array | None = None,
|
60
|
-
cache:
|
61
|
-
|
62
|
-
|
63
|
-
"""Apply self-attention to the input.
|
226
|
+
cache: AttentionCache | None = None,
|
227
|
+
) -> tuple[Array, AttentionCache]:
|
228
|
+
"""Apply self-attention.
|
64
229
|
|
65
230
|
Args:
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
cache: Optional dictionary containing cached key and value tensors
|
70
|
-
update_cache: Whether to update the cache and return it
|
231
|
+
x_tn: Input tensor of shape (seq_len, embed_dim)
|
232
|
+
mask: Optional mask tensor
|
233
|
+
cache: The cached key and value tensors (fixed-length)
|
71
234
|
|
72
235
|
Returns:
|
73
|
-
|
74
|
-
If update_cache is True: Tuple of (output tensor, updated cache)
|
236
|
+
The output tensor of shape (seq_len, embed_dim) and updated cache
|
75
237
|
"""
|
76
|
-
chex.assert_rank(
|
238
|
+
chex.assert_rank(x_tn, 2)
|
77
239
|
|
78
240
|
# Project inputs to queries, keys, and values
|
79
|
-
q = jax.vmap(self.q_proj)(
|
80
|
-
|
81
|
-
|
82
|
-
if cache is not None and not update_cache:
|
83
|
-
k = cache["k"]
|
84
|
-
v = cache["v"]
|
85
|
-
else:
|
86
|
-
k = jax.vmap(self.k_proj)(x)
|
87
|
-
v = jax.vmap(self.v_proj)(x)
|
88
|
-
|
89
|
-
# Update cache if needed
|
90
|
-
if update_cache:
|
91
|
-
if cache is None:
|
92
|
-
cache = {}
|
93
|
-
cache = {"k": k, "v": v}
|
241
|
+
q = jax.vmap(self.q_proj)(x_tn)
|
242
|
+
k = jax.vmap(self.k_proj)(x_tn)
|
243
|
+
v = jax.vmap(self.v_proj)(x_tn)
|
94
244
|
|
95
245
|
# Reshape to multihead format
|
96
246
|
q = self._reshape_for_multihead(q)
|
97
247
|
k = self._reshape_for_multihead(k)
|
98
248
|
v = self._reshape_for_multihead(v)
|
99
249
|
|
100
|
-
|
101
|
-
|
102
|
-
|
250
|
+
seq_len = q.shape[0]
|
251
|
+
if self.rotary_emb is not None:
|
252
|
+
# Determine position indices for rotary embeddings
|
253
|
+
if cache is not None:
|
254
|
+
start_pos = cache["position"]
|
255
|
+
else:
|
256
|
+
start_pos = 0
|
257
|
+
positions = jnp.arange(seq_len) + start_pos
|
258
|
+
q = self.rotary_emb.apply_rotary_embeddings(q, positions=positions)
|
259
|
+
k = self.rotary_emb.apply_rotary_embeddings(k, positions=positions)
|
260
|
+
|
261
|
+
if cache is not None:
|
262
|
+
k_cache = cache["k"]
|
263
|
+
v_cache = cache["v"]
|
264
|
+
k = jnp.concatenate([k_cache, k], axis=0)
|
265
|
+
v = jnp.concatenate([v_cache, v], axis=0)
|
266
|
+
new_position = cache["position"] + seq_len
|
267
|
+
|
268
|
+
else:
|
269
|
+
new_position = seq_len
|
270
|
+
|
103
271
|
attn_output = jax.nn.dot_product_attention(
|
104
272
|
q,
|
105
273
|
k,
|
@@ -108,15 +276,14 @@ class SelfAttentionBlock(eqx.Module):
|
|
108
276
|
is_causal=self.causal and mask is None,
|
109
277
|
)
|
110
278
|
|
111
|
-
# Combine heads
|
112
279
|
attn_output = self._combine_heads(attn_output)
|
113
|
-
|
114
|
-
# Final projection
|
115
280
|
output = jax.vmap(self.output_proj)(attn_output)
|
116
281
|
|
117
|
-
if
|
118
|
-
|
119
|
-
|
282
|
+
if self.context_length is not None:
|
283
|
+
k = k[-(self.context_length - 1) :]
|
284
|
+
v = v[-(self.context_length - 1) :]
|
285
|
+
|
286
|
+
return output, {"k": k, "v": v, "position": new_position}
|
120
287
|
|
121
288
|
|
122
289
|
class CrossAttentionBlock(eqx.Module):
|
@@ -126,6 +293,7 @@ class CrossAttentionBlock(eqx.Module):
|
|
126
293
|
k_proj: eqx.nn.Linear
|
127
294
|
v_proj: eqx.nn.Linear
|
128
295
|
output_proj: eqx.nn.Linear
|
296
|
+
rotary_emb: RotaryEmbedding | None
|
129
297
|
num_heads: int = eqx.static_field()
|
130
298
|
head_dim: int = eqx.static_field()
|
131
299
|
|
@@ -135,6 +303,8 @@ class CrossAttentionBlock(eqx.Module):
|
|
135
303
|
num_heads: int,
|
136
304
|
*,
|
137
305
|
key: PRNGKeyArray,
|
306
|
+
use_rotary_embeddings: bool = False,
|
307
|
+
rotary_base: float = 10000.0,
|
138
308
|
) -> None:
|
139
309
|
keys = jax.random.split(key, 4)
|
140
310
|
|
@@ -147,6 +317,15 @@ class CrossAttentionBlock(eqx.Module):
|
|
147
317
|
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
148
318
|
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
149
319
|
|
320
|
+
# Initialize rotary embeddings if requested
|
321
|
+
if use_rotary_embeddings:
|
322
|
+
self.rotary_emb = RotaryEmbedding(
|
323
|
+
head_dim=self.head_dim,
|
324
|
+
base=rotary_base,
|
325
|
+
)
|
326
|
+
else:
|
327
|
+
self.rotary_emb = None
|
328
|
+
|
150
329
|
def _reshape_for_multihead(self, x: Array) -> Array:
|
151
330
|
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
152
331
|
seq_len, _ = x.shape
|
@@ -157,59 +336,73 @@ class CrossAttentionBlock(eqx.Module):
|
|
157
336
|
seq_len, _, _ = x.shape
|
158
337
|
return x.reshape(seq_len, -1)
|
159
338
|
|
160
|
-
def
|
339
|
+
def init_cache(self, kv_sn: Array) -> AttentionCache:
|
340
|
+
"""Initialize cache for the input."""
|
341
|
+
chex.assert_rank(kv_sn, 2)
|
342
|
+
k = jax.vmap(self.k_proj)(kv_sn)
|
343
|
+
v = jax.vmap(self.v_proj)(kv_sn)
|
344
|
+
# Reshape to multihead format
|
345
|
+
k = self._reshape_for_multihead(k)
|
346
|
+
v = self._reshape_for_multihead(v)
|
347
|
+
return {"k": k, "v": v, "position": 0}
|
348
|
+
|
349
|
+
def forward(
|
161
350
|
self,
|
162
|
-
|
163
|
-
kv_input: Array,
|
351
|
+
q_tn: Array,
|
164
352
|
*,
|
165
|
-
|
353
|
+
kv_sn: Array | None = None,
|
354
|
+
cache: AttentionCache | None = None,
|
166
355
|
mask: Array | None = None,
|
167
|
-
|
168
|
-
update_cache: bool = False,
|
169
|
-
) -> Array | tuple[Array, dict[str, Array]]:
|
356
|
+
) -> tuple[Array, AttentionCache]:
|
170
357
|
"""Apply cross-attention.
|
171
358
|
|
172
359
|
Args:
|
173
|
-
|
174
|
-
|
175
|
-
|
360
|
+
q_tn: Query input tensor of shape (q_seq_len, embed_dim)
|
361
|
+
kv_sn: Key/value input tensor of shape (kv_seq_len, embed_dim).
|
362
|
+
If not provided, then `cache` must be provided.
|
363
|
+
cache: The cached key and value tensors. If not provided, then
|
364
|
+
`kv_sn` must be provided.
|
176
365
|
mask: Optional mask tensor
|
177
|
-
cache: Optional dictionary containing cached key and value tensors
|
178
|
-
update_cache: Whether to update the cache and return it
|
179
366
|
|
180
367
|
Returns:
|
181
|
-
|
182
|
-
If update_cache is True: Tuple of (output tensor, updated cache)
|
368
|
+
The output tensor of shape (q_seq_len, embed_dim)
|
183
369
|
"""
|
184
|
-
chex.assert_rank(
|
185
|
-
chex.assert_rank(kv_input, 2)
|
370
|
+
chex.assert_rank(q_tn, 2)
|
186
371
|
|
187
372
|
# Project inputs to queries, keys, and values
|
188
|
-
q = jax.vmap(self.q_proj)(
|
373
|
+
q = jax.vmap(self.q_proj)(q_tn)
|
374
|
+
q = self._reshape_for_multihead(q)
|
375
|
+
q_seq_len = q.shape[0]
|
189
376
|
|
190
|
-
# Use cached key/value if provided
|
191
|
-
if cache is not None
|
377
|
+
# Use cached key/value if provided
|
378
|
+
if cache is not None:
|
192
379
|
k = cache["k"]
|
193
380
|
v = cache["v"]
|
381
|
+
q_position = cache["position"]
|
382
|
+
elif kv_sn is not None:
|
383
|
+
chex.assert_rank(kv_sn, 2)
|
384
|
+
k = jax.vmap(self.k_proj)(kv_sn)
|
385
|
+
v = jax.vmap(self.v_proj)(kv_sn)
|
386
|
+
k = self._reshape_for_multihead(k)
|
387
|
+
v = self._reshape_for_multihead(v)
|
388
|
+
q_position = 0
|
194
389
|
else:
|
195
|
-
|
196
|
-
v = jax.vmap(self.v_proj)(kv_input)
|
197
|
-
|
198
|
-
# Update cache if needed
|
199
|
-
if update_cache:
|
200
|
-
if cache is None:
|
201
|
-
cache = {}
|
202
|
-
cache = {"k": k, "v": v}
|
390
|
+
raise ValueError("Either `cache` or `kv_sn` must be provided.")
|
203
391
|
|
204
|
-
#
|
205
|
-
|
206
|
-
|
207
|
-
|
392
|
+
# Apply rotary embeddings to queries and keys if enabled
|
393
|
+
if self.rotary_emb is None:
|
394
|
+
q_rot = q
|
395
|
+
k_rot = k
|
396
|
+
else:
|
397
|
+
q_positions = jnp.arange(q_seq_len) + q_position
|
398
|
+
k_positions = jnp.arange(k.shape[0])
|
399
|
+
q_rot = self.rotary_emb.apply_rotary_embeddings(q, positions=q_positions)
|
400
|
+
k_rot = self.rotary_emb.apply_rotary_embeddings(k, positions=k_positions)
|
208
401
|
|
209
402
|
# Apply dot product attention
|
210
403
|
attn_output = jax.nn.dot_product_attention(
|
211
|
-
|
212
|
-
|
404
|
+
q_rot,
|
405
|
+
k_rot,
|
213
406
|
v,
|
214
407
|
mask=mask,
|
215
408
|
is_causal=False,
|
@@ -221,9 +414,7 @@ class CrossAttentionBlock(eqx.Module):
|
|
221
414
|
# Final projection
|
222
415
|
output = jax.vmap(self.output_proj)(attn_output)
|
223
416
|
|
224
|
-
|
225
|
-
return output, cast(dict[str, Array], cache)
|
226
|
-
return output
|
417
|
+
return output, {"k": k, "v": v, "position": q_position + q_seq_len}
|
227
418
|
|
228
419
|
|
229
420
|
class TransformerBlock(eqx.Module):
|
@@ -233,6 +424,10 @@ class TransformerBlock(eqx.Module):
|
|
233
424
|
layer_norm1: eqx.nn.LayerNorm
|
234
425
|
layer_norm2: eqx.nn.LayerNorm
|
235
426
|
layer_norm3: eqx.nn.LayerNorm | None
|
427
|
+
num_heads: int = eqx.static_field()
|
428
|
+
head_dim: int = eqx.static_field()
|
429
|
+
causal: bool = eqx.static_field()
|
430
|
+
context_length: int | None = eqx.static_field()
|
236
431
|
|
237
432
|
def __init__(
|
238
433
|
self,
|
@@ -243,14 +438,20 @@ class TransformerBlock(eqx.Module):
|
|
243
438
|
key: PRNGKeyArray,
|
244
439
|
causal: bool = False,
|
245
440
|
cross_attention: bool = False,
|
441
|
+
context_length: int | None = None,
|
442
|
+
use_rotary_embeddings: bool = False,
|
443
|
+
rotary_base: float = 10000.0,
|
246
444
|
) -> None:
|
247
|
-
keys = jax.random.split(key,
|
445
|
+
keys = jax.random.split(key, 3)
|
248
446
|
|
249
447
|
self.self_attn = SelfAttentionBlock(
|
250
448
|
embed_dim=embed_dim,
|
251
449
|
num_heads=num_heads,
|
252
450
|
key=keys[0],
|
253
451
|
causal=causal,
|
452
|
+
context_length=context_length,
|
453
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
454
|
+
rotary_base=rotary_base,
|
254
455
|
)
|
255
456
|
|
256
457
|
if cross_attention:
|
@@ -258,8 +459,11 @@ class TransformerBlock(eqx.Module):
|
|
258
459
|
embed_dim=embed_dim,
|
259
460
|
num_heads=num_heads,
|
260
461
|
key=keys[1],
|
462
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
463
|
+
rotary_base=rotary_base,
|
261
464
|
)
|
262
465
|
self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
|
466
|
+
|
263
467
|
else:
|
264
468
|
self.cross_attn = None
|
265
469
|
self.layer_norm3 = None
|
@@ -276,390 +480,350 @@ class TransformerBlock(eqx.Module):
|
|
276
480
|
key=keys[2],
|
277
481
|
)
|
278
482
|
|
279
|
-
|
280
|
-
|
281
|
-
self
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
) -> Array: ...
|
304
|
-
|
305
|
-
def __call__(
|
483
|
+
self.num_heads = num_heads
|
484
|
+
self.head_dim = embed_dim // num_heads
|
485
|
+
self.causal = causal
|
486
|
+
self.context_length = context_length
|
487
|
+
|
488
|
+
@property
|
489
|
+
def embed_dim(self) -> int:
|
490
|
+
return self.head_dim * self.num_heads
|
491
|
+
|
492
|
+
def init_cache(self, dtype: jnp.dtype | None = None, context_sn: Array | None = None) -> AttentionCacheDict:
|
493
|
+
"""Initialize cache for the input."""
|
494
|
+
if dtype is None and context_sn is not None:
|
495
|
+
dtype = context_sn.dtype
|
496
|
+
cache: AttentionCacheDict = {"self_attn": self.self_attn.init_cache(dtype=dtype)}
|
497
|
+
if self.cross_attn is not None:
|
498
|
+
if context_sn is None:
|
499
|
+
raise ValueError("x_tn must be provided if cross_attn is not None")
|
500
|
+
cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
|
501
|
+
return cache
|
502
|
+
|
503
|
+
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
504
|
+
return self.self_attn.init_mask(seq_len, with_cache=with_cache)
|
505
|
+
|
506
|
+
def forward(
|
306
507
|
self,
|
307
|
-
|
508
|
+
x_tn: Array,
|
308
509
|
*,
|
309
|
-
|
510
|
+
context_sn: Array | None = None,
|
310
511
|
self_mask: Array | None = None,
|
311
512
|
cross_mask: Array | None = None,
|
312
|
-
|
313
|
-
|
314
|
-
update_cache: bool = False,
|
315
|
-
) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
|
513
|
+
cache: AttentionCacheDict | None = None,
|
514
|
+
) -> tuple[Array, AttentionCacheDict]:
|
316
515
|
"""Apply transformer block.
|
317
516
|
|
318
517
|
Args:
|
319
|
-
|
320
|
-
|
518
|
+
x_tn: Input tensor of shape (seq_len, embed_dim)
|
519
|
+
context_sn: Optional context for cross-attention
|
321
520
|
self_mask: Mask for self-attention
|
322
521
|
cross_mask: Mask for cross-attention
|
323
|
-
key: Optional PRNG key for dropout
|
324
522
|
cache: Optional dictionary containing cached key and value tensors
|
325
|
-
update_cache: Whether to update the cache and return it
|
326
523
|
|
327
524
|
Returns:
|
328
|
-
|
329
|
-
If update_cache is True: Tuple of (output tensor, updated cache)
|
525
|
+
The output tensor and the updated cache
|
330
526
|
"""
|
331
|
-
chex.assert_rank(
|
332
|
-
if key is not None:
|
333
|
-
key1, key2 = jax.random.split(key)
|
334
|
-
else:
|
335
|
-
key1 = key2 = None
|
336
|
-
|
337
|
-
# Initialize cache if needed
|
338
|
-
updated_cache = {}
|
339
|
-
if cache is None:
|
340
|
-
cache = {}
|
527
|
+
chex.assert_rank(x_tn, 2)
|
341
528
|
|
342
529
|
# Self-attention block with pre-norm
|
343
|
-
norm_x = jax.vmap(self.layer_norm1)(
|
530
|
+
norm_x = jax.vmap(self.layer_norm1)(x_tn)
|
344
531
|
|
345
|
-
self_attn_cache =
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
else:
|
352
|
-
attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
|
532
|
+
attn_output, self_attn_cache = self.self_attn.forward(
|
533
|
+
x_tn=norm_x,
|
534
|
+
mask=self_mask,
|
535
|
+
cache=None if cache is None else cache["self_attn"],
|
536
|
+
)
|
537
|
+
updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
|
353
538
|
|
354
|
-
|
539
|
+
x_tn = x_tn + attn_output
|
355
540
|
|
356
541
|
# Cross-attention block (if enabled) with pre-norm
|
357
|
-
if self.cross_attn is not None
|
542
|
+
if self.cross_attn is not None:
|
358
543
|
assert self.layer_norm3 is not None
|
359
544
|
|
360
|
-
norm_x = jax.vmap(self.layer_norm3)(
|
361
|
-
cross_attn_cache = cache.get("cross_attn")
|
545
|
+
norm_x = jax.vmap(self.layer_norm3)(x_tn)
|
362
546
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
|
547
|
+
cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
|
548
|
+
q_tn=norm_x,
|
549
|
+
kv_sn=context_sn,
|
550
|
+
mask=cross_mask,
|
551
|
+
cache=None if cache is None else cache.get("cross_attn"),
|
552
|
+
)
|
370
553
|
|
371
|
-
|
554
|
+
x_tn = x_tn + cross_attn_output
|
372
555
|
|
373
556
|
# Feed-forward block with pre-norm
|
374
|
-
norm_x = jax.vmap(self.layer_norm2)(
|
557
|
+
norm_x = jax.vmap(self.layer_norm2)(x_tn)
|
375
558
|
ff_output = jax.vmap(self.feed_forward)(norm_x)
|
376
|
-
|
559
|
+
x_tn = x_tn + ff_output
|
377
560
|
|
378
|
-
|
379
|
-
return x, updated_cache
|
380
|
-
return x
|
561
|
+
return x_tn, updated_cache
|
381
562
|
|
382
563
|
|
383
|
-
class
|
384
|
-
|
385
|
-
|
564
|
+
class TransformerStack(eqx.Module):
|
565
|
+
"""A stack of transformer blocks."""
|
566
|
+
|
386
567
|
layers: list[TransformerBlock]
|
387
|
-
|
388
|
-
|
389
|
-
max_seq_len: int = eqx.static_field()
|
390
|
-
embed_dim: int = eqx.static_field()
|
568
|
+
num_layers: int = eqx.static_field()
|
569
|
+
causal: bool = eqx.static_field()
|
391
570
|
|
392
571
|
def __init__(
|
393
572
|
self,
|
394
|
-
vocab_size: int,
|
395
573
|
embed_dim: int,
|
396
574
|
num_heads: int,
|
397
575
|
ff_dim: int,
|
398
576
|
num_layers: int,
|
399
|
-
max_seq_len: int,
|
400
|
-
output_size: int | None = None,
|
401
577
|
*,
|
402
578
|
key: PRNGKeyArray,
|
403
579
|
causal: bool = False,
|
404
580
|
cross_attention: bool = False,
|
405
|
-
|
581
|
+
context_length: int | None = None,
|
582
|
+
use_rotary_embeddings: bool = False,
|
583
|
+
rotary_base: float = 10000.0,
|
406
584
|
) -> None:
|
407
|
-
keys = jax.random.split(key, num_layers
|
408
|
-
|
409
|
-
self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
|
410
|
-
|
411
|
-
# Position embeddings can be disabled
|
412
|
-
if use_absolute_position:
|
413
|
-
self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
|
414
|
-
else:
|
415
|
-
self.position_embedding = None
|
585
|
+
keys = jax.random.split(key, num_layers)
|
416
586
|
|
417
587
|
self.layers = [
|
418
588
|
TransformerBlock(
|
419
589
|
embed_dim=embed_dim,
|
420
590
|
num_heads=num_heads,
|
421
591
|
ff_dim=ff_dim,
|
422
|
-
key=keys[i
|
592
|
+
key=keys[i],
|
423
593
|
causal=causal,
|
424
594
|
cross_attention=cross_attention,
|
595
|
+
context_length=context_length,
|
596
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
597
|
+
rotary_base=rotary_base,
|
425
598
|
)
|
426
599
|
for i in range(num_layers)
|
427
600
|
]
|
428
601
|
|
429
|
-
self.
|
602
|
+
self.num_layers = num_layers
|
603
|
+
self.causal = causal
|
604
|
+
|
605
|
+
def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
|
606
|
+
"""Initialize cache for the input."""
|
607
|
+
cache = {}
|
608
|
+
for i, layer in enumerate(self.layers):
|
609
|
+
cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
|
610
|
+
return {"layers": cache}
|
611
|
+
|
612
|
+
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
613
|
+
return self.layers[0].init_mask(seq_len, with_cache=with_cache)
|
614
|
+
|
615
|
+
def forward(
|
616
|
+
self,
|
617
|
+
x_tn: Array,
|
618
|
+
*,
|
619
|
+
context_sn: Array | None = None,
|
620
|
+
self_mask: Array | None = None,
|
621
|
+
cross_mask: Array | None = None,
|
622
|
+
cache: TransformerCache | None = None,
|
623
|
+
) -> tuple[Array, TransformerCache]:
|
624
|
+
"""Apply transformer stack.
|
625
|
+
|
626
|
+
Args:
|
627
|
+
x_tn: Input tensor of shape (seq_len, embed_dim)
|
628
|
+
context_sn: Optional context for cross-attention
|
629
|
+
self_mask: Mask for self-attention
|
630
|
+
cross_mask: Mask for cross-attention
|
631
|
+
cache: Optional dictionary containing cached key and value tensors
|
632
|
+
|
633
|
+
Returns:
|
634
|
+
The output tensor and the updated cache
|
635
|
+
"""
|
636
|
+
# Initialize layer caches
|
637
|
+
if cache is None:
|
638
|
+
cache = {"layers": {}}
|
639
|
+
|
640
|
+
# Updated cache will be built
|
641
|
+
updated_cache: TransformerCache = {"layers": {}}
|
642
|
+
|
643
|
+
# Apply transformer layers
|
644
|
+
for i, layer in enumerate(self.layers):
|
645
|
+
layer_cache = cache["layers"].get(f"layer_{i}")
|
646
|
+
|
647
|
+
x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
|
648
|
+
x_tn,
|
649
|
+
context_sn=context_sn,
|
650
|
+
self_mask=self_mask,
|
651
|
+
cross_mask=cross_mask,
|
652
|
+
cache=layer_cache,
|
653
|
+
)
|
654
|
+
|
655
|
+
return x_tn, updated_cache
|
656
|
+
|
657
|
+
|
658
|
+
class Transformer(eqx.Module):
|
659
|
+
token_embedding: eqx.nn.Embedding
|
660
|
+
layers: TransformerStack
|
661
|
+
output_layer: eqx.nn.Linear | None
|
662
|
+
layer_norm: eqx.nn.LayerNorm
|
663
|
+
embed_dim: int = eqx.static_field()
|
664
|
+
causal: bool = eqx.static_field()
|
665
|
+
context_length: int | None = eqx.static_field()
|
666
|
+
|
667
|
+
def __init__(
|
668
|
+
self,
|
669
|
+
vocab_size: int,
|
670
|
+
embed_dim: int,
|
671
|
+
num_heads: int,
|
672
|
+
ff_dim: int,
|
673
|
+
num_layers: int,
|
674
|
+
output_size: int | None = None,
|
675
|
+
*,
|
676
|
+
key: PRNGKeyArray,
|
677
|
+
causal: bool = False,
|
678
|
+
cross_attention: bool = False,
|
679
|
+
context_length: int | None = None,
|
680
|
+
use_rotary_embeddings: bool = False,
|
681
|
+
rotary_base: float = 10000.0,
|
682
|
+
) -> None:
|
683
|
+
# Calculate number of keys needed
|
684
|
+
num_keys = 3 if output_size is None else 4
|
685
|
+
keys = jax.random.split(key, num_keys)
|
686
|
+
|
687
|
+
self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
|
688
|
+
|
689
|
+
self.layers = TransformerStack(
|
690
|
+
embed_dim=embed_dim,
|
691
|
+
num_heads=num_heads,
|
692
|
+
ff_dim=ff_dim,
|
693
|
+
num_layers=num_layers,
|
694
|
+
key=keys[2],
|
695
|
+
causal=causal,
|
696
|
+
cross_attention=cross_attention,
|
697
|
+
context_length=context_length,
|
698
|
+
use_rotary_embeddings=use_rotary_embeddings,
|
699
|
+
rotary_base=rotary_base,
|
700
|
+
)
|
430
701
|
|
702
|
+
self.layer_norm = eqx.nn.LayerNorm(embed_dim)
|
431
703
|
if output_size is not None:
|
432
|
-
self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[
|
704
|
+
self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[3])
|
433
705
|
else:
|
434
706
|
self.output_layer = None
|
435
707
|
|
436
|
-
self.max_seq_len = max_seq_len
|
437
708
|
self.embed_dim = embed_dim
|
709
|
+
self.causal = causal
|
710
|
+
self.context_length = context_length
|
438
711
|
|
439
|
-
def
|
440
|
-
"""
|
441
|
-
|
442
|
-
return x_embedded
|
443
|
-
|
444
|
-
seq_len, _ = x_embedded.shape
|
445
|
-
|
446
|
-
if positions is None:
|
447
|
-
positions = jnp.arange(seq_len)
|
448
|
-
pos_embedded = jax.vmap(self.position_embedding)(positions)
|
712
|
+
def init_cache(self, dtype: jnp.dtype | None = None, x_tn: Array | None = None) -> TransformerCache:
|
713
|
+
"""Initialize cache for the input."""
|
714
|
+
return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
|
449
715
|
|
450
|
-
|
716
|
+
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
717
|
+
return self.layers.init_mask(seq_len, with_cache=with_cache)
|
451
718
|
|
452
719
|
def encode(
|
453
720
|
self,
|
454
721
|
x: Array,
|
722
|
+
*,
|
455
723
|
mask: Array | None = None,
|
456
|
-
|
457
|
-
|
458
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
459
|
-
update_cache: bool = False,
|
460
|
-
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
724
|
+
cache: TransformerCache | None = None,
|
725
|
+
) -> tuple[Array, TransformerCache]:
|
461
726
|
"""Encode the input sequence.
|
462
727
|
|
463
728
|
Args:
|
464
729
|
x: Input token indices of shape (seq_len)
|
465
730
|
mask: Optional attention mask
|
466
|
-
positions: Optional positions
|
467
|
-
key: Optional PRNG key for dropout
|
468
731
|
cache: Optional dictionary containing cached key and value tensors
|
469
|
-
update_cache: Whether to update the cache and return it
|
470
732
|
|
471
733
|
Returns:
|
472
|
-
|
473
|
-
If update_cache is True: Tuple of (encoded representation, updated cache)
|
734
|
+
The encoded representation and the updated cache
|
474
735
|
"""
|
475
736
|
# Token embedding
|
476
737
|
x_embedded = jax.vmap(self.token_embedding)(x)
|
477
738
|
|
478
|
-
#
|
479
|
-
x_embedded = self.
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
# Updated cache will be built if needed
|
486
|
-
updated_cache = {}
|
487
|
-
|
488
|
-
# Apply transformer layers
|
489
|
-
keys: Array | list[None] = [None] * len(self.layers)
|
490
|
-
if key is not None:
|
491
|
-
keys = jax.random.split(key, len(self.layers))
|
492
|
-
|
493
|
-
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
494
|
-
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
495
|
-
|
496
|
-
if update_cache:
|
497
|
-
x_embedded, layer_updated_cache = layer.__call__(
|
498
|
-
x_embedded,
|
499
|
-
self_mask=mask,
|
500
|
-
key=layer_key,
|
501
|
-
cache=layer_cache,
|
502
|
-
update_cache=True,
|
503
|
-
)
|
504
|
-
updated_cache[f"layer_{i}"] = layer_updated_cache
|
505
|
-
else:
|
506
|
-
x_embedded = layer.__call__(
|
507
|
-
x_embedded,
|
508
|
-
self_mask=mask,
|
509
|
-
key=layer_key,
|
510
|
-
cache=layer_cache,
|
511
|
-
)
|
739
|
+
# Apply transformer stack
|
740
|
+
x_embedded, updated_cache = self.layers.forward(
|
741
|
+
x_embedded,
|
742
|
+
self_mask=mask,
|
743
|
+
cache=cache,
|
744
|
+
)
|
512
745
|
|
513
746
|
# Apply final layer norm
|
514
747
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
515
748
|
|
516
|
-
|
517
|
-
return output, updated_cache
|
518
|
-
return output
|
749
|
+
return output, updated_cache
|
519
750
|
|
520
751
|
def decode(
|
521
752
|
self,
|
522
|
-
|
523
|
-
|
753
|
+
x_t: Array,
|
754
|
+
context_s: Array,
|
755
|
+
*,
|
524
756
|
self_mask: Array | None = None,
|
525
757
|
cross_mask: Array | None = None,
|
526
|
-
|
527
|
-
|
528
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
529
|
-
update_cache: bool = False,
|
530
|
-
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
758
|
+
cache: TransformerCache | None = None,
|
759
|
+
) -> tuple[Array, TransformerCache]:
|
531
760
|
"""Decode with self-attention and cross-attention.
|
532
761
|
|
533
762
|
Args:
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
key: Optional PRNG key for dropout
|
763
|
+
x_t: Input token indices, shape (seq_len)
|
764
|
+
context_s: Context from encoder (token indices or embedded),
|
765
|
+
shape (context_len, embed_dim)
|
766
|
+
self_mask: Optional self-attention mask, shape (seq_len, seq_len)
|
767
|
+
cross_mask: Optional cross-attention mask, shape (seq_len, context_len)
|
540
768
|
cache: Optional dictionary containing cached key and value tensors
|
541
|
-
update_cache: Whether to update the cache and return it
|
542
769
|
|
543
770
|
Returns:
|
544
|
-
|
545
|
-
If update_cache is True: Tuple of (decoded representation, updated cache)
|
771
|
+
The decoded representation and the updated cache
|
546
772
|
"""
|
547
|
-
# Token embedding
|
548
|
-
x_embedded = jax.vmap(
|
549
|
-
|
550
|
-
#
|
551
|
-
|
552
|
-
|
553
|
-
#
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
keys: Array | list[None] = [None] * len(self.layers)
|
562
|
-
if key is not None:
|
563
|
-
keys = jax.random.split(key, len(self.layers))
|
564
|
-
|
565
|
-
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
566
|
-
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
567
|
-
|
568
|
-
if update_cache:
|
569
|
-
x_embedded, layer_updated_cache = layer.__call__(
|
570
|
-
x_embedded,
|
571
|
-
context=context,
|
572
|
-
self_mask=self_mask,
|
573
|
-
cross_mask=cross_mask,
|
574
|
-
key=layer_key,
|
575
|
-
cache=layer_cache,
|
576
|
-
update_cache=True,
|
577
|
-
)
|
578
|
-
updated_cache[f"layer_{i}"] = layer_updated_cache
|
579
|
-
else:
|
580
|
-
x_embedded = layer(
|
581
|
-
x_embedded,
|
582
|
-
context=context,
|
583
|
-
self_mask=self_mask,
|
584
|
-
cross_mask=cross_mask,
|
585
|
-
key=layer_key,
|
586
|
-
cache=layer_cache,
|
587
|
-
)
|
773
|
+
# Token embedding for x
|
774
|
+
x_embedded = jax.vmap(self.token_embedding)(x_t)
|
775
|
+
|
776
|
+
# Token embedding for context if needed
|
777
|
+
context_embedded = jax.vmap(self.token_embedding)(context_s)
|
778
|
+
|
779
|
+
# Apply transformer stack with cross-attention
|
780
|
+
x_embedded, updated_cache = self.layers.forward(
|
781
|
+
x_embedded,
|
782
|
+
context_sn=context_embedded,
|
783
|
+
self_mask=self_mask,
|
784
|
+
cross_mask=cross_mask,
|
785
|
+
cache=cache,
|
786
|
+
)
|
588
787
|
|
589
788
|
# Apply final layer norm
|
590
789
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
591
790
|
|
592
|
-
|
593
|
-
return output, updated_cache
|
594
|
-
return output
|
595
|
-
|
596
|
-
@overload
|
597
|
-
def __call__(
|
598
|
-
self,
|
599
|
-
x: Array,
|
600
|
-
*,
|
601
|
-
mask: Array | None = None,
|
602
|
-
positions: Array | None = None,
|
603
|
-
key: PRNGKeyArray | None = None,
|
604
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
605
|
-
update_cache: Literal[True],
|
606
|
-
) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
|
607
|
-
|
608
|
-
@overload
|
609
|
-
def __call__(
|
610
|
-
self,
|
611
|
-
x: Array,
|
612
|
-
*,
|
613
|
-
mask: Array | None = None,
|
614
|
-
positions: Array | None = None,
|
615
|
-
key: PRNGKeyArray | None = None,
|
616
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
617
|
-
update_cache: Literal[False] = False,
|
618
|
-
) -> Array: ...
|
791
|
+
return output, updated_cache
|
619
792
|
|
620
|
-
def
|
793
|
+
def forward(
|
621
794
|
self,
|
622
795
|
x: Array,
|
623
796
|
*,
|
624
797
|
mask: Array | None = None,
|
625
|
-
|
626
|
-
|
627
|
-
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
628
|
-
update_cache: bool = False,
|
629
|
-
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
798
|
+
cache: TransformerCache | None = None,
|
799
|
+
) -> tuple[Array, TransformerCache]:
|
630
800
|
"""Forward pass for encoder-only or decoder-only transformers.
|
631
801
|
|
632
802
|
Args:
|
633
803
|
x: Input token indices of shape (seq_len)
|
634
804
|
mask: Optional attention mask
|
635
|
-
positions: Optional positions
|
636
|
-
key: Optional PRNG key for dropout
|
637
805
|
cache: Optional dictionary containing cached key and value tensors
|
638
|
-
update_cache: Whether to update the cache and return it
|
639
806
|
|
640
807
|
Returns:
|
641
|
-
|
642
|
-
If update_cache is True: Tuple of (output representation, updated cache)
|
808
|
+
The output representation and the updated cache
|
643
809
|
"""
|
644
810
|
chex.assert_rank(x, 1)
|
645
811
|
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
|
812
|
+
output, updated_cache = self.encode(
|
813
|
+
x,
|
814
|
+
mask=mask,
|
815
|
+
cache=cache,
|
816
|
+
)
|
652
817
|
|
653
818
|
# Apply output layer if it exists
|
654
819
|
if self.output_layer is not None:
|
655
820
|
output = jax.vmap(self.output_layer)(output)
|
656
821
|
|
657
|
-
|
658
|
-
return output, updated_cache
|
659
|
-
return output
|
822
|
+
return output, updated_cache
|
660
823
|
|
661
824
|
def predict_sequence(self, x_seq: Array) -> Array:
|
662
|
-
|
825
|
+
output, _ = self.forward(x=x_seq)
|
826
|
+
return output
|
663
827
|
|
664
828
|
def generate_sequence(
|
665
829
|
self,
|
@@ -685,54 +849,40 @@ class Transformer(eqx.Module):
|
|
685
849
|
key = jax.random.PRNGKey(0)
|
686
850
|
|
687
851
|
prompt_len = prompt_seq.shape[0]
|
688
|
-
sequence = prompt_seq
|
689
852
|
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
853
|
+
total_len = prompt_len + max_len
|
854
|
+
output_seq = jnp.zeros(total_len, dtype=prompt_seq.dtype)
|
855
|
+
output_seq = output_seq.at[:prompt_len].set(prompt_seq)
|
856
|
+
|
857
|
+
# Initialize cache with prompt
|
858
|
+
cache = self.init_cache()
|
859
|
+
_, cache = self.encode(prompt_seq, cache=cache)
|
860
|
+
|
861
|
+
# Define scan function for autoregressive generation
|
862
|
+
def scan_fn(
|
863
|
+
carry: tuple[Array, int, TransformerCache, PRNGKeyArray],
|
864
|
+
_: None,
|
865
|
+
) -> tuple[tuple[Array, int, TransformerCache, PRNGKeyArray], Array]:
|
866
|
+
output_seq, pos, cache, rng = carry
|
867
|
+
current_token = jax.lax.dynamic_slice(output_seq, (pos,), (1,))
|
868
|
+
|
869
|
+
# Forward pass with cache update
|
870
|
+
logits, new_cache = self.forward(
|
871
|
+
x=current_token,
|
872
|
+
cache=cache,
|
710
873
|
)
|
711
874
|
|
712
|
-
# Extract final logits and apply temperature
|
713
875
|
logits = logits[-1] / temperature
|
714
|
-
|
715
|
-
# Apply top-k sampling if specified
|
716
876
|
if top_k is not None:
|
717
877
|
top_logits, top_indices = jax.lax.top_k(logits, top_k)
|
718
878
|
logits = jnp.full_like(logits, float("-inf"))
|
719
879
|
logits = logits.at[top_indices].set(top_logits)
|
720
|
-
|
721
|
-
# Sample next token
|
722
880
|
rng, subrng = jax.random.split(rng)
|
723
881
|
next_token = jax.random.categorical(subrng, logits[None, ...])[0]
|
882
|
+
new_output_seq = jax.lax.dynamic_update_slice(output_seq, next_token[None], (pos + 1,))
|
724
883
|
|
725
|
-
|
726
|
-
new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
|
727
|
-
return new_seq, new_cache, rng
|
728
|
-
|
729
|
-
# Generate tokens one by one
|
730
|
-
for _ in range(max_len):
|
731
|
-
# Break if max sequence length reached
|
732
|
-
if sequence.shape[0] >= self.max_seq_len:
|
733
|
-
break
|
734
|
-
|
735
|
-
# Decode next token
|
736
|
-
sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
|
884
|
+
return (new_output_seq, pos + 1, new_cache, rng), next_token
|
737
885
|
|
738
|
-
|
886
|
+
init_carry = (output_seq, prompt_len - 1, cache, key)
|
887
|
+
(final_seq, _, _, _), _ = jax.lax.scan(scan_fn, init_carry, length=max_len)
|
888
|
+
return final_seq
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=LJFB4xQplzC08tkbkZMxaCd-7jIB7aJZzBMcs9AuqiM,16240
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -8,7 +8,7 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
xax/nn/attention.py,sha256=
|
11
|
+
xax/nn/attention.py,sha256=aIEtrM7vAQtaXTPKmsqGcYqt03CyiUQMccXj8Cjw3vc,29514
|
12
12
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
13
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
14
|
xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
|
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
62
|
+
xax-0.3.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.3.4.dist-info/METADATA,sha256=j_UQdK4iPYbhzMH0osmHm5XJnYnFY1A_Z5MwSJwXr-4,1246
|
64
|
+
xax-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
xax-0.3.4.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.3.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|