lalamo 0.5.14__py3-none-any.whl → 0.5.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/decoder_configs/huggingface/llama.py +32 -21
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +33 -22
- lalamo/model_import/loaders/huggingface.py +49 -1
- lalamo/model_import/model_specs/qwen.py +14 -0
- lalamo/modules/__init__.py +4 -0
- lalamo/modules/embedding.py +169 -0
- {lalamo-0.5.14.dist-info → lalamo-0.5.16.dist-info}/METADATA +1 -1
- {lalamo-0.5.14.dist-info → lalamo-0.5.16.dist-info}/RECORD +13 -13
- {lalamo-0.5.14.dist-info → lalamo-0.5.16.dist-info}/WHEEL +0 -0
- {lalamo-0.5.14.dist-info → lalamo-0.5.16.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.14.dist-info → lalamo-0.5.16.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.14.dist-info → lalamo-0.5.16.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -13,6 +13,7 @@ from lalamo.modules import (
|
|
|
13
13
|
LlamaRoPEConfig,
|
|
14
14
|
MLXQuantizedLinearConfig,
|
|
15
15
|
MLXQuantizedTiedEmbeddingConfig,
|
|
16
|
+
MLXQuantizedUntiedEmbeddingConfig,
|
|
16
17
|
NormalizationConfig,
|
|
17
18
|
SiLU,
|
|
18
19
|
TiedEmbeddingConfig,
|
|
@@ -89,27 +90,37 @@ class HFLlamaConfig(HuggingFaceLMConfig):
|
|
|
89
90
|
) -> DecoderConfig:
|
|
90
91
|
quantization = self.quantization or self.quantization_config
|
|
91
92
|
if isinstance(quantization, MLXQuantizationConfig):
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
93
|
+
if self.tie_word_embeddings:
|
|
94
|
+
embedding_config = MLXQuantizedTiedEmbeddingConfig(
|
|
95
|
+
input_scale=None,
|
|
96
|
+
logit_soft_cap=None,
|
|
97
|
+
group_size=quantization.group_size,
|
|
98
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(quantization.bits),
|
|
99
|
+
activation_quantization_mode=None,
|
|
100
|
+
activation_precision=activation_precision,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
embedding_config = MLXQuantizedUntiedEmbeddingConfig(
|
|
104
|
+
input_scale=None,
|
|
105
|
+
logit_soft_cap=None,
|
|
106
|
+
group_size=quantization.group_size,
|
|
107
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(quantization.bits),
|
|
108
|
+
activation_quantization_mode=None,
|
|
109
|
+
activation_precision=activation_precision,
|
|
110
|
+
)
|
|
111
|
+
else: # noqa: PLR5501
|
|
112
|
+
if self.tie_word_embeddings:
|
|
113
|
+
embedding_config = TiedEmbeddingConfig(
|
|
114
|
+
input_scale=None,
|
|
115
|
+
logit_soft_cap=None,
|
|
116
|
+
precision=activation_precision,
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
120
|
+
input_scale=None,
|
|
121
|
+
logit_soft_cap=None,
|
|
122
|
+
precision=activation_precision,
|
|
123
|
+
)
|
|
113
124
|
if self.rope_scaling is None:
|
|
114
125
|
rope_config = UnscaledRoPEConfig(
|
|
115
126
|
precision=activation_precision,
|
|
@@ -10,6 +10,8 @@ from lalamo.modules import (
|
|
|
10
10
|
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
12
|
GroupQuantizedLinearConfig,
|
|
13
|
+
MLXQuantizedTiedEmbeddingConfig,
|
|
14
|
+
MLXQuantizedUntiedEmbeddingConfig,
|
|
13
15
|
NormalizationConfig,
|
|
14
16
|
TiedEmbeddingConfig,
|
|
15
17
|
TransformerConfig,
|
|
@@ -19,7 +21,6 @@ from lalamo.modules import (
|
|
|
19
21
|
UpcastMode,
|
|
20
22
|
)
|
|
21
23
|
from lalamo.modules.activations import SiLU
|
|
22
|
-
from lalamo.modules.embedding import MLXQuantizedTiedEmbeddingConfig
|
|
23
24
|
from lalamo.modules.linear import MLXQuantizedLinearConfig
|
|
24
25
|
from lalamo.quantization import QuantizationMode
|
|
25
26
|
|
|
@@ -75,27 +76,37 @@ class HFQwen3Config(HuggingFaceLMConfig):
|
|
|
75
76
|
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
76
77
|
) -> DecoderConfig:
|
|
77
78
|
if isinstance(self.quantization_config, MLXQuantizationConfig):
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
79
|
+
if self.tie_word_embeddings:
|
|
80
|
+
embedding_config = MLXQuantizedTiedEmbeddingConfig(
|
|
81
|
+
input_scale=None,
|
|
82
|
+
logit_soft_cap=None,
|
|
83
|
+
group_size=self.quantization_config.group_size,
|
|
84
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
85
|
+
activation_quantization_mode=None,
|
|
86
|
+
activation_precision=activation_precision,
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
embedding_config = MLXQuantizedUntiedEmbeddingConfig(
|
|
90
|
+
input_scale=None,
|
|
91
|
+
logit_soft_cap=None,
|
|
92
|
+
group_size=self.quantization_config.group_size,
|
|
93
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
94
|
+
activation_quantization_mode=None,
|
|
95
|
+
activation_precision=activation_precision,
|
|
96
|
+
)
|
|
97
|
+
else: # noqa: PLR5501
|
|
98
|
+
if self.tie_word_embeddings:
|
|
99
|
+
embedding_config = TiedEmbeddingConfig(
|
|
100
|
+
input_scale=None,
|
|
101
|
+
logit_soft_cap=None,
|
|
102
|
+
precision=activation_precision,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
106
|
+
input_scale=None,
|
|
107
|
+
logit_soft_cap=None,
|
|
108
|
+
precision=activation_precision,
|
|
109
|
+
)
|
|
99
110
|
rope_config = UnscaledRoPEConfig(
|
|
100
111
|
precision=activation_precision,
|
|
101
112
|
base=self.rope_theta,
|
|
@@ -29,6 +29,7 @@ from lalamo.modules import (
|
|
|
29
29
|
UntiedEmbedding,
|
|
30
30
|
)
|
|
31
31
|
from lalamo.modules.classifier import Classifier
|
|
32
|
+
from lalamo.modules.embedding import MLXQuantizedUntiedEmbedding
|
|
32
33
|
from lalamo.modules.mlp import MixtureOfExperts, MLPBase
|
|
33
34
|
from lalamo.quantization import QuantizationMode
|
|
34
35
|
|
|
@@ -625,6 +626,51 @@ def load_mlx_quantized_tied_embedding(
|
|
|
625
626
|
return load_parameters(lambda m: (m.weights, m.scales, m.biases), module, (weights, scales, biases))
|
|
626
627
|
|
|
627
628
|
|
|
629
|
+
def load_mlx_quantized_untied_embedding(
|
|
630
|
+
module: MLXQuantizedUntiedEmbedding,
|
|
631
|
+
weights_dict: Mapping[str, Array],
|
|
632
|
+
embedding_path: ParameterPath,
|
|
633
|
+
lm_head_path: ParameterPath,
|
|
634
|
+
) -> MLXQuantizedUntiedEmbedding:
|
|
635
|
+
input_qweights = weights_dict[embedding_path / "weight"]
|
|
636
|
+
input_qscales = weights_dict[embedding_path / "scales"]
|
|
637
|
+
input_qbiases = weights_dict[embedding_path / "biases"]
|
|
638
|
+
output_qweights = weights_dict[lm_head_path / "weight"]
|
|
639
|
+
output_qscales = weights_dict[lm_head_path / "scales"]
|
|
640
|
+
output_qbiases = weights_dict[lm_head_path / "biases"]
|
|
641
|
+
|
|
642
|
+
input_weights = _process_quantized_tensor(
|
|
643
|
+
input_qweights,
|
|
644
|
+
module.config.embedding_quantization_mode,
|
|
645
|
+
module.activation_precision,
|
|
646
|
+
None,
|
|
647
|
+
)
|
|
648
|
+
input_scales = input_qscales.astype(module.activation_precision)
|
|
649
|
+
input_biases = input_qbiases.astype(module.activation_precision)
|
|
650
|
+
|
|
651
|
+
output_weights = _process_quantized_tensor(
|
|
652
|
+
output_qweights,
|
|
653
|
+
module.config.embedding_quantization_mode,
|
|
654
|
+
module.activation_precision,
|
|
655
|
+
None,
|
|
656
|
+
)
|
|
657
|
+
output_scales = output_qscales.astype(module.activation_precision)
|
|
658
|
+
output_biases = output_qbiases.astype(module.activation_precision)
|
|
659
|
+
|
|
660
|
+
return load_parameters(
|
|
661
|
+
lambda m: (
|
|
662
|
+
m.input_weights,
|
|
663
|
+
m.input_scales,
|
|
664
|
+
m.input_biases,
|
|
665
|
+
m.output_weights,
|
|
666
|
+
m.output_scales,
|
|
667
|
+
m.output_biases,
|
|
668
|
+
),
|
|
669
|
+
module,
|
|
670
|
+
(input_weights, input_scales, input_biases, output_weights, output_scales, output_biases),
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
|
|
628
674
|
def load_mlx_semi_quantized_untied_embedding(
|
|
629
675
|
module: MLXSemiQuantizedUntiedEmbedding,
|
|
630
676
|
weights_dict: Mapping[str, Array],
|
|
@@ -741,6 +787,8 @@ def load_huggingface_decoder(
|
|
|
741
787
|
embedding = load_tied_embedding(module.embedding, weights_dict, embedding_path)
|
|
742
788
|
elif isinstance(module.embedding, MLXQuantizedTiedEmbedding):
|
|
743
789
|
embedding = load_mlx_quantized_tied_embedding(module.embedding, weights_dict, embedding_path)
|
|
790
|
+
elif isinstance(module.embedding, MLXQuantizedUntiedEmbedding):
|
|
791
|
+
embedding = load_mlx_quantized_untied_embedding(module.embedding, weights_dict, embedding_path, lm_head_path)
|
|
744
792
|
elif isinstance(module.embedding, MLXSemiQuantizedUntiedEmbedding):
|
|
745
793
|
embedding = load_mlx_semi_quantized_untied_embedding(
|
|
746
794
|
module.embedding,
|
|
@@ -759,7 +807,7 @@ def load_huggingface_decoder(
|
|
|
759
807
|
weights_dict,
|
|
760
808
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
761
809
|
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
762
|
-
mixer_key[type(layer.config.mixer_config)],
|
|
810
|
+
mixer_key[type(layer.config.mixer_config)], # type: ignore
|
|
763
811
|
mlp_key,
|
|
764
812
|
pre_mixer_norm_key,
|
|
765
813
|
pre_mlp_norm_key,
|
|
@@ -223,6 +223,20 @@ QWEN3 = [
|
|
|
223
223
|
repo="Qwen/Qwen3-8B-AWQ",
|
|
224
224
|
config_type=HFQwen3Config,
|
|
225
225
|
),
|
|
226
|
+
ModelSpec(
|
|
227
|
+
vendor="Alibaba",
|
|
228
|
+
family="Qwen3",
|
|
229
|
+
name="Qwen3-8B-MLX-4bit",
|
|
230
|
+
size="8B",
|
|
231
|
+
quantization=QuantizationMode.UINT4,
|
|
232
|
+
repo="Qwen/Qwen3-8B-MLX-4bit",
|
|
233
|
+
config_type=HFQwen3Config,
|
|
234
|
+
configs=ConfigMap(
|
|
235
|
+
tokenizer=FileSpec("tokenizer.json", "Qwen/Qwen3-8B"),
|
|
236
|
+
tokenizer_config=FileSpec("tokenizer_config.json", "Qwen/Qwen3-8B"),
|
|
237
|
+
generation_config=FileSpec("generation_config.json", "Qwen/Qwen3-8B"),
|
|
238
|
+
),
|
|
239
|
+
),
|
|
226
240
|
ModelSpec(
|
|
227
241
|
vendor="Alibaba",
|
|
228
242
|
family="Qwen3",
|
lalamo/modules/__init__.py
CHANGED
|
@@ -18,6 +18,8 @@ from .embedding import (
|
|
|
18
18
|
EmbeddingConfig,
|
|
19
19
|
MLXQuantizedTiedEmbedding,
|
|
20
20
|
MLXQuantizedTiedEmbeddingConfig,
|
|
21
|
+
MLXQuantizedUntiedEmbedding,
|
|
22
|
+
MLXQuantizedUntiedEmbeddingConfig,
|
|
21
23
|
MLXSemiQuantizedUntiedEmbedding,
|
|
22
24
|
MLXSemiQuantizedUntiedEmbeddingConfig,
|
|
23
25
|
QuantizedTiedEmbedding,
|
|
@@ -120,6 +122,8 @@ __all__ = [
|
|
|
120
122
|
"MLXQuantizedLinearConfig",
|
|
121
123
|
"MLXQuantizedTiedEmbedding",
|
|
122
124
|
"MLXQuantizedTiedEmbeddingConfig",
|
|
125
|
+
"MLXQuantizedUntiedEmbedding",
|
|
126
|
+
"MLXQuantizedUntiedEmbeddingConfig",
|
|
123
127
|
"MLXSemiQuantizedUntiedEmbedding",
|
|
124
128
|
"MLXSemiQuantizedUntiedEmbeddingConfig",
|
|
125
129
|
"Mamba2",
|
lalamo/modules/embedding.py
CHANGED
|
@@ -24,6 +24,8 @@ __all__ = [
|
|
|
24
24
|
"EmbeddingConfig",
|
|
25
25
|
"MLXQuantizedTiedEmbedding",
|
|
26
26
|
"MLXQuantizedTiedEmbeddingConfig",
|
|
27
|
+
"MLXQuantizedUntiedEmbedding",
|
|
28
|
+
"MLXQuantizedUntiedEmbeddingConfig",
|
|
27
29
|
"MLXSemiQuantizedUntiedEmbedding",
|
|
28
30
|
"MLXSemiQuantizedUntiedEmbeddingConfig",
|
|
29
31
|
"QuantizedTiedEmbedding",
|
|
@@ -492,6 +494,172 @@ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
|
|
|
492
494
|
)
|
|
493
495
|
|
|
494
496
|
|
|
497
|
+
@dataclass(frozen=True)
|
|
498
|
+
class MLXQuantizedUntiedEmbeddingConfig(EmbeddingConfigBase):
|
|
499
|
+
group_size: int
|
|
500
|
+
embedding_quantization_mode: QuantizationMode
|
|
501
|
+
activation_quantization_mode: QuantizationMode | None
|
|
502
|
+
activation_precision: DTypeLike
|
|
503
|
+
|
|
504
|
+
def random_init(
|
|
505
|
+
self,
|
|
506
|
+
vocab_size: int,
|
|
507
|
+
model_dim: int,
|
|
508
|
+
*,
|
|
509
|
+
key: PRNGKeyArray,
|
|
510
|
+
) -> "MLXQuantizedUntiedEmbedding":
|
|
511
|
+
raise NotImplementedError
|
|
512
|
+
|
|
513
|
+
def empty(
|
|
514
|
+
self,
|
|
515
|
+
vocab_size: int,
|
|
516
|
+
model_dim: int,
|
|
517
|
+
) -> "MLXQuantizedUntiedEmbedding":
|
|
518
|
+
assert model_dim % self.group_size == 0
|
|
519
|
+
model_groups = model_dim // self.group_size
|
|
520
|
+
return MLXQuantizedUntiedEmbedding(
|
|
521
|
+
config=self,
|
|
522
|
+
input_weights=dummy_array((vocab_size, model_dim), dtype=self.activation_precision),
|
|
523
|
+
input_scales=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
|
|
524
|
+
input_biases=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
|
|
525
|
+
output_weights=dummy_array((vocab_size, model_dim), dtype=self.activation_precision),
|
|
526
|
+
output_scales=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
|
|
527
|
+
output_biases=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class MLXQuantizedUntiedEmbedding(EmbeddingBase[MLXQuantizedUntiedEmbeddingConfig]):
|
|
532
|
+
input_weights: Float[Array, "vocabulary channels"]
|
|
533
|
+
input_scales: Float[Array, "vocabulary groups"]
|
|
534
|
+
input_biases: Float[Array, "vocabulary groups"]
|
|
535
|
+
output_weights: Float[Array, "vocabulary channels"]
|
|
536
|
+
output_scales: Float[Array, "vocabulary groups"]
|
|
537
|
+
output_biases: Float[Array, "vocabulary groups"]
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def activation_precision(self) -> DTypeLike:
|
|
541
|
+
return self.config.activation_precision
|
|
542
|
+
|
|
543
|
+
@property
|
|
544
|
+
def model_dim(self) -> int:
|
|
545
|
+
_, model_dim = self.input_weights.shape
|
|
546
|
+
return model_dim
|
|
547
|
+
|
|
548
|
+
@property
|
|
549
|
+
def vocab_size(self) -> int:
|
|
550
|
+
vocab_size, _ = self.input_weights.shape
|
|
551
|
+
return vocab_size
|
|
552
|
+
|
|
553
|
+
@property
|
|
554
|
+
def int_input_weights(self) -> Int[Array, "vocabulary channels"]:
|
|
555
|
+
quantized = quantize_weights(self.input_weights, self.config.embedding_quantization_mode)
|
|
556
|
+
casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
|
|
557
|
+
|
|
558
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
559
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
560
|
+
else:
|
|
561
|
+
packed = casted
|
|
562
|
+
|
|
563
|
+
return packed
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def int_output_weights(self) -> Int[Array, "vocabulary channels"]:
|
|
567
|
+
quantized = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
|
|
568
|
+
casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
|
|
569
|
+
|
|
570
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
571
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
572
|
+
else:
|
|
573
|
+
packed = casted
|
|
574
|
+
|
|
575
|
+
return packed
|
|
576
|
+
|
|
577
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
578
|
+
quantized_weights = quantize_weights(self.input_weights, self.config.embedding_quantization_mode)
|
|
579
|
+
grouped_weights = rearrange(
|
|
580
|
+
quantized_weights,
|
|
581
|
+
"vocab (groups elements) -> vocab groups elements",
|
|
582
|
+
elements=self.config.group_size,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
scales = rearrange(self.input_scales, "vocab groups -> vocab groups 1")
|
|
586
|
+
|
|
587
|
+
biases = rearrange(self.input_biases, "vocab groups -> vocab groups 1")
|
|
588
|
+
|
|
589
|
+
scaled_grouped_weights = grouped_weights * scales + biases
|
|
590
|
+
|
|
591
|
+
result = rearrange(
|
|
592
|
+
scaled_grouped_weights,
|
|
593
|
+
"vocab groups elements -> vocab (groups elements)",
|
|
594
|
+
)
|
|
595
|
+
return result
|
|
596
|
+
|
|
597
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
598
|
+
quantized_weights = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
|
|
599
|
+
grouped_weights = rearrange(
|
|
600
|
+
quantized_weights,
|
|
601
|
+
"vocab (groups elements) -> vocab groups elements",
|
|
602
|
+
elements=self.config.group_size,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
scales = rearrange(self.output_scales, "vocab groups -> vocab groups 1")
|
|
606
|
+
|
|
607
|
+
biases = rearrange(self.output_biases, "vocab groups -> vocab groups 1")
|
|
608
|
+
|
|
609
|
+
scaled_grouped_weights = grouped_weights * scales + biases
|
|
610
|
+
|
|
611
|
+
result = rearrange(
|
|
612
|
+
scaled_grouped_weights,
|
|
613
|
+
"vocab groups elements -> vocab (groups elements)",
|
|
614
|
+
)
|
|
615
|
+
return result
|
|
616
|
+
|
|
617
|
+
@eqx.filter_jit
|
|
618
|
+
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
619
|
+
if self.config.activation_quantization_mode is not None:
|
|
620
|
+
x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
|
|
621
|
+
return super().readout(x)
|
|
622
|
+
|
|
623
|
+
def export_weights(self) -> ParameterTree:
|
|
624
|
+
return {
|
|
625
|
+
"input_weights": self.int_input_weights,
|
|
626
|
+
"input_scales": self.input_scales,
|
|
627
|
+
"input_biases": self.input_biases,
|
|
628
|
+
"output_weights": self.int_output_weights,
|
|
629
|
+
"output_scales": self.output_scales,
|
|
630
|
+
"output_biases": self.output_biases,
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
def import_weights(
|
|
634
|
+
self,
|
|
635
|
+
weights: ParameterTree[Array],
|
|
636
|
+
) -> Self:
|
|
637
|
+
assert isinstance(weights, Mapping)
|
|
638
|
+
assert isinstance(weights["input_weights"], Array)
|
|
639
|
+
assert isinstance(weights["input_scales"], Array)
|
|
640
|
+
assert isinstance(weights["input_biases"], Array)
|
|
641
|
+
assert isinstance(weights["output_weights"], Array)
|
|
642
|
+
assert isinstance(weights["output_scales"], Array)
|
|
643
|
+
assert isinstance(weights["output_biases"], Array)
|
|
644
|
+
|
|
645
|
+
unpacked_input_weights = weights["input_weights"]
|
|
646
|
+
unpacked_output_weights = weights["output_weights"]
|
|
647
|
+
|
|
648
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
649
|
+
unpacked_input_weights = jax_uint8_to_unpacked_uint4(weights["input_weights"])
|
|
650
|
+
unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
|
|
651
|
+
|
|
652
|
+
return replace(
|
|
653
|
+
self,
|
|
654
|
+
input_weights=unpacked_input_weights.astype(self.input_weights.dtype),
|
|
655
|
+
input_scales=weights["input_scales"],
|
|
656
|
+
input_biases=weights["input_biases"],
|
|
657
|
+
output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
|
|
658
|
+
output_scales=weights["output_scales"],
|
|
659
|
+
output_biases=weights["output_biases"],
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
495
663
|
@dataclass(frozen=True)
|
|
496
664
|
class MLXSemiQuantizedUntiedEmbeddingConfig(EmbeddingConfigBase):
|
|
497
665
|
group_size: int
|
|
@@ -626,6 +794,7 @@ EmbeddingConfig = (
|
|
|
626
794
|
| UntiedEmbeddingConfig
|
|
627
795
|
| QuantizedTiedEmbeddingConfig
|
|
628
796
|
| MLXQuantizedTiedEmbeddingConfig
|
|
797
|
+
| MLXQuantizedUntiedEmbeddingConfig
|
|
629
798
|
| MLXSemiQuantizedUntiedEmbeddingConfig
|
|
630
799
|
)
|
|
631
800
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=FjfGsBVSl14mNsDoFJEwXMRUq1-Kg_lessRzlJNG3KM,815
|
|
2
2
|
lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
|
|
3
3
|
lalamo/main.py,sha256=GgUT7lT48-XQuAEH7qzsDKG8Lx9iBf-sYBIRhZL9q7E,23978
|
|
4
4
|
lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
|
|
@@ -23,16 +23,16 @@ lalamo/model_import/decoder_configs/huggingface/gemma2.py,sha256=g8LH_GlSNyL04WW
|
|
|
23
23
|
lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=UXiEyNqlD0Czc5Gj3n4hNqNDp9Ml5YzH1XZ6BXj0mgU,10223
|
|
24
24
|
lalamo/model_import/decoder_configs/huggingface/gpt_oss.py,sha256=MBCoPbuWyzbJiBRtHOtpaPHJjQ1UVCAYcVrfIejTnlQ,7446
|
|
25
25
|
lalamo/model_import/decoder_configs/huggingface/lfm2.py,sha256=vrBMxtiKEg0eHNDL_bWM9odlrsab7jlMXEY8vjEB7-c,7595
|
|
26
|
-
lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=
|
|
26
|
+
lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=pGuBQTY6qpx6CriWwdsLpuTSRS7ECoTP1kt5pSKRlNQ,8549
|
|
27
27
|
lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=ANB-vQK8U-zVFubZSTDXXt2S70T5SVOGzf7eOVvPzIQ,5773
|
|
28
28
|
lalamo/model_import/decoder_configs/huggingface/mistral.py,sha256=MDGC0ivzJuUpOC11n8vFdcVzqccUyaRw_hkL74mVlAg,4599
|
|
29
29
|
lalamo/model_import/decoder_configs/huggingface/modern_bert.py,sha256=A8nNIMhPVumvPWIFR3RexRc6XkFyUd_3mmNpmvyPEGE,8816
|
|
30
30
|
lalamo/model_import/decoder_configs/huggingface/qwen2.py,sha256=n3qIANMPbtQsTtk5QEWWFZ6R85eDxR_kaZd0NDlJ3T4,5786
|
|
31
|
-
lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=
|
|
31
|
+
lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=i99mfL2DbeJ0l5aFRV84MTT-PsWf6q-8B-SGPIVGe1w,7522
|
|
32
32
|
lalamo/model_import/loaders/__init__.py,sha256=3THc1wQ4EPBzQkL_4EaKCa7Ev5Z7oczcvc4AHy9v5EI,228
|
|
33
33
|
lalamo/model_import/loaders/common.py,sha256=kkugV-bMQlN1zvGHoj3uc7z0FbXKoMtXEBTvyu4KxK4,1844
|
|
34
34
|
lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFlYrqtWyNkBU_8,9219
|
|
35
|
-
lalamo/model_import/loaders/huggingface.py,sha256=
|
|
35
|
+
lalamo/model_import/loaders/huggingface.py,sha256=qWdzoSvHvb_3prn2kwfxgnYPW2bVB0Q49m_wyRYha8Q,34677
|
|
36
36
|
lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
|
|
37
37
|
lalamo/model_import/model_specs/__init__.py,sha256=JISqwJkloQkGD2jvi1MakNEWapIwlNXXVi5giZyXB74,1275
|
|
38
38
|
lalamo/model_import/model_specs/common.py,sha256=RLySCIkmGiA1IVZgLeemssMBMo4hMYMpmBjV0cRwBb4,6586
|
|
@@ -48,18 +48,18 @@ lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0v
|
|
|
48
48
|
lalamo/model_import/model_specs/mistral.py,sha256=HAojorjOqsJn2DoMBzYRw8A70qCslhFEsE9AF5xumlg,1278
|
|
49
49
|
lalamo/model_import/model_specs/pleias.py,sha256=5sRpZGYwLdsav6bLiW-459y1Cs9iJKgKkBIuGsOxtsQ,368
|
|
50
50
|
lalamo/model_import/model_specs/polaris.py,sha256=Mw1-6bByjDmPIKlIUIV46CsmV5xUp_laI5Qquo5DmAQ,520
|
|
51
|
-
lalamo/model_import/model_specs/qwen.py,sha256=
|
|
51
|
+
lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiEIhWxUdTRfo,7563
|
|
52
52
|
lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
|
|
53
53
|
lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
|
|
54
54
|
lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
|
|
55
55
|
lalamo/models/common.py,sha256=PDteofGxjSBWYw_mPxbN1DTUba70aOURrAIjl13SSHc,2954
|
|
56
56
|
lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
|
|
57
|
-
lalamo/modules/__init__.py,sha256=
|
|
57
|
+
lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
|
|
58
58
|
lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
|
|
59
59
|
lalamo/modules/classifier.py,sha256=_jtJ3INEq1dJP5HpUmcDk9YYzpRYlQ04zvFGaWBV6Lg,12101
|
|
60
60
|
lalamo/modules/common.py,sha256=dqDEOi-C3H4U9iWUisU32RA-wRDCGuaUNGbObRBhyQM,3315
|
|
61
61
|
lalamo/modules/decoder.py,sha256=Opd3QIq1mpGr9P7sLH-Fryitlfp6ESTpcX71vgm89t0,7129
|
|
62
|
-
lalamo/modules/embedding.py,sha256=
|
|
62
|
+
lalamo/modules/embedding.py,sha256=LLiH8mTu81JSpUTj-XhsrVIUfl_GhapnXxw1yGSUBgM,28428
|
|
63
63
|
lalamo/modules/linear.py,sha256=XfIYhmpk-bwNHIzIgsL48ZUTclHD2KB4uXHMw9NTE-8,42991
|
|
64
64
|
lalamo/modules/mlp.py,sha256=bL3sQ46vCNt1MBRwlzmXZx9nQfRe4axpGe5UOFVanBI,17959
|
|
65
65
|
lalamo/modules/mlx_interop.py,sha256=FdfU_1iES-HQ9r4K0SkYwJTyvE0f-_T5ursNCjPLZKY,467
|
|
@@ -85,9 +85,9 @@ lalamo/speculator/estimator.py,sha256=4D8dPZCWsrpORb7y8pQ6VsiIg1Cblvvxe6gXCoYtcD
|
|
|
85
85
|
lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
|
|
86
86
|
lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
|
|
87
87
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
88
|
-
lalamo-0.5.
|
|
89
|
-
lalamo-0.5.
|
|
90
|
-
lalamo-0.5.
|
|
91
|
-
lalamo-0.5.
|
|
92
|
-
lalamo-0.5.
|
|
93
|
-
lalamo-0.5.
|
|
88
|
+
lalamo-0.5.16.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
89
|
+
lalamo-0.5.16.dist-info/METADATA,sha256=dcs0vT9RULTxt4cxJJmfjP-4UJi7ZkrifXAaSMAgKeU,3147
|
|
90
|
+
lalamo-0.5.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
91
|
+
lalamo-0.5.16.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
92
|
+
lalamo-0.5.16.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
93
|
+
lalamo-0.5.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|