lalamo 0.5.10__tar.gz → 0.5.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {lalamo-0.5.10 → lalamo-0.5.12}/PKG-INFO +1 -1
  2. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/__init__.py +1 -1
  3. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +1 -1
  4. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +63 -12
  5. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/loaders/huggingface.py +18 -6
  6. lalamo-0.5.12/lalamo/model_import/model_specs/lfm2.py +31 -0
  7. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo.egg-info/PKG-INFO +1 -1
  8. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_huggingface_model_conversion.py +5 -1
  9. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_lfm2_models.py +2 -3
  10. lalamo-0.5.10/lalamo/model_import/model_specs/lfm2.py +0 -21
  11. {lalamo-0.5.10 → lalamo-0.5.12}/LICENSE +0 -0
  12. {lalamo-0.5.10 → lalamo-0.5.12}/README.md +0 -0
  13. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/common.py +0 -0
  14. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/data/__init__.py +0 -0
  15. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/data/huggingface_message.py +0 -0
  16. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/data/lalamo_completions.py +0 -0
  17. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/data/utils.py +0 -0
  18. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/main.py +0 -0
  19. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/message_processor.py +0 -0
  20. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/__init__.py +0 -0
  21. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/common.py +0 -0
  22. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  23. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/common.py +0 -0
  24. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  25. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  26. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  27. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  28. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  29. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  30. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  31. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  32. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  33. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  34. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  35. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/huggingface_generation_config.py +0 -0
  36. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  37. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/loaders/__init__.py +0 -0
  38. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/loaders/common.py +0 -0
  39. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/loaders/executorch.py +0 -0
  40. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/loaders/utils.py +0 -0
  41. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/__init__.py +0 -0
  42. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/common.py +0 -0
  43. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/deepseek.py +0 -0
  44. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  45. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/gemma.py +0 -0
  46. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  47. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/huggingface.py +0 -0
  48. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/llama.py +0 -0
  49. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/llamba.py +0 -0
  50. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/mirai.py +0 -0
  51. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/mistral.py +0 -0
  52. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/pleias.py +0 -0
  53. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/polaris.py +0 -0
  54. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/qwen.py +0 -0
  55. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/model_import/model_specs/reka.py +0 -0
  56. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/models/__init__.py +0 -0
  57. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/models/classifier.py +0 -0
  58. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/models/common.py +0 -0
  59. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/models/language_model.py +0 -0
  60. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/__init__.py +0 -0
  61. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/activations.py +0 -0
  62. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/classifier.py +0 -0
  63. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/common.py +0 -0
  64. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/decoder.py +0 -0
  65. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/embedding.py +0 -0
  66. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/linear.py +0 -0
  67. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/mlp.py +0 -0
  68. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/mlx_interop.py +0 -0
  69. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/normalization.py +0 -0
  70. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/rope.py +0 -0
  71. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/__init__.py +0 -0
  72. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/attention.py +0 -0
  73. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/common.py +0 -0
  74. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/mamba.py +0 -0
  75. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/short_conv.py +0 -0
  76. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  77. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/state/common.py +0 -0
  78. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  79. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  80. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  81. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/torch_interop.py +0 -0
  82. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/transformer.py +0 -0
  83. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/transformer_layer.py +0 -0
  84. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/modules/utils.py +0 -0
  85. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/quantization.py +0 -0
  86. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/registry_abc.py +0 -0
  87. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/sampling.py +0 -0
  88. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/speculator/__init__.py +0 -0
  89. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/speculator/common.py +0 -0
  90. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/speculator/estimator.py +0 -0
  91. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/speculator/inference.py +0 -0
  92. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/speculator/ngram.py +0 -0
  93. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/speculator/utils.py +0 -0
  94. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo/utils.py +0 -0
  95. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo.egg-info/SOURCES.txt +0 -0
  96. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo.egg-info/dependency_links.txt +0 -0
  97. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo.egg-info/entry_points.txt +0 -0
  98. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo.egg-info/requires.txt +0 -0
  99. {lalamo-0.5.10 → lalamo-0.5.12}/lalamo.egg-info/top_level.txt +0 -0
  100. {lalamo-0.5.10 → lalamo-0.5.12}/pyproject.toml +0 -0
  101. {lalamo-0.5.10 → lalamo-0.5.12}/setup.cfg +0 -0
  102. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_cartesia_mlx_models.py +0 -0
  103. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_chat_template.py +0 -0
  104. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_generation.py +0 -0
  105. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_huggingface_models.py +0 -0
  106. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_mlx_models.py +0 -0
  107. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_model_spec.py +0 -0
  108. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_models.py +0 -0
  109. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_moe.py +0 -0
  110. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_parameter_tree.py +0 -0
  111. {lalamo-0.5.10 → lalamo-0.5.12}/tests/test_registry_abc.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.10
3
+ Version: 0.5.12
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -15,7 +15,7 @@ from lalamo.speculator import (
15
15
  SpeculatorTrainingEvent,
16
16
  )
17
17
 
18
- __version__ = "0.5.10"
18
+ __version__ = "0.5.12"
19
19
 
20
20
  __all__ = [
21
21
  "AssistantMessage",
@@ -46,7 +46,6 @@ class HFGemma3TextConfigRaw:
46
46
  model_type: Literal["gemma3_text"]
47
47
  num_hidden_layers: int
48
48
  sliding_window: int
49
- sliding_window_pattern: int
50
49
  rms_norm_eps: float = 1e-06
51
50
  query_pre_attn_scalar: float = 256.0
52
51
  attention_bias: bool = False
@@ -55,6 +54,7 @@ class HFGemma3TextConfigRaw:
55
54
  attn_logit_softcapping: float | None = None
56
55
  head_dim: int = 256
57
56
  max_position_embeddings: int = 131072
57
+ sliding_window_pattern: int = 6
58
58
  rope_theta: float = 1000000.0
59
59
  rope_local_base_freq: float = 10000.0
60
60
  rope_scaling: GemmaRoPEScalingConfig | YarnRopeScalingConfig | None = None
@@ -9,6 +9,8 @@ from lalamo.modules import (
9
9
  DecoderConfig,
10
10
  DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
+ MLXQuantizedLinearConfig,
13
+ MLXQuantizedTiedEmbeddingConfig,
12
14
  NormalizationConfig,
13
15
  SeparableCausalConvConfig,
14
16
  ShortConvConfig,
@@ -20,14 +22,21 @@ from lalamo.modules import (
20
22
  UntiedEmbeddingConfig,
21
23
  UpcastMode,
22
24
  )
25
+ from lalamo.quantization import QuantizationMode
23
26
 
24
27
  from .common import HuggingFaceLMConfig
25
28
 
26
29
 
30
+ @dataclass(frozen=True)
31
+ class QuantizationConfig:
32
+ group_size: int
33
+ bits: int
34
+
35
+
27
36
  @dataclass(frozen=True)
28
37
  class HFLFM2Config(HuggingFaceLMConfig):
29
38
  architectures: list[Literal["Lfm2ForCausalLM"]]
30
- block_auto_adjust_ff_dim: Literal[False]
39
+ block_auto_adjust_ff_dim: bool
31
40
  block_dim: int
32
41
  block_ff_dim: int
33
42
  block_ffn_dim_multiplier: float
@@ -38,16 +47,14 @@ class HFLFM2Config(HuggingFaceLMConfig):
38
47
  block_use_swiglu: bool
39
48
  block_use_xavier_init: bool
40
49
  bos_token_id: int
41
- conv_L_cache: int # noqa: N815
42
- conv_bias: int
50
+ conv_L_cache: int # noqa: N815
51
+ conv_bias: bool
43
52
  conv_dim: int
44
53
  conv_dim_out: int
45
54
  conv_use_xavier_init: bool
46
55
  eos_token_id: int
47
56
  hidden_size: int
48
57
  initializer_range: float
49
- intermediate_size: int
50
- layer_types: list[Literal["conv", "full_attention"]]
51
58
  max_position_embeddings: int
52
59
  model_type: Literal["lfm2"]
53
60
  norm_eps: float
@@ -57,14 +64,21 @@ class HFLFM2Config(HuggingFaceLMConfig):
57
64
  num_key_value_heads: int
58
65
  pad_token_id: int
59
66
  rope_theta: float
60
- theta: float
61
- tie_embedding: bool
62
67
  torch_dtype: Literal["bfloat16"]
63
68
  transformers_version: str
64
69
  use_cache: bool
65
70
  use_pos_enc: bool
66
71
  vocab_size: int
67
72
 
73
+ intermediate_size: int | None = None
74
+ layer_types: list[Literal["conv", "full_attention"]] | None = None
75
+ full_attn_idxs: list[int] | None = None
76
+ tie_embedding: bool = True
77
+ theta: float | None = None
78
+
79
+ quantization: QuantizationConfig | None = None
80
+ quantization_config: QuantizationConfig | None = None
81
+
68
82
  def to_decoder_config(
69
83
  self,
70
84
  context_length: int | None,
@@ -74,7 +88,18 @@ class HFLFM2Config(HuggingFaceLMConfig):
74
88
  ) -> DecoderConfig:
75
89
  assert self.num_attention_heads == self.num_heads
76
90
 
77
- if self.tie_embedding:
91
+ if self.quantization_config is not None:
92
+ assert self.tie_embedding
93
+
94
+ embedding_config = MLXQuantizedTiedEmbeddingConfig(
95
+ input_scale=None,
96
+ logit_soft_cap=None,
97
+ group_size=self.quantization_config.group_size,
98
+ embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
99
+ activation_quantization_mode=None,
100
+ activation_precision=activation_precision,
101
+ )
102
+ elif self.tie_embedding:
78
103
  embedding_config = TiedEmbeddingConfig(
79
104
  input_scale=None,
80
105
  logit_soft_cap=None,
@@ -93,7 +118,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
93
118
  max_sequence_length=context_length or self.max_position_embeddings,
94
119
  )
95
120
 
96
- linear_config = FullPrecisionLinearConfig(activation_precision)
121
+ if self.quantization_config is None:
122
+ linear_config = FullPrecisionLinearConfig(activation_precision)
123
+ else:
124
+ linear_config = MLXQuantizedLinearConfig(
125
+ group_size=self.quantization_config.group_size,
126
+ weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
127
+ activation_quantization_mode=None,
128
+ activation_precision=activation_precision,
129
+ )
97
130
 
98
131
  block_norm_config = NormalizationConfig(
99
132
  scale_precision=activation_precision,
@@ -123,7 +156,7 @@ class HFLFM2Config(HuggingFaceLMConfig):
123
156
 
124
157
  short_conv_config = ShortConvConfig(
125
158
  in_projection_config=linear_config,
126
- conv_config=SeparableCausalConvConfig(activation_precision, has_biases=False),
159
+ conv_config=SeparableCausalConvConfig(activation_precision, has_biases=self.conv_bias),
127
160
  out_projection_config=linear_config,
128
161
  kernel_size=self.conv_L_cache,
129
162
  )
@@ -137,6 +170,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
137
170
  gate_clipping=None,
138
171
  )
139
172
 
173
+ if self.layer_types is not None:
174
+ layer_types = self.layer_types
175
+ elif self.full_attn_idxs is not None:
176
+ layer_types = [
177
+ "full_attention" if i in self.full_attn_idxs else "conv" for i in range(self.num_hidden_layers)
178
+ ]
179
+ else:
180
+ raise RuntimeError("Either layer_types or full_attn_idxs must be present.")
181
+
140
182
  layer_configs = [
141
183
  TransformerLayerConfig(
142
184
  pre_mixer_norm_config=block_norm_config,
@@ -145,7 +187,8 @@ class HFLFM2Config(HuggingFaceLMConfig):
145
187
  pre_mlp_norm_config=block_norm_config,
146
188
  mlp_config=mlp_config,
147
189
  post_mlp_norm_config=None,
148
- ) for layer_type in self.layer_types
190
+ )
191
+ for layer_type in layer_types
149
192
  ]
150
193
 
151
194
  output_norm_config = NormalizationConfig(
@@ -157,13 +200,21 @@ class HFLFM2Config(HuggingFaceLMConfig):
157
200
  subtract_mean=False,
158
201
  )
159
202
 
203
+ if self.intermediate_size is not None:
204
+ hidden_dim = self.intermediate_size
205
+ else:
206
+ hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
207
+ hidden_dim = int(
208
+ (hidden_dim_adjusted + self.block_multiple_of - 1) // self.block_multiple_of * self.block_multiple_of,
209
+ )
210
+
160
211
  transformer_config = TransformerConfig(
161
212
  global_rope_config=rope_config,
162
213
  local_rope_config=None,
163
214
  layer_configs=tuple(layer_configs),
164
215
  output_norm_config=output_norm_config,
165
216
  model_dim=self.hidden_size,
166
- hidden_dim=self.intermediate_size,
217
+ hidden_dim=hidden_dim,
167
218
  context_length=context_length or self.max_position_embeddings,
168
219
  )
169
220
 
@@ -18,6 +18,7 @@ from lalamo.modules import (
18
18
  Mamba2Config,
19
19
  MLXQuantizedLinear,
20
20
  MLXQuantizedTiedEmbedding,
21
+ MLXQuantizedTiedEmbeddingConfig,
21
22
  MLXSemiQuantizedUntiedEmbedding,
22
23
  Normalization,
23
24
  SeparableCausalConv,
@@ -349,9 +350,9 @@ def load_attention(
349
350
  weights_dict: Mapping[str, Array],
350
351
  path: ParameterPath,
351
352
  ) -> Attention:
352
- if (path / "o_proj.weight") in weights_dict:
353
+ if (path / "o_proj.weight") in weights_dict or (path / "o_proj.qweight") in weights_dict:
353
354
  o_proj_name = "o_proj"
354
- elif (path / "out_proj.weight") in weights_dict:
355
+ elif (path / "out_proj.weight") in weights_dict or (path / "out_proj.qweight") in weights_dict:
355
356
  o_proj_name = "out_proj"
356
357
  else:
357
358
  raise NotImplementedError("Can't determine attention output projection name")
@@ -411,6 +412,7 @@ def _load_conv(
411
412
  conv_module: SeparableCausalConv,
412
413
  weights_dict: Mapping[str, Array],
413
414
  path: ParameterPath,
415
+ permute_conv: bool,
414
416
  ) -> SeparableCausalConv:
415
417
  weight_path = path / "conv1d" / "weight"
416
418
  if weight_path not in weights_dict:
@@ -422,6 +424,8 @@ def _load_conv(
422
424
 
423
425
  if weight_path is not None:
424
426
  raw = weights_dict[weight_path]
427
+ if permute_conv:
428
+ raw = jnp.matrix_transpose(raw)
425
429
  conv_weight = raw.squeeze(1) if raw.ndim == 3 else raw
426
430
  else:
427
431
  conv_weight = conv_module.weights
@@ -450,10 +454,11 @@ def load_mamba2(
450
454
  module: Mamba2,
451
455
  weights_dict: Mapping[str, Array],
452
456
  path: ParameterPath,
457
+ permute_conv: bool,
453
458
  ) -> Mamba2:
454
459
  in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
455
460
  out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
456
- conv = _load_conv(module.conv, weights_dict, path)
461
+ conv = _load_conv(module.conv, weights_dict, path, permute_conv)
457
462
 
458
463
  skip_connection_weight_path = path / "D"
459
464
  if skip_connection_weight_path in weights_dict:
@@ -484,10 +489,11 @@ def load_short_conv(
484
489
  module: ShortConv,
485
490
  weights_dict: Mapping[str, Array],
486
491
  path: ParameterPath,
492
+ permute_conv: bool,
487
493
  ) -> ShortConv:
488
494
  in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
489
495
  out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
490
- conv = _load_conv(module.conv, weights_dict, path)
496
+ conv = _load_conv(module.conv, weights_dict, path, permute_conv)
491
497
 
492
498
  return load_parameters(
493
499
  lambda m: (m.in_projection, m.out_projection, m.conv),
@@ -508,6 +514,7 @@ def load_transformer_layer(
508
514
  up_proj_key: str,
509
515
  gate_proj_key: str,
510
516
  down_proj_key: str,
517
+ permute_conv: bool,
511
518
  ) -> TransformerLayer:
512
519
  if module.pre_mixer_norm is not None:
513
520
  pre_attention_norm = load_rmsnorm(
@@ -522,9 +529,9 @@ def load_transformer_layer(
522
529
  if isinstance(module.mixer, Attention):
523
530
  mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
524
531
  elif isinstance(module.mixer, Mamba2):
525
- mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
532
+ mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
526
533
  elif isinstance(module.mixer, ShortConv):
527
- mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key)
534
+ mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key, permute_conv)
528
535
  else:
529
536
  mixer = module.mixer
530
537
 
@@ -678,6 +685,7 @@ def load_huggingface_decoder(
678
685
  embedding_path = decoder_path / "embedding"
679
686
  pre_mixer_norm_key = "input_layernorm"
680
687
  mixer_key = {Mamba2Config: "mixer"}
688
+ permute_conv = False
681
689
  pre_mlp_norm_key = "post_attention_layernorm"
682
690
  mlp_key = "mlp"
683
691
  up_proj_key = "up_proj"
@@ -691,6 +699,7 @@ def load_huggingface_decoder(
691
699
  embedding_path = base_path / "embedding.encoder"
692
700
  pre_mixer_norm_key = "norm"
693
701
  mixer_key = {Mamba2Config: "layer"}
702
+ permute_conv = False
694
703
  pre_mlp_norm_key = "norm"
695
704
  mlp_key = "layer"
696
705
  up_proj_key = "gate_proj"
@@ -704,6 +713,7 @@ def load_huggingface_decoder(
704
713
  embedding_path = decoder_path / "embed_tokens"
705
714
  pre_mixer_norm_key = "operator_norm"
706
715
  mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
716
+ permute_conv = isinstance(module.config.embedding_config, MLXQuantizedTiedEmbeddingConfig)
707
717
  pre_mlp_norm_key = "ffn_norm"
708
718
  mlp_key = "feed_forward"
709
719
  up_proj_key = "w3"
@@ -717,6 +727,7 @@ def load_huggingface_decoder(
717
727
  embedding_path = decoder_path / "embed_tokens"
718
728
  pre_mixer_norm_key = "input_layernorm"
719
729
  mixer_key = {AttentionConfig: "self_attn"}
730
+ permute_conv = False
720
731
  pre_mlp_norm_key = "post_attention_layernorm"
721
732
  mlp_key = "mlp"
722
733
  up_proj_key = "up_proj"
@@ -755,6 +766,7 @@ def load_huggingface_decoder(
755
766
  up_proj_key,
756
767
  gate_proj_key,
757
768
  down_proj_key,
769
+ permute_conv,
758
770
  )
759
771
  for i, layer in enumerate(module.transformer.layers)
760
772
  )
@@ -0,0 +1,31 @@
1
+ from lalamo.model_import.decoder_configs import HFLFM2Config
2
+ from lalamo.quantization import QuantizationMode
3
+
4
+ from .common import ConfigMap, FileSpec, ModelSpec
5
+
6
+ __all__ = ["LFM2_MODELS"]
7
+
8
+
9
+ def _lfm2_repo(size: str, quantization: QuantizationMode | None) -> tuple[str, str]:
10
+ organization = "LiquidAI" if quantization is None else "mlx-community"
11
+ name = f"LFM2-{size}{f'-{quantization.bits}bit' if quantization is not None else ''}"
12
+ return (organization, name)
13
+
14
+
15
+ LFM2_MODELS = [
16
+ ModelSpec(
17
+ vendor="LiquidAI",
18
+ family="LFM2",
19
+ name=_lfm2_repo(size, quantization)[1],
20
+ size=size,
21
+ repo="/".join(_lfm2_repo(size, quantization)),
22
+ config_type=HFLFM2Config,
23
+ quantization=quantization,
24
+ configs=ConfigMap(
25
+ chat_template=FileSpec("chat_template.jinja"),
26
+ ),
27
+ use_cases=tuple(),
28
+ )
29
+ for size in ["350M", "700M", "1.2B", "2.6B"]
30
+ for quantization in [None, *([QuantizationMode.UINT4, QuantizationMode.UINT8] if size != "2.6B" else [])]
31
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.10
3
+ Version: 0.5.12
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -14,6 +14,7 @@ from safetensors.flax import save_file
14
14
  from lalamo.common import flatten_parameters
15
15
  from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, import_model
16
16
  from lalamo.model_import.model_specs import ModelType
17
+ from lalamo.model_import.model_specs.lfm2 import LFM2_MODELS
17
18
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
18
19
  from lalamo.modules import config_converter
19
20
  from tests.test_models import DType, ModelTestSpec
@@ -21,13 +22,16 @@ from tests.test_models import DType, ModelTestSpec
21
22
  MODEL_LIST: list[ModelTestSpec] = [
22
23
  ModelTestSpec("trymirai/chat-moderation-router", DType.FLOAT32),
23
24
  ModelTestSpec("Qwen/Qwen3-0.6B", DType.FLOAT32),
25
+ ModelTestSpec("Qwen/Qwen3-4B-AWQ", DType.FLOAT32),
24
26
  ModelTestSpec("Qwen/Qwen2.5-0.5B-Instruct", DType.FLOAT32),
25
27
  ModelTestSpec("google/gemma-3-1b-it", DType.FLOAT32),
28
+ ModelTestSpec("google/gemma-3-4b-it", DType.FLOAT32),
26
29
  ModelTestSpec("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DType.FLOAT32),
27
30
  ModelTestSpec("meta-llama/Llama-3.2-1B-Instruct", DType.FLOAT32),
28
31
  ModelTestSpec("cartesia-ai/Llamba-1B", DType.FLOAT32),
29
32
  ModelTestSpec("cartesia-ai/Llamba-1B-4bit-mlx", DType.FLOAT32),
30
- ]
33
+ ] + \
34
+ [ModelTestSpec(model.repo, DType.FLOAT32) for model in LFM2_MODELS]
31
35
 
32
36
  MODEL_LIST += (
33
37
  [
@@ -1,11 +1,10 @@
1
1
  import pytest
2
2
 
3
+ from lalamo.model_import.model_specs.lfm2 import LFM2_MODELS
3
4
  from tests.lfm2_tracer import LFM2DecoderTracer
4
5
  from tests.test_models import DType, ModelTestSpec, _test_model
5
6
 
6
- MODEL_LIST = [
7
- ModelTestSpec("LiquidAI/LFM2-2.6B", DType.FLOAT32),
8
- ]
7
+ MODEL_LIST = [ModelTestSpec(model.repo, DType.FLOAT32) for model in LFM2_MODELS if model.quantization is None]
9
8
 
10
9
 
11
10
  @pytest.mark.parametrize("test_spec", MODEL_LIST, ids=[m.model_repo for m in MODEL_LIST])
@@ -1,21 +0,0 @@
1
- from lalamo.model_import.decoder_configs import HFLFM2Config
2
-
3
- from .common import ConfigMap, FileSpec, ModelSpec
4
-
5
- __all__ = ["LFM2_MODELS"]
6
-
7
- LFM2_MODELS = [
8
- ModelSpec(
9
- vendor="LiquidAI",
10
- family="LFM2",
11
- name="LFM2-2.6B",
12
- size="2.6B",
13
- repo="LiquidAI/LFM2-2.6B",
14
- config_type=HFLFM2Config,
15
- quantization=None,
16
- configs=ConfigMap(
17
- chat_template=FileSpec("chat_template.jinja"),
18
- ),
19
- use_cases=tuple(),
20
- ),
21
- ]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes