lalamo 0.5.10__tar.gz → 0.5.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.5.10 → lalamo-0.5.11}/PKG-INFO +1 -1
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/__init__.py +1 -1
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +63 -12
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/loaders/huggingface.py +16 -4
- lalamo-0.5.11/lalamo/model_import/model_specs/lfm2.py +31 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_huggingface_model_conversion.py +3 -1
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_lfm2_models.py +2 -3
- lalamo-0.5.10/lalamo/model_import/model_specs/lfm2.py +0 -21
- {lalamo-0.5.10 → lalamo-0.5.11}/LICENSE +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/README.md +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/data/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/data/utils.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/main.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/message_processor.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/essential_ai.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/mirai.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/models/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/models/classifier.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/models/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/models/language_model.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/activations.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/classifier.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/decoder.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/linear.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/mlp.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/rope.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/attention.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/mamba.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/short_conv.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/transformer.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/transformer_layer.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/modules/utils.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/quantization.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/registry_abc.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/sampling.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/speculator/common.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/speculator/estimator.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/speculator/inference.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo/utils.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo.egg-info/SOURCES.txt +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/pyproject.toml +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/setup.cfg +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_cartesia_mlx_models.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_chat_template.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_generation.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_huggingface_models.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_mlx_models.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_model_spec.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_models.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_moe.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_parameter_tree.py +0 -0
- {lalamo-0.5.10 → lalamo-0.5.11}/tests/test_registry_abc.py +0 -0
|
@@ -9,6 +9,8 @@ from lalamo.modules import (
|
|
|
9
9
|
DecoderConfig,
|
|
10
10
|
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
|
+
MLXQuantizedLinearConfig,
|
|
13
|
+
MLXQuantizedTiedEmbeddingConfig,
|
|
12
14
|
NormalizationConfig,
|
|
13
15
|
SeparableCausalConvConfig,
|
|
14
16
|
ShortConvConfig,
|
|
@@ -20,14 +22,21 @@ from lalamo.modules import (
|
|
|
20
22
|
UntiedEmbeddingConfig,
|
|
21
23
|
UpcastMode,
|
|
22
24
|
)
|
|
25
|
+
from lalamo.quantization import QuantizationMode
|
|
23
26
|
|
|
24
27
|
from .common import HuggingFaceLMConfig
|
|
25
28
|
|
|
26
29
|
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class QuantizationConfig:
|
|
32
|
+
group_size: int
|
|
33
|
+
bits: int
|
|
34
|
+
|
|
35
|
+
|
|
27
36
|
@dataclass(frozen=True)
|
|
28
37
|
class HFLFM2Config(HuggingFaceLMConfig):
|
|
29
38
|
architectures: list[Literal["Lfm2ForCausalLM"]]
|
|
30
|
-
block_auto_adjust_ff_dim:
|
|
39
|
+
block_auto_adjust_ff_dim: bool
|
|
31
40
|
block_dim: int
|
|
32
41
|
block_ff_dim: int
|
|
33
42
|
block_ffn_dim_multiplier: float
|
|
@@ -38,16 +47,14 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
38
47
|
block_use_swiglu: bool
|
|
39
48
|
block_use_xavier_init: bool
|
|
40
49
|
bos_token_id: int
|
|
41
|
-
conv_L_cache: int
|
|
42
|
-
conv_bias:
|
|
50
|
+
conv_L_cache: int # noqa: N815
|
|
51
|
+
conv_bias: bool
|
|
43
52
|
conv_dim: int
|
|
44
53
|
conv_dim_out: int
|
|
45
54
|
conv_use_xavier_init: bool
|
|
46
55
|
eos_token_id: int
|
|
47
56
|
hidden_size: int
|
|
48
57
|
initializer_range: float
|
|
49
|
-
intermediate_size: int
|
|
50
|
-
layer_types: list[Literal["conv", "full_attention"]]
|
|
51
58
|
max_position_embeddings: int
|
|
52
59
|
model_type: Literal["lfm2"]
|
|
53
60
|
norm_eps: float
|
|
@@ -57,14 +64,21 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
57
64
|
num_key_value_heads: int
|
|
58
65
|
pad_token_id: int
|
|
59
66
|
rope_theta: float
|
|
60
|
-
theta: float
|
|
61
|
-
tie_embedding: bool
|
|
62
67
|
torch_dtype: Literal["bfloat16"]
|
|
63
68
|
transformers_version: str
|
|
64
69
|
use_cache: bool
|
|
65
70
|
use_pos_enc: bool
|
|
66
71
|
vocab_size: int
|
|
67
72
|
|
|
73
|
+
intermediate_size: int | None = None
|
|
74
|
+
layer_types: list[Literal["conv", "full_attention"]] | None = None
|
|
75
|
+
full_attn_idxs: list[int] | None = None
|
|
76
|
+
tie_embedding: bool = True
|
|
77
|
+
theta: float | None = None
|
|
78
|
+
|
|
79
|
+
quantization: QuantizationConfig | None = None
|
|
80
|
+
quantization_config: QuantizationConfig | None = None
|
|
81
|
+
|
|
68
82
|
def to_decoder_config(
|
|
69
83
|
self,
|
|
70
84
|
context_length: int | None,
|
|
@@ -74,7 +88,18 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
74
88
|
) -> DecoderConfig:
|
|
75
89
|
assert self.num_attention_heads == self.num_heads
|
|
76
90
|
|
|
77
|
-
if self.
|
|
91
|
+
if self.quantization_config is not None:
|
|
92
|
+
assert self.tie_embedding
|
|
93
|
+
|
|
94
|
+
embedding_config = MLXQuantizedTiedEmbeddingConfig(
|
|
95
|
+
input_scale=None,
|
|
96
|
+
logit_soft_cap=None,
|
|
97
|
+
group_size=self.quantization_config.group_size,
|
|
98
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
99
|
+
activation_quantization_mode=None,
|
|
100
|
+
activation_precision=activation_precision,
|
|
101
|
+
)
|
|
102
|
+
elif self.tie_embedding:
|
|
78
103
|
embedding_config = TiedEmbeddingConfig(
|
|
79
104
|
input_scale=None,
|
|
80
105
|
logit_soft_cap=None,
|
|
@@ -93,7 +118,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
93
118
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
94
119
|
)
|
|
95
120
|
|
|
96
|
-
|
|
121
|
+
if self.quantization_config is None:
|
|
122
|
+
linear_config = FullPrecisionLinearConfig(activation_precision)
|
|
123
|
+
else:
|
|
124
|
+
linear_config = MLXQuantizedLinearConfig(
|
|
125
|
+
group_size=self.quantization_config.group_size,
|
|
126
|
+
weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
127
|
+
activation_quantization_mode=None,
|
|
128
|
+
activation_precision=activation_precision,
|
|
129
|
+
)
|
|
97
130
|
|
|
98
131
|
block_norm_config = NormalizationConfig(
|
|
99
132
|
scale_precision=activation_precision,
|
|
@@ -123,7 +156,7 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
123
156
|
|
|
124
157
|
short_conv_config = ShortConvConfig(
|
|
125
158
|
in_projection_config=linear_config,
|
|
126
|
-
conv_config=SeparableCausalConvConfig(activation_precision, has_biases=
|
|
159
|
+
conv_config=SeparableCausalConvConfig(activation_precision, has_biases=self.conv_bias),
|
|
127
160
|
out_projection_config=linear_config,
|
|
128
161
|
kernel_size=self.conv_L_cache,
|
|
129
162
|
)
|
|
@@ -137,6 +170,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
137
170
|
gate_clipping=None,
|
|
138
171
|
)
|
|
139
172
|
|
|
173
|
+
if self.layer_types is not None:
|
|
174
|
+
layer_types = self.layer_types
|
|
175
|
+
elif self.full_attn_idxs is not None:
|
|
176
|
+
layer_types = [
|
|
177
|
+
"full_attention" if i in self.full_attn_idxs else "conv" for i in range(self.num_hidden_layers)
|
|
178
|
+
]
|
|
179
|
+
else:
|
|
180
|
+
raise RuntimeError("Either layer_types or full_attn_idxs must be present.")
|
|
181
|
+
|
|
140
182
|
layer_configs = [
|
|
141
183
|
TransformerLayerConfig(
|
|
142
184
|
pre_mixer_norm_config=block_norm_config,
|
|
@@ -145,7 +187,8 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
145
187
|
pre_mlp_norm_config=block_norm_config,
|
|
146
188
|
mlp_config=mlp_config,
|
|
147
189
|
post_mlp_norm_config=None,
|
|
148
|
-
)
|
|
190
|
+
)
|
|
191
|
+
for layer_type in layer_types
|
|
149
192
|
]
|
|
150
193
|
|
|
151
194
|
output_norm_config = NormalizationConfig(
|
|
@@ -157,13 +200,21 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
157
200
|
subtract_mean=False,
|
|
158
201
|
)
|
|
159
202
|
|
|
203
|
+
if self.intermediate_size is not None:
|
|
204
|
+
hidden_dim = self.intermediate_size
|
|
205
|
+
else:
|
|
206
|
+
hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
|
|
207
|
+
hidden_dim = int(
|
|
208
|
+
(hidden_dim_adjusted + self.block_multiple_of - 1) // self.block_multiple_of * self.block_multiple_of,
|
|
209
|
+
)
|
|
210
|
+
|
|
160
211
|
transformer_config = TransformerConfig(
|
|
161
212
|
global_rope_config=rope_config,
|
|
162
213
|
local_rope_config=None,
|
|
163
214
|
layer_configs=tuple(layer_configs),
|
|
164
215
|
output_norm_config=output_norm_config,
|
|
165
216
|
model_dim=self.hidden_size,
|
|
166
|
-
hidden_dim=
|
|
217
|
+
hidden_dim=hidden_dim,
|
|
167
218
|
context_length=context_length or self.max_position_embeddings,
|
|
168
219
|
)
|
|
169
220
|
|
|
@@ -18,6 +18,7 @@ from lalamo.modules import (
|
|
|
18
18
|
Mamba2Config,
|
|
19
19
|
MLXQuantizedLinear,
|
|
20
20
|
MLXQuantizedTiedEmbedding,
|
|
21
|
+
MLXQuantizedTiedEmbeddingConfig,
|
|
21
22
|
MLXSemiQuantizedUntiedEmbedding,
|
|
22
23
|
Normalization,
|
|
23
24
|
SeparableCausalConv,
|
|
@@ -411,6 +412,7 @@ def _load_conv(
|
|
|
411
412
|
conv_module: SeparableCausalConv,
|
|
412
413
|
weights_dict: Mapping[str, Array],
|
|
413
414
|
path: ParameterPath,
|
|
415
|
+
permute_conv: bool,
|
|
414
416
|
) -> SeparableCausalConv:
|
|
415
417
|
weight_path = path / "conv1d" / "weight"
|
|
416
418
|
if weight_path not in weights_dict:
|
|
@@ -422,6 +424,8 @@ def _load_conv(
|
|
|
422
424
|
|
|
423
425
|
if weight_path is not None:
|
|
424
426
|
raw = weights_dict[weight_path]
|
|
427
|
+
if permute_conv:
|
|
428
|
+
raw = jnp.matrix_transpose(raw)
|
|
425
429
|
conv_weight = raw.squeeze(1) if raw.ndim == 3 else raw
|
|
426
430
|
else:
|
|
427
431
|
conv_weight = conv_module.weights
|
|
@@ -450,10 +454,11 @@ def load_mamba2(
|
|
|
450
454
|
module: Mamba2,
|
|
451
455
|
weights_dict: Mapping[str, Array],
|
|
452
456
|
path: ParameterPath,
|
|
457
|
+
permute_conv: bool,
|
|
453
458
|
) -> Mamba2:
|
|
454
459
|
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
455
460
|
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
456
|
-
conv = _load_conv(module.conv, weights_dict, path)
|
|
461
|
+
conv = _load_conv(module.conv, weights_dict, path, permute_conv)
|
|
457
462
|
|
|
458
463
|
skip_connection_weight_path = path / "D"
|
|
459
464
|
if skip_connection_weight_path in weights_dict:
|
|
@@ -484,10 +489,11 @@ def load_short_conv(
|
|
|
484
489
|
module: ShortConv,
|
|
485
490
|
weights_dict: Mapping[str, Array],
|
|
486
491
|
path: ParameterPath,
|
|
492
|
+
permute_conv: bool,
|
|
487
493
|
) -> ShortConv:
|
|
488
494
|
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
489
495
|
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
490
|
-
conv = _load_conv(module.conv, weights_dict, path)
|
|
496
|
+
conv = _load_conv(module.conv, weights_dict, path, permute_conv)
|
|
491
497
|
|
|
492
498
|
return load_parameters(
|
|
493
499
|
lambda m: (m.in_projection, m.out_projection, m.conv),
|
|
@@ -508,6 +514,7 @@ def load_transformer_layer(
|
|
|
508
514
|
up_proj_key: str,
|
|
509
515
|
gate_proj_key: str,
|
|
510
516
|
down_proj_key: str,
|
|
517
|
+
permute_conv: bool,
|
|
511
518
|
) -> TransformerLayer:
|
|
512
519
|
if module.pre_mixer_norm is not None:
|
|
513
520
|
pre_attention_norm = load_rmsnorm(
|
|
@@ -522,9 +529,9 @@ def load_transformer_layer(
|
|
|
522
529
|
if isinstance(module.mixer, Attention):
|
|
523
530
|
mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
524
531
|
elif isinstance(module.mixer, Mamba2):
|
|
525
|
-
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
532
|
+
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
|
|
526
533
|
elif isinstance(module.mixer, ShortConv):
|
|
527
|
-
mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
534
|
+
mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
|
|
528
535
|
else:
|
|
529
536
|
mixer = module.mixer
|
|
530
537
|
|
|
@@ -678,6 +685,7 @@ def load_huggingface_decoder(
|
|
|
678
685
|
embedding_path = decoder_path / "embedding"
|
|
679
686
|
pre_mixer_norm_key = "input_layernorm"
|
|
680
687
|
mixer_key = {Mamba2Config: "mixer"}
|
|
688
|
+
permute_conv = False
|
|
681
689
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
682
690
|
mlp_key = "mlp"
|
|
683
691
|
up_proj_key = "up_proj"
|
|
@@ -691,6 +699,7 @@ def load_huggingface_decoder(
|
|
|
691
699
|
embedding_path = base_path / "embedding.encoder"
|
|
692
700
|
pre_mixer_norm_key = "norm"
|
|
693
701
|
mixer_key = {Mamba2Config: "layer"}
|
|
702
|
+
permute_conv = False
|
|
694
703
|
pre_mlp_norm_key = "norm"
|
|
695
704
|
mlp_key = "layer"
|
|
696
705
|
up_proj_key = "gate_proj"
|
|
@@ -704,6 +713,7 @@ def load_huggingface_decoder(
|
|
|
704
713
|
embedding_path = decoder_path / "embed_tokens"
|
|
705
714
|
pre_mixer_norm_key = "operator_norm"
|
|
706
715
|
mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
|
|
716
|
+
permute_conv = isinstance(module.config.embedding_config, MLXQuantizedTiedEmbeddingConfig)
|
|
707
717
|
pre_mlp_norm_key = "ffn_norm"
|
|
708
718
|
mlp_key = "feed_forward"
|
|
709
719
|
up_proj_key = "w3"
|
|
@@ -717,6 +727,7 @@ def load_huggingface_decoder(
|
|
|
717
727
|
embedding_path = decoder_path / "embed_tokens"
|
|
718
728
|
pre_mixer_norm_key = "input_layernorm"
|
|
719
729
|
mixer_key = {AttentionConfig: "self_attn"}
|
|
730
|
+
permute_conv = False
|
|
720
731
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
721
732
|
mlp_key = "mlp"
|
|
722
733
|
up_proj_key = "up_proj"
|
|
@@ -755,6 +766,7 @@ def load_huggingface_decoder(
|
|
|
755
766
|
up_proj_key,
|
|
756
767
|
gate_proj_key,
|
|
757
768
|
down_proj_key,
|
|
769
|
+
permute_conv,
|
|
758
770
|
)
|
|
759
771
|
for i, layer in enumerate(module.transformer.layers)
|
|
760
772
|
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
2
|
+
from lalamo.quantization import QuantizationMode
|
|
3
|
+
|
|
4
|
+
from .common import ConfigMap, FileSpec, ModelSpec
|
|
5
|
+
|
|
6
|
+
__all__ = ["LFM2_MODELS"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _lfm2_repo(size: str, quantization: QuantizationMode | None) -> tuple[str, str]:
|
|
10
|
+
organization = "LiquidAI" if quantization is None else "mlx-community"
|
|
11
|
+
name = f"LFM2-{size}{f'-{quantization.bits}bit' if quantization is not None else ''}"
|
|
12
|
+
return (organization, name)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
LFM2_MODELS = [
|
|
16
|
+
ModelSpec(
|
|
17
|
+
vendor="LiquidAI",
|
|
18
|
+
family="LFM2",
|
|
19
|
+
name=_lfm2_repo(size, quantization)[1],
|
|
20
|
+
size=size,
|
|
21
|
+
repo="/".join(_lfm2_repo(size, quantization)),
|
|
22
|
+
config_type=HFLFM2Config,
|
|
23
|
+
quantization=quantization,
|
|
24
|
+
configs=ConfigMap(
|
|
25
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
26
|
+
),
|
|
27
|
+
use_cases=tuple(),
|
|
28
|
+
)
|
|
29
|
+
for size in ["350M", "700M", "1.2B", "2.6B"]
|
|
30
|
+
for quantization in [None, *([QuantizationMode.UINT4, QuantizationMode.UINT8] if size != "2.6B" else [])]
|
|
31
|
+
]
|
|
@@ -14,6 +14,7 @@ from safetensors.flax import save_file
|
|
|
14
14
|
from lalamo.common import flatten_parameters
|
|
15
15
|
from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, import_model
|
|
16
16
|
from lalamo.model_import.model_specs import ModelType
|
|
17
|
+
from lalamo.model_import.model_specs.lfm2 import LFM2_MODELS
|
|
17
18
|
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
18
19
|
from lalamo.modules import config_converter
|
|
19
20
|
from tests.test_models import DType, ModelTestSpec
|
|
@@ -27,7 +28,8 @@ MODEL_LIST: list[ModelTestSpec] = [
|
|
|
27
28
|
ModelTestSpec("meta-llama/Llama-3.2-1B-Instruct", DType.FLOAT32),
|
|
28
29
|
ModelTestSpec("cartesia-ai/Llamba-1B", DType.FLOAT32),
|
|
29
30
|
ModelTestSpec("cartesia-ai/Llamba-1B-4bit-mlx", DType.FLOAT32),
|
|
30
|
-
]
|
|
31
|
+
] + \
|
|
32
|
+
[ModelTestSpec(model.repo, DType.FLOAT32) for model in LFM2_MODELS]
|
|
31
33
|
|
|
32
34
|
MODEL_LIST += (
|
|
33
35
|
[
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
|
|
3
|
+
from lalamo.model_import.model_specs.lfm2 import LFM2_MODELS
|
|
3
4
|
from tests.lfm2_tracer import LFM2DecoderTracer
|
|
4
5
|
from tests.test_models import DType, ModelTestSpec, _test_model
|
|
5
6
|
|
|
6
|
-
MODEL_LIST = [
|
|
7
|
-
ModelTestSpec("LiquidAI/LFM2-2.6B", DType.FLOAT32),
|
|
8
|
-
]
|
|
7
|
+
MODEL_LIST = [ModelTestSpec(model.repo, DType.FLOAT32) for model in LFM2_MODELS if model.quantization is None]
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
@pytest.mark.parametrize("test_spec", MODEL_LIST, ids=[m.model_repo for m in MODEL_LIST])
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
2
|
-
|
|
3
|
-
from .common import ConfigMap, FileSpec, ModelSpec
|
|
4
|
-
|
|
5
|
-
__all__ = ["LFM2_MODELS"]
|
|
6
|
-
|
|
7
|
-
LFM2_MODELS = [
|
|
8
|
-
ModelSpec(
|
|
9
|
-
vendor="LiquidAI",
|
|
10
|
-
family="LFM2",
|
|
11
|
-
name="LFM2-2.6B",
|
|
12
|
-
size="2.6B",
|
|
13
|
-
repo="LiquidAI/LFM2-2.6B",
|
|
14
|
-
config_type=HFLFM2Config,
|
|
15
|
-
quantization=None,
|
|
16
|
-
configs=ConfigMap(
|
|
17
|
-
chat_template=FileSpec("chat_template.jinja"),
|
|
18
|
-
),
|
|
19
|
-
use_cases=tuple(),
|
|
20
|
-
),
|
|
21
|
-
]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{lalamo-0.5.10 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|