xax 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -1
- xax/nn/attention.py +144 -92
- xax/nn/embeddings.py +10 -10
- xax/nn/geom.py +5 -5
- xax/nn/ssm.py +6 -6
- xax/task/mixins/train.py +6 -1
- xax/utils/pytree.py +13 -0
- {xax-0.3.4.dist-info → xax-0.3.6.dist-info}/METADATA +1 -1
- {xax-0.3.4.dist-info → xax-0.3.6.dist-info}/RECORD +13 -13
- {xax-0.3.4.dist-info → xax-0.3.6.dist-info}/WHEEL +0 -0
- {xax-0.3.4.dist-info → xax-0.3.6.dist-info}/entry_points.txt +0 -0
- {xax-0.3.4.dist-info → xax-0.3.6.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.4.dist-info → xax-0.3.6.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.6"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -136,6 +136,7 @@ __all__ = [
|
|
136
136
|
"compute_nan_ratio",
|
137
137
|
"flatten_array",
|
138
138
|
"flatten_pytree",
|
139
|
+
"get_pytree_mapping",
|
139
140
|
"get_pytree_param_count",
|
140
141
|
"pytree_has_nans",
|
141
142
|
"reshuffle_pytree",
|
@@ -323,6 +324,7 @@ NAME_MAP: dict[str, str] = {
|
|
323
324
|
"compute_nan_ratio": "utils.pytree",
|
324
325
|
"flatten_array": "utils.pytree",
|
325
326
|
"flatten_pytree": "utils.pytree",
|
327
|
+
"get_pytree_mapping": "utils.pytree",
|
326
328
|
"get_pytree_param_count": "utils.pytree",
|
327
329
|
"pytree_has_nans": "utils.pytree",
|
328
330
|
"reshuffle_pytree": "utils.pytree",
|
@@ -509,6 +511,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
509
511
|
compute_nan_ratio,
|
510
512
|
flatten_array,
|
511
513
|
flatten_pytree,
|
514
|
+
get_pytree_mapping,
|
512
515
|
get_pytree_param_count,
|
513
516
|
pytree_has_nans,
|
514
517
|
reshuffle_pytree,
|
xax/nn/attention.py
CHANGED
@@ -5,6 +5,8 @@ supporting a fixed-size context window and caching that can be used to train
|
|
5
5
|
transformers which can be unrolled with a fixed-length cache.
|
6
6
|
"""
|
7
7
|
|
8
|
+
import math
|
9
|
+
import warnings
|
8
10
|
from typing import NotRequired, TypedDict
|
9
11
|
|
10
12
|
import chex
|
@@ -13,6 +15,8 @@ import jax
|
|
13
15
|
import jax.numpy as jnp
|
14
16
|
from jaxtyping import Array, PRNGKeyArray
|
15
17
|
|
18
|
+
from xax.utils.jax import scan as xax_scan
|
19
|
+
|
16
20
|
|
17
21
|
class RotaryEmbedding(eqx.Module):
|
18
22
|
"""Rotary Position Embedding (RoPE) for transformer attention.
|
@@ -22,8 +26,8 @@ class RotaryEmbedding(eqx.Module):
|
|
22
26
|
https://arxiv.org/abs/2104.09864
|
23
27
|
"""
|
24
28
|
|
25
|
-
head_dim: int = eqx.
|
26
|
-
base: float = eqx.
|
29
|
+
head_dim: int = eqx.field()
|
30
|
+
base: float = eqx.field()
|
27
31
|
|
28
32
|
def __init__(
|
29
33
|
self,
|
@@ -125,15 +129,15 @@ class TransformerCache(TypedDict):
|
|
125
129
|
class SelfAttentionBlock(eqx.Module):
|
126
130
|
"""Self-attention block using jax.nn.dot_product_attention."""
|
127
131
|
|
128
|
-
q_proj: eqx.nn.Linear
|
129
|
-
k_proj: eqx.nn.Linear
|
130
|
-
v_proj: eqx.nn.Linear
|
131
|
-
output_proj: eqx.nn.Linear
|
132
|
-
rotary_emb: RotaryEmbedding | None
|
133
|
-
num_heads: int = eqx.
|
134
|
-
head_dim: int = eqx.
|
135
|
-
causal: bool = eqx.
|
136
|
-
|
132
|
+
q_proj: eqx.nn.Linear = eqx.field()
|
133
|
+
k_proj: eqx.nn.Linear = eqx.field()
|
134
|
+
v_proj: eqx.nn.Linear = eqx.field()
|
135
|
+
output_proj: eqx.nn.Linear = eqx.field()
|
136
|
+
rotary_emb: RotaryEmbedding | None = eqx.field()
|
137
|
+
num_heads: int = eqx.field()
|
138
|
+
head_dim: int = eqx.field()
|
139
|
+
causal: bool = eqx.field()
|
140
|
+
local_window_size: int | None = eqx.field()
|
137
141
|
|
138
142
|
def __init__(
|
139
143
|
self,
|
@@ -169,8 +173,12 @@ class SelfAttentionBlock(eqx.Module):
|
|
169
173
|
else:
|
170
174
|
self.rotary_emb = None
|
171
175
|
|
176
|
+
if context_length is not None and not causal:
|
177
|
+
warnings.warn("context_length is set but causal is False; overriding causal to True", stacklevel=2)
|
178
|
+
causal = True
|
179
|
+
|
172
180
|
self.causal = causal
|
173
|
-
self.
|
181
|
+
self.local_window_size = None if context_length is None else context_length - 1
|
174
182
|
|
175
183
|
@property
|
176
184
|
def embed_dim(self) -> int:
|
@@ -195,28 +203,44 @@ class SelfAttentionBlock(eqx.Module):
|
|
195
203
|
Returns:
|
196
204
|
Cache with fixed-length k and v tensors
|
197
205
|
"""
|
198
|
-
if self.
|
206
|
+
if self.local_window_size is None:
|
199
207
|
raise ValueError("context_length must be set for caching")
|
200
208
|
|
201
209
|
# Create fixed-length cache
|
202
|
-
k_cache = jnp.zeros((self.
|
203
|
-
v_cache = jnp.zeros((self.
|
210
|
+
k_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
|
211
|
+
v_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
|
204
212
|
|
205
213
|
return {"k": k_cache, "v": v_cache, "position": 0}
|
206
214
|
|
207
|
-
def init_mask(
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
215
|
+
def init_mask(
|
216
|
+
self,
|
217
|
+
seq_len: int,
|
218
|
+
add_cache: bool = False,
|
219
|
+
batch_dim: bool = False,
|
220
|
+
) -> Array:
|
221
|
+
"""Initialize the attention matrix mask.
|
213
222
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
223
|
+
Args:
|
224
|
+
seq_len: The length of the sequence
|
225
|
+
add_cache: Whether to add the cache to the mask
|
226
|
+
batch_dim: Whether to add a batch dimension to the mask
|
218
227
|
|
219
|
-
|
228
|
+
Returns:
|
229
|
+
The attention matrix mask of shape (bsz, 1, seq_len, seq_len + cache_len)
|
230
|
+
if batch_dim is True, otherwise (seq_len, seq_len + cache_len).
|
231
|
+
"""
|
232
|
+
t, s, o = seq_len, seq_len, 0
|
233
|
+
if add_cache:
|
234
|
+
if self.local_window_size is None:
|
235
|
+
raise ValueError("local_window_size must be set for caching")
|
236
|
+
s += self.local_window_size
|
237
|
+
o -= self.local_window_size
|
238
|
+
mask = jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-o)
|
239
|
+
if self.local_window_size is not None:
|
240
|
+
neg_mask = ~jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-(self.local_window_size + 1 + o))
|
241
|
+
mask = mask & neg_mask
|
242
|
+
mask = mask.reshape(1, 1, t, s) if batch_dim else mask.reshape(t, s)
|
243
|
+
return mask
|
220
244
|
|
221
245
|
def forward(
|
222
246
|
self,
|
@@ -229,7 +253,8 @@ class SelfAttentionBlock(eqx.Module):
|
|
229
253
|
|
230
254
|
Args:
|
231
255
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
232
|
-
mask: Optional mask
|
256
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
257
|
+
seq_len + cache_len)
|
233
258
|
cache: The cached key and value tensors (fixed-length)
|
234
259
|
|
235
260
|
Returns:
|
@@ -263,25 +288,36 @@ class SelfAttentionBlock(eqx.Module):
|
|
263
288
|
v_cache = cache["v"]
|
264
289
|
k = jnp.concatenate([k_cache, k], axis=0)
|
265
290
|
v = jnp.concatenate([v_cache, v], axis=0)
|
291
|
+
|
266
292
|
new_position = cache["position"] + seq_len
|
267
293
|
|
268
294
|
else:
|
269
295
|
new_position = seq_len
|
270
296
|
|
271
|
-
|
272
|
-
q,
|
273
|
-
|
274
|
-
|
275
|
-
mask=mask
|
276
|
-
|
277
|
-
|
297
|
+
if seq_len == 1:
|
298
|
+
attn_output = jax.nn.dot_product_attention(q, k, v)
|
299
|
+
|
300
|
+
elif mask is not None:
|
301
|
+
attn_output = jax.nn.dot_product_attention(q, k, v, mask=mask)
|
302
|
+
|
303
|
+
elif cache is not None:
|
304
|
+
raise NotImplementedError("For training with a cache, provide a mask instead.")
|
305
|
+
|
306
|
+
else:
|
307
|
+
attn_output = jax.nn.dot_product_attention(
|
308
|
+
q,
|
309
|
+
k,
|
310
|
+
v,
|
311
|
+
is_causal=self.causal,
|
312
|
+
local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
|
313
|
+
)
|
278
314
|
|
279
315
|
attn_output = self._combine_heads(attn_output)
|
280
316
|
output = jax.vmap(self.output_proj)(attn_output)
|
281
317
|
|
282
|
-
if self.
|
283
|
-
k = k[-
|
284
|
-
v = v[-
|
318
|
+
if self.local_window_size is not None:
|
319
|
+
k = k[-self.local_window_size :]
|
320
|
+
v = v[-self.local_window_size :]
|
285
321
|
|
286
322
|
return output, {"k": k, "v": v, "position": new_position}
|
287
323
|
|
@@ -294,8 +330,8 @@ class CrossAttentionBlock(eqx.Module):
|
|
294
330
|
v_proj: eqx.nn.Linear
|
295
331
|
output_proj: eqx.nn.Linear
|
296
332
|
rotary_emb: RotaryEmbedding | None
|
297
|
-
num_heads: int = eqx.
|
298
|
-
head_dim: int = eqx.
|
333
|
+
num_heads: int = eqx.field()
|
334
|
+
head_dim: int = eqx.field()
|
299
335
|
|
300
336
|
def __init__(
|
301
337
|
self,
|
@@ -352,7 +388,6 @@ class CrossAttentionBlock(eqx.Module):
|
|
352
388
|
*,
|
353
389
|
kv_sn: Array | None = None,
|
354
390
|
cache: AttentionCache | None = None,
|
355
|
-
mask: Array | None = None,
|
356
391
|
) -> tuple[Array, AttentionCache]:
|
357
392
|
"""Apply cross-attention.
|
358
393
|
|
@@ -362,7 +397,6 @@ class CrossAttentionBlock(eqx.Module):
|
|
362
397
|
If not provided, then `cache` must be provided.
|
363
398
|
cache: The cached key and value tensors. If not provided, then
|
364
399
|
`kv_sn` must be provided.
|
365
|
-
mask: Optional mask tensor
|
366
400
|
|
367
401
|
Returns:
|
368
402
|
The output tensor of shape (q_seq_len, embed_dim)
|
@@ -404,7 +438,7 @@ class CrossAttentionBlock(eqx.Module):
|
|
404
438
|
q_rot,
|
405
439
|
k_rot,
|
406
440
|
v,
|
407
|
-
|
441
|
+
scale=1.0 / math.sqrt(self.head_dim),
|
408
442
|
is_causal=False,
|
409
443
|
)
|
410
444
|
|
@@ -424,10 +458,10 @@ class TransformerBlock(eqx.Module):
|
|
424
458
|
layer_norm1: eqx.nn.LayerNorm
|
425
459
|
layer_norm2: eqx.nn.LayerNorm
|
426
460
|
layer_norm3: eqx.nn.LayerNorm | None
|
427
|
-
num_heads: int = eqx.
|
428
|
-
head_dim: int = eqx.
|
429
|
-
causal: bool = eqx.
|
430
|
-
context_length: int | None = eqx.
|
461
|
+
num_heads: int = eqx.field()
|
462
|
+
head_dim: int = eqx.field()
|
463
|
+
causal: bool = eqx.field()
|
464
|
+
context_length: int | None = eqx.field()
|
431
465
|
|
432
466
|
def __init__(
|
433
467
|
self,
|
@@ -500,16 +534,24 @@ class TransformerBlock(eqx.Module):
|
|
500
534
|
cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
|
501
535
|
return cache
|
502
536
|
|
503
|
-
def init_mask(
|
504
|
-
|
537
|
+
def init_mask(
|
538
|
+
self,
|
539
|
+
seq_len: int,
|
540
|
+
add_cache: bool = False,
|
541
|
+
batch_dim: bool = False,
|
542
|
+
) -> Array:
|
543
|
+
return self.self_attn.init_mask(
|
544
|
+
seq_len,
|
545
|
+
add_cache=add_cache,
|
546
|
+
batch_dim=batch_dim,
|
547
|
+
)
|
505
548
|
|
506
549
|
def forward(
|
507
550
|
self,
|
508
551
|
x_tn: Array,
|
509
552
|
*,
|
510
553
|
context_sn: Array | None = None,
|
511
|
-
|
512
|
-
cross_mask: Array | None = None,
|
554
|
+
mask: Array | None = None,
|
513
555
|
cache: AttentionCacheDict | None = None,
|
514
556
|
) -> tuple[Array, AttentionCacheDict]:
|
515
557
|
"""Apply transformer block.
|
@@ -517,8 +559,8 @@ class TransformerBlock(eqx.Module):
|
|
517
559
|
Args:
|
518
560
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
519
561
|
context_sn: Optional context for cross-attention
|
520
|
-
|
521
|
-
|
562
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
563
|
+
seq_len + cache_len)
|
522
564
|
cache: Optional dictionary containing cached key and value tensors
|
523
565
|
|
524
566
|
Returns:
|
@@ -531,7 +573,7 @@ class TransformerBlock(eqx.Module):
|
|
531
573
|
|
532
574
|
attn_output, self_attn_cache = self.self_attn.forward(
|
533
575
|
x_tn=norm_x,
|
534
|
-
mask=
|
576
|
+
mask=mask,
|
535
577
|
cache=None if cache is None else cache["self_attn"],
|
536
578
|
)
|
537
579
|
updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
|
@@ -547,7 +589,6 @@ class TransformerBlock(eqx.Module):
|
|
547
589
|
cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
|
548
590
|
q_tn=norm_x,
|
549
591
|
kv_sn=context_sn,
|
550
|
-
mask=cross_mask,
|
551
592
|
cache=None if cache is None else cache.get("cross_attn"),
|
552
593
|
)
|
553
594
|
|
@@ -564,9 +605,9 @@ class TransformerBlock(eqx.Module):
|
|
564
605
|
class TransformerStack(eqx.Module):
|
565
606
|
"""A stack of transformer blocks."""
|
566
607
|
|
567
|
-
layers:
|
568
|
-
num_layers: int = eqx.
|
569
|
-
causal: bool = eqx.
|
608
|
+
layers: tuple[TransformerBlock, ...]
|
609
|
+
num_layers: int = eqx.field()
|
610
|
+
causal: bool = eqx.field()
|
570
611
|
|
571
612
|
def __init__(
|
572
613
|
self,
|
@@ -584,7 +625,7 @@ class TransformerStack(eqx.Module):
|
|
584
625
|
) -> None:
|
585
626
|
keys = jax.random.split(key, num_layers)
|
586
627
|
|
587
|
-
self.layers =
|
628
|
+
self.layers = tuple(
|
588
629
|
TransformerBlock(
|
589
630
|
embed_dim=embed_dim,
|
590
631
|
num_heads=num_heads,
|
@@ -597,7 +638,7 @@ class TransformerStack(eqx.Module):
|
|
597
638
|
rotary_base=rotary_base,
|
598
639
|
)
|
599
640
|
for i in range(num_layers)
|
600
|
-
|
641
|
+
)
|
601
642
|
|
602
643
|
self.num_layers = num_layers
|
603
644
|
self.causal = causal
|
@@ -609,16 +650,24 @@ class TransformerStack(eqx.Module):
|
|
609
650
|
cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
|
610
651
|
return {"layers": cache}
|
611
652
|
|
612
|
-
def init_mask(
|
613
|
-
|
653
|
+
def init_mask(
|
654
|
+
self,
|
655
|
+
seq_len: int,
|
656
|
+
add_cache: bool = False,
|
657
|
+
batch_dim: bool = False,
|
658
|
+
) -> Array:
|
659
|
+
return self.layers[0].init_mask(
|
660
|
+
seq_len,
|
661
|
+
add_cache=add_cache,
|
662
|
+
batch_dim=batch_dim,
|
663
|
+
)
|
614
664
|
|
615
665
|
def forward(
|
616
666
|
self,
|
617
667
|
x_tn: Array,
|
618
668
|
*,
|
619
669
|
context_sn: Array | None = None,
|
620
|
-
|
621
|
-
cross_mask: Array | None = None,
|
670
|
+
mask: Array | None = None,
|
622
671
|
cache: TransformerCache | None = None,
|
623
672
|
) -> tuple[Array, TransformerCache]:
|
624
673
|
"""Apply transformer stack.
|
@@ -626,8 +675,8 @@ class TransformerStack(eqx.Module):
|
|
626
675
|
Args:
|
627
676
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
628
677
|
context_sn: Optional context for cross-attention
|
629
|
-
|
630
|
-
|
678
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
679
|
+
seq_len + cache_len)
|
631
680
|
cache: Optional dictionary containing cached key and value tensors
|
632
681
|
|
633
682
|
Returns:
|
@@ -647,8 +696,7 @@ class TransformerStack(eqx.Module):
|
|
647
696
|
x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
|
648
697
|
x_tn,
|
649
698
|
context_sn=context_sn,
|
650
|
-
|
651
|
-
cross_mask=cross_mask,
|
699
|
+
mask=mask,
|
652
700
|
cache=layer_cache,
|
653
701
|
)
|
654
702
|
|
@@ -660,9 +708,9 @@ class Transformer(eqx.Module):
|
|
660
708
|
layers: TransformerStack
|
661
709
|
output_layer: eqx.nn.Linear | None
|
662
710
|
layer_norm: eqx.nn.LayerNorm
|
663
|
-
embed_dim: int = eqx.
|
664
|
-
causal: bool = eqx.
|
665
|
-
context_length: int | None = eqx.
|
711
|
+
embed_dim: int = eqx.field()
|
712
|
+
causal: bool = eqx.field()
|
713
|
+
context_length: int | None = eqx.field()
|
666
714
|
|
667
715
|
def __init__(
|
668
716
|
self,
|
@@ -713,8 +761,17 @@ class Transformer(eqx.Module):
|
|
713
761
|
"""Initialize cache for the input."""
|
714
762
|
return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
|
715
763
|
|
716
|
-
def init_mask(
|
717
|
-
|
764
|
+
def init_mask(
|
765
|
+
self,
|
766
|
+
seq_len: int,
|
767
|
+
add_cache: bool = False,
|
768
|
+
batch_dim: bool = False,
|
769
|
+
) -> Array:
|
770
|
+
return self.layers.init_mask(
|
771
|
+
seq_len,
|
772
|
+
add_cache=add_cache,
|
773
|
+
batch_dim=batch_dim,
|
774
|
+
)
|
718
775
|
|
719
776
|
def encode(
|
720
777
|
self,
|
@@ -727,7 +784,8 @@ class Transformer(eqx.Module):
|
|
727
784
|
|
728
785
|
Args:
|
729
786
|
x: Input token indices of shape (seq_len)
|
730
|
-
mask: Optional
|
787
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
788
|
+
seq_len + cache_len)
|
731
789
|
cache: Optional dictionary containing cached key and value tensors
|
732
790
|
|
733
791
|
Returns:
|
@@ -737,11 +795,7 @@ class Transformer(eqx.Module):
|
|
737
795
|
x_embedded = jax.vmap(self.token_embedding)(x)
|
738
796
|
|
739
797
|
# Apply transformer stack
|
740
|
-
x_embedded, updated_cache = self.layers.forward(
|
741
|
-
x_embedded,
|
742
|
-
self_mask=mask,
|
743
|
-
cache=cache,
|
744
|
-
)
|
798
|
+
x_embedded, updated_cache = self.layers.forward(x_embedded, mask=mask, cache=cache)
|
745
799
|
|
746
800
|
# Apply final layer norm
|
747
801
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
@@ -753,8 +807,7 @@ class Transformer(eqx.Module):
|
|
753
807
|
x_t: Array,
|
754
808
|
context_s: Array,
|
755
809
|
*,
|
756
|
-
|
757
|
-
cross_mask: Array | None = None,
|
810
|
+
mask: Array | None = None,
|
758
811
|
cache: TransformerCache | None = None,
|
759
812
|
) -> tuple[Array, TransformerCache]:
|
760
813
|
"""Decode with self-attention and cross-attention.
|
@@ -763,8 +816,8 @@ class Transformer(eqx.Module):
|
|
763
816
|
x_t: Input token indices, shape (seq_len)
|
764
817
|
context_s: Context from encoder (token indices or embedded),
|
765
818
|
shape (context_len, embed_dim)
|
766
|
-
|
767
|
-
|
819
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
820
|
+
seq_len + cache_len)
|
768
821
|
cache: Optional dictionary containing cached key and value tensors
|
769
822
|
|
770
823
|
Returns:
|
@@ -780,8 +833,7 @@ class Transformer(eqx.Module):
|
|
780
833
|
x_embedded, updated_cache = self.layers.forward(
|
781
834
|
x_embedded,
|
782
835
|
context_sn=context_embedded,
|
783
|
-
|
784
|
-
cross_mask=cross_mask,
|
836
|
+
mask=mask,
|
785
837
|
cache=cache,
|
786
838
|
)
|
787
839
|
|
@@ -801,7 +853,8 @@ class Transformer(eqx.Module):
|
|
801
853
|
|
802
854
|
Args:
|
803
855
|
x: Input token indices of shape (seq_len)
|
804
|
-
mask: Optional
|
856
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
857
|
+
seq_len + cache_len)
|
805
858
|
cache: Optional dictionary containing cached key and value tensors
|
806
859
|
|
807
860
|
Returns:
|
@@ -809,11 +862,7 @@ class Transformer(eqx.Module):
|
|
809
862
|
"""
|
810
863
|
chex.assert_rank(x, 1)
|
811
864
|
|
812
|
-
output, updated_cache = self.encode(
|
813
|
-
x,
|
814
|
-
mask=mask,
|
815
|
-
cache=cache,
|
816
|
-
)
|
865
|
+
output, updated_cache = self.encode(x, mask=mask, cache=cache)
|
817
866
|
|
818
867
|
# Apply output layer if it exists
|
819
868
|
if self.output_layer is not None:
|
@@ -832,6 +881,7 @@ class Transformer(eqx.Module):
|
|
832
881
|
temperature: float = 1.0,
|
833
882
|
top_k: int | None = None,
|
834
883
|
key: PRNGKeyArray | None = None,
|
884
|
+
jit_level: int | None = None,
|
835
885
|
) -> Array:
|
836
886
|
"""Generate a sequence autoregressively with KV caching.
|
837
887
|
|
@@ -841,6 +891,7 @@ class Transformer(eqx.Module):
|
|
841
891
|
temperature: Sampling temperature
|
842
892
|
top_k: Optional top-k sampling parameter
|
843
893
|
key: PRNG key for sampling
|
894
|
+
jit_level: JIT level for the scan function
|
844
895
|
|
845
896
|
Returns:
|
846
897
|
Generated sequence of shape (prompt_len + max_len,)
|
@@ -856,7 +907,8 @@ class Transformer(eqx.Module):
|
|
856
907
|
|
857
908
|
# Initialize cache with prompt
|
858
909
|
cache = self.init_cache()
|
859
|
-
|
910
|
+
mask = self.init_mask(prompt_len, add_cache=True, batch_dim=False)
|
911
|
+
_, cache = self.encode(prompt_seq, cache=cache, mask=mask)
|
860
912
|
|
861
913
|
# Define scan function for autoregressive generation
|
862
914
|
def scan_fn(
|
@@ -884,5 +936,5 @@ class Transformer(eqx.Module):
|
|
884
936
|
return (new_output_seq, pos + 1, new_cache, rng), next_token
|
885
937
|
|
886
938
|
init_carry = (output_seq, prompt_len - 1, cache, key)
|
887
|
-
(final_seq, _, _, _), _ =
|
939
|
+
(final_seq, _, _, _), _ = xax_scan(scan_fn, init_carry, length=max_len, jit_level=jit_level)
|
888
940
|
return final_seq
|
xax/nn/embeddings.py
CHANGED
@@ -33,10 +33,10 @@ class LearnedPositionalEmbeddings(eqx.Module):
|
|
33
33
|
learnable: Whether the embeddings are learnable.
|
34
34
|
"""
|
35
35
|
|
36
|
-
max_tsz: int = eqx.field(
|
37
|
-
embed_dim: int = eqx.field(
|
38
|
-
learnable: bool = eqx.field(
|
39
|
-
embeddings_tc: Array
|
36
|
+
max_tsz: int = eqx.field()
|
37
|
+
embed_dim: int = eqx.field()
|
38
|
+
learnable: bool = eqx.field()
|
39
|
+
embeddings_tc: Array = eqx.field()
|
40
40
|
|
41
41
|
def __init__(
|
42
42
|
self,
|
@@ -74,10 +74,10 @@ class SinusoidalEmbeddings(eqx.Module):
|
|
74
74
|
base: The base for the sinusoidal embeddings.
|
75
75
|
"""
|
76
76
|
|
77
|
-
base: int = eqx.field(
|
78
|
-
max_tsz: int | None = eqx.field(
|
79
|
-
embed_dim: int | None = eqx.field(
|
80
|
-
embeddings_tc: Array | None
|
77
|
+
base: int = eqx.field()
|
78
|
+
max_tsz: int | None = eqx.field()
|
79
|
+
embed_dim: int | None = eqx.field()
|
80
|
+
embeddings_tc: Array | None = eqx.field()
|
81
81
|
|
82
82
|
def __init__(
|
83
83
|
self,
|
@@ -91,8 +91,8 @@ class SinusoidalEmbeddings(eqx.Module):
|
|
91
91
|
self.max_tsz = max_tsz
|
92
92
|
self.embed_dim = embed_dim
|
93
93
|
self.base = base
|
94
|
+
self.embeddings_tc = None
|
94
95
|
|
95
|
-
self.embeddings_tc: Array | None = None
|
96
96
|
if learnable:
|
97
97
|
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
|
98
98
|
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
|
@@ -192,7 +192,7 @@ class RotaryEmbeddings(eqx.Module):
|
|
192
192
|
base: The base for the sinusoidal embeddings.
|
193
193
|
"""
|
194
194
|
|
195
|
-
base: int = eqx.field(
|
195
|
+
base: int = eqx.field()
|
196
196
|
|
197
197
|
def __init__(self, base: int = 10_000) -> None:
|
198
198
|
"""Defines a rotary embeddings module.
|
xax/nn/geom.py
CHANGED
@@ -207,7 +207,7 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
|
207
207
|
|
208
208
|
def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
|
209
209
|
norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
|
210
|
-
return v / jnp.clip(norm,
|
210
|
+
return v / jnp.clip(norm, min=eps)
|
211
211
|
|
212
212
|
|
213
213
|
def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
|
@@ -299,28 +299,28 @@ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
|
|
299
299
|
trace = m00 + m11 + m22
|
300
300
|
|
301
301
|
# Case 0: trace is positive
|
302
|
-
s0 = jnp.sqrt(jnp.clip(trace + 1.0,
|
302
|
+
s0 = jnp.sqrt(jnp.clip(trace + 1.0, min=0.0)) * 2.0 # S = 4 * qw
|
303
303
|
w0 = 0.25 * s0
|
304
304
|
x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
|
305
305
|
y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
|
306
306
|
z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
|
307
307
|
|
308
308
|
# Case 1: m00 is the largest diagonal term
|
309
|
-
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22,
|
309
|
+
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, min=0.0)) * 2.0 # S = 4 * qx
|
310
310
|
w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
|
311
311
|
x1 = 0.25 * s1
|
312
312
|
y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
|
313
313
|
z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
|
314
314
|
|
315
315
|
# Case 2: m11 is the largest diagonal term
|
316
|
-
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22,
|
316
|
+
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, min=0.0)) * 2.0 # S = 4 * qy
|
317
317
|
w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
|
318
318
|
x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
|
319
319
|
y2 = 0.25 * s2
|
320
320
|
z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
|
321
321
|
|
322
322
|
# Case 3: m22 is the largest diagonal term
|
323
|
-
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11,
|
323
|
+
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, min=0.0)) * 2.0 # S = 4 * qz
|
324
324
|
w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
|
325
325
|
x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
|
326
326
|
y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
|
xax/nn/ssm.py
CHANGED
@@ -222,12 +222,12 @@ class DiscreteDiagSSMBlock(DiagSSMBlock):
|
|
222
222
|
|
223
223
|
|
224
224
|
class SSM(eqx.Module):
|
225
|
-
vocab_embedding: eqx.nn.Embedding
|
226
|
-
output_layer: eqx.nn.Linear
|
227
|
-
blocks: list[BaseSSMBlock]
|
228
|
-
num_layers: int = eqx.
|
229
|
-
hidden_size: int = eqx.
|
230
|
-
skip_connections: bool = eqx.
|
225
|
+
vocab_embedding: eqx.nn.Embedding = eqx.field()
|
226
|
+
output_layer: eqx.nn.Linear = eqx.field()
|
227
|
+
blocks: list[BaseSSMBlock] = eqx.field()
|
228
|
+
num_layers: int = eqx.field()
|
229
|
+
hidden_size: int = eqx.field()
|
230
|
+
skip_connections: bool = eqx.field()
|
231
231
|
|
232
232
|
def __init__(
|
233
233
|
self,
|
xax/task/mixins/train.py
CHANGED
@@ -40,7 +40,12 @@ from xax.core.state import Phase, State
|
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import
|
43
|
+
from xax.task.mixins.checkpointing import (
|
44
|
+
CheckpointingConfig,
|
45
|
+
CheckpointingMixin,
|
46
|
+
CheckpointPart,
|
47
|
+
load_ckpt,
|
48
|
+
)
|
44
49
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
50
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
51
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
xax/utils/pytree.py
CHANGED
@@ -253,3 +253,16 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
|
253
253
|
mut = list(t)
|
254
254
|
mut[index] = value
|
255
255
|
return tuple(mut)
|
256
|
+
|
257
|
+
|
258
|
+
def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
|
259
|
+
leaves: dict[str, Array] = {}
|
260
|
+
|
261
|
+
def _get_leaf(path: tuple, x: PyTree) -> None:
|
262
|
+
if isinstance(x, jnp.ndarray):
|
263
|
+
# Convert path tuple to string, e.g. (1, 'a', 2) -> '1/a/2'
|
264
|
+
path_str = "/".join(str(p) for p in path)
|
265
|
+
leaves[path_str] = x
|
266
|
+
|
267
|
+
jax.tree.map_with_path(_get_leaf, pytree)
|
268
|
+
return leaves
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=9i6UlrAP1wLDh1lod-4ETWll4pcADIir_Tk3O6OvH7g,16336
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -8,14 +8,14 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
xax/nn/attention.py,sha256=
|
12
|
-
xax/nn/embeddings.py,sha256=
|
11
|
+
xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
|
12
|
+
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
13
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
|
-
xax/nn/geom.py,sha256=
|
14
|
+
xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
|
15
15
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
16
16
|
xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
|
17
17
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
18
|
-
xax/nn/ssm.py,sha256=
|
18
|
+
xax/nn/ssm.py,sha256=qSBv_FobnaFA5jt87OF5P2q5ih6sj4SlehhEhEFaPjA,10766
|
19
19
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
20
|
xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
|
21
21
|
xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
|
@@ -42,7 +42,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
42
42
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
43
43
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
44
44
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
45
|
-
xax/task/mixins/train.py,sha256=
|
45
|
+
xax/task/mixins/train.py,sha256=bjBoigTCjbq9H4hcqIO32irHBc9rC2zkgXrnGNI2RtI,33266
|
46
46
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
47
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
48
48
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
@@ -51,7 +51,7 @@ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
|
51
51
|
xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
|
52
52
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
53
53
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
54
|
-
xax/utils/pytree.py,sha256=
|
54
|
+
xax/utils/pytree.py,sha256=cLZRSd5xc-DqcbRfWnBy87pAiUU5fT8U4CHoLi_i_v4,9642
|
55
55
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
56
56
|
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
57
57
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
62
|
+
xax-0.3.6.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.3.6.dist-info/METADATA,sha256=PI1onOBOY7vwwjDdg_fDoQIDSQ6tyUfwDK3nPnE_fcE,1246
|
64
|
+
xax-0.3.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
xax-0.3.6.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.3.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.3.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|