lalamo 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +24 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +100 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/METADATA +1 -1
  44. lalamo-0.2.3.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ from jaxtyping import DTypeLike
5
+
6
+ from lalamo.modules import (
7
+ Activation,
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DecoderLayerConfig,
11
+ FullPrecisionLinearConfig,
12
+ GroupQuantizedLinearConfig,
13
+ MLPConfig,
14
+ RMSNormConfig,
15
+ TiedEmbeddingConfig,
16
+ UnscaledRoPEConfig,
17
+ UntiedEmbeddingConfig,
18
+ UpcastMode,
19
+ )
20
+ from lalamo.quantization import QuantizationMode
21
+
22
+ from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
23
+
24
+ __all__ = ["HFQwen3Config"]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class HFQwen3Config(HuggingFaceConfig):
29
+ attention_bias: bool
30
+ hidden_act: Literal["silu"]
31
+ hidden_size: int
32
+ intermediate_size: int
33
+ max_position_embeddings: int
34
+ max_window_layers: int
35
+ model_type: Literal["qwen3"]
36
+ num_attention_heads: int
37
+ num_hidden_layers: int
38
+ num_key_value_heads: int
39
+ rms_norm_eps: float
40
+ rope_theta: float
41
+ sliding_window: int | None
42
+ tie_word_embeddings: bool
43
+ use_sliding_window: bool
44
+ vocab_size: int
45
+ head_dim: int
46
+
47
+ quantization_config: AWQQuantizationConfig | GPTQQuantizationConfig | None = None
48
+
49
+ def _get_sliding_window_sizes(self) -> tuple[int | None, ...]:
50
+ if not self.use_sliding_window:
51
+ return tuple([None] * self.num_hidden_layers)
52
+
53
+ # The HuggingFace Qwen3 implementation's comment states that bottom layers use SWA,
54
+ # but the code (`configuration_qwen3.py`) implements it for the top layers.
55
+ # We are following the code.
56
+ sliding_window_sizes = []
57
+ for i in range(self.num_hidden_layers):
58
+ if i >= self.max_window_layers:
59
+ sliding_window_sizes.append(self.sliding_window)
60
+ else:
61
+ sliding_window_sizes.append(None)
62
+ return tuple(sliding_window_sizes)
63
+
64
+ def to_decoder_config(
65
+ self,
66
+ context_length: int | None,
67
+ activation_precision: DTypeLike,
68
+ accumulation_precision: DTypeLike,
69
+ ) -> DecoderConfig:
70
+ if self.tie_word_embeddings:
71
+ embedding_config = TiedEmbeddingConfig(
72
+ input_scale=None,
73
+ logits_soft_cap=None,
74
+ precision=activation_precision,
75
+ )
76
+ else:
77
+ embedding_config = UntiedEmbeddingConfig(
78
+ input_scale=None,
79
+ logits_soft_cap=None,
80
+ precision=activation_precision,
81
+ )
82
+ rope_config = UnscaledRoPEConfig(
83
+ precision=activation_precision,
84
+ base=self.rope_theta,
85
+ max_sequence_length=self.max_position_embeddings,
86
+ )
87
+ rmsnorm_config = RMSNormConfig(
88
+ scale_precision=activation_precision,
89
+ accumulation_precision=accumulation_precision,
90
+ epsilon=self.rms_norm_eps,
91
+ scale_offset=None,
92
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
93
+ )
94
+ if self.quantization_config is None:
95
+ linear_config = FullPrecisionLinearConfig(
96
+ precision=activation_precision,
97
+ )
98
+ else:
99
+ linear_config = GroupQuantizedLinearConfig(
100
+ group_size=self.quantization_config.group_size,
101
+ weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
102
+ activation_quantization_mode=None,
103
+ activation_precision=activation_precision,
104
+ )
105
+ attention_config = AttentionConfig(
106
+ qkv_projection_config=linear_config,
107
+ out_projection_config=linear_config,
108
+ query_norm_config=rmsnorm_config,
109
+ key_norm_config=rmsnorm_config,
110
+ logit_soft_cap=None,
111
+ has_qkv_biases=self.attention_bias,
112
+ has_out_biases=self.attention_bias,
113
+ )
114
+ mlp_config = MLPConfig(
115
+ linear_config=linear_config,
116
+ activation=Activation.SILU,
117
+ )
118
+ decoder_layer_config = DecoderLayerConfig(
119
+ pre_attention_norm_config=rmsnorm_config,
120
+ attention_config=attention_config,
121
+ post_attention_norm_config=None,
122
+ pre_mlp_norm_config=rmsnorm_config,
123
+ mlp_config=mlp_config,
124
+ post_mlp_norm_config=None,
125
+ )
126
+ return DecoderConfig(
127
+ embedding_config=embedding_config,
128
+ global_rope_config=rope_config,
129
+ local_rope_config=None,
130
+ layer_config=decoder_layer_config,
131
+ output_norm_config=rmsnorm_config,
132
+ vocab_size=self.vocab_size,
133
+ model_dim=self.hidden_size,
134
+ hidden_dim=self.intermediate_size,
135
+ num_heads=self.num_attention_heads,
136
+ num_groups=self.num_key_value_heads,
137
+ head_dim=self.head_dim,
138
+ attention_scale=None,
139
+ num_layers=self.num_hidden_layers,
140
+ sliding_window_sizes=self._get_sliding_window_sizes(),
141
+ context_length=context_length or self.max_position_embeddings,
142
+ )
@@ -0,0 +1,7 @@
1
+ # from .executorch import load_executorch
2
+ from .huggingface import load_huggingface
3
+
4
+ __all__ = [
5
+ # "load_executorch",
6
+ "load_huggingface",
7
+ ]
@@ -0,0 +1,45 @@
1
+ from collections.abc import Callable, Iterable
2
+
3
+ import equinox as eqx
4
+ from jax.tree import leaves_with_path
5
+ from jax.tree_util import keystr
6
+ from jaxtyping import Array, PyTree
7
+
8
+ __all__ = [
9
+ "load_parameters",
10
+ ]
11
+
12
+
13
+ def _get_name(leaf: PyTree, tree: PyTree) -> str:
14
+ for path, value in leaves_with_path(tree):
15
+ if value is leaf:
16
+ return f"~{keystr(path)}"
17
+ raise ValueError(f"Leaf {leaf} not found in tree {tree}")
18
+
19
+
20
+ def _check_compatible(old_value: PyTree, new_value: PyTree, module: eqx.Module) -> None:
21
+ if isinstance(old_value, Array) and isinstance(new_value, Array):
22
+ name = _get_name(old_value, module)
23
+ if old_value.shape != new_value.shape:
24
+ raise ValueError(f"Expected parameter {name} to have shape {old_value.shape}, got {new_value.shape}")
25
+ if old_value.dtype != new_value.dtype:
26
+ raise ValueError(f"Expected parameter {name} to have dtype {old_value.dtype}, got {new_value.dtype}")
27
+ elif type(old_value) is not type(new_value):
28
+ raise TypeError(f"Expected parameter of type {type(old_value)}, got {type(new_value)}")
29
+
30
+
31
+ def load_parameters[M: eqx.Module](
32
+ selector: Callable[[M], Iterable[PyTree]],
33
+ module: M,
34
+ new_values: Iterable[PyTree],
35
+ ) -> M:
36
+ old_values = list(selector(module))
37
+ new_values = list(new_values)
38
+ casted_new_values = []
39
+ for old_value, new_value in zip(old_values, new_values, strict=True):
40
+ _check_compatible(old_value, new_value, module)
41
+ if isinstance(old_value, Array) and isinstance(new_value, Array):
42
+ casted_new_values.append(new_value.astype(old_value.dtype))
43
+ else:
44
+ casted_new_values.append(new_value)
45
+ return eqx.tree_at(selector, module, casted_new_values, is_leaf=lambda x: x is None)
@@ -0,0 +1,223 @@
1
+ from collections.abc import Iterable, Iterator
2
+ from dataclasses import dataclass, replace
3
+
4
+ import jax.numpy as jnp
5
+ from einops import rearrange
6
+ from jaxtyping import Array, Float, Int
7
+
8
+ from lalamo.common import ParameterPath
9
+ from lalamo.modules import MLP, Attention, Decoder, DecoderLayer, QLoRALinear, QuantizedTiedEmbedding, RMSNorm
10
+
11
+ from .common import load_parameters
12
+
13
+ __all__ = ["load_executorch"]
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class QLoRALinearParams:
18
+ weights: Int[Array, "out_channels in_channels"]
19
+ zero_points: Int[Array, "out_channels groups"]
20
+ scales: Float[Array, "out_channels groups"]
21
+ lora_down_weights: Float[Array, "total_lora_channels in_channels"]
22
+ lora_up_weights: tuple[Float[Array, "..."], ...]
23
+
24
+ def __iter__(self) -> Iterator[Array]:
25
+ yield self.weights
26
+ yield self.zero_points
27
+ yield self.scales
28
+ yield self.lora_down_weights
29
+ yield from self.lora_up_weights
30
+
31
+ def __len__(self) -> int:
32
+ return 3 + len(self.lora_up_weights)
33
+
34
+
35
+ def params_selector(module: QLoRALinear) -> tuple:
36
+ return (
37
+ module.weights,
38
+ module.zero_points,
39
+ module.scales,
40
+ module.lora_down_weights,
41
+ *module.lora_up_weights,
42
+ )
43
+
44
+
45
+ def get_qlora_linear_params(
46
+ weights_dict: dict[str, Array],
47
+ path: ParameterPath,
48
+ weights_dtype: jnp.dtype,
49
+ ) -> QLoRALinearParams:
50
+ shift_to_unsigned = 8
51
+
52
+ weights = weights_dict[path / "weight"].astype(weights_dtype)
53
+ scales = weights_dict[path / "scales"]
54
+
55
+ # We don't support signed int4 on the inference side, so we map int4 to uint4 and add zero-points.
56
+ weights = weights + shift_to_unsigned
57
+ zero_points = jnp.ones_like(scales) * shift_to_unsigned
58
+
59
+ lora_down_weights = weights_dict[path / "adaptor" / "A" / "weight"]
60
+ lora_up_weights = (weights_dict[path / "adaptor" / "B" / "weight"],)
61
+ return QLoRALinearParams(weights, scales, zero_points, lora_down_weights, lora_up_weights)
62
+
63
+
64
+ def merge_linear_params(params_list: Iterable[QLoRALinearParams]) -> QLoRALinearParams:
65
+ params_list = list(params_list)
66
+ weights = jnp.concatenate([p.weights for p in params_list], axis=0)
67
+ scales = jnp.concatenate([p.scales for p in params_list], axis=0)
68
+ zero_points = jnp.concatenate([p.zero_points for p in params_list], axis=0)
69
+ lora_down_weights = jnp.concatenate([p.lora_down_weights for p in params_list], axis=0)
70
+ lora_up_weights = tuple(w for p in params_list for w in p.lora_up_weights)
71
+ return QLoRALinearParams(weights, scales, zero_points, lora_down_weights, lora_up_weights)
72
+
73
+
74
+ def load_linear(module: QLoRALinear, weights_dict: dict[str, Array], path: ParameterPath) -> QLoRALinear:
75
+ params = get_qlora_linear_params(weights_dict, path, module.weights.dtype)
76
+ return load_parameters(params_selector, module, params)
77
+
78
+
79
+ def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -> MLP:
80
+ if not isinstance(module.up_projection, QLoRALinear):
81
+ raise TypeError(f"Expected up_projection to be QLoRALinear, got {type(module.up_projection)}")
82
+ if not isinstance(module.down_projection, QLoRALinear):
83
+ raise TypeError(f"Expected down_projection to be QLoRALinear, got {type(module.down_projection)}")
84
+
85
+ up_proj_params = get_qlora_linear_params(weights_dict, path / "w3", module.up_projection.weights.dtype)
86
+ gate_proj_params = get_qlora_linear_params(weights_dict, path / "w1", module.down_projection.weights.dtype)
87
+ down_proj_params = get_qlora_linear_params(weights_dict, path / "w2", module.down_projection.weights.dtype)
88
+
89
+ fused_up_gate_params = merge_linear_params([up_proj_params, gate_proj_params])
90
+
91
+ return load_parameters(
92
+ lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)), # type: ignore
93
+ module,
94
+ (*fused_up_gate_params, *down_proj_params),
95
+ )
96
+
97
+
98
+ def load_rmsnorm(module: RMSNorm, weights_dict: dict[str, Array], path: ParameterPath) -> RMSNorm:
99
+ return load_parameters(lambda m: (m.scales,), module, (weights_dict[path / "weight"],))
100
+
101
+
102
+ def permute_qk_weights(weights: Array, input_dim: int, num_heads: int, head_dim: int) -> Array:
103
+ # Reference: https://github.com/huggingface/transformers/blob/15bd3e61f8d3680ca472c9314ad07584d20f7b81/src/transformers/models/llama/convert_llama_weights_to_hf.py#L222
104
+ return rearrange(
105
+ weights,
106
+ "(heads rotors reim) input_channels -> (heads reim rotors) input_channels",
107
+ heads=num_heads,
108
+ rotors=head_dim // 2,
109
+ reim=2,
110
+ input_channels=input_dim,
111
+ )
112
+
113
+
114
+ def permute_qk_params(
115
+ *,
116
+ params: QLoRALinearParams,
117
+ model_dim: int,
118
+ num_heads: int,
119
+ head_dim: int,
120
+ quantization_group_size: int,
121
+ lora_rank: int,
122
+ ) -> QLoRALinearParams:
123
+ # Read https://github.com/huggingface/transformers/issues/25199 to understand WTF is going on here
124
+ return replace(
125
+ params,
126
+ weights=permute_qk_weights(params.weights, model_dim, num_heads, head_dim),
127
+ scales=permute_qk_weights(params.scales, model_dim // quantization_group_size, num_heads, head_dim),
128
+ lora_up_weights=tuple(permute_qk_weights(w, lora_rank, num_heads, head_dim) for w in params.lora_up_weights),
129
+ )
130
+
131
+
132
+ def load_attention(
133
+ module: Attention,
134
+ weights_dict: dict[str, Array],
135
+ path: ParameterPath,
136
+ ) -> Attention:
137
+ if not isinstance(module.qkv_projection, QLoRALinear):
138
+ raise TypeError(f"Expected qkv_projection to be QLoRALinear, got {type(module.qkv_projection)}")
139
+
140
+ model_dim = module.model_dim
141
+ num_heads = module.num_heads
142
+ num_groups = module.num_groups
143
+ head_dim = module.head_dim
144
+ lora_rank = module.qkv_projection.config.lora_rank
145
+
146
+ q_params = get_qlora_linear_params(weights_dict, path / "wq", module.qkv_projection.weights.dtype)
147
+ q_params = permute_qk_params(
148
+ params=q_params,
149
+ model_dim=model_dim,
150
+ num_heads=num_heads,
151
+ head_dim=head_dim,
152
+ quantization_group_size=module.qkv_projection.config.group_size,
153
+ lora_rank=lora_rank,
154
+ )
155
+
156
+ k_params = get_qlora_linear_params(weights_dict, path / "wk", module.qkv_projection.weights.dtype)
157
+ k_params = permute_qk_params(
158
+ params=k_params,
159
+ model_dim=model_dim,
160
+ num_heads=num_groups,
161
+ head_dim=head_dim,
162
+ quantization_group_size=module.qkv_projection.config.group_size,
163
+ lora_rank=lora_rank,
164
+ )
165
+
166
+ v_params = get_qlora_linear_params(weights_dict, path / "wv", module.qkv_projection.weights.dtype)
167
+
168
+ out_params = get_qlora_linear_params(weights_dict, path / "wo", module.qkv_projection.weights.dtype)
169
+
170
+ qkv_params = merge_linear_params([q_params, k_params, v_params])
171
+ return load_parameters(
172
+ lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)), # type: ignore
173
+ module,
174
+ (*qkv_params, *out_params),
175
+ )
176
+
177
+
178
+ def load_decoder_layer(
179
+ module: DecoderLayer,
180
+ weights_dict: dict[str, Array],
181
+ path: ParameterPath,
182
+ ) -> DecoderLayer:
183
+ if module.post_attention_norm is not None:
184
+ raise ValueError("Post attention normalization is not supported")
185
+ if module.post_mlp_norm is not None:
186
+ raise ValueError("Post MLP normalization is not supported")
187
+ attention_norm = load_rmsnorm(module.pre_attention_norm, weights_dict, path / "attention_norm")
188
+ attention = load_attention(module.attention, weights_dict, path / "attention")
189
+ mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
190
+ mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
191
+ return load_parameters(
192
+ lambda m: (m.pre_attention_norm, m.attention, m.pre_mlp_norm, m.mlp),
193
+ module,
194
+ (attention_norm, attention, mlp_norm, mlp),
195
+ )
196
+
197
+
198
+ def load_embedding(
199
+ module: QuantizedTiedEmbedding,
200
+ weights_dict: dict[str, Array],
201
+ path: ParameterPath,
202
+ ) -> QuantizedTiedEmbedding:
203
+ weights = weights_dict[path / "weight"].astype(module.weights.dtype)
204
+ scales = weights_dict[path / "scales"].squeeze(1)
205
+
206
+ return load_parameters(lambda m: (m.weights, m.scales), module, (weights, scales))
207
+
208
+
209
+ def load_executorch(module: Decoder, weights_dict: dict[str, Array]) -> Decoder:
210
+ root_path = ParameterPath()
211
+ if not isinstance(module.embedding, QuantizedTiedEmbedding):
212
+ raise TypeError(f"Expected embedding to be QuantizedTiedEmbedding, got {type(module.embedding)}")
213
+
214
+ embedding = load_embedding(module.embedding, weights_dict, root_path / "tok_embeddings")
215
+ decoder_layers = tuple(
216
+ load_decoder_layer(layer, weights_dict, root_path / f"layers.{i}") for i, layer in enumerate(module.layers)
217
+ )
218
+ output_norm = load_rmsnorm(module.output_norm, weights_dict, root_path / "norm")
219
+ return load_parameters(
220
+ lambda m: (m.embedding, m.layers, m.output_norm),
221
+ module,
222
+ (embedding, decoder_layers, output_norm),
223
+ )