lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +15 -2
- lalamo/data/__init__.py +0 -1
- lalamo/data/huggingface_message.py +1 -0
- lalamo/main.py +167 -18
- lalamo/message_processor.py +2 -3
- lalamo/model_import/common.py +120 -27
- lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo/model_import/decoder_configs/common.py +62 -21
- lalamo/model_import/decoder_configs/executorch.py +14 -9
- lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo/model_import/loaders/__init__.py +3 -2
- lalamo/model_import/loaders/executorch.py +24 -12
- lalamo/model_import/loaders/huggingface.py +258 -30
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +8 -2
- lalamo/model_import/model_specs/gemma.py +5 -1
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo/models/__init__.py +10 -0
- lalamo/models/common.py +81 -0
- lalamo/{language_model.py → models/language_model.py} +32 -49
- lalamo/models/router.py +59 -0
- lalamo/modules/__init__.py +33 -16
- lalamo/modules/classifier.py +339 -0
- lalamo/modules/common.py +6 -3
- lalamo/modules/decoder.py +52 -180
- lalamo/modules/mlp.py +28 -5
- lalamo/modules/normalization.py +13 -8
- lalamo/modules/token_mixers/attention.py +10 -6
- lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo/modules/transformer.py +273 -0
- lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
- lalamo/speculator/__init__.py +6 -2
- lalamo/speculator/estimator.py +91 -0
- lalamo/speculator/inference.py +28 -9
- lalamo/speculator/ngram.py +7 -3
- lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
- lalamo-0.5.4.dist-info/RECORD +88 -0
- lalamo-0.5.2.dist-info/RECORD +0 -80
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -7,24 +7,25 @@ from jaxtyping import DTypeLike
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
AttentionConfig,
|
|
9
9
|
DecoderConfig,
|
|
10
|
-
DecoderLayerConfig,
|
|
11
10
|
DenseMLPConfig,
|
|
12
11
|
FullPrecisionLinearConfig,
|
|
13
|
-
|
|
12
|
+
NormalizationConfig,
|
|
14
13
|
TiedEmbeddingConfig,
|
|
14
|
+
TransformerConfig,
|
|
15
|
+
TransformerLayerConfig,
|
|
15
16
|
UnscaledRoPEConfig,
|
|
16
17
|
UntiedEmbeddingConfig,
|
|
17
18
|
)
|
|
18
19
|
from lalamo.modules.activations import SiLU
|
|
19
20
|
from lalamo.modules.normalization import UpcastMode
|
|
20
21
|
|
|
21
|
-
from .common import
|
|
22
|
+
from .common import HuggingFaceLMConfig
|
|
22
23
|
|
|
23
24
|
__all__ = ["HFMistralConfig"]
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
@dataclass(frozen=True)
|
|
27
|
-
class HFMistralConfig(
|
|
28
|
+
class HFMistralConfig(HuggingFaceLMConfig):
|
|
28
29
|
architectures: list[Literal["MistralForCausalLM"]]
|
|
29
30
|
attention_dropout: float
|
|
30
31
|
bos_token_id: int
|
|
@@ -42,7 +43,6 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
42
43
|
rope_theta: float
|
|
43
44
|
sliding_window: int | None
|
|
44
45
|
tie_word_embeddings: bool
|
|
45
|
-
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
46
46
|
transformers_version: str
|
|
47
47
|
use_cache: bool
|
|
48
48
|
vocab_size: int
|
|
@@ -74,12 +74,13 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
74
74
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
75
75
|
)
|
|
76
76
|
|
|
77
|
-
rmsnorm_config =
|
|
77
|
+
rmsnorm_config = NormalizationConfig(
|
|
78
78
|
scale_precision=activation_precision,
|
|
79
79
|
accumulation_precision=accumulation_precision,
|
|
80
80
|
epsilon=self.rms_norm_eps,
|
|
81
81
|
scale_offset=None,
|
|
82
82
|
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
83
|
+
subtract_mean=False,
|
|
83
84
|
)
|
|
84
85
|
|
|
85
86
|
linear_config = FullPrecisionLinearConfig(
|
|
@@ -116,7 +117,7 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
116
117
|
sliding_window_size=self.sliding_window,
|
|
117
118
|
)
|
|
118
119
|
|
|
119
|
-
|
|
120
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
120
121
|
pre_mixer_norm_config=rmsnorm_config,
|
|
121
122
|
mixer_config=attention_config,
|
|
122
123
|
post_mixer_norm_config=None,
|
|
@@ -124,16 +125,20 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
124
125
|
mlp_config=mlp_config,
|
|
125
126
|
post_mlp_norm_config=None,
|
|
126
127
|
)
|
|
127
|
-
layer_configs.append(
|
|
128
|
+
layer_configs.append(transformer_layer_config)
|
|
128
129
|
|
|
129
|
-
|
|
130
|
-
embedding_config=embedding_config,
|
|
130
|
+
transformer_config = TransformerConfig(
|
|
131
131
|
global_rope_config=rope_config,
|
|
132
132
|
local_rope_config=None,
|
|
133
133
|
layer_configs=tuple(layer_configs),
|
|
134
134
|
output_norm_config=rmsnorm_config,
|
|
135
|
-
vocab_size=self.vocab_size,
|
|
136
135
|
model_dim=self.hidden_size,
|
|
137
136
|
hidden_dim=self.intermediate_size,
|
|
138
137
|
context_length=context_length or self.max_position_embeddings,
|
|
139
138
|
)
|
|
139
|
+
|
|
140
|
+
return DecoderConfig(
|
|
141
|
+
embedding_config=embedding_config,
|
|
142
|
+
transformer_config=transformer_config,
|
|
143
|
+
vocab_size=self.vocab_size,
|
|
144
|
+
)
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jaxtyping import DTypeLike
|
|
6
|
+
|
|
7
|
+
from lalamo.modules import (
|
|
8
|
+
Activation,
|
|
9
|
+
AttentionConfig,
|
|
10
|
+
ClassifierConfig,
|
|
11
|
+
DenseMLPConfig,
|
|
12
|
+
FullPrecisionLinearConfig,
|
|
13
|
+
NormalizationConfig,
|
|
14
|
+
TransformerConfig,
|
|
15
|
+
TransformerLayerConfig,
|
|
16
|
+
UnscaledRoPEConfig,
|
|
17
|
+
UpcastMode,
|
|
18
|
+
)
|
|
19
|
+
from lalamo.modules.activations import GELU, SiLU
|
|
20
|
+
from lalamo.modules.classifier import (
|
|
21
|
+
PoolingType,
|
|
22
|
+
PredictionHeadConfig,
|
|
23
|
+
)
|
|
24
|
+
from lalamo.modules.embedding import TiedEmbeddingConfig
|
|
25
|
+
|
|
26
|
+
from .common import (
|
|
27
|
+
AWQQuantizationConfig,
|
|
28
|
+
GPTQQuantizationConfig,
|
|
29
|
+
HuggingFaceClassifierConfig,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__all__ = ["ModernBERTConfig"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def activation_from_str(activation: str) -> type[Activation]:
|
|
36
|
+
supported_activations = {
|
|
37
|
+
"silu": SiLU,
|
|
38
|
+
"gelu": GELU,
|
|
39
|
+
}
|
|
40
|
+
if activation in supported_activations:
|
|
41
|
+
return supported_activations[activation]
|
|
42
|
+
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class ModernBERTConfig(HuggingFaceClassifierConfig):
|
|
50
|
+
architectures: list[Literal["ModernBertForSequenceClassification"]]
|
|
51
|
+
attention_bias: bool
|
|
52
|
+
attention_dropout: float
|
|
53
|
+
bos_token_id: int | list[int]
|
|
54
|
+
classifier_activation: Literal["gelu"]
|
|
55
|
+
classifier_bias: bool
|
|
56
|
+
classifier_dropout: float
|
|
57
|
+
classifier_pooling: Literal["mean"]
|
|
58
|
+
cls_token_id: int | list[int]
|
|
59
|
+
decoder_bias: bool
|
|
60
|
+
deterministic_flash_attn: bool
|
|
61
|
+
embedding_dropout: float
|
|
62
|
+
eos_token_id: int | list[int]
|
|
63
|
+
global_attn_every_n_layers: int
|
|
64
|
+
global_rope_theta: float
|
|
65
|
+
gradient_checkpointing: bool
|
|
66
|
+
hidden_activation: Literal["gelu"]
|
|
67
|
+
hidden_size: int
|
|
68
|
+
initializer_cutoff_factor: float
|
|
69
|
+
initializer_range: float
|
|
70
|
+
intermediate_size: int
|
|
71
|
+
layer_norm_eps: float
|
|
72
|
+
local_attention: int
|
|
73
|
+
local_rope_theta: float
|
|
74
|
+
max_position_embeddings: int
|
|
75
|
+
mlp_bias: bool
|
|
76
|
+
mlp_dropout: float
|
|
77
|
+
model_type: Literal["modernbert"]
|
|
78
|
+
norm_bias: bool
|
|
79
|
+
norm_eps: float
|
|
80
|
+
num_attention_heads: int
|
|
81
|
+
num_hidden_layers: int
|
|
82
|
+
pad_token_id: int | list[int]
|
|
83
|
+
position_embedding_type: Literal["absolute"]
|
|
84
|
+
sep_token_id: int | list[int]
|
|
85
|
+
transformers_version: str
|
|
86
|
+
vocab_size: int
|
|
87
|
+
id2label: dict[int, str]
|
|
88
|
+
label2id: dict[str, int]
|
|
89
|
+
|
|
90
|
+
quantization_config: AWQQuantizationConfig | GPTQQuantizationConfig | None = None
|
|
91
|
+
|
|
92
|
+
def __post_init__(self) -> None:
|
|
93
|
+
if len(self.label2id) != len(self.id2label):
|
|
94
|
+
raise ValueError("Legnth of label2id and id2label is expected to be the same")
|
|
95
|
+
|
|
96
|
+
def calculate_sliding_windows(self, num_layers: int, global_attn_every_n_layers: int) -> tuple[None, ...]:
|
|
97
|
+
result = [None] * num_layers
|
|
98
|
+
for index in range(len(result)):
|
|
99
|
+
if index % global_attn_every_n_layers != 0:
|
|
100
|
+
result[index] = self.local_attention # type: ignore
|
|
101
|
+
else:
|
|
102
|
+
pass
|
|
103
|
+
return tuple(result)
|
|
104
|
+
|
|
105
|
+
def to_classifier_config(
|
|
106
|
+
self,
|
|
107
|
+
context_length: int | None,
|
|
108
|
+
activation_precision: DTypeLike,
|
|
109
|
+
accumulation_precision: DTypeLike,
|
|
110
|
+
) -> ClassifierConfig:
|
|
111
|
+
embedding_config = TiedEmbeddingConfig(
|
|
112
|
+
input_scale=None,
|
|
113
|
+
logit_soft_cap=None,
|
|
114
|
+
precision=activation_precision,
|
|
115
|
+
)
|
|
116
|
+
embedding_norm_config = NormalizationConfig(
|
|
117
|
+
scale_precision=activation_precision,
|
|
118
|
+
accumulation_precision=accumulation_precision,
|
|
119
|
+
epsilon=self.norm_eps,
|
|
120
|
+
scale_offset=None,
|
|
121
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
122
|
+
subtract_mean=True,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
global_rope_config = UnscaledRoPEConfig(
|
|
126
|
+
precision=activation_precision,
|
|
127
|
+
base=self.global_rope_theta,
|
|
128
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
129
|
+
)
|
|
130
|
+
local_rope_config = UnscaledRoPEConfig(
|
|
131
|
+
precision=activation_precision,
|
|
132
|
+
base=self.local_rope_theta,
|
|
133
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
sliding_window_sizes = self.calculate_sliding_windows(self.num_hidden_layers, self.global_attn_every_n_layers)
|
|
137
|
+
|
|
138
|
+
transformer_norm_config = NormalizationConfig(
|
|
139
|
+
scale_precision=activation_precision,
|
|
140
|
+
accumulation_precision=accumulation_precision,
|
|
141
|
+
epsilon=self.norm_eps,
|
|
142
|
+
scale_offset=None,
|
|
143
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
144
|
+
subtract_mean=True,
|
|
145
|
+
)
|
|
146
|
+
linear_config = FullPrecisionLinearConfig(
|
|
147
|
+
precision=activation_precision,
|
|
148
|
+
)
|
|
149
|
+
activation = activation_from_str(self.hidden_activation)
|
|
150
|
+
assert activation is SiLU or activation is GELU
|
|
151
|
+
mlp_config = DenseMLPConfig(
|
|
152
|
+
linear_config=linear_config,
|
|
153
|
+
activation=activation(),
|
|
154
|
+
has_up_biases=False,
|
|
155
|
+
has_down_biases=False,
|
|
156
|
+
up_clipping=None,
|
|
157
|
+
gate_clipping=None,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# In ModernBERT architecture first Transformer layer has no pre-attention normalization
|
|
161
|
+
pre_attn_configs = [transformer_norm_config if i > 0 else None for i in range(self.num_hidden_layers)]
|
|
162
|
+
|
|
163
|
+
transformer_layer_configs = []
|
|
164
|
+
for sliding_window_size, pre_attn_config in zip(sliding_window_sizes, pre_attn_configs, strict=True):
|
|
165
|
+
attention_config = AttentionConfig(
|
|
166
|
+
qkv_projection_config=linear_config,
|
|
167
|
+
out_projection_config=linear_config,
|
|
168
|
+
query_norm_config=None,
|
|
169
|
+
key_norm_config=None,
|
|
170
|
+
logit_soft_cap=None,
|
|
171
|
+
has_sinks=False,
|
|
172
|
+
has_qkv_biases=self.attention_bias,
|
|
173
|
+
has_out_biases=False,
|
|
174
|
+
num_heads=self.num_attention_heads,
|
|
175
|
+
num_groups=self.num_attention_heads,
|
|
176
|
+
head_dim=self.hidden_size // self.num_attention_heads,
|
|
177
|
+
scale=None,
|
|
178
|
+
is_causal=False,
|
|
179
|
+
sliding_window_size=sliding_window_size,
|
|
180
|
+
)
|
|
181
|
+
layer_config = TransformerLayerConfig(
|
|
182
|
+
pre_mixer_norm_config=pre_attn_config,
|
|
183
|
+
mixer_config=attention_config,
|
|
184
|
+
post_mixer_norm_config=None,
|
|
185
|
+
pre_mlp_norm_config=transformer_norm_config,
|
|
186
|
+
mlp_config=mlp_config,
|
|
187
|
+
post_mlp_norm_config=None,
|
|
188
|
+
)
|
|
189
|
+
transformer_layer_configs.append(layer_config)
|
|
190
|
+
|
|
191
|
+
transformer_config = TransformerConfig(
|
|
192
|
+
global_rope_config=global_rope_config,
|
|
193
|
+
local_rope_config=local_rope_config,
|
|
194
|
+
layer_configs=tuple(transformer_layer_configs),
|
|
195
|
+
output_norm_config=transformer_norm_config,
|
|
196
|
+
model_dim=self.hidden_size,
|
|
197
|
+
hidden_dim=self.intermediate_size,
|
|
198
|
+
context_length=context_length or self.max_position_embeddings,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
prediction_head_dense_config = FullPrecisionLinearConfig(
|
|
202
|
+
precision=activation_precision,
|
|
203
|
+
)
|
|
204
|
+
prediction_head_norm_config = NormalizationConfig(
|
|
205
|
+
scale_precision=activation_precision,
|
|
206
|
+
accumulation_precision=jnp.float32,
|
|
207
|
+
epsilon=self.norm_eps,
|
|
208
|
+
scale_offset=0.0,
|
|
209
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
210
|
+
subtract_mean=True,
|
|
211
|
+
)
|
|
212
|
+
prediction_head_activation = activation_from_str(self.classifier_activation)
|
|
213
|
+
prediction_head_readout_config = FullPrecisionLinearConfig(
|
|
214
|
+
precision=activation_precision,
|
|
215
|
+
)
|
|
216
|
+
prediction_head_config = PredictionHeadConfig(
|
|
217
|
+
dense_config=prediction_head_dense_config,
|
|
218
|
+
activation=prediction_head_activation(),
|
|
219
|
+
normalization_config=prediction_head_norm_config,
|
|
220
|
+
readout_config=prediction_head_readout_config,
|
|
221
|
+
use_dense_bias=self.classifier_bias,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
output_labels = [self.id2label[idx] for idx in range(len(self.id2label))]
|
|
225
|
+
|
|
226
|
+
return ClassifierConfig(
|
|
227
|
+
embedding_config=embedding_config,
|
|
228
|
+
embedding_norm_config=embedding_norm_config,
|
|
229
|
+
transformer_config=transformer_config,
|
|
230
|
+
prediction_head_config=prediction_head_config,
|
|
231
|
+
readout_config=prediction_head_readout_config,
|
|
232
|
+
vocab_size=self.vocab_size,
|
|
233
|
+
model_dim=self.hidden_size,
|
|
234
|
+
hidden_dim=self.hidden_size,
|
|
235
|
+
attention_scale=None,
|
|
236
|
+
num_layers=self.num_hidden_layers,
|
|
237
|
+
context_length=self.max_position_embeddings,
|
|
238
|
+
num_labels=len(self.id2label),
|
|
239
|
+
classifier_pooling=PoolingType(self.classifier_pooling),
|
|
240
|
+
output_labels=tuple(output_labels),
|
|
241
|
+
)
|
|
@@ -7,12 +7,13 @@ from jaxtyping import DTypeLike
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
AttentionConfig,
|
|
9
9
|
DecoderConfig,
|
|
10
|
-
DecoderLayerConfig,
|
|
11
10
|
DenseMLPConfig,
|
|
12
11
|
FullPrecisionLinearConfig,
|
|
13
12
|
GroupQuantizedLinearConfig,
|
|
14
|
-
|
|
13
|
+
NormalizationConfig,
|
|
15
14
|
TiedEmbeddingConfig,
|
|
15
|
+
TransformerConfig,
|
|
16
|
+
TransformerLayerConfig,
|
|
16
17
|
UnscaledRoPEConfig,
|
|
17
18
|
UntiedEmbeddingConfig,
|
|
18
19
|
UpcastMode,
|
|
@@ -20,13 +21,13 @@ from lalamo.modules import (
|
|
|
20
21
|
from lalamo.modules.activations import SiLU
|
|
21
22
|
from lalamo.quantization import QuantizationMode
|
|
22
23
|
|
|
23
|
-
from .common import AWQQuantizationConfig, GPTQQuantizationConfig,
|
|
24
|
+
from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceLMConfig
|
|
24
25
|
|
|
25
26
|
__all__ = ["HFQwen2Config"]
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@dataclass(frozen=True)
|
|
29
|
-
class HFQwen2Config(
|
|
30
|
+
class HFQwen2Config(HuggingFaceLMConfig):
|
|
30
31
|
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
31
32
|
architectures: list[Literal["Qwen2ForCausalLM"]]
|
|
32
33
|
attention_dropout: float
|
|
@@ -89,12 +90,13 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
89
90
|
base=self.rope_theta,
|
|
90
91
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
91
92
|
)
|
|
92
|
-
rmsnorm_config =
|
|
93
|
+
rmsnorm_config = NormalizationConfig(
|
|
93
94
|
scale_precision=activation_precision,
|
|
94
95
|
accumulation_precision=accumulation_precision,
|
|
95
96
|
epsilon=self.rms_norm_eps,
|
|
96
97
|
scale_offset=None,
|
|
97
98
|
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
99
|
+
subtract_mean=False,
|
|
98
100
|
)
|
|
99
101
|
if self.quantization_config is None:
|
|
100
102
|
linear_config = FullPrecisionLinearConfig(
|
|
@@ -136,7 +138,7 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
136
138
|
scale=None,
|
|
137
139
|
sliding_window_size=sliding_window_size,
|
|
138
140
|
)
|
|
139
|
-
|
|
141
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
140
142
|
pre_mixer_norm_config=rmsnorm_config,
|
|
141
143
|
mixer_config=attention_config,
|
|
142
144
|
post_mixer_norm_config=None,
|
|
@@ -144,15 +146,20 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
144
146
|
mlp_config=mlp_config,
|
|
145
147
|
post_mlp_norm_config=None,
|
|
146
148
|
)
|
|
147
|
-
layer_configs.append(
|
|
148
|
-
|
|
149
|
-
|
|
149
|
+
layer_configs.append(transformer_layer_config)
|
|
150
|
+
|
|
151
|
+
transformer_config = TransformerConfig(
|
|
150
152
|
global_rope_config=rope_config,
|
|
151
153
|
local_rope_config=None,
|
|
152
154
|
layer_configs=tuple(layer_configs),
|
|
153
155
|
output_norm_config=rmsnorm_config,
|
|
154
|
-
vocab_size=self.vocab_size,
|
|
155
156
|
model_dim=self.hidden_size,
|
|
156
157
|
hidden_dim=self.intermediate_size,
|
|
157
158
|
context_length=context_length or self.max_position_embeddings,
|
|
158
159
|
)
|
|
160
|
+
|
|
161
|
+
return DecoderConfig(
|
|
162
|
+
embedding_config=embedding_config,
|
|
163
|
+
transformer_config=transformer_config,
|
|
164
|
+
vocab_size=self.vocab_size,
|
|
165
|
+
)
|
|
@@ -7,12 +7,13 @@ from jaxtyping import DTypeLike
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
AttentionConfig,
|
|
9
9
|
DecoderConfig,
|
|
10
|
-
DecoderLayerConfig,
|
|
11
10
|
DenseMLPConfig,
|
|
12
11
|
FullPrecisionLinearConfig,
|
|
13
12
|
GroupQuantizedLinearConfig,
|
|
14
|
-
|
|
13
|
+
NormalizationConfig,
|
|
15
14
|
TiedEmbeddingConfig,
|
|
15
|
+
TransformerConfig,
|
|
16
|
+
TransformerLayerConfig,
|
|
16
17
|
UnscaledRoPEConfig,
|
|
17
18
|
UntiedEmbeddingConfig,
|
|
18
19
|
UpcastMode,
|
|
@@ -22,13 +23,13 @@ from lalamo.modules.embedding import MLXQuantizedTiedEmbeddingConfig
|
|
|
22
23
|
from lalamo.modules.linear import MLXQuantizedLinearConfig
|
|
23
24
|
from lalamo.quantization import QuantizationMode
|
|
24
25
|
|
|
25
|
-
from .common import
|
|
26
|
+
from .common import HuggingFaceLMConfig, MLXQuantizationConfig, QuantizationConfigType
|
|
26
27
|
|
|
27
28
|
__all__ = ["HFQwen3Config"]
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
@dataclass(frozen=True)
|
|
31
|
-
class HFQwen3Config(
|
|
32
|
+
class HFQwen3Config(HuggingFaceLMConfig):
|
|
32
33
|
eos_token_id: int | list[int]
|
|
33
34
|
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
34
35
|
attention_bias: bool
|
|
@@ -100,12 +101,13 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
100
101
|
base=self.rope_theta,
|
|
101
102
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
102
103
|
)
|
|
103
|
-
rmsnorm_config =
|
|
104
|
+
rmsnorm_config = NormalizationConfig(
|
|
104
105
|
scale_precision=activation_precision,
|
|
105
106
|
accumulation_precision=accumulation_precision,
|
|
106
107
|
epsilon=self.rms_norm_eps,
|
|
107
108
|
scale_offset=None,
|
|
108
109
|
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
110
|
+
subtract_mean=False,
|
|
109
111
|
)
|
|
110
112
|
if self.quantization_config is None:
|
|
111
113
|
linear_config = FullPrecisionLinearConfig(
|
|
@@ -153,7 +155,7 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
153
155
|
scale=None,
|
|
154
156
|
sliding_window_size=sliding_window_size,
|
|
155
157
|
)
|
|
156
|
-
|
|
158
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
157
159
|
pre_mixer_norm_config=rmsnorm_config,
|
|
158
160
|
mixer_config=attention_config,
|
|
159
161
|
post_mixer_norm_config=None,
|
|
@@ -161,15 +163,18 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
161
163
|
mlp_config=mlp_config,
|
|
162
164
|
post_mlp_norm_config=None,
|
|
163
165
|
)
|
|
164
|
-
layer_configs.append(
|
|
165
|
-
|
|
166
|
-
embedding_config=embedding_config,
|
|
166
|
+
layer_configs.append(transformer_layer_config)
|
|
167
|
+
transformer_config = TransformerConfig(
|
|
167
168
|
global_rope_config=rope_config,
|
|
168
169
|
local_rope_config=None,
|
|
169
170
|
layer_configs=tuple(layer_configs),
|
|
170
171
|
output_norm_config=rmsnorm_config,
|
|
171
|
-
vocab_size=self.vocab_size,
|
|
172
172
|
model_dim=self.hidden_size,
|
|
173
173
|
hidden_dim=self.intermediate_size,
|
|
174
174
|
context_length=context_length or self.max_position_embeddings,
|
|
175
175
|
)
|
|
176
|
+
return DecoderConfig(
|
|
177
|
+
embedding_config=embedding_config,
|
|
178
|
+
transformer_config=transformer_config,
|
|
179
|
+
vocab_size=self.vocab_size,
|
|
180
|
+
)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# from .executorch import load_executorch
|
|
2
|
-
from .huggingface import
|
|
2
|
+
from .huggingface import load_huggingface_classifier, load_huggingface_decoder
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
+
"load_huggingface_classifier",
|
|
5
6
|
# "load_executorch",
|
|
6
|
-
"
|
|
7
|
+
"load_huggingface_decoder",
|
|
7
8
|
]
|
|
@@ -6,7 +6,15 @@ from einops import rearrange
|
|
|
6
6
|
from jaxtyping import Array, Float, Int
|
|
7
7
|
|
|
8
8
|
from lalamo.common import ParameterPath
|
|
9
|
-
from lalamo.modules import
|
|
9
|
+
from lalamo.modules import (
|
|
10
|
+
Attention,
|
|
11
|
+
Decoder,
|
|
12
|
+
DenseMLP,
|
|
13
|
+
Normalization,
|
|
14
|
+
QLoRALinear,
|
|
15
|
+
QuantizedTiedEmbedding,
|
|
16
|
+
TransformerLayer,
|
|
17
|
+
)
|
|
10
18
|
|
|
11
19
|
from .common import load_parameters
|
|
12
20
|
|
|
@@ -95,7 +103,7 @@ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: Paramete
|
|
|
95
103
|
)
|
|
96
104
|
|
|
97
105
|
|
|
98
|
-
def load_rmsnorm(module:
|
|
106
|
+
def load_rmsnorm(module: Normalization, weights_dict: Mapping[str, Array], path: ParameterPath) -> Normalization:
|
|
99
107
|
return load_parameters(lambda m: (m.scales,), module, (weights_dict[path / "weight"],))
|
|
100
108
|
|
|
101
109
|
|
|
@@ -175,18 +183,21 @@ def load_attention(
|
|
|
175
183
|
)
|
|
176
184
|
|
|
177
185
|
|
|
178
|
-
def
|
|
179
|
-
module:
|
|
186
|
+
def load_transformer_layer(
|
|
187
|
+
module: TransformerLayer,
|
|
180
188
|
weights_dict: Mapping[str, Array],
|
|
181
189
|
path: ParameterPath,
|
|
182
|
-
) ->
|
|
190
|
+
) -> TransformerLayer:
|
|
183
191
|
if module.post_mixer_norm is not None:
|
|
184
192
|
raise ValueError("Post attention normalization is not supported")
|
|
185
193
|
if module.post_mlp_norm is not None:
|
|
186
194
|
raise ValueError("Post MLP normalization is not supported")
|
|
187
|
-
|
|
195
|
+
if module.pre_mixer_norm is not None:
|
|
196
|
+
attention_norm = load_rmsnorm(module.pre_mixer_norm, weights_dict, path / "attention_norm")
|
|
197
|
+
else:
|
|
198
|
+
attention_norm = None
|
|
188
199
|
assert isinstance(module.mixer, Attention)
|
|
189
|
-
attention = load_attention(module.mixer, weights_dict, path / "
|
|
200
|
+
attention = load_attention(module.mixer, weights_dict, path / "mixer")
|
|
190
201
|
mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
|
|
191
202
|
assert isinstance(module.mlp, DenseMLP)
|
|
192
203
|
mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
|
|
@@ -214,12 +225,13 @@ def load_executorch(module: Decoder, weights_dict: Mapping[str, Array]) -> Decod
|
|
|
214
225
|
raise TypeError(f"Expected embedding to be QuantizedTiedEmbedding, got {type(module.embedding)}")
|
|
215
226
|
|
|
216
227
|
embedding = load_embedding(module.embedding, weights_dict, root_path / "tok_embeddings")
|
|
217
|
-
|
|
218
|
-
|
|
228
|
+
transformer_layers = tuple(
|
|
229
|
+
load_transformer_layer(layer, weights_dict, root_path / f"layers.{i}")
|
|
230
|
+
for i, layer in enumerate(module.transformer.layers)
|
|
219
231
|
)
|
|
220
|
-
output_norm = load_rmsnorm(module.output_norm, weights_dict, root_path / "norm")
|
|
232
|
+
output_norm = load_rmsnorm(module.transformer.output_norm, weights_dict, root_path / "norm")
|
|
221
233
|
return load_parameters(
|
|
222
|
-
lambda m: (m.embedding, m.layers, m.output_norm),
|
|
234
|
+
lambda m: (m.embedding, m.transformer.layers, m.transformer.output_norm),
|
|
223
235
|
module,
|
|
224
|
-
(embedding,
|
|
236
|
+
(embedding, transformer_layers, output_norm),
|
|
225
237
|
)
|