xax 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/nn/attention.py +56 -95
- xax/nn/embeddings.py +10 -10
- xax/nn/geom.py +5 -5
- xax/nn/ssm.py +6 -6
- xax/task/mixins/train.py +6 -1
- {xax-0.3.4.dist-info → xax-0.3.5.dist-info}/METADATA +1 -1
- {xax-0.3.4.dist-info → xax-0.3.5.dist-info}/RECORD +12 -12
- {xax-0.3.4.dist-info → xax-0.3.5.dist-info}/WHEEL +0 -0
- {xax-0.3.4.dist-info → xax-0.3.5.dist-info}/entry_points.txt +0 -0
- {xax-0.3.4.dist-info → xax-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.4.dist-info → xax-0.3.5.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/nn/attention.py
CHANGED
@@ -5,6 +5,8 @@ supporting a fixed-size context window and caching that can be used to train
|
|
5
5
|
transformers which can be unrolled with a fixed-length cache.
|
6
6
|
"""
|
7
7
|
|
8
|
+
import math
|
9
|
+
import warnings
|
8
10
|
from typing import NotRequired, TypedDict
|
9
11
|
|
10
12
|
import chex
|
@@ -13,6 +15,8 @@ import jax
|
|
13
15
|
import jax.numpy as jnp
|
14
16
|
from jaxtyping import Array, PRNGKeyArray
|
15
17
|
|
18
|
+
from xax.utils.jax import scan as xax_scan
|
19
|
+
|
16
20
|
|
17
21
|
class RotaryEmbedding(eqx.Module):
|
18
22
|
"""Rotary Position Embedding (RoPE) for transformer attention.
|
@@ -22,8 +26,8 @@ class RotaryEmbedding(eqx.Module):
|
|
22
26
|
https://arxiv.org/abs/2104.09864
|
23
27
|
"""
|
24
28
|
|
25
|
-
head_dim: int = eqx.
|
26
|
-
base: float = eqx.
|
29
|
+
head_dim: int = eqx.field()
|
30
|
+
base: float = eqx.field()
|
27
31
|
|
28
32
|
def __init__(
|
29
33
|
self,
|
@@ -125,15 +129,15 @@ class TransformerCache(TypedDict):
|
|
125
129
|
class SelfAttentionBlock(eqx.Module):
|
126
130
|
"""Self-attention block using jax.nn.dot_product_attention."""
|
127
131
|
|
128
|
-
q_proj: eqx.nn.Linear
|
129
|
-
k_proj: eqx.nn.Linear
|
130
|
-
v_proj: eqx.nn.Linear
|
131
|
-
output_proj: eqx.nn.Linear
|
132
|
-
rotary_emb: RotaryEmbedding | None
|
133
|
-
num_heads: int = eqx.
|
134
|
-
head_dim: int = eqx.
|
135
|
-
causal: bool = eqx.
|
136
|
-
|
132
|
+
q_proj: eqx.nn.Linear = eqx.field()
|
133
|
+
k_proj: eqx.nn.Linear = eqx.field()
|
134
|
+
v_proj: eqx.nn.Linear = eqx.field()
|
135
|
+
output_proj: eqx.nn.Linear = eqx.field()
|
136
|
+
rotary_emb: RotaryEmbedding | None = eqx.field()
|
137
|
+
num_heads: int = eqx.field()
|
138
|
+
head_dim: int = eqx.field()
|
139
|
+
causal: bool = eqx.field()
|
140
|
+
local_window_size: int | None = eqx.field()
|
137
141
|
|
138
142
|
def __init__(
|
139
143
|
self,
|
@@ -169,8 +173,12 @@ class SelfAttentionBlock(eqx.Module):
|
|
169
173
|
else:
|
170
174
|
self.rotary_emb = None
|
171
175
|
|
176
|
+
if context_length is not None and not causal:
|
177
|
+
warnings.warn("context_length is set but causal is False; overriding causal to True", stacklevel=2)
|
178
|
+
causal = True
|
179
|
+
|
172
180
|
self.causal = causal
|
173
|
-
self.
|
181
|
+
self.local_window_size = None if context_length is None else context_length - 1
|
174
182
|
|
175
183
|
@property
|
176
184
|
def embed_dim(self) -> int:
|
@@ -195,41 +203,25 @@ class SelfAttentionBlock(eqx.Module):
|
|
195
203
|
Returns:
|
196
204
|
Cache with fixed-length k and v tensors
|
197
205
|
"""
|
198
|
-
if self.
|
206
|
+
if self.local_window_size is None:
|
199
207
|
raise ValueError("context_length must be set for caching")
|
200
208
|
|
201
209
|
# Create fixed-length cache
|
202
|
-
k_cache = jnp.zeros((self.
|
203
|
-
v_cache = jnp.zeros((self.
|
210
|
+
k_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
|
211
|
+
v_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
|
204
212
|
|
205
213
|
return {"k": k_cache, "v": v_cache, "position": 0}
|
206
214
|
|
207
|
-
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
208
|
-
in_dim, out_dim = seq_len, seq_len
|
209
|
-
if with_cache:
|
210
|
-
if self.context_length is None:
|
211
|
-
raise ValueError("context_length must be set for caching")
|
212
|
-
in_dim = in_dim + self.context_length - 1
|
213
|
-
|
214
|
-
mask = jnp.tril(jnp.ones((in_dim, out_dim)))
|
215
|
-
if self.context_length is not None:
|
216
|
-
neg_mask = 1 - jnp.tril(jnp.ones((in_dim, out_dim)), -self.context_length)
|
217
|
-
mask = mask * neg_mask
|
218
|
-
|
219
|
-
return mask.astype(jnp.bool_).transpose()
|
220
|
-
|
221
215
|
def forward(
|
222
216
|
self,
|
223
217
|
x_tn: Array,
|
224
218
|
*,
|
225
|
-
mask: Array | None = None,
|
226
219
|
cache: AttentionCache | None = None,
|
227
220
|
) -> tuple[Array, AttentionCache]:
|
228
221
|
"""Apply self-attention.
|
229
222
|
|
230
223
|
Args:
|
231
224
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
232
|
-
mask: Optional mask tensor
|
233
225
|
cache: The cached key and value tensors (fixed-length)
|
234
226
|
|
235
227
|
Returns:
|
@@ -263,6 +255,10 @@ class SelfAttentionBlock(eqx.Module):
|
|
263
255
|
v_cache = cache["v"]
|
264
256
|
k = jnp.concatenate([k_cache, k], axis=0)
|
265
257
|
v = jnp.concatenate([v_cache, v], axis=0)
|
258
|
+
|
259
|
+
# Pads query with `k_cache.shape[0]` zeros.
|
260
|
+
q = jnp.pad(q, ((k_cache.shape[0], 0), (0, 0), (0, 0)), mode="constant", constant_values=0)
|
261
|
+
|
266
262
|
new_position = cache["position"] + seq_len
|
267
263
|
|
268
264
|
else:
|
@@ -272,16 +268,21 @@ class SelfAttentionBlock(eqx.Module):
|
|
272
268
|
q,
|
273
269
|
k,
|
274
270
|
v,
|
275
|
-
|
276
|
-
is_causal=self.causal
|
271
|
+
scale=1.0 / math.sqrt(self.head_dim),
|
272
|
+
is_causal=self.causal,
|
273
|
+
local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
|
277
274
|
)
|
278
275
|
|
276
|
+
if cache is not None:
|
277
|
+
# Remove the padding.
|
278
|
+
attn_output = attn_output[cache["k"].shape[0] :]
|
279
|
+
|
279
280
|
attn_output = self._combine_heads(attn_output)
|
280
281
|
output = jax.vmap(self.output_proj)(attn_output)
|
281
282
|
|
282
|
-
if self.
|
283
|
-
k = k[-
|
284
|
-
v = v[-
|
283
|
+
if self.local_window_size is not None:
|
284
|
+
k = k[-self.local_window_size :]
|
285
|
+
v = v[-self.local_window_size :]
|
285
286
|
|
286
287
|
return output, {"k": k, "v": v, "position": new_position}
|
287
288
|
|
@@ -294,8 +295,8 @@ class CrossAttentionBlock(eqx.Module):
|
|
294
295
|
v_proj: eqx.nn.Linear
|
295
296
|
output_proj: eqx.nn.Linear
|
296
297
|
rotary_emb: RotaryEmbedding | None
|
297
|
-
num_heads: int = eqx.
|
298
|
-
head_dim: int = eqx.
|
298
|
+
num_heads: int = eqx.field()
|
299
|
+
head_dim: int = eqx.field()
|
299
300
|
|
300
301
|
def __init__(
|
301
302
|
self,
|
@@ -352,7 +353,6 @@ class CrossAttentionBlock(eqx.Module):
|
|
352
353
|
*,
|
353
354
|
kv_sn: Array | None = None,
|
354
355
|
cache: AttentionCache | None = None,
|
355
|
-
mask: Array | None = None,
|
356
356
|
) -> tuple[Array, AttentionCache]:
|
357
357
|
"""Apply cross-attention.
|
358
358
|
|
@@ -362,7 +362,6 @@ class CrossAttentionBlock(eqx.Module):
|
|
362
362
|
If not provided, then `cache` must be provided.
|
363
363
|
cache: The cached key and value tensors. If not provided, then
|
364
364
|
`kv_sn` must be provided.
|
365
|
-
mask: Optional mask tensor
|
366
365
|
|
367
366
|
Returns:
|
368
367
|
The output tensor of shape (q_seq_len, embed_dim)
|
@@ -404,7 +403,6 @@ class CrossAttentionBlock(eqx.Module):
|
|
404
403
|
q_rot,
|
405
404
|
k_rot,
|
406
405
|
v,
|
407
|
-
mask=mask,
|
408
406
|
is_causal=False,
|
409
407
|
)
|
410
408
|
|
@@ -424,10 +422,10 @@ class TransformerBlock(eqx.Module):
|
|
424
422
|
layer_norm1: eqx.nn.LayerNorm
|
425
423
|
layer_norm2: eqx.nn.LayerNorm
|
426
424
|
layer_norm3: eqx.nn.LayerNorm | None
|
427
|
-
num_heads: int = eqx.
|
428
|
-
head_dim: int = eqx.
|
429
|
-
causal: bool = eqx.
|
430
|
-
context_length: int | None = eqx.
|
425
|
+
num_heads: int = eqx.field()
|
426
|
+
head_dim: int = eqx.field()
|
427
|
+
causal: bool = eqx.field()
|
428
|
+
context_length: int | None = eqx.field()
|
431
429
|
|
432
430
|
def __init__(
|
433
431
|
self,
|
@@ -500,16 +498,11 @@ class TransformerBlock(eqx.Module):
|
|
500
498
|
cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
|
501
499
|
return cache
|
502
500
|
|
503
|
-
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
504
|
-
return self.self_attn.init_mask(seq_len, with_cache=with_cache)
|
505
|
-
|
506
501
|
def forward(
|
507
502
|
self,
|
508
503
|
x_tn: Array,
|
509
504
|
*,
|
510
505
|
context_sn: Array | None = None,
|
511
|
-
self_mask: Array | None = None,
|
512
|
-
cross_mask: Array | None = None,
|
513
506
|
cache: AttentionCacheDict | None = None,
|
514
507
|
) -> tuple[Array, AttentionCacheDict]:
|
515
508
|
"""Apply transformer block.
|
@@ -517,8 +510,6 @@ class TransformerBlock(eqx.Module):
|
|
517
510
|
Args:
|
518
511
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
519
512
|
context_sn: Optional context for cross-attention
|
520
|
-
self_mask: Mask for self-attention
|
521
|
-
cross_mask: Mask for cross-attention
|
522
513
|
cache: Optional dictionary containing cached key and value tensors
|
523
514
|
|
524
515
|
Returns:
|
@@ -531,7 +522,6 @@ class TransformerBlock(eqx.Module):
|
|
531
522
|
|
532
523
|
attn_output, self_attn_cache = self.self_attn.forward(
|
533
524
|
x_tn=norm_x,
|
534
|
-
mask=self_mask,
|
535
525
|
cache=None if cache is None else cache["self_attn"],
|
536
526
|
)
|
537
527
|
updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
|
@@ -547,7 +537,6 @@ class TransformerBlock(eqx.Module):
|
|
547
537
|
cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
|
548
538
|
q_tn=norm_x,
|
549
539
|
kv_sn=context_sn,
|
550
|
-
mask=cross_mask,
|
551
540
|
cache=None if cache is None else cache.get("cross_attn"),
|
552
541
|
)
|
553
542
|
|
@@ -564,9 +553,9 @@ class TransformerBlock(eqx.Module):
|
|
564
553
|
class TransformerStack(eqx.Module):
|
565
554
|
"""A stack of transformer blocks."""
|
566
555
|
|
567
|
-
layers:
|
568
|
-
num_layers: int = eqx.
|
569
|
-
causal: bool = eqx.
|
556
|
+
layers: tuple[TransformerBlock, ...]
|
557
|
+
num_layers: int = eqx.field()
|
558
|
+
causal: bool = eqx.field()
|
570
559
|
|
571
560
|
def __init__(
|
572
561
|
self,
|
@@ -584,7 +573,7 @@ class TransformerStack(eqx.Module):
|
|
584
573
|
) -> None:
|
585
574
|
keys = jax.random.split(key, num_layers)
|
586
575
|
|
587
|
-
self.layers =
|
576
|
+
self.layers = tuple(
|
588
577
|
TransformerBlock(
|
589
578
|
embed_dim=embed_dim,
|
590
579
|
num_heads=num_heads,
|
@@ -597,7 +586,7 @@ class TransformerStack(eqx.Module):
|
|
597
586
|
rotary_base=rotary_base,
|
598
587
|
)
|
599
588
|
for i in range(num_layers)
|
600
|
-
|
589
|
+
)
|
601
590
|
|
602
591
|
self.num_layers = num_layers
|
603
592
|
self.causal = causal
|
@@ -609,16 +598,11 @@ class TransformerStack(eqx.Module):
|
|
609
598
|
cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
|
610
599
|
return {"layers": cache}
|
611
600
|
|
612
|
-
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
613
|
-
return self.layers[0].init_mask(seq_len, with_cache=with_cache)
|
614
|
-
|
615
601
|
def forward(
|
616
602
|
self,
|
617
603
|
x_tn: Array,
|
618
604
|
*,
|
619
605
|
context_sn: Array | None = None,
|
620
|
-
self_mask: Array | None = None,
|
621
|
-
cross_mask: Array | None = None,
|
622
606
|
cache: TransformerCache | None = None,
|
623
607
|
) -> tuple[Array, TransformerCache]:
|
624
608
|
"""Apply transformer stack.
|
@@ -626,8 +610,6 @@ class TransformerStack(eqx.Module):
|
|
626
610
|
Args:
|
627
611
|
x_tn: Input tensor of shape (seq_len, embed_dim)
|
628
612
|
context_sn: Optional context for cross-attention
|
629
|
-
self_mask: Mask for self-attention
|
630
|
-
cross_mask: Mask for cross-attention
|
631
613
|
cache: Optional dictionary containing cached key and value tensors
|
632
614
|
|
633
615
|
Returns:
|
@@ -647,8 +629,6 @@ class TransformerStack(eqx.Module):
|
|
647
629
|
x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
|
648
630
|
x_tn,
|
649
631
|
context_sn=context_sn,
|
650
|
-
self_mask=self_mask,
|
651
|
-
cross_mask=cross_mask,
|
652
632
|
cache=layer_cache,
|
653
633
|
)
|
654
634
|
|
@@ -660,9 +640,9 @@ class Transformer(eqx.Module):
|
|
660
640
|
layers: TransformerStack
|
661
641
|
output_layer: eqx.nn.Linear | None
|
662
642
|
layer_norm: eqx.nn.LayerNorm
|
663
|
-
embed_dim: int = eqx.
|
664
|
-
causal: bool = eqx.
|
665
|
-
context_length: int | None = eqx.
|
643
|
+
embed_dim: int = eqx.field()
|
644
|
+
causal: bool = eqx.field()
|
645
|
+
context_length: int | None = eqx.field()
|
666
646
|
|
667
647
|
def __init__(
|
668
648
|
self,
|
@@ -713,21 +693,16 @@ class Transformer(eqx.Module):
|
|
713
693
|
"""Initialize cache for the input."""
|
714
694
|
return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
|
715
695
|
|
716
|
-
def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
|
717
|
-
return self.layers.init_mask(seq_len, with_cache=with_cache)
|
718
|
-
|
719
696
|
def encode(
|
720
697
|
self,
|
721
698
|
x: Array,
|
722
699
|
*,
|
723
|
-
mask: Array | None = None,
|
724
700
|
cache: TransformerCache | None = None,
|
725
701
|
) -> tuple[Array, TransformerCache]:
|
726
702
|
"""Encode the input sequence.
|
727
703
|
|
728
704
|
Args:
|
729
705
|
x: Input token indices of shape (seq_len)
|
730
|
-
mask: Optional attention mask
|
731
706
|
cache: Optional dictionary containing cached key and value tensors
|
732
707
|
|
733
708
|
Returns:
|
@@ -737,11 +712,7 @@ class Transformer(eqx.Module):
|
|
737
712
|
x_embedded = jax.vmap(self.token_embedding)(x)
|
738
713
|
|
739
714
|
# Apply transformer stack
|
740
|
-
x_embedded, updated_cache = self.layers.forward(
|
741
|
-
x_embedded,
|
742
|
-
self_mask=mask,
|
743
|
-
cache=cache,
|
744
|
-
)
|
715
|
+
x_embedded, updated_cache = self.layers.forward(x_embedded, cache=cache)
|
745
716
|
|
746
717
|
# Apply final layer norm
|
747
718
|
output = jax.vmap(self.layer_norm)(x_embedded)
|
@@ -753,8 +724,6 @@ class Transformer(eqx.Module):
|
|
753
724
|
x_t: Array,
|
754
725
|
context_s: Array,
|
755
726
|
*,
|
756
|
-
self_mask: Array | None = None,
|
757
|
-
cross_mask: Array | None = None,
|
758
727
|
cache: TransformerCache | None = None,
|
759
728
|
) -> tuple[Array, TransformerCache]:
|
760
729
|
"""Decode with self-attention and cross-attention.
|
@@ -763,8 +732,6 @@ class Transformer(eqx.Module):
|
|
763
732
|
x_t: Input token indices, shape (seq_len)
|
764
733
|
context_s: Context from encoder (token indices or embedded),
|
765
734
|
shape (context_len, embed_dim)
|
766
|
-
self_mask: Optional self-attention mask, shape (seq_len, seq_len)
|
767
|
-
cross_mask: Optional cross-attention mask, shape (seq_len, context_len)
|
768
735
|
cache: Optional dictionary containing cached key and value tensors
|
769
736
|
|
770
737
|
Returns:
|
@@ -780,8 +747,6 @@ class Transformer(eqx.Module):
|
|
780
747
|
x_embedded, updated_cache = self.layers.forward(
|
781
748
|
x_embedded,
|
782
749
|
context_sn=context_embedded,
|
783
|
-
self_mask=self_mask,
|
784
|
-
cross_mask=cross_mask,
|
785
750
|
cache=cache,
|
786
751
|
)
|
787
752
|
|
@@ -794,14 +759,12 @@ class Transformer(eqx.Module):
|
|
794
759
|
self,
|
795
760
|
x: Array,
|
796
761
|
*,
|
797
|
-
mask: Array | None = None,
|
798
762
|
cache: TransformerCache | None = None,
|
799
763
|
) -> tuple[Array, TransformerCache]:
|
800
764
|
"""Forward pass for encoder-only or decoder-only transformers.
|
801
765
|
|
802
766
|
Args:
|
803
767
|
x: Input token indices of shape (seq_len)
|
804
|
-
mask: Optional attention mask
|
805
768
|
cache: Optional dictionary containing cached key and value tensors
|
806
769
|
|
807
770
|
Returns:
|
@@ -809,11 +772,7 @@ class Transformer(eqx.Module):
|
|
809
772
|
"""
|
810
773
|
chex.assert_rank(x, 1)
|
811
774
|
|
812
|
-
output, updated_cache = self.encode(
|
813
|
-
x,
|
814
|
-
mask=mask,
|
815
|
-
cache=cache,
|
816
|
-
)
|
775
|
+
output, updated_cache = self.encode(x, cache=cache)
|
817
776
|
|
818
777
|
# Apply output layer if it exists
|
819
778
|
if self.output_layer is not None:
|
@@ -832,6 +791,7 @@ class Transformer(eqx.Module):
|
|
832
791
|
temperature: float = 1.0,
|
833
792
|
top_k: int | None = None,
|
834
793
|
key: PRNGKeyArray | None = None,
|
794
|
+
jit_level: int | None = None,
|
835
795
|
) -> Array:
|
836
796
|
"""Generate a sequence autoregressively with KV caching.
|
837
797
|
|
@@ -841,6 +801,7 @@ class Transformer(eqx.Module):
|
|
841
801
|
temperature: Sampling temperature
|
842
802
|
top_k: Optional top-k sampling parameter
|
843
803
|
key: PRNG key for sampling
|
804
|
+
jit_level: JIT level for the scan function
|
844
805
|
|
845
806
|
Returns:
|
846
807
|
Generated sequence of shape (prompt_len + max_len,)
|
@@ -884,5 +845,5 @@ class Transformer(eqx.Module):
|
|
884
845
|
return (new_output_seq, pos + 1, new_cache, rng), next_token
|
885
846
|
|
886
847
|
init_carry = (output_seq, prompt_len - 1, cache, key)
|
887
|
-
(final_seq, _, _, _), _ =
|
848
|
+
(final_seq, _, _, _), _ = xax_scan(scan_fn, init_carry, length=max_len, jit_level=jit_level)
|
888
849
|
return final_seq
|
xax/nn/embeddings.py
CHANGED
@@ -33,10 +33,10 @@ class LearnedPositionalEmbeddings(eqx.Module):
|
|
33
33
|
learnable: Whether the embeddings are learnable.
|
34
34
|
"""
|
35
35
|
|
36
|
-
max_tsz: int = eqx.field(
|
37
|
-
embed_dim: int = eqx.field(
|
38
|
-
learnable: bool = eqx.field(
|
39
|
-
embeddings_tc: Array
|
36
|
+
max_tsz: int = eqx.field()
|
37
|
+
embed_dim: int = eqx.field()
|
38
|
+
learnable: bool = eqx.field()
|
39
|
+
embeddings_tc: Array = eqx.field()
|
40
40
|
|
41
41
|
def __init__(
|
42
42
|
self,
|
@@ -74,10 +74,10 @@ class SinusoidalEmbeddings(eqx.Module):
|
|
74
74
|
base: The base for the sinusoidal embeddings.
|
75
75
|
"""
|
76
76
|
|
77
|
-
base: int = eqx.field(
|
78
|
-
max_tsz: int | None = eqx.field(
|
79
|
-
embed_dim: int | None = eqx.field(
|
80
|
-
embeddings_tc: Array | None
|
77
|
+
base: int = eqx.field()
|
78
|
+
max_tsz: int | None = eqx.field()
|
79
|
+
embed_dim: int | None = eqx.field()
|
80
|
+
embeddings_tc: Array | None = eqx.field()
|
81
81
|
|
82
82
|
def __init__(
|
83
83
|
self,
|
@@ -91,8 +91,8 @@ class SinusoidalEmbeddings(eqx.Module):
|
|
91
91
|
self.max_tsz = max_tsz
|
92
92
|
self.embed_dim = embed_dim
|
93
93
|
self.base = base
|
94
|
+
self.embeddings_tc = None
|
94
95
|
|
95
|
-
self.embeddings_tc: Array | None = None
|
96
96
|
if learnable:
|
97
97
|
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
|
98
98
|
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
|
@@ -192,7 +192,7 @@ class RotaryEmbeddings(eqx.Module):
|
|
192
192
|
base: The base for the sinusoidal embeddings.
|
193
193
|
"""
|
194
194
|
|
195
|
-
base: int = eqx.field(
|
195
|
+
base: int = eqx.field()
|
196
196
|
|
197
197
|
def __init__(self, base: int = 10_000) -> None:
|
198
198
|
"""Defines a rotary embeddings module.
|
xax/nn/geom.py
CHANGED
@@ -207,7 +207,7 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
|
207
207
|
|
208
208
|
def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
|
209
209
|
norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
|
210
|
-
return v / jnp.clip(norm,
|
210
|
+
return v / jnp.clip(norm, min=eps)
|
211
211
|
|
212
212
|
|
213
213
|
def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
|
@@ -299,28 +299,28 @@ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
|
|
299
299
|
trace = m00 + m11 + m22
|
300
300
|
|
301
301
|
# Case 0: trace is positive
|
302
|
-
s0 = jnp.sqrt(jnp.clip(trace + 1.0,
|
302
|
+
s0 = jnp.sqrt(jnp.clip(trace + 1.0, min=0.0)) * 2.0 # S = 4 * qw
|
303
303
|
w0 = 0.25 * s0
|
304
304
|
x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
|
305
305
|
y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
|
306
306
|
z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
|
307
307
|
|
308
308
|
# Case 1: m00 is the largest diagonal term
|
309
|
-
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22,
|
309
|
+
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, min=0.0)) * 2.0 # S = 4 * qx
|
310
310
|
w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
|
311
311
|
x1 = 0.25 * s1
|
312
312
|
y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
|
313
313
|
z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
|
314
314
|
|
315
315
|
# Case 2: m11 is the largest diagonal term
|
316
|
-
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22,
|
316
|
+
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, min=0.0)) * 2.0 # S = 4 * qy
|
317
317
|
w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
|
318
318
|
x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
|
319
319
|
y2 = 0.25 * s2
|
320
320
|
z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
|
321
321
|
|
322
322
|
# Case 3: m22 is the largest diagonal term
|
323
|
-
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11,
|
323
|
+
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, min=0.0)) * 2.0 # S = 4 * qz
|
324
324
|
w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
|
325
325
|
x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
|
326
326
|
y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
|
xax/nn/ssm.py
CHANGED
@@ -222,12 +222,12 @@ class DiscreteDiagSSMBlock(DiagSSMBlock):
|
|
222
222
|
|
223
223
|
|
224
224
|
class SSM(eqx.Module):
|
225
|
-
vocab_embedding: eqx.nn.Embedding
|
226
|
-
output_layer: eqx.nn.Linear
|
227
|
-
blocks: list[BaseSSMBlock]
|
228
|
-
num_layers: int = eqx.
|
229
|
-
hidden_size: int = eqx.
|
230
|
-
skip_connections: bool = eqx.
|
225
|
+
vocab_embedding: eqx.nn.Embedding = eqx.field()
|
226
|
+
output_layer: eqx.nn.Linear = eqx.field()
|
227
|
+
blocks: list[BaseSSMBlock] = eqx.field()
|
228
|
+
num_layers: int = eqx.field()
|
229
|
+
hidden_size: int = eqx.field()
|
230
|
+
skip_connections: bool = eqx.field()
|
231
231
|
|
232
232
|
def __init__(
|
233
233
|
self,
|
xax/task/mixins/train.py
CHANGED
@@ -40,7 +40,12 @@ from xax.core.state import Phase, State
|
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import
|
43
|
+
from xax.task.mixins.checkpointing import (
|
44
|
+
CheckpointingConfig,
|
45
|
+
CheckpointingMixin,
|
46
|
+
CheckpointPart,
|
47
|
+
load_ckpt,
|
48
|
+
)
|
44
49
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
50
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
51
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=OeW6UObyosw6eJSEQ96AfRJIKHg5WyZ6xuZLJdcR6cg,16240
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -8,14 +8,14 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
xax/nn/attention.py,sha256=
|
12
|
-
xax/nn/embeddings.py,sha256=
|
11
|
+
xax/nn/attention.py,sha256=ESO6THJ5ORKxSM8LziRLEkj1d_QXtDndPi80Puyo-xA,28033
|
12
|
+
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
13
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
|
-
xax/nn/geom.py,sha256=
|
14
|
+
xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
|
15
15
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
16
16
|
xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
|
17
17
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
18
|
-
xax/nn/ssm.py,sha256=
|
18
|
+
xax/nn/ssm.py,sha256=qSBv_FobnaFA5jt87OF5P2q5ih6sj4SlehhEhEFaPjA,10766
|
19
19
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
20
|
xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
|
21
21
|
xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
|
@@ -42,7 +42,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
42
42
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
43
43
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
44
44
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
45
|
-
xax/task/mixins/train.py,sha256=
|
45
|
+
xax/task/mixins/train.py,sha256=bjBoigTCjbq9H4hcqIO32irHBc9rC2zkgXrnGNI2RtI,33266
|
46
46
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
47
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
48
48
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
62
|
+
xax-0.3.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.3.5.dist-info/METADATA,sha256=kMRKGih6o7SfqGrvGQW_7OkFST6PDnbPuopnfx_bAOs,1246
|
64
|
+
xax-0.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
xax-0.3.5.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.3.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.3.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|