lalamo 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +3 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +21 -17
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +2 -0
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/METADATA +1 -1
  49. lalamo-0.5.3.dist-info/RECORD +88 -0
  50. lalamo-0.5.1.dist-info/RECORD +0 -80
  51. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/top_level.txt +0 -0
@@ -7,24 +7,25 @@ from jaxtyping import DTypeLike
7
7
  from lalamo.modules import (
8
8
  AttentionConfig,
9
9
  DecoderConfig,
10
- DecoderLayerConfig,
11
10
  DenseMLPConfig,
12
11
  FullPrecisionLinearConfig,
13
- RMSNormConfig,
12
+ NormalizationConfig,
14
13
  TiedEmbeddingConfig,
14
+ TransformerConfig,
15
+ TransformerLayerConfig,
15
16
  UnscaledRoPEConfig,
16
17
  UntiedEmbeddingConfig,
17
18
  )
18
19
  from lalamo.modules.activations import SiLU
19
20
  from lalamo.modules.normalization import UpcastMode
20
21
 
21
- from .common import HuggingFaceConfig
22
+ from .common import HuggingFaceLMConfig
22
23
 
23
24
  __all__ = ["HFMistralConfig"]
24
25
 
25
26
 
26
27
  @dataclass(frozen=True)
27
- class HFMistralConfig(HuggingFaceConfig):
28
+ class HFMistralConfig(HuggingFaceLMConfig):
28
29
  architectures: list[Literal["MistralForCausalLM"]]
29
30
  attention_dropout: float
30
31
  bos_token_id: int
@@ -42,7 +43,6 @@ class HFMistralConfig(HuggingFaceConfig):
42
43
  rope_theta: float
43
44
  sliding_window: int | None
44
45
  tie_word_embeddings: bool
45
- torch_dtype: Literal["bfloat16", "float16", "float32"]
46
46
  transformers_version: str
47
47
  use_cache: bool
48
48
  vocab_size: int
@@ -74,12 +74,13 @@ class HFMistralConfig(HuggingFaceConfig):
74
74
  max_sequence_length=context_length or self.max_position_embeddings,
75
75
  )
76
76
 
77
- rmsnorm_config = RMSNormConfig(
77
+ rmsnorm_config = NormalizationConfig(
78
78
  scale_precision=activation_precision,
79
79
  accumulation_precision=accumulation_precision,
80
80
  epsilon=self.rms_norm_eps,
81
81
  scale_offset=None,
82
82
  upcast_mode=UpcastMode.ONLY_NORMALIZATION,
83
+ subtract_mean=False,
83
84
  )
84
85
 
85
86
  linear_config = FullPrecisionLinearConfig(
@@ -116,7 +117,7 @@ class HFMistralConfig(HuggingFaceConfig):
116
117
  sliding_window_size=self.sliding_window,
117
118
  )
118
119
 
119
- decoder_layer_config = DecoderLayerConfig(
120
+ transformer_layer_config = TransformerLayerConfig(
120
121
  pre_mixer_norm_config=rmsnorm_config,
121
122
  mixer_config=attention_config,
122
123
  post_mixer_norm_config=None,
@@ -124,16 +125,20 @@ class HFMistralConfig(HuggingFaceConfig):
124
125
  mlp_config=mlp_config,
125
126
  post_mlp_norm_config=None,
126
127
  )
127
- layer_configs.append(decoder_layer_config)
128
+ layer_configs.append(transformer_layer_config)
128
129
 
129
- return DecoderConfig(
130
- embedding_config=embedding_config,
130
+ transformer_config = TransformerConfig(
131
131
  global_rope_config=rope_config,
132
132
  local_rope_config=None,
133
133
  layer_configs=tuple(layer_configs),
134
134
  output_norm_config=rmsnorm_config,
135
- vocab_size=self.vocab_size,
136
135
  model_dim=self.hidden_size,
137
136
  hidden_dim=self.intermediate_size,
138
137
  context_length=context_length or self.max_position_embeddings,
139
138
  )
139
+
140
+ return DecoderConfig(
141
+ embedding_config=embedding_config,
142
+ transformer_config=transformer_config,
143
+ vocab_size=self.vocab_size,
144
+ )
@@ -0,0 +1,241 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import jax.numpy as jnp
5
+ from jaxtyping import DTypeLike
6
+
7
+ from lalamo.modules import (
8
+ Activation,
9
+ AttentionConfig,
10
+ ClassifierConfig,
11
+ DenseMLPConfig,
12
+ FullPrecisionLinearConfig,
13
+ NormalizationConfig,
14
+ TransformerConfig,
15
+ TransformerLayerConfig,
16
+ UnscaledRoPEConfig,
17
+ UpcastMode,
18
+ )
19
+ from lalamo.modules.activations import GELU, SiLU
20
+ from lalamo.modules.classifier import (
21
+ PoolingType,
22
+ PredictionHeadConfig,
23
+ )
24
+ from lalamo.modules.embedding import TiedEmbeddingConfig
25
+
26
+ from .common import (
27
+ AWQQuantizationConfig,
28
+ GPTQQuantizationConfig,
29
+ HuggingFaceClassifierConfig,
30
+ )
31
+
32
+ __all__ = ["ModernBERTConfig"]
33
+
34
+
35
+ def activation_from_str(activation: str) -> type[Activation]:
36
+ supported_activations = {
37
+ "silu": SiLU,
38
+ "gelu": GELU,
39
+ }
40
+ if activation in supported_activations:
41
+ return supported_activations[activation]
42
+
43
+ raise ValueError(
44
+ f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}"
45
+ )
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class ModernBERTConfig(HuggingFaceClassifierConfig):
50
+ architectures: list[Literal["ModernBertForSequenceClassification"]]
51
+ attention_bias: bool
52
+ attention_dropout: float
53
+ bos_token_id: int | list[int]
54
+ classifier_activation: Literal["gelu"]
55
+ classifier_bias: bool
56
+ classifier_dropout: float
57
+ classifier_pooling: Literal["mean"]
58
+ cls_token_id: int | list[int]
59
+ decoder_bias: bool
60
+ deterministic_flash_attn: bool
61
+ embedding_dropout: float
62
+ eos_token_id: int | list[int]
63
+ global_attn_every_n_layers: int
64
+ global_rope_theta: float
65
+ gradient_checkpointing: bool
66
+ hidden_activation: Literal["gelu"]
67
+ hidden_size: int
68
+ initializer_cutoff_factor: float
69
+ initializer_range: float
70
+ intermediate_size: int
71
+ layer_norm_eps: float
72
+ local_attention: int
73
+ local_rope_theta: float
74
+ max_position_embeddings: int
75
+ mlp_bias: bool
76
+ mlp_dropout: float
77
+ model_type: Literal["modernbert"]
78
+ norm_bias: bool
79
+ norm_eps: float
80
+ num_attention_heads: int
81
+ num_hidden_layers: int
82
+ pad_token_id: int | list[int]
83
+ position_embedding_type: Literal["absolute"]
84
+ sep_token_id: int | list[int]
85
+ transformers_version: str
86
+ vocab_size: int
87
+ id2label: dict[int, str]
88
+ label2id: dict[str, int]
89
+
90
+ quantization_config: AWQQuantizationConfig | GPTQQuantizationConfig | None = None
91
+
92
+ def __post_init__(self) -> None:
93
+ if len(self.label2id) != len(self.id2label):
94
+ raise ValueError("Legnth of label2id and id2label is expected to be the same")
95
+
96
+ def calculate_sliding_windows(self, num_layers: int, global_attn_every_n_layers: int) -> tuple[None, ...]:
97
+ result = [None] * num_layers
98
+ for index in range(len(result)):
99
+ if index % global_attn_every_n_layers != 0:
100
+ result[index] = self.local_attention # type: ignore
101
+ else:
102
+ pass
103
+ return tuple(result)
104
+
105
+ def to_classifier_config(
106
+ self,
107
+ context_length: int | None,
108
+ activation_precision: DTypeLike,
109
+ accumulation_precision: DTypeLike,
110
+ ) -> ClassifierConfig:
111
+ embedding_config = TiedEmbeddingConfig(
112
+ input_scale=None,
113
+ logit_soft_cap=None,
114
+ precision=activation_precision,
115
+ )
116
+ embedding_norm_config = NormalizationConfig(
117
+ scale_precision=activation_precision,
118
+ accumulation_precision=accumulation_precision,
119
+ epsilon=self.norm_eps,
120
+ scale_offset=None,
121
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
122
+ subtract_mean=True,
123
+ )
124
+
125
+ global_rope_config = UnscaledRoPEConfig(
126
+ precision=activation_precision,
127
+ base=self.global_rope_theta,
128
+ max_sequence_length=context_length or self.max_position_embeddings,
129
+ )
130
+ local_rope_config = UnscaledRoPEConfig(
131
+ precision=activation_precision,
132
+ base=self.local_rope_theta,
133
+ max_sequence_length=context_length or self.max_position_embeddings,
134
+ )
135
+
136
+ sliding_window_sizes = self.calculate_sliding_windows(self.num_hidden_layers, self.global_attn_every_n_layers)
137
+
138
+ transformer_norm_config = NormalizationConfig(
139
+ scale_precision=activation_precision,
140
+ accumulation_precision=accumulation_precision,
141
+ epsilon=self.norm_eps,
142
+ scale_offset=None,
143
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
144
+ subtract_mean=True,
145
+ )
146
+ linear_config = FullPrecisionLinearConfig(
147
+ precision=activation_precision,
148
+ )
149
+ activation = activation_from_str(self.hidden_activation)
150
+ assert activation is SiLU or activation is GELU
151
+ mlp_config = DenseMLPConfig(
152
+ linear_config=linear_config,
153
+ activation=activation(),
154
+ has_up_biases=False,
155
+ has_down_biases=False,
156
+ up_clipping=None,
157
+ gate_clipping=None,
158
+ )
159
+
160
+ # In ModernBERT architecture first Transformer layer has no pre-attention normalization
161
+ pre_attn_configs = [transformer_norm_config if i > 0 else None for i in range(self.num_hidden_layers)]
162
+
163
+ transformer_layer_configs = []
164
+ for sliding_window_size, pre_attn_config in zip(sliding_window_sizes, pre_attn_configs, strict=True):
165
+ attention_config = AttentionConfig(
166
+ qkv_projection_config=linear_config,
167
+ out_projection_config=linear_config,
168
+ query_norm_config=None,
169
+ key_norm_config=None,
170
+ logit_soft_cap=None,
171
+ has_sinks=False,
172
+ has_qkv_biases=self.attention_bias,
173
+ has_out_biases=False,
174
+ num_heads=self.num_attention_heads,
175
+ num_groups=self.num_attention_heads,
176
+ head_dim=self.hidden_size // self.num_attention_heads,
177
+ scale=None,
178
+ is_causal=False,
179
+ sliding_window_size=sliding_window_size,
180
+ )
181
+ layer_config = TransformerLayerConfig(
182
+ pre_mixer_norm_config=pre_attn_config,
183
+ mixer_config=attention_config,
184
+ post_mixer_norm_config=None,
185
+ pre_mlp_norm_config=transformer_norm_config,
186
+ mlp_config=mlp_config,
187
+ post_mlp_norm_config=None,
188
+ )
189
+ transformer_layer_configs.append(layer_config)
190
+
191
+ transformer_config = TransformerConfig(
192
+ global_rope_config=global_rope_config,
193
+ local_rope_config=local_rope_config,
194
+ layer_configs=tuple(transformer_layer_configs),
195
+ output_norm_config=transformer_norm_config,
196
+ model_dim=self.hidden_size,
197
+ hidden_dim=self.intermediate_size,
198
+ context_length=context_length or self.max_position_embeddings,
199
+ )
200
+
201
+ prediction_head_dense_config = FullPrecisionLinearConfig(
202
+ precision=activation_precision,
203
+ )
204
+ prediction_head_norm_config = NormalizationConfig(
205
+ scale_precision=activation_precision,
206
+ accumulation_precision=jnp.float32,
207
+ epsilon=self.norm_eps,
208
+ scale_offset=0.0,
209
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
210
+ subtract_mean=True,
211
+ )
212
+ prediction_head_activation = activation_from_str(self.classifier_activation)
213
+ prediction_head_readout_config = FullPrecisionLinearConfig(
214
+ precision=activation_precision,
215
+ )
216
+ prediction_head_config = PredictionHeadConfig(
217
+ dense_config=prediction_head_dense_config,
218
+ activation=prediction_head_activation(),
219
+ normalization_config=prediction_head_norm_config,
220
+ readout_config=prediction_head_readout_config,
221
+ use_dense_bias=self.classifier_bias,
222
+ )
223
+
224
+ output_labels = [self.id2label[idx] for idx in range(len(self.id2label))]
225
+
226
+ return ClassifierConfig(
227
+ embedding_config=embedding_config,
228
+ embedding_norm_config=embedding_norm_config,
229
+ transformer_config=transformer_config,
230
+ prediction_head_config=prediction_head_config,
231
+ readout_config=prediction_head_readout_config,
232
+ vocab_size=self.vocab_size,
233
+ model_dim=self.hidden_size,
234
+ hidden_dim=self.hidden_size,
235
+ attention_scale=None,
236
+ num_layers=self.num_hidden_layers,
237
+ context_length=self.max_position_embeddings,
238
+ num_labels=len(self.id2label),
239
+ classifier_pooling=PoolingType(self.classifier_pooling),
240
+ output_labels=tuple(output_labels),
241
+ )
@@ -7,12 +7,13 @@ from jaxtyping import DTypeLike
7
7
  from lalamo.modules import (
8
8
  AttentionConfig,
9
9
  DecoderConfig,
10
- DecoderLayerConfig,
11
10
  DenseMLPConfig,
12
11
  FullPrecisionLinearConfig,
13
12
  GroupQuantizedLinearConfig,
14
- RMSNormConfig,
13
+ NormalizationConfig,
15
14
  TiedEmbeddingConfig,
15
+ TransformerConfig,
16
+ TransformerLayerConfig,
16
17
  UnscaledRoPEConfig,
17
18
  UntiedEmbeddingConfig,
18
19
  UpcastMode,
@@ -20,13 +21,13 @@ from lalamo.modules import (
20
21
  from lalamo.modules.activations import SiLU
21
22
  from lalamo.quantization import QuantizationMode
22
23
 
23
- from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
24
+ from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceLMConfig
24
25
 
25
26
  __all__ = ["HFQwen2Config"]
26
27
 
27
28
 
28
29
  @dataclass(frozen=True)
29
- class HFQwen2Config(HuggingFaceConfig):
30
+ class HFQwen2Config(HuggingFaceLMConfig):
30
31
  torch_dtype: Literal["bfloat16", "float16", "float32"]
31
32
  architectures: list[Literal["Qwen2ForCausalLM"]]
32
33
  attention_dropout: float
@@ -89,12 +90,13 @@ class HFQwen2Config(HuggingFaceConfig):
89
90
  base=self.rope_theta,
90
91
  max_sequence_length=context_length or self.max_position_embeddings,
91
92
  )
92
- rmsnorm_config = RMSNormConfig(
93
+ rmsnorm_config = NormalizationConfig(
93
94
  scale_precision=activation_precision,
94
95
  accumulation_precision=accumulation_precision,
95
96
  epsilon=self.rms_norm_eps,
96
97
  scale_offset=None,
97
98
  upcast_mode=UpcastMode.ONLY_NORMALIZATION,
99
+ subtract_mean=False,
98
100
  )
99
101
  if self.quantization_config is None:
100
102
  linear_config = FullPrecisionLinearConfig(
@@ -136,7 +138,7 @@ class HFQwen2Config(HuggingFaceConfig):
136
138
  scale=None,
137
139
  sliding_window_size=sliding_window_size,
138
140
  )
139
- decoder_layer_config = DecoderLayerConfig(
141
+ transformer_layer_config = TransformerLayerConfig(
140
142
  pre_mixer_norm_config=rmsnorm_config,
141
143
  mixer_config=attention_config,
142
144
  post_mixer_norm_config=None,
@@ -144,15 +146,20 @@ class HFQwen2Config(HuggingFaceConfig):
144
146
  mlp_config=mlp_config,
145
147
  post_mlp_norm_config=None,
146
148
  )
147
- layer_configs.append(decoder_layer_config)
148
- return DecoderConfig(
149
- embedding_config=embedding_config,
149
+ layer_configs.append(transformer_layer_config)
150
+
151
+ transformer_config = TransformerConfig(
150
152
  global_rope_config=rope_config,
151
153
  local_rope_config=None,
152
154
  layer_configs=tuple(layer_configs),
153
155
  output_norm_config=rmsnorm_config,
154
- vocab_size=self.vocab_size,
155
156
  model_dim=self.hidden_size,
156
157
  hidden_dim=self.intermediate_size,
157
158
  context_length=context_length or self.max_position_embeddings,
158
159
  )
160
+
161
+ return DecoderConfig(
162
+ embedding_config=embedding_config,
163
+ transformer_config=transformer_config,
164
+ vocab_size=self.vocab_size,
165
+ )
@@ -7,12 +7,13 @@ from jaxtyping import DTypeLike
7
7
  from lalamo.modules import (
8
8
  AttentionConfig,
9
9
  DecoderConfig,
10
- DecoderLayerConfig,
11
10
  DenseMLPConfig,
12
11
  FullPrecisionLinearConfig,
13
12
  GroupQuantizedLinearConfig,
14
- RMSNormConfig,
13
+ NormalizationConfig,
15
14
  TiedEmbeddingConfig,
15
+ TransformerConfig,
16
+ TransformerLayerConfig,
16
17
  UnscaledRoPEConfig,
17
18
  UntiedEmbeddingConfig,
18
19
  UpcastMode,
@@ -22,13 +23,13 @@ from lalamo.modules.embedding import MLXQuantizedTiedEmbeddingConfig
22
23
  from lalamo.modules.linear import MLXQuantizedLinearConfig
23
24
  from lalamo.quantization import QuantizationMode
24
25
 
25
- from .common import HuggingFaceConfig, MLXQuantizationConfig, QuantizationConfigType
26
+ from .common import HuggingFaceLMConfig, MLXQuantizationConfig, QuantizationConfigType
26
27
 
27
28
  __all__ = ["HFQwen3Config"]
28
29
 
29
30
 
30
31
  @dataclass(frozen=True)
31
- class HFQwen3Config(HuggingFaceConfig):
32
+ class HFQwen3Config(HuggingFaceLMConfig):
32
33
  eos_token_id: int | list[int]
33
34
  torch_dtype: Literal["bfloat16", "float16", "float32"]
34
35
  attention_bias: bool
@@ -100,12 +101,13 @@ class HFQwen3Config(HuggingFaceConfig):
100
101
  base=self.rope_theta,
101
102
  max_sequence_length=context_length or self.max_position_embeddings,
102
103
  )
103
- rmsnorm_config = RMSNormConfig(
104
+ rmsnorm_config = NormalizationConfig(
104
105
  scale_precision=activation_precision,
105
106
  accumulation_precision=accumulation_precision,
106
107
  epsilon=self.rms_norm_eps,
107
108
  scale_offset=None,
108
109
  upcast_mode=UpcastMode.ONLY_NORMALIZATION,
110
+ subtract_mean=False,
109
111
  )
110
112
  if self.quantization_config is None:
111
113
  linear_config = FullPrecisionLinearConfig(
@@ -153,7 +155,7 @@ class HFQwen3Config(HuggingFaceConfig):
153
155
  scale=None,
154
156
  sliding_window_size=sliding_window_size,
155
157
  )
156
- decoder_layer_config = DecoderLayerConfig(
158
+ transformer_layer_config = TransformerLayerConfig(
157
159
  pre_mixer_norm_config=rmsnorm_config,
158
160
  mixer_config=attention_config,
159
161
  post_mixer_norm_config=None,
@@ -161,15 +163,18 @@ class HFQwen3Config(HuggingFaceConfig):
161
163
  mlp_config=mlp_config,
162
164
  post_mlp_norm_config=None,
163
165
  )
164
- layer_configs.append(decoder_layer_config)
165
- return DecoderConfig(
166
- embedding_config=embedding_config,
166
+ layer_configs.append(transformer_layer_config)
167
+ transformer_config = TransformerConfig(
167
168
  global_rope_config=rope_config,
168
169
  local_rope_config=None,
169
170
  layer_configs=tuple(layer_configs),
170
171
  output_norm_config=rmsnorm_config,
171
- vocab_size=self.vocab_size,
172
172
  model_dim=self.hidden_size,
173
173
  hidden_dim=self.intermediate_size,
174
174
  context_length=context_length or self.max_position_embeddings,
175
175
  )
176
+ return DecoderConfig(
177
+ embedding_config=embedding_config,
178
+ transformer_config=transformer_config,
179
+ vocab_size=self.vocab_size,
180
+ )
@@ -1,7 +1,8 @@
1
1
  # from .executorch import load_executorch
2
- from .huggingface import load_huggingface
2
+ from .huggingface import load_huggingface_classifier, load_huggingface_decoder
3
3
 
4
4
  __all__ = [
5
+ "load_huggingface_classifier",
5
6
  # "load_executorch",
6
- "load_huggingface",
7
+ "load_huggingface_decoder",
7
8
  ]
@@ -6,7 +6,15 @@ from einops import rearrange
6
6
  from jaxtyping import Array, Float, Int
7
7
 
8
8
  from lalamo.common import ParameterPath
9
- from lalamo.modules import Attention, Decoder, DecoderLayer, DenseMLP, QLoRALinear, QuantizedTiedEmbedding, RMSNorm
9
+ from lalamo.modules import (
10
+ Attention,
11
+ Decoder,
12
+ DenseMLP,
13
+ Normalization,
14
+ QLoRALinear,
15
+ QuantizedTiedEmbedding,
16
+ TransformerLayer,
17
+ )
10
18
 
11
19
  from .common import load_parameters
12
20
 
@@ -95,7 +103,7 @@ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: Paramete
95
103
  )
96
104
 
97
105
 
98
- def load_rmsnorm(module: RMSNorm, weights_dict: Mapping[str, Array], path: ParameterPath) -> RMSNorm:
106
+ def load_rmsnorm(module: Normalization, weights_dict: Mapping[str, Array], path: ParameterPath) -> Normalization:
99
107
  return load_parameters(lambda m: (m.scales,), module, (weights_dict[path / "weight"],))
100
108
 
101
109
 
@@ -175,18 +183,21 @@ def load_attention(
175
183
  )
176
184
 
177
185
 
178
- def load_decoder_layer(
179
- module: DecoderLayer,
186
+ def load_transformer_layer(
187
+ module: TransformerLayer,
180
188
  weights_dict: Mapping[str, Array],
181
189
  path: ParameterPath,
182
- ) -> DecoderLayer:
190
+ ) -> TransformerLayer:
183
191
  if module.post_mixer_norm is not None:
184
192
  raise ValueError("Post attention normalization is not supported")
185
193
  if module.post_mlp_norm is not None:
186
194
  raise ValueError("Post MLP normalization is not supported")
187
- attention_norm = load_rmsnorm(module.pre_mixer_norm, weights_dict, path / "attention_norm")
195
+ if module.pre_mixer_norm is not None:
196
+ attention_norm = load_rmsnorm(module.pre_mixer_norm, weights_dict, path / "attention_norm")
197
+ else:
198
+ attention_norm = None
188
199
  assert isinstance(module.mixer, Attention)
189
- attention = load_attention(module.mixer, weights_dict, path / "attention")
200
+ attention = load_attention(module.mixer, weights_dict, path / "mixer")
190
201
  mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
191
202
  assert isinstance(module.mlp, DenseMLP)
192
203
  mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
@@ -214,12 +225,13 @@ def load_executorch(module: Decoder, weights_dict: Mapping[str, Array]) -> Decod
214
225
  raise TypeError(f"Expected embedding to be QuantizedTiedEmbedding, got {type(module.embedding)}")
215
226
 
216
227
  embedding = load_embedding(module.embedding, weights_dict, root_path / "tok_embeddings")
217
- decoder_layers = tuple(
218
- load_decoder_layer(layer, weights_dict, root_path / f"layers.{i}") for i, layer in enumerate(module.layers)
228
+ transformer_layers = tuple(
229
+ load_transformer_layer(layer, weights_dict, root_path / f"layers.{i}")
230
+ for i, layer in enumerate(module.transformer.layers)
219
231
  )
220
- output_norm = load_rmsnorm(module.output_norm, weights_dict, root_path / "norm")
232
+ output_norm = load_rmsnorm(module.transformer.output_norm, weights_dict, root_path / "norm")
221
233
  return load_parameters(
222
- lambda m: (m.embedding, m.layers, m.output_norm),
234
+ lambda m: (m.embedding, m.transformer.layers, m.transformer.output_norm),
223
235
  module,
224
- (embedding, decoder_layers, output_norm),
236
+ (embedding, transformer_layers, output_norm),
225
237
  )