lalamo 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +3 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +21 -17
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +2 -0
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/METADATA +1 -1
  49. lalamo-0.5.3.dist-info/RECORD +88 -0
  50. lalamo-0.5.1.dist-info/RECORD +0 -80
  51. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/top_level.txt +0 -0
@@ -9,12 +9,13 @@ from lalamo.modules import (
9
9
  AttentionConfig,
10
10
  Decoder,
11
11
  DecoderConfig,
12
- DecoderLayerConfig,
13
12
  DenseMLPConfig,
14
13
  LlamaRoPEConfig,
14
+ NormalizationConfig,
15
15
  QLoRALinearConfig,
16
16
  QuantizedTiedEmbeddingConfig,
17
- RMSNormConfig,
17
+ TransformerConfig,
18
+ TransformerLayerConfig,
18
19
  UpcastMode,
19
20
  )
20
21
  from lalamo.modules.activations import SiLU
@@ -62,7 +63,7 @@ class ExecutorchConfig(ForeignConfig):
62
63
  return jnp.bfloat16
63
64
 
64
65
  @classmethod
65
- def _load_weights(
66
+ def _load_decoder_weights(
66
67
  cls,
67
68
  model: Decoder,
68
69
  weights_dict: Mapping[str, Array],
@@ -119,12 +120,13 @@ class ETLlamaConfig(ExecutorchConfig):
119
120
  low_frequency_factor=LOW_FREQ_FACTOR,
120
121
  high_frequency_factor=HIGH_FREQ_FACTOR,
121
122
  )
122
- rmsnorm_config = RMSNormConfig(
123
+ rmsnorm_config = NormalizationConfig(
123
124
  scale_precision=activation_precision,
124
125
  accumulation_precision=accumulation_precision,
125
126
  epsilon=self.norm_eps,
126
127
  scale_offset=None,
127
128
  upcast_mode=UpcastMode.ONLY_NORMALIZATION,
129
+ subtract_mean=False,
128
130
  )
129
131
  linear_config = QLoRALinearConfig(
130
132
  group_size=self.quantization_args.group_size,
@@ -158,7 +160,7 @@ class ETLlamaConfig(ExecutorchConfig):
158
160
  up_clipping=None,
159
161
  gate_clipping=None,
160
162
  )
161
- decoder_layer_config = DecoderLayerConfig(
163
+ tranformer_layer_config = TransformerLayerConfig(
162
164
  pre_mixer_norm_config=rmsnorm_config,
163
165
  mixer_config=attention_config,
164
166
  post_mixer_norm_config=None,
@@ -166,14 +168,17 @@ class ETLlamaConfig(ExecutorchConfig):
166
168
  mlp_config=mlp_config,
167
169
  post_mlp_norm_config=None,
168
170
  )
169
- return DecoderConfig(
170
- embedding_config=embedding_config,
171
+ transformer_config = TransformerConfig(
171
172
  global_rope_config=rope_config,
172
173
  local_rope_config=None,
173
- layer_configs=(decoder_layer_config,) * self.n_layers,
174
+ layer_configs=(tranformer_layer_config,) * self.n_layers,
174
175
  output_norm_config=rmsnorm_config,
175
- vocab_size=self.vocab_size,
176
176
  model_dim=self.dim,
177
177
  hidden_dim=self._find_hidden_size(),
178
178
  context_length=context_length or MAX_SEQUENCE_LENGTH,
179
179
  )
180
+ return DecoderConfig(
181
+ embedding_config=embedding_config,
182
+ transformer_config=transformer_config,
183
+ vocab_size=self.vocab_size,
184
+ )
@@ -1,10 +1,11 @@
1
- from .common import HuggingFaceConfig
1
+ from .common import HuggingFaceLMConfig
2
2
  from .gemma2 import HFGemma2Config
3
3
  from .gemma3 import HFGemma3Config, HFGemma3TextConfig
4
4
  from .gpt_oss import HFGPTOssConfig
5
5
  from .llama import HFLlamaConfig
6
6
  from .llamba import HFLlambaConfig
7
7
  from .mistral import HFMistralConfig
8
+ from .modern_bert import ModernBERTConfig
8
9
  from .qwen2 import HFQwen2Config
9
10
  from .qwen3 import HFQwen3Config
10
11
 
@@ -18,5 +19,6 @@ __all__ = [
18
19
  "HFMistralConfig",
19
20
  "HFQwen2Config",
20
21
  "HFQwen3Config",
21
- "HuggingFaceConfig",
22
+ "HuggingFaceLMConfig",
23
+ "ModernBERTConfig",
22
24
  ]
@@ -6,15 +6,22 @@ import cattrs
6
6
  import jax.numpy as jnp
7
7
  from jaxtyping import Array, DTypeLike
8
8
 
9
- from lalamo.model_import.decoder_configs import ForeignConfig
10
- from lalamo.model_import.loaders import load_huggingface
9
+ from lalamo.model_import.decoder_configs import ForeignLMConfig
10
+ from lalamo.model_import.decoder_configs.common import ForeignClassifierConfig
11
+ from lalamo.model_import.loaders import (
12
+ load_huggingface_classifier,
13
+ load_huggingface_decoder,
14
+ )
11
15
  from lalamo.modules import Decoder
16
+ from lalamo.modules.classifier import Classifier
17
+ from lalamo.modules.common import LalamoModule
12
18
 
13
19
  __all__ = [
14
20
  "AWQQuantizationConfig",
15
21
  "GPTQMetaConfig",
16
22
  "GPTQQuantizationConfig",
17
- "HuggingFaceConfig",
23
+ "HuggingFaceClassifierConfig",
24
+ "HuggingFaceLMConfig",
18
25
  ]
19
26
 
20
27
 
@@ -85,26 +92,45 @@ def _structure_quantization_config(v: object, _: object) -> QuantizationConfigTy
85
92
 
86
93
 
87
94
  @dataclass(frozen=True)
88
- class HuggingFaceConfig(ForeignConfig):
95
+ class HuggingFaceLMConfig(ForeignLMConfig):
89
96
  _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
90
97
  _converter.register_structure_hook(int | list[int], lambda v, _: v)
91
98
  _converter.register_structure_hook(QuantizationConfigType, _structure_quantization_config)
92
99
 
93
100
  @property
94
101
  def eos_token_ids(self) -> list[int]:
95
- if not hasattr(self, "eos_token_id"):
96
- raise RuntimeError("model doesn't havve eos_token_id, override eos_token_ids in model config")
102
+ result = getattr(self, "eos_token_id", None)
103
+ if result is None:
104
+ raise RuntimeError("model doesn't have eos_token_id, override eos_token_ids in model config")
97
105
 
98
- return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id # type: ignore (This is a bug in pyright)
106
+ if isinstance(result, int):
107
+ result = [result]
99
108
 
109
+ return result
110
+
111
+ @property
112
+ def default_precision(self) -> DTypeLike:
113
+ return jnp.dtype(getattr(self, "torch_dtype", "bfloat16"))
114
+
115
+ def _load_weights(
116
+ self,
117
+ model: LalamoModule,
118
+ weights_dict: Mapping[str, Array],
119
+ ) -> LalamoModule:
120
+ assert isinstance(model, Decoder)
121
+ return load_huggingface_decoder(model, weights_dict)
122
+
123
+
124
+ @dataclass(frozen=True)
125
+ class HuggingFaceClassifierConfig(ForeignClassifierConfig):
100
126
  @property
101
127
  def default_precision(self) -> DTypeLike:
102
128
  return jnp.dtype(getattr(self, "torch_dtype", "bfloat16"))
103
129
 
104
- @classmethod
105
130
  def _load_weights(
106
- cls,
107
- model: Decoder,
131
+ self,
132
+ model: LalamoModule,
108
133
  weights_dict: Mapping[str, Array],
109
- ) -> Decoder:
110
- return load_huggingface(model, weights_dict)
134
+ ) -> LalamoModule:
135
+ assert isinstance(model, Classifier)
136
+ return load_huggingface_classifier(model, weights_dict)
@@ -7,23 +7,24 @@ from jaxtyping import DTypeLike
7
7
  from lalamo.modules import (
8
8
  AttentionConfig,
9
9
  DecoderConfig,
10
- DecoderLayerConfig,
11
10
  DenseMLPConfig,
12
11
  FullPrecisionLinearConfig,
13
- RMSNormConfig,
12
+ NormalizationConfig,
14
13
  TiedEmbeddingConfig,
14
+ TransformerConfig,
15
+ TransformerLayerConfig,
15
16
  UnscaledRoPEConfig,
16
17
  UpcastMode,
17
18
  )
18
19
  from lalamo.modules.activations import GELU
19
20
 
20
- from .common import HuggingFaceConfig
21
+ from .common import HuggingFaceLMConfig
21
22
 
22
23
  __all__ = ["HFGemma2Config"]
23
24
 
24
25
 
25
26
  @dataclass(frozen=True)
26
- class HFGemma2Config(HuggingFaceConfig):
27
+ class HFGemma2Config(HuggingFaceLMConfig):
27
28
  architectures: list[Literal["Gemma2ForCausalLM"]]
28
29
  attention_bias: bool
29
30
  attention_dropout: float
@@ -72,12 +73,13 @@ class HFGemma2Config(HuggingFaceConfig):
72
73
  base=self.rope_theta,
73
74
  max_sequence_length=self.max_position_embeddings,
74
75
  )
75
- rmsnorm_config = RMSNormConfig(
76
+ rmsnorm_config = NormalizationConfig(
76
77
  scale_precision=activation_precision,
77
78
  accumulation_precision=accumulation_precision,
78
79
  epsilon=self.rms_norm_eps,
79
80
  scale_offset=1.0,
80
81
  upcast_mode=UpcastMode.FULL_LAYER,
82
+ subtract_mean=False,
81
83
  )
82
84
  linear_config = FullPrecisionLinearConfig(
83
85
  precision=activation_precision,
@@ -110,7 +112,7 @@ class HFGemma2Config(HuggingFaceConfig):
110
112
  scale=attention_scale,
111
113
  sliding_window_size=sliding_window_size,
112
114
  )
113
- decoder_layer_config = DecoderLayerConfig(
115
+ transformer_layer_config = TransformerLayerConfig(
114
116
  pre_mixer_norm_config=rmsnorm_config,
115
117
  mixer_config=attention_config,
116
118
  post_mixer_norm_config=rmsnorm_config,
@@ -118,16 +120,19 @@ class HFGemma2Config(HuggingFaceConfig):
118
120
  mlp_config=mlp_config,
119
121
  post_mlp_norm_config=rmsnorm_config,
120
122
  )
121
- layer_configs.append(decoder_layer_config)
123
+ layer_configs.append(transformer_layer_config)
122
124
 
123
- return DecoderConfig(
124
- embedding_config=embedding_config,
125
+ transformer_config = TransformerConfig(
125
126
  global_rope_config=rope_config,
126
127
  local_rope_config=None,
127
128
  layer_configs=tuple(layer_configs),
128
129
  output_norm_config=rmsnorm_config,
129
- vocab_size=self.vocab_size,
130
130
  model_dim=self.hidden_size,
131
131
  hidden_dim=self.intermediate_size,
132
132
  context_length=context_length or self.max_position_embeddings,
133
133
  )
134
+ return DecoderConfig(
135
+ embedding_config=embedding_config,
136
+ transformer_config=transformer_config,
137
+ vocab_size=self.vocab_size,
138
+ )
@@ -1,23 +1,20 @@
1
1
  from collections.abc import Mapping
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from typing import Literal
4
4
 
5
5
  import jax.numpy as jnp
6
6
  from jaxtyping import DTypeLike
7
7
 
8
- from lalamo.modules import (
9
- DecoderConfig,
10
- TiedEmbeddingConfig,
11
- )
8
+ from lalamo.modules import DecoderConfig, TiedEmbeddingConfig, TransformerConfig
12
9
  from lalamo.modules.activations import GELU
13
- from lalamo.modules.decoder_layer import DecoderLayerConfig
14
10
  from lalamo.modules.linear import FullPrecisionLinearConfig
15
11
  from lalamo.modules.mlp import DenseMLPConfig
16
- from lalamo.modules.normalization import RMSNormConfig, UpcastMode
12
+ from lalamo.modules.normalization import NormalizationConfig, UpcastMode
17
13
  from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
18
- from lalamo.modules.token_mixers import AttentionConfig
14
+ from lalamo.modules.token_mixers.attention import AttentionConfig
15
+ from lalamo.modules.transformer_layer import TransformerLayerConfig
19
16
 
20
- from .common import HuggingFaceConfig
17
+ from .common import HuggingFaceLMConfig
21
18
 
22
19
  __all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
23
20
 
@@ -80,12 +77,13 @@ class HFGemma3TextConfigRaw:
80
77
  logit_soft_cap=None,
81
78
  precision=activation_precision,
82
79
  )
83
- rms_norm_config = RMSNormConfig(
80
+ rms_norm_config = NormalizationConfig(
84
81
  scale_precision=activation_precision,
85
82
  accumulation_precision=accumulation_precision,
86
83
  epsilon=self.rms_norm_eps,
87
84
  scale_offset=1.0,
88
85
  upcast_mode=UpcastMode.FULL_LAYER,
86
+ subtract_mean=False,
89
87
  )
90
88
 
91
89
  if self.rope_scaling is not None:
@@ -134,7 +132,7 @@ class HFGemma3TextConfigRaw:
134
132
  scale=attention_scale,
135
133
  sliding_window_size=sliding_window_size,
136
134
  )
137
- decoder_layer_config = DecoderLayerConfig(
135
+ transformer_layer_config = TransformerLayerConfig(
138
136
  pre_mixer_norm_config=rms_norm_config,
139
137
  mixer_config=attention_config,
140
138
  post_mixer_norm_config=rms_norm_config,
@@ -142,23 +140,29 @@ class HFGemma3TextConfigRaw:
142
140
  mlp_config=mlp_config,
143
141
  post_mlp_norm_config=rms_norm_config,
144
142
  )
145
- layer_configs.append(decoder_layer_config)
146
- return DecoderConfig(
147
- embedding_config=embedding_config,
143
+ layer_configs.append(transformer_layer_config)
144
+
145
+ transformer_config = TransformerConfig(
148
146
  global_rope_config=global_rope_config,
149
147
  local_rope_config=local_rope_config,
150
148
  layer_configs=tuple(layer_configs),
151
149
  output_norm_config=rms_norm_config,
152
- vocab_size=self.vocab_size,
153
150
  model_dim=self.hidden_size,
154
151
  hidden_dim=self.intermediate_size,
155
152
  context_length=context_length or self.max_position_embeddings,
156
153
  )
157
154
 
155
+ return DecoderConfig(
156
+ embedding_config=embedding_config,
157
+ transformer_config=transformer_config,
158
+ vocab_size=self.vocab_size,
159
+ )
160
+
158
161
 
159
162
  @dataclass(frozen=True)
160
- class HFGemma3TextConfig(HFGemma3TextConfigRaw, HuggingFaceConfig):
163
+ class HFGemma3TextConfig(HFGemma3TextConfigRaw, HuggingFaceLMConfig):
161
164
  torch_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
165
+ eos_token_id: int | list[int] = field(default_factory=list)
162
166
 
163
167
 
164
168
  @dataclass(frozen=True)
@@ -174,7 +178,7 @@ class HFGemma3VisionConfig:
174
178
 
175
179
 
176
180
  @dataclass(frozen=True)
177
- class HFGemma3Config(HuggingFaceConfig):
181
+ class HFGemma3Config(HuggingFaceLMConfig):
178
182
  torch_dtype: Literal["bfloat16", "float16", "float32"]
179
183
  architectures: list[Literal["Gemma3ForConditionalGeneration"]]
180
184
  boi_token_index: int
@@ -7,20 +7,21 @@ from jaxtyping import DTypeLike
7
7
  from lalamo.modules import (
8
8
  AttentionConfig,
9
9
  DecoderConfig,
10
- DecoderLayerConfig,
11
10
  DenseMLPConfig,
12
11
  FullPrecisionLinearConfig,
13
12
  MixtureOfExpertsConfig,
14
- RMSNormConfig,
13
+ NormalizationConfig,
15
14
  SoftmaxRouting,
16
15
  TiedEmbeddingConfig,
16
+ TransformerConfig,
17
+ TransformerLayerConfig,
17
18
  UntiedEmbeddingConfig,
18
19
  UpcastMode,
19
20
  YARNRoPEConfig,
20
21
  )
21
22
  from lalamo.modules.activations import SiLU
22
23
 
23
- from .common import HuggingFaceConfig
24
+ from .common import HuggingFaceLMConfig
24
25
 
25
26
  __all__ = ["HFGPTOssConfig"]
26
27
 
@@ -36,7 +37,7 @@ class YarnRopeScalingConfig:
36
37
 
37
38
 
38
39
  @dataclass(frozen=True)
39
- class HFGPTOssConfig(HuggingFaceConfig):
40
+ class HFGPTOssConfig(HuggingFaceLMConfig):
40
41
  # Core HF fields
41
42
  architectures: list[Literal["GptOssForCausalLM"]]
42
43
  attention_bias: bool
@@ -115,12 +116,13 @@ class HFGPTOssConfig(HuggingFaceConfig):
115
116
  truncate=True,
116
117
  )
117
118
 
118
- rmsnorm_config = RMSNormConfig(
119
+ rmsnorm_config = NormalizationConfig(
119
120
  scale_precision=activation_precision,
120
121
  accumulation_precision=accumulation_precision,
121
122
  epsilon=self.rms_norm_eps,
122
123
  scale_offset=None,
123
124
  upcast_mode=UpcastMode.FULL_LAYER,
125
+ subtract_mean=False,
124
126
  )
125
127
 
126
128
  # Linear layers
@@ -179,7 +181,7 @@ class HFGPTOssConfig(HuggingFaceConfig):
179
181
  scale=None,
180
182
  sliding_window_size=sliding_window_size,
181
183
  )
182
- decoder_layer_config = DecoderLayerConfig(
184
+ transformer_layer_config = TransformerLayerConfig(
183
185
  pre_mixer_norm_config=rmsnorm_config,
184
186
  mixer_config=attention_config,
185
187
  post_mixer_norm_config=None,
@@ -187,16 +189,20 @@ class HFGPTOssConfig(HuggingFaceConfig):
187
189
  mlp_config=moe_config,
188
190
  post_mlp_norm_config=None,
189
191
  )
190
- layer_configs.append(decoder_layer_config)
192
+ layer_configs.append(transformer_layer_config)
191
193
 
192
- return DecoderConfig(
193
- embedding_config=embedding_config,
194
+ transformer_config = TransformerConfig(
194
195
  global_rope_config=rope_config,
195
196
  local_rope_config=None,
196
197
  layer_configs=tuple(layer_configs),
197
198
  output_norm_config=rmsnorm_config,
198
- vocab_size=self.vocab_size,
199
199
  model_dim=self.hidden_size,
200
200
  hidden_dim=self.intermediate_size,
201
201
  context_length=context_length or self.max_position_embeddings,
202
202
  )
203
+
204
+ return DecoderConfig(
205
+ embedding_config=embedding_config,
206
+ transformer_config=transformer_config,
207
+ vocab_size=self.vocab_size,
208
+ )
@@ -7,14 +7,15 @@ from jaxtyping import DTypeLike
7
7
  from lalamo.modules import (
8
8
  AttentionConfig,
9
9
  DecoderConfig,
10
- DecoderLayerConfig,
11
10
  DenseMLPConfig,
12
11
  FullPrecisionLinearConfig,
13
12
  GroupQuantizedLinearConfig,
14
13
  LlamaRoPEConfig,
15
- RMSNormConfig,
14
+ NormalizationConfig,
16
15
  SiLU,
17
16
  TiedEmbeddingConfig,
17
+ TransformerConfig,
18
+ TransformerLayerConfig,
18
19
  UnscaledRoPEConfig,
19
20
  UntiedEmbeddingConfig,
20
21
  UpcastMode,
@@ -22,7 +23,7 @@ from lalamo.modules import (
22
23
  )
23
24
  from lalamo.quantization import QuantizationMode
24
25
 
25
- from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
26
+ from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceLMConfig
26
27
 
27
28
  __all__ = ["HFLlamaConfig"]
28
29
 
@@ -47,7 +48,7 @@ class YarnRopeScalingConfig:
47
48
 
48
49
 
49
50
  @dataclass(frozen=True)
50
- class HFLlamaConfig(HuggingFaceConfig):
51
+ class HFLlamaConfig(HuggingFaceLMConfig):
51
52
  torch_dtype: Literal["bfloat16", "float16", "float32"]
52
53
  architectures: list[Literal["LlamaForCausalLM"]]
53
54
  attention_bias: bool
@@ -124,12 +125,13 @@ class HFLlamaConfig(HuggingFaceConfig):
124
125
  )
125
126
  else:
126
127
  raise ValueError("Unsupported rope_scaling configuration")
127
- rmsnorm_config = RMSNormConfig(
128
+ rmsnorm_config = NormalizationConfig(
128
129
  scale_precision=activation_precision,
129
130
  accumulation_precision=accumulation_precision,
130
131
  epsilon=self.rms_norm_eps,
131
132
  scale_offset=None,
132
133
  upcast_mode=UpcastMode.ONLY_NORMALIZATION,
134
+ subtract_mean=False,
133
135
  )
134
136
  if self.quantization_config is None:
135
137
  linear_config = FullPrecisionLinearConfig(
@@ -153,7 +155,7 @@ class HFLlamaConfig(HuggingFaceConfig):
153
155
  has_out_biases=False,
154
156
  num_heads=self.num_attention_heads,
155
157
  num_groups=self.num_key_value_heads,
156
- head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
158
+ head_dim=(self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads),
157
159
  is_causal=True,
158
160
  scale=None,
159
161
  sliding_window_size=None,
@@ -166,7 +168,7 @@ class HFLlamaConfig(HuggingFaceConfig):
166
168
  up_clipping=None,
167
169
  gate_clipping=None,
168
170
  )
169
- decoder_layer_config = DecoderLayerConfig(
171
+ transformer_layer_config = TransformerLayerConfig(
170
172
  pre_mixer_norm_config=rmsnorm_config,
171
173
  mixer_config=attention_config,
172
174
  post_mixer_norm_config=None,
@@ -174,14 +176,17 @@ class HFLlamaConfig(HuggingFaceConfig):
174
176
  mlp_config=mlp_config,
175
177
  post_mlp_norm_config=None,
176
178
  )
177
- return DecoderConfig(
178
- embedding_config=embedding_config,
179
+ transformer_config = TransformerConfig(
179
180
  global_rope_config=rope_config,
180
181
  local_rope_config=None,
181
- layer_configs=(decoder_layer_config,) * self.num_hidden_layers,
182
+ layer_configs=(transformer_layer_config,) * self.num_hidden_layers,
182
183
  output_norm_config=rmsnorm_config,
183
- vocab_size=self.vocab_size,
184
184
  model_dim=self.hidden_size,
185
185
  hidden_dim=self.intermediate_size,
186
186
  context_length=context_length or self.max_position_embeddings,
187
187
  )
188
+ return DecoderConfig(
189
+ embedding_config=embedding_config,
190
+ transformer_config=transformer_config,
191
+ vocab_size=self.vocab_size,
192
+ )
@@ -6,23 +6,24 @@ from jaxtyping import DTypeLike
6
6
 
7
7
  from lalamo.modules import (
8
8
  DecoderConfig,
9
- DecoderLayerConfig,
10
9
  DenseMLPConfig,
11
10
  FullPrecisionLinearConfig,
12
11
  Identity,
13
12
  Mamba2Config,
14
13
  MLXQuantizedLinearConfig,
15
14
  MLXSemiQuantizedUntiedEmbeddingConfig,
16
- RMSNormConfig,
15
+ NormalizationConfig,
17
16
  SeparableCausalConvConfig,
18
17
  SiLU,
19
18
  TiedEmbeddingConfig,
19
+ TransformerConfig,
20
+ TransformerLayerConfig,
20
21
  UntiedEmbeddingConfig,
21
22
  UpcastMode,
22
23
  )
23
24
  from lalamo.quantization import QuantizationMode
24
25
 
25
- from .common import HuggingFaceConfig
26
+ from .common import HuggingFaceLMConfig
26
27
 
27
28
 
28
29
  @dataclass(frozen=True)
@@ -45,7 +46,7 @@ class HFLlambaSsmConfig:
45
46
 
46
47
 
47
48
  @dataclass(frozen=True)
48
- class HFLlambaConfig(HuggingFaceConfig):
49
+ class HFLlambaConfig(HuggingFaceLMConfig):
49
50
  model_type: Literal["llamba"]
50
51
  vocab_size: int
51
52
  tie_embeddings: bool
@@ -74,7 +75,9 @@ class HFLlambaConfig(HuggingFaceConfig):
74
75
  input_scale=None,
75
76
  logit_soft_cap=None,
76
77
  group_size=int(metadata_dict["quantization_kwargs.group_size"]),
77
- embedding_quantization_mode=QuantizationMode.from_num_bits(int(metadata_dict["quantization_kwargs.bits"])),
78
+ embedding_quantization_mode=QuantizationMode.from_num_bits(
79
+ int(metadata_dict["quantization_kwargs.bits"])
80
+ ),
78
81
  activation_quantization_mode=None,
79
82
  activation_precision=activation_precision,
80
83
  )
@@ -91,18 +94,21 @@ class HFLlambaConfig(HuggingFaceConfig):
91
94
  precision=activation_precision,
92
95
  )
93
96
 
94
- rmsnorm_config = RMSNormConfig(
97
+ rmsnorm_config = NormalizationConfig(
95
98
  scale_precision=activation_precision,
96
99
  accumulation_precision=accumulation_precision,
97
100
  epsilon=self.norm_epsilon,
98
101
  scale_offset=None,
99
102
  upcast_mode=UpcastMode.ONLY_NORMALIZATION,
103
+ subtract_mean=False,
100
104
  )
101
105
 
102
- if "quantization_kwargs.group_size" in metadata_dict:
106
+ if metadata_dict and "quantization_kwargs.group_size" in metadata_dict:
103
107
  linear_config = MLXQuantizedLinearConfig(
104
108
  group_size=int(metadata_dict["quantization_kwargs.group_size"]),
105
- weight_quantization_mode=QuantizationMode.from_num_bits(int(metadata_dict["quantization_kwargs.bits"])),
109
+ weight_quantization_mode=QuantizationMode.from_num_bits(
110
+ int(metadata_dict["quantization_kwargs.bits"])
111
+ ),
106
112
  activation_quantization_mode=None,
107
113
  activation_precision=activation_precision,
108
114
  )
@@ -148,7 +154,7 @@ class HFLlambaConfig(HuggingFaceConfig):
148
154
  has_out_biases=self.ssm_cfg.bias,
149
155
  )
150
156
 
151
- decoder_layer_config = DecoderLayerConfig(
157
+ transformer_layer_config = TransformerLayerConfig(
152
158
  pre_mixer_norm_config=rmsnorm_config,
153
159
  mixer_config=mamba_config,
154
160
  post_mixer_norm_config=None,
@@ -156,15 +162,18 @@ class HFLlambaConfig(HuggingFaceConfig):
156
162
  mlp_config=mlp_config,
157
163
  post_mlp_norm_config=None,
158
164
  )
159
-
160
- return DecoderConfig(
161
- embedding_config=embedding_config,
165
+ transformer_config = TransformerConfig(
162
166
  global_rope_config=None,
163
167
  local_rope_config=None,
164
- layer_configs=(decoder_layer_config,) * self.n_layer,
168
+ layer_configs=(transformer_layer_config,) * self.n_layer,
165
169
  output_norm_config=rmsnorm_config,
166
- vocab_size=self.vocab_size,
167
170
  model_dim=self.d_model,
168
171
  hidden_dim=self.mlp_cfg.intermediate_size,
169
172
  context_length=context_length or 4096,
170
173
  )
174
+
175
+ return DecoderConfig(
176
+ embedding_config=embedding_config,
177
+ transformer_config=transformer_config,
178
+ vocab_size=self.vocab_size,
179
+ )