lalamo 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +4 -18
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.0.dist-info/RECORD +0 -71
  46. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
@@ -180,17 +180,18 @@ def load_decoder_layer(
180
180
  weights_dict: Mapping[str, Array],
181
181
  path: ParameterPath,
182
182
  ) -> DecoderLayer:
183
- if module.post_attention_norm is not None:
183
+ if module.post_mixer_norm is not None:
184
184
  raise ValueError("Post attention normalization is not supported")
185
185
  if module.post_mlp_norm is not None:
186
186
  raise ValueError("Post MLP normalization is not supported")
187
- attention_norm = load_rmsnorm(module.pre_attention_norm, weights_dict, path / "attention_norm")
188
- attention = load_attention(module.attention, weights_dict, path / "attention")
187
+ attention_norm = load_rmsnorm(module.pre_mixer_norm, weights_dict, path / "attention_norm")
188
+ assert isinstance(module.mixer, Attention)
189
+ attention = load_attention(module.mixer, weights_dict, path / "attention")
189
190
  mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
190
191
  assert isinstance(module.mlp, DenseMLP)
191
192
  mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
192
193
  return load_parameters(
193
- lambda m: (m.pre_attention_norm, m.attention, m.pre_mlp_norm, m.mlp),
194
+ lambda m: (m.pre_mixer_norm, m.mixer, m.pre_mlp_norm, m.mlp),
194
195
  module,
195
196
  (attention_norm, attention, mlp_norm, mlp),
196
197
  )
@@ -1,8 +1,9 @@
1
1
  from collections.abc import Mapping
2
+ from dataclasses import dataclass
2
3
 
3
4
  import jax.numpy as jnp
4
5
  from einops import rearrange
5
- from jaxtyping import Array
6
+ from jaxtyping import Array, DTypeLike
6
7
 
7
8
  from lalamo.common import ParameterPath
8
9
  from lalamo.modules import (
@@ -13,7 +14,12 @@ from lalamo.modules import (
13
14
  FullPrecisionLinear,
14
15
  GroupQuantizedLinear,
15
16
  LinearBase,
17
+ Mamba2,
18
+ MLXQuantizedLinear,
19
+ MLXQuantizedTiedEmbedding,
20
+ MLXSemiQuantizedUntiedEmbedding,
16
21
  RMSNorm,
22
+ SeparableCausalConv,
17
23
  TiedEmbedding,
18
24
  UntiedEmbedding,
19
25
  )
@@ -26,10 +32,10 @@ from .utils import decode_mxfp4, deinterleave_pairwise_columns
26
32
  __all__ = ["load_huggingface"]
27
33
 
28
34
 
29
- AWQ_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
35
+ AWQ_UINT4_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
30
36
 
31
37
 
32
- def _reverse_uint4_awq_order(array: Array) -> Array:
38
+ def _reverse_uint4_order(array: Array, reverse_order: Array) -> Array:
33
39
  """Reverses the AWQ packing order to get the logical order of channels for INT4."""
34
40
  pack_factor = 32 // 4
35
41
  *_, last_dim = array.shape
@@ -37,13 +43,13 @@ def _reverse_uint4_awq_order(array: Array) -> Array:
37
43
  return array
38
44
 
39
45
  array_reshaped = rearrange(array, "... (group pack_factor) -> ... group pack_factor", pack_factor=pack_factor)
40
- array_reordered = array_reshaped[..., AWQ_REVERSE_ORDER]
46
+ array_reordered = array_reshaped[..., reverse_order]
41
47
  return rearrange(array_reordered, "... group pack_factor -> ... (group pack_factor)")
42
48
 
43
49
 
44
50
  def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
45
- assert packed_weights.dtype == jnp.int32, (
46
- f"Expected packed_weights to be of dtype jnp.int32, got {packed_weights.dtype}"
51
+ assert packed_weights.dtype in (jnp.int32, jnp.uint32), (
52
+ f"Expected packed_weights to be of dtype jnp.(u)int32, got {packed_weights.dtype}"
47
53
  )
48
54
  assert 32 % mode.bits == 0
49
55
 
@@ -58,29 +64,18 @@ def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
58
64
  return unpacked
59
65
 
60
66
 
61
- def _process_quantized_tensors(
62
- qweights: Array,
63
- qzeros: Array,
64
- scales: Array,
65
- module: GroupQuantizedLinear,
66
- ) -> tuple[Array, Array, Array]:
67
- """Unpacks, recenters, transposes, and casts quantized tensors to the correct dtype."""
68
- mode = module.config.weight_quantization_mode
69
- assert qweights.dtype == jnp.int32
70
- unpacked_weights = unpack_int32(qweights, mode)
71
- if mode == QuantizationMode.UINT4:
72
- unpacked_weights = _reverse_uint4_awq_order(unpacked_weights)
73
-
74
- assert qzeros.dtype == jnp.int32
75
- unpacked_zero_points = unpack_int32(qzeros, mode)
76
- if mode == QuantizationMode.UINT4:
77
- unpacked_zero_points = _reverse_uint4_awq_order(unpacked_zero_points)
78
-
79
- weights = unpacked_weights.astype(module.config.activation_precision)
80
- zero_points = unpacked_zero_points.astype(module.config.activation_precision)
81
- processed_scales = scales.astype(module.config.activation_precision)
67
+ def _process_quantized_tensor(
68
+ quantized: Array,
69
+ weight_quantization: QuantizationMode,
70
+ activation_precision: DTypeLike,
71
+ reverse_order: Array | None = None,
72
+ ) -> Array:
73
+ unpacked = unpack_int32(quantized, weight_quantization)
74
+ if reverse_order is not None:
75
+ assert weight_quantization == QuantizationMode.UINT4, "reverse order only supported on uint4 quant type"
76
+ unpacked = _reverse_uint4_order(unpacked, reverse_order)
82
77
 
83
- return weights, zero_points, processed_scales
78
+ return unpacked.astype(activation_precision)
84
79
 
85
80
 
86
81
  def _fuse_full_precision_weights(
@@ -95,26 +90,39 @@ def _fuse_full_precision_weights(
95
90
  return jnp.concatenate(weights, axis=0)
96
91
 
97
92
 
93
+ @dataclass(frozen=True)
94
+ class QuantizedParamLayout:
95
+ weight: str
96
+ scale: str
97
+ bias: str
98
+ transposed: bool
99
+
100
+
101
+ AWQ_QUANTIZED_WEIGHT_LAYOUT = QuantizedParamLayout("qweight", "scales", "qzeros", transposed=True)
102
+ MLX_QUANTIZED_WEIGHT_LAYOUT = QuantizedParamLayout("weight", "scales", "biases", transposed=False)
103
+
104
+
98
105
  def _fuse_quantized_weights(
99
106
  weights_dict: Mapping[str, Array],
100
107
  path: ParameterPath,
101
108
  sublayers_to_fuse: list[str] | None,
109
+ quantized_param_layout: QuantizedParamLayout,
102
110
  ) -> tuple[Array, Array, Array]:
103
111
  # Note that AWQ quantized weights are stored transposed relative to full-precision weights
104
112
 
105
113
  if sublayers_to_fuse is None:
106
- qweights = weights_dict[path / "qweight"]
107
- qzeros = weights_dict[path / "qzeros"]
108
- scales = weights_dict[path / "scales"]
114
+ qweights = weights_dict[path / quantized_param_layout.weight]
115
+ qzeros = weights_dict[path / quantized_param_layout.bias]
116
+ scales = weights_dict[path / quantized_param_layout.scale]
109
117
  return qweights, qzeros, scales
110
118
 
111
- qweights = [weights_dict[path / layer_name / "qweight"] for layer_name in sublayers_to_fuse]
112
- qzeros = [weights_dict[path / layer_name / "qzeros"] for layer_name in sublayers_to_fuse]
113
- scales = [weights_dict[path / layer_name / "scales"] for layer_name in sublayers_to_fuse]
119
+ qweights = [weights_dict[path / layer_name / quantized_param_layout.weight] for layer_name in sublayers_to_fuse]
120
+ qzeros = [weights_dict[path / layer_name / quantized_param_layout.bias] for layer_name in sublayers_to_fuse]
121
+ scales = [weights_dict[path / layer_name / quantized_param_layout.scale] for layer_name in sublayers_to_fuse]
114
122
 
115
- fused_qweights = jnp.concatenate(qweights, axis=1)
116
- fused_qzeros = jnp.concatenate(qzeros, axis=1)
117
- fused_scales = jnp.concatenate(scales, axis=1)
123
+ fused_qweights = jnp.concatenate(qweights, axis=int(quantized_param_layout.transposed))
124
+ fused_qzeros = jnp.concatenate(qzeros, axis=int(quantized_param_layout.transposed))
125
+ fused_scales = jnp.concatenate(scales, axis=int(quantized_param_layout.transposed))
118
126
 
119
127
  return fused_qweights, fused_qzeros, fused_scales
120
128
 
@@ -148,34 +156,85 @@ def load_linear(
148
156
  return load_parameters(lambda m: (m.weights, m.biases), module, (weights, bias))
149
157
 
150
158
  if isinstance(module, GroupQuantizedLinear):
151
- qweights, qzeros, scales = _fuse_quantized_weights(weights_dict, path, sublayers_to_fuse)
159
+ qweights, qzeros, scales = _fuse_quantized_weights(
160
+ weights_dict,
161
+ path,
162
+ sublayers_to_fuse,
163
+ AWQ_QUANTIZED_WEIGHT_LAYOUT,
164
+ )
165
+ weight_quantization = module.config.weight_quantization_mode
166
+ activation_precision = module.activation_precision
167
+
168
+ if weight_quantization == QuantizationMode.UINT4:
169
+ reverse_order = AWQ_UINT4_REVERSE_ORDER
170
+ else:
171
+ reverse_order = None
152
172
 
153
- weights, zero_points, scales = _process_quantized_tensors(
173
+ weights = _process_quantized_tensor(
154
174
  qweights,
175
+ weight_quantization,
176
+ activation_precision,
177
+ reverse_order,
178
+ )
179
+ zeros = _process_quantized_tensor(
155
180
  qzeros,
156
- scales,
157
- module,
181
+ weight_quantization,
182
+ activation_precision,
183
+ reverse_order,
158
184
  )
185
+ scales = scales.astype(activation_precision)
159
186
 
160
187
  return load_parameters(
161
188
  lambda m: (m.weights, m.scales, m.zero_points, m.biases),
162
189
  module,
163
- (weights.T, scales.T, zero_points.T, bias),
190
+ (weights.T, scales.T, zeros.T, bias),
191
+ )
192
+
193
+ if isinstance(module, MLXQuantizedLinear):
194
+ qweights, deq_biases, scales = _fuse_quantized_weights(
195
+ weights_dict,
196
+ path,
197
+ sublayers_to_fuse,
198
+ MLX_QUANTIZED_WEIGHT_LAYOUT,
199
+ )
200
+ weight_quantization = module.config.weight_quantization_mode
201
+ activation_precision = module.activation_precision
202
+
203
+ weights = _process_quantized_tensor(
204
+ qweights,
205
+ weight_quantization,
206
+ activation_precision,
207
+ None,
208
+ )
209
+ scales = scales.astype(activation_precision)
210
+ deq_biases = deq_biases.astype(activation_precision)
211
+
212
+ return load_parameters(
213
+ lambda m: (m.weights, m.scales, m.deq_biases, m.biases),
214
+ module,
215
+ (weights, scales, deq_biases, bias),
164
216
  )
165
217
 
166
218
  raise TypeError(f"Unsupported module type for loading: {type(module)}")
167
219
 
168
220
 
169
- def load_mlp(module: MLPBase, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLPBase:
221
+ def load_mlp(
222
+ module: MLPBase,
223
+ weights_dict: Mapping[str, Array],
224
+ path: ParameterPath,
225
+ up_proj_key: str,
226
+ gate_proj_key: str,
227
+ down_proj_key: str,
228
+ ) -> MLPBase:
170
229
  if isinstance(module, DenseMLP):
171
230
  # Standard dense MLP with separate sublayers.
172
231
  up_projection = load_linear(
173
232
  module.up_projection,
174
233
  weights_dict,
175
234
  path,
176
- sublayers_to_fuse=["up_proj", "gate_proj"],
235
+ sublayers_to_fuse=[up_proj_key, gate_proj_key],
177
236
  )
178
- down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
237
+ down_projection = load_linear(module.down_projection, weights_dict, path / down_proj_key)
179
238
  return load_parameters(
180
239
  lambda m: (m.up_projection, m.down_projection),
181
240
  module,
@@ -250,7 +309,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
250
309
  )
251
310
  else:
252
311
  # Fallback: recursively load a standard DenseMLP experts module
253
- experts = load_mlp(module.experts, weights_dict, experts_path)
312
+ experts = load_mlp(module.experts, weights_dict, experts_path, "up_proj", "gate_proj", "down_proj")
254
313
 
255
314
  return load_parameters(
256
315
  lambda m: (m.router, m.experts),
@@ -304,28 +363,107 @@ def load_attention(
304
363
  )
305
364
 
306
365
 
366
+ def _load_mamba_conv(
367
+ conv_module: SeparableCausalConv,
368
+ weights_dict: Mapping[str, Array],
369
+ path: ParameterPath,
370
+ ) -> SeparableCausalConv:
371
+ weight_path = path / "conv1d" / "weight"
372
+ if weight_path not in weights_dict:
373
+ weight_path = path / "conv_weight"
374
+ if weight_path not in weights_dict:
375
+ weight_path = None
376
+
377
+ if weight_path is not None:
378
+ raw = weights_dict[weight_path]
379
+ conv_weight = raw.squeeze(1) if raw.ndim == 3 else raw
380
+ else:
381
+ conv_weight = conv_module.weights
382
+
383
+ bias_path = path / "conv1d" / "bias"
384
+ if bias_path not in weights_dict:
385
+ bias_path = path / "conv_bias"
386
+ if bias_path not in weights_dict:
387
+ bias_path = None
388
+
389
+ if bias_path is not None and conv_module.biases is not None:
390
+ conv_bias = weights_dict[bias_path]
391
+ else:
392
+ conv_bias = conv_module.biases
393
+
394
+ return load_parameters(
395
+ lambda m: (m.weights, m.biases),
396
+ conv_module,
397
+ (conv_weight, conv_bias),
398
+ )
399
+
400
+
401
+ def load_mamba2(
402
+ module: Mamba2,
403
+ weights_dict: Mapping[str, Array],
404
+ path: ParameterPath,
405
+ ) -> Mamba2:
406
+ in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
407
+ out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
408
+ conv = _load_mamba_conv(module.conv, weights_dict, path)
409
+
410
+ skip_connection_weight_path = path / "D"
411
+ if skip_connection_weight_path in weights_dict:
412
+ skip_connection_weight = weights_dict[skip_connection_weight_path]
413
+ else:
414
+ skip_connection_weight = module.skip_connection_weight
415
+
416
+ gate_bias_path = path / "z_bias"
417
+ if gate_bias_path in weights_dict:
418
+ gate_bias = weights_dict[gate_bias_path]
419
+ else:
420
+ gate_bias = module.gate_bias
421
+
422
+ return load_parameters(
423
+ lambda m: (m.in_projection, m.out_projection, m.conv, m.skip_connection_weight, m.gate_bias),
424
+ module,
425
+ (in_projection, out_projection, conv, skip_connection_weight, gate_bias),
426
+ )
427
+
428
+
307
429
  def load_decoder_layer(
308
430
  module: DecoderLayer,
309
431
  weights_dict: Mapping[str, Array],
310
- path: ParameterPath,
432
+ mixer_path: ParameterPath,
433
+ mlp_path: ParameterPath,
434
+ mixer_key: str,
435
+ mlp_key: str,
436
+ pre_mixer_norm_key: str,
437
+ pre_mlp_norm_key: str,
438
+ up_proj_key: str,
439
+ gate_proj_key: str,
440
+ down_proj_key: str,
311
441
  ) -> DecoderLayer:
312
442
  pre_attention_norm = load_rmsnorm(
313
- module.pre_attention_norm,
443
+ module.pre_mixer_norm,
314
444
  weights_dict,
315
- path / "input_layernorm",
445
+ mixer_path / pre_mixer_norm_key,
316
446
  )
317
- attention = load_attention(module.attention, weights_dict, path / "self_attn")
318
- if module.post_attention_norm is not None:
447
+
448
+ # Load mixer (attention or mamba)
449
+ if isinstance(module.mixer, Attention):
450
+ mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
451
+ elif isinstance(module.mixer, Mamba2):
452
+ mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
453
+ else:
454
+ mixer = module.mixer
455
+
456
+ if module.post_mixer_norm is not None:
319
457
  post_attention_norm = load_rmsnorm(
320
- module.post_attention_norm,
458
+ module.post_mixer_norm,
321
459
  weights_dict,
322
- path / "post_attention_layernorm",
460
+ mixer_path / "post_attention_layernorm",
323
461
  )
324
462
 
325
463
  pre_mlp_norm = load_rmsnorm(
326
464
  module.pre_mlp_norm,
327
465
  weights_dict,
328
- path / "pre_feedforward_layernorm",
466
+ mlp_path / "pre_feedforward_layernorm",
329
467
  )
330
468
  else:
331
469
  post_attention_norm = None
@@ -333,41 +471,92 @@ def load_decoder_layer(
333
471
  pre_mlp_norm = load_rmsnorm(
334
472
  module.pre_mlp_norm,
335
473
  weights_dict,
336
- path / "post_attention_layernorm",
474
+ mlp_path / pre_mlp_norm_key,
337
475
  )
338
476
 
339
- mlp = load_mlp(module.mlp, weights_dict, path / "mlp")
477
+ mlp = load_mlp(module.mlp, weights_dict, mlp_path / mlp_key, up_proj_key, gate_proj_key, down_proj_key)
478
+
340
479
  if module.post_mlp_norm is not None:
341
480
  post_mlp_norm = load_rmsnorm(
342
481
  module.post_mlp_norm,
343
482
  weights_dict,
344
- path / "post_feedforward_layernorm",
483
+ mlp_path / "post_feedforward_layernorm",
345
484
  )
346
485
  else:
347
486
  post_mlp_norm = None
487
+
348
488
  return load_parameters(
349
- lambda m: (m.pre_attention_norm, m.attention, m.post_attention_norm, m.pre_mlp_norm, m.mlp, m.post_mlp_norm),
489
+ lambda m: (m.pre_mixer_norm, m.mixer, m.post_mixer_norm, m.pre_mlp_norm, m.mlp, m.post_mlp_norm),
350
490
  module,
351
- (pre_attention_norm, attention, post_attention_norm, pre_mlp_norm, mlp, post_mlp_norm),
491
+ (pre_attention_norm, mixer, post_attention_norm, pre_mlp_norm, mlp, post_mlp_norm),
352
492
  )
353
493
 
354
494
 
355
495
  def load_tied_embedding(
356
496
  module: TiedEmbedding,
357
497
  weights_dict: Mapping[str, Array],
358
- decoder_path: ParameterPath,
498
+ embedding_path: ParameterPath,
359
499
  ) -> TiedEmbedding:
360
- weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
500
+ weights = weights_dict[embedding_path / "weight"]
361
501
  return load_parameters(lambda m: (m.weights,), module, (weights,))
362
502
 
363
503
 
504
+ def load_mlx_quantized_tied_embedding(
505
+ module: MLXQuantizedTiedEmbedding,
506
+ weights_dict: Mapping[str, Array],
507
+ embedding_path: ParameterPath,
508
+ ) -> MLXQuantizedTiedEmbedding:
509
+ qweights = weights_dict[embedding_path / "weight"]
510
+ qscales = weights_dict[embedding_path / "scales"]
511
+ qbiases = weights_dict[embedding_path / "biases"]
512
+
513
+ weights = _process_quantized_tensor(
514
+ qweights,
515
+ module.config.embedding_quantization_mode,
516
+ module.activation_precision,
517
+ None,
518
+ )
519
+ scales = qscales.astype(module.activation_precision)
520
+ biases = qbiases.astype(module.activation_precision)
521
+
522
+ return load_parameters(lambda m: (m.weights, m.scales, m.biases), module, (weights, scales, biases))
523
+
524
+
525
+ def load_mlx_semi_quantized_untied_embedding(
526
+ module: MLXSemiQuantizedUntiedEmbedding,
527
+ weights_dict: Mapping[str, Array],
528
+ embedding_path: ParameterPath,
529
+ lm_head_path: ParameterPath,
530
+ ) -> MLXSemiQuantizedUntiedEmbedding:
531
+ input_weights = weights_dict[embedding_path / "weight"]
532
+
533
+ output_qweights = weights_dict[lm_head_path / "weight"]
534
+ output_qscales = weights_dict[lm_head_path / "scales"]
535
+ output_qbiases = weights_dict[lm_head_path / "biases"]
536
+
537
+ output_weights = _process_quantized_tensor(
538
+ output_qweights,
539
+ module.config.embedding_quantization_mode,
540
+ module.activation_precision,
541
+ None,
542
+ )
543
+ output_scales = output_qscales.astype(module.activation_precision)
544
+ output_biases = output_qbiases.astype(module.activation_precision)
545
+
546
+ return load_parameters(
547
+ lambda m: (m.input_weights, m.output_weights, m.output_scales, m.output_biases),
548
+ module,
549
+ (input_weights, output_weights, output_scales, output_biases),
550
+ )
551
+
552
+
364
553
  def load_untied_embedding(
365
554
  module: UntiedEmbedding,
366
555
  weights_dict: Mapping[str, Array],
367
- decoder_path: ParameterPath,
556
+ embedding_path: ParameterPath,
368
557
  lm_head_path: ParameterPath,
369
558
  ) -> UntiedEmbedding:
370
- input_weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
559
+ input_weights = weights_dict[embedding_path / "weight"]
371
560
  output_weights = weights_dict[lm_head_path / "weight"]
372
561
  return load_parameters(lambda m: (m.input_weights, m.output_weights), module, (input_weights, output_weights))
373
562
 
@@ -381,19 +570,82 @@ def load_huggingface(
381
570
  else:
382
571
  base_path = ParameterPath()
383
572
 
384
- decoder_path = base_path / "model"
385
- lm_head_path = base_path / "lm_head"
573
+ is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
574
+ is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
575
+ if is_llamba_full_precision:
576
+ decoder_path = base_path / "backbone"
577
+ embedding_path = decoder_path / "embedding"
578
+ pre_mixer_norm_key = "input_layernorm"
579
+ mixer_key = "mixer"
580
+ pre_mlp_norm_key = "post_attention_layernorm"
581
+ mlp_key = "mlp"
582
+ up_proj_key = "up_proj"
583
+ gate_proj_key = "gate_proj"
584
+ down_proj_key = "down_proj"
585
+ alternating_layers = False
586
+ norm_key = "final_layernorm"
587
+ lm_head_path = base_path / "lm_head"
588
+ elif is_llamba_mlx:
589
+ decoder_path = base_path / "model"
590
+ embedding_path = base_path / "embedding.encoder"
591
+ pre_mixer_norm_key = "norm"
592
+ mixer_key = "layer"
593
+ pre_mlp_norm_key = "norm"
594
+ mlp_key = "layer"
595
+ up_proj_key = "gate_proj"
596
+ gate_proj_key = "in_proj"
597
+ down_proj_key = "out_proj"
598
+ alternating_layers = True
599
+ norm_key = "norm"
600
+ lm_head_path = base_path / "head.linear"
601
+ else:
602
+ decoder_path = base_path / "model"
603
+ embedding_path = decoder_path / "embed_tokens"
604
+ pre_mixer_norm_key = "input_layernorm"
605
+ mixer_key = "self_attn"
606
+ pre_mlp_norm_key = "post_attention_layernorm"
607
+ mlp_key = "mlp"
608
+ up_proj_key = "up_proj"
609
+ gate_proj_key = "gate_proj"
610
+ down_proj_key = "down_proj"
611
+ alternating_layers = False
612
+ norm_key = "norm"
613
+ lm_head_path = base_path / "lm_head"
386
614
 
387
615
  if isinstance(module.embedding, TiedEmbedding):
388
- embedding = load_tied_embedding(module.embedding, weights_dict, decoder_path)
616
+ embedding = load_tied_embedding(module.embedding, weights_dict, embedding_path)
617
+ elif isinstance(module.embedding, MLXQuantizedTiedEmbedding):
618
+ embedding = load_mlx_quantized_tied_embedding(module.embedding, weights_dict, embedding_path)
619
+ elif isinstance(module.embedding, MLXSemiQuantizedUntiedEmbedding):
620
+ embedding = load_mlx_semi_quantized_untied_embedding(
621
+ module.embedding,
622
+ weights_dict,
623
+ embedding_path,
624
+ lm_head_path,
625
+ )
389
626
  elif isinstance(module.embedding, UntiedEmbedding):
390
- embedding = load_untied_embedding(module.embedding, weights_dict, decoder_path, lm_head_path)
627
+ embedding = load_untied_embedding(module.embedding, weights_dict, embedding_path, lm_head_path)
391
628
  else:
392
629
  raise TypeError(f"Unsupported embedding type: {type(module.embedding)}")
630
+
393
631
  decoder_layers = tuple(
394
- load_decoder_layer(layer, weights_dict, decoder_path / "layers" / i) for i, layer in enumerate(module.layers)
632
+ load_decoder_layer(
633
+ layer,
634
+ weights_dict,
635
+ decoder_path / "layers" / ((i * 2) if alternating_layers else i),
636
+ decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
637
+ mixer_key,
638
+ mlp_key,
639
+ pre_mixer_norm_key,
640
+ pre_mlp_norm_key,
641
+ up_proj_key,
642
+ gate_proj_key,
643
+ down_proj_key,
644
+ )
645
+ for i, layer in enumerate(module.layers)
395
646
  )
396
- output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / "norm")
647
+
648
+ output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / norm_key)
397
649
  return load_parameters(
398
650
  lambda m: (m.embedding, m.layers, m.output_norm),
399
651
  module,
@@ -4,6 +4,7 @@ from .gemma import GEMMA_MODELS
4
4
  from .gpt_oss import GPT_OSS_MODELS
5
5
  from .huggingface import HUGGINGFACE_MODELS
6
6
  from .llama import LLAMA_MODELS
7
+ from .llamba import LLAMBA_MODELS
7
8
  from .mistral import MISTRAL_MODELS
8
9
 
9
10
  # from .pleias import PLEIAS_MODELS
@@ -22,6 +23,7 @@ __all__ = [
22
23
 
23
24
  ALL_MODEL_LISTS = [
24
25
  LLAMA_MODELS,
26
+ LLAMBA_MODELS,
25
27
  DEEPSEEK_MODELS,
26
28
  GEMMA_MODELS,
27
29
  HUGGINGFACE_MODELS,
@@ -20,6 +20,7 @@ from lalamo.utils import MapDictValues, open_safetensors
20
20
  __all__ = [
21
21
  "ConfigMap",
22
22
  "FileSpec",
23
+ "JSONFieldSpec",
23
24
  "ModelSpec",
24
25
  "UseCase",
25
26
  "WeightsType",
@@ -39,17 +40,21 @@ class WeightsType(Enum):
39
40
  TORCH = "torch"
40
41
 
41
42
  @contextmanager
42
- def load(self, filename: Path | str, float_dtype: DTypeLike) -> Iterator[Mapping[str, jnp.ndarray]]:
43
+ def load(
44
+ self,
45
+ filename: Path | str,
46
+ float_dtype: DTypeLike,
47
+ ) -> Iterator[tuple[Mapping[str, jnp.ndarray], Mapping[str, str]]]:
43
48
  if self == WeightsType.SAFETENSORS:
44
- with open_safetensors(filename) as weights_dict:
45
- yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict)
49
+ with open_safetensors(filename) as (weights_dict, metadata_dict):
50
+ yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
46
51
  else:
47
52
  import torch
48
53
 
49
54
  from lalamo.modules.torch_interop import torch_to_jax
50
55
 
51
56
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
52
- yield MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
57
+ yield MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights), {}
53
58
 
54
59
 
55
60
  class UseCase(Enum):
@@ -62,13 +67,19 @@ class FileSpec:
62
67
  repo: str | None = None
63
68
 
64
69
 
70
+ @dataclass(frozen=True)
71
+ class JSONFieldSpec:
72
+ file_spec: FileSpec
73
+ field_name: str
74
+
75
+
65
76
  @dataclass(frozen=True)
66
77
  class ConfigMap:
67
78
  model_config: FileSpec = field(default=FileSpec("config.json"))
68
79
  tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
69
80
  tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
70
81
  generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
71
- chat_template: FileSpec | None = None
82
+ chat_template: FileSpec | JSONFieldSpec | None = None
72
83
 
73
84
 
74
85
  def _is_foreign_config_type(t: object) -> bool: