xax 0.3.5__tar.gz → 0.3.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {xax-0.3.5/xax.egg-info → xax-0.3.6}/PKG-INFO +1 -1
  2. {xax-0.3.5 → xax-0.3.6}/xax/__init__.py +4 -1
  3. {xax-0.3.5 → xax-0.3.6}/xax/nn/attention.py +108 -17
  4. {xax-0.3.5 → xax-0.3.6}/xax/utils/pytree.py +13 -0
  5. {xax-0.3.5 → xax-0.3.6/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.3.5 → xax-0.3.6}/LICENSE +0 -0
  7. {xax-0.3.5 → xax-0.3.6}/MANIFEST.in +0 -0
  8. {xax-0.3.5 → xax-0.3.6}/README.md +0 -0
  9. {xax-0.3.5 → xax-0.3.6}/pyproject.toml +0 -0
  10. {xax-0.3.5 → xax-0.3.6}/setup.cfg +0 -0
  11. {xax-0.3.5 → xax-0.3.6}/setup.py +0 -0
  12. {xax-0.3.5 → xax-0.3.6}/xax/cli/__init__.py +0 -0
  13. {xax-0.3.5 → xax-0.3.6}/xax/cli/edit_config.py +0 -0
  14. {xax-0.3.5 → xax-0.3.6}/xax/core/__init__.py +0 -0
  15. {xax-0.3.5 → xax-0.3.6}/xax/core/conf.py +0 -0
  16. {xax-0.3.5 → xax-0.3.6}/xax/core/state.py +0 -0
  17. {xax-0.3.5 → xax-0.3.6}/xax/nn/__init__.py +0 -0
  18. {xax-0.3.5 → xax-0.3.6}/xax/nn/embeddings.py +0 -0
  19. {xax-0.3.5 → xax-0.3.6}/xax/nn/functions.py +0 -0
  20. {xax-0.3.5 → xax-0.3.6}/xax/nn/geom.py +0 -0
  21. {xax-0.3.5 → xax-0.3.6}/xax/nn/losses.py +0 -0
  22. {xax-0.3.5 → xax-0.3.6}/xax/nn/metrics.py +0 -0
  23. {xax-0.3.5 → xax-0.3.6}/xax/nn/parallel.py +0 -0
  24. {xax-0.3.5 → xax-0.3.6}/xax/nn/ssm.py +0 -0
  25. {xax-0.3.5 → xax-0.3.6}/xax/py.typed +0 -0
  26. {xax-0.3.5 → xax-0.3.6}/xax/requirements-dev.txt +0 -0
  27. {xax-0.3.5 → xax-0.3.6}/xax/requirements.txt +0 -0
  28. {xax-0.3.5 → xax-0.3.6}/xax/task/__init__.py +0 -0
  29. {xax-0.3.5 → xax-0.3.6}/xax/task/base.py +0 -0
  30. {xax-0.3.5 → xax-0.3.6}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.3.5 → xax-0.3.6}/xax/task/launchers/base.py +0 -0
  32. {xax-0.3.5 → xax-0.3.6}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.3.5 → xax-0.3.6}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.3.5 → xax-0.3.6}/xax/task/logger.py +0 -0
  35. {xax-0.3.5 → xax-0.3.6}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.3.5 → xax-0.3.6}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.3.5 → xax-0.3.6}/xax/task/loggers/json.py +0 -0
  38. {xax-0.3.5 → xax-0.3.6}/xax/task/loggers/state.py +0 -0
  39. {xax-0.3.5 → xax-0.3.6}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.3.5 → xax-0.3.6}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/process.py +0 -0
  50. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.3.5 → xax-0.3.6}/xax/task/mixins/train.py +0 -0
  53. {xax-0.3.5 → xax-0.3.6}/xax/task/script.py +0 -0
  54. {xax-0.3.5 → xax-0.3.6}/xax/task/task.py +0 -0
  55. {xax-0.3.5 → xax-0.3.6}/xax/utils/__init__.py +0 -0
  56. {xax-0.3.5 → xax-0.3.6}/xax/utils/data/__init__.py +0 -0
  57. {xax-0.3.5 → xax-0.3.6}/xax/utils/data/collate.py +0 -0
  58. {xax-0.3.5 → xax-0.3.6}/xax/utils/debugging.py +0 -0
  59. {xax-0.3.5 → xax-0.3.6}/xax/utils/experiments.py +0 -0
  60. {xax-0.3.5 → xax-0.3.6}/xax/utils/jax.py +0 -0
  61. {xax-0.3.5 → xax-0.3.6}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.3.5 → xax-0.3.6}/xax/utils/logging.py +0 -0
  63. {xax-0.3.5 → xax-0.3.6}/xax/utils/numpy.py +0 -0
  64. {xax-0.3.5 → xax-0.3.6}/xax/utils/profile.py +0 -0
  65. {xax-0.3.5 → xax-0.3.6}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.3.5 → xax-0.3.6}/xax/utils/text.py +0 -0
  67. {xax-0.3.5 → xax-0.3.6}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.3.5 → xax-0.3.6}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.3.5 → xax-0.3.6}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.3.5 → xax-0.3.6}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.3.5 → xax-0.3.6}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.3.5 → xax-0.3.6}/xax.egg-info/entry_points.txt +0 -0
  73. {xax-0.3.5 → xax-0.3.6}/xax.egg-info/requires.txt +0 -0
  74. {xax-0.3.5 → xax-0.3.6}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.5
3
+ Version: 0.3.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.5"
15
+ __version__ = "0.3.6"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -136,6 +136,7 @@ __all__ = [
136
136
  "compute_nan_ratio",
137
137
  "flatten_array",
138
138
  "flatten_pytree",
139
+ "get_pytree_mapping",
139
140
  "get_pytree_param_count",
140
141
  "pytree_has_nans",
141
142
  "reshuffle_pytree",
@@ -323,6 +324,7 @@ NAME_MAP: dict[str, str] = {
323
324
  "compute_nan_ratio": "utils.pytree",
324
325
  "flatten_array": "utils.pytree",
325
326
  "flatten_pytree": "utils.pytree",
327
+ "get_pytree_mapping": "utils.pytree",
326
328
  "get_pytree_param_count": "utils.pytree",
327
329
  "pytree_has_nans": "utils.pytree",
328
330
  "reshuffle_pytree": "utils.pytree",
@@ -509,6 +511,7 @@ if IMPORT_ALL or TYPE_CHECKING:
509
511
  compute_nan_ratio,
510
512
  flatten_array,
511
513
  flatten_pytree,
514
+ get_pytree_mapping,
512
515
  get_pytree_param_count,
513
516
  pytree_has_nans,
514
517
  reshuffle_pytree,
@@ -212,16 +212,49 @@ class SelfAttentionBlock(eqx.Module):
212
212
 
213
213
  return {"k": k_cache, "v": v_cache, "position": 0}
214
214
 
215
+ def init_mask(
216
+ self,
217
+ seq_len: int,
218
+ add_cache: bool = False,
219
+ batch_dim: bool = False,
220
+ ) -> Array:
221
+ """Initialize the attention matrix mask.
222
+
223
+ Args:
224
+ seq_len: The length of the sequence
225
+ add_cache: Whether to add the cache to the mask
226
+ batch_dim: Whether to add a batch dimension to the mask
227
+
228
+ Returns:
229
+ The attention matrix mask of shape (bsz, 1, seq_len, seq_len + cache_len)
230
+ if batch_dim is True, otherwise (seq_len, seq_len + cache_len).
231
+ """
232
+ t, s, o = seq_len, seq_len, 0
233
+ if add_cache:
234
+ if self.local_window_size is None:
235
+ raise ValueError("local_window_size must be set for caching")
236
+ s += self.local_window_size
237
+ o -= self.local_window_size
238
+ mask = jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-o)
239
+ if self.local_window_size is not None:
240
+ neg_mask = ~jnp.tril(jnp.ones((t, s), dtype=jnp.bool_), k=-(self.local_window_size + 1 + o))
241
+ mask = mask & neg_mask
242
+ mask = mask.reshape(1, 1, t, s) if batch_dim else mask.reshape(t, s)
243
+ return mask
244
+
215
245
  def forward(
216
246
  self,
217
247
  x_tn: Array,
218
248
  *,
249
+ mask: Array | None = None,
219
250
  cache: AttentionCache | None = None,
220
251
  ) -> tuple[Array, AttentionCache]:
221
252
  """Apply self-attention.
222
253
 
223
254
  Args:
224
255
  x_tn: Input tensor of shape (seq_len, embed_dim)
256
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
257
+ seq_len + cache_len)
225
258
  cache: The cached key and value tensors (fixed-length)
226
259
 
227
260
  Returns:
@@ -256,26 +289,28 @@ class SelfAttentionBlock(eqx.Module):
256
289
  k = jnp.concatenate([k_cache, k], axis=0)
257
290
  v = jnp.concatenate([v_cache, v], axis=0)
258
291
 
259
- # Pads query with `k_cache.shape[0]` zeros.
260
- q = jnp.pad(q, ((k_cache.shape[0], 0), (0, 0), (0, 0)), mode="constant", constant_values=0)
261
-
262
292
  new_position = cache["position"] + seq_len
263
293
 
264
294
  else:
265
295
  new_position = seq_len
266
296
 
267
- attn_output = jax.nn.dot_product_attention(
268
- q,
269
- k,
270
- v,
271
- scale=1.0 / math.sqrt(self.head_dim),
272
- is_causal=self.causal,
273
- local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
274
- )
297
+ if seq_len == 1:
298
+ attn_output = jax.nn.dot_product_attention(q, k, v)
275
299
 
276
- if cache is not None:
277
- # Remove the padding.
278
- attn_output = attn_output[cache["k"].shape[0] :]
300
+ elif mask is not None:
301
+ attn_output = jax.nn.dot_product_attention(q, k, v, mask=mask)
302
+
303
+ elif cache is not None:
304
+ raise NotImplementedError("For training with a cache, provide a mask instead.")
305
+
306
+ else:
307
+ attn_output = jax.nn.dot_product_attention(
308
+ q,
309
+ k,
310
+ v,
311
+ is_causal=self.causal,
312
+ local_window_size=(self.local_window_size, 0) if self.local_window_size is not None else None,
313
+ )
279
314
 
280
315
  attn_output = self._combine_heads(attn_output)
281
316
  output = jax.vmap(self.output_proj)(attn_output)
@@ -403,6 +438,7 @@ class CrossAttentionBlock(eqx.Module):
403
438
  q_rot,
404
439
  k_rot,
405
440
  v,
441
+ scale=1.0 / math.sqrt(self.head_dim),
406
442
  is_causal=False,
407
443
  )
408
444
 
@@ -498,11 +534,24 @@ class TransformerBlock(eqx.Module):
498
534
  cache["cross_attn"] = self.cross_attn.init_cache(kv_sn=context_sn)
499
535
  return cache
500
536
 
537
+ def init_mask(
538
+ self,
539
+ seq_len: int,
540
+ add_cache: bool = False,
541
+ batch_dim: bool = False,
542
+ ) -> Array:
543
+ return self.self_attn.init_mask(
544
+ seq_len,
545
+ add_cache=add_cache,
546
+ batch_dim=batch_dim,
547
+ )
548
+
501
549
  def forward(
502
550
  self,
503
551
  x_tn: Array,
504
552
  *,
505
553
  context_sn: Array | None = None,
554
+ mask: Array | None = None,
506
555
  cache: AttentionCacheDict | None = None,
507
556
  ) -> tuple[Array, AttentionCacheDict]:
508
557
  """Apply transformer block.
@@ -510,6 +559,8 @@ class TransformerBlock(eqx.Module):
510
559
  Args:
511
560
  x_tn: Input tensor of shape (seq_len, embed_dim)
512
561
  context_sn: Optional context for cross-attention
562
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
563
+ seq_len + cache_len)
513
564
  cache: Optional dictionary containing cached key and value tensors
514
565
 
515
566
  Returns:
@@ -522,6 +573,7 @@ class TransformerBlock(eqx.Module):
522
573
 
523
574
  attn_output, self_attn_cache = self.self_attn.forward(
524
575
  x_tn=norm_x,
576
+ mask=mask,
525
577
  cache=None if cache is None else cache["self_attn"],
526
578
  )
527
579
  updated_cache: AttentionCacheDict = {"self_attn": self_attn_cache}
@@ -598,11 +650,24 @@ class TransformerStack(eqx.Module):
598
650
  cache[f"layer_{i}"] = layer.init_cache(dtype=dtype, context_sn=x_tn)
599
651
  return {"layers": cache}
600
652
 
653
+ def init_mask(
654
+ self,
655
+ seq_len: int,
656
+ add_cache: bool = False,
657
+ batch_dim: bool = False,
658
+ ) -> Array:
659
+ return self.layers[0].init_mask(
660
+ seq_len,
661
+ add_cache=add_cache,
662
+ batch_dim=batch_dim,
663
+ )
664
+
601
665
  def forward(
602
666
  self,
603
667
  x_tn: Array,
604
668
  *,
605
669
  context_sn: Array | None = None,
670
+ mask: Array | None = None,
606
671
  cache: TransformerCache | None = None,
607
672
  ) -> tuple[Array, TransformerCache]:
608
673
  """Apply transformer stack.
@@ -610,6 +675,8 @@ class TransformerStack(eqx.Module):
610
675
  Args:
611
676
  x_tn: Input tensor of shape (seq_len, embed_dim)
612
677
  context_sn: Optional context for cross-attention
678
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
679
+ seq_len + cache_len)
613
680
  cache: Optional dictionary containing cached key and value tensors
614
681
 
615
682
  Returns:
@@ -629,6 +696,7 @@ class TransformerStack(eqx.Module):
629
696
  x_tn, updated_cache["layers"][f"layer_{i}"] = layer.forward(
630
697
  x_tn,
631
698
  context_sn=context_sn,
699
+ mask=mask,
632
700
  cache=layer_cache,
633
701
  )
634
702
 
@@ -693,16 +761,31 @@ class Transformer(eqx.Module):
693
761
  """Initialize cache for the input."""
694
762
  return self.layers.init_cache(dtype=dtype, x_tn=x_tn)
695
763
 
764
+ def init_mask(
765
+ self,
766
+ seq_len: int,
767
+ add_cache: bool = False,
768
+ batch_dim: bool = False,
769
+ ) -> Array:
770
+ return self.layers.init_mask(
771
+ seq_len,
772
+ add_cache=add_cache,
773
+ batch_dim=batch_dim,
774
+ )
775
+
696
776
  def encode(
697
777
  self,
698
778
  x: Array,
699
779
  *,
780
+ mask: Array | None = None,
700
781
  cache: TransformerCache | None = None,
701
782
  ) -> tuple[Array, TransformerCache]:
702
783
  """Encode the input sequence.
703
784
 
704
785
  Args:
705
786
  x: Input token indices of shape (seq_len)
787
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
788
+ seq_len + cache_len)
706
789
  cache: Optional dictionary containing cached key and value tensors
707
790
 
708
791
  Returns:
@@ -712,7 +795,7 @@ class Transformer(eqx.Module):
712
795
  x_embedded = jax.vmap(self.token_embedding)(x)
713
796
 
714
797
  # Apply transformer stack
715
- x_embedded, updated_cache = self.layers.forward(x_embedded, cache=cache)
798
+ x_embedded, updated_cache = self.layers.forward(x_embedded, mask=mask, cache=cache)
716
799
 
717
800
  # Apply final layer norm
718
801
  output = jax.vmap(self.layer_norm)(x_embedded)
@@ -724,6 +807,7 @@ class Transformer(eqx.Module):
724
807
  x_t: Array,
725
808
  context_s: Array,
726
809
  *,
810
+ mask: Array | None = None,
727
811
  cache: TransformerCache | None = None,
728
812
  ) -> tuple[Array, TransformerCache]:
729
813
  """Decode with self-attention and cross-attention.
@@ -732,6 +816,8 @@ class Transformer(eqx.Module):
732
816
  x_t: Input token indices, shape (seq_len)
733
817
  context_s: Context from encoder (token indices or embedded),
734
818
  shape (context_len, embed_dim)
819
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
820
+ seq_len + cache_len)
735
821
  cache: Optional dictionary containing cached key and value tensors
736
822
 
737
823
  Returns:
@@ -747,6 +833,7 @@ class Transformer(eqx.Module):
747
833
  x_embedded, updated_cache = self.layers.forward(
748
834
  x_embedded,
749
835
  context_sn=context_embedded,
836
+ mask=mask,
750
837
  cache=cache,
751
838
  )
752
839
 
@@ -759,12 +846,15 @@ class Transformer(eqx.Module):
759
846
  self,
760
847
  x: Array,
761
848
  *,
849
+ mask: Array | None = None,
762
850
  cache: TransformerCache | None = None,
763
851
  ) -> tuple[Array, TransformerCache]:
764
852
  """Forward pass for encoder-only or decoder-only transformers.
765
853
 
766
854
  Args:
767
855
  x: Input token indices of shape (seq_len)
856
+ mask: Optional mask of shape (batch_size, num_heads, seq_len,
857
+ seq_len + cache_len)
768
858
  cache: Optional dictionary containing cached key and value tensors
769
859
 
770
860
  Returns:
@@ -772,7 +862,7 @@ class Transformer(eqx.Module):
772
862
  """
773
863
  chex.assert_rank(x, 1)
774
864
 
775
- output, updated_cache = self.encode(x, cache=cache)
865
+ output, updated_cache = self.encode(x, mask=mask, cache=cache)
776
866
 
777
867
  # Apply output layer if it exists
778
868
  if self.output_layer is not None:
@@ -817,7 +907,8 @@ class Transformer(eqx.Module):
817
907
 
818
908
  # Initialize cache with prompt
819
909
  cache = self.init_cache()
820
- _, cache = self.encode(prompt_seq, cache=cache)
910
+ mask = self.init_mask(prompt_len, add_cache=True, batch_dim=False)
911
+ _, cache = self.encode(prompt_seq, cache=cache, mask=mask)
821
912
 
822
913
  # Define scan function for autoregressive generation
823
914
  def scan_fn(
@@ -253,3 +253,16 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
253
253
  mut = list(t)
254
254
  mut[index] = value
255
255
  return tuple(mut)
256
+
257
+
258
+ def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
259
+ leaves: dict[str, Array] = {}
260
+
261
+ def _get_leaf(path: tuple, x: PyTree) -> None:
262
+ if isinstance(x, jnp.ndarray):
263
+ # Convert path tuple to string, e.g. (1, 'a', 2) -> '1/a/2'
264
+ path_str = "/".join(str(p) for p in path)
265
+ leaves[path_str] = x
266
+
267
+ jax.tree.map_with_path(_get_leaf, pytree)
268
+ return leaves
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.5
3
+ Version: 0.3.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes