lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +2 -16
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.1.dist-info/RECORD +0 -71
  46. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
@@ -75,6 +76,7 @@ class HFGPTOssConfig(HuggingFaceConfig):
75
76
  context_length: int | None,
76
77
  activation_precision: DTypeLike,
77
78
  accumulation_precision: DTypeLike,
79
+ metadata_dict: Mapping[str, str], # noqa: ARG002
78
80
  ) -> DecoderConfig:
79
81
  # Embedding
80
82
  if self.tie_word_embeddings:
@@ -124,17 +126,6 @@ class HFGPTOssConfig(HuggingFaceConfig):
124
126
  # Linear layers
125
127
  linear_config = FullPrecisionLinearConfig(precision=activation_precision)
126
128
 
127
- attention_config = AttentionConfig(
128
- qkv_projection_config=linear_config,
129
- out_projection_config=linear_config,
130
- query_norm_config=None,
131
- key_norm_config=None,
132
- logit_soft_cap=None,
133
- has_sinks=True,
134
- has_qkv_biases=self.attention_bias,
135
- has_out_biases=self.attention_bias,
136
- )
137
-
138
129
  # Experts (MoE) scaffold
139
130
  # Router: linear with bias; Experts: DenseMLP with SiLU(alpha=1.702) and value/gate clipping
140
131
  experts_activation = SiLU(alpha=1.702)
@@ -154,42 +145,58 @@ class HFGPTOssConfig(HuggingFaceConfig):
154
145
  router_has_biases=True,
155
146
  expert_config=experts_config,
156
147
  )
157
- decoder_layer_config = DecoderLayerConfig(
158
- pre_attention_norm_config=rmsnorm_config,
159
- attention_config=attention_config,
160
- post_attention_norm_config=None,
161
- pre_mlp_norm_config=rmsnorm_config,
162
- mlp_config=moe_config,
163
- post_mlp_norm_config=None,
164
- )
165
148
 
166
149
  # Per-layer sliding-window
167
150
  if self.layer_types is not None and len(self.layer_types) == self.num_hidden_layers:
168
- sliding_window_sizes = tuple(
151
+ sliding_window_sizes = [
169
152
  self.sliding_window if layer_type == "sliding_attention" else None for layer_type in self.layer_types
170
- )
153
+ ]
171
154
  else:
172
155
  # Fallback: apply the same sliding window to all layers if provided
173
156
  sliding_window_sizes = (
174
- tuple([self.sliding_window] * self.num_hidden_layers) if self.sliding_window is not None else None
157
+ [self.sliding_window] * self.num_hidden_layers
158
+ if self.sliding_window is not None
159
+ else [None] * self.num_hidden_layers
175
160
  )
176
161
 
177
162
  head_dim = self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads
178
163
 
164
+ layer_configs = []
165
+ for sliding_window_size in sliding_window_sizes:
166
+ attention_config = AttentionConfig(
167
+ qkv_projection_config=linear_config,
168
+ out_projection_config=linear_config,
169
+ query_norm_config=None,
170
+ key_norm_config=None,
171
+ logit_soft_cap=None,
172
+ has_sinks=True,
173
+ has_qkv_biases=self.attention_bias,
174
+ has_out_biases=self.attention_bias,
175
+ num_heads=self.num_attention_heads,
176
+ num_groups=self.num_key_value_heads,
177
+ head_dim=head_dim,
178
+ is_causal=True,
179
+ scale=None,
180
+ sliding_window_size=sliding_window_size,
181
+ )
182
+ decoder_layer_config = DecoderLayerConfig(
183
+ pre_mixer_norm_config=rmsnorm_config,
184
+ mixer_config=attention_config,
185
+ post_mixer_norm_config=None,
186
+ pre_mlp_norm_config=rmsnorm_config,
187
+ mlp_config=moe_config,
188
+ post_mlp_norm_config=None,
189
+ )
190
+ layer_configs.append(decoder_layer_config)
191
+
179
192
  return DecoderConfig(
180
193
  embedding_config=embedding_config,
181
194
  global_rope_config=rope_config,
182
195
  local_rope_config=None,
183
- layer_config=decoder_layer_config,
196
+ layer_configs=tuple(layer_configs),
184
197
  output_norm_config=rmsnorm_config,
185
198
  vocab_size=self.vocab_size,
186
199
  model_dim=self.hidden_size,
187
200
  hidden_dim=self.intermediate_size,
188
- num_heads=self.num_attention_heads,
189
- num_groups=self.num_key_value_heads,
190
- head_dim=head_dim,
191
- attention_scale=None,
192
- num_layers=self.num_hidden_layers,
193
- sliding_window_sizes=sliding_window_sizes,
194
201
  context_length=context_length or self.max_position_embeddings,
195
202
  )
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
@@ -12,13 +13,13 @@ from lalamo.modules import (
12
13
  GroupQuantizedLinearConfig,
13
14
  LlamaRoPEConfig,
14
15
  RMSNormConfig,
16
+ SiLU,
15
17
  TiedEmbeddingConfig,
16
18
  UnscaledRoPEConfig,
19
+ UntiedEmbeddingConfig,
17
20
  UpcastMode,
18
21
  YARNRoPEConfig,
19
22
  )
20
- from lalamo.modules.activations import SiLU
21
- from lalamo.modules.embedding import UntiedEmbeddingConfig
22
23
  from lalamo.quantization import QuantizationMode
23
24
 
24
25
  from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
@@ -80,6 +81,7 @@ class HFLlamaConfig(HuggingFaceConfig):
80
81
  context_length: int | None,
81
82
  activation_precision: DTypeLike,
82
83
  accumulation_precision: DTypeLike,
84
+ metadata_dict: Mapping[str, str], # noqa: ARG002
83
85
  ) -> DecoderConfig:
84
86
  if self.tie_word_embeddings:
85
87
  embedding_config = TiedEmbeddingConfig(
@@ -149,6 +151,12 @@ class HFLlamaConfig(HuggingFaceConfig):
149
151
  has_sinks=False,
150
152
  has_qkv_biases=self.attention_bias,
151
153
  has_out_biases=False,
154
+ num_heads=self.num_attention_heads,
155
+ num_groups=self.num_key_value_heads,
156
+ head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
157
+ is_causal=True,
158
+ scale=None,
159
+ sliding_window_size=None,
152
160
  )
153
161
  mlp_config = DenseMLPConfig(
154
162
  linear_config=linear_config,
@@ -159,9 +167,9 @@ class HFLlamaConfig(HuggingFaceConfig):
159
167
  gate_clipping=None,
160
168
  )
161
169
  decoder_layer_config = DecoderLayerConfig(
162
- pre_attention_norm_config=rmsnorm_config,
163
- attention_config=attention_config,
164
- post_attention_norm_config=None,
170
+ pre_mixer_norm_config=rmsnorm_config,
171
+ mixer_config=attention_config,
172
+ post_mixer_norm_config=None,
165
173
  pre_mlp_norm_config=rmsnorm_config,
166
174
  mlp_config=mlp_config,
167
175
  post_mlp_norm_config=None,
@@ -170,16 +178,10 @@ class HFLlamaConfig(HuggingFaceConfig):
170
178
  embedding_config=embedding_config,
171
179
  global_rope_config=rope_config,
172
180
  local_rope_config=None,
173
- layer_config=decoder_layer_config,
181
+ layer_configs=(decoder_layer_config,) * self.num_hidden_layers,
174
182
  output_norm_config=rmsnorm_config,
175
183
  vocab_size=self.vocab_size,
176
184
  model_dim=self.hidden_size,
177
185
  hidden_dim=self.intermediate_size,
178
- num_heads=self.num_attention_heads,
179
- num_groups=self.num_key_value_heads,
180
- head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
181
- attention_scale=None,
182
- num_layers=self.num_hidden_layers,
183
- sliding_window_sizes=None,
184
186
  context_length=context_length or self.max_position_embeddings,
185
187
  )
@@ -0,0 +1,170 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass
3
+ from typing import Literal
4
+
5
+ from jaxtyping import DTypeLike
6
+
7
+ from lalamo.modules import (
8
+ DecoderConfig,
9
+ DecoderLayerConfig,
10
+ DenseMLPConfig,
11
+ FullPrecisionLinearConfig,
12
+ Identity,
13
+ Mamba2Config,
14
+ MLXQuantizedLinearConfig,
15
+ MLXSemiQuantizedUntiedEmbeddingConfig,
16
+ RMSNormConfig,
17
+ SeparableCausalConvConfig,
18
+ SiLU,
19
+ TiedEmbeddingConfig,
20
+ UntiedEmbeddingConfig,
21
+ UpcastMode,
22
+ )
23
+ from lalamo.quantization import QuantizationMode
24
+
25
+ from .common import HuggingFaceConfig
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class HFLlambaMlpConfig:
30
+ intermediate_size: int
31
+ bias: bool
32
+ act_fn: Literal["silu"]
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class HFLlambaSsmConfig:
37
+ d_state: int
38
+ n_v_heads: int
39
+ n_qk_heads: int
40
+ expand: int
41
+ activation: Literal["identity"]
42
+ bias: bool
43
+ conv_bias: bool = True
44
+ d_conv: int = 4
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class HFLlambaConfig(HuggingFaceConfig):
49
+ model_type: Literal["llamba"]
50
+ vocab_size: int
51
+ tie_embeddings: bool
52
+ pad_vocab_size_multiple: int
53
+ lm_head_bias: bool
54
+ d_model: int
55
+ n_layer: int
56
+ resid_dropout: float
57
+ norm_epsilon: float
58
+ mlp_cfg: HFLlambaMlpConfig
59
+ ssm_cfg: HFLlambaSsmConfig
60
+
61
+ @property
62
+ def eos_token_ids(self) -> list[int]:
63
+ return [128001, 128008, 128009]
64
+
65
+ def to_decoder_config(
66
+ self,
67
+ context_length: int | None,
68
+ activation_precision: DTypeLike,
69
+ accumulation_precision: DTypeLike,
70
+ metadata_dict: Mapping[str, str],
71
+ ) -> DecoderConfig:
72
+ if "quantization_kwargs.group_size" in metadata_dict:
73
+ embedding_config = MLXSemiQuantizedUntiedEmbeddingConfig(
74
+ input_scale=None,
75
+ logit_soft_cap=None,
76
+ group_size=int(metadata_dict["quantization_kwargs.group_size"]),
77
+ embedding_quantization_mode=QuantizationMode.from_num_bits(int(metadata_dict["quantization_kwargs.bits"])),
78
+ activation_quantization_mode=None,
79
+ activation_precision=activation_precision,
80
+ )
81
+ elif self.tie_embeddings:
82
+ embedding_config = TiedEmbeddingConfig(
83
+ input_scale=None,
84
+ logit_soft_cap=None,
85
+ precision=activation_precision,
86
+ )
87
+ else:
88
+ embedding_config = UntiedEmbeddingConfig(
89
+ input_scale=None,
90
+ logit_soft_cap=None,
91
+ precision=activation_precision,
92
+ )
93
+
94
+ rmsnorm_config = RMSNormConfig(
95
+ scale_precision=activation_precision,
96
+ accumulation_precision=accumulation_precision,
97
+ epsilon=self.norm_epsilon,
98
+ scale_offset=None,
99
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
100
+ )
101
+
102
+ if "quantization_kwargs.group_size" in metadata_dict:
103
+ linear_config = MLXQuantizedLinearConfig(
104
+ group_size=int(metadata_dict["quantization_kwargs.group_size"]),
105
+ weight_quantization_mode=QuantizationMode.from_num_bits(int(metadata_dict["quantization_kwargs.bits"])),
106
+ activation_quantization_mode=None,
107
+ activation_precision=activation_precision,
108
+ )
109
+ else:
110
+ linear_config = FullPrecisionLinearConfig(
111
+ precision=activation_precision,
112
+ )
113
+
114
+ mlp_config = DenseMLPConfig(
115
+ linear_config=linear_config,
116
+ activation=SiLU(),
117
+ has_up_biases=self.mlp_cfg.bias,
118
+ has_down_biases=self.mlp_cfg.bias,
119
+ up_clipping=None,
120
+ gate_clipping=None,
121
+ )
122
+
123
+ inner_dim = self.ssm_cfg.expand * self.d_model
124
+ head_dim = inner_dim // self.ssm_cfg.n_v_heads
125
+
126
+ if self.ssm_cfg.activation == "identity":
127
+ activation = Identity()
128
+ elif self.ssm_cfg.activation == "silu":
129
+ activation = SiLU()
130
+ else:
131
+ activation = SiLU() # fallback
132
+
133
+ mamba_config = Mamba2Config(
134
+ in_projection_config=linear_config,
135
+ out_projection_config=linear_config,
136
+ conv_config=SeparableCausalConvConfig(
137
+ precision=activation_precision,
138
+ has_biases=self.ssm_cfg.conv_bias,
139
+ ),
140
+ activation=activation,
141
+ kernel_size=self.ssm_cfg.d_conv,
142
+ num_heads=self.ssm_cfg.n_v_heads,
143
+ num_groups=self.ssm_cfg.n_qk_heads,
144
+ head_dim=head_dim,
145
+ state_dim=self.ssm_cfg.d_state,
146
+ expansion_factor=self.ssm_cfg.expand,
147
+ has_in_biases=self.ssm_cfg.bias,
148
+ has_out_biases=self.ssm_cfg.bias,
149
+ )
150
+
151
+ decoder_layer_config = DecoderLayerConfig(
152
+ pre_mixer_norm_config=rmsnorm_config,
153
+ mixer_config=mamba_config,
154
+ post_mixer_norm_config=None,
155
+ pre_mlp_norm_config=rmsnorm_config,
156
+ mlp_config=mlp_config,
157
+ post_mlp_norm_config=None,
158
+ )
159
+
160
+ return DecoderConfig(
161
+ embedding_config=embedding_config,
162
+ global_rope_config=None,
163
+ local_rope_config=None,
164
+ layer_configs=(decoder_layer_config,) * self.n_layer,
165
+ output_norm_config=rmsnorm_config,
166
+ vocab_size=self.vocab_size,
167
+ model_dim=self.d_model,
168
+ hidden_dim=self.mlp_cfg.intermediate_size,
169
+ context_length=context_length or 4096,
170
+ )
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
@@ -24,7 +25,6 @@ __all__ = ["HFMistralConfig"]
24
25
 
25
26
  @dataclass(frozen=True)
26
27
  class HFMistralConfig(HuggingFaceConfig):
27
- torch_dtype: Literal["bfloat16", "float16", "float32"]
28
28
  architectures: list[Literal["MistralForCausalLM"]]
29
29
  attention_dropout: float
30
30
  bos_token_id: int
@@ -53,8 +53,8 @@ class HFMistralConfig(HuggingFaceConfig):
53
53
  context_length: int | None,
54
54
  activation_precision: DTypeLike,
55
55
  accumulation_precision: DTypeLike,
56
+ metadata_dict: Mapping[str, str], # noqa: ARG002
56
57
  ) -> DecoderConfig:
57
- # Choose embedding config based on tie_word_embeddings flag
58
58
  if self.tie_word_embeddings:
59
59
  embedding_config = TiedEmbeddingConfig(
60
60
  input_scale=None,
@@ -86,16 +86,7 @@ class HFMistralConfig(HuggingFaceConfig):
86
86
  precision=activation_precision,
87
87
  )
88
88
 
89
- attention_config = AttentionConfig(
90
- qkv_projection_config=linear_config,
91
- out_projection_config=linear_config,
92
- query_norm_config=None,
93
- key_norm_config=None,
94
- logit_soft_cap=None,
95
- has_sinks=False,
96
- has_qkv_biases=False,
97
- has_out_biases=False,
98
- )
89
+ head_dim = self.head_dim or self.hidden_size // self.num_attention_heads
99
90
 
100
91
  mlp_config = DenseMLPConfig(
101
92
  linear_config=linear_config,
@@ -106,33 +97,43 @@ class HFMistralConfig(HuggingFaceConfig):
106
97
  gate_clipping=None,
107
98
  )
108
99
 
109
- decoder_layer_config = DecoderLayerConfig(
110
- pre_attention_norm_config=rmsnorm_config,
111
- attention_config=attention_config,
112
- post_attention_norm_config=None,
113
- pre_mlp_norm_config=rmsnorm_config,
114
- mlp_config=mlp_config,
115
- post_mlp_norm_config=None,
116
- )
100
+ layer_configs = []
101
+ for _ in range(self.num_hidden_layers):
102
+ attention_config = AttentionConfig(
103
+ qkv_projection_config=linear_config,
104
+ out_projection_config=linear_config,
105
+ query_norm_config=None,
106
+ key_norm_config=None,
107
+ logit_soft_cap=None,
108
+ has_sinks=False,
109
+ has_qkv_biases=False,
110
+ has_out_biases=False,
111
+ num_heads=self.num_attention_heads,
112
+ num_groups=self.num_key_value_heads,
113
+ head_dim=head_dim,
114
+ is_causal=True,
115
+ scale=None,
116
+ sliding_window_size=self.sliding_window,
117
+ )
117
118
 
118
- head_dim = self.head_dim or self.hidden_size // self.num_attention_heads
119
+ decoder_layer_config = DecoderLayerConfig(
120
+ pre_mixer_norm_config=rmsnorm_config,
121
+ mixer_config=attention_config,
122
+ post_mixer_norm_config=None,
123
+ pre_mlp_norm_config=rmsnorm_config,
124
+ mlp_config=mlp_config,
125
+ post_mlp_norm_config=None,
126
+ )
127
+ layer_configs.append(decoder_layer_config)
119
128
 
120
129
  return DecoderConfig(
121
130
  embedding_config=embedding_config,
122
131
  global_rope_config=rope_config,
123
132
  local_rope_config=None,
124
- layer_config=decoder_layer_config,
133
+ layer_configs=tuple(layer_configs),
125
134
  output_norm_config=rmsnorm_config,
126
135
  vocab_size=self.vocab_size,
127
136
  model_dim=self.hidden_size,
128
137
  hidden_dim=self.intermediate_size,
129
- num_heads=self.num_attention_heads,
130
- num_groups=self.num_key_value_heads,
131
- head_dim=head_dim,
132
- attention_scale=None,
133
- num_layers=self.num_hidden_layers,
134
- sliding_window_sizes=tuple([self.sliding_window] * self.num_hidden_layers)
135
- if self.sliding_window is not None
136
- else None,
137
138
  context_length=context_length or self.max_position_embeddings,
138
139
  )
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
@@ -69,6 +70,7 @@ class HFQwen2Config(HuggingFaceConfig):
69
70
  context_length: int | None,
70
71
  activation_precision: DTypeLike,
71
72
  accumulation_precision: DTypeLike,
73
+ metadata_dict: Mapping[str, str], # noqa: ARG002
72
74
  ) -> DecoderConfig:
73
75
  if self.tie_word_embeddings:
74
76
  embedding_config = TiedEmbeddingConfig(
@@ -105,16 +107,7 @@ class HFQwen2Config(HuggingFaceConfig):
105
107
  activation_quantization_mode=None,
106
108
  activation_precision=activation_precision,
107
109
  )
108
- attention_config = AttentionConfig(
109
- qkv_projection_config=linear_config,
110
- out_projection_config=linear_config,
111
- query_norm_config=None,
112
- key_norm_config=None,
113
- logit_soft_cap=None,
114
- has_sinks=False,
115
- has_qkv_biases=True,
116
- has_out_biases=False,
117
- )
110
+ head_dim = self.hidden_size // self.num_attention_heads
118
111
  mlp_config = DenseMLPConfig(
119
112
  linear_config=linear_config,
120
113
  activation=SiLU(),
@@ -123,28 +116,43 @@ class HFQwen2Config(HuggingFaceConfig):
123
116
  up_clipping=None,
124
117
  gate_clipping=None,
125
118
  )
126
- decoder_layer_config = DecoderLayerConfig(
127
- pre_attention_norm_config=rmsnorm_config,
128
- attention_config=attention_config,
129
- post_attention_norm_config=None,
130
- pre_mlp_norm_config=rmsnorm_config,
131
- mlp_config=mlp_config,
132
- post_mlp_norm_config=None,
133
- )
119
+
120
+ sliding_window_sizes = self._get_sliding_window_sizes()
121
+ layer_configs = []
122
+ for sliding_window_size in sliding_window_sizes:
123
+ attention_config = AttentionConfig(
124
+ qkv_projection_config=linear_config,
125
+ out_projection_config=linear_config,
126
+ query_norm_config=None,
127
+ key_norm_config=None,
128
+ logit_soft_cap=None,
129
+ has_sinks=False,
130
+ has_qkv_biases=True,
131
+ has_out_biases=False,
132
+ num_heads=self.num_attention_heads,
133
+ num_groups=self.num_key_value_heads,
134
+ head_dim=head_dim,
135
+ is_causal=True,
136
+ scale=None,
137
+ sliding_window_size=sliding_window_size,
138
+ )
139
+ decoder_layer_config = DecoderLayerConfig(
140
+ pre_mixer_norm_config=rmsnorm_config,
141
+ mixer_config=attention_config,
142
+ post_mixer_norm_config=None,
143
+ pre_mlp_norm_config=rmsnorm_config,
144
+ mlp_config=mlp_config,
145
+ post_mlp_norm_config=None,
146
+ )
147
+ layer_configs.append(decoder_layer_config)
134
148
  return DecoderConfig(
135
149
  embedding_config=embedding_config,
136
150
  global_rope_config=rope_config,
137
151
  local_rope_config=None,
138
- layer_config=decoder_layer_config,
152
+ layer_configs=tuple(layer_configs),
139
153
  output_norm_config=rmsnorm_config,
140
154
  vocab_size=self.vocab_size,
141
155
  model_dim=self.hidden_size,
142
156
  hidden_dim=self.intermediate_size,
143
- num_heads=self.num_attention_heads,
144
- num_groups=self.num_key_value_heads,
145
- head_dim=self.hidden_size // self.num_attention_heads,
146
- attention_scale=None,
147
- num_layers=self.num_hidden_layers,
148
- sliding_window_sizes=tuple(self._get_sliding_window_sizes()),
149
157
  context_length=context_length or self.max_position_embeddings,
150
158
  )
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
@@ -17,15 +18,18 @@ from lalamo.modules import (
17
18
  UpcastMode,
18
19
  )
19
20
  from lalamo.modules.activations import SiLU
21
+ from lalamo.modules.embedding import MLXQuantizedTiedEmbeddingConfig
22
+ from lalamo.modules.linear import MLXQuantizedLinearConfig
20
23
  from lalamo.quantization import QuantizationMode
21
24
 
22
- from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
25
+ from .common import HuggingFaceConfig, MLXQuantizationConfig, QuantizationConfigType
23
26
 
24
27
  __all__ = ["HFQwen3Config"]
25
28
 
26
29
 
27
30
  @dataclass(frozen=True)
28
31
  class HFQwen3Config(HuggingFaceConfig):
32
+ eos_token_id: int | list[int]
29
33
  torch_dtype: Literal["bfloat16", "float16", "float32"]
30
34
  attention_bias: bool
31
35
  hidden_act: Literal["silu"]
@@ -45,7 +49,7 @@ class HFQwen3Config(HuggingFaceConfig):
45
49
  vocab_size: int
46
50
  head_dim: int
47
51
 
48
- quantization_config: AWQQuantizationConfig | GPTQQuantizationConfig | None = None
52
+ quantization_config: QuantizationConfigType = None
49
53
 
50
54
  def _get_sliding_window_sizes(self) -> tuple[int | None, ...]:
51
55
  if not self.use_sliding_window:
@@ -67,8 +71,19 @@ class HFQwen3Config(HuggingFaceConfig):
67
71
  context_length: int | None,
68
72
  activation_precision: DTypeLike,
69
73
  accumulation_precision: DTypeLike,
74
+ metadata_dict: Mapping[str, str], # noqa: ARG002
70
75
  ) -> DecoderConfig:
71
- if self.tie_word_embeddings:
76
+ if isinstance(self.quantization_config, MLXQuantizationConfig):
77
+ assert self.tie_word_embeddings, "only tied embeddings are supported"
78
+ embedding_config = MLXQuantizedTiedEmbeddingConfig(
79
+ input_scale=None,
80
+ logit_soft_cap=None,
81
+ group_size=self.quantization_config.group_size,
82
+ embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
83
+ activation_quantization_mode=None,
84
+ activation_precision=activation_precision,
85
+ )
86
+ elif self.tie_word_embeddings:
72
87
  embedding_config = TiedEmbeddingConfig(
73
88
  input_scale=None,
74
89
  logit_soft_cap=None,
@@ -96,6 +111,13 @@ class HFQwen3Config(HuggingFaceConfig):
96
111
  linear_config = FullPrecisionLinearConfig(
97
112
  precision=activation_precision,
98
113
  )
114
+ elif isinstance(self.quantization_config, MLXQuantizationConfig):
115
+ linear_config = MLXQuantizedLinearConfig(
116
+ group_size=self.quantization_config.group_size,
117
+ weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
118
+ activation_quantization_mode=None,
119
+ activation_precision=activation_precision,
120
+ )
99
121
  else:
100
122
  linear_config = GroupQuantizedLinearConfig(
101
123
  group_size=self.quantization_config.group_size,
@@ -103,16 +125,6 @@ class HFQwen3Config(HuggingFaceConfig):
103
125
  activation_quantization_mode=None,
104
126
  activation_precision=activation_precision,
105
127
  )
106
- attention_config = AttentionConfig(
107
- qkv_projection_config=linear_config,
108
- out_projection_config=linear_config,
109
- query_norm_config=rmsnorm_config,
110
- key_norm_config=rmsnorm_config,
111
- logit_soft_cap=None,
112
- has_sinks=False,
113
- has_qkv_biases=self.attention_bias,
114
- has_out_biases=self.attention_bias,
115
- )
116
128
  mlp_config = DenseMLPConfig(
117
129
  linear_config=linear_config,
118
130
  activation=SiLU(),
@@ -121,28 +133,43 @@ class HFQwen3Config(HuggingFaceConfig):
121
133
  up_clipping=None,
122
134
  gate_clipping=None,
123
135
  )
124
- decoder_layer_config = DecoderLayerConfig(
125
- pre_attention_norm_config=rmsnorm_config,
126
- attention_config=attention_config,
127
- post_attention_norm_config=None,
128
- pre_mlp_norm_config=rmsnorm_config,
129
- mlp_config=mlp_config,
130
- post_mlp_norm_config=None,
131
- )
136
+
137
+ sliding_window_sizes = self._get_sliding_window_sizes()
138
+ layer_configs = []
139
+ for sliding_window_size in sliding_window_sizes:
140
+ attention_config = AttentionConfig(
141
+ qkv_projection_config=linear_config,
142
+ out_projection_config=linear_config,
143
+ query_norm_config=rmsnorm_config,
144
+ key_norm_config=rmsnorm_config,
145
+ logit_soft_cap=None,
146
+ has_sinks=False,
147
+ has_qkv_biases=self.attention_bias,
148
+ has_out_biases=self.attention_bias,
149
+ num_heads=self.num_attention_heads,
150
+ num_groups=self.num_key_value_heads,
151
+ head_dim=self.head_dim,
152
+ is_causal=True,
153
+ scale=None,
154
+ sliding_window_size=sliding_window_size,
155
+ )
156
+ decoder_layer_config = DecoderLayerConfig(
157
+ pre_mixer_norm_config=rmsnorm_config,
158
+ mixer_config=attention_config,
159
+ post_mixer_norm_config=None,
160
+ pre_mlp_norm_config=rmsnorm_config,
161
+ mlp_config=mlp_config,
162
+ post_mlp_norm_config=None,
163
+ )
164
+ layer_configs.append(decoder_layer_config)
132
165
  return DecoderConfig(
133
166
  embedding_config=embedding_config,
134
167
  global_rope_config=rope_config,
135
168
  local_rope_config=None,
136
- layer_config=decoder_layer_config,
169
+ layer_configs=tuple(layer_configs),
137
170
  output_norm_config=rmsnorm_config,
138
171
  vocab_size=self.vocab_size,
139
172
  model_dim=self.hidden_size,
140
173
  hidden_dim=self.intermediate_size,
141
- num_heads=self.num_attention_heads,
142
- num_groups=self.num_key_value_heads,
143
- head_dim=self.head_dim,
144
- attention_scale=None,
145
- num_layers=self.num_hidden_layers,
146
- sliding_window_sizes=self._get_sliding_window_sizes(),
147
174
  context_length=context_length or self.max_position_embeddings,
148
175
  )