xax 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.4"
15
+ __version__ = "0.3.5"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/nn/attention.py CHANGED
@@ -5,6 +5,8 @@ supporting a fixed-size context window and caching that can be used to train
5
5
  transformers which can be unrolled with a fixed-length cache.
6
6
  """
7
7
 
8
+ import math
9
+ import warnings
8
10
  from typing import NotRequired, TypedDict
9
11
 
10
12
  import chex
@@ -13,6 +15,8 @@ import jax
13
15
  import jax.numpy as jnp
14
16
  from jaxtyping import Array, PRNGKeyArray
15
17
 
18
+ from xax.utils.jax import scan as xax_scan
19
+
16
20
 
17
21
  class RotaryEmbedding(eqx.Module):
18
22
  """Rotary Position Embedding (RoPE) for transformer attention.
@@ -22,8 +26,8 @@ class RotaryEmbedding(eqx.Module):
22
26
  https://arxiv.org/abs/2104.09864
23
27
  """
24
28
 
25
- head_dim: int = eqx.static_field()
26
- base: float = eqx.static_field()
29
+ head_dim: int = eqx.field()
30
+ base: float = eqx.field()
27
31
 
28
32
  def __init__(
29
33
  self,
@@ -125,15 +129,15 @@ class TransformerCache(TypedDict):
125
129
  class SelfAttentionBlock(eqx.Module):
126
130
  """Self-attention block using jax.nn.dot_product_attention."""
127
131
 
128
- q_proj: eqx.nn.Linear
129
- k_proj: eqx.nn.Linear
130
- v_proj: eqx.nn.Linear
131
- output_proj: eqx.nn.Linear
132
- rotary_emb: RotaryEmbedding | None
133
- num_heads: int = eqx.static_field()
134
- head_dim: int = eqx.static_field()
135
- causal: bool = eqx.static_field()
136
- context_length: int | None = eqx.static_field()
132
+ q_proj: eqx.nn.Linear = eqx.field()
133
+ k_proj: eqx.nn.Linear = eqx.field()
134
+ v_proj: eqx.nn.Linear = eqx.field()
135
+ output_proj: eqx.nn.Linear = eqx.field()
136
+ rotary_emb: RotaryEmbedding | None = eqx.field()
137
+ num_heads: int = eqx.field()
138
+ head_dim: int = eqx.field()
139
+ causal: bool = eqx.field()
140
+ local_window_size: int | None = eqx.field()
137
141
 
138
142
  def __init__(
139
143
  self,
@@ -169,8 +173,12 @@ class SelfAttentionBlock(eqx.Module):
169
173
  else:
170
174
  self.rotary_emb = None
171
175
 
176
+ if context_length is not None and not causal:
177
+ warnings.warn("context_length is set but causal is False; overriding causal to True", stacklevel=2)
178
+ causal = True
179
+
172
180
  self.causal = causal
173
- self.context_length = context_length
181
+ self.local_window_size = None if context_length is None else context_length - 1
174
182
 
175
183
  @property
176
184
  def embed_dim(self) -> int:
@@ -195,41 +203,25 @@ class SelfAttentionBlock(eqx.Module):
195
203
  Returns:
196
204
  Cache with fixed-length k and v tensors
197
205
  """
198
- if self.context_length is None:
206
+ if self.local_window_size is None:
199
207
  raise ValueError("context_length must be set for caching")
200
208
 
201
209
  # Create fixed-length cache
202
- k_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
203
- v_cache = jnp.zeros((self.context_length - 1, self.num_heads, self.head_dim), dtype=dtype)
210
+ k_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
211
+ v_cache = jnp.zeros((self.local_window_size, self.num_heads, self.head_dim), dtype=dtype)
204
212
 
205
213
  return {"k": k_cache, "v": v_cache, "position": 0}
206
214
 
207
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
208
- in_dim, out_dim = seq_len, seq_len
209
- if with_cache:
210
- if self.context_length is None:
211
- raise ValueError("context_length must be set for caching")
212
- in_dim = in_dim + self.context_length - 1
213
-
214
- mask = jnp.tril(jnp.ones((in_dim, out_dim)))
215
- if self.context_length is not None:
216
- neg_mask = 1 - jnp.tril(jnp.ones((in_dim, out_dim)), -self.context_length)
217
- mask = mask * neg_mask
218
-
219
- return mask.astype(jnp.bool_).transpose()
220
-
221
215
  def forward(
222
216
  self,
223
217
  x_tn: Array,
224
218
  *,
225
- mask: Array | None = None,
226
219
  cache: AttentionCache | None = None,
227
220
  ) -> tuple[Array, AttentionCache]:
228
221
  """Apply self-attention.
229
222
 
230
223
  Args:
231
224
  x_tn: Input tensor of shape (seq_len, embed_dim)
232
- mask: Optional mask tensor
233
225
  cache: The cached key and value tensors (fixed-length)
234
226
 
235
227
  Returns:
@@ -263,6 +255,10 @@ class SelfAttentionBlock(eqx.Module):
263
255
  v_cache = cache["v"]
264
256
  k = jnp.concatenate([k_cache, k], axis=0)
265
257
  v = jnp.concatenate([v_cache, v], axis=0)
258
+
259
+ # Pads query with `k_cache.shape[0]` zeros.
260
+ q = jnp.pad(q, ((k_cache.shape[0], 0), (0, 0), (0, 0)), mode="constant", constant_values=0)
261
+
266
262
  new_position = cache["position"] + seq_len
267
263
 
268
264
  else:
@@ -272,16 +268,21 @@ class SelfAttentionBlock(eqx.Module):
272
268
  q,
273
269
  k,
274
270
  v,
275
- mask=mask,
276
- is_causal=self.causal and mask is None,
271
+ scale=1.0 / math.sqrt(self.head_dim),
272
+ is_causal=self.causal,
273
+ local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
277
274
  )
278
275
 
276
+ if cache is not None:
277
+ # Remove the padding.
278
+ attn_output = attn_output[cache["k"].shape[0] :]
279
+
279
280
  attn_output = self._combine_heads(attn_output)
280
281
  output = jax.vmap(self.output_proj)(attn_output)
281
282
 
282
- if self.context_length is not None:
283
- k = k[-(self.context_length - 1) :]
284
- v = v[-(self.context_length - 1) :]
283
+ if self.local_window_size is not None:
284
+ k = k[-self.local_window_size :]
285
+ v = v[-self.local_window_size :]
285
286
 
286
287
  return output, {"k": k, "v": v, "position": new_position}
287
288
 
@@ -294,8 +295,8 @@ class CrossAttentionBlock(eqx.Module):
294
295
  v_proj: eqx.nn.Linear
295
296
  output_proj: eqx.nn.Linear
296
297
  rotary_emb: RotaryEmbedding | None
297
- num_heads: int = eqx.static_field()
298
- head_dim: int = eqx.static_field()
298
+ num_heads: int = eqx.field()
299
+ head_dim: int = eqx.field()
299
300
 
300
301
  def __init__(
301
302
  self,
@@ -352,7 +353,6 @@ class CrossAttentionBlock(eqx.Module):
352
353
  *,
353
354
  kv_sn: Array | None = None,
354
355
  cache: AttentionCache | None = None,
355
- mask: Array | None = None,
356
356
  ) -> tuple[Array, AttentionCache]:
357
357
  """Apply cross-attention.
358
358
 
@@ -362,7 +362,6 @@ class CrossAttentionBlock(eqx.Module):
362
362
  If not provided, then `cache` must be provided.
363
363
  cache: The cached key and value tensors. If not provided, then
364
364
  `kv_sn` must be provided.
365
- mask: Optional mask tensor
366
365
 
367
366
  Returns:
368
367
  The output tensor of shape (q_seq_len, embed_dim)
@@ -404,7 +403,6 @@ class CrossAttentionBlock(eqx.Module):
404
403
  q_rot,
405
404
  k_rot,
406
405
  v,
407
- mask=mask,
408
406
  is_causal=False,
409
407
  )
410
408
 
@@ -424,10 +422,10 @@ class TransformerBlock(eqx.Module):
424
422
  layer_norm1: eqx.nn.LayerNorm
425
423
  layer_norm2: eqx.nn.LayerNorm
426
424
  layer_norm3: eqx.nn.LayerNorm | None
427
- num_heads: int = eqx.static_field()
428
- head_dim: int = eqx.static_field()
429
- causal: bool = eqx.static_field()
430
- context_length: int | None = eqx.static_field()
425
+ num_heads: int = eqx.field()
426
+ head_dim: int = eqx.field()
427
+ causal: bool = eqx.field()
428
+ context_length: int | None = eqx.field()
431
429
 
432
430
  def __init__(
433
431
  self,
@@ -500,16 +498,11 @@ class TransformerBlock(eqx.Module):
500
498
  cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
501
499
  return cache
502
500
 
503
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
504
- return self.self_attn.init_mask(seq_len, with_cache=with_cache)
505
-
506
501
  def forward(
507
502
  self,
508
503
  x_tn: Array,
509
504
  *,
510
505
  context_sn: Array | None = None,
511
- self_mask: Array | None = None,
512
- cross_mask: Array | None = None,
513
506
  cache: AttentionCacheDict | None = None,
514
507
  ) -> tuple[Array, AttentionCacheDict]:
515
508
  """Apply transformer block.
@@ -517,8 +510,6 @@ class TransformerBlock(eqx.Module):
517
510
  Args:
518
511
  x_tn: Input tensor of shape (seq_len, embed_dim)
519
512
  context_sn: Optional context for cross-attention
520
- self_mask: Mask for self-attention
521
- cross_mask: Mask for cross-attention
522
513
  cache: Optional dictionary containing cached key and value tensors
523
514
 
524
515
  Returns:
@@ -531,7 +522,6 @@ class TransformerBlock(eqx.Module):
531
522
 
532
523
  attn_output, self_attn_cache = self.self_attn.forward(
533
524
  x_tn=norm_x,
534
- mask=self_mask,
535
525
  cache=None if cache is None else cache["self_attn"],
536
526
  )
537
527
  updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
@@ -547,7 +537,6 @@ class TransformerBlock(eqx.Module):
547
537
  cross_attn_output, updated_cache["cross_attn"] = self.cross_attn.forward(
548
538
  q_tn=norm_x,
549
539
  kv_sn=context_sn,
550
- mask=cross_mask,
551
540
  cache=None if cache is None else cache.get("cross_attn"),
552
541
  )
553
542
 
@@ -564,9 +553,9 @@ class TransformerBlock(eqx.Module):
564
553
  class TransformerStack(eqx.Module):
565
554
  """A stack of transformer blocks."""
566
555
 
567
- layers: list[TransformerBlock]
568
- num_layers: int = eqx.static_field()
569
- causal: bool = eqx.static_field()
556
+ layers: tuple[TransformerBlock, ...]
557
+ num_layers: int = eqx.field()
558
+ causal: bool = eqx.field()
570
559
 
571
560
  def __init__(
572
561
  self,
@@ -584,7 +573,7 @@ class TransformerStack(eqx.Module):
584
573
  ) -> None:
585
574
  keys = jax.random.split(key, num_layers)
586
575
 
587
- self.layers = [
576
+ self.layers = tuple(
588
577
  TransformerBlock(
589
578
  embed_dim=embed_dim,
590
579
  num_heads=num_heads,
@@ -597,7 +586,7 @@ class TransformerStack(eqx.Module):
597
586
  rotary_base=rotary_base,
598
587
  )
599
588
  for i in range(num_layers)
600
- ]
589
+ )
601
590
 
602
591
  self.num_layers = num_layers
603
592
  self.causal = causal
@@ -609,16 +598,11 @@ class TransformerStack(eqx.Module):
609
598
  cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
610
599
  return {"layers": cache}
611
600
 
612
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
613
- return self.layers[0].init_mask(seq_len, with_cache=with_cache)
614
-
615
601
  def forward(
616
602
  self,
617
603
  x_tn: Array,
618
604
  *,
619
605
  context_sn: Array | None = None,
620
- self_mask: Array | None = None,
621
- cross_mask: Array | None = None,
622
606
  cache: TransformerCache | None = None,
623
607
  ) -> tuple[Array, TransformerCache]:
624
608
  """Apply transformer stack.
@@ -626,8 +610,6 @@ class TransformerStack(eqx.Module):
626
610
  Args:
627
611
  x_tn: Input tensor of shape (seq_len, embed_dim)
628
612
  context_sn: Optional context for cross-attention
629
- self_mask: Mask for self-attention
630
- cross_mask: Mask for cross-attention
631
613
  cache: Optional dictionary containing cached key and value tensors
632
614
 
633
615
  Returns:
@@ -647,8 +629,6 @@ class TransformerStack(eqx.Module):
647
629
  x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
648
630
  x_tn,
649
631
  context_sn=context_sn,
650
- self_mask=self_mask,
651
- cross_mask=cross_mask,
652
632
  cache=layer_cache,
653
633
  )
654
634
 
@@ -660,9 +640,9 @@ class Transformer(eqx.Module):
660
640
  layers: TransformerStack
661
641
  output_layer: eqx.nn.Linear | None
662
642
  layer_norm: eqx.nn.LayerNorm
663
- embed_dim: int = eqx.static_field()
664
- causal: bool = eqx.static_field()
665
- context_length: int | None = eqx.static_field()
643
+ embed_dim: int = eqx.field()
644
+ causal: bool = eqx.field()
645
+ context_length: int | None = eqx.field()
666
646
 
667
647
  def __init__(
668
648
  self,
@@ -713,21 +693,16 @@ class Transformer(eqx.Module):
713
693
  """Initialize cache for the input."""
714
694
  return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
715
695
 
716
- def init_mask(self, seq_len: int, with_cache: bool = True) -> Array:
717
- return self.layers.init_mask(seq_len, with_cache=with_cache)
718
-
719
696
  def encode(
720
697
  self,
721
698
  x: Array,
722
699
  *,
723
- mask: Array | None = None,
724
700
  cache: TransformerCache | None = None,
725
701
  ) -> tuple[Array, TransformerCache]:
726
702
  """Encode the input sequence.
727
703
 
728
704
  Args:
729
705
  x: Input token indices of shape (seq_len)
730
- mask: Optional attention mask
731
706
  cache: Optional dictionary containing cached key and value tensors
732
707
 
733
708
  Returns:
@@ -737,11 +712,7 @@ class Transformer(eqx.Module):
737
712
  x_embedded = jax.vmap(self.token_embedding)(x)
738
713
 
739
714
  # Apply transformer stack
740
- x_embedded, updated_cache = self.layers.forward(
741
- x_embedded,
742
- self_mask=mask,
743
- cache=cache,
744
- )
715
+ x_embedded, updated_cache = self.layers.forward(x_embedded, cache=cache)
745
716
 
746
717
  # Apply final layer norm
747
718
  output = jax.vmap(self.layer_norm)(x_embedded)
@@ -753,8 +724,6 @@ class Transformer(eqx.Module):
753
724
  x_t: Array,
754
725
  context_s: Array,
755
726
  *,
756
- self_mask: Array | None = None,
757
- cross_mask: Array | None = None,
758
727
  cache: TransformerCache | None = None,
759
728
  ) -> tuple[Array, TransformerCache]:
760
729
  """Decode with self-attention and cross-attention.
@@ -763,8 +732,6 @@ class Transformer(eqx.Module):
763
732
  x_t: Input token indices, shape (seq_len)
764
733
  context_s: Context from encoder (token indices or embedded),
765
734
  shape (context_len, embed_dim)
766
- self_mask: Optional self-attention mask, shape (seq_len, seq_len)
767
- cross_mask: Optional cross-attention mask, shape (seq_len, context_len)
768
735
  cache: Optional dictionary containing cached key and value tensors
769
736
 
770
737
  Returns:
@@ -780,8 +747,6 @@ class Transformer(eqx.Module):
780
747
  x_embedded, updated_cache = self.layers.forward(
781
748
  x_embedded,
782
749
  context_sn=context_embedded,
783
- self_mask=self_mask,
784
- cross_mask=cross_mask,
785
750
  cache=cache,
786
751
  )
787
752
 
@@ -794,14 +759,12 @@ class Transformer(eqx.Module):
794
759
  self,
795
760
  x: Array,
796
761
  *,
797
- mask: Array | None = None,
798
762
  cache: TransformerCache | None = None,
799
763
  ) -> tuple[Array, TransformerCache]:
800
764
  """Forward pass for encoder-only or decoder-only transformers.
801
765
 
802
766
  Args:
803
767
  x: Input token indices of shape (seq_len)
804
- mask: Optional attention mask
805
768
  cache: Optional dictionary containing cached key and value tensors
806
769
 
807
770
  Returns:
@@ -809,11 +772,7 @@ class Transformer(eqx.Module):
809
772
  """
810
773
  chex.assert_rank(x, 1)
811
774
 
812
- output, updated_cache = self.encode(
813
- x,
814
- mask=mask,
815
- cache=cache,
816
- )
775
+ output, updated_cache = self.encode(x, cache=cache)
817
776
 
818
777
  # Apply output layer if it exists
819
778
  if self.output_layer is not None:
@@ -832,6 +791,7 @@ class Transformer(eqx.Module):
832
791
  temperature: float = 1.0,
833
792
  top_k: int | None = None,
834
793
  key: PRNGKeyArray | None = None,
794
+ jit_level: int | None = None,
835
795
  ) -> Array:
836
796
  """Generate a sequence autoregressively with KV caching.
837
797
 
@@ -841,6 +801,7 @@ class Transformer(eqx.Module):
841
801
  temperature: Sampling temperature
842
802
  top_k: Optional top-k sampling parameter
843
803
  key: PRNG key for sampling
804
+ jit_level: JIT level for the scan function
844
805
 
845
806
  Returns:
846
807
  Generated sequence of shape (prompt_len + max_len,)
@@ -884,5 +845,5 @@ class Transformer(eqx.Module):
884
845
  return (new_output_seq, pos + 1, new_cache, rng), next_token
885
846
 
886
847
  init_carry = (output_seq, prompt_len - 1, cache, key)
887
- (final_seq, _, _, _), _ = jax.lax.scan(scan_fn, init_carry, length=max_len)
848
+ (final_seq, _, _, _), _ = xax_scan(scan_fn, init_carry, length=max_len, jit_level=jit_level)
888
849
  return final_seq
xax/nn/embeddings.py CHANGED
@@ -33,10 +33,10 @@ class LearnedPositionalEmbeddings(eqx.Module):
33
33
  learnable: Whether the embeddings are learnable.
34
34
  """
35
35
 
36
- max_tsz: int = eqx.field(static=True)
37
- embed_dim: int = eqx.field(static=True)
38
- learnable: bool = eqx.field(static=True)
39
- embeddings_tc: Array
36
+ max_tsz: int = eqx.field()
37
+ embed_dim: int = eqx.field()
38
+ learnable: bool = eqx.field()
39
+ embeddings_tc: Array = eqx.field()
40
40
 
41
41
  def __init__(
42
42
  self,
@@ -74,10 +74,10 @@ class SinusoidalEmbeddings(eqx.Module):
74
74
  base: The base for the sinusoidal embeddings.
75
75
  """
76
76
 
77
- base: int = eqx.field(static=True)
78
- max_tsz: int | None = eqx.field(static=True)
79
- embed_dim: int | None = eqx.field(static=True)
80
- embeddings_tc: Array | None
77
+ base: int = eqx.field()
78
+ max_tsz: int | None = eqx.field()
79
+ embed_dim: int | None = eqx.field()
80
+ embeddings_tc: Array | None = eqx.field()
81
81
 
82
82
  def __init__(
83
83
  self,
@@ -91,8 +91,8 @@ class SinusoidalEmbeddings(eqx.Module):
91
91
  self.max_tsz = max_tsz
92
92
  self.embed_dim = embed_dim
93
93
  self.base = base
94
+ self.embeddings_tc = None
94
95
 
95
- self.embeddings_tc: Array | None = None
96
96
  if learnable:
97
97
  assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
98
98
  assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
@@ -192,7 +192,7 @@ class RotaryEmbeddings(eqx.Module):
192
192
  base: The base for the sinusoidal embeddings.
193
193
  """
194
194
 
195
- base: int = eqx.field(static=True)
195
+ base: int = eqx.field()
196
196
 
197
197
  def __init__(self, base: int = 10_000) -> None:
198
198
  """Defines a rotary embeddings module.
xax/nn/geom.py CHANGED
@@ -207,7 +207,7 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
207
207
 
208
208
  def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
209
209
  norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
210
- return v / jnp.clip(norm, a_min=eps)
210
+ return v / jnp.clip(norm, min=eps)
211
211
 
212
212
 
213
213
  def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
@@ -299,28 +299,28 @@ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
299
299
  trace = m00 + m11 + m22
300
300
 
301
301
  # Case 0: trace is positive
302
- s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, min=0.0)) * 2.0 # S = 4 * qw
303
303
  w0 = 0.25 * s0
304
304
  x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
305
  y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
306
  z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
307
 
308
308
  # Case 1: m00 is the largest diagonal term
309
- s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, min=0.0)) * 2.0 # S = 4 * qx
310
310
  w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
311
  x1 = 0.25 * s1
312
312
  y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
313
  z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
314
 
315
315
  # Case 2: m11 is the largest diagonal term
316
- s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, min=0.0)) * 2.0 # S = 4 * qy
317
317
  w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
318
  x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
319
  y2 = 0.25 * s2
320
320
  z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
321
 
322
322
  # Case 3: m22 is the largest diagonal term
323
- s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, min=0.0)) * 2.0 # S = 4 * qz
324
324
  w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
325
  x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
326
  y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
xax/nn/ssm.py CHANGED
@@ -222,12 +222,12 @@ class DiscreteDiagSSMBlock(DiagSSMBlock):
222
222
 
223
223
 
224
224
  class SSM(eqx.Module):
225
- vocab_embedding: eqx.nn.Embedding
226
- output_layer: eqx.nn.Linear
227
- blocks: list[BaseSSMBlock]
228
- num_layers: int = eqx.static_field()
229
- hidden_size: int = eqx.static_field()
230
- skip_connections: bool = eqx.static_field()
225
+ vocab_embedding: eqx.nn.Embedding = eqx.field()
226
+ output_layer: eqx.nn.Linear = eqx.field()
227
+ blocks: list[BaseSSMBlock] = eqx.field()
228
+ num_layers: int = eqx.field()
229
+ hidden_size: int = eqx.field()
230
+ skip_connections: bool = eqx.field()
231
231
 
232
232
  def __init__(
233
233
  self,
xax/task/mixins/train.py CHANGED
@@ -40,7 +40,12 @@ from xax.core.state import Phase, State
40
40
  from xax.nn.functions import set_random_seed
41
41
  from xax.nn.parallel import is_master
42
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
43
+ from xax.task.mixins.checkpointing import (
44
+ CheckpointingConfig,
45
+ CheckpointingMixin,
46
+ CheckpointPart,
47
+ load_ckpt,
48
+ )
44
49
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
45
50
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
51
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.4
3
+ Version: 0.3.5
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=LJFB4xQplzC08tkbkZMxaCd-7jIB7aJZzBMcs9AuqiM,16240
1
+ xax/__init__.py,sha256=OeW6UObyosw6eJSEQ96AfRJIKHg5WyZ6xuZLJdcR6cg,16240
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -8,14 +8,14 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- xax/nn/attention.py,sha256=aIEtrM7vAQtaXTPKmsqGcYqt03CyiUQMccXj8Cjw3vc,29514
12
- xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
11
+ xax/nn/attention.py,sha256=ESO6THJ5ORKxSM8LziRLEkj1d_QXtDndPi80Puyo-xA,28033
12
+ xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
13
13
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
- xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
14
+ xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
15
15
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
16
16
  xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
17
17
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
18
- xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
+ xax/nn/ssm.py,sha256=qSBv_FobnaFA5jt87OF5P2q5ih6sj4SlehhEhEFaPjA,10766
19
19
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
21
21
  xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
@@ -42,7 +42,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
42
42
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
43
43
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
44
44
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
45
- xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
45
+ xax/task/mixins/train.py,sha256=bjBoigTCjbq9H4hcqIO32irHBc9rC2zkgXrnGNI2RtI,33266
46
46
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
47
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
48
48
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
59
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
60
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
61
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.4.dist-info/METADATA,sha256=j_UQdK4iPYbhzMH0osmHm5XJnYnFY1A_Z5MwSJwXr-4,1246
64
- xax-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.4.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.4.dist-info/RECORD,,
62
+ xax-0.3.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
+ xax-0.3.5.dist-info/METADATA,sha256=kMRKGih6o7SfqGrvGQW_7OkFST6PDnbPuopnfx_bAOs,1246
64
+ xax-0.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ xax-0.3.5.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
+ xax-0.3.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
+ xax-0.3.5.dist-info/RECORD,,
File without changes