lalamo 0.5.9__tar.gz → 0.5.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.5.9 → lalamo-0.5.11}/PKG-INFO +1 -1
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/__init__.py +1 -1
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/__init__.py +2 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo-0.5.11/lalamo/model_import/decoder_configs/huggingface/lfm2.py +225 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/huggingface.py +83 -10
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/__init__.py +2 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/common.py +1 -0
- lalamo-0.5.11/lalamo/model_import/model_specs/lfm2.py +31 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/__init__.py +6 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/__init__.py +15 -2
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/common.py +1 -1
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/mamba.py +2 -2
- lalamo-0.5.11/lalamo/modules/token_mixers/short_conv.py +168 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/__init__.py +2 -0
- lalamo-0.5.11/lalamo/modules/token_mixers/state/short_conv_state.py +33 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/transformer.py +18 -6
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/transformer_layer.py +1 -1
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/SOURCES.txt +5 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_huggingface_model_conversion.py +3 -1
- lalamo-0.5.11/tests/test_lfm2_models.py +13 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/LICENSE +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/README.md +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/__init__.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/utils.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/main.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/message_processor.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/essential_ai.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/mirai.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/__init__.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/classifier.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/language_model.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/activations.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/classifier.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/decoder.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/linear.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/mlp.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/rope.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/attention.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/utils.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/quantization.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/registry_abc.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/sampling.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/common.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/estimator.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/inference.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/utils.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/pyproject.toml +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/setup.cfg +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_cartesia_mlx_models.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_chat_template.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_generation.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_huggingface_models.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_mlx_models.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_model_spec.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_models.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_moe.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_parameter_tree.py +0 -0
- {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_registry_abc.py +0 -0
|
@@ -6,6 +6,7 @@ from .huggingface import (
|
|
|
6
6
|
HFGemma3Config,
|
|
7
7
|
HFGemma3TextConfig,
|
|
8
8
|
HFGPTOssConfig,
|
|
9
|
+
HFLFM2Config,
|
|
9
10
|
HFLlamaConfig,
|
|
10
11
|
HFLlambaConfig,
|
|
11
12
|
HFMistralConfig,
|
|
@@ -22,6 +23,7 @@ __all__ = [
|
|
|
22
23
|
"HFGemma2Config",
|
|
23
24
|
"HFGemma3Config",
|
|
24
25
|
"HFGemma3TextConfig",
|
|
26
|
+
"HFLFM2Config",
|
|
25
27
|
"HFLlamaConfig",
|
|
26
28
|
"HFLlambaConfig",
|
|
27
29
|
"HFMistralConfig",
|
|
@@ -2,6 +2,7 @@ from .common import HuggingFaceLMConfig
|
|
|
2
2
|
from .gemma2 import HFGemma2Config
|
|
3
3
|
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
4
|
from .gpt_oss import HFGPTOssConfig
|
|
5
|
+
from .lfm2 import HFLFM2Config
|
|
5
6
|
from .llama import HFLlamaConfig
|
|
6
7
|
from .llamba import HFLlambaConfig
|
|
7
8
|
from .mistral import HFMistralConfig
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"HFGemma2Config",
|
|
15
16
|
"HFGemma3Config",
|
|
16
17
|
"HFGemma3TextConfig",
|
|
18
|
+
"HFLFM2Config",
|
|
17
19
|
"HFLlamaConfig",
|
|
18
20
|
"HFLlambaConfig",
|
|
19
21
|
"HFMistralConfig",
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from jaxtyping import DTypeLike
|
|
6
|
+
|
|
7
|
+
from lalamo.modules import (
|
|
8
|
+
AttentionConfig,
|
|
9
|
+
DecoderConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
|
+
FullPrecisionLinearConfig,
|
|
12
|
+
MLXQuantizedLinearConfig,
|
|
13
|
+
MLXQuantizedTiedEmbeddingConfig,
|
|
14
|
+
NormalizationConfig,
|
|
15
|
+
SeparableCausalConvConfig,
|
|
16
|
+
ShortConvConfig,
|
|
17
|
+
SiLU,
|
|
18
|
+
TiedEmbeddingConfig,
|
|
19
|
+
TransformerConfig,
|
|
20
|
+
TransformerLayerConfig,
|
|
21
|
+
UnscaledRoPEConfig,
|
|
22
|
+
UntiedEmbeddingConfig,
|
|
23
|
+
UpcastMode,
|
|
24
|
+
)
|
|
25
|
+
from lalamo.quantization import QuantizationMode
|
|
26
|
+
|
|
27
|
+
from .common import HuggingFaceLMConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class QuantizationConfig:
|
|
32
|
+
group_size: int
|
|
33
|
+
bits: int
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class HFLFM2Config(HuggingFaceLMConfig):
|
|
38
|
+
architectures: list[Literal["Lfm2ForCausalLM"]]
|
|
39
|
+
block_auto_adjust_ff_dim: bool
|
|
40
|
+
block_dim: int
|
|
41
|
+
block_ff_dim: int
|
|
42
|
+
block_ffn_dim_multiplier: float
|
|
43
|
+
block_mlp_init_scale: float
|
|
44
|
+
block_multiple_of: int
|
|
45
|
+
block_norm_eps: float
|
|
46
|
+
block_out_init_scale: float
|
|
47
|
+
block_use_swiglu: bool
|
|
48
|
+
block_use_xavier_init: bool
|
|
49
|
+
bos_token_id: int
|
|
50
|
+
conv_L_cache: int # noqa: N815
|
|
51
|
+
conv_bias: bool
|
|
52
|
+
conv_dim: int
|
|
53
|
+
conv_dim_out: int
|
|
54
|
+
conv_use_xavier_init: bool
|
|
55
|
+
eos_token_id: int
|
|
56
|
+
hidden_size: int
|
|
57
|
+
initializer_range: float
|
|
58
|
+
max_position_embeddings: int
|
|
59
|
+
model_type: Literal["lfm2"]
|
|
60
|
+
norm_eps: float
|
|
61
|
+
num_attention_heads: int
|
|
62
|
+
num_heads: int
|
|
63
|
+
num_hidden_layers: int
|
|
64
|
+
num_key_value_heads: int
|
|
65
|
+
pad_token_id: int
|
|
66
|
+
rope_theta: float
|
|
67
|
+
torch_dtype: Literal["bfloat16"]
|
|
68
|
+
transformers_version: str
|
|
69
|
+
use_cache: bool
|
|
70
|
+
use_pos_enc: bool
|
|
71
|
+
vocab_size: int
|
|
72
|
+
|
|
73
|
+
intermediate_size: int | None = None
|
|
74
|
+
layer_types: list[Literal["conv", "full_attention"]] | None = None
|
|
75
|
+
full_attn_idxs: list[int] | None = None
|
|
76
|
+
tie_embedding: bool = True
|
|
77
|
+
theta: float | None = None
|
|
78
|
+
|
|
79
|
+
quantization: QuantizationConfig | None = None
|
|
80
|
+
quantization_config: QuantizationConfig | None = None
|
|
81
|
+
|
|
82
|
+
def to_decoder_config(
|
|
83
|
+
self,
|
|
84
|
+
context_length: int | None,
|
|
85
|
+
activation_precision: DTypeLike,
|
|
86
|
+
accumulation_precision: DTypeLike,
|
|
87
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
88
|
+
) -> DecoderConfig:
|
|
89
|
+
assert self.num_attention_heads == self.num_heads
|
|
90
|
+
|
|
91
|
+
if self.quantization_config is not None:
|
|
92
|
+
assert self.tie_embedding
|
|
93
|
+
|
|
94
|
+
embedding_config = MLXQuantizedTiedEmbeddingConfig(
|
|
95
|
+
input_scale=None,
|
|
96
|
+
logit_soft_cap=None,
|
|
97
|
+
group_size=self.quantization_config.group_size,
|
|
98
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
99
|
+
activation_quantization_mode=None,
|
|
100
|
+
activation_precision=activation_precision,
|
|
101
|
+
)
|
|
102
|
+
elif self.tie_embedding:
|
|
103
|
+
embedding_config = TiedEmbeddingConfig(
|
|
104
|
+
input_scale=None,
|
|
105
|
+
logit_soft_cap=None,
|
|
106
|
+
precision=activation_precision,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
110
|
+
input_scale=None,
|
|
111
|
+
logit_soft_cap=None,
|
|
112
|
+
precision=activation_precision,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
rope_config = UnscaledRoPEConfig(
|
|
116
|
+
precision=activation_precision,
|
|
117
|
+
base=self.rope_theta,
|
|
118
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if self.quantization_config is None:
|
|
122
|
+
linear_config = FullPrecisionLinearConfig(activation_precision)
|
|
123
|
+
else:
|
|
124
|
+
linear_config = MLXQuantizedLinearConfig(
|
|
125
|
+
group_size=self.quantization_config.group_size,
|
|
126
|
+
weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
127
|
+
activation_quantization_mode=None,
|
|
128
|
+
activation_precision=activation_precision,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
block_norm_config = NormalizationConfig(
|
|
132
|
+
scale_precision=activation_precision,
|
|
133
|
+
accumulation_precision=accumulation_precision,
|
|
134
|
+
epsilon=self.block_norm_eps,
|
|
135
|
+
scale_offset=None,
|
|
136
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
137
|
+
subtract_mean=False,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
attention_config = AttentionConfig(
|
|
141
|
+
qkv_projection_config=linear_config,
|
|
142
|
+
out_projection_config=linear_config,
|
|
143
|
+
query_norm_config=block_norm_config,
|
|
144
|
+
key_norm_config=block_norm_config,
|
|
145
|
+
num_heads=self.num_attention_heads,
|
|
146
|
+
num_groups=self.num_key_value_heads,
|
|
147
|
+
head_dim=self.hidden_size // self.num_heads,
|
|
148
|
+
is_causal=True,
|
|
149
|
+
scale=None,
|
|
150
|
+
sliding_window_size=None,
|
|
151
|
+
logit_soft_cap=None,
|
|
152
|
+
has_sinks=False,
|
|
153
|
+
has_qkv_biases=False,
|
|
154
|
+
has_out_biases=False,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
short_conv_config = ShortConvConfig(
|
|
158
|
+
in_projection_config=linear_config,
|
|
159
|
+
conv_config=SeparableCausalConvConfig(activation_precision, has_biases=self.conv_bias),
|
|
160
|
+
out_projection_config=linear_config,
|
|
161
|
+
kernel_size=self.conv_L_cache,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
mlp_config = DenseMLPConfig(
|
|
165
|
+
linear_config=linear_config,
|
|
166
|
+
activation=SiLU(),
|
|
167
|
+
has_up_biases=False,
|
|
168
|
+
has_down_biases=False,
|
|
169
|
+
up_clipping=None,
|
|
170
|
+
gate_clipping=None,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if self.layer_types is not None:
|
|
174
|
+
layer_types = self.layer_types
|
|
175
|
+
elif self.full_attn_idxs is not None:
|
|
176
|
+
layer_types = [
|
|
177
|
+
"full_attention" if i in self.full_attn_idxs else "conv" for i in range(self.num_hidden_layers)
|
|
178
|
+
]
|
|
179
|
+
else:
|
|
180
|
+
raise RuntimeError("Either layer_types or full_attn_idxs must be present.")
|
|
181
|
+
|
|
182
|
+
layer_configs = [
|
|
183
|
+
TransformerLayerConfig(
|
|
184
|
+
pre_mixer_norm_config=block_norm_config,
|
|
185
|
+
mixer_config={"conv": short_conv_config, "full_attention": attention_config}[layer_type],
|
|
186
|
+
post_mixer_norm_config=None,
|
|
187
|
+
pre_mlp_norm_config=block_norm_config,
|
|
188
|
+
mlp_config=mlp_config,
|
|
189
|
+
post_mlp_norm_config=None,
|
|
190
|
+
)
|
|
191
|
+
for layer_type in layer_types
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
output_norm_config = NormalizationConfig(
|
|
195
|
+
scale_precision=activation_precision,
|
|
196
|
+
accumulation_precision=accumulation_precision,
|
|
197
|
+
epsilon=self.norm_eps,
|
|
198
|
+
scale_offset=None,
|
|
199
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
200
|
+
subtract_mean=False,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if self.intermediate_size is not None:
|
|
204
|
+
hidden_dim = self.intermediate_size
|
|
205
|
+
else:
|
|
206
|
+
hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
|
|
207
|
+
hidden_dim = int(
|
|
208
|
+
(hidden_dim_adjusted + self.block_multiple_of - 1) // self.block_multiple_of * self.block_multiple_of,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
transformer_config = TransformerConfig(
|
|
212
|
+
global_rope_config=rope_config,
|
|
213
|
+
local_rope_config=None,
|
|
214
|
+
layer_configs=tuple(layer_configs),
|
|
215
|
+
output_norm_config=output_norm_config,
|
|
216
|
+
model_dim=self.hidden_size,
|
|
217
|
+
hidden_dim=hidden_dim,
|
|
218
|
+
context_length=context_length or self.max_position_embeddings,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return DecoderConfig(
|
|
222
|
+
embedding_config=embedding_config,
|
|
223
|
+
transformer_config=transformer_config,
|
|
224
|
+
vocab_size=self.vocab_size,
|
|
225
|
+
)
|
|
@@ -8,17 +8,22 @@ from jaxtyping import Array, DTypeLike
|
|
|
8
8
|
from lalamo.common import ParameterPath
|
|
9
9
|
from lalamo.modules import (
|
|
10
10
|
Attention,
|
|
11
|
+
AttentionConfig,
|
|
11
12
|
Decoder,
|
|
12
13
|
DenseMLP,
|
|
13
14
|
FullPrecisionLinear,
|
|
14
15
|
GroupQuantizedLinear,
|
|
15
16
|
LinearBase,
|
|
16
17
|
Mamba2,
|
|
18
|
+
Mamba2Config,
|
|
17
19
|
MLXQuantizedLinear,
|
|
18
20
|
MLXQuantizedTiedEmbedding,
|
|
21
|
+
MLXQuantizedTiedEmbeddingConfig,
|
|
19
22
|
MLXSemiQuantizedUntiedEmbedding,
|
|
20
23
|
Normalization,
|
|
21
24
|
SeparableCausalConv,
|
|
25
|
+
ShortConv,
|
|
26
|
+
ShortConvConfig,
|
|
22
27
|
TiedEmbedding,
|
|
23
28
|
TransformerLayer,
|
|
24
29
|
UntiedEmbedding,
|
|
@@ -345,21 +350,42 @@ def load_attention(
|
|
|
345
350
|
weights_dict: Mapping[str, Array],
|
|
346
351
|
path: ParameterPath,
|
|
347
352
|
) -> Attention:
|
|
353
|
+
if (path / "o_proj.weight") in weights_dict:
|
|
354
|
+
o_proj_name = "o_proj"
|
|
355
|
+
elif (path / "out_proj.weight") in weights_dict:
|
|
356
|
+
o_proj_name = "out_proj"
|
|
357
|
+
else:
|
|
358
|
+
raise NotImplementedError("Can't determine attention output projection name")
|
|
359
|
+
|
|
348
360
|
qkv_projection = load_linear(
|
|
349
361
|
module.qkv_projection,
|
|
350
362
|
weights_dict,
|
|
351
363
|
path,
|
|
352
364
|
sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
|
|
353
365
|
)
|
|
354
|
-
out_projection = load_linear(module.out_projection, weights_dict, path /
|
|
366
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / o_proj_name)
|
|
355
367
|
|
|
356
368
|
if module.query_norm is not None:
|
|
357
|
-
|
|
369
|
+
if (path / "q_norm.weight") in weights_dict:
|
|
370
|
+
q_norm_name = "q_norm"
|
|
371
|
+
elif (path / "q_layernorm.weight") in weights_dict:
|
|
372
|
+
q_norm_name = "q_layernorm"
|
|
373
|
+
else:
|
|
374
|
+
raise NotImplementedError("Can't determine attention query projection parameter name")
|
|
375
|
+
|
|
376
|
+
query_norm = load_rmsnorm(module.query_norm, weights_dict, path / q_norm_name)
|
|
358
377
|
else:
|
|
359
378
|
query_norm = None
|
|
360
379
|
|
|
361
380
|
if module.key_norm is not None:
|
|
362
|
-
|
|
381
|
+
if (path / "k_norm.weight") in weights_dict:
|
|
382
|
+
k_norm_name = "k_norm"
|
|
383
|
+
elif (path / "k_layernorm.weight") in weights_dict:
|
|
384
|
+
k_norm_name = "k_layernorm"
|
|
385
|
+
else:
|
|
386
|
+
raise NotImplementedError("Can't determine attention key projection parameter name")
|
|
387
|
+
|
|
388
|
+
key_norm = load_rmsnorm(module.key_norm, weights_dict, path / k_norm_name)
|
|
363
389
|
else:
|
|
364
390
|
key_norm = None
|
|
365
391
|
|
|
@@ -382,19 +408,24 @@ def load_attention(
|
|
|
382
408
|
)
|
|
383
409
|
|
|
384
410
|
|
|
385
|
-
def
|
|
411
|
+
def _load_conv(
|
|
386
412
|
conv_module: SeparableCausalConv,
|
|
387
413
|
weights_dict: Mapping[str, Array],
|
|
388
414
|
path: ParameterPath,
|
|
415
|
+
permute_conv: bool,
|
|
389
416
|
) -> SeparableCausalConv:
|
|
390
417
|
weight_path = path / "conv1d" / "weight"
|
|
391
418
|
if weight_path not in weights_dict:
|
|
392
419
|
weight_path = path / "conv_weight"
|
|
420
|
+
if weight_path not in weights_dict:
|
|
421
|
+
weight_path = path / "conv.weight"
|
|
393
422
|
if weight_path not in weights_dict:
|
|
394
423
|
weight_path = None
|
|
395
424
|
|
|
396
425
|
if weight_path is not None:
|
|
397
426
|
raw = weights_dict[weight_path]
|
|
427
|
+
if permute_conv:
|
|
428
|
+
raw = jnp.matrix_transpose(raw)
|
|
398
429
|
conv_weight = raw.squeeze(1) if raw.ndim == 3 else raw
|
|
399
430
|
else:
|
|
400
431
|
conv_weight = conv_module.weights
|
|
@@ -402,6 +433,8 @@ def _load_mamba_conv(
|
|
|
402
433
|
bias_path = path / "conv1d" / "bias"
|
|
403
434
|
if bias_path not in weights_dict:
|
|
404
435
|
bias_path = path / "conv_bias"
|
|
436
|
+
if bias_path not in weights_dict:
|
|
437
|
+
bias_path = path / "conv.bias"
|
|
405
438
|
if bias_path not in weights_dict:
|
|
406
439
|
bias_path = None
|
|
407
440
|
|
|
@@ -421,10 +454,11 @@ def load_mamba2(
|
|
|
421
454
|
module: Mamba2,
|
|
422
455
|
weights_dict: Mapping[str, Array],
|
|
423
456
|
path: ParameterPath,
|
|
457
|
+
permute_conv: bool,
|
|
424
458
|
) -> Mamba2:
|
|
425
459
|
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
426
460
|
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
427
|
-
conv =
|
|
461
|
+
conv = _load_conv(module.conv, weights_dict, path, permute_conv)
|
|
428
462
|
|
|
429
463
|
skip_connection_weight_path = path / "D"
|
|
430
464
|
if skip_connection_weight_path in weights_dict:
|
|
@@ -451,6 +485,23 @@ def load_mamba2(
|
|
|
451
485
|
)
|
|
452
486
|
|
|
453
487
|
|
|
488
|
+
def load_short_conv(
|
|
489
|
+
module: ShortConv,
|
|
490
|
+
weights_dict: Mapping[str, Array],
|
|
491
|
+
path: ParameterPath,
|
|
492
|
+
permute_conv: bool,
|
|
493
|
+
) -> ShortConv:
|
|
494
|
+
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
495
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
496
|
+
conv = _load_conv(module.conv, weights_dict, path, permute_conv)
|
|
497
|
+
|
|
498
|
+
return load_parameters(
|
|
499
|
+
lambda m: (m.in_projection, m.out_projection, m.conv),
|
|
500
|
+
module,
|
|
501
|
+
(in_projection, out_projection, conv),
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
|
|
454
505
|
def load_transformer_layer(
|
|
455
506
|
module: TransformerLayer,
|
|
456
507
|
weights_dict: Mapping[str, Array],
|
|
@@ -463,6 +514,7 @@ def load_transformer_layer(
|
|
|
463
514
|
up_proj_key: str,
|
|
464
515
|
gate_proj_key: str,
|
|
465
516
|
down_proj_key: str,
|
|
517
|
+
permute_conv: bool,
|
|
466
518
|
) -> TransformerLayer:
|
|
467
519
|
if module.pre_mixer_norm is not None:
|
|
468
520
|
pre_attention_norm = load_rmsnorm(
|
|
@@ -477,7 +529,9 @@ def load_transformer_layer(
|
|
|
477
529
|
if isinstance(module.mixer, Attention):
|
|
478
530
|
mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
479
531
|
elif isinstance(module.mixer, Mamba2):
|
|
480
|
-
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
532
|
+
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
|
|
533
|
+
elif isinstance(module.mixer, ShortConv):
|
|
534
|
+
mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
|
|
481
535
|
else:
|
|
482
536
|
mixer = module.mixer
|
|
483
537
|
|
|
@@ -625,11 +679,13 @@ def load_huggingface_decoder(
|
|
|
625
679
|
|
|
626
680
|
is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
|
|
627
681
|
is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
|
|
682
|
+
is_lfm2 = any(key.startswith("model.layers.0.operator_norm.weight") for key in weights_dict)
|
|
628
683
|
if is_llamba_full_precision:
|
|
629
684
|
decoder_path = base_path / "backbone"
|
|
630
685
|
embedding_path = decoder_path / "embedding"
|
|
631
686
|
pre_mixer_norm_key = "input_layernorm"
|
|
632
|
-
mixer_key = "mixer"
|
|
687
|
+
mixer_key = {Mamba2Config: "mixer"}
|
|
688
|
+
permute_conv = False
|
|
633
689
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
634
690
|
mlp_key = "mlp"
|
|
635
691
|
up_proj_key = "up_proj"
|
|
@@ -642,7 +698,8 @@ def load_huggingface_decoder(
|
|
|
642
698
|
decoder_path = base_path / "model"
|
|
643
699
|
embedding_path = base_path / "embedding.encoder"
|
|
644
700
|
pre_mixer_norm_key = "norm"
|
|
645
|
-
mixer_key = "layer"
|
|
701
|
+
mixer_key = {Mamba2Config: "layer"}
|
|
702
|
+
permute_conv = False
|
|
646
703
|
pre_mlp_norm_key = "norm"
|
|
647
704
|
mlp_key = "layer"
|
|
648
705
|
up_proj_key = "gate_proj"
|
|
@@ -651,11 +708,26 @@ def load_huggingface_decoder(
|
|
|
651
708
|
alternating_layers = True
|
|
652
709
|
norm_key = "norm"
|
|
653
710
|
lm_head_path = base_path / "head.linear"
|
|
711
|
+
elif is_lfm2:
|
|
712
|
+
decoder_path = base_path / "model"
|
|
713
|
+
embedding_path = decoder_path / "embed_tokens"
|
|
714
|
+
pre_mixer_norm_key = "operator_norm"
|
|
715
|
+
mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
|
|
716
|
+
permute_conv = isinstance(module.config.embedding_config, MLXQuantizedTiedEmbeddingConfig)
|
|
717
|
+
pre_mlp_norm_key = "ffn_norm"
|
|
718
|
+
mlp_key = "feed_forward"
|
|
719
|
+
up_proj_key = "w3"
|
|
720
|
+
gate_proj_key = "w1"
|
|
721
|
+
down_proj_key = "w2"
|
|
722
|
+
alternating_layers = False
|
|
723
|
+
norm_key = "embedding_norm"
|
|
724
|
+
lm_head_path = base_path / "lm_head"
|
|
654
725
|
else:
|
|
655
726
|
decoder_path = base_path / "model"
|
|
656
727
|
embedding_path = decoder_path / "embed_tokens"
|
|
657
728
|
pre_mixer_norm_key = "input_layernorm"
|
|
658
|
-
mixer_key = "self_attn"
|
|
729
|
+
mixer_key = {AttentionConfig: "self_attn"}
|
|
730
|
+
permute_conv = False
|
|
659
731
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
660
732
|
mlp_key = "mlp"
|
|
661
733
|
up_proj_key = "up_proj"
|
|
@@ -687,13 +759,14 @@ def load_huggingface_decoder(
|
|
|
687
759
|
weights_dict,
|
|
688
760
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
689
761
|
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
690
|
-
mixer_key,
|
|
762
|
+
mixer_key[type(layer.config.mixer_config)], # type: ignore
|
|
691
763
|
mlp_key,
|
|
692
764
|
pre_mixer_norm_key,
|
|
693
765
|
pre_mlp_norm_key,
|
|
694
766
|
up_proj_key,
|
|
695
767
|
gate_proj_key,
|
|
696
768
|
down_proj_key,
|
|
769
|
+
permute_conv,
|
|
697
770
|
)
|
|
698
771
|
for i, layer in enumerate(module.transformer.layers)
|
|
699
772
|
)
|
|
@@ -4,6 +4,7 @@ from .essential_ai import RNJ_MODELS
|
|
|
4
4
|
from .gemma import GEMMA_MODELS
|
|
5
5
|
from .gpt_oss import GPT_OSS_MODELS
|
|
6
6
|
from .huggingface import HUGGINGFACE_MODELS
|
|
7
|
+
from .lfm2 import LFM2_MODELS
|
|
7
8
|
from .llama import LLAMA_MODELS
|
|
8
9
|
from .llamba import LLAMBA_MODELS
|
|
9
10
|
from .mirai import MIRAI_CLASSIFIER_MODELS
|
|
@@ -25,6 +26,7 @@ __all__ = [
|
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
ALL_MODEL_LISTS = [
|
|
29
|
+
LFM2_MODELS,
|
|
28
30
|
LLAMA_MODELS,
|
|
29
31
|
LLAMBA_MODELS,
|
|
30
32
|
DEEPSEEK_MODELS,
|
|
@@ -56,6 +56,7 @@ class WeightsType(Enum):
|
|
|
56
56
|
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
57
57
|
else:
|
|
58
58
|
import torch
|
|
59
|
+
|
|
59
60
|
from lalamo.modules.torch_interop import torch_to_jax
|
|
60
61
|
|
|
61
62
|
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
2
|
+
from lalamo.quantization import QuantizationMode
|
|
3
|
+
|
|
4
|
+
from .common import ConfigMap, FileSpec, ModelSpec
|
|
5
|
+
|
|
6
|
+
__all__ = ["LFM2_MODELS"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _lfm2_repo(size: str, quantization: QuantizationMode | None) -> tuple[str, str]:
|
|
10
|
+
organization = "LiquidAI" if quantization is None else "mlx-community"
|
|
11
|
+
name = f"LFM2-{size}{f'-{quantization.bits}bit' if quantization is not None else ''}"
|
|
12
|
+
return (organization, name)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
LFM2_MODELS = [
|
|
16
|
+
ModelSpec(
|
|
17
|
+
vendor="LiquidAI",
|
|
18
|
+
family="LFM2",
|
|
19
|
+
name=_lfm2_repo(size, quantization)[1],
|
|
20
|
+
size=size,
|
|
21
|
+
repo="/".join(_lfm2_repo(size, quantization)),
|
|
22
|
+
config_type=HFLFM2Config,
|
|
23
|
+
quantization=quantization,
|
|
24
|
+
configs=ConfigMap(
|
|
25
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
26
|
+
),
|
|
27
|
+
use_cases=tuple(),
|
|
28
|
+
)
|
|
29
|
+
for size in ["350M", "700M", "1.2B", "2.6B"]
|
|
30
|
+
for quantization in [None, *([QuantizationMode.UINT4, QuantizationMode.UINT8] if size != "2.6B" else [])]
|
|
31
|
+
]
|
|
@@ -69,6 +69,9 @@ from .token_mixers import (
|
|
|
69
69
|
Mamba2Config,
|
|
70
70
|
SeparableCausalConv,
|
|
71
71
|
SeparableCausalConvConfig,
|
|
72
|
+
ShortConv,
|
|
73
|
+
ShortConvConfig,
|
|
74
|
+
ShortConvStateLayer,
|
|
72
75
|
State,
|
|
73
76
|
StaticKVCacheLayer,
|
|
74
77
|
)
|
|
@@ -136,6 +139,9 @@ __all__ = [
|
|
|
136
139
|
"RoutingFunction",
|
|
137
140
|
"SeparableCausalConv",
|
|
138
141
|
"SeparableCausalConvConfig",
|
|
142
|
+
"ShortConv",
|
|
143
|
+
"ShortConvConfig",
|
|
144
|
+
"ShortConvStateLayer",
|
|
139
145
|
"SiLU",
|
|
140
146
|
"SoftmaxRouting",
|
|
141
147
|
"State",
|
|
@@ -3,9 +3,18 @@ from lalamo.modules.common import register_config_union
|
|
|
3
3
|
from .attention import Attention, AttentionConfig, AttentionResult
|
|
4
4
|
from .common import TokenMixerBase, TokenMixerResult
|
|
5
5
|
from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
|
|
6
|
-
from .
|
|
6
|
+
from .short_conv import ShortConv, ShortConvConfig, ShortConvResult
|
|
7
|
+
from .state import (
|
|
8
|
+
DynamicKVCacheLayer,
|
|
9
|
+
KVCacheLayer,
|
|
10
|
+
Mamba2StateLayer,
|
|
11
|
+
ShortConvStateLayer,
|
|
12
|
+
State,
|
|
13
|
+
StateLayerBase,
|
|
14
|
+
StaticKVCacheLayer,
|
|
15
|
+
)
|
|
7
16
|
|
|
8
|
-
TokenMixerConfig = AttentionConfig | Mamba2Config
|
|
17
|
+
TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
|
|
9
18
|
|
|
10
19
|
register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
|
|
11
20
|
|
|
@@ -21,6 +30,10 @@ __all__ = [
|
|
|
21
30
|
"Mamba2StateLayer",
|
|
22
31
|
"SeparableCausalConv",
|
|
23
32
|
"SeparableCausalConvConfig",
|
|
33
|
+
"ShortConv",
|
|
34
|
+
"ShortConvConfig",
|
|
35
|
+
"ShortConvResult",
|
|
36
|
+
"ShortConvStateLayer",
|
|
24
37
|
"State",
|
|
25
38
|
"StateLayerBase",
|
|
26
39
|
"StaticKVCacheLayer",
|