lalamo 0.5.9__tar.gz → 0.5.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. {lalamo-0.5.9 → lalamo-0.5.11}/PKG-INFO +1 -1
  2. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/__init__.py +1 -1
  3. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/__init__.py +2 -0
  4. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  5. lalamo-0.5.11/lalamo/model_import/decoder_configs/huggingface/lfm2.py +225 -0
  6. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/huggingface.py +83 -10
  7. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/__init__.py +2 -0
  8. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/common.py +1 -0
  9. lalamo-0.5.11/lalamo/model_import/model_specs/lfm2.py +31 -0
  10. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/__init__.py +6 -0
  11. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/__init__.py +15 -2
  12. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/common.py +1 -1
  13. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/mamba.py +2 -2
  14. lalamo-0.5.11/lalamo/modules/token_mixers/short_conv.py +168 -0
  15. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/__init__.py +2 -0
  16. lalamo-0.5.11/lalamo/modules/token_mixers/state/short_conv_state.py +33 -0
  17. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/transformer.py +18 -6
  18. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/transformer_layer.py +1 -1
  19. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/PKG-INFO +1 -1
  20. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/SOURCES.txt +5 -0
  21. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_huggingface_model_conversion.py +3 -1
  22. lalamo-0.5.11/tests/test_lfm2_models.py +13 -0
  23. {lalamo-0.5.9 → lalamo-0.5.11}/LICENSE +0 -0
  24. {lalamo-0.5.9 → lalamo-0.5.11}/README.md +0 -0
  25. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/common.py +0 -0
  26. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/__init__.py +0 -0
  27. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/huggingface_message.py +0 -0
  28. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/lalamo_completions.py +0 -0
  29. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/data/utils.py +0 -0
  30. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/main.py +0 -0
  31. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/message_processor.py +0 -0
  32. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/__init__.py +0 -0
  33. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/common.py +0 -0
  34. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/common.py +0 -0
  35. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  36. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  37. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  38. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  39. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  40. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  41. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  42. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  43. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  44. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  45. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  46. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/huggingface_generation_config.py +0 -0
  47. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  48. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/__init__.py +0 -0
  49. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/common.py +0 -0
  50. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/executorch.py +0 -0
  51. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/loaders/utils.py +0 -0
  52. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/deepseek.py +0 -0
  53. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  54. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/gemma.py +0 -0
  55. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  56. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/huggingface.py +0 -0
  57. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/llama.py +0 -0
  58. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/llamba.py +0 -0
  59. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/mirai.py +0 -0
  60. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/mistral.py +0 -0
  61. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/pleias.py +0 -0
  62. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/polaris.py +0 -0
  63. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/qwen.py +0 -0
  64. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/model_import/model_specs/reka.py +0 -0
  65. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/__init__.py +0 -0
  66. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/classifier.py +0 -0
  67. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/common.py +0 -0
  68. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/models/language_model.py +0 -0
  69. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/activations.py +0 -0
  70. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/classifier.py +0 -0
  71. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/common.py +0 -0
  72. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/decoder.py +0 -0
  73. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/embedding.py +0 -0
  74. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/linear.py +0 -0
  75. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/mlp.py +0 -0
  76. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/mlx_interop.py +0 -0
  77. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/normalization.py +0 -0
  78. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/rope.py +0 -0
  79. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/attention.py +0 -0
  80. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/common.py +0 -0
  81. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  82. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  83. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/torch_interop.py +0 -0
  84. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/modules/utils.py +0 -0
  85. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/quantization.py +0 -0
  86. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/registry_abc.py +0 -0
  87. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/sampling.py +0 -0
  88. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/__init__.py +0 -0
  89. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/common.py +0 -0
  90. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/estimator.py +0 -0
  91. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/inference.py +0 -0
  92. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/ngram.py +0 -0
  93. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/speculator/utils.py +0 -0
  94. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo/utils.py +0 -0
  95. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/dependency_links.txt +0 -0
  96. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/entry_points.txt +0 -0
  97. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/requires.txt +0 -0
  98. {lalamo-0.5.9 → lalamo-0.5.11}/lalamo.egg-info/top_level.txt +0 -0
  99. {lalamo-0.5.9 → lalamo-0.5.11}/pyproject.toml +0 -0
  100. {lalamo-0.5.9 → lalamo-0.5.11}/setup.cfg +0 -0
  101. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_cartesia_mlx_models.py +0 -0
  102. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_chat_template.py +0 -0
  103. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_generation.py +0 -0
  104. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_huggingface_models.py +0 -0
  105. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_mlx_models.py +0 -0
  106. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_model_spec.py +0 -0
  107. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_models.py +0 -0
  108. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_moe.py +0 -0
  109. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_parameter_tree.py +0 -0
  110. {lalamo-0.5.9 → lalamo-0.5.11}/tests/test_registry_abc.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.9
3
+ Version: 0.5.11
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -15,7 +15,7 @@ from lalamo.speculator import (
15
15
  SpeculatorTrainingEvent,
16
16
  )
17
17
 
18
- __version__ = "0.5.9"
18
+ __version__ = "0.5.11"
19
19
 
20
20
  __all__ = [
21
21
  "AssistantMessage",
@@ -6,6 +6,7 @@ from .huggingface import (
6
6
  HFGemma3Config,
7
7
  HFGemma3TextConfig,
8
8
  HFGPTOssConfig,
9
+ HFLFM2Config,
9
10
  HFLlamaConfig,
10
11
  HFLlambaConfig,
11
12
  HFMistralConfig,
@@ -22,6 +23,7 @@ __all__ = [
22
23
  "HFGemma2Config",
23
24
  "HFGemma3Config",
24
25
  "HFGemma3TextConfig",
26
+ "HFLFM2Config",
25
27
  "HFLlamaConfig",
26
28
  "HFLlambaConfig",
27
29
  "HFMistralConfig",
@@ -2,6 +2,7 @@ from .common import HuggingFaceLMConfig
2
2
  from .gemma2 import HFGemma2Config
3
3
  from .gemma3 import HFGemma3Config, HFGemma3TextConfig
4
4
  from .gpt_oss import HFGPTOssConfig
5
+ from .lfm2 import HFLFM2Config
5
6
  from .llama import HFLlamaConfig
6
7
  from .llamba import HFLlambaConfig
7
8
  from .mistral import HFMistralConfig
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "HFGemma2Config",
15
16
  "HFGemma3Config",
16
17
  "HFGemma3TextConfig",
18
+ "HFLFM2Config",
17
19
  "HFLlamaConfig",
18
20
  "HFLlambaConfig",
19
21
  "HFMistralConfig",
@@ -0,0 +1,225 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass
3
+ from typing import Literal
4
+
5
+ from jaxtyping import DTypeLike
6
+
7
+ from lalamo.modules import (
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DenseMLPConfig,
11
+ FullPrecisionLinearConfig,
12
+ MLXQuantizedLinearConfig,
13
+ MLXQuantizedTiedEmbeddingConfig,
14
+ NormalizationConfig,
15
+ SeparableCausalConvConfig,
16
+ ShortConvConfig,
17
+ SiLU,
18
+ TiedEmbeddingConfig,
19
+ TransformerConfig,
20
+ TransformerLayerConfig,
21
+ UnscaledRoPEConfig,
22
+ UntiedEmbeddingConfig,
23
+ UpcastMode,
24
+ )
25
+ from lalamo.quantization import QuantizationMode
26
+
27
+ from .common import HuggingFaceLMConfig
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class QuantizationConfig:
32
+ group_size: int
33
+ bits: int
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class HFLFM2Config(HuggingFaceLMConfig):
38
+ architectures: list[Literal["Lfm2ForCausalLM"]]
39
+ block_auto_adjust_ff_dim: bool
40
+ block_dim: int
41
+ block_ff_dim: int
42
+ block_ffn_dim_multiplier: float
43
+ block_mlp_init_scale: float
44
+ block_multiple_of: int
45
+ block_norm_eps: float
46
+ block_out_init_scale: float
47
+ block_use_swiglu: bool
48
+ block_use_xavier_init: bool
49
+ bos_token_id: int
50
+ conv_L_cache: int # noqa: N815
51
+ conv_bias: bool
52
+ conv_dim: int
53
+ conv_dim_out: int
54
+ conv_use_xavier_init: bool
55
+ eos_token_id: int
56
+ hidden_size: int
57
+ initializer_range: float
58
+ max_position_embeddings: int
59
+ model_type: Literal["lfm2"]
60
+ norm_eps: float
61
+ num_attention_heads: int
62
+ num_heads: int
63
+ num_hidden_layers: int
64
+ num_key_value_heads: int
65
+ pad_token_id: int
66
+ rope_theta: float
67
+ torch_dtype: Literal["bfloat16"]
68
+ transformers_version: str
69
+ use_cache: bool
70
+ use_pos_enc: bool
71
+ vocab_size: int
72
+
73
+ intermediate_size: int | None = None
74
+ layer_types: list[Literal["conv", "full_attention"]] | None = None
75
+ full_attn_idxs: list[int] | None = None
76
+ tie_embedding: bool = True
77
+ theta: float | None = None
78
+
79
+ quantization: QuantizationConfig | None = None
80
+ quantization_config: QuantizationConfig | None = None
81
+
82
+ def to_decoder_config(
83
+ self,
84
+ context_length: int | None,
85
+ activation_precision: DTypeLike,
86
+ accumulation_precision: DTypeLike,
87
+ metadata_dict: Mapping[str, str], # noqa: ARG002
88
+ ) -> DecoderConfig:
89
+ assert self.num_attention_heads == self.num_heads
90
+
91
+ if self.quantization_config is not None:
92
+ assert self.tie_embedding
93
+
94
+ embedding_config = MLXQuantizedTiedEmbeddingConfig(
95
+ input_scale=None,
96
+ logit_soft_cap=None,
97
+ group_size=self.quantization_config.group_size,
98
+ embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
99
+ activation_quantization_mode=None,
100
+ activation_precision=activation_precision,
101
+ )
102
+ elif self.tie_embedding:
103
+ embedding_config = TiedEmbeddingConfig(
104
+ input_scale=None,
105
+ logit_soft_cap=None,
106
+ precision=activation_precision,
107
+ )
108
+ else:
109
+ embedding_config = UntiedEmbeddingConfig(
110
+ input_scale=None,
111
+ logit_soft_cap=None,
112
+ precision=activation_precision,
113
+ )
114
+
115
+ rope_config = UnscaledRoPEConfig(
116
+ precision=activation_precision,
117
+ base=self.rope_theta,
118
+ max_sequence_length=context_length or self.max_position_embeddings,
119
+ )
120
+
121
+ if self.quantization_config is None:
122
+ linear_config = FullPrecisionLinearConfig(activation_precision)
123
+ else:
124
+ linear_config = MLXQuantizedLinearConfig(
125
+ group_size=self.quantization_config.group_size,
126
+ weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
127
+ activation_quantization_mode=None,
128
+ activation_precision=activation_precision,
129
+ )
130
+
131
+ block_norm_config = NormalizationConfig(
132
+ scale_precision=activation_precision,
133
+ accumulation_precision=accumulation_precision,
134
+ epsilon=self.block_norm_eps,
135
+ scale_offset=None,
136
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
137
+ subtract_mean=False,
138
+ )
139
+
140
+ attention_config = AttentionConfig(
141
+ qkv_projection_config=linear_config,
142
+ out_projection_config=linear_config,
143
+ query_norm_config=block_norm_config,
144
+ key_norm_config=block_norm_config,
145
+ num_heads=self.num_attention_heads,
146
+ num_groups=self.num_key_value_heads,
147
+ head_dim=self.hidden_size // self.num_heads,
148
+ is_causal=True,
149
+ scale=None,
150
+ sliding_window_size=None,
151
+ logit_soft_cap=None,
152
+ has_sinks=False,
153
+ has_qkv_biases=False,
154
+ has_out_biases=False,
155
+ )
156
+
157
+ short_conv_config = ShortConvConfig(
158
+ in_projection_config=linear_config,
159
+ conv_config=SeparableCausalConvConfig(activation_precision, has_biases=self.conv_bias),
160
+ out_projection_config=linear_config,
161
+ kernel_size=self.conv_L_cache,
162
+ )
163
+
164
+ mlp_config = DenseMLPConfig(
165
+ linear_config=linear_config,
166
+ activation=SiLU(),
167
+ has_up_biases=False,
168
+ has_down_biases=False,
169
+ up_clipping=None,
170
+ gate_clipping=None,
171
+ )
172
+
173
+ if self.layer_types is not None:
174
+ layer_types = self.layer_types
175
+ elif self.full_attn_idxs is not None:
176
+ layer_types = [
177
+ "full_attention" if i in self.full_attn_idxs else "conv" for i in range(self.num_hidden_layers)
178
+ ]
179
+ else:
180
+ raise RuntimeError("Either layer_types or full_attn_idxs must be present.")
181
+
182
+ layer_configs = [
183
+ TransformerLayerConfig(
184
+ pre_mixer_norm_config=block_norm_config,
185
+ mixer_config={"conv": short_conv_config, "full_attention": attention_config}[layer_type],
186
+ post_mixer_norm_config=None,
187
+ pre_mlp_norm_config=block_norm_config,
188
+ mlp_config=mlp_config,
189
+ post_mlp_norm_config=None,
190
+ )
191
+ for layer_type in layer_types
192
+ ]
193
+
194
+ output_norm_config = NormalizationConfig(
195
+ scale_precision=activation_precision,
196
+ accumulation_precision=accumulation_precision,
197
+ epsilon=self.norm_eps,
198
+ scale_offset=None,
199
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
200
+ subtract_mean=False,
201
+ )
202
+
203
+ if self.intermediate_size is not None:
204
+ hidden_dim = self.intermediate_size
205
+ else:
206
+ hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
207
+ hidden_dim = int(
208
+ (hidden_dim_adjusted + self.block_multiple_of - 1) // self.block_multiple_of * self.block_multiple_of,
209
+ )
210
+
211
+ transformer_config = TransformerConfig(
212
+ global_rope_config=rope_config,
213
+ local_rope_config=None,
214
+ layer_configs=tuple(layer_configs),
215
+ output_norm_config=output_norm_config,
216
+ model_dim=self.hidden_size,
217
+ hidden_dim=hidden_dim,
218
+ context_length=context_length or self.max_position_embeddings,
219
+ )
220
+
221
+ return DecoderConfig(
222
+ embedding_config=embedding_config,
223
+ transformer_config=transformer_config,
224
+ vocab_size=self.vocab_size,
225
+ )
@@ -8,17 +8,22 @@ from jaxtyping import Array, DTypeLike
8
8
  from lalamo.common import ParameterPath
9
9
  from lalamo.modules import (
10
10
  Attention,
11
+ AttentionConfig,
11
12
  Decoder,
12
13
  DenseMLP,
13
14
  FullPrecisionLinear,
14
15
  GroupQuantizedLinear,
15
16
  LinearBase,
16
17
  Mamba2,
18
+ Mamba2Config,
17
19
  MLXQuantizedLinear,
18
20
  MLXQuantizedTiedEmbedding,
21
+ MLXQuantizedTiedEmbeddingConfig,
19
22
  MLXSemiQuantizedUntiedEmbedding,
20
23
  Normalization,
21
24
  SeparableCausalConv,
25
+ ShortConv,
26
+ ShortConvConfig,
22
27
  TiedEmbedding,
23
28
  TransformerLayer,
24
29
  UntiedEmbedding,
@@ -345,21 +350,42 @@ def load_attention(
345
350
  weights_dict: Mapping[str, Array],
346
351
  path: ParameterPath,
347
352
  ) -> Attention:
353
+ if (path / "o_proj.weight") in weights_dict:
354
+ o_proj_name = "o_proj"
355
+ elif (path / "out_proj.weight") in weights_dict:
356
+ o_proj_name = "out_proj"
357
+ else:
358
+ raise NotImplementedError("Can't determine attention output projection name")
359
+
348
360
  qkv_projection = load_linear(
349
361
  module.qkv_projection,
350
362
  weights_dict,
351
363
  path,
352
364
  sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
353
365
  )
354
- out_projection = load_linear(module.out_projection, weights_dict, path / "o_proj")
366
+ out_projection = load_linear(module.out_projection, weights_dict, path / o_proj_name)
355
367
 
356
368
  if module.query_norm is not None:
357
- query_norm = load_rmsnorm(module.query_norm, weights_dict, path / "q_norm")
369
+ if (path / "q_norm.weight") in weights_dict:
370
+ q_norm_name = "q_norm"
371
+ elif (path / "q_layernorm.weight") in weights_dict:
372
+ q_norm_name = "q_layernorm"
373
+ else:
374
+ raise NotImplementedError("Can't determine attention query projection parameter name")
375
+
376
+ query_norm = load_rmsnorm(module.query_norm, weights_dict, path / q_norm_name)
358
377
  else:
359
378
  query_norm = None
360
379
 
361
380
  if module.key_norm is not None:
362
- key_norm = load_rmsnorm(module.key_norm, weights_dict, path / "k_norm")
381
+ if (path / "k_norm.weight") in weights_dict:
382
+ k_norm_name = "k_norm"
383
+ elif (path / "k_layernorm.weight") in weights_dict:
384
+ k_norm_name = "k_layernorm"
385
+ else:
386
+ raise NotImplementedError("Can't determine attention key projection parameter name")
387
+
388
+ key_norm = load_rmsnorm(module.key_norm, weights_dict, path / k_norm_name)
363
389
  else:
364
390
  key_norm = None
365
391
 
@@ -382,19 +408,24 @@ def load_attention(
382
408
  )
383
409
 
384
410
 
385
- def _load_mamba_conv(
411
+ def _load_conv(
386
412
  conv_module: SeparableCausalConv,
387
413
  weights_dict: Mapping[str, Array],
388
414
  path: ParameterPath,
415
+ permute_conv: bool,
389
416
  ) -> SeparableCausalConv:
390
417
  weight_path = path / "conv1d" / "weight"
391
418
  if weight_path not in weights_dict:
392
419
  weight_path = path / "conv_weight"
420
+ if weight_path not in weights_dict:
421
+ weight_path = path / "conv.weight"
393
422
  if weight_path not in weights_dict:
394
423
  weight_path = None
395
424
 
396
425
  if weight_path is not None:
397
426
  raw = weights_dict[weight_path]
427
+ if permute_conv:
428
+ raw = jnp.matrix_transpose(raw)
398
429
  conv_weight = raw.squeeze(1) if raw.ndim == 3 else raw
399
430
  else:
400
431
  conv_weight = conv_module.weights
@@ -402,6 +433,8 @@ def _load_mamba_conv(
402
433
  bias_path = path / "conv1d" / "bias"
403
434
  if bias_path not in weights_dict:
404
435
  bias_path = path / "conv_bias"
436
+ if bias_path not in weights_dict:
437
+ bias_path = path / "conv.bias"
405
438
  if bias_path not in weights_dict:
406
439
  bias_path = None
407
440
 
@@ -421,10 +454,11 @@ def load_mamba2(
421
454
  module: Mamba2,
422
455
  weights_dict: Mapping[str, Array],
423
456
  path: ParameterPath,
457
+ permute_conv: bool,
424
458
  ) -> Mamba2:
425
459
  in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
426
460
  out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
427
- conv = _load_mamba_conv(module.conv, weights_dict, path)
461
+ conv = _load_conv(module.conv, weights_dict, path, permute_conv)
428
462
 
429
463
  skip_connection_weight_path = path / "D"
430
464
  if skip_connection_weight_path in weights_dict:
@@ -451,6 +485,23 @@ def load_mamba2(
451
485
  )
452
486
 
453
487
 
488
+ def load_short_conv(
489
+ module: ShortConv,
490
+ weights_dict: Mapping[str, Array],
491
+ path: ParameterPath,
492
+ permute_conv: bool,
493
+ ) -> ShortConv:
494
+ in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
495
+ out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
496
+ conv = _load_conv(module.conv, weights_dict, path, permute_conv)
497
+
498
+ return load_parameters(
499
+ lambda m: (m.in_projection, m.out_projection, m.conv),
500
+ module,
501
+ (in_projection, out_projection, conv),
502
+ )
503
+
504
+
454
505
  def load_transformer_layer(
455
506
  module: TransformerLayer,
456
507
  weights_dict: Mapping[str, Array],
@@ -463,6 +514,7 @@ def load_transformer_layer(
463
514
  up_proj_key: str,
464
515
  gate_proj_key: str,
465
516
  down_proj_key: str,
517
+ permute_conv: bool,
466
518
  ) -> TransformerLayer:
467
519
  if module.pre_mixer_norm is not None:
468
520
  pre_attention_norm = load_rmsnorm(
@@ -477,7 +529,9 @@ def load_transformer_layer(
477
529
  if isinstance(module.mixer, Attention):
478
530
  mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
479
531
  elif isinstance(module.mixer, Mamba2):
480
- mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
532
+ mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
533
+ elif isinstance(module.mixer, ShortConv):
534
+ mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
481
535
  else:
482
536
  mixer = module.mixer
483
537
 
@@ -625,11 +679,13 @@ def load_huggingface_decoder(
625
679
 
626
680
  is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
627
681
  is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
682
+ is_lfm2 = any(key.startswith("model.layers.0.operator_norm.weight") for key in weights_dict)
628
683
  if is_llamba_full_precision:
629
684
  decoder_path = base_path / "backbone"
630
685
  embedding_path = decoder_path / "embedding"
631
686
  pre_mixer_norm_key = "input_layernorm"
632
- mixer_key = "mixer"
687
+ mixer_key = {Mamba2Config: "mixer"}
688
+ permute_conv = False
633
689
  pre_mlp_norm_key = "post_attention_layernorm"
634
690
  mlp_key = "mlp"
635
691
  up_proj_key = "up_proj"
@@ -642,7 +698,8 @@ def load_huggingface_decoder(
642
698
  decoder_path = base_path / "model"
643
699
  embedding_path = base_path / "embedding.encoder"
644
700
  pre_mixer_norm_key = "norm"
645
- mixer_key = "layer"
701
+ mixer_key = {Mamba2Config: "layer"}
702
+ permute_conv = False
646
703
  pre_mlp_norm_key = "norm"
647
704
  mlp_key = "layer"
648
705
  up_proj_key = "gate_proj"
@@ -651,11 +708,26 @@ def load_huggingface_decoder(
651
708
  alternating_layers = True
652
709
  norm_key = "norm"
653
710
  lm_head_path = base_path / "head.linear"
711
+ elif is_lfm2:
712
+ decoder_path = base_path / "model"
713
+ embedding_path = decoder_path / "embed_tokens"
714
+ pre_mixer_norm_key = "operator_norm"
715
+ mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
716
+ permute_conv = isinstance(module.config.embedding_config, MLXQuantizedTiedEmbeddingConfig)
717
+ pre_mlp_norm_key = "ffn_norm"
718
+ mlp_key = "feed_forward"
719
+ up_proj_key = "w3"
720
+ gate_proj_key = "w1"
721
+ down_proj_key = "w2"
722
+ alternating_layers = False
723
+ norm_key = "embedding_norm"
724
+ lm_head_path = base_path / "lm_head"
654
725
  else:
655
726
  decoder_path = base_path / "model"
656
727
  embedding_path = decoder_path / "embed_tokens"
657
728
  pre_mixer_norm_key = "input_layernorm"
658
- mixer_key = "self_attn"
729
+ mixer_key = {AttentionConfig: "self_attn"}
730
+ permute_conv = False
659
731
  pre_mlp_norm_key = "post_attention_layernorm"
660
732
  mlp_key = "mlp"
661
733
  up_proj_key = "up_proj"
@@ -687,13 +759,14 @@ def load_huggingface_decoder(
687
759
  weights_dict,
688
760
  decoder_path / "layers" / ((i * 2) if alternating_layers else i),
689
761
  decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
690
- mixer_key,
762
+ mixer_key[type(layer.config.mixer_config)], # type: ignore
691
763
  mlp_key,
692
764
  pre_mixer_norm_key,
693
765
  pre_mlp_norm_key,
694
766
  up_proj_key,
695
767
  gate_proj_key,
696
768
  down_proj_key,
769
+ permute_conv,
697
770
  )
698
771
  for i, layer in enumerate(module.transformer.layers)
699
772
  )
@@ -4,6 +4,7 @@ from .essential_ai import RNJ_MODELS
4
4
  from .gemma import GEMMA_MODELS
5
5
  from .gpt_oss import GPT_OSS_MODELS
6
6
  from .huggingface import HUGGINGFACE_MODELS
7
+ from .lfm2 import LFM2_MODELS
7
8
  from .llama import LLAMA_MODELS
8
9
  from .llamba import LLAMBA_MODELS
9
10
  from .mirai import MIRAI_CLASSIFIER_MODELS
@@ -25,6 +26,7 @@ __all__ = [
25
26
 
26
27
 
27
28
  ALL_MODEL_LISTS = [
29
+ LFM2_MODELS,
28
30
  LLAMA_MODELS,
29
31
  LLAMBA_MODELS,
30
32
  DEEPSEEK_MODELS,
@@ -56,6 +56,7 @@ class WeightsType(Enum):
56
56
  yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
57
57
  else:
58
58
  import torch
59
+
59
60
  from lalamo.modules.torch_interop import torch_to_jax
60
61
 
61
62
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
@@ -0,0 +1,31 @@
1
+ from lalamo.model_import.decoder_configs import HFLFM2Config
2
+ from lalamo.quantization import QuantizationMode
3
+
4
+ from .common import ConfigMap, FileSpec, ModelSpec
5
+
6
+ __all__ = ["LFM2_MODELS"]
7
+
8
+
9
+ def _lfm2_repo(size: str, quantization: QuantizationMode | None) -> tuple[str, str]:
10
+ organization = "LiquidAI" if quantization is None else "mlx-community"
11
+ name = f"LFM2-{size}{f'-{quantization.bits}bit' if quantization is not None else ''}"
12
+ return (organization, name)
13
+
14
+
15
+ LFM2_MODELS = [
16
+ ModelSpec(
17
+ vendor="LiquidAI",
18
+ family="LFM2",
19
+ name=_lfm2_repo(size, quantization)[1],
20
+ size=size,
21
+ repo="/".join(_lfm2_repo(size, quantization)),
22
+ config_type=HFLFM2Config,
23
+ quantization=quantization,
24
+ configs=ConfigMap(
25
+ chat_template=FileSpec("chat_template.jinja"),
26
+ ),
27
+ use_cases=tuple(),
28
+ )
29
+ for size in ["350M", "700M", "1.2B", "2.6B"]
30
+ for quantization in [None, *([QuantizationMode.UINT4, QuantizationMode.UINT8] if size != "2.6B" else [])]
31
+ ]
@@ -69,6 +69,9 @@ from .token_mixers import (
69
69
  Mamba2Config,
70
70
  SeparableCausalConv,
71
71
  SeparableCausalConvConfig,
72
+ ShortConv,
73
+ ShortConvConfig,
74
+ ShortConvStateLayer,
72
75
  State,
73
76
  StaticKVCacheLayer,
74
77
  )
@@ -136,6 +139,9 @@ __all__ = [
136
139
  "RoutingFunction",
137
140
  "SeparableCausalConv",
138
141
  "SeparableCausalConvConfig",
142
+ "ShortConv",
143
+ "ShortConvConfig",
144
+ "ShortConvStateLayer",
139
145
  "SiLU",
140
146
  "SoftmaxRouting",
141
147
  "State",
@@ -3,9 +3,18 @@ from lalamo.modules.common import register_config_union
3
3
  from .attention import Attention, AttentionConfig, AttentionResult
4
4
  from .common import TokenMixerBase, TokenMixerResult
5
5
  from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
6
- from .state import DynamicKVCacheLayer, KVCacheLayer, Mamba2StateLayer, State, StateLayerBase, StaticKVCacheLayer
6
+ from .short_conv import ShortConv, ShortConvConfig, ShortConvResult
7
+ from .state import (
8
+ DynamicKVCacheLayer,
9
+ KVCacheLayer,
10
+ Mamba2StateLayer,
11
+ ShortConvStateLayer,
12
+ State,
13
+ StateLayerBase,
14
+ StaticKVCacheLayer,
15
+ )
7
16
 
8
- TokenMixerConfig = AttentionConfig | Mamba2Config
17
+ TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
9
18
 
10
19
  register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
11
20
 
@@ -21,6 +30,10 @@ __all__ = [
21
30
  "Mamba2StateLayer",
22
31
  "SeparableCausalConv",
23
32
  "SeparableCausalConvConfig",
33
+ "ShortConv",
34
+ "ShortConvConfig",
35
+ "ShortConvResult",
36
+ "ShortConvStateLayer",
24
37
  "State",
25
38
  "StateLayerBase",
26
39
  "StaticKVCacheLayer",
@@ -25,7 +25,7 @@ class TokenMixerResult[StateLayerT](NamedTuple):
25
25
  class TokenMixerConfigBase(ABC):
26
26
  @property
27
27
  @abstractmethod
28
- def rope_dim(self) -> int: ...
28
+ def rope_dim(self) -> int | None: ...
29
29
 
30
30
  @abstractmethod
31
31
  def random_init(
@@ -184,8 +184,8 @@ class Mamba2Config(TokenMixerConfigBase):
184
184
  return self.num_heads * self.head_dim
185
185
 
186
186
  @property
187
- def rope_dim(self) -> int:
188
- return self.head_dim
187
+ def rope_dim(self) -> None:
188
+ return None
189
189
 
190
190
  def random_init(
191
191
  self,