lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +271 -43
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +17 -7
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -4
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
  48. lalamo-0.4.0.dist-info/RECORD +71 -0
  49. lalamo-0.3.3.dist-info/RECORD +0 -59
  50. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
@@ -8,11 +8,11 @@ from lalamo.modules import (
8
8
  DecoderConfig,
9
9
  TiedEmbeddingConfig,
10
10
  )
11
- from lalamo.modules.activations import Activation
11
+ from lalamo.modules.activations import GELU
12
12
  from lalamo.modules.attention import AttentionConfig
13
13
  from lalamo.modules.decoder_layer import DecoderLayerConfig
14
14
  from lalamo.modules.linear import FullPrecisionLinearConfig
15
- from lalamo.modules.mlp import MLPConfig
15
+ from lalamo.modules.mlp import DenseMLPConfig
16
16
  from lalamo.modules.normalization import RMSNormConfig, UpcastMode
17
17
  from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
18
18
 
@@ -75,7 +75,7 @@ class HFGemma3TextConfigRaw:
75
75
  attention_scale = self.query_pre_attn_scalar**-0.5
76
76
  embedding_config = TiedEmbeddingConfig(
77
77
  input_scale=input_scale,
78
- logits_soft_cap=None,
78
+ logit_soft_cap=None,
79
79
  precision=activation_precision,
80
80
  )
81
81
  rms_norm_config = RMSNormConfig(
@@ -106,13 +106,21 @@ class HFGemma3TextConfigRaw:
106
106
  )
107
107
 
108
108
  linear_config = FullPrecisionLinearConfig(precision=activation_precision)
109
- mlp_config = MLPConfig(linear_config=linear_config, activation=Activation.GELU)
109
+ mlp_config = DenseMLPConfig(
110
+ linear_config=linear_config,
111
+ activation=GELU(),
112
+ has_up_biases=False,
113
+ has_down_biases=False,
114
+ up_clipping=None,
115
+ gate_clipping=None,
116
+ )
110
117
  attention_config = AttentionConfig(
111
118
  qkv_projection_config=linear_config,
112
119
  out_projection_config=linear_config,
113
120
  query_norm_config=rms_norm_config,
114
121
  key_norm_config=rms_norm_config,
115
122
  logit_soft_cap=self.attn_logit_softcapping,
123
+ has_sinks=False,
116
124
  has_qkv_biases=self.attention_bias,
117
125
  has_out_biases=self.attention_bias,
118
126
  )
@@ -145,7 +153,7 @@ class HFGemma3TextConfigRaw:
145
153
 
146
154
  @dataclass(frozen=True)
147
155
  class HFGemma3TextConfig(HFGemma3TextConfigRaw, HuggingFaceConfig):
148
- pass
156
+ torch_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
149
157
 
150
158
 
151
159
  @dataclass(frozen=True)
@@ -162,6 +170,7 @@ class HFGemma3VisionConfig:
162
170
 
163
171
  @dataclass(frozen=True)
164
172
  class HFGemma3Config(HuggingFaceConfig):
173
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
165
174
  architectures: list[Literal["Gemma3ForConditionalGeneration"]]
166
175
  boi_token_index: int
167
176
  eoi_token_index: int
@@ -0,0 +1,195 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ from jaxtyping import DTypeLike
5
+
6
+ from lalamo.modules import (
7
+ AttentionConfig,
8
+ DecoderConfig,
9
+ DecoderLayerConfig,
10
+ DenseMLPConfig,
11
+ FullPrecisionLinearConfig,
12
+ MixtureOfExpertsConfig,
13
+ RMSNormConfig,
14
+ SoftmaxRouting,
15
+ TiedEmbeddingConfig,
16
+ UntiedEmbeddingConfig,
17
+ UpcastMode,
18
+ YARNRoPEConfig,
19
+ )
20
+ from lalamo.modules.activations import SiLU
21
+
22
+ from .common import HuggingFaceConfig
23
+
24
+ __all__ = ["HFGPTOssConfig"]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class YarnRopeScalingConfig:
29
+ factor: float
30
+ beta_fast: float
31
+ beta_slow: float
32
+ original_max_position_embeddings: int
33
+ rope_type: Literal["yarn"]
34
+ truncate: bool
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class HFGPTOssConfig(HuggingFaceConfig):
39
+ # Core HF fields
40
+ architectures: list[Literal["GptOssForCausalLM"]]
41
+ attention_bias: bool
42
+ attention_dropout: float
43
+ eos_token_id: int | list[int]
44
+ hidden_act: Literal["silu"]
45
+ hidden_size: int
46
+ initializer_range: float
47
+ intermediate_size: int
48
+ max_position_embeddings: int
49
+ model_type: Literal["gpt_oss"]
50
+ num_attention_heads: int
51
+ num_hidden_layers: int
52
+ num_key_value_heads: int
53
+ pad_token_id: int
54
+ rms_norm_eps: float
55
+ rope_theta: float
56
+ tie_word_embeddings: bool
57
+ transformers_version: str
58
+ use_cache: bool
59
+ vocab_size: int
60
+
61
+ # GPT-OSS specifics
62
+ layer_types: list[Literal["sliding_attention", "full_attention"]] | None
63
+ sliding_window: int | None
64
+ swiglu_limit: float
65
+ head_dim: int | None
66
+ num_local_experts: int
67
+ num_experts_per_tok: int | None = None
68
+ experts_per_token: int | None = None # some configs may use this alias
69
+ rope_scaling: YarnRopeScalingConfig | None = None
70
+ output_router_logits: bool | None = None
71
+ router_aux_loss_coef: float | None = None
72
+
73
+ def to_decoder_config(
74
+ self,
75
+ context_length: int | None,
76
+ activation_precision: DTypeLike,
77
+ accumulation_precision: DTypeLike,
78
+ ) -> DecoderConfig:
79
+ # Embedding
80
+ if self.tie_word_embeddings:
81
+ embedding_config = TiedEmbeddingConfig(
82
+ input_scale=None,
83
+ logit_soft_cap=None,
84
+ precision=activation_precision,
85
+ )
86
+ else:
87
+ embedding_config = UntiedEmbeddingConfig(
88
+ input_scale=None,
89
+ logit_soft_cap=None,
90
+ precision=activation_precision,
91
+ )
92
+
93
+ if self.rope_scaling is not None and self.rope_scaling.rope_type == "yarn":
94
+ rope_config = YARNRoPEConfig(
95
+ precision=activation_precision,
96
+ base=self.rope_theta,
97
+ max_sequence_length=context_length or self.max_position_embeddings,
98
+ scaling_factor=self.rope_scaling.factor,
99
+ original_context_length=self.rope_scaling.original_max_position_embeddings,
100
+ beta_fast=self.rope_scaling.beta_fast,
101
+ beta_slow=self.rope_scaling.beta_slow,
102
+ truncate=self.rope_scaling.truncate,
103
+ )
104
+ else:
105
+ rope_config = YARNRoPEConfig(
106
+ precision=activation_precision,
107
+ base=self.rope_theta,
108
+ max_sequence_length=context_length or self.max_position_embeddings,
109
+ scaling_factor=1.0,
110
+ original_context_length=self.max_position_embeddings,
111
+ beta_fast=32.0,
112
+ beta_slow=1.0,
113
+ truncate=True,
114
+ )
115
+
116
+ rmsnorm_config = RMSNormConfig(
117
+ scale_precision=activation_precision,
118
+ accumulation_precision=accumulation_precision,
119
+ epsilon=self.rms_norm_eps,
120
+ scale_offset=None,
121
+ upcast_mode=UpcastMode.FULL_LAYER,
122
+ )
123
+
124
+ # Linear layers
125
+ linear_config = FullPrecisionLinearConfig(precision=activation_precision)
126
+
127
+ attention_config = AttentionConfig(
128
+ qkv_projection_config=linear_config,
129
+ out_projection_config=linear_config,
130
+ query_norm_config=None,
131
+ key_norm_config=None,
132
+ logit_soft_cap=None,
133
+ has_sinks=True,
134
+ has_qkv_biases=self.attention_bias,
135
+ has_out_biases=self.attention_bias,
136
+ )
137
+
138
+ # Experts (MoE) scaffold
139
+ # Router: linear with bias; Experts: DenseMLP with SiLU(alpha=1.702) and value/gate clipping
140
+ experts_activation = SiLU(alpha=1.702)
141
+ experts_config = DenseMLPConfig(
142
+ linear_config=linear_config,
143
+ activation=experts_activation,
144
+ has_up_biases=True,
145
+ has_down_biases=True,
146
+ up_clipping=(-self.swiglu_limit + 1.0, self.swiglu_limit + 1.0),
147
+ gate_clipping=(None, self.swiglu_limit),
148
+ )
149
+ moe_config = MixtureOfExpertsConfig(
150
+ mixture_size=self.num_local_experts,
151
+ num_experts_per_token=(self.num_experts_per_tok or self.experts_per_token or 1),
152
+ routing_function=SoftmaxRouting(),
153
+ router_config=linear_config,
154
+ router_has_biases=True,
155
+ expert_config=experts_config,
156
+ )
157
+ decoder_layer_config = DecoderLayerConfig(
158
+ pre_attention_norm_config=rmsnorm_config,
159
+ attention_config=attention_config,
160
+ post_attention_norm_config=None,
161
+ pre_mlp_norm_config=rmsnorm_config,
162
+ mlp_config=moe_config,
163
+ post_mlp_norm_config=None,
164
+ )
165
+
166
+ # Per-layer sliding-window
167
+ if self.layer_types is not None and len(self.layer_types) == self.num_hidden_layers:
168
+ sliding_window_sizes = tuple(
169
+ self.sliding_window if layer_type == "sliding_attention" else None for layer_type in self.layer_types
170
+ )
171
+ else:
172
+ # Fallback: apply the same sliding window to all layers if provided
173
+ sliding_window_sizes = (
174
+ tuple([self.sliding_window] * self.num_hidden_layers) if self.sliding_window is not None else None
175
+ )
176
+
177
+ head_dim = self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads
178
+
179
+ return DecoderConfig(
180
+ embedding_config=embedding_config,
181
+ global_rope_config=rope_config,
182
+ local_rope_config=None,
183
+ layer_config=decoder_layer_config,
184
+ output_norm_config=rmsnorm_config,
185
+ vocab_size=self.vocab_size,
186
+ model_dim=self.hidden_size,
187
+ hidden_dim=self.intermediate_size,
188
+ num_heads=self.num_attention_heads,
189
+ num_groups=self.num_key_value_heads,
190
+ head_dim=head_dim,
191
+ attention_scale=None,
192
+ num_layers=self.num_hidden_layers,
193
+ sliding_window_sizes=sliding_window_sizes,
194
+ context_length=context_length or self.max_position_embeddings,
195
+ )
@@ -4,19 +4,20 @@ from typing import Literal
4
4
  from jaxtyping import DTypeLike
5
5
 
6
6
  from lalamo.modules import (
7
- Activation,
8
7
  AttentionConfig,
9
8
  DecoderConfig,
10
9
  DecoderLayerConfig,
10
+ DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
12
  GroupQuantizedLinearConfig,
13
13
  LlamaRoPEConfig,
14
- MLPConfig,
15
14
  RMSNormConfig,
16
15
  TiedEmbeddingConfig,
17
16
  UnscaledRoPEConfig,
18
17
  UpcastMode,
18
+ YARNRoPEConfig,
19
19
  )
20
+ from lalamo.modules.activations import SiLU
20
21
  from lalamo.modules.embedding import UntiedEmbeddingConfig
21
22
  from lalamo.quantization import QuantizationMode
22
23
 
@@ -34,8 +35,19 @@ class LlamaRopeScalingConfig:
34
35
  rope_type: Literal["llama3"]
35
36
 
36
37
 
38
+ @dataclass(frozen=True)
39
+ class YarnRopeScalingConfig:
40
+ factor: float
41
+ beta_fast: float
42
+ beta_slow: float
43
+ original_max_position_embeddings: int
44
+ rope_type: Literal["yarn"]
45
+ truncate: bool
46
+
47
+
37
48
  @dataclass(frozen=True)
38
49
  class HFLlamaConfig(HuggingFaceConfig):
50
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
39
51
  architectures: list[Literal["LlamaForCausalLM"]]
40
52
  attention_bias: bool
41
53
  attention_dropout: float
@@ -53,7 +65,7 @@ class HFLlamaConfig(HuggingFaceConfig):
53
65
  num_key_value_heads: int
54
66
  pretraining_tp: int
55
67
  rms_norm_eps: float
56
- rope_scaling: LlamaRopeScalingConfig | None
68
+ rope_scaling: LlamaRopeScalingConfig | YarnRopeScalingConfig | None
57
69
  rope_theta: float
58
70
  tie_word_embeddings: bool
59
71
  transformers_version: str
@@ -72,13 +84,13 @@ class HFLlamaConfig(HuggingFaceConfig):
72
84
  if self.tie_word_embeddings:
73
85
  embedding_config = TiedEmbeddingConfig(
74
86
  input_scale=None,
75
- logits_soft_cap=None,
87
+ logit_soft_cap=None,
76
88
  precision=activation_precision,
77
89
  )
78
90
  else:
79
91
  embedding_config = UntiedEmbeddingConfig(
80
92
  input_scale=None,
81
- logits_soft_cap=None,
93
+ logit_soft_cap=None,
82
94
  precision=activation_precision,
83
95
  )
84
96
  if self.rope_scaling is None:
@@ -87,7 +99,18 @@ class HFLlamaConfig(HuggingFaceConfig):
87
99
  base=self.rope_theta,
88
100
  max_sequence_length=context_length or self.max_position_embeddings,
89
101
  )
90
- else:
102
+ elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
103
+ rope_config = YARNRoPEConfig(
104
+ precision=activation_precision,
105
+ base=self.rope_theta,
106
+ max_sequence_length=context_length or self.max_position_embeddings,
107
+ scaling_factor=self.rope_scaling.factor,
108
+ original_context_length=self.rope_scaling.original_max_position_embeddings,
109
+ beta_fast=self.rope_scaling.beta_fast,
110
+ beta_slow=self.rope_scaling.beta_slow,
111
+ truncate=self.rope_scaling.truncate,
112
+ )
113
+ elif isinstance(self.rope_scaling, LlamaRopeScalingConfig):
91
114
  rope_config = LlamaRoPEConfig(
92
115
  precision=activation_precision,
93
116
  base=self.rope_theta,
@@ -97,6 +120,8 @@ class HFLlamaConfig(HuggingFaceConfig):
97
120
  low_frequency_factor=self.rope_scaling.low_freq_factor,
98
121
  high_frequency_factor=self.rope_scaling.high_freq_factor,
99
122
  )
123
+ else:
124
+ raise ValueError("Unsupported rope_scaling configuration")
100
125
  rmsnorm_config = RMSNormConfig(
101
126
  scale_precision=activation_precision,
102
127
  accumulation_precision=accumulation_precision,
@@ -121,12 +146,17 @@ class HFLlamaConfig(HuggingFaceConfig):
121
146
  query_norm_config=None,
122
147
  key_norm_config=None,
123
148
  logit_soft_cap=None,
149
+ has_sinks=False,
124
150
  has_qkv_biases=self.attention_bias,
125
151
  has_out_biases=False,
126
152
  )
127
- mlp_config = MLPConfig(
153
+ mlp_config = DenseMLPConfig(
128
154
  linear_config=linear_config,
129
- activation=Activation.SILU,
155
+ activation=SiLU(),
156
+ has_up_biases=False,
157
+ has_down_biases=False,
158
+ up_clipping=None,
159
+ gate_clipping=None,
130
160
  )
131
161
  decoder_layer_config = DecoderLayerConfig(
132
162
  pre_attention_norm_config=rmsnorm_config,
@@ -4,17 +4,17 @@ from typing import Literal
4
4
  from jaxtyping import DTypeLike
5
5
 
6
6
  from lalamo.modules import (
7
- Activation,
8
7
  AttentionConfig,
9
8
  DecoderConfig,
10
9
  DecoderLayerConfig,
10
+ DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
- MLPConfig,
13
12
  RMSNormConfig,
14
13
  TiedEmbeddingConfig,
15
14
  UnscaledRoPEConfig,
16
15
  UntiedEmbeddingConfig,
17
16
  )
17
+ from lalamo.modules.activations import SiLU
18
18
  from lalamo.modules.normalization import UpcastMode
19
19
 
20
20
  from .common import HuggingFaceConfig
@@ -24,6 +24,7 @@ __all__ = ["HFMistralConfig"]
24
24
 
25
25
  @dataclass(frozen=True)
26
26
  class HFMistralConfig(HuggingFaceConfig):
27
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
27
28
  architectures: list[Literal["MistralForCausalLM"]]
28
29
  attention_dropout: float
29
30
  bos_token_id: int
@@ -57,13 +58,13 @@ class HFMistralConfig(HuggingFaceConfig):
57
58
  if self.tie_word_embeddings:
58
59
  embedding_config = TiedEmbeddingConfig(
59
60
  input_scale=None,
60
- logits_soft_cap=None,
61
+ logit_soft_cap=None,
61
62
  precision=activation_precision,
62
63
  )
63
64
  else:
64
65
  embedding_config = UntiedEmbeddingConfig(
65
66
  input_scale=None,
66
- logits_soft_cap=None,
67
+ logit_soft_cap=None,
67
68
  precision=activation_precision,
68
69
  )
69
70
 
@@ -91,13 +92,18 @@ class HFMistralConfig(HuggingFaceConfig):
91
92
  query_norm_config=None,
92
93
  key_norm_config=None,
93
94
  logit_soft_cap=None,
95
+ has_sinks=False,
94
96
  has_qkv_biases=False,
95
97
  has_out_biases=False,
96
98
  )
97
99
 
98
- mlp_config = MLPConfig(
100
+ mlp_config = DenseMLPConfig(
99
101
  linear_config=linear_config,
100
- activation=Activation.SILU,
102
+ activation=SiLU(),
103
+ has_up_biases=False,
104
+ has_down_biases=False,
105
+ up_clipping=None,
106
+ gate_clipping=None,
101
107
  )
102
108
 
103
109
  decoder_layer_config = DecoderLayerConfig(
@@ -4,19 +4,19 @@ from typing import Literal
4
4
  from jaxtyping import DTypeLike
5
5
 
6
6
  from lalamo.modules import (
7
- Activation,
8
7
  AttentionConfig,
9
8
  DecoderConfig,
10
9
  DecoderLayerConfig,
10
+ DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
12
  GroupQuantizedLinearConfig,
13
- MLPConfig,
14
13
  RMSNormConfig,
15
14
  TiedEmbeddingConfig,
16
15
  UnscaledRoPEConfig,
17
16
  UntiedEmbeddingConfig,
18
17
  UpcastMode,
19
18
  )
19
+ from lalamo.modules.activations import SiLU
20
20
  from lalamo.quantization import QuantizationMode
21
21
 
22
22
  from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
@@ -26,6 +26,7 @@ __all__ = ["HFQwen2Config"]
26
26
 
27
27
  @dataclass(frozen=True)
28
28
  class HFQwen2Config(HuggingFaceConfig):
29
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
29
30
  architectures: list[Literal["Qwen2ForCausalLM"]]
30
31
  attention_dropout: float
31
32
  bos_token_id: int | list[int]
@@ -72,13 +73,13 @@ class HFQwen2Config(HuggingFaceConfig):
72
73
  if self.tie_word_embeddings:
73
74
  embedding_config = TiedEmbeddingConfig(
74
75
  input_scale=None,
75
- logits_soft_cap=None,
76
+ logit_soft_cap=None,
76
77
  precision=activation_precision,
77
78
  )
78
79
  else:
79
80
  embedding_config = UntiedEmbeddingConfig(
80
81
  input_scale=None,
81
- logits_soft_cap=None,
82
+ logit_soft_cap=None,
82
83
  precision=activation_precision,
83
84
  )
84
85
  rope_config = UnscaledRoPEConfig(
@@ -110,12 +111,17 @@ class HFQwen2Config(HuggingFaceConfig):
110
111
  query_norm_config=None,
111
112
  key_norm_config=None,
112
113
  logit_soft_cap=None,
114
+ has_sinks=False,
113
115
  has_qkv_biases=True,
114
116
  has_out_biases=False,
115
117
  )
116
- mlp_config = MLPConfig(
118
+ mlp_config = DenseMLPConfig(
117
119
  linear_config=linear_config,
118
- activation=Activation.SILU,
120
+ activation=SiLU(),
121
+ has_up_biases=False,
122
+ has_down_biases=False,
123
+ up_clipping=None,
124
+ gate_clipping=None,
119
125
  )
120
126
  decoder_layer_config = DecoderLayerConfig(
121
127
  pre_attention_norm_config=rmsnorm_config,
@@ -4,19 +4,19 @@ from typing import Literal
4
4
  from jaxtyping import DTypeLike
5
5
 
6
6
  from lalamo.modules import (
7
- Activation,
8
7
  AttentionConfig,
9
8
  DecoderConfig,
10
9
  DecoderLayerConfig,
10
+ DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
12
  GroupQuantizedLinearConfig,
13
- MLPConfig,
14
13
  RMSNormConfig,
15
14
  TiedEmbeddingConfig,
16
15
  UnscaledRoPEConfig,
17
16
  UntiedEmbeddingConfig,
18
17
  UpcastMode,
19
18
  )
19
+ from lalamo.modules.activations import SiLU
20
20
  from lalamo.quantization import QuantizationMode
21
21
 
22
22
  from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
@@ -26,6 +26,7 @@ __all__ = ["HFQwen3Config"]
26
26
 
27
27
  @dataclass(frozen=True)
28
28
  class HFQwen3Config(HuggingFaceConfig):
29
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
29
30
  attention_bias: bool
30
31
  hidden_act: Literal["silu"]
31
32
  hidden_size: int
@@ -70,13 +71,13 @@ class HFQwen3Config(HuggingFaceConfig):
70
71
  if self.tie_word_embeddings:
71
72
  embedding_config = TiedEmbeddingConfig(
72
73
  input_scale=None,
73
- logits_soft_cap=None,
74
+ logit_soft_cap=None,
74
75
  precision=activation_precision,
75
76
  )
76
77
  else:
77
78
  embedding_config = UntiedEmbeddingConfig(
78
79
  input_scale=None,
79
- logits_soft_cap=None,
80
+ logit_soft_cap=None,
80
81
  precision=activation_precision,
81
82
  )
82
83
  rope_config = UnscaledRoPEConfig(
@@ -108,12 +109,17 @@ class HFQwen3Config(HuggingFaceConfig):
108
109
  query_norm_config=rmsnorm_config,
109
110
  key_norm_config=rmsnorm_config,
110
111
  logit_soft_cap=None,
112
+ has_sinks=False,
111
113
  has_qkv_biases=self.attention_bias,
112
114
  has_out_biases=self.attention_bias,
113
115
  )
114
- mlp_config = MLPConfig(
116
+ mlp_config = DenseMLPConfig(
115
117
  linear_config=linear_config,
116
- activation=Activation.SILU,
118
+ activation=SiLU(),
119
+ has_up_biases=False,
120
+ has_down_biases=False,
121
+ up_clipping=None,
122
+ gate_clipping=None,
117
123
  )
118
124
  decoder_layer_config = DecoderLayerConfig(
119
125
  pre_attention_norm_config=rmsnorm_config,
@@ -72,10 +72,7 @@ class HFTokenizerConfig:
72
72
  def added_tokens(self) -> list[AddedToken]:
73
73
  if self.added_tokens_decoder is None:
74
74
  return []
75
- return [
76
- AddedToken(content=token.content, single_word=token.single_word, normalized=token.normalized)
77
- for token in self.added_tokens_decoder.values()
78
- ]
75
+ return [token.to_added_token() for token in self.added_tokens_decoder.values()]
79
76
 
80
77
  @classmethod
81
78
  def from_json(cls, json_path: Path | str) -> "HFTokenizerConfig":
@@ -1,4 +1,4 @@
1
- from collections.abc import Iterable, Iterator
1
+ from collections.abc import Iterable, Iterator, Mapping
2
2
  from dataclasses import dataclass, replace
3
3
 
4
4
  import jax.numpy as jnp
@@ -6,7 +6,7 @@ from einops import rearrange
6
6
  from jaxtyping import Array, Float, Int
7
7
 
8
8
  from lalamo.common import ParameterPath
9
- from lalamo.modules import MLP, Attention, Decoder, DecoderLayer, QLoRALinear, QuantizedTiedEmbedding, RMSNorm
9
+ from lalamo.modules import Attention, Decoder, DecoderLayer, DenseMLP, QLoRALinear, QuantizedTiedEmbedding, RMSNorm
10
10
 
11
11
  from .common import load_parameters
12
12
 
@@ -43,7 +43,7 @@ def params_selector(module: QLoRALinear) -> tuple:
43
43
 
44
44
 
45
45
  def get_qlora_linear_params(
46
- weights_dict: dict[str, Array],
46
+ weights_dict: Mapping[str, Array],
47
47
  path: ParameterPath,
48
48
  weights_dtype: jnp.dtype,
49
49
  ) -> QLoRALinearParams:
@@ -76,7 +76,7 @@ def load_linear(module: QLoRALinear, weights_dict: dict[str, Array], path: Param
76
76
  return load_parameters(params_selector, module, params)
77
77
 
78
78
 
79
- def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -> MLP:
79
+ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: ParameterPath) -> DenseMLP:
80
80
  if not isinstance(module.up_projection, QLoRALinear):
81
81
  raise TypeError(f"Expected up_projection to be QLoRALinear, got {type(module.up_projection)}")
82
82
  if not isinstance(module.down_projection, QLoRALinear):
@@ -95,7 +95,7 @@ def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -
95
95
  )
96
96
 
97
97
 
98
- def load_rmsnorm(module: RMSNorm, weights_dict: dict[str, Array], path: ParameterPath) -> RMSNorm:
98
+ def load_rmsnorm(module: RMSNorm, weights_dict: Mapping[str, Array], path: ParameterPath) -> RMSNorm:
99
99
  return load_parameters(lambda m: (m.scales,), module, (weights_dict[path / "weight"],))
100
100
 
101
101
 
@@ -131,7 +131,7 @@ def permute_qk_params(
131
131
 
132
132
  def load_attention(
133
133
  module: Attention,
134
- weights_dict: dict[str, Array],
134
+ weights_dict: Mapping[str, Array],
135
135
  path: ParameterPath,
136
136
  ) -> Attention:
137
137
  if not isinstance(module.qkv_projection, QLoRALinear):
@@ -177,7 +177,7 @@ def load_attention(
177
177
 
178
178
  def load_decoder_layer(
179
179
  module: DecoderLayer,
180
- weights_dict: dict[str, Array],
180
+ weights_dict: Mapping[str, Array],
181
181
  path: ParameterPath,
182
182
  ) -> DecoderLayer:
183
183
  if module.post_attention_norm is not None:
@@ -187,6 +187,7 @@ def load_decoder_layer(
187
187
  attention_norm = load_rmsnorm(module.pre_attention_norm, weights_dict, path / "attention_norm")
188
188
  attention = load_attention(module.attention, weights_dict, path / "attention")
189
189
  mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
190
+ assert isinstance(module.mlp, DenseMLP)
190
191
  mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
191
192
  return load_parameters(
192
193
  lambda m: (m.pre_attention_norm, m.attention, m.pre_mlp_norm, m.mlp),
@@ -197,7 +198,7 @@ def load_decoder_layer(
197
198
 
198
199
  def load_embedding(
199
200
  module: QuantizedTiedEmbedding,
200
- weights_dict: dict[str, Array],
201
+ weights_dict: Mapping[str, Array],
201
202
  path: ParameterPath,
202
203
  ) -> QuantizedTiedEmbedding:
203
204
  weights = weights_dict[path / "weight"].astype(module.weights.dtype)
@@ -206,7 +207,7 @@ def load_embedding(
206
207
  return load_parameters(lambda m: (m.weights, m.scales), module, (weights, scales))
207
208
 
208
209
 
209
- def load_executorch(module: Decoder, weights_dict: dict[str, Array]) -> Decoder:
210
+ def load_executorch(module: Decoder, weights_dict: Mapping[str, Array]) -> Decoder:
210
211
  root_path = ParameterPath()
211
212
  if not isinstance(module.embedding, QuantizedTiedEmbedding):
212
213
  raise TypeError(f"Expected embedding to be QuantizedTiedEmbedding, got {type(module.embedding)}")