lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +2 -16
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
- lalamo-0.5.0.dist-info/RECORD +80 -0
- lalamo-0.4.1.dist-info/RECORD +0 -71
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -75,6 +76,7 @@ class HFGPTOssConfig(HuggingFaceConfig):
|
|
|
75
76
|
context_length: int | None,
|
|
76
77
|
activation_precision: DTypeLike,
|
|
77
78
|
accumulation_precision: DTypeLike,
|
|
79
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
78
80
|
) -> DecoderConfig:
|
|
79
81
|
# Embedding
|
|
80
82
|
if self.tie_word_embeddings:
|
|
@@ -124,17 +126,6 @@ class HFGPTOssConfig(HuggingFaceConfig):
|
|
|
124
126
|
# Linear layers
|
|
125
127
|
linear_config = FullPrecisionLinearConfig(precision=activation_precision)
|
|
126
128
|
|
|
127
|
-
attention_config = AttentionConfig(
|
|
128
|
-
qkv_projection_config=linear_config,
|
|
129
|
-
out_projection_config=linear_config,
|
|
130
|
-
query_norm_config=None,
|
|
131
|
-
key_norm_config=None,
|
|
132
|
-
logit_soft_cap=None,
|
|
133
|
-
has_sinks=True,
|
|
134
|
-
has_qkv_biases=self.attention_bias,
|
|
135
|
-
has_out_biases=self.attention_bias,
|
|
136
|
-
)
|
|
137
|
-
|
|
138
129
|
# Experts (MoE) scaffold
|
|
139
130
|
# Router: linear with bias; Experts: DenseMLP with SiLU(alpha=1.702) and value/gate clipping
|
|
140
131
|
experts_activation = SiLU(alpha=1.702)
|
|
@@ -154,42 +145,58 @@ class HFGPTOssConfig(HuggingFaceConfig):
|
|
|
154
145
|
router_has_biases=True,
|
|
155
146
|
expert_config=experts_config,
|
|
156
147
|
)
|
|
157
|
-
decoder_layer_config = DecoderLayerConfig(
|
|
158
|
-
pre_attention_norm_config=rmsnorm_config,
|
|
159
|
-
attention_config=attention_config,
|
|
160
|
-
post_attention_norm_config=None,
|
|
161
|
-
pre_mlp_norm_config=rmsnorm_config,
|
|
162
|
-
mlp_config=moe_config,
|
|
163
|
-
post_mlp_norm_config=None,
|
|
164
|
-
)
|
|
165
148
|
|
|
166
149
|
# Per-layer sliding-window
|
|
167
150
|
if self.layer_types is not None and len(self.layer_types) == self.num_hidden_layers:
|
|
168
|
-
sliding_window_sizes =
|
|
151
|
+
sliding_window_sizes = [
|
|
169
152
|
self.sliding_window if layer_type == "sliding_attention" else None for layer_type in self.layer_types
|
|
170
|
-
|
|
153
|
+
]
|
|
171
154
|
else:
|
|
172
155
|
# Fallback: apply the same sliding window to all layers if provided
|
|
173
156
|
sliding_window_sizes = (
|
|
174
|
-
|
|
157
|
+
[self.sliding_window] * self.num_hidden_layers
|
|
158
|
+
if self.sliding_window is not None
|
|
159
|
+
else [None] * self.num_hidden_layers
|
|
175
160
|
)
|
|
176
161
|
|
|
177
162
|
head_dim = self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads
|
|
178
163
|
|
|
164
|
+
layer_configs = []
|
|
165
|
+
for sliding_window_size in sliding_window_sizes:
|
|
166
|
+
attention_config = AttentionConfig(
|
|
167
|
+
qkv_projection_config=linear_config,
|
|
168
|
+
out_projection_config=linear_config,
|
|
169
|
+
query_norm_config=None,
|
|
170
|
+
key_norm_config=None,
|
|
171
|
+
logit_soft_cap=None,
|
|
172
|
+
has_sinks=True,
|
|
173
|
+
has_qkv_biases=self.attention_bias,
|
|
174
|
+
has_out_biases=self.attention_bias,
|
|
175
|
+
num_heads=self.num_attention_heads,
|
|
176
|
+
num_groups=self.num_key_value_heads,
|
|
177
|
+
head_dim=head_dim,
|
|
178
|
+
is_causal=True,
|
|
179
|
+
scale=None,
|
|
180
|
+
sliding_window_size=sliding_window_size,
|
|
181
|
+
)
|
|
182
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
183
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
184
|
+
mixer_config=attention_config,
|
|
185
|
+
post_mixer_norm_config=None,
|
|
186
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
187
|
+
mlp_config=moe_config,
|
|
188
|
+
post_mlp_norm_config=None,
|
|
189
|
+
)
|
|
190
|
+
layer_configs.append(decoder_layer_config)
|
|
191
|
+
|
|
179
192
|
return DecoderConfig(
|
|
180
193
|
embedding_config=embedding_config,
|
|
181
194
|
global_rope_config=rope_config,
|
|
182
195
|
local_rope_config=None,
|
|
183
|
-
|
|
196
|
+
layer_configs=tuple(layer_configs),
|
|
184
197
|
output_norm_config=rmsnorm_config,
|
|
185
198
|
vocab_size=self.vocab_size,
|
|
186
199
|
model_dim=self.hidden_size,
|
|
187
200
|
hidden_dim=self.intermediate_size,
|
|
188
|
-
num_heads=self.num_attention_heads,
|
|
189
|
-
num_groups=self.num_key_value_heads,
|
|
190
|
-
head_dim=head_dim,
|
|
191
|
-
attention_scale=None,
|
|
192
|
-
num_layers=self.num_hidden_layers,
|
|
193
|
-
sliding_window_sizes=sliding_window_sizes,
|
|
194
201
|
context_length=context_length or self.max_position_embeddings,
|
|
195
202
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -12,13 +13,13 @@ from lalamo.modules import (
|
|
|
12
13
|
GroupQuantizedLinearConfig,
|
|
13
14
|
LlamaRoPEConfig,
|
|
14
15
|
RMSNormConfig,
|
|
16
|
+
SiLU,
|
|
15
17
|
TiedEmbeddingConfig,
|
|
16
18
|
UnscaledRoPEConfig,
|
|
19
|
+
UntiedEmbeddingConfig,
|
|
17
20
|
UpcastMode,
|
|
18
21
|
YARNRoPEConfig,
|
|
19
22
|
)
|
|
20
|
-
from lalamo.modules.activations import SiLU
|
|
21
|
-
from lalamo.modules.embedding import UntiedEmbeddingConfig
|
|
22
23
|
from lalamo.quantization import QuantizationMode
|
|
23
24
|
|
|
24
25
|
from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
|
|
@@ -80,6 +81,7 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
80
81
|
context_length: int | None,
|
|
81
82
|
activation_precision: DTypeLike,
|
|
82
83
|
accumulation_precision: DTypeLike,
|
|
84
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
83
85
|
) -> DecoderConfig:
|
|
84
86
|
if self.tie_word_embeddings:
|
|
85
87
|
embedding_config = TiedEmbeddingConfig(
|
|
@@ -149,6 +151,12 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
149
151
|
has_sinks=False,
|
|
150
152
|
has_qkv_biases=self.attention_bias,
|
|
151
153
|
has_out_biases=False,
|
|
154
|
+
num_heads=self.num_attention_heads,
|
|
155
|
+
num_groups=self.num_key_value_heads,
|
|
156
|
+
head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
|
|
157
|
+
is_causal=True,
|
|
158
|
+
scale=None,
|
|
159
|
+
sliding_window_size=None,
|
|
152
160
|
)
|
|
153
161
|
mlp_config = DenseMLPConfig(
|
|
154
162
|
linear_config=linear_config,
|
|
@@ -159,9 +167,9 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
159
167
|
gate_clipping=None,
|
|
160
168
|
)
|
|
161
169
|
decoder_layer_config = DecoderLayerConfig(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
170
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
171
|
+
mixer_config=attention_config,
|
|
172
|
+
post_mixer_norm_config=None,
|
|
165
173
|
pre_mlp_norm_config=rmsnorm_config,
|
|
166
174
|
mlp_config=mlp_config,
|
|
167
175
|
post_mlp_norm_config=None,
|
|
@@ -170,16 +178,10 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
170
178
|
embedding_config=embedding_config,
|
|
171
179
|
global_rope_config=rope_config,
|
|
172
180
|
local_rope_config=None,
|
|
173
|
-
|
|
181
|
+
layer_configs=(decoder_layer_config,) * self.num_hidden_layers,
|
|
174
182
|
output_norm_config=rmsnorm_config,
|
|
175
183
|
vocab_size=self.vocab_size,
|
|
176
184
|
model_dim=self.hidden_size,
|
|
177
185
|
hidden_dim=self.intermediate_size,
|
|
178
|
-
num_heads=self.num_attention_heads,
|
|
179
|
-
num_groups=self.num_key_value_heads,
|
|
180
|
-
head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
|
|
181
|
-
attention_scale=None,
|
|
182
|
-
num_layers=self.num_hidden_layers,
|
|
183
|
-
sliding_window_sizes=None,
|
|
184
186
|
context_length=context_length or self.max_position_embeddings,
|
|
185
187
|
)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from jaxtyping import DTypeLike
|
|
6
|
+
|
|
7
|
+
from lalamo.modules import (
|
|
8
|
+
DecoderConfig,
|
|
9
|
+
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
|
+
FullPrecisionLinearConfig,
|
|
12
|
+
Identity,
|
|
13
|
+
Mamba2Config,
|
|
14
|
+
MLXQuantizedLinearConfig,
|
|
15
|
+
MLXSemiQuantizedUntiedEmbeddingConfig,
|
|
16
|
+
RMSNormConfig,
|
|
17
|
+
SeparableCausalConvConfig,
|
|
18
|
+
SiLU,
|
|
19
|
+
TiedEmbeddingConfig,
|
|
20
|
+
UntiedEmbeddingConfig,
|
|
21
|
+
UpcastMode,
|
|
22
|
+
)
|
|
23
|
+
from lalamo.quantization import QuantizationMode
|
|
24
|
+
|
|
25
|
+
from .common import HuggingFaceConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class HFLlambaMlpConfig:
|
|
30
|
+
intermediate_size: int
|
|
31
|
+
bias: bool
|
|
32
|
+
act_fn: Literal["silu"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class HFLlambaSsmConfig:
|
|
37
|
+
d_state: int
|
|
38
|
+
n_v_heads: int
|
|
39
|
+
n_qk_heads: int
|
|
40
|
+
expand: int
|
|
41
|
+
activation: Literal["identity"]
|
|
42
|
+
bias: bool
|
|
43
|
+
conv_bias: bool = True
|
|
44
|
+
d_conv: int = 4
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class HFLlambaConfig(HuggingFaceConfig):
|
|
49
|
+
model_type: Literal["llamba"]
|
|
50
|
+
vocab_size: int
|
|
51
|
+
tie_embeddings: bool
|
|
52
|
+
pad_vocab_size_multiple: int
|
|
53
|
+
lm_head_bias: bool
|
|
54
|
+
d_model: int
|
|
55
|
+
n_layer: int
|
|
56
|
+
resid_dropout: float
|
|
57
|
+
norm_epsilon: float
|
|
58
|
+
mlp_cfg: HFLlambaMlpConfig
|
|
59
|
+
ssm_cfg: HFLlambaSsmConfig
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def eos_token_ids(self) -> list[int]:
|
|
63
|
+
return [128001, 128008, 128009]
|
|
64
|
+
|
|
65
|
+
def to_decoder_config(
|
|
66
|
+
self,
|
|
67
|
+
context_length: int | None,
|
|
68
|
+
activation_precision: DTypeLike,
|
|
69
|
+
accumulation_precision: DTypeLike,
|
|
70
|
+
metadata_dict: Mapping[str, str],
|
|
71
|
+
) -> DecoderConfig:
|
|
72
|
+
if "quantization_kwargs.group_size" in metadata_dict:
|
|
73
|
+
embedding_config = MLXSemiQuantizedUntiedEmbeddingConfig(
|
|
74
|
+
input_scale=None,
|
|
75
|
+
logit_soft_cap=None,
|
|
76
|
+
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
77
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(int(metadata_dict["quantization_kwargs.bits"])),
|
|
78
|
+
activation_quantization_mode=None,
|
|
79
|
+
activation_precision=activation_precision,
|
|
80
|
+
)
|
|
81
|
+
elif self.tie_embeddings:
|
|
82
|
+
embedding_config = TiedEmbeddingConfig(
|
|
83
|
+
input_scale=None,
|
|
84
|
+
logit_soft_cap=None,
|
|
85
|
+
precision=activation_precision,
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
89
|
+
input_scale=None,
|
|
90
|
+
logit_soft_cap=None,
|
|
91
|
+
precision=activation_precision,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
rmsnorm_config = RMSNormConfig(
|
|
95
|
+
scale_precision=activation_precision,
|
|
96
|
+
accumulation_precision=accumulation_precision,
|
|
97
|
+
epsilon=self.norm_epsilon,
|
|
98
|
+
scale_offset=None,
|
|
99
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if "quantization_kwargs.group_size" in metadata_dict:
|
|
103
|
+
linear_config = MLXQuantizedLinearConfig(
|
|
104
|
+
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
105
|
+
weight_quantization_mode=QuantizationMode.from_num_bits(int(metadata_dict["quantization_kwargs.bits"])),
|
|
106
|
+
activation_quantization_mode=None,
|
|
107
|
+
activation_precision=activation_precision,
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
linear_config = FullPrecisionLinearConfig(
|
|
111
|
+
precision=activation_precision,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
mlp_config = DenseMLPConfig(
|
|
115
|
+
linear_config=linear_config,
|
|
116
|
+
activation=SiLU(),
|
|
117
|
+
has_up_biases=self.mlp_cfg.bias,
|
|
118
|
+
has_down_biases=self.mlp_cfg.bias,
|
|
119
|
+
up_clipping=None,
|
|
120
|
+
gate_clipping=None,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
inner_dim = self.ssm_cfg.expand * self.d_model
|
|
124
|
+
head_dim = inner_dim // self.ssm_cfg.n_v_heads
|
|
125
|
+
|
|
126
|
+
if self.ssm_cfg.activation == "identity":
|
|
127
|
+
activation = Identity()
|
|
128
|
+
elif self.ssm_cfg.activation == "silu":
|
|
129
|
+
activation = SiLU()
|
|
130
|
+
else:
|
|
131
|
+
activation = SiLU() # fallback
|
|
132
|
+
|
|
133
|
+
mamba_config = Mamba2Config(
|
|
134
|
+
in_projection_config=linear_config,
|
|
135
|
+
out_projection_config=linear_config,
|
|
136
|
+
conv_config=SeparableCausalConvConfig(
|
|
137
|
+
precision=activation_precision,
|
|
138
|
+
has_biases=self.ssm_cfg.conv_bias,
|
|
139
|
+
),
|
|
140
|
+
activation=activation,
|
|
141
|
+
kernel_size=self.ssm_cfg.d_conv,
|
|
142
|
+
num_heads=self.ssm_cfg.n_v_heads,
|
|
143
|
+
num_groups=self.ssm_cfg.n_qk_heads,
|
|
144
|
+
head_dim=head_dim,
|
|
145
|
+
state_dim=self.ssm_cfg.d_state,
|
|
146
|
+
expansion_factor=self.ssm_cfg.expand,
|
|
147
|
+
has_in_biases=self.ssm_cfg.bias,
|
|
148
|
+
has_out_biases=self.ssm_cfg.bias,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
152
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
153
|
+
mixer_config=mamba_config,
|
|
154
|
+
post_mixer_norm_config=None,
|
|
155
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
156
|
+
mlp_config=mlp_config,
|
|
157
|
+
post_mlp_norm_config=None,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return DecoderConfig(
|
|
161
|
+
embedding_config=embedding_config,
|
|
162
|
+
global_rope_config=None,
|
|
163
|
+
local_rope_config=None,
|
|
164
|
+
layer_configs=(decoder_layer_config,) * self.n_layer,
|
|
165
|
+
output_norm_config=rmsnorm_config,
|
|
166
|
+
vocab_size=self.vocab_size,
|
|
167
|
+
model_dim=self.d_model,
|
|
168
|
+
hidden_dim=self.mlp_cfg.intermediate_size,
|
|
169
|
+
context_length=context_length or 4096,
|
|
170
|
+
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -24,7 +25,6 @@ __all__ = ["HFMistralConfig"]
|
|
|
24
25
|
|
|
25
26
|
@dataclass(frozen=True)
|
|
26
27
|
class HFMistralConfig(HuggingFaceConfig):
|
|
27
|
-
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
28
28
|
architectures: list[Literal["MistralForCausalLM"]]
|
|
29
29
|
attention_dropout: float
|
|
30
30
|
bos_token_id: int
|
|
@@ -53,8 +53,8 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
53
53
|
context_length: int | None,
|
|
54
54
|
activation_precision: DTypeLike,
|
|
55
55
|
accumulation_precision: DTypeLike,
|
|
56
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
56
57
|
) -> DecoderConfig:
|
|
57
|
-
# Choose embedding config based on tie_word_embeddings flag
|
|
58
58
|
if self.tie_word_embeddings:
|
|
59
59
|
embedding_config = TiedEmbeddingConfig(
|
|
60
60
|
input_scale=None,
|
|
@@ -86,16 +86,7 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
86
86
|
precision=activation_precision,
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
-
|
|
90
|
-
qkv_projection_config=linear_config,
|
|
91
|
-
out_projection_config=linear_config,
|
|
92
|
-
query_norm_config=None,
|
|
93
|
-
key_norm_config=None,
|
|
94
|
-
logit_soft_cap=None,
|
|
95
|
-
has_sinks=False,
|
|
96
|
-
has_qkv_biases=False,
|
|
97
|
-
has_out_biases=False,
|
|
98
|
-
)
|
|
89
|
+
head_dim = self.head_dim or self.hidden_size // self.num_attention_heads
|
|
99
90
|
|
|
100
91
|
mlp_config = DenseMLPConfig(
|
|
101
92
|
linear_config=linear_config,
|
|
@@ -106,33 +97,43 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
106
97
|
gate_clipping=None,
|
|
107
98
|
)
|
|
108
99
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
attention_config=
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
100
|
+
layer_configs = []
|
|
101
|
+
for _ in range(self.num_hidden_layers):
|
|
102
|
+
attention_config = AttentionConfig(
|
|
103
|
+
qkv_projection_config=linear_config,
|
|
104
|
+
out_projection_config=linear_config,
|
|
105
|
+
query_norm_config=None,
|
|
106
|
+
key_norm_config=None,
|
|
107
|
+
logit_soft_cap=None,
|
|
108
|
+
has_sinks=False,
|
|
109
|
+
has_qkv_biases=False,
|
|
110
|
+
has_out_biases=False,
|
|
111
|
+
num_heads=self.num_attention_heads,
|
|
112
|
+
num_groups=self.num_key_value_heads,
|
|
113
|
+
head_dim=head_dim,
|
|
114
|
+
is_causal=True,
|
|
115
|
+
scale=None,
|
|
116
|
+
sliding_window_size=self.sliding_window,
|
|
117
|
+
)
|
|
117
118
|
|
|
118
|
-
|
|
119
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
120
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
121
|
+
mixer_config=attention_config,
|
|
122
|
+
post_mixer_norm_config=None,
|
|
123
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
124
|
+
mlp_config=mlp_config,
|
|
125
|
+
post_mlp_norm_config=None,
|
|
126
|
+
)
|
|
127
|
+
layer_configs.append(decoder_layer_config)
|
|
119
128
|
|
|
120
129
|
return DecoderConfig(
|
|
121
130
|
embedding_config=embedding_config,
|
|
122
131
|
global_rope_config=rope_config,
|
|
123
132
|
local_rope_config=None,
|
|
124
|
-
|
|
133
|
+
layer_configs=tuple(layer_configs),
|
|
125
134
|
output_norm_config=rmsnorm_config,
|
|
126
135
|
vocab_size=self.vocab_size,
|
|
127
136
|
model_dim=self.hidden_size,
|
|
128
137
|
hidden_dim=self.intermediate_size,
|
|
129
|
-
num_heads=self.num_attention_heads,
|
|
130
|
-
num_groups=self.num_key_value_heads,
|
|
131
|
-
head_dim=head_dim,
|
|
132
|
-
attention_scale=None,
|
|
133
|
-
num_layers=self.num_hidden_layers,
|
|
134
|
-
sliding_window_sizes=tuple([self.sliding_window] * self.num_hidden_layers)
|
|
135
|
-
if self.sliding_window is not None
|
|
136
|
-
else None,
|
|
137
138
|
context_length=context_length or self.max_position_embeddings,
|
|
138
139
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -69,6 +70,7 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
69
70
|
context_length: int | None,
|
|
70
71
|
activation_precision: DTypeLike,
|
|
71
72
|
accumulation_precision: DTypeLike,
|
|
73
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
72
74
|
) -> DecoderConfig:
|
|
73
75
|
if self.tie_word_embeddings:
|
|
74
76
|
embedding_config = TiedEmbeddingConfig(
|
|
@@ -105,16 +107,7 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
105
107
|
activation_quantization_mode=None,
|
|
106
108
|
activation_precision=activation_precision,
|
|
107
109
|
)
|
|
108
|
-
|
|
109
|
-
qkv_projection_config=linear_config,
|
|
110
|
-
out_projection_config=linear_config,
|
|
111
|
-
query_norm_config=None,
|
|
112
|
-
key_norm_config=None,
|
|
113
|
-
logit_soft_cap=None,
|
|
114
|
-
has_sinks=False,
|
|
115
|
-
has_qkv_biases=True,
|
|
116
|
-
has_out_biases=False,
|
|
117
|
-
)
|
|
110
|
+
head_dim = self.hidden_size // self.num_attention_heads
|
|
118
111
|
mlp_config = DenseMLPConfig(
|
|
119
112
|
linear_config=linear_config,
|
|
120
113
|
activation=SiLU(),
|
|
@@ -123,28 +116,43 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
123
116
|
up_clipping=None,
|
|
124
117
|
gate_clipping=None,
|
|
125
118
|
)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
119
|
+
|
|
120
|
+
sliding_window_sizes = self._get_sliding_window_sizes()
|
|
121
|
+
layer_configs = []
|
|
122
|
+
for sliding_window_size in sliding_window_sizes:
|
|
123
|
+
attention_config = AttentionConfig(
|
|
124
|
+
qkv_projection_config=linear_config,
|
|
125
|
+
out_projection_config=linear_config,
|
|
126
|
+
query_norm_config=None,
|
|
127
|
+
key_norm_config=None,
|
|
128
|
+
logit_soft_cap=None,
|
|
129
|
+
has_sinks=False,
|
|
130
|
+
has_qkv_biases=True,
|
|
131
|
+
has_out_biases=False,
|
|
132
|
+
num_heads=self.num_attention_heads,
|
|
133
|
+
num_groups=self.num_key_value_heads,
|
|
134
|
+
head_dim=head_dim,
|
|
135
|
+
is_causal=True,
|
|
136
|
+
scale=None,
|
|
137
|
+
sliding_window_size=sliding_window_size,
|
|
138
|
+
)
|
|
139
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
140
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
141
|
+
mixer_config=attention_config,
|
|
142
|
+
post_mixer_norm_config=None,
|
|
143
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
144
|
+
mlp_config=mlp_config,
|
|
145
|
+
post_mlp_norm_config=None,
|
|
146
|
+
)
|
|
147
|
+
layer_configs.append(decoder_layer_config)
|
|
134
148
|
return DecoderConfig(
|
|
135
149
|
embedding_config=embedding_config,
|
|
136
150
|
global_rope_config=rope_config,
|
|
137
151
|
local_rope_config=None,
|
|
138
|
-
|
|
152
|
+
layer_configs=tuple(layer_configs),
|
|
139
153
|
output_norm_config=rmsnorm_config,
|
|
140
154
|
vocab_size=self.vocab_size,
|
|
141
155
|
model_dim=self.hidden_size,
|
|
142
156
|
hidden_dim=self.intermediate_size,
|
|
143
|
-
num_heads=self.num_attention_heads,
|
|
144
|
-
num_groups=self.num_key_value_heads,
|
|
145
|
-
head_dim=self.hidden_size // self.num_attention_heads,
|
|
146
|
-
attention_scale=None,
|
|
147
|
-
num_layers=self.num_hidden_layers,
|
|
148
|
-
sliding_window_sizes=tuple(self._get_sliding_window_sizes()),
|
|
149
157
|
context_length=context_length or self.max_position_embeddings,
|
|
150
158
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -17,15 +18,18 @@ from lalamo.modules import (
|
|
|
17
18
|
UpcastMode,
|
|
18
19
|
)
|
|
19
20
|
from lalamo.modules.activations import SiLU
|
|
21
|
+
from lalamo.modules.embedding import MLXQuantizedTiedEmbeddingConfig
|
|
22
|
+
from lalamo.modules.linear import MLXQuantizedLinearConfig
|
|
20
23
|
from lalamo.quantization import QuantizationMode
|
|
21
24
|
|
|
22
|
-
from .common import
|
|
25
|
+
from .common import HuggingFaceConfig, MLXQuantizationConfig, QuantizationConfigType
|
|
23
26
|
|
|
24
27
|
__all__ = ["HFQwen3Config"]
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
@dataclass(frozen=True)
|
|
28
31
|
class HFQwen3Config(HuggingFaceConfig):
|
|
32
|
+
eos_token_id: int | list[int]
|
|
29
33
|
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
30
34
|
attention_bias: bool
|
|
31
35
|
hidden_act: Literal["silu"]
|
|
@@ -45,7 +49,7 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
45
49
|
vocab_size: int
|
|
46
50
|
head_dim: int
|
|
47
51
|
|
|
48
|
-
quantization_config:
|
|
52
|
+
quantization_config: QuantizationConfigType = None
|
|
49
53
|
|
|
50
54
|
def _get_sliding_window_sizes(self) -> tuple[int | None, ...]:
|
|
51
55
|
if not self.use_sliding_window:
|
|
@@ -67,8 +71,19 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
67
71
|
context_length: int | None,
|
|
68
72
|
activation_precision: DTypeLike,
|
|
69
73
|
accumulation_precision: DTypeLike,
|
|
74
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
70
75
|
) -> DecoderConfig:
|
|
71
|
-
if self.
|
|
76
|
+
if isinstance(self.quantization_config, MLXQuantizationConfig):
|
|
77
|
+
assert self.tie_word_embeddings, "only tied embeddings are supported"
|
|
78
|
+
embedding_config = MLXQuantizedTiedEmbeddingConfig(
|
|
79
|
+
input_scale=None,
|
|
80
|
+
logit_soft_cap=None,
|
|
81
|
+
group_size=self.quantization_config.group_size,
|
|
82
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
83
|
+
activation_quantization_mode=None,
|
|
84
|
+
activation_precision=activation_precision,
|
|
85
|
+
)
|
|
86
|
+
elif self.tie_word_embeddings:
|
|
72
87
|
embedding_config = TiedEmbeddingConfig(
|
|
73
88
|
input_scale=None,
|
|
74
89
|
logit_soft_cap=None,
|
|
@@ -96,6 +111,13 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
96
111
|
linear_config = FullPrecisionLinearConfig(
|
|
97
112
|
precision=activation_precision,
|
|
98
113
|
)
|
|
114
|
+
elif isinstance(self.quantization_config, MLXQuantizationConfig):
|
|
115
|
+
linear_config = MLXQuantizedLinearConfig(
|
|
116
|
+
group_size=self.quantization_config.group_size,
|
|
117
|
+
weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
|
|
118
|
+
activation_quantization_mode=None,
|
|
119
|
+
activation_precision=activation_precision,
|
|
120
|
+
)
|
|
99
121
|
else:
|
|
100
122
|
linear_config = GroupQuantizedLinearConfig(
|
|
101
123
|
group_size=self.quantization_config.group_size,
|
|
@@ -103,16 +125,6 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
103
125
|
activation_quantization_mode=None,
|
|
104
126
|
activation_precision=activation_precision,
|
|
105
127
|
)
|
|
106
|
-
attention_config = AttentionConfig(
|
|
107
|
-
qkv_projection_config=linear_config,
|
|
108
|
-
out_projection_config=linear_config,
|
|
109
|
-
query_norm_config=rmsnorm_config,
|
|
110
|
-
key_norm_config=rmsnorm_config,
|
|
111
|
-
logit_soft_cap=None,
|
|
112
|
-
has_sinks=False,
|
|
113
|
-
has_qkv_biases=self.attention_bias,
|
|
114
|
-
has_out_biases=self.attention_bias,
|
|
115
|
-
)
|
|
116
128
|
mlp_config = DenseMLPConfig(
|
|
117
129
|
linear_config=linear_config,
|
|
118
130
|
activation=SiLU(),
|
|
@@ -121,28 +133,43 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
121
133
|
up_clipping=None,
|
|
122
134
|
gate_clipping=None,
|
|
123
135
|
)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
136
|
+
|
|
137
|
+
sliding_window_sizes = self._get_sliding_window_sizes()
|
|
138
|
+
layer_configs = []
|
|
139
|
+
for sliding_window_size in sliding_window_sizes:
|
|
140
|
+
attention_config = AttentionConfig(
|
|
141
|
+
qkv_projection_config=linear_config,
|
|
142
|
+
out_projection_config=linear_config,
|
|
143
|
+
query_norm_config=rmsnorm_config,
|
|
144
|
+
key_norm_config=rmsnorm_config,
|
|
145
|
+
logit_soft_cap=None,
|
|
146
|
+
has_sinks=False,
|
|
147
|
+
has_qkv_biases=self.attention_bias,
|
|
148
|
+
has_out_biases=self.attention_bias,
|
|
149
|
+
num_heads=self.num_attention_heads,
|
|
150
|
+
num_groups=self.num_key_value_heads,
|
|
151
|
+
head_dim=self.head_dim,
|
|
152
|
+
is_causal=True,
|
|
153
|
+
scale=None,
|
|
154
|
+
sliding_window_size=sliding_window_size,
|
|
155
|
+
)
|
|
156
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
157
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
158
|
+
mixer_config=attention_config,
|
|
159
|
+
post_mixer_norm_config=None,
|
|
160
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
161
|
+
mlp_config=mlp_config,
|
|
162
|
+
post_mlp_norm_config=None,
|
|
163
|
+
)
|
|
164
|
+
layer_configs.append(decoder_layer_config)
|
|
132
165
|
return DecoderConfig(
|
|
133
166
|
embedding_config=embedding_config,
|
|
134
167
|
global_rope_config=rope_config,
|
|
135
168
|
local_rope_config=None,
|
|
136
|
-
|
|
169
|
+
layer_configs=tuple(layer_configs),
|
|
137
170
|
output_norm_config=rmsnorm_config,
|
|
138
171
|
vocab_size=self.vocab_size,
|
|
139
172
|
model_dim=self.hidden_size,
|
|
140
173
|
hidden_dim=self.intermediate_size,
|
|
141
|
-
num_heads=self.num_attention_heads,
|
|
142
|
-
num_groups=self.num_key_value_heads,
|
|
143
|
-
head_dim=self.head_dim,
|
|
144
|
-
attention_scale=None,
|
|
145
|
-
num_layers=self.num_hidden_layers,
|
|
146
|
-
sliding_window_sizes=self._get_sliding_window_sizes(),
|
|
147
174
|
context_length=context_length or self.max_position_embeddings,
|
|
148
175
|
)
|