lalamo 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +271 -43
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +10 -6
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -3
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
- lalamo-0.4.0.dist-info/RECORD +71 -0
- lalamo-0.3.4.dist-info/RECORD +0 -59
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -8,11 +8,11 @@ from lalamo.modules import (
|
|
|
8
8
|
DecoderConfig,
|
|
9
9
|
TiedEmbeddingConfig,
|
|
10
10
|
)
|
|
11
|
-
from lalamo.modules.activations import
|
|
11
|
+
from lalamo.modules.activations import GELU
|
|
12
12
|
from lalamo.modules.attention import AttentionConfig
|
|
13
13
|
from lalamo.modules.decoder_layer import DecoderLayerConfig
|
|
14
14
|
from lalamo.modules.linear import FullPrecisionLinearConfig
|
|
15
|
-
from lalamo.modules.mlp import
|
|
15
|
+
from lalamo.modules.mlp import DenseMLPConfig
|
|
16
16
|
from lalamo.modules.normalization import RMSNormConfig, UpcastMode
|
|
17
17
|
from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
|
|
18
18
|
|
|
@@ -75,7 +75,7 @@ class HFGemma3TextConfigRaw:
|
|
|
75
75
|
attention_scale = self.query_pre_attn_scalar**-0.5
|
|
76
76
|
embedding_config = TiedEmbeddingConfig(
|
|
77
77
|
input_scale=input_scale,
|
|
78
|
-
|
|
78
|
+
logit_soft_cap=None,
|
|
79
79
|
precision=activation_precision,
|
|
80
80
|
)
|
|
81
81
|
rms_norm_config = RMSNormConfig(
|
|
@@ -106,13 +106,21 @@ class HFGemma3TextConfigRaw:
|
|
|
106
106
|
)
|
|
107
107
|
|
|
108
108
|
linear_config = FullPrecisionLinearConfig(precision=activation_precision)
|
|
109
|
-
mlp_config =
|
|
109
|
+
mlp_config = DenseMLPConfig(
|
|
110
|
+
linear_config=linear_config,
|
|
111
|
+
activation=GELU(),
|
|
112
|
+
has_up_biases=False,
|
|
113
|
+
has_down_biases=False,
|
|
114
|
+
up_clipping=None,
|
|
115
|
+
gate_clipping=None,
|
|
116
|
+
)
|
|
110
117
|
attention_config = AttentionConfig(
|
|
111
118
|
qkv_projection_config=linear_config,
|
|
112
119
|
out_projection_config=linear_config,
|
|
113
120
|
query_norm_config=rms_norm_config,
|
|
114
121
|
key_norm_config=rms_norm_config,
|
|
115
122
|
logit_soft_cap=self.attn_logit_softcapping,
|
|
123
|
+
has_sinks=False,
|
|
116
124
|
has_qkv_biases=self.attention_bias,
|
|
117
125
|
has_out_biases=self.attention_bias,
|
|
118
126
|
)
|
|
@@ -145,7 +153,7 @@ class HFGemma3TextConfigRaw:
|
|
|
145
153
|
|
|
146
154
|
@dataclass(frozen=True)
|
|
147
155
|
class HFGemma3TextConfig(HFGemma3TextConfigRaw, HuggingFaceConfig):
|
|
148
|
-
|
|
156
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
|
149
157
|
|
|
150
158
|
|
|
151
159
|
@dataclass(frozen=True)
|
|
@@ -162,6 +170,7 @@ class HFGemma3VisionConfig:
|
|
|
162
170
|
|
|
163
171
|
@dataclass(frozen=True)
|
|
164
172
|
class HFGemma3Config(HuggingFaceConfig):
|
|
173
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
165
174
|
architectures: list[Literal["Gemma3ForConditionalGeneration"]]
|
|
166
175
|
boi_token_index: int
|
|
167
176
|
eoi_token_index: int
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from jaxtyping import DTypeLike
|
|
5
|
+
|
|
6
|
+
from lalamo.modules import (
|
|
7
|
+
AttentionConfig,
|
|
8
|
+
DecoderConfig,
|
|
9
|
+
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
|
+
FullPrecisionLinearConfig,
|
|
12
|
+
MixtureOfExpertsConfig,
|
|
13
|
+
RMSNormConfig,
|
|
14
|
+
SoftmaxRouting,
|
|
15
|
+
TiedEmbeddingConfig,
|
|
16
|
+
UntiedEmbeddingConfig,
|
|
17
|
+
UpcastMode,
|
|
18
|
+
YARNRoPEConfig,
|
|
19
|
+
)
|
|
20
|
+
from lalamo.modules.activations import SiLU
|
|
21
|
+
|
|
22
|
+
from .common import HuggingFaceConfig
|
|
23
|
+
|
|
24
|
+
__all__ = ["HFGPTOssConfig"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class YarnRopeScalingConfig:
|
|
29
|
+
factor: float
|
|
30
|
+
beta_fast: float
|
|
31
|
+
beta_slow: float
|
|
32
|
+
original_max_position_embeddings: int
|
|
33
|
+
rope_type: Literal["yarn"]
|
|
34
|
+
truncate: bool
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class HFGPTOssConfig(HuggingFaceConfig):
|
|
39
|
+
# Core HF fields
|
|
40
|
+
architectures: list[Literal["GptOssForCausalLM"]]
|
|
41
|
+
attention_bias: bool
|
|
42
|
+
attention_dropout: float
|
|
43
|
+
eos_token_id: int | list[int]
|
|
44
|
+
hidden_act: Literal["silu"]
|
|
45
|
+
hidden_size: int
|
|
46
|
+
initializer_range: float
|
|
47
|
+
intermediate_size: int
|
|
48
|
+
max_position_embeddings: int
|
|
49
|
+
model_type: Literal["gpt_oss"]
|
|
50
|
+
num_attention_heads: int
|
|
51
|
+
num_hidden_layers: int
|
|
52
|
+
num_key_value_heads: int
|
|
53
|
+
pad_token_id: int
|
|
54
|
+
rms_norm_eps: float
|
|
55
|
+
rope_theta: float
|
|
56
|
+
tie_word_embeddings: bool
|
|
57
|
+
transformers_version: str
|
|
58
|
+
use_cache: bool
|
|
59
|
+
vocab_size: int
|
|
60
|
+
|
|
61
|
+
# GPT-OSS specifics
|
|
62
|
+
layer_types: list[Literal["sliding_attention", "full_attention"]] | None
|
|
63
|
+
sliding_window: int | None
|
|
64
|
+
swiglu_limit: float
|
|
65
|
+
head_dim: int | None
|
|
66
|
+
num_local_experts: int
|
|
67
|
+
num_experts_per_tok: int | None = None
|
|
68
|
+
experts_per_token: int | None = None # some configs may use this alias
|
|
69
|
+
rope_scaling: YarnRopeScalingConfig | None = None
|
|
70
|
+
output_router_logits: bool | None = None
|
|
71
|
+
router_aux_loss_coef: float | None = None
|
|
72
|
+
|
|
73
|
+
def to_decoder_config(
|
|
74
|
+
self,
|
|
75
|
+
context_length: int | None,
|
|
76
|
+
activation_precision: DTypeLike,
|
|
77
|
+
accumulation_precision: DTypeLike,
|
|
78
|
+
) -> DecoderConfig:
|
|
79
|
+
# Embedding
|
|
80
|
+
if self.tie_word_embeddings:
|
|
81
|
+
embedding_config = TiedEmbeddingConfig(
|
|
82
|
+
input_scale=None,
|
|
83
|
+
logit_soft_cap=None,
|
|
84
|
+
precision=activation_precision,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
88
|
+
input_scale=None,
|
|
89
|
+
logit_soft_cap=None,
|
|
90
|
+
precision=activation_precision,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if self.rope_scaling is not None and self.rope_scaling.rope_type == "yarn":
|
|
94
|
+
rope_config = YARNRoPEConfig(
|
|
95
|
+
precision=activation_precision,
|
|
96
|
+
base=self.rope_theta,
|
|
97
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
98
|
+
scaling_factor=self.rope_scaling.factor,
|
|
99
|
+
original_context_length=self.rope_scaling.original_max_position_embeddings,
|
|
100
|
+
beta_fast=self.rope_scaling.beta_fast,
|
|
101
|
+
beta_slow=self.rope_scaling.beta_slow,
|
|
102
|
+
truncate=self.rope_scaling.truncate,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
rope_config = YARNRoPEConfig(
|
|
106
|
+
precision=activation_precision,
|
|
107
|
+
base=self.rope_theta,
|
|
108
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
109
|
+
scaling_factor=1.0,
|
|
110
|
+
original_context_length=self.max_position_embeddings,
|
|
111
|
+
beta_fast=32.0,
|
|
112
|
+
beta_slow=1.0,
|
|
113
|
+
truncate=True,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
rmsnorm_config = RMSNormConfig(
|
|
117
|
+
scale_precision=activation_precision,
|
|
118
|
+
accumulation_precision=accumulation_precision,
|
|
119
|
+
epsilon=self.rms_norm_eps,
|
|
120
|
+
scale_offset=None,
|
|
121
|
+
upcast_mode=UpcastMode.FULL_LAYER,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Linear layers
|
|
125
|
+
linear_config = FullPrecisionLinearConfig(precision=activation_precision)
|
|
126
|
+
|
|
127
|
+
attention_config = AttentionConfig(
|
|
128
|
+
qkv_projection_config=linear_config,
|
|
129
|
+
out_projection_config=linear_config,
|
|
130
|
+
query_norm_config=None,
|
|
131
|
+
key_norm_config=None,
|
|
132
|
+
logit_soft_cap=None,
|
|
133
|
+
has_sinks=True,
|
|
134
|
+
has_qkv_biases=self.attention_bias,
|
|
135
|
+
has_out_biases=self.attention_bias,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Experts (MoE) scaffold
|
|
139
|
+
# Router: linear with bias; Experts: DenseMLP with SiLU(alpha=1.702) and value/gate clipping
|
|
140
|
+
experts_activation = SiLU(alpha=1.702)
|
|
141
|
+
experts_config = DenseMLPConfig(
|
|
142
|
+
linear_config=linear_config,
|
|
143
|
+
activation=experts_activation,
|
|
144
|
+
has_up_biases=True,
|
|
145
|
+
has_down_biases=True,
|
|
146
|
+
up_clipping=(-self.swiglu_limit + 1.0, self.swiglu_limit + 1.0),
|
|
147
|
+
gate_clipping=(None, self.swiglu_limit),
|
|
148
|
+
)
|
|
149
|
+
moe_config = MixtureOfExpertsConfig(
|
|
150
|
+
mixture_size=self.num_local_experts,
|
|
151
|
+
num_experts_per_token=(self.num_experts_per_tok or self.experts_per_token or 1),
|
|
152
|
+
routing_function=SoftmaxRouting(),
|
|
153
|
+
router_config=linear_config,
|
|
154
|
+
router_has_biases=True,
|
|
155
|
+
expert_config=experts_config,
|
|
156
|
+
)
|
|
157
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
158
|
+
pre_attention_norm_config=rmsnorm_config,
|
|
159
|
+
attention_config=attention_config,
|
|
160
|
+
post_attention_norm_config=None,
|
|
161
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
162
|
+
mlp_config=moe_config,
|
|
163
|
+
post_mlp_norm_config=None,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Per-layer sliding-window
|
|
167
|
+
if self.layer_types is not None and len(self.layer_types) == self.num_hidden_layers:
|
|
168
|
+
sliding_window_sizes = tuple(
|
|
169
|
+
self.sliding_window if layer_type == "sliding_attention" else None for layer_type in self.layer_types
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
# Fallback: apply the same sliding window to all layers if provided
|
|
173
|
+
sliding_window_sizes = (
|
|
174
|
+
tuple([self.sliding_window] * self.num_hidden_layers) if self.sliding_window is not None else None
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
head_dim = self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads
|
|
178
|
+
|
|
179
|
+
return DecoderConfig(
|
|
180
|
+
embedding_config=embedding_config,
|
|
181
|
+
global_rope_config=rope_config,
|
|
182
|
+
local_rope_config=None,
|
|
183
|
+
layer_config=decoder_layer_config,
|
|
184
|
+
output_norm_config=rmsnorm_config,
|
|
185
|
+
vocab_size=self.vocab_size,
|
|
186
|
+
model_dim=self.hidden_size,
|
|
187
|
+
hidden_dim=self.intermediate_size,
|
|
188
|
+
num_heads=self.num_attention_heads,
|
|
189
|
+
num_groups=self.num_key_value_heads,
|
|
190
|
+
head_dim=head_dim,
|
|
191
|
+
attention_scale=None,
|
|
192
|
+
num_layers=self.num_hidden_layers,
|
|
193
|
+
sliding_window_sizes=sliding_window_sizes,
|
|
194
|
+
context_length=context_length or self.max_position_embeddings,
|
|
195
|
+
)
|
|
@@ -4,19 +4,20 @@ from typing import Literal
|
|
|
4
4
|
from jaxtyping import DTypeLike
|
|
5
5
|
|
|
6
6
|
from lalamo.modules import (
|
|
7
|
-
Activation,
|
|
8
7
|
AttentionConfig,
|
|
9
8
|
DecoderConfig,
|
|
10
9
|
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
12
|
GroupQuantizedLinearConfig,
|
|
13
13
|
LlamaRoPEConfig,
|
|
14
|
-
MLPConfig,
|
|
15
14
|
RMSNormConfig,
|
|
16
15
|
TiedEmbeddingConfig,
|
|
17
16
|
UnscaledRoPEConfig,
|
|
18
17
|
UpcastMode,
|
|
18
|
+
YARNRoPEConfig,
|
|
19
19
|
)
|
|
20
|
+
from lalamo.modules.activations import SiLU
|
|
20
21
|
from lalamo.modules.embedding import UntiedEmbeddingConfig
|
|
21
22
|
from lalamo.quantization import QuantizationMode
|
|
22
23
|
|
|
@@ -34,8 +35,19 @@ class LlamaRopeScalingConfig:
|
|
|
34
35
|
rope_type: Literal["llama3"]
|
|
35
36
|
|
|
36
37
|
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class YarnRopeScalingConfig:
|
|
40
|
+
factor: float
|
|
41
|
+
beta_fast: float
|
|
42
|
+
beta_slow: float
|
|
43
|
+
original_max_position_embeddings: int
|
|
44
|
+
rope_type: Literal["yarn"]
|
|
45
|
+
truncate: bool
|
|
46
|
+
|
|
47
|
+
|
|
37
48
|
@dataclass(frozen=True)
|
|
38
49
|
class HFLlamaConfig(HuggingFaceConfig):
|
|
50
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
39
51
|
architectures: list[Literal["LlamaForCausalLM"]]
|
|
40
52
|
attention_bias: bool
|
|
41
53
|
attention_dropout: float
|
|
@@ -53,7 +65,7 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
53
65
|
num_key_value_heads: int
|
|
54
66
|
pretraining_tp: int
|
|
55
67
|
rms_norm_eps: float
|
|
56
|
-
rope_scaling: LlamaRopeScalingConfig | None
|
|
68
|
+
rope_scaling: LlamaRopeScalingConfig | YarnRopeScalingConfig | None
|
|
57
69
|
rope_theta: float
|
|
58
70
|
tie_word_embeddings: bool
|
|
59
71
|
transformers_version: str
|
|
@@ -72,13 +84,13 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
72
84
|
if self.tie_word_embeddings:
|
|
73
85
|
embedding_config = TiedEmbeddingConfig(
|
|
74
86
|
input_scale=None,
|
|
75
|
-
|
|
87
|
+
logit_soft_cap=None,
|
|
76
88
|
precision=activation_precision,
|
|
77
89
|
)
|
|
78
90
|
else:
|
|
79
91
|
embedding_config = UntiedEmbeddingConfig(
|
|
80
92
|
input_scale=None,
|
|
81
|
-
|
|
93
|
+
logit_soft_cap=None,
|
|
82
94
|
precision=activation_precision,
|
|
83
95
|
)
|
|
84
96
|
if self.rope_scaling is None:
|
|
@@ -87,7 +99,18 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
87
99
|
base=self.rope_theta,
|
|
88
100
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
89
101
|
)
|
|
90
|
-
|
|
102
|
+
elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
|
|
103
|
+
rope_config = YARNRoPEConfig(
|
|
104
|
+
precision=activation_precision,
|
|
105
|
+
base=self.rope_theta,
|
|
106
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
107
|
+
scaling_factor=self.rope_scaling.factor,
|
|
108
|
+
original_context_length=self.rope_scaling.original_max_position_embeddings,
|
|
109
|
+
beta_fast=self.rope_scaling.beta_fast,
|
|
110
|
+
beta_slow=self.rope_scaling.beta_slow,
|
|
111
|
+
truncate=self.rope_scaling.truncate,
|
|
112
|
+
)
|
|
113
|
+
elif isinstance(self.rope_scaling, LlamaRopeScalingConfig):
|
|
91
114
|
rope_config = LlamaRoPEConfig(
|
|
92
115
|
precision=activation_precision,
|
|
93
116
|
base=self.rope_theta,
|
|
@@ -97,6 +120,8 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
97
120
|
low_frequency_factor=self.rope_scaling.low_freq_factor,
|
|
98
121
|
high_frequency_factor=self.rope_scaling.high_freq_factor,
|
|
99
122
|
)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError("Unsupported rope_scaling configuration")
|
|
100
125
|
rmsnorm_config = RMSNormConfig(
|
|
101
126
|
scale_precision=activation_precision,
|
|
102
127
|
accumulation_precision=accumulation_precision,
|
|
@@ -121,12 +146,17 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
121
146
|
query_norm_config=None,
|
|
122
147
|
key_norm_config=None,
|
|
123
148
|
logit_soft_cap=None,
|
|
149
|
+
has_sinks=False,
|
|
124
150
|
has_qkv_biases=self.attention_bias,
|
|
125
151
|
has_out_biases=False,
|
|
126
152
|
)
|
|
127
|
-
mlp_config =
|
|
153
|
+
mlp_config = DenseMLPConfig(
|
|
128
154
|
linear_config=linear_config,
|
|
129
|
-
activation=
|
|
155
|
+
activation=SiLU(),
|
|
156
|
+
has_up_biases=False,
|
|
157
|
+
has_down_biases=False,
|
|
158
|
+
up_clipping=None,
|
|
159
|
+
gate_clipping=None,
|
|
130
160
|
)
|
|
131
161
|
decoder_layer_config = DecoderLayerConfig(
|
|
132
162
|
pre_attention_norm_config=rmsnorm_config,
|
|
@@ -4,17 +4,17 @@ from typing import Literal
|
|
|
4
4
|
from jaxtyping import DTypeLike
|
|
5
5
|
|
|
6
6
|
from lalamo.modules import (
|
|
7
|
-
Activation,
|
|
8
7
|
AttentionConfig,
|
|
9
8
|
DecoderConfig,
|
|
10
9
|
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
|
-
MLPConfig,
|
|
13
12
|
RMSNormConfig,
|
|
14
13
|
TiedEmbeddingConfig,
|
|
15
14
|
UnscaledRoPEConfig,
|
|
16
15
|
UntiedEmbeddingConfig,
|
|
17
16
|
)
|
|
17
|
+
from lalamo.modules.activations import SiLU
|
|
18
18
|
from lalamo.modules.normalization import UpcastMode
|
|
19
19
|
|
|
20
20
|
from .common import HuggingFaceConfig
|
|
@@ -24,6 +24,7 @@ __all__ = ["HFMistralConfig"]
|
|
|
24
24
|
|
|
25
25
|
@dataclass(frozen=True)
|
|
26
26
|
class HFMistralConfig(HuggingFaceConfig):
|
|
27
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
27
28
|
architectures: list[Literal["MistralForCausalLM"]]
|
|
28
29
|
attention_dropout: float
|
|
29
30
|
bos_token_id: int
|
|
@@ -57,13 +58,13 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
57
58
|
if self.tie_word_embeddings:
|
|
58
59
|
embedding_config = TiedEmbeddingConfig(
|
|
59
60
|
input_scale=None,
|
|
60
|
-
|
|
61
|
+
logit_soft_cap=None,
|
|
61
62
|
precision=activation_precision,
|
|
62
63
|
)
|
|
63
64
|
else:
|
|
64
65
|
embedding_config = UntiedEmbeddingConfig(
|
|
65
66
|
input_scale=None,
|
|
66
|
-
|
|
67
|
+
logit_soft_cap=None,
|
|
67
68
|
precision=activation_precision,
|
|
68
69
|
)
|
|
69
70
|
|
|
@@ -91,13 +92,18 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
91
92
|
query_norm_config=None,
|
|
92
93
|
key_norm_config=None,
|
|
93
94
|
logit_soft_cap=None,
|
|
95
|
+
has_sinks=False,
|
|
94
96
|
has_qkv_biases=False,
|
|
95
97
|
has_out_biases=False,
|
|
96
98
|
)
|
|
97
99
|
|
|
98
|
-
mlp_config =
|
|
100
|
+
mlp_config = DenseMLPConfig(
|
|
99
101
|
linear_config=linear_config,
|
|
100
|
-
activation=
|
|
102
|
+
activation=SiLU(),
|
|
103
|
+
has_up_biases=False,
|
|
104
|
+
has_down_biases=False,
|
|
105
|
+
up_clipping=None,
|
|
106
|
+
gate_clipping=None,
|
|
101
107
|
)
|
|
102
108
|
|
|
103
109
|
decoder_layer_config = DecoderLayerConfig(
|
|
@@ -4,19 +4,19 @@ from typing import Literal
|
|
|
4
4
|
from jaxtyping import DTypeLike
|
|
5
5
|
|
|
6
6
|
from lalamo.modules import (
|
|
7
|
-
Activation,
|
|
8
7
|
AttentionConfig,
|
|
9
8
|
DecoderConfig,
|
|
10
9
|
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
12
|
GroupQuantizedLinearConfig,
|
|
13
|
-
MLPConfig,
|
|
14
13
|
RMSNormConfig,
|
|
15
14
|
TiedEmbeddingConfig,
|
|
16
15
|
UnscaledRoPEConfig,
|
|
17
16
|
UntiedEmbeddingConfig,
|
|
18
17
|
UpcastMode,
|
|
19
18
|
)
|
|
19
|
+
from lalamo.modules.activations import SiLU
|
|
20
20
|
from lalamo.quantization import QuantizationMode
|
|
21
21
|
|
|
22
22
|
from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
|
|
@@ -26,6 +26,7 @@ __all__ = ["HFQwen2Config"]
|
|
|
26
26
|
|
|
27
27
|
@dataclass(frozen=True)
|
|
28
28
|
class HFQwen2Config(HuggingFaceConfig):
|
|
29
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
29
30
|
architectures: list[Literal["Qwen2ForCausalLM"]]
|
|
30
31
|
attention_dropout: float
|
|
31
32
|
bos_token_id: int | list[int]
|
|
@@ -72,13 +73,13 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
72
73
|
if self.tie_word_embeddings:
|
|
73
74
|
embedding_config = TiedEmbeddingConfig(
|
|
74
75
|
input_scale=None,
|
|
75
|
-
|
|
76
|
+
logit_soft_cap=None,
|
|
76
77
|
precision=activation_precision,
|
|
77
78
|
)
|
|
78
79
|
else:
|
|
79
80
|
embedding_config = UntiedEmbeddingConfig(
|
|
80
81
|
input_scale=None,
|
|
81
|
-
|
|
82
|
+
logit_soft_cap=None,
|
|
82
83
|
precision=activation_precision,
|
|
83
84
|
)
|
|
84
85
|
rope_config = UnscaledRoPEConfig(
|
|
@@ -110,12 +111,17 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
110
111
|
query_norm_config=None,
|
|
111
112
|
key_norm_config=None,
|
|
112
113
|
logit_soft_cap=None,
|
|
114
|
+
has_sinks=False,
|
|
113
115
|
has_qkv_biases=True,
|
|
114
116
|
has_out_biases=False,
|
|
115
117
|
)
|
|
116
|
-
mlp_config =
|
|
118
|
+
mlp_config = DenseMLPConfig(
|
|
117
119
|
linear_config=linear_config,
|
|
118
|
-
activation=
|
|
120
|
+
activation=SiLU(),
|
|
121
|
+
has_up_biases=False,
|
|
122
|
+
has_down_biases=False,
|
|
123
|
+
up_clipping=None,
|
|
124
|
+
gate_clipping=None,
|
|
119
125
|
)
|
|
120
126
|
decoder_layer_config = DecoderLayerConfig(
|
|
121
127
|
pre_attention_norm_config=rmsnorm_config,
|
|
@@ -4,19 +4,19 @@ from typing import Literal
|
|
|
4
4
|
from jaxtyping import DTypeLike
|
|
5
5
|
|
|
6
6
|
from lalamo.modules import (
|
|
7
|
-
Activation,
|
|
8
7
|
AttentionConfig,
|
|
9
8
|
DecoderConfig,
|
|
10
9
|
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
12
|
GroupQuantizedLinearConfig,
|
|
13
|
-
MLPConfig,
|
|
14
13
|
RMSNormConfig,
|
|
15
14
|
TiedEmbeddingConfig,
|
|
16
15
|
UnscaledRoPEConfig,
|
|
17
16
|
UntiedEmbeddingConfig,
|
|
18
17
|
UpcastMode,
|
|
19
18
|
)
|
|
19
|
+
from lalamo.modules.activations import SiLU
|
|
20
20
|
from lalamo.quantization import QuantizationMode
|
|
21
21
|
|
|
22
22
|
from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
|
|
@@ -26,6 +26,7 @@ __all__ = ["HFQwen3Config"]
|
|
|
26
26
|
|
|
27
27
|
@dataclass(frozen=True)
|
|
28
28
|
class HFQwen3Config(HuggingFaceConfig):
|
|
29
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
29
30
|
attention_bias: bool
|
|
30
31
|
hidden_act: Literal["silu"]
|
|
31
32
|
hidden_size: int
|
|
@@ -70,13 +71,13 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
70
71
|
if self.tie_word_embeddings:
|
|
71
72
|
embedding_config = TiedEmbeddingConfig(
|
|
72
73
|
input_scale=None,
|
|
73
|
-
|
|
74
|
+
logit_soft_cap=None,
|
|
74
75
|
precision=activation_precision,
|
|
75
76
|
)
|
|
76
77
|
else:
|
|
77
78
|
embedding_config = UntiedEmbeddingConfig(
|
|
78
79
|
input_scale=None,
|
|
79
|
-
|
|
80
|
+
logit_soft_cap=None,
|
|
80
81
|
precision=activation_precision,
|
|
81
82
|
)
|
|
82
83
|
rope_config = UnscaledRoPEConfig(
|
|
@@ -108,12 +109,17 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
108
109
|
query_norm_config=rmsnorm_config,
|
|
109
110
|
key_norm_config=rmsnorm_config,
|
|
110
111
|
logit_soft_cap=None,
|
|
112
|
+
has_sinks=False,
|
|
111
113
|
has_qkv_biases=self.attention_bias,
|
|
112
114
|
has_out_biases=self.attention_bias,
|
|
113
115
|
)
|
|
114
|
-
mlp_config =
|
|
116
|
+
mlp_config = DenseMLPConfig(
|
|
115
117
|
linear_config=linear_config,
|
|
116
|
-
activation=
|
|
118
|
+
activation=SiLU(),
|
|
119
|
+
has_up_biases=False,
|
|
120
|
+
has_down_biases=False,
|
|
121
|
+
up_clipping=None,
|
|
122
|
+
gate_clipping=None,
|
|
117
123
|
)
|
|
118
124
|
decoder_layer_config = DecoderLayerConfig(
|
|
119
125
|
pre_attention_norm_config=rmsnorm_config,
|
|
@@ -72,9 +72,7 @@ class HFTokenizerConfig:
|
|
|
72
72
|
def added_tokens(self) -> list[AddedToken]:
|
|
73
73
|
if self.added_tokens_decoder is None:
|
|
74
74
|
return []
|
|
75
|
-
return [
|
|
76
|
-
token.to_added_token() for token in self.added_tokens_decoder.values()
|
|
77
|
-
]
|
|
75
|
+
return [token.to_added_token() for token in self.added_tokens_decoder.values()]
|
|
78
76
|
|
|
79
77
|
@classmethod
|
|
80
78
|
def from_json(cls, json_path: Path | str) -> "HFTokenizerConfig":
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from collections.abc import Iterable, Iterator
|
|
1
|
+
from collections.abc import Iterable, Iterator, Mapping
|
|
2
2
|
from dataclasses import dataclass, replace
|
|
3
3
|
|
|
4
4
|
import jax.numpy as jnp
|
|
@@ -6,7 +6,7 @@ from einops import rearrange
|
|
|
6
6
|
from jaxtyping import Array, Float, Int
|
|
7
7
|
|
|
8
8
|
from lalamo.common import ParameterPath
|
|
9
|
-
from lalamo.modules import
|
|
9
|
+
from lalamo.modules import Attention, Decoder, DecoderLayer, DenseMLP, QLoRALinear, QuantizedTiedEmbedding, RMSNorm
|
|
10
10
|
|
|
11
11
|
from .common import load_parameters
|
|
12
12
|
|
|
@@ -43,7 +43,7 @@ def params_selector(module: QLoRALinear) -> tuple:
|
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
def get_qlora_linear_params(
|
|
46
|
-
weights_dict:
|
|
46
|
+
weights_dict: Mapping[str, Array],
|
|
47
47
|
path: ParameterPath,
|
|
48
48
|
weights_dtype: jnp.dtype,
|
|
49
49
|
) -> QLoRALinearParams:
|
|
@@ -76,7 +76,7 @@ def load_linear(module: QLoRALinear, weights_dict: dict[str, Array], path: Param
|
|
|
76
76
|
return load_parameters(params_selector, module, params)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
def load_mlp(module:
|
|
79
|
+
def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: ParameterPath) -> DenseMLP:
|
|
80
80
|
if not isinstance(module.up_projection, QLoRALinear):
|
|
81
81
|
raise TypeError(f"Expected up_projection to be QLoRALinear, got {type(module.up_projection)}")
|
|
82
82
|
if not isinstance(module.down_projection, QLoRALinear):
|
|
@@ -95,7 +95,7 @@ def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -
|
|
|
95
95
|
)
|
|
96
96
|
|
|
97
97
|
|
|
98
|
-
def load_rmsnorm(module: RMSNorm, weights_dict:
|
|
98
|
+
def load_rmsnorm(module: RMSNorm, weights_dict: Mapping[str, Array], path: ParameterPath) -> RMSNorm:
|
|
99
99
|
return load_parameters(lambda m: (m.scales,), module, (weights_dict[path / "weight"],))
|
|
100
100
|
|
|
101
101
|
|
|
@@ -131,7 +131,7 @@ def permute_qk_params(
|
|
|
131
131
|
|
|
132
132
|
def load_attention(
|
|
133
133
|
module: Attention,
|
|
134
|
-
weights_dict:
|
|
134
|
+
weights_dict: Mapping[str, Array],
|
|
135
135
|
path: ParameterPath,
|
|
136
136
|
) -> Attention:
|
|
137
137
|
if not isinstance(module.qkv_projection, QLoRALinear):
|
|
@@ -177,7 +177,7 @@ def load_attention(
|
|
|
177
177
|
|
|
178
178
|
def load_decoder_layer(
|
|
179
179
|
module: DecoderLayer,
|
|
180
|
-
weights_dict:
|
|
180
|
+
weights_dict: Mapping[str, Array],
|
|
181
181
|
path: ParameterPath,
|
|
182
182
|
) -> DecoderLayer:
|
|
183
183
|
if module.post_attention_norm is not None:
|
|
@@ -187,6 +187,7 @@ def load_decoder_layer(
|
|
|
187
187
|
attention_norm = load_rmsnorm(module.pre_attention_norm, weights_dict, path / "attention_norm")
|
|
188
188
|
attention = load_attention(module.attention, weights_dict, path / "attention")
|
|
189
189
|
mlp_norm = load_rmsnorm(module.pre_mlp_norm, weights_dict, path / "ffn_norm")
|
|
190
|
+
assert isinstance(module.mlp, DenseMLP)
|
|
190
191
|
mlp = load_mlp(module.mlp, weights_dict, path / "feed_forward")
|
|
191
192
|
return load_parameters(
|
|
192
193
|
lambda m: (m.pre_attention_norm, m.attention, m.pre_mlp_norm, m.mlp),
|
|
@@ -197,7 +198,7 @@ def load_decoder_layer(
|
|
|
197
198
|
|
|
198
199
|
def load_embedding(
|
|
199
200
|
module: QuantizedTiedEmbedding,
|
|
200
|
-
weights_dict:
|
|
201
|
+
weights_dict: Mapping[str, Array],
|
|
201
202
|
path: ParameterPath,
|
|
202
203
|
) -> QuantizedTiedEmbedding:
|
|
203
204
|
weights = weights_dict[path / "weight"].astype(module.weights.dtype)
|
|
@@ -206,7 +207,7 @@ def load_embedding(
|
|
|
206
207
|
return load_parameters(lambda m: (m.weights, m.scales), module, (weights, scales))
|
|
207
208
|
|
|
208
209
|
|
|
209
|
-
def load_executorch(module: Decoder, weights_dict:
|
|
210
|
+
def load_executorch(module: Decoder, weights_dict: Mapping[str, Array]) -> Decoder:
|
|
210
211
|
root_path = ParameterPath()
|
|
211
212
|
if not isinstance(module.embedding, QuantizedTiedEmbedding):
|
|
212
213
|
raise TypeError(f"Expected embedding to be QuantizedTiedEmbedding, got {type(module.embedding)}")
|