lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +2 -16
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
- lalamo-0.5.0.dist-info/RECORD +80 -0
- lalamo-0.4.1.dist-info/RECORD +0 -71
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -180,17 +180,18 @@ def load_decoder_layer(
|
|
|
180
180
|
weights_dict: Mapping[str, Array],
|
|
181
181
|
path: ParameterPath,
|
|
182
182
|
) -> DecoderLayer:
|
|
183
|
-
if module.
|
|
183
|
+
if module.post_mixer_norm is not None:
|
|
184
184
|
raise ValueError("Post attention normalization is not supported")
|
|
185
185
|
if module.post_mlp_norm is not None:
|
|
186
186
|
raise ValueError("Post MLP normalization is not supported")
|
|
187
|
-
attention_norm = load_rmsnorm(module.
|
|
188
|
-
|
|
187
|
+
attention_norm = load_rmsnorm(module.pre_mixer_norm, weights_dict, path / "attention_norm")
|
|
188
|
+
assert isinstance(module.mixer, Attention)
|
|
189
|
+
attention = load_attention(module.mixer, weights_dict, path / "attention")
|
|
189
190
|
mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
|
|
190
191
|
assert isinstance(module.mlp, DenseMLP)
|
|
191
192
|
mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
|
|
192
193
|
return load_parameters(
|
|
193
|
-
lambda m: (m.
|
|
194
|
+
lambda m: (m.pre_mixer_norm, m.mixer, m.pre_mlp_norm, m.mlp),
|
|
194
195
|
module,
|
|
195
196
|
(attention_norm, attention, mlp_norm, mlp),
|
|
196
197
|
)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
2
3
|
|
|
3
4
|
import jax.numpy as jnp
|
|
4
5
|
from einops import rearrange
|
|
5
|
-
from jaxtyping import Array
|
|
6
|
+
from jaxtyping import Array, DTypeLike
|
|
6
7
|
|
|
7
8
|
from lalamo.common import ParameterPath
|
|
8
9
|
from lalamo.modules import (
|
|
@@ -13,7 +14,12 @@ from lalamo.modules import (
|
|
|
13
14
|
FullPrecisionLinear,
|
|
14
15
|
GroupQuantizedLinear,
|
|
15
16
|
LinearBase,
|
|
17
|
+
Mamba2,
|
|
18
|
+
MLXQuantizedLinear,
|
|
19
|
+
MLXQuantizedTiedEmbedding,
|
|
20
|
+
MLXSemiQuantizedUntiedEmbedding,
|
|
16
21
|
RMSNorm,
|
|
22
|
+
SeparableCausalConv,
|
|
17
23
|
TiedEmbedding,
|
|
18
24
|
UntiedEmbedding,
|
|
19
25
|
)
|
|
@@ -26,10 +32,10 @@ from .utils import decode_mxfp4, deinterleave_pairwise_columns
|
|
|
26
32
|
__all__ = ["load_huggingface"]
|
|
27
33
|
|
|
28
34
|
|
|
29
|
-
|
|
35
|
+
AWQ_UINT4_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
|
|
30
36
|
|
|
31
37
|
|
|
32
|
-
def
|
|
38
|
+
def _reverse_uint4_order(array: Array, reverse_order: Array) -> Array:
|
|
33
39
|
"""Reverses the AWQ packing order to get the logical order of channels for INT4."""
|
|
34
40
|
pack_factor = 32 // 4
|
|
35
41
|
*_, last_dim = array.shape
|
|
@@ -37,13 +43,13 @@ def _reverse_uint4_awq_order(array: Array) -> Array:
|
|
|
37
43
|
return array
|
|
38
44
|
|
|
39
45
|
array_reshaped = rearrange(array, "... (group pack_factor) -> ... group pack_factor", pack_factor=pack_factor)
|
|
40
|
-
array_reordered = array_reshaped[...,
|
|
46
|
+
array_reordered = array_reshaped[..., reverse_order]
|
|
41
47
|
return rearrange(array_reordered, "... group pack_factor -> ... (group pack_factor)")
|
|
42
48
|
|
|
43
49
|
|
|
44
50
|
def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
|
|
45
|
-
assert packed_weights.dtype
|
|
46
|
-
f"Expected packed_weights to be of dtype jnp.int32, got {packed_weights.dtype}"
|
|
51
|
+
assert packed_weights.dtype in (jnp.int32, jnp.uint32), (
|
|
52
|
+
f"Expected packed_weights to be of dtype jnp.(u)int32, got {packed_weights.dtype}"
|
|
47
53
|
)
|
|
48
54
|
assert 32 % mode.bits == 0
|
|
49
55
|
|
|
@@ -58,29 +64,18 @@ def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
|
|
|
58
64
|
return unpacked
|
|
59
65
|
|
|
60
66
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
) ->
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if mode == QuantizationMode.UINT4:
|
|
72
|
-
unpacked_weights = _reverse_uint4_awq_order(unpacked_weights)
|
|
73
|
-
|
|
74
|
-
assert qzeros.dtype == jnp.int32
|
|
75
|
-
unpacked_zero_points = unpack_int32(qzeros, mode)
|
|
76
|
-
if mode == QuantizationMode.UINT4:
|
|
77
|
-
unpacked_zero_points = _reverse_uint4_awq_order(unpacked_zero_points)
|
|
78
|
-
|
|
79
|
-
weights = unpacked_weights.astype(module.config.activation_precision)
|
|
80
|
-
zero_points = unpacked_zero_points.astype(module.config.activation_precision)
|
|
81
|
-
processed_scales = scales.astype(module.config.activation_precision)
|
|
67
|
+
def _process_quantized_tensor(
|
|
68
|
+
quantized: Array,
|
|
69
|
+
weight_quantization: QuantizationMode,
|
|
70
|
+
activation_precision: DTypeLike,
|
|
71
|
+
reverse_order: Array | None = None,
|
|
72
|
+
) -> Array:
|
|
73
|
+
unpacked = unpack_int32(quantized, weight_quantization)
|
|
74
|
+
if reverse_order is not None:
|
|
75
|
+
assert weight_quantization == QuantizationMode.UINT4, "reverse order only supported on uint4 quant type"
|
|
76
|
+
unpacked = _reverse_uint4_order(unpacked, reverse_order)
|
|
82
77
|
|
|
83
|
-
return
|
|
78
|
+
return unpacked.astype(activation_precision)
|
|
84
79
|
|
|
85
80
|
|
|
86
81
|
def _fuse_full_precision_weights(
|
|
@@ -95,26 +90,39 @@ def _fuse_full_precision_weights(
|
|
|
95
90
|
return jnp.concatenate(weights, axis=0)
|
|
96
91
|
|
|
97
92
|
|
|
93
|
+
@dataclass(frozen=True)
|
|
94
|
+
class QuantizedParamLayout:
|
|
95
|
+
weight: str
|
|
96
|
+
scale: str
|
|
97
|
+
bias: str
|
|
98
|
+
transposed: bool
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
AWQ_QUANTIZED_WEIGHT_LAYOUT = QuantizedParamLayout("qweight", "scales", "qzeros", transposed=True)
|
|
102
|
+
MLX_QUANTIZED_WEIGHT_LAYOUT = QuantizedParamLayout("weight", "scales", "biases", transposed=False)
|
|
103
|
+
|
|
104
|
+
|
|
98
105
|
def _fuse_quantized_weights(
|
|
99
106
|
weights_dict: Mapping[str, Array],
|
|
100
107
|
path: ParameterPath,
|
|
101
108
|
sublayers_to_fuse: list[str] | None,
|
|
109
|
+
quantized_param_layout: QuantizedParamLayout,
|
|
102
110
|
) -> tuple[Array, Array, Array]:
|
|
103
111
|
# Note that AWQ quantized weights are stored transposed relative to full-precision weights
|
|
104
112
|
|
|
105
113
|
if sublayers_to_fuse is None:
|
|
106
|
-
qweights = weights_dict[path /
|
|
107
|
-
qzeros = weights_dict[path /
|
|
108
|
-
scales = weights_dict[path /
|
|
114
|
+
qweights = weights_dict[path / quantized_param_layout.weight]
|
|
115
|
+
qzeros = weights_dict[path / quantized_param_layout.bias]
|
|
116
|
+
scales = weights_dict[path / quantized_param_layout.scale]
|
|
109
117
|
return qweights, qzeros, scales
|
|
110
118
|
|
|
111
|
-
qweights = [weights_dict[path / layer_name /
|
|
112
|
-
qzeros = [weights_dict[path / layer_name /
|
|
113
|
-
scales = [weights_dict[path / layer_name /
|
|
119
|
+
qweights = [weights_dict[path / layer_name / quantized_param_layout.weight] for layer_name in sublayers_to_fuse]
|
|
120
|
+
qzeros = [weights_dict[path / layer_name / quantized_param_layout.bias] for layer_name in sublayers_to_fuse]
|
|
121
|
+
scales = [weights_dict[path / layer_name / quantized_param_layout.scale] for layer_name in sublayers_to_fuse]
|
|
114
122
|
|
|
115
|
-
fused_qweights = jnp.concatenate(qweights, axis=
|
|
116
|
-
fused_qzeros = jnp.concatenate(qzeros, axis=
|
|
117
|
-
fused_scales = jnp.concatenate(scales, axis=
|
|
123
|
+
fused_qweights = jnp.concatenate(qweights, axis=int(quantized_param_layout.transposed))
|
|
124
|
+
fused_qzeros = jnp.concatenate(qzeros, axis=int(quantized_param_layout.transposed))
|
|
125
|
+
fused_scales = jnp.concatenate(scales, axis=int(quantized_param_layout.transposed))
|
|
118
126
|
|
|
119
127
|
return fused_qweights, fused_qzeros, fused_scales
|
|
120
128
|
|
|
@@ -148,34 +156,85 @@ def load_linear(
|
|
|
148
156
|
return load_parameters(lambda m: (m.weights, m.biases), module, (weights, bias))
|
|
149
157
|
|
|
150
158
|
if isinstance(module, GroupQuantizedLinear):
|
|
151
|
-
qweights, qzeros, scales = _fuse_quantized_weights(
|
|
159
|
+
qweights, qzeros, scales = _fuse_quantized_weights(
|
|
160
|
+
weights_dict,
|
|
161
|
+
path,
|
|
162
|
+
sublayers_to_fuse,
|
|
163
|
+
AWQ_QUANTIZED_WEIGHT_LAYOUT,
|
|
164
|
+
)
|
|
165
|
+
weight_quantization = module.config.weight_quantization_mode
|
|
166
|
+
activation_precision = module.activation_precision
|
|
167
|
+
|
|
168
|
+
if weight_quantization == QuantizationMode.UINT4:
|
|
169
|
+
reverse_order = AWQ_UINT4_REVERSE_ORDER
|
|
170
|
+
else:
|
|
171
|
+
reverse_order = None
|
|
152
172
|
|
|
153
|
-
weights
|
|
173
|
+
weights = _process_quantized_tensor(
|
|
154
174
|
qweights,
|
|
175
|
+
weight_quantization,
|
|
176
|
+
activation_precision,
|
|
177
|
+
reverse_order,
|
|
178
|
+
)
|
|
179
|
+
zeros = _process_quantized_tensor(
|
|
155
180
|
qzeros,
|
|
156
|
-
|
|
157
|
-
|
|
181
|
+
weight_quantization,
|
|
182
|
+
activation_precision,
|
|
183
|
+
reverse_order,
|
|
158
184
|
)
|
|
185
|
+
scales = scales.astype(activation_precision)
|
|
159
186
|
|
|
160
187
|
return load_parameters(
|
|
161
188
|
lambda m: (m.weights, m.scales, m.zero_points, m.biases),
|
|
162
189
|
module,
|
|
163
|
-
(weights.T, scales.T,
|
|
190
|
+
(weights.T, scales.T, zeros.T, bias),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if isinstance(module, MLXQuantizedLinear):
|
|
194
|
+
qweights, deq_biases, scales = _fuse_quantized_weights(
|
|
195
|
+
weights_dict,
|
|
196
|
+
path,
|
|
197
|
+
sublayers_to_fuse,
|
|
198
|
+
MLX_QUANTIZED_WEIGHT_LAYOUT,
|
|
199
|
+
)
|
|
200
|
+
weight_quantization = module.config.weight_quantization_mode
|
|
201
|
+
activation_precision = module.activation_precision
|
|
202
|
+
|
|
203
|
+
weights = _process_quantized_tensor(
|
|
204
|
+
qweights,
|
|
205
|
+
weight_quantization,
|
|
206
|
+
activation_precision,
|
|
207
|
+
None,
|
|
208
|
+
)
|
|
209
|
+
scales = scales.astype(activation_precision)
|
|
210
|
+
deq_biases = deq_biases.astype(activation_precision)
|
|
211
|
+
|
|
212
|
+
return load_parameters(
|
|
213
|
+
lambda m: (m.weights, m.scales, m.deq_biases, m.biases),
|
|
214
|
+
module,
|
|
215
|
+
(weights, scales, deq_biases, bias),
|
|
164
216
|
)
|
|
165
217
|
|
|
166
218
|
raise TypeError(f"Unsupported module type for loading: {type(module)}")
|
|
167
219
|
|
|
168
220
|
|
|
169
|
-
def load_mlp(
|
|
221
|
+
def load_mlp(
|
|
222
|
+
module: MLPBase,
|
|
223
|
+
weights_dict: Mapping[str, Array],
|
|
224
|
+
path: ParameterPath,
|
|
225
|
+
up_proj_key: str,
|
|
226
|
+
gate_proj_key: str,
|
|
227
|
+
down_proj_key: str,
|
|
228
|
+
) -> MLPBase:
|
|
170
229
|
if isinstance(module, DenseMLP):
|
|
171
230
|
# Standard dense MLP with separate sublayers.
|
|
172
231
|
up_projection = load_linear(
|
|
173
232
|
module.up_projection,
|
|
174
233
|
weights_dict,
|
|
175
234
|
path,
|
|
176
|
-
sublayers_to_fuse=[
|
|
235
|
+
sublayers_to_fuse=[up_proj_key, gate_proj_key],
|
|
177
236
|
)
|
|
178
|
-
down_projection = load_linear(module.down_projection, weights_dict, path /
|
|
237
|
+
down_projection = load_linear(module.down_projection, weights_dict, path / down_proj_key)
|
|
179
238
|
return load_parameters(
|
|
180
239
|
lambda m: (m.up_projection, m.down_projection),
|
|
181
240
|
module,
|
|
@@ -250,7 +309,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
250
309
|
)
|
|
251
310
|
else:
|
|
252
311
|
# Fallback: recursively load a standard DenseMLP experts module
|
|
253
|
-
experts = load_mlp(module.experts, weights_dict, experts_path)
|
|
312
|
+
experts = load_mlp(module.experts, weights_dict, experts_path, "up_proj", "gate_proj", "down_proj")
|
|
254
313
|
|
|
255
314
|
return load_parameters(
|
|
256
315
|
lambda m: (m.router, m.experts),
|
|
@@ -304,28 +363,107 @@ def load_attention(
|
|
|
304
363
|
)
|
|
305
364
|
|
|
306
365
|
|
|
366
|
+
def _load_mamba_conv(
|
|
367
|
+
conv_module: SeparableCausalConv,
|
|
368
|
+
weights_dict: Mapping[str, Array],
|
|
369
|
+
path: ParameterPath,
|
|
370
|
+
) -> SeparableCausalConv:
|
|
371
|
+
weight_path = path / "conv1d" / "weight"
|
|
372
|
+
if weight_path not in weights_dict:
|
|
373
|
+
weight_path = path / "conv_weight"
|
|
374
|
+
if weight_path not in weights_dict:
|
|
375
|
+
weight_path = None
|
|
376
|
+
|
|
377
|
+
if weight_path is not None:
|
|
378
|
+
raw = weights_dict[weight_path]
|
|
379
|
+
conv_weight = raw.squeeze(1) if raw.ndim == 3 else raw
|
|
380
|
+
else:
|
|
381
|
+
conv_weight = conv_module.weights
|
|
382
|
+
|
|
383
|
+
bias_path = path / "conv1d" / "bias"
|
|
384
|
+
if bias_path not in weights_dict:
|
|
385
|
+
bias_path = path / "conv_bias"
|
|
386
|
+
if bias_path not in weights_dict:
|
|
387
|
+
bias_path = None
|
|
388
|
+
|
|
389
|
+
if bias_path is not None and conv_module.biases is not None:
|
|
390
|
+
conv_bias = weights_dict[bias_path]
|
|
391
|
+
else:
|
|
392
|
+
conv_bias = conv_module.biases
|
|
393
|
+
|
|
394
|
+
return load_parameters(
|
|
395
|
+
lambda m: (m.weights, m.biases),
|
|
396
|
+
conv_module,
|
|
397
|
+
(conv_weight, conv_bias),
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def load_mamba2(
|
|
402
|
+
module: Mamba2,
|
|
403
|
+
weights_dict: Mapping[str, Array],
|
|
404
|
+
path: ParameterPath,
|
|
405
|
+
) -> Mamba2:
|
|
406
|
+
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
407
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
408
|
+
conv = _load_mamba_conv(module.conv, weights_dict, path)
|
|
409
|
+
|
|
410
|
+
skip_connection_weight_path = path / "D"
|
|
411
|
+
if skip_connection_weight_path in weights_dict:
|
|
412
|
+
skip_connection_weight = weights_dict[skip_connection_weight_path]
|
|
413
|
+
else:
|
|
414
|
+
skip_connection_weight = module.skip_connection_weight
|
|
415
|
+
|
|
416
|
+
gate_bias_path = path / "z_bias"
|
|
417
|
+
if gate_bias_path in weights_dict:
|
|
418
|
+
gate_bias = weights_dict[gate_bias_path]
|
|
419
|
+
else:
|
|
420
|
+
gate_bias = module.gate_bias
|
|
421
|
+
|
|
422
|
+
return load_parameters(
|
|
423
|
+
lambda m: (m.in_projection, m.out_projection, m.conv, m.skip_connection_weight, m.gate_bias),
|
|
424
|
+
module,
|
|
425
|
+
(in_projection, out_projection, conv, skip_connection_weight, gate_bias),
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
|
|
307
429
|
def load_decoder_layer(
|
|
308
430
|
module: DecoderLayer,
|
|
309
431
|
weights_dict: Mapping[str, Array],
|
|
310
|
-
|
|
432
|
+
mixer_path: ParameterPath,
|
|
433
|
+
mlp_path: ParameterPath,
|
|
434
|
+
mixer_key: str,
|
|
435
|
+
mlp_key: str,
|
|
436
|
+
pre_mixer_norm_key: str,
|
|
437
|
+
pre_mlp_norm_key: str,
|
|
438
|
+
up_proj_key: str,
|
|
439
|
+
gate_proj_key: str,
|
|
440
|
+
down_proj_key: str,
|
|
311
441
|
) -> DecoderLayer:
|
|
312
442
|
pre_attention_norm = load_rmsnorm(
|
|
313
|
-
module.
|
|
443
|
+
module.pre_mixer_norm,
|
|
314
444
|
weights_dict,
|
|
315
|
-
|
|
445
|
+
mixer_path / pre_mixer_norm_key,
|
|
316
446
|
)
|
|
317
|
-
|
|
318
|
-
|
|
447
|
+
|
|
448
|
+
# Load mixer (attention or mamba)
|
|
449
|
+
if isinstance(module.mixer, Attention):
|
|
450
|
+
mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
451
|
+
elif isinstance(module.mixer, Mamba2):
|
|
452
|
+
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
453
|
+
else:
|
|
454
|
+
mixer = module.mixer
|
|
455
|
+
|
|
456
|
+
if module.post_mixer_norm is not None:
|
|
319
457
|
post_attention_norm = load_rmsnorm(
|
|
320
|
-
module.
|
|
458
|
+
module.post_mixer_norm,
|
|
321
459
|
weights_dict,
|
|
322
|
-
|
|
460
|
+
mixer_path / "post_attention_layernorm",
|
|
323
461
|
)
|
|
324
462
|
|
|
325
463
|
pre_mlp_norm = load_rmsnorm(
|
|
326
464
|
module.pre_mlp_norm,
|
|
327
465
|
weights_dict,
|
|
328
|
-
|
|
466
|
+
mlp_path / "pre_feedforward_layernorm",
|
|
329
467
|
)
|
|
330
468
|
else:
|
|
331
469
|
post_attention_norm = None
|
|
@@ -333,41 +471,92 @@ def load_decoder_layer(
|
|
|
333
471
|
pre_mlp_norm = load_rmsnorm(
|
|
334
472
|
module.pre_mlp_norm,
|
|
335
473
|
weights_dict,
|
|
336
|
-
|
|
474
|
+
mlp_path / pre_mlp_norm_key,
|
|
337
475
|
)
|
|
338
476
|
|
|
339
|
-
mlp = load_mlp(module.mlp, weights_dict,
|
|
477
|
+
mlp = load_mlp(module.mlp, weights_dict, mlp_path / mlp_key, up_proj_key, gate_proj_key, down_proj_key)
|
|
478
|
+
|
|
340
479
|
if module.post_mlp_norm is not None:
|
|
341
480
|
post_mlp_norm = load_rmsnorm(
|
|
342
481
|
module.post_mlp_norm,
|
|
343
482
|
weights_dict,
|
|
344
|
-
|
|
483
|
+
mlp_path / "post_feedforward_layernorm",
|
|
345
484
|
)
|
|
346
485
|
else:
|
|
347
486
|
post_mlp_norm = None
|
|
487
|
+
|
|
348
488
|
return load_parameters(
|
|
349
|
-
lambda m: (m.
|
|
489
|
+
lambda m: (m.pre_mixer_norm, m.mixer, m.post_mixer_norm, m.pre_mlp_norm, m.mlp, m.post_mlp_norm),
|
|
350
490
|
module,
|
|
351
|
-
(pre_attention_norm,
|
|
491
|
+
(pre_attention_norm, mixer, post_attention_norm, pre_mlp_norm, mlp, post_mlp_norm),
|
|
352
492
|
)
|
|
353
493
|
|
|
354
494
|
|
|
355
495
|
def load_tied_embedding(
|
|
356
496
|
module: TiedEmbedding,
|
|
357
497
|
weights_dict: Mapping[str, Array],
|
|
358
|
-
|
|
498
|
+
embedding_path: ParameterPath,
|
|
359
499
|
) -> TiedEmbedding:
|
|
360
|
-
weights = weights_dict[
|
|
500
|
+
weights = weights_dict[embedding_path / "weight"]
|
|
361
501
|
return load_parameters(lambda m: (m.weights,), module, (weights,))
|
|
362
502
|
|
|
363
503
|
|
|
504
|
+
def load_mlx_quantized_tied_embedding(
|
|
505
|
+
module: MLXQuantizedTiedEmbedding,
|
|
506
|
+
weights_dict: Mapping[str, Array],
|
|
507
|
+
embedding_path: ParameterPath,
|
|
508
|
+
) -> MLXQuantizedTiedEmbedding:
|
|
509
|
+
qweights = weights_dict[embedding_path / "weight"]
|
|
510
|
+
qscales = weights_dict[embedding_path / "scales"]
|
|
511
|
+
qbiases = weights_dict[embedding_path / "biases"]
|
|
512
|
+
|
|
513
|
+
weights = _process_quantized_tensor(
|
|
514
|
+
qweights,
|
|
515
|
+
module.config.embedding_quantization_mode,
|
|
516
|
+
module.activation_precision,
|
|
517
|
+
None,
|
|
518
|
+
)
|
|
519
|
+
scales = qscales.astype(module.activation_precision)
|
|
520
|
+
biases = qbiases.astype(module.activation_precision)
|
|
521
|
+
|
|
522
|
+
return load_parameters(lambda m: (m.weights, m.scales, m.biases), module, (weights, scales, biases))
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def load_mlx_semi_quantized_untied_embedding(
|
|
526
|
+
module: MLXSemiQuantizedUntiedEmbedding,
|
|
527
|
+
weights_dict: Mapping[str, Array],
|
|
528
|
+
embedding_path: ParameterPath,
|
|
529
|
+
lm_head_path: ParameterPath,
|
|
530
|
+
) -> MLXSemiQuantizedUntiedEmbedding:
|
|
531
|
+
input_weights = weights_dict[embedding_path / "weight"]
|
|
532
|
+
|
|
533
|
+
output_qweights = weights_dict[lm_head_path / "weight"]
|
|
534
|
+
output_qscales = weights_dict[lm_head_path / "scales"]
|
|
535
|
+
output_qbiases = weights_dict[lm_head_path / "biases"]
|
|
536
|
+
|
|
537
|
+
output_weights = _process_quantized_tensor(
|
|
538
|
+
output_qweights,
|
|
539
|
+
module.config.embedding_quantization_mode,
|
|
540
|
+
module.activation_precision,
|
|
541
|
+
None,
|
|
542
|
+
)
|
|
543
|
+
output_scales = output_qscales.astype(module.activation_precision)
|
|
544
|
+
output_biases = output_qbiases.astype(module.activation_precision)
|
|
545
|
+
|
|
546
|
+
return load_parameters(
|
|
547
|
+
lambda m: (m.input_weights, m.output_weights, m.output_scales, m.output_biases),
|
|
548
|
+
module,
|
|
549
|
+
(input_weights, output_weights, output_scales, output_biases),
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
|
|
364
553
|
def load_untied_embedding(
|
|
365
554
|
module: UntiedEmbedding,
|
|
366
555
|
weights_dict: Mapping[str, Array],
|
|
367
|
-
|
|
556
|
+
embedding_path: ParameterPath,
|
|
368
557
|
lm_head_path: ParameterPath,
|
|
369
558
|
) -> UntiedEmbedding:
|
|
370
|
-
input_weights = weights_dict[
|
|
559
|
+
input_weights = weights_dict[embedding_path / "weight"]
|
|
371
560
|
output_weights = weights_dict[lm_head_path / "weight"]
|
|
372
561
|
return load_parameters(lambda m: (m.input_weights, m.output_weights), module, (input_weights, output_weights))
|
|
373
562
|
|
|
@@ -381,19 +570,82 @@ def load_huggingface(
|
|
|
381
570
|
else:
|
|
382
571
|
base_path = ParameterPath()
|
|
383
572
|
|
|
384
|
-
|
|
385
|
-
|
|
573
|
+
is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
|
|
574
|
+
is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
|
|
575
|
+
if is_llamba_full_precision:
|
|
576
|
+
decoder_path = base_path / "backbone"
|
|
577
|
+
embedding_path = decoder_path / "embedding"
|
|
578
|
+
pre_mixer_norm_key = "input_layernorm"
|
|
579
|
+
mixer_key = "mixer"
|
|
580
|
+
pre_mlp_norm_key = "post_attention_layernorm"
|
|
581
|
+
mlp_key = "mlp"
|
|
582
|
+
up_proj_key = "up_proj"
|
|
583
|
+
gate_proj_key = "gate_proj"
|
|
584
|
+
down_proj_key = "down_proj"
|
|
585
|
+
alternating_layers = False
|
|
586
|
+
norm_key = "final_layernorm"
|
|
587
|
+
lm_head_path = base_path / "lm_head"
|
|
588
|
+
elif is_llamba_mlx:
|
|
589
|
+
decoder_path = base_path / "model"
|
|
590
|
+
embedding_path = base_path / "embedding.encoder"
|
|
591
|
+
pre_mixer_norm_key = "norm"
|
|
592
|
+
mixer_key = "layer"
|
|
593
|
+
pre_mlp_norm_key = "norm"
|
|
594
|
+
mlp_key = "layer"
|
|
595
|
+
up_proj_key = "gate_proj"
|
|
596
|
+
gate_proj_key = "in_proj"
|
|
597
|
+
down_proj_key = "out_proj"
|
|
598
|
+
alternating_layers = True
|
|
599
|
+
norm_key = "norm"
|
|
600
|
+
lm_head_path = base_path / "head.linear"
|
|
601
|
+
else:
|
|
602
|
+
decoder_path = base_path / "model"
|
|
603
|
+
embedding_path = decoder_path / "embed_tokens"
|
|
604
|
+
pre_mixer_norm_key = "input_layernorm"
|
|
605
|
+
mixer_key = "self_attn"
|
|
606
|
+
pre_mlp_norm_key = "post_attention_layernorm"
|
|
607
|
+
mlp_key = "mlp"
|
|
608
|
+
up_proj_key = "up_proj"
|
|
609
|
+
gate_proj_key = "gate_proj"
|
|
610
|
+
down_proj_key = "down_proj"
|
|
611
|
+
alternating_layers = False
|
|
612
|
+
norm_key = "norm"
|
|
613
|
+
lm_head_path = base_path / "lm_head"
|
|
386
614
|
|
|
387
615
|
if isinstance(module.embedding, TiedEmbedding):
|
|
388
|
-
embedding = load_tied_embedding(module.embedding, weights_dict,
|
|
616
|
+
embedding = load_tied_embedding(module.embedding, weights_dict, embedding_path)
|
|
617
|
+
elif isinstance(module.embedding, MLXQuantizedTiedEmbedding):
|
|
618
|
+
embedding = load_mlx_quantized_tied_embedding(module.embedding, weights_dict, embedding_path)
|
|
619
|
+
elif isinstance(module.embedding, MLXSemiQuantizedUntiedEmbedding):
|
|
620
|
+
embedding = load_mlx_semi_quantized_untied_embedding(
|
|
621
|
+
module.embedding,
|
|
622
|
+
weights_dict,
|
|
623
|
+
embedding_path,
|
|
624
|
+
lm_head_path,
|
|
625
|
+
)
|
|
389
626
|
elif isinstance(module.embedding, UntiedEmbedding):
|
|
390
|
-
embedding = load_untied_embedding(module.embedding, weights_dict,
|
|
627
|
+
embedding = load_untied_embedding(module.embedding, weights_dict, embedding_path, lm_head_path)
|
|
391
628
|
else:
|
|
392
629
|
raise TypeError(f"Unsupported embedding type: {type(module.embedding)}")
|
|
630
|
+
|
|
393
631
|
decoder_layers = tuple(
|
|
394
|
-
load_decoder_layer(
|
|
632
|
+
load_decoder_layer(
|
|
633
|
+
layer,
|
|
634
|
+
weights_dict,
|
|
635
|
+
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
636
|
+
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
637
|
+
mixer_key,
|
|
638
|
+
mlp_key,
|
|
639
|
+
pre_mixer_norm_key,
|
|
640
|
+
pre_mlp_norm_key,
|
|
641
|
+
up_proj_key,
|
|
642
|
+
gate_proj_key,
|
|
643
|
+
down_proj_key,
|
|
644
|
+
)
|
|
645
|
+
for i, layer in enumerate(module.layers)
|
|
395
646
|
)
|
|
396
|
-
|
|
647
|
+
|
|
648
|
+
output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / norm_key)
|
|
397
649
|
return load_parameters(
|
|
398
650
|
lambda m: (m.embedding, m.layers, m.output_norm),
|
|
399
651
|
module,
|
|
@@ -4,6 +4,7 @@ from .gemma import GEMMA_MODELS
|
|
|
4
4
|
from .gpt_oss import GPT_OSS_MODELS
|
|
5
5
|
from .huggingface import HUGGINGFACE_MODELS
|
|
6
6
|
from .llama import LLAMA_MODELS
|
|
7
|
+
from .llamba import LLAMBA_MODELS
|
|
7
8
|
from .mistral import MISTRAL_MODELS
|
|
8
9
|
|
|
9
10
|
# from .pleias import PLEIAS_MODELS
|
|
@@ -22,6 +23,7 @@ __all__ = [
|
|
|
22
23
|
|
|
23
24
|
ALL_MODEL_LISTS = [
|
|
24
25
|
LLAMA_MODELS,
|
|
26
|
+
LLAMBA_MODELS,
|
|
25
27
|
DEEPSEEK_MODELS,
|
|
26
28
|
GEMMA_MODELS,
|
|
27
29
|
HUGGINGFACE_MODELS,
|
|
@@ -20,6 +20,7 @@ from lalamo.utils import MapDictValues, open_safetensors
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"ConfigMap",
|
|
22
22
|
"FileSpec",
|
|
23
|
+
"JSONFieldSpec",
|
|
23
24
|
"ModelSpec",
|
|
24
25
|
"UseCase",
|
|
25
26
|
"WeightsType",
|
|
@@ -39,17 +40,21 @@ class WeightsType(Enum):
|
|
|
39
40
|
TORCH = "torch"
|
|
40
41
|
|
|
41
42
|
@contextmanager
|
|
42
|
-
def load(
|
|
43
|
+
def load(
|
|
44
|
+
self,
|
|
45
|
+
filename: Path | str,
|
|
46
|
+
float_dtype: DTypeLike,
|
|
47
|
+
) -> Iterator[tuple[Mapping[str, jnp.ndarray], Mapping[str, str]]]:
|
|
43
48
|
if self == WeightsType.SAFETENSORS:
|
|
44
|
-
with open_safetensors(filename) as weights_dict:
|
|
45
|
-
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict)
|
|
49
|
+
with open_safetensors(filename) as (weights_dict, metadata_dict):
|
|
50
|
+
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
46
51
|
else:
|
|
47
52
|
import torch
|
|
48
53
|
|
|
49
54
|
from lalamo.modules.torch_interop import torch_to_jax
|
|
50
55
|
|
|
51
56
|
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
52
|
-
yield MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
|
|
57
|
+
yield MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights), {}
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
class UseCase(Enum):
|
|
@@ -62,13 +67,19 @@ class FileSpec:
|
|
|
62
67
|
repo: str | None = None
|
|
63
68
|
|
|
64
69
|
|
|
70
|
+
@dataclass(frozen=True)
|
|
71
|
+
class JSONFieldSpec:
|
|
72
|
+
file_spec: FileSpec
|
|
73
|
+
field_name: str
|
|
74
|
+
|
|
75
|
+
|
|
65
76
|
@dataclass(frozen=True)
|
|
66
77
|
class ConfigMap:
|
|
67
78
|
model_config: FileSpec = field(default=FileSpec("config.json"))
|
|
68
79
|
tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
|
|
69
80
|
tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
|
|
70
81
|
generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
|
|
71
|
-
chat_template: FileSpec | None = None
|
|
82
|
+
chat_template: FileSpec | JSONFieldSpec | None = None
|
|
72
83
|
|
|
73
84
|
|
|
74
85
|
def _is_foreign_config_type(t: object) -> bool:
|