lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +2 -16
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.1.dist-info/RECORD +0 -71
  46. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py CHANGED
@@ -12,6 +12,7 @@ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
12
 
13
13
  from lalamo.common import ParameterTree, dummy_array
14
14
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
15
+ from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
15
16
 
16
17
  from .common import (
17
18
  LalamoModule,
@@ -59,7 +60,7 @@ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
59
60
  assert isinstance(self.output_dims, tuple)
60
61
 
61
62
  @staticmethod
62
- def _get_split_points(output_dims: Sequence[int]) -> tuple[int, ...]:
63
+ def get_split_points(output_dims: Sequence[int]) -> tuple[int, ...]:
63
64
  result = []
64
65
  last_split_point = 0
65
66
  for dim in output_dims[:-1]:
@@ -258,7 +259,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
258
259
  result = self.weights @ inputs
259
260
  if self.biases is not None:
260
261
  result = result + self.biases
261
- return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
262
+ return tuple(jnp.split(result, self.get_split_points(self.output_dims)))
262
263
 
263
264
  def export_weights(self) -> ParameterTree:
264
265
  result = dict(weights=self.weights)
@@ -279,12 +280,39 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
279
280
 
280
281
 
281
282
  @dataclass(frozen=True)
282
- class GroupQuantizedLinearConfig(LinearConfigBase):
283
+ class QuantizedLinearConfigBase(LinearConfigBase):
283
284
  group_size: int
284
285
  weight_quantization_mode: QuantizationMode
285
286
  activation_quantization_mode: QuantizationMode | None
286
287
  activation_precision: DTypeLike
287
288
 
289
+
290
+ class QuantizedLinearBase[ConfigT: QuantizedLinearConfigBase](LinearBase[ConfigT]):
291
+ biases: Float[Array, "*components total_out_channels"] | None
292
+
293
+ @abstractmethod
294
+ def _prepare_scaled_weights(self) -> Float[Array, "*components in_channels total_out_channels"]: ...
295
+
296
+ def _apply_weights(self, inputs: Float[Array, " in_channels"]) -> Float[Array, " total_out_channels"]:
297
+ if self.config.activation_quantization_mode is not None:
298
+ inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
299
+ return self._prepare_scaled_weights() @ inputs
300
+
301
+ @eqx.filter_jit
302
+ def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
303
+ if self.mixture_size is not None:
304
+ raise ValueError(
305
+ "Mixtures of linear layers cannot be called directly."
306
+ "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
307
+ )
308
+ result = self._apply_weights(inputs)
309
+ if self.biases is not None:
310
+ result = result + self.biases
311
+ return tuple(jnp.split(result, self.get_split_points(self.output_dims)))
312
+
313
+
314
+ @dataclass(frozen=True)
315
+ class GroupQuantizedLinearConfig(QuantizedLinearConfigBase):
288
316
  def random_init(
289
317
  self,
290
318
  input_dim: int,
@@ -381,7 +409,7 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
381
409
  return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
382
410
 
383
411
 
384
- class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[ConfigT]):
412
+ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLinearBase[ConfigT]):
385
413
  weights: Float[Array, "*components total_out_channels in_channels"]
386
414
  scales: Float[Array, "*components total_out_channels groups"]
387
415
  zero_points: Float[Array, "*components total_out_channels groups"]
@@ -414,13 +442,27 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
414
442
 
415
443
  @property
416
444
  def int_weights(self) -> Int[Array, "*components in_channels out_channels"]:
417
- result = quantize_weights(self.weights, self.config.weight_quantization_mode)
418
- return result.astype(self.config.weight_quantization_mode.dtype)
445
+ quantized = quantize_weights(self.weights, self.config.weight_quantization_mode)
446
+ casted = quantized.astype(self.config.weight_quantization_mode.dtype)
447
+
448
+ if self.config.weight_quantization_mode == QuantizationMode.UINT4:
449
+ packed = jax_uint4_to_packed_uint8(casted)
450
+ else:
451
+ packed = casted
452
+
453
+ return packed
419
454
 
420
455
  @property
421
456
  def int_zero_points(self) -> Int[Array, "*components groups out_channels"]:
422
- result = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
423
- return result.astype(self.config.weight_quantization_mode.dtype)
457
+ quantized = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
458
+ casted = quantized.astype(self.config.weight_quantization_mode.dtype)
459
+
460
+ if self.config.weight_quantization_mode == QuantizationMode.UINT4:
461
+ packed = jax_uint4_to_packed_uint8(casted)
462
+ else:
463
+ packed = casted
464
+
465
+ return packed
424
466
 
425
467
  def __post_init__(self) -> None: # noqa: PLR0912
426
468
  if self.weights.dtype != self.config.activation_precision:
@@ -520,28 +562,286 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
520
562
  )
521
563
  return result
522
564
 
523
- def _apply_weights(self, inputs: Float[Array, " in_channels"]) -> Float[Array, " total_out_channels"]:
524
- if self.config.activation_quantization_mode is not None:
525
- inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
526
- return self._prepare_scaled_weights() @ inputs
565
+ def export_weights(self) -> ParameterTree:
566
+ result = dict(
567
+ weights=self.int_weights,
568
+ zero_points=self.int_zero_points,
569
+ scales=self.scales,
570
+ )
571
+ if self.biases is not None:
572
+ result["biases"] = self.biases
573
+ return result
527
574
 
528
- @eqx.filter_jit
529
- def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
530
- if self.mixture_size is not None:
575
+ def import_weights(
576
+ self,
577
+ weights: ParameterTree[Array],
578
+ ) -> Self:
579
+ assert isinstance(weights, Mapping)
580
+ assert isinstance(weights["weights"], Array)
581
+ assert isinstance(weights["zero_points"], Array)
582
+ unpacked_weights = weights["weights"]
583
+ unpacked_zero_points = weights["zero_points"]
584
+
585
+ if self.config.weight_quantization_mode == QuantizationMode.UINT4:
586
+ unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
587
+ unpacked_zero_points = jax_uint8_to_unpacked_uint4(weights["zero_points"])
588
+
589
+ return replace(
590
+ self,
591
+ weights=unpacked_weights.astype(self.weights.dtype),
592
+ scales=weights["scales"],
593
+ zero_points=unpacked_zero_points.astype(self.zero_points.dtype),
594
+ biases=weights["biases"] if self.has_biases else None,
595
+ )
596
+
597
+
598
+ class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
599
+ pass
600
+
601
+
602
+ @dataclass(frozen=True)
603
+ class MLXQuantizedLinearConfig(QuantizedLinearConfigBase):
604
+ def random_init(
605
+ self,
606
+ input_dim: int,
607
+ output_dims: tuple[int, ...],
608
+ has_biases: bool,
609
+ *,
610
+ key: PRNGKeyArray,
611
+ ) -> LinearBase:
612
+ min_val, max_val = self.weight_quantization_mode.range
613
+ weights = jax.random.uniform(
614
+ key,
615
+ (sum(output_dims), input_dim),
616
+ minval=min_val - 1,
617
+ maxval=max_val + 1,
618
+ dtype=self.activation_precision,
619
+ )
620
+ num_groups = input_dim // self.group_size
621
+ scale = 1 / ((max_val - min_val) / 2 * math.sqrt(input_dim))
622
+ scales = scale * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
623
+
624
+ if has_biases:
625
+ biases = jnp.zeros((sum(output_dims),), dtype=self.activation_precision)
626
+ else:
627
+ biases = None
628
+
629
+ deq_bias = min_val + 2 ** (self.weight_quantization_mode.bits - 1)
630
+ deq_biases = deq_bias * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
631
+
632
+ return MLXQuantizedLinear(
633
+ config=self,
634
+ output_dims=output_dims,
635
+ weights=weights,
636
+ scales=scales,
637
+ deq_biases=deq_biases,
638
+ biases=biases,
639
+ )
640
+
641
+ def random_init_mixture(
642
+ self,
643
+ mixture_size: int,
644
+ input_dim: int,
645
+ output_dims: tuple[int, ...],
646
+ has_biases: bool,
647
+ *,
648
+ key: PRNGKeyArray,
649
+ ) -> LinearBase:
650
+ subkeys = jax.random.split(key, mixture_size)
651
+ return eqx.filter_vmap(lambda key: self.random_init(input_dim, output_dims, has_biases, key=key))(subkeys)
652
+
653
+ def _empty_general(
654
+ self,
655
+ leading_dims: tuple[int, ...],
656
+ input_dim: int,
657
+ output_dims: tuple[int, ...],
658
+ has_biases: bool,
659
+ ) -> LinearBase:
660
+ weights = dummy_array(
661
+ (*leading_dims, sum(output_dims), input_dim),
662
+ dtype=self.activation_precision,
663
+ )
664
+ num_groups = input_dim // self.group_size
665
+ scales = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
666
+
667
+ if has_biases:
668
+ biases = dummy_array((*leading_dims, sum(output_dims)), dtype=self.activation_precision)
669
+ else:
670
+ biases = None
671
+ deq_biases = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
672
+
673
+ return MLXQuantizedLinear(
674
+ config=self,
675
+ output_dims=output_dims,
676
+ weights=weights,
677
+ scales=scales,
678
+ deq_biases=deq_biases,
679
+ biases=biases,
680
+ )
681
+
682
+ def empty(
683
+ self,
684
+ input_dim: int,
685
+ output_dims: tuple[int, ...],
686
+ has_biases: bool,
687
+ ) -> LinearBase:
688
+ return self._empty_general((), input_dim, output_dims, has_biases)
689
+
690
+ def empty_mixture(
691
+ self,
692
+ mixture_size: int,
693
+ input_dim: int,
694
+ output_dims: tuple[int, ...],
695
+ has_biases: bool,
696
+ ) -> LinearBase:
697
+ return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
698
+
699
+
700
+ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearBase[ConfigT]):
701
+ weights: Float[Array, "*components total_out_channels in_channels"]
702
+ scales: Float[Array, "*components total_out_channels groups"]
703
+ deq_biases: Float[Array, "*components total_out_channels groups"]
704
+ biases: Float[Array, "*components total_out_channels"] | None
705
+
706
+ @property
707
+ def mixture_size(self) -> int | None:
708
+ match self.weights.shape:
709
+ case [num_components, _, _]:
710
+ return num_components
711
+ case _:
712
+ return None
713
+
714
+ @property
715
+ def activation_precision(self) -> DTypeLike:
716
+ return self.config.activation_precision
717
+
718
+ @property
719
+ def input_dim(self) -> int:
720
+ *_, _, input_dim = self.weights.shape
721
+ return input_dim
722
+
723
+ @property
724
+ def has_biases(self) -> bool:
725
+ return self.biases is not None
726
+
727
+ @property
728
+ def num_groups(self) -> int:
729
+ return self.input_dim // self.config.group_size
730
+
731
+ @property
732
+ def int_weights(self) -> Int[Array, "*components in_channels out_channels"]:
733
+ quantized = quantize_weights(self.weights, self.config.weight_quantization_mode)
734
+ casted = quantized.astype(self.config.weight_quantization_mode.dtype)
735
+
736
+ if self.config.weight_quantization_mode == QuantizationMode.UINT4:
737
+ packed = jax_uint4_to_packed_uint8(casted)
738
+ else:
739
+ packed = casted
740
+
741
+ return packed
742
+
743
+ def __post_init__(self) -> None: # noqa: PLR0912
744
+ if self.weights.dtype != self.config.activation_precision:
531
745
  raise ValueError(
532
- "Mixtures of linear layers cannot be called directly."
533
- "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
746
+ f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
747
+ f" ({self.config.activation_precision}).",
748
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
534
749
  )
535
- result = self._apply_weights(inputs)
750
+ *w_num_components, w_output_dim, _ = self.weights.shape
751
+ if w_output_dim != sum(self.output_dims):
752
+ raise ValueError(
753
+ f"Number of output channels in weights ({w_output_dim}) is not"
754
+ f" equal to sum of output dims ({sum(self.output_dims)}).",
755
+ )
756
+
757
+ if self.scales.dtype != self.config.activation_precision:
758
+ raise ValueError(
759
+ f"Scale dtype ({self.scales.dtype}) is not equal to specified activation precision"
760
+ f" ({self.config.activation_precision}).",
761
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
762
+ )
763
+ *s_num_components, s_output_dim, s_num_groups = self.scales.shape
764
+ if w_output_dim != s_output_dim:
765
+ raise ValueError(
766
+ f"Number of output channels in weights ({w_output_dim}) is not"
767
+ f" equal to number of output channels in scales ({s_output_dim}).",
768
+ )
769
+ if tuple(s_num_components) != tuple(w_num_components):
770
+ raise ValueError(
771
+ f"Number of mixture components in weights ({w_num_components}) is not"
772
+ f" equal to number of mixture components in scales ({s_num_components}).",
773
+ )
774
+ if s_num_groups != self.num_groups:
775
+ raise ValueError(
776
+ f"Number of groups in scales ({s_num_groups}) is incompatible with"
777
+ f" the specified group size ({self.config.group_size}).",
778
+ )
779
+
780
+ if self.deq_biases.dtype != self.config.activation_precision:
781
+ raise ValueError(
782
+ f"Dequantization bias dtype ({self.deq_biases.dtype}) is not equal to specified activation precision"
783
+ f" ({self.config.activation_precision}).",
784
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
785
+ )
786
+ *zp_num_components, zp_output_dim, zp_num_groups = self.deq_biases.shape
787
+ if w_output_dim != zp_output_dim:
788
+ raise ValueError(
789
+ f"Number of output channels in weights ({w_output_dim}) is not"
790
+ f" equal to number of output channels in zero points ({zp_output_dim}).",
791
+ )
792
+ if tuple(zp_num_components) != tuple(w_num_components):
793
+ raise ValueError(
794
+ f"Number of mixture components in weights ({w_num_components}) is not"
795
+ f" equal to number of mixture components in zero points ({zp_num_components}).",
796
+ )
797
+ if self.num_groups != zp_num_groups:
798
+ raise ValueError(
799
+ f"Number of groups in zero points ({zp_num_groups}) is incompatible with"
800
+ f" the specified group size ({self.config.group_size}).",
801
+ )
802
+
536
803
  if self.biases is not None:
537
- result = result + self.biases
538
- return tuple(jnp.split(result, self._get_split_points(self.output_dims)))
804
+ if self.biases.dtype != self.config.activation_precision:
805
+ raise ValueError(
806
+ f"Bias dtype ({self.biases.dtype}) is not equal to specified activation precision"
807
+ f" ({self.config.activation_precision}).",
808
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
809
+ )
810
+ *b_num_components, b_output_dim = self.biases.shape
811
+ if w_output_dim != b_output_dim:
812
+ raise ValueError(
813
+ f"Number of output channels in weights ({w_output_dim}) is not"
814
+ f" equal to number of output channels in biases ({b_output_dim}).",
815
+ )
816
+ if tuple(b_num_components) != tuple(w_num_components):
817
+ raise ValueError(
818
+ f"Number of mixture components in weights ({w_num_components}) is not"
819
+ f" equal to number of mixture components in biases ({b_num_components}).",
820
+ )
821
+
822
+ def _prepare_scaled_weights(self) -> Float[Array, "*components in_channels total_out_channels"]:
823
+ quantized_weights = quantize_weights(self.weights, self.config.weight_quantization_mode)
824
+ grouped_weights = rearrange(
825
+ quantized_weights,
826
+ "... total_out_channels (groups group_channels) -> ... total_out_channels groups group_channels",
827
+ groups=self.num_groups,
828
+ )
829
+
830
+ scales = rearrange(self.scales, "... total_out_channels groups -> ... total_out_channels groups 1")
831
+ deq_biases = rearrange(self.deq_biases, "... total_out_channels groups -> ... total_out_channels groups 1")
832
+
833
+ scaled_grouped_weights = grouped_weights * scales + deq_biases
834
+ result = rearrange(
835
+ scaled_grouped_weights,
836
+ "... total_out_channels groups group_channels -> ... total_out_channels (groups group_channels)",
837
+ )
838
+ return result
539
839
 
540
840
  def export_weights(self) -> ParameterTree:
541
841
  result = dict(
542
842
  weights=self.int_weights,
543
- zero_points=self.int_zero_points,
544
843
  scales=self.scales,
844
+ deq_biases=self.deq_biases,
545
845
  )
546
846
  if self.biases is not None:
547
847
  result["biases"] = self.biases
@@ -553,17 +853,22 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
553
853
  ) -> Self:
554
854
  assert isinstance(weights, Mapping)
555
855
  assert isinstance(weights["weights"], Array)
556
- assert isinstance(weights["zero_points"], Array)
856
+
857
+ unpacked_weights = weights["weights"]
858
+
859
+ if self.config.weight_quantization_mode == QuantizationMode.UINT4:
860
+ unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
861
+
557
862
  return replace(
558
863
  self,
559
- weights=weights["weights"].astype(self.weights.dtype),
864
+ weights=unpacked_weights.astype(self.weights.dtype),
560
865
  scales=weights["scales"],
561
- zero_points=weights["zero_points"].astype(self.zero_points.dtype),
866
+ deq_biases=weights["deq_biases"],
562
867
  biases=weights["biases"] if self.has_biases else None,
563
868
  )
564
869
 
565
870
 
566
- class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
871
+ class MLXQuantizedLinear(MLXQuantizedLinearBase[MLXQuantizedLinearConfig]):
567
872
  pass
568
873
 
569
874
 
@@ -714,7 +1019,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
714
1019
 
715
1020
  def _split_biases(self) -> tuple[Float[Array, "*components out_channels"] | None, ...]:
716
1021
  if self.biases is not None:
717
- return tuple(jnp.split(self.biases, self._get_split_points(self.output_dims)))
1022
+ return tuple(jnp.split(self.biases, self.get_split_points(self.output_dims)))
718
1023
  return (None,) * len(self.output_dims)
719
1024
 
720
1025
  def __post_init__(self) -> None:
@@ -778,10 +1083,10 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
778
1083
  "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
779
1084
  )
780
1085
  joint_q_out = self._apply_weights(inputs)
781
- q_outs = jnp.split(joint_q_out, self._get_split_points(self.output_dims))
1086
+ q_outs = jnp.split(joint_q_out, self.get_split_points(self.output_dims))
782
1087
 
783
1088
  joint_lora_hidden = inputs @ self.lora_down_weights
784
- lora_hiddens = jnp.split(joint_lora_hidden, self._get_split_points([self.config.lora_rank] * self.num_outputs))
1089
+ lora_hiddens = jnp.split(joint_lora_hidden, self.get_split_points([self.config.lora_rank] * self.num_outputs))
785
1090
  lora_outs = [
786
1091
  lora_hidden @ lora_up_weight
787
1092
  for lora_up_weight, lora_hidden in zip(self.lora_up_weights, lora_hiddens, strict=True)
@@ -818,7 +1123,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
818
1123
  )
819
1124
 
820
1125
 
821
- LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | QLoRALinearConfig
1126
+ LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | MLXQuantizedLinearConfig | QLoRALinearConfig
822
1127
 
823
1128
 
824
- register_config_union(LinearConfig)
1129
+ register_config_union(LinearConfig) # type: ignore (pyright bug)
lalamo/modules/mlp.py CHANGED
@@ -273,20 +273,19 @@ class SoftmaxRouting(RoutingFunctionBase):
273
273
  RoutingFunction = SoftmaxRouting | DummyUnionMember
274
274
 
275
275
 
276
- register_config_union(RoutingFunction)
276
+ register_config_union(RoutingFunction) # type: ignore (pyright bug)
277
277
 
278
278
 
279
279
  @dataclass(frozen=True)
280
280
  class MixtureOfExpertsConfig(ABC):
281
- mixture_size: int
282
- num_experts_per_token: int
281
+ expert_config: DenseMLPConfig
282
+ router_config: LinearConfig
283
283
  routing_function: RoutingFunction
284
284
 
285
- router_config: LinearConfig
285
+ mixture_size: int
286
+ num_experts_per_token: int
286
287
  router_has_biases: bool
287
288
 
288
- expert_config: DenseMLPConfig
289
-
290
289
  def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MixtureOfExperts":
291
290
  experts_key, router_key = jax.random.split(key)
292
291
  router = self.router_config.random_init(
@@ -481,4 +480,4 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
481
480
  MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
482
481
 
483
482
 
484
- register_config_union(MLPConfig)
483
+ register_config_union(MLPConfig) # type: ignore (pyright bug)
@@ -0,0 +1,19 @@
1
+ import jax.numpy as jnp
2
+ import mlx.core as mx
3
+ from jaxtyping import Array
4
+
5
+ __all__ = ["jax_to_mlx", "mlx_to_jax"]
6
+
7
+
8
+ def mlx_to_jax(a: mx.array) -> Array:
9
+ if a.dtype == mx.bfloat16:
10
+ return jnp.asarray(a.view(mx.uint16)).view(jnp.bfloat16)
11
+
12
+ return jnp.asarray(a)
13
+
14
+
15
+ def jax_to_mlx(a: Array) -> mx.array:
16
+ if a.dtype == jnp.bfloat16:
17
+ return mx.array(a.view(jnp.uint16)).view(mx.bfloat16) # type: ignore
18
+
19
+ return mx.array(a) # type: ignore
lalamo/modules/rope.py CHANGED
@@ -281,4 +281,4 @@ class LinearScalingRoPEConfig(RoPEConfigBase):
281
281
 
282
282
  RoPEConfig = UnscaledRoPEConfig | LlamaRoPEConfig | YARNRoPEConfig | LinearScalingRoPEConfig
283
283
 
284
- register_config_union(RoPEConfig)
284
+ register_config_union(RoPEConfig) # type: ignore (pyright bug)
@@ -0,0 +1,30 @@
1
+ from lalamo.modules.common import register_config_union
2
+
3
+ from .attention import Attention, AttentionConfig, AttentionResult
4
+ from .common import TokenMixerBase, TokenMixerResult
5
+ from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
6
+ from .state import DynamicKVCacheLayer, KVCacheLayer, Mamba2StateLayer, State, StateLayerBase, StaticKVCacheLayer
7
+
8
+ TokenMixerConfig = AttentionConfig | Mamba2Config
9
+
10
+ register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
11
+
12
+ __all__ = [
13
+ "Attention",
14
+ "AttentionConfig",
15
+ "AttentionResult",
16
+ "DynamicKVCacheLayer",
17
+ "KVCacheLayer",
18
+ "Mamba2",
19
+ "Mamba2Config",
20
+ "Mamba2Result",
21
+ "Mamba2StateLayer",
22
+ "SeparableCausalConv",
23
+ "SeparableCausalConvConfig",
24
+ "State",
25
+ "StateLayerBase",
26
+ "StaticKVCacheLayer",
27
+ "TokenMixerBase",
28
+ "TokenMixerConfig",
29
+ "TokenMixerResult",
30
+ ]