xax 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -1
- xax/nn/attention.py +108 -17
- xax/utils/pytree.py +13 -0
- {xax-0.3.5.dist-info → xax-0.3.7.dist-info}/METADATA +1 -1
- {xax-0.3.5.dist-info → xax-0.3.7.dist-info}/RECORD +9 -9
- {xax-0.3.5.dist-info → xax-0.3.7.dist-info}/WHEEL +0 -0
- {xax-0.3.5.dist-info → xax-0.3.7.dist-info}/entry_points.txt +0 -0
- {xax-0.3.5.dist-info → xax-0.3.7.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.5.dist-info → xax-0.3.7.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.7"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -136,6 +136,7 @@ __all__ = [
|
|
136
136
|
"compute_nan_ratio",
|
137
137
|
"flatten_array",
|
138
138
|
"flatten_pytree",
|
139
|
+
"get_pytree_mapping",
|
139
140
|
"get_pytree_param_count",
|
140
141
|
"pytree_has_nans",
|
141
142
|
"reshuffle_pytree",
|
@@ -323,6 +324,7 @@ NAME_MAP: dict[str, str] = {
|
|
323
324
|
"compute_nan_ratio": "utils.pytree",
|
324
325
|
"flatten_array": "utils.pytree",
|
325
326
|
"flatten_pytree": "utils.pytree",
|
327
|
+
"get_pytree_mapping": "utils.pytree",
|
326
328
|
"get_pytree_param_count": "utils.pytree",
|
327
329
|
"pytree_has_nans": "utils.pytree",
|
328
330
|
"reshuffle_pytree": "utils.pytree",
|
@@ -509,6 +511,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
509
511
|
compute_nan_ratio,
|
510
512
|
flatten_array,
|
511
513
|
flatten_pytree,
|
514
|
+
get_pytree_mapping,
|
512
515
|
get_pytree_param_count,
|
513
516
|
pytree_has_nans,
|
514
517
|
reshuffle_pytree,
|
xax/nn/attention.py
CHANGED
@@ -212,16 +212,49 @@ class SelfAttentionBlock(eqx.Module):
|
|
212
212
|
|
213
213
|
return {"k": k_cache, "v": v_cache, "position": 0}
|
214
214
|
|
215
|
+
def init_mask(
|
216
|
+
self,
|
217
|
+
seq_len: int,
|
218
|
+
add_cache: bool = False,
|
219
|
+
batch_dim: bool = False,
|
220
|
+
) -> Array:
|
221
|
+
"""Initialize the attention matrix mask.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
seq_len: The length of the sequence
|
225
|
+
add_cache: Whether to add the cache to the mask
|
226
|
+
batch_dim: Whether to add a batch dimension to the mask
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
The attention matrix mask of shape (bsz, 1, seq_len, seq_len + cache_len)
|
230
|
+
if batch_dim is True, otherwise (seq_len, seq_len + cache_len).
|
231
|
+
"""
|
232
|
+
t, s, o = seq_len, seq_len, 0
|
233
|
+
if add_cache:
|
234
|
+
if self.local_window_size is None:
|
235
|
+
raise ValueError("local_window_size must be set for caching")
|
236
|
+
s += self.local_window_size
|
237
|
+
o -= self.local_window_size
|
238
|
+
mask = jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-o)
|
239
|
+
if self.local_window_size is not None:
|
240
|
+
neg_mask = ~jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-(self.local_window_size + 1 + o))
|
241
|
+
mask = mask & neg_mask
|
242
|
+
mask = mask.reshape(1, 1, t, s) if batch_dim else mask.reshape(t, s)
|
243
|
+
return mask
|
244
|
+
|
215
245
|
def forward(
|
216
246
|
self,
|
217
247
|
x_tn: Array,
|
218
248
|
*,
|
249
|
+
mask: Array | None = None,
|
219
250
|
cache: AttentionCache | None = None,
|
220
251
|
) -> tuple[Array, AttentionCache]:
|
221
252
|
"""Apply self-attention.
|
222
253
|
|
223
254
|
Args:
|
224
255
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
256
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
257
|
+
seq_len + cache_len)
|
225
258
|
cache: The cached key and value tensors (fixed-length)
|
226
259
|
|
227
260
|
Returns:
|
@@ -256,26 +289,28 @@ class SelfAttentionBlock(eqx.Module):
|
|
256
289
|
k = jnp.concatenate([k_cache, k], axis=0)
|
257
290
|
v = jnp.concatenate([v_cache, v], axis=0)
|
258
291
|
|
259
|
-
# Pads query with `k_cache.shape[0]` zeros.
|
260
|
-
q = jnp.pad(q, ((k_cache.shape[0], 0), (0, 0), (0, 0)), mode="constant", constant_values=0)
|
261
|
-
|
262
292
|
new_position = cache["position"] + seq_len
|
263
293
|
|
264
294
|
else:
|
265
295
|
new_position = seq_len
|
266
296
|
|
267
|
-
|
268
|
-
q,
|
269
|
-
k,
|
270
|
-
v,
|
271
|
-
scale=1.0 / math.sqrt(self.head_dim),
|
272
|
-
is_causal=self.causal,
|
273
|
-
local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
|
274
|
-
)
|
297
|
+
if seq_len == 1:
|
298
|
+
attn_output = jax.nn.dot_product_attention(q, k, v)
|
275
299
|
|
276
|
-
|
277
|
-
|
278
|
-
|
300
|
+
elif mask is not None:
|
301
|
+
attn_output = jax.nn.dot_product_attention(q, k, v, mask=mask)
|
302
|
+
|
303
|
+
elif cache is not None:
|
304
|
+
raise NotImplementedError("For training with a cache, provide a mask instead.")
|
305
|
+
|
306
|
+
else:
|
307
|
+
attn_output = jax.nn.dot_product_attention(
|
308
|
+
q,
|
309
|
+
k,
|
310
|
+
v,
|
311
|
+
is_causal=self.causal,
|
312
|
+
local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
|
313
|
+
)
|
279
314
|
|
280
315
|
attn_output = self._combine_heads(attn_output)
|
281
316
|
output = jax.vmap(self.output_proj)(attn_output)
|
@@ -403,6 +438,7 @@ class CrossAttentionBlock(eqx.Module):
|
|
403
438
|
q_rot,
|
404
439
|
k_rot,
|
405
440
|
v,
|
441
|
+
scale=1.0 / math.sqrt(self.head_dim),
|
406
442
|
is_causal=False,
|
407
443
|
)
|
408
444
|
|
@@ -498,11 +534,24 @@ class TransformerBlock(eqx.Module):
|
|
498
534
|
cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
|
499
535
|
return cache
|
500
536
|
|
537
|
+
def init_mask(
|
538
|
+
self,
|
539
|
+
seq_len: int,
|
540
|
+
add_cache: bool = False,
|
541
|
+
batch_dim: bool = False,
|
542
|
+
) -> Array:
|
543
|
+
return self.self_attn.init_mask(
|
544
|
+
seq_len,
|
545
|
+
add_cache=add_cache,
|
546
|
+
batch_dim=batch_dim,
|
547
|
+
)
|
548
|
+
|
501
549
|
def forward(
|
502
550
|
self,
|
503
551
|
x_tn: Array,
|
504
552
|
*,
|
505
553
|
context_sn: Array | None = None,
|
554
|
+
mask: Array | None = None,
|
506
555
|
cache: AttentionCacheDict | None = None,
|
507
556
|
) -> tuple[Array, AttentionCacheDict]:
|
508
557
|
"""Apply transformer block.
|
@@ -510,6 +559,8 @@ class TransformerBlock(eqx.Module):
|
|
510
559
|
Args:
|
511
560
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
512
561
|
context_sn: Optional context for cross-attention
|
562
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
563
|
+
seq_len + cache_len)
|
513
564
|
cache: Optional dictionary containing cached key and value tensors
|
514
565
|
|
515
566
|
Returns:
|
@@ -522,6 +573,7 @@ class TransformerBlock(eqx.Module):
|
|
522
573
|
|
523
574
|
attn_output, self_attn_cache = self.self_attn.forward(
|
524
575
|
x_tn=norm_x,
|
576
|
+
mask=mask,
|
525
577
|
cache=None if cache is None else cache["self_attn"],
|
526
578
|
)
|
527
579
|
updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
|
@@ -598,11 +650,24 @@ class TransformerStack(eqx.Module):
|
|
598
650
|
cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
|
599
651
|
return {"layers": cache}
|
600
652
|
|
653
|
+
def init_mask(
|
654
|
+
self,
|
655
|
+
seq_len: int,
|
656
|
+
add_cache: bool = False,
|
657
|
+
batch_dim: bool = False,
|
658
|
+
) -> Array:
|
659
|
+
return self.layers[0].init_mask(
|
660
|
+
seq_len,
|
661
|
+
add_cache=add_cache,
|
662
|
+
batch_dim=batch_dim,
|
663
|
+
)
|
664
|
+
|
601
665
|
def forward(
|
602
666
|
self,
|
603
667
|
x_tn: Array,
|
604
668
|
*,
|
605
669
|
context_sn: Array | None = None,
|
670
|
+
mask: Array | None = None,
|
606
671
|
cache: TransformerCache | None = None,
|
607
672
|
) -> tuple[Array, TransformerCache]:
|
608
673
|
"""Apply transformer stack.
|
@@ -610,6 +675,8 @@ class TransformerStack(eqx.Module):
|
|
610
675
|
Args:
|
611
676
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
612
677
|
context_sn: Optional context for cross-attention
|
678
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
679
|
+
seq_len + cache_len)
|
613
680
|
cache: Optional dictionary containing cached key and value tensors
|
614
681
|
|
615
682
|
Returns:
|
@@ -629,6 +696,7 @@ class TransformerStack(eqx.Module):
|
|
629
696
|
x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
|
630
697
|
x_tn,
|
631
698
|
context_sn=context_sn,
|
699
|
+
mask=mask,
|
632
700
|
cache=layer_cache,
|
633
701
|
)
|
634
702
|
|
@@ -693,16 +761,31 @@ class Transformer(eqx.Module):
|
|
693
761
|
"""Initialize cache for the input."""
|
694
762
|
return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
|
695
763
|
|
764
|
+
def init_mask(
|
765
|
+
self,
|
766
|
+
seq_len: int,
|
767
|
+
add_cache: bool = False,
|
768
|
+
batch_dim: bool = False,
|
769
|
+
) -> Array:
|
770
|
+
return self.layers.init_mask(
|
771
|
+
seq_len,
|
772
|
+
add_cache=add_cache,
|
773
|
+
batch_dim=batch_dim,
|
774
|
+
)
|
775
|
+
|
696
776
|
def encode(
|
697
777
|
self,
|
698
778
|
x: Array,
|
699
779
|
*,
|
780
|
+
mask: Array | None = None,
|
700
781
|
cache: TransformerCache | None = None,
|
701
782
|
) -> tuple[Array, TransformerCache]:
|
702
783
|
"""Encode the input sequence.
|
703
784
|
|
704
785
|
Args:
|
705
786
|
x: Input token indices of shape (seq_len)
|
787
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
788
|
+
seq_len + cache_len)
|
706
789
|
cache: Optional dictionary containing cached key and value tensors
|
707
790
|
|
708
791
|
Returns:
|
@@ -712,7 +795,7 @@ class Transformer(eqx.Module):
|
|
712
795
|
x_embedded = jax.vmap(self.token_embedding)(x)
|
713
796
|
|
714
797
|
# Apply transformer stack
|
715
|
-
x_embedded, updated_cache = self.layers.forward(x_embedded, cache=cache)
|
798
|
+
x_embedded, updated_cache = self.layers.forward(x_embedded, mask=mask, cache=cache)
|
716
799
|
|
717
800
|
# Apply final layer norm
|
718
801
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
@@ -724,6 +807,7 @@ class Transformer(eqx.Module):
|
|
724
807
|
x_t: Array,
|
725
808
|
context_s: Array,
|
726
809
|
*,
|
810
|
+
mask: Array | None = None,
|
727
811
|
cache: TransformerCache | None = None,
|
728
812
|
) -> tuple[Array, TransformerCache]:
|
729
813
|
"""Decode with self-attention and cross-attention.
|
@@ -732,6 +816,8 @@ class Transformer(eqx.Module):
|
|
732
816
|
x_t: Input token indices, shape (seq_len)
|
733
817
|
context_s: Context from encoder (token indices or embedded),
|
734
818
|
shape (context_len, embed_dim)
|
819
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
820
|
+
seq_len + cache_len)
|
735
821
|
cache: Optional dictionary containing cached key and value tensors
|
736
822
|
|
737
823
|
Returns:
|
@@ -747,6 +833,7 @@ class Transformer(eqx.Module):
|
|
747
833
|
x_embedded, updated_cache = self.layers.forward(
|
748
834
|
x_embedded,
|
749
835
|
context_sn=context_embedded,
|
836
|
+
mask=mask,
|
750
837
|
cache=cache,
|
751
838
|
)
|
752
839
|
|
@@ -759,12 +846,15 @@ class Transformer(eqx.Module):
|
|
759
846
|
self,
|
760
847
|
x: Array,
|
761
848
|
*,
|
849
|
+
mask: Array | None = None,
|
762
850
|
cache: TransformerCache | None = None,
|
763
851
|
) -> tuple[Array, TransformerCache]:
|
764
852
|
"""Forward pass for encoder-only or decoder-only transformers.
|
765
853
|
|
766
854
|
Args:
|
767
855
|
x: Input token indices of shape (seq_len)
|
856
|
+
mask: Optional mask of shape (batch_size, num_heads, seq_len,
|
857
|
+
seq_len + cache_len)
|
768
858
|
cache: Optional dictionary containing cached key and value tensors
|
769
859
|
|
770
860
|
Returns:
|
@@ -772,7 +862,7 @@ class Transformer(eqx.Module):
|
|
772
862
|
"""
|
773
863
|
chex.assert_rank(x, 1)
|
774
864
|
|
775
|
-
output, updated_cache = self.encode(x, cache=cache)
|
865
|
+
output, updated_cache = self.encode(x, mask=mask, cache=cache)
|
776
866
|
|
777
867
|
# Apply output layer if it exists
|
778
868
|
if self.output_layer is not None:
|
@@ -817,7 +907,8 @@ class Transformer(eqx.Module):
|
|
817
907
|
|
818
908
|
# Initialize cache with prompt
|
819
909
|
cache = self.init_cache()
|
820
|
-
|
910
|
+
mask = self.init_mask(prompt_len, add_cache=True, batch_dim=False)
|
911
|
+
_, cache = self.encode(prompt_seq, cache=cache, mask=mask)
|
821
912
|
|
822
913
|
# Define scan function for autoregressive generation
|
823
914
|
def scan_fn(
|
xax/utils/pytree.py
CHANGED
@@ -253,3 +253,16 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
|
253
253
|
mut = list(t)
|
254
254
|
mut[index] = value
|
255
255
|
return tuple(mut)
|
256
|
+
|
257
|
+
|
258
|
+
def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
|
259
|
+
leaves: dict[str, Array] = {}
|
260
|
+
|
261
|
+
def _get_leaf(path: tuple, x: PyTree) -> None:
|
262
|
+
if isinstance(x, jnp.ndarray):
|
263
|
+
# Convert path tuple to string, e.g. (1, 'a', 2) -> '1/a/2'
|
264
|
+
path_str = "/".join(str(p) for p in path)
|
265
|
+
leaves[path_str] = x
|
266
|
+
|
267
|
+
jax.tree.map_with_path(_get_leaf, pytree)
|
268
|
+
return leaves
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=YCDjLRwliJCyEmNFC56PNQXV9Vn9Fr13VJS_am4h3To,16336
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -8,7 +8,7 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
xax/nn/attention.py,sha256=
|
11
|
+
xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
|
12
12
|
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
13
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
14
|
xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
|
@@ -51,7 +51,7 @@ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
|
51
51
|
xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
|
52
52
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
53
53
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
54
|
-
xax/utils/pytree.py,sha256=
|
54
|
+
xax/utils/pytree.py,sha256=cLZRSd5xc-DqcbRfWnBy87pAiUU5fT8U4CHoLi_i_v4,9642
|
55
55
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
56
56
|
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
57
57
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
62
|
+
xax-0.3.7.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.3.7.dist-info/METADATA,sha256=8Zb0pvTJOjrCHK7giM2MbhlGCPREQewJK3GgRDQNWY0,1246
|
64
|
+
xax-0.3.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
xax-0.3.7.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.3.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.3.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|