lalamo 0.4.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +2 -16
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +33 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.1.dist-info → lalamo-0.5.1.dist-info}/METADATA +3 -2
- lalamo-0.5.1.dist-info/RECORD +80 -0
- lalamo-0.4.1.dist-info/RECORD +0 -71
- {lalamo-0.4.1.dist-info → lalamo-0.5.1.dist-info}/WHEEL +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.1.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.1.dist-info}/top_level.txt +0 -0
lalamo/modules/linear.py
CHANGED
|
@@ -12,6 +12,7 @@ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
|
12
12
|
|
|
13
13
|
from lalamo.common import ParameterTree, dummy_array
|
|
14
14
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
15
|
+
from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
|
|
15
16
|
|
|
16
17
|
from .common import (
|
|
17
18
|
LalamoModule,
|
|
@@ -59,7 +60,7 @@ class LinearBase[ConfigT: LinearConfigBase](LalamoModule[ConfigT]):
|
|
|
59
60
|
assert isinstance(self.output_dims, tuple)
|
|
60
61
|
|
|
61
62
|
@staticmethod
|
|
62
|
-
def
|
|
63
|
+
def get_split_points(output_dims: Sequence[int]) -> tuple[int, ...]:
|
|
63
64
|
result = []
|
|
64
65
|
last_split_point = 0
|
|
65
66
|
for dim in output_dims[:-1]:
|
|
@@ -258,7 +259,7 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
258
259
|
result = self.weights @ inputs
|
|
259
260
|
if self.biases is not None:
|
|
260
261
|
result = result + self.biases
|
|
261
|
-
return tuple(jnp.split(result, self.
|
|
262
|
+
return tuple(jnp.split(result, self.get_split_points(self.output_dims)))
|
|
262
263
|
|
|
263
264
|
def export_weights(self) -> ParameterTree:
|
|
264
265
|
result = dict(weights=self.weights)
|
|
@@ -279,12 +280,39 @@ class FullPrecisionLinear(LinearBase[FullPrecisionLinearConfig]):
|
|
|
279
280
|
|
|
280
281
|
|
|
281
282
|
@dataclass(frozen=True)
|
|
282
|
-
class
|
|
283
|
+
class QuantizedLinearConfigBase(LinearConfigBase):
|
|
283
284
|
group_size: int
|
|
284
285
|
weight_quantization_mode: QuantizationMode
|
|
285
286
|
activation_quantization_mode: QuantizationMode | None
|
|
286
287
|
activation_precision: DTypeLike
|
|
287
288
|
|
|
289
|
+
|
|
290
|
+
class QuantizedLinearBase[ConfigT: QuantizedLinearConfigBase](LinearBase[ConfigT]):
|
|
291
|
+
biases: Float[Array, "*components total_out_channels"] | None
|
|
292
|
+
|
|
293
|
+
@abstractmethod
|
|
294
|
+
def _prepare_scaled_weights(self) -> Float[Array, "*components in_channels total_out_channels"]: ...
|
|
295
|
+
|
|
296
|
+
def _apply_weights(self, inputs: Float[Array, " in_channels"]) -> Float[Array, " total_out_channels"]:
|
|
297
|
+
if self.config.activation_quantization_mode is not None:
|
|
298
|
+
inputs = dynamically_quantize_activations(inputs, self.config.activation_quantization_mode)
|
|
299
|
+
return self._prepare_scaled_weights() @ inputs
|
|
300
|
+
|
|
301
|
+
@eqx.filter_jit
|
|
302
|
+
def __call__(self, inputs: Float[Array, " in_channels"]) -> tuple[Float[Array, " out_channels"], ...]:
|
|
303
|
+
if self.mixture_size is not None:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
"Mixtures of linear layers cannot be called directly."
|
|
306
|
+
"They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
|
|
307
|
+
)
|
|
308
|
+
result = self._apply_weights(inputs)
|
|
309
|
+
if self.biases is not None:
|
|
310
|
+
result = result + self.biases
|
|
311
|
+
return tuple(jnp.split(result, self.get_split_points(self.output_dims)))
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@dataclass(frozen=True)
|
|
315
|
+
class GroupQuantizedLinearConfig(QuantizedLinearConfigBase):
|
|
288
316
|
def random_init(
|
|
289
317
|
self,
|
|
290
318
|
input_dim: int,
|
|
@@ -381,7 +409,7 @@ class GroupQuantizedLinearConfig(LinearConfigBase):
|
|
|
381
409
|
return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
|
|
382
410
|
|
|
383
411
|
|
|
384
|
-
class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](
|
|
412
|
+
class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLinearBase[ConfigT]):
|
|
385
413
|
weights: Float[Array, "*components total_out_channels in_channels"]
|
|
386
414
|
scales: Float[Array, "*components total_out_channels groups"]
|
|
387
415
|
zero_points: Float[Array, "*components total_out_channels groups"]
|
|
@@ -414,13 +442,27 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
414
442
|
|
|
415
443
|
@property
|
|
416
444
|
def int_weights(self) -> Int[Array, "*components in_channels out_channels"]:
|
|
417
|
-
|
|
418
|
-
|
|
445
|
+
quantized = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
446
|
+
casted = quantized.astype(self.config.weight_quantization_mode.dtype)
|
|
447
|
+
|
|
448
|
+
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
449
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
450
|
+
else:
|
|
451
|
+
packed = casted
|
|
452
|
+
|
|
453
|
+
return packed
|
|
419
454
|
|
|
420
455
|
@property
|
|
421
456
|
def int_zero_points(self) -> Int[Array, "*components groups out_channels"]:
|
|
422
|
-
|
|
423
|
-
|
|
457
|
+
quantized = quantize_weights(self.zero_points, self.config.weight_quantization_mode)
|
|
458
|
+
casted = quantized.astype(self.config.weight_quantization_mode.dtype)
|
|
459
|
+
|
|
460
|
+
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
461
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
462
|
+
else:
|
|
463
|
+
packed = casted
|
|
464
|
+
|
|
465
|
+
return packed
|
|
424
466
|
|
|
425
467
|
def __post_init__(self) -> None: # noqa: PLR0912
|
|
426
468
|
if self.weights.dtype != self.config.activation_precision:
|
|
@@ -520,28 +562,286 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
520
562
|
)
|
|
521
563
|
return result
|
|
522
564
|
|
|
523
|
-
def
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
565
|
+
def export_weights(self) -> ParameterTree:
|
|
566
|
+
result = dict(
|
|
567
|
+
weights=self.int_weights,
|
|
568
|
+
zero_points=self.int_zero_points,
|
|
569
|
+
scales=self.scales,
|
|
570
|
+
)
|
|
571
|
+
if self.biases is not None:
|
|
572
|
+
result["biases"] = self.biases
|
|
573
|
+
return result
|
|
527
574
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
575
|
+
def import_weights(
|
|
576
|
+
self,
|
|
577
|
+
weights: ParameterTree[Array],
|
|
578
|
+
) -> Self:
|
|
579
|
+
assert isinstance(weights, Mapping)
|
|
580
|
+
assert isinstance(weights["weights"], Array)
|
|
581
|
+
assert isinstance(weights["zero_points"], Array)
|
|
582
|
+
unpacked_weights = weights["weights"]
|
|
583
|
+
unpacked_zero_points = weights["zero_points"]
|
|
584
|
+
|
|
585
|
+
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
586
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
|
|
587
|
+
unpacked_zero_points = jax_uint8_to_unpacked_uint4(weights["zero_points"])
|
|
588
|
+
|
|
589
|
+
return replace(
|
|
590
|
+
self,
|
|
591
|
+
weights=unpacked_weights.astype(self.weights.dtype),
|
|
592
|
+
scales=weights["scales"],
|
|
593
|
+
zero_points=unpacked_zero_points.astype(self.zero_points.dtype),
|
|
594
|
+
biases=weights["biases"] if self.has_biases else None,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
class GroupQuantizedLinear(GroupQuantizedLinearBase[GroupQuantizedLinearConfig]):
|
|
599
|
+
pass
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
@dataclass(frozen=True)
|
|
603
|
+
class MLXQuantizedLinearConfig(QuantizedLinearConfigBase):
|
|
604
|
+
def random_init(
|
|
605
|
+
self,
|
|
606
|
+
input_dim: int,
|
|
607
|
+
output_dims: tuple[int, ...],
|
|
608
|
+
has_biases: bool,
|
|
609
|
+
*,
|
|
610
|
+
key: PRNGKeyArray,
|
|
611
|
+
) -> LinearBase:
|
|
612
|
+
min_val, max_val = self.weight_quantization_mode.range
|
|
613
|
+
weights = jax.random.uniform(
|
|
614
|
+
key,
|
|
615
|
+
(sum(output_dims), input_dim),
|
|
616
|
+
minval=min_val - 1,
|
|
617
|
+
maxval=max_val + 1,
|
|
618
|
+
dtype=self.activation_precision,
|
|
619
|
+
)
|
|
620
|
+
num_groups = input_dim // self.group_size
|
|
621
|
+
scale = 1 / ((max_val - min_val) / 2 * math.sqrt(input_dim))
|
|
622
|
+
scales = scale * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
623
|
+
|
|
624
|
+
if has_biases:
|
|
625
|
+
biases = jnp.zeros((sum(output_dims),), dtype=self.activation_precision)
|
|
626
|
+
else:
|
|
627
|
+
biases = None
|
|
628
|
+
|
|
629
|
+
deq_bias = min_val + 2 ** (self.weight_quantization_mode.bits - 1)
|
|
630
|
+
deq_biases = deq_bias * jnp.ones((sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
631
|
+
|
|
632
|
+
return MLXQuantizedLinear(
|
|
633
|
+
config=self,
|
|
634
|
+
output_dims=output_dims,
|
|
635
|
+
weights=weights,
|
|
636
|
+
scales=scales,
|
|
637
|
+
deq_biases=deq_biases,
|
|
638
|
+
biases=biases,
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
def random_init_mixture(
|
|
642
|
+
self,
|
|
643
|
+
mixture_size: int,
|
|
644
|
+
input_dim: int,
|
|
645
|
+
output_dims: tuple[int, ...],
|
|
646
|
+
has_biases: bool,
|
|
647
|
+
*,
|
|
648
|
+
key: PRNGKeyArray,
|
|
649
|
+
) -> LinearBase:
|
|
650
|
+
subkeys = jax.random.split(key, mixture_size)
|
|
651
|
+
return eqx.filter_vmap(lambda key: self.random_init(input_dim, output_dims, has_biases, key=key))(subkeys)
|
|
652
|
+
|
|
653
|
+
def _empty_general(
|
|
654
|
+
self,
|
|
655
|
+
leading_dims: tuple[int, ...],
|
|
656
|
+
input_dim: int,
|
|
657
|
+
output_dims: tuple[int, ...],
|
|
658
|
+
has_biases: bool,
|
|
659
|
+
) -> LinearBase:
|
|
660
|
+
weights = dummy_array(
|
|
661
|
+
(*leading_dims, sum(output_dims), input_dim),
|
|
662
|
+
dtype=self.activation_precision,
|
|
663
|
+
)
|
|
664
|
+
num_groups = input_dim // self.group_size
|
|
665
|
+
scales = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
666
|
+
|
|
667
|
+
if has_biases:
|
|
668
|
+
biases = dummy_array((*leading_dims, sum(output_dims)), dtype=self.activation_precision)
|
|
669
|
+
else:
|
|
670
|
+
biases = None
|
|
671
|
+
deq_biases = dummy_array((*leading_dims, sum(output_dims), num_groups), dtype=self.activation_precision)
|
|
672
|
+
|
|
673
|
+
return MLXQuantizedLinear(
|
|
674
|
+
config=self,
|
|
675
|
+
output_dims=output_dims,
|
|
676
|
+
weights=weights,
|
|
677
|
+
scales=scales,
|
|
678
|
+
deq_biases=deq_biases,
|
|
679
|
+
biases=biases,
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
def empty(
|
|
683
|
+
self,
|
|
684
|
+
input_dim: int,
|
|
685
|
+
output_dims: tuple[int, ...],
|
|
686
|
+
has_biases: bool,
|
|
687
|
+
) -> LinearBase:
|
|
688
|
+
return self._empty_general((), input_dim, output_dims, has_biases)
|
|
689
|
+
|
|
690
|
+
def empty_mixture(
|
|
691
|
+
self,
|
|
692
|
+
mixture_size: int,
|
|
693
|
+
input_dim: int,
|
|
694
|
+
output_dims: tuple[int, ...],
|
|
695
|
+
has_biases: bool,
|
|
696
|
+
) -> LinearBase:
|
|
697
|
+
return self._empty_general((mixture_size,), input_dim, output_dims, has_biases)
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearBase[ConfigT]):
|
|
701
|
+
weights: Float[Array, "*components total_out_channels in_channels"]
|
|
702
|
+
scales: Float[Array, "*components total_out_channels groups"]
|
|
703
|
+
deq_biases: Float[Array, "*components total_out_channels groups"]
|
|
704
|
+
biases: Float[Array, "*components total_out_channels"] | None
|
|
705
|
+
|
|
706
|
+
@property
|
|
707
|
+
def mixture_size(self) -> int | None:
|
|
708
|
+
match self.weights.shape:
|
|
709
|
+
case [num_components, _, _]:
|
|
710
|
+
return num_components
|
|
711
|
+
case _:
|
|
712
|
+
return None
|
|
713
|
+
|
|
714
|
+
@property
|
|
715
|
+
def activation_precision(self) -> DTypeLike:
|
|
716
|
+
return self.config.activation_precision
|
|
717
|
+
|
|
718
|
+
@property
|
|
719
|
+
def input_dim(self) -> int:
|
|
720
|
+
*_, _, input_dim = self.weights.shape
|
|
721
|
+
return input_dim
|
|
722
|
+
|
|
723
|
+
@property
|
|
724
|
+
def has_biases(self) -> bool:
|
|
725
|
+
return self.biases is not None
|
|
726
|
+
|
|
727
|
+
@property
|
|
728
|
+
def num_groups(self) -> int:
|
|
729
|
+
return self.input_dim // self.config.group_size
|
|
730
|
+
|
|
731
|
+
@property
|
|
732
|
+
def int_weights(self) -> Int[Array, "*components in_channels out_channels"]:
|
|
733
|
+
quantized = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
734
|
+
casted = quantized.astype(self.config.weight_quantization_mode.dtype)
|
|
735
|
+
|
|
736
|
+
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
737
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
738
|
+
else:
|
|
739
|
+
packed = casted
|
|
740
|
+
|
|
741
|
+
return packed
|
|
742
|
+
|
|
743
|
+
def __post_init__(self) -> None: # noqa: PLR0912
|
|
744
|
+
if self.weights.dtype != self.config.activation_precision:
|
|
531
745
|
raise ValueError(
|
|
532
|
-
"
|
|
533
|
-
"
|
|
746
|
+
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
747
|
+
f" ({self.config.activation_precision}).",
|
|
748
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
534
749
|
)
|
|
535
|
-
|
|
750
|
+
*w_num_components, w_output_dim, _ = self.weights.shape
|
|
751
|
+
if w_output_dim != sum(self.output_dims):
|
|
752
|
+
raise ValueError(
|
|
753
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
754
|
+
f" equal to sum of output dims ({sum(self.output_dims)}).",
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
if self.scales.dtype != self.config.activation_precision:
|
|
758
|
+
raise ValueError(
|
|
759
|
+
f"Scale dtype ({self.scales.dtype}) is not equal to specified activation precision"
|
|
760
|
+
f" ({self.config.activation_precision}).",
|
|
761
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
762
|
+
)
|
|
763
|
+
*s_num_components, s_output_dim, s_num_groups = self.scales.shape
|
|
764
|
+
if w_output_dim != s_output_dim:
|
|
765
|
+
raise ValueError(
|
|
766
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
767
|
+
f" equal to number of output channels in scales ({s_output_dim}).",
|
|
768
|
+
)
|
|
769
|
+
if tuple(s_num_components) != tuple(w_num_components):
|
|
770
|
+
raise ValueError(
|
|
771
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
772
|
+
f" equal to number of mixture components in scales ({s_num_components}).",
|
|
773
|
+
)
|
|
774
|
+
if s_num_groups != self.num_groups:
|
|
775
|
+
raise ValueError(
|
|
776
|
+
f"Number of groups in scales ({s_num_groups}) is incompatible with"
|
|
777
|
+
f" the specified group size ({self.config.group_size}).",
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
if self.deq_biases.dtype != self.config.activation_precision:
|
|
781
|
+
raise ValueError(
|
|
782
|
+
f"Dequantization bias dtype ({self.deq_biases.dtype}) is not equal to specified activation precision"
|
|
783
|
+
f" ({self.config.activation_precision}).",
|
|
784
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
785
|
+
)
|
|
786
|
+
*zp_num_components, zp_output_dim, zp_num_groups = self.deq_biases.shape
|
|
787
|
+
if w_output_dim != zp_output_dim:
|
|
788
|
+
raise ValueError(
|
|
789
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
790
|
+
f" equal to number of output channels in zero points ({zp_output_dim}).",
|
|
791
|
+
)
|
|
792
|
+
if tuple(zp_num_components) != tuple(w_num_components):
|
|
793
|
+
raise ValueError(
|
|
794
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
795
|
+
f" equal to number of mixture components in zero points ({zp_num_components}).",
|
|
796
|
+
)
|
|
797
|
+
if self.num_groups != zp_num_groups:
|
|
798
|
+
raise ValueError(
|
|
799
|
+
f"Number of groups in zero points ({zp_num_groups}) is incompatible with"
|
|
800
|
+
f" the specified group size ({self.config.group_size}).",
|
|
801
|
+
)
|
|
802
|
+
|
|
536
803
|
if self.biases is not None:
|
|
537
|
-
|
|
538
|
-
|
|
804
|
+
if self.biases.dtype != self.config.activation_precision:
|
|
805
|
+
raise ValueError(
|
|
806
|
+
f"Bias dtype ({self.biases.dtype}) is not equal to specified activation precision"
|
|
807
|
+
f" ({self.config.activation_precision}).",
|
|
808
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
809
|
+
)
|
|
810
|
+
*b_num_components, b_output_dim = self.biases.shape
|
|
811
|
+
if w_output_dim != b_output_dim:
|
|
812
|
+
raise ValueError(
|
|
813
|
+
f"Number of output channels in weights ({w_output_dim}) is not"
|
|
814
|
+
f" equal to number of output channels in biases ({b_output_dim}).",
|
|
815
|
+
)
|
|
816
|
+
if tuple(b_num_components) != tuple(w_num_components):
|
|
817
|
+
raise ValueError(
|
|
818
|
+
f"Number of mixture components in weights ({w_num_components}) is not"
|
|
819
|
+
f" equal to number of mixture components in biases ({b_num_components}).",
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
def _prepare_scaled_weights(self) -> Float[Array, "*components in_channels total_out_channels"]:
|
|
823
|
+
quantized_weights = quantize_weights(self.weights, self.config.weight_quantization_mode)
|
|
824
|
+
grouped_weights = rearrange(
|
|
825
|
+
quantized_weights,
|
|
826
|
+
"... total_out_channels (groups group_channels) -> ... total_out_channels groups group_channels",
|
|
827
|
+
groups=self.num_groups,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
scales = rearrange(self.scales, "... total_out_channels groups -> ... total_out_channels groups 1")
|
|
831
|
+
deq_biases = rearrange(self.deq_biases, "... total_out_channels groups -> ... total_out_channels groups 1")
|
|
832
|
+
|
|
833
|
+
scaled_grouped_weights = grouped_weights * scales + deq_biases
|
|
834
|
+
result = rearrange(
|
|
835
|
+
scaled_grouped_weights,
|
|
836
|
+
"... total_out_channels groups group_channels -> ... total_out_channels (groups group_channels)",
|
|
837
|
+
)
|
|
838
|
+
return result
|
|
539
839
|
|
|
540
840
|
def export_weights(self) -> ParameterTree:
|
|
541
841
|
result = dict(
|
|
542
842
|
weights=self.int_weights,
|
|
543
|
-
zero_points=self.int_zero_points,
|
|
544
843
|
scales=self.scales,
|
|
844
|
+
deq_biases=self.deq_biases,
|
|
545
845
|
)
|
|
546
846
|
if self.biases is not None:
|
|
547
847
|
result["biases"] = self.biases
|
|
@@ -553,17 +853,22 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](LinearBase[C
|
|
|
553
853
|
) -> Self:
|
|
554
854
|
assert isinstance(weights, Mapping)
|
|
555
855
|
assert isinstance(weights["weights"], Array)
|
|
556
|
-
|
|
856
|
+
|
|
857
|
+
unpacked_weights = weights["weights"]
|
|
858
|
+
|
|
859
|
+
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
860
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
|
|
861
|
+
|
|
557
862
|
return replace(
|
|
558
863
|
self,
|
|
559
|
-
weights=
|
|
864
|
+
weights=unpacked_weights.astype(self.weights.dtype),
|
|
560
865
|
scales=weights["scales"],
|
|
561
|
-
|
|
866
|
+
deq_biases=weights["deq_biases"],
|
|
562
867
|
biases=weights["biases"] if self.has_biases else None,
|
|
563
868
|
)
|
|
564
869
|
|
|
565
870
|
|
|
566
|
-
class
|
|
871
|
+
class MLXQuantizedLinear(MLXQuantizedLinearBase[MLXQuantizedLinearConfig]):
|
|
567
872
|
pass
|
|
568
873
|
|
|
569
874
|
|
|
@@ -714,7 +1019,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
714
1019
|
|
|
715
1020
|
def _split_biases(self) -> tuple[Float[Array, "*components out_channels"] | None, ...]:
|
|
716
1021
|
if self.biases is not None:
|
|
717
|
-
return tuple(jnp.split(self.biases, self.
|
|
1022
|
+
return tuple(jnp.split(self.biases, self.get_split_points(self.output_dims)))
|
|
718
1023
|
return (None,) * len(self.output_dims)
|
|
719
1024
|
|
|
720
1025
|
def __post_init__(self) -> None:
|
|
@@ -778,10 +1083,10 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
778
1083
|
"They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
|
|
779
1084
|
)
|
|
780
1085
|
joint_q_out = self._apply_weights(inputs)
|
|
781
|
-
q_outs = jnp.split(joint_q_out, self.
|
|
1086
|
+
q_outs = jnp.split(joint_q_out, self.get_split_points(self.output_dims))
|
|
782
1087
|
|
|
783
1088
|
joint_lora_hidden = inputs @ self.lora_down_weights
|
|
784
|
-
lora_hiddens = jnp.split(joint_lora_hidden, self.
|
|
1089
|
+
lora_hiddens = jnp.split(joint_lora_hidden, self.get_split_points([self.config.lora_rank] * self.num_outputs))
|
|
785
1090
|
lora_outs = [
|
|
786
1091
|
lora_hidden @ lora_up_weight
|
|
787
1092
|
for lora_up_weight, lora_hidden in zip(self.lora_up_weights, lora_hiddens, strict=True)
|
|
@@ -818,7 +1123,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
818
1123
|
)
|
|
819
1124
|
|
|
820
1125
|
|
|
821
|
-
LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | QLoRALinearConfig
|
|
1126
|
+
LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | MLXQuantizedLinearConfig | QLoRALinearConfig
|
|
822
1127
|
|
|
823
1128
|
|
|
824
|
-
register_config_union(LinearConfig)
|
|
1129
|
+
register_config_union(LinearConfig) # type: ignore (pyright bug)
|
lalamo/modules/mlp.py
CHANGED
|
@@ -273,20 +273,19 @@ class SoftmaxRouting(RoutingFunctionBase):
|
|
|
273
273
|
RoutingFunction = SoftmaxRouting | DummyUnionMember
|
|
274
274
|
|
|
275
275
|
|
|
276
|
-
register_config_union(RoutingFunction)
|
|
276
|
+
register_config_union(RoutingFunction) # type: ignore (pyright bug)
|
|
277
277
|
|
|
278
278
|
|
|
279
279
|
@dataclass(frozen=True)
|
|
280
280
|
class MixtureOfExpertsConfig(ABC):
|
|
281
|
-
|
|
282
|
-
|
|
281
|
+
expert_config: DenseMLPConfig
|
|
282
|
+
router_config: LinearConfig
|
|
283
283
|
routing_function: RoutingFunction
|
|
284
284
|
|
|
285
|
-
|
|
285
|
+
mixture_size: int
|
|
286
|
+
num_experts_per_token: int
|
|
286
287
|
router_has_biases: bool
|
|
287
288
|
|
|
288
|
-
expert_config: DenseMLPConfig
|
|
289
|
-
|
|
290
289
|
def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MixtureOfExperts":
|
|
291
290
|
experts_key, router_key = jax.random.split(key)
|
|
292
291
|
router = self.router_config.random_init(
|
|
@@ -481,4 +480,4 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
|
|
|
481
480
|
MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
|
|
482
481
|
|
|
483
482
|
|
|
484
|
-
register_config_union(MLPConfig)
|
|
483
|
+
register_config_union(MLPConfig) # type: ignore (pyright bug)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import mlx.core as mx
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
|
|
5
|
+
__all__ = ["jax_to_mlx", "mlx_to_jax"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def mlx_to_jax(a: mx.array) -> Array:
|
|
9
|
+
if a.dtype == mx.bfloat16:
|
|
10
|
+
return jnp.asarray(a.view(mx.uint16)).view(jnp.bfloat16)
|
|
11
|
+
|
|
12
|
+
return jnp.asarray(a)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def jax_to_mlx(a: Array) -> mx.array:
|
|
16
|
+
if a.dtype == jnp.bfloat16:
|
|
17
|
+
return mx.array(a.view(jnp.uint16)).view(mx.bfloat16) # type: ignore
|
|
18
|
+
|
|
19
|
+
return mx.array(a) # type: ignore
|
lalamo/modules/rope.py
CHANGED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from lalamo.modules.common import register_config_union
|
|
2
|
+
|
|
3
|
+
from .attention import Attention, AttentionConfig, AttentionResult
|
|
4
|
+
from .common import TokenMixerBase, TokenMixerResult
|
|
5
|
+
from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
|
|
6
|
+
from .state import DynamicKVCacheLayer, KVCacheLayer, Mamba2StateLayer, State, StateLayerBase, StaticKVCacheLayer
|
|
7
|
+
|
|
8
|
+
TokenMixerConfig = AttentionConfig | Mamba2Config
|
|
9
|
+
|
|
10
|
+
register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Attention",
|
|
14
|
+
"AttentionConfig",
|
|
15
|
+
"AttentionResult",
|
|
16
|
+
"DynamicKVCacheLayer",
|
|
17
|
+
"KVCacheLayer",
|
|
18
|
+
"Mamba2",
|
|
19
|
+
"Mamba2Config",
|
|
20
|
+
"Mamba2Result",
|
|
21
|
+
"Mamba2StateLayer",
|
|
22
|
+
"SeparableCausalConv",
|
|
23
|
+
"SeparableCausalConvConfig",
|
|
24
|
+
"State",
|
|
25
|
+
"StateLayerBase",
|
|
26
|
+
"StaticKVCacheLayer",
|
|
27
|
+
"TokenMixerBase",
|
|
28
|
+
"TokenMixerConfig",
|
|
29
|
+
"TokenMixerResult",
|
|
30
|
+
]
|