lalamo 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +3 -2
- lalamo/data/__init__.py +0 -1
- lalamo/data/huggingface_message.py +1 -0
- lalamo/main.py +167 -18
- lalamo/message_processor.py +2 -3
- lalamo/model_import/common.py +120 -27
- lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo/model_import/decoder_configs/common.py +62 -21
- lalamo/model_import/decoder_configs/executorch.py +14 -9
- lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +21 -17
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo/model_import/loaders/__init__.py +3 -2
- lalamo/model_import/loaders/executorch.py +24 -12
- lalamo/model_import/loaders/huggingface.py +258 -30
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +8 -2
- lalamo/model_import/model_specs/gemma.py +5 -1
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo/models/__init__.py +10 -0
- lalamo/models/common.py +81 -0
- lalamo/{language_model.py → models/language_model.py} +32 -49
- lalamo/models/router.py +59 -0
- lalamo/modules/__init__.py +33 -16
- lalamo/modules/classifier.py +339 -0
- lalamo/modules/common.py +6 -3
- lalamo/modules/decoder.py +52 -180
- lalamo/modules/mlp.py +28 -5
- lalamo/modules/normalization.py +13 -8
- lalamo/modules/token_mixers/attention.py +10 -6
- lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo/modules/transformer.py +273 -0
- lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
- lalamo/speculator/__init__.py +2 -0
- lalamo/speculator/estimator.py +91 -0
- lalamo/speculator/inference.py +28 -9
- lalamo/speculator/ngram.py +7 -3
- lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/METADATA +1 -1
- lalamo-0.5.3.dist-info/RECORD +88 -0
- lalamo-0.5.1.dist-info/RECORD +0 -80
- {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/WHEEL +0 -0
- {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.1.dist-info → lalamo-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -9,12 +9,13 @@ from lalamo.modules import (
|
|
|
9
9
|
AttentionConfig,
|
|
10
10
|
Decoder,
|
|
11
11
|
DecoderConfig,
|
|
12
|
-
DecoderLayerConfig,
|
|
13
12
|
DenseMLPConfig,
|
|
14
13
|
LlamaRoPEConfig,
|
|
14
|
+
NormalizationConfig,
|
|
15
15
|
QLoRALinearConfig,
|
|
16
16
|
QuantizedTiedEmbeddingConfig,
|
|
17
|
-
|
|
17
|
+
TransformerConfig,
|
|
18
|
+
TransformerLayerConfig,
|
|
18
19
|
UpcastMode,
|
|
19
20
|
)
|
|
20
21
|
from lalamo.modules.activations import SiLU
|
|
@@ -62,7 +63,7 @@ class ExecutorchConfig(ForeignConfig):
|
|
|
62
63
|
return jnp.bfloat16
|
|
63
64
|
|
|
64
65
|
@classmethod
|
|
65
|
-
def
|
|
66
|
+
def _load_decoder_weights(
|
|
66
67
|
cls,
|
|
67
68
|
model: Decoder,
|
|
68
69
|
weights_dict: Mapping[str, Array],
|
|
@@ -119,12 +120,13 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
119
120
|
low_frequency_factor=LOW_FREQ_FACTOR,
|
|
120
121
|
high_frequency_factor=HIGH_FREQ_FACTOR,
|
|
121
122
|
)
|
|
122
|
-
rmsnorm_config =
|
|
123
|
+
rmsnorm_config = NormalizationConfig(
|
|
123
124
|
scale_precision=activation_precision,
|
|
124
125
|
accumulation_precision=accumulation_precision,
|
|
125
126
|
epsilon=self.norm_eps,
|
|
126
127
|
scale_offset=None,
|
|
127
128
|
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
129
|
+
subtract_mean=False,
|
|
128
130
|
)
|
|
129
131
|
linear_config = QLoRALinearConfig(
|
|
130
132
|
group_size=self.quantization_args.group_size,
|
|
@@ -158,7 +160,7 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
158
160
|
up_clipping=None,
|
|
159
161
|
gate_clipping=None,
|
|
160
162
|
)
|
|
161
|
-
|
|
163
|
+
tranformer_layer_config = TransformerLayerConfig(
|
|
162
164
|
pre_mixer_norm_config=rmsnorm_config,
|
|
163
165
|
mixer_config=attention_config,
|
|
164
166
|
post_mixer_norm_config=None,
|
|
@@ -166,14 +168,17 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
166
168
|
mlp_config=mlp_config,
|
|
167
169
|
post_mlp_norm_config=None,
|
|
168
170
|
)
|
|
169
|
-
|
|
170
|
-
embedding_config=embedding_config,
|
|
171
|
+
transformer_config = TransformerConfig(
|
|
171
172
|
global_rope_config=rope_config,
|
|
172
173
|
local_rope_config=None,
|
|
173
|
-
layer_configs=(
|
|
174
|
+
layer_configs=(tranformer_layer_config,) * self.n_layers,
|
|
174
175
|
output_norm_config=rmsnorm_config,
|
|
175
|
-
vocab_size=self.vocab_size,
|
|
176
176
|
model_dim=self.dim,
|
|
177
177
|
hidden_dim=self._find_hidden_size(),
|
|
178
178
|
context_length=context_length or MAX_SEQUENCE_LENGTH,
|
|
179
179
|
)
|
|
180
|
+
return DecoderConfig(
|
|
181
|
+
embedding_config=embedding_config,
|
|
182
|
+
transformer_config=transformer_config,
|
|
183
|
+
vocab_size=self.vocab_size,
|
|
184
|
+
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from .common import
|
|
1
|
+
from .common import HuggingFaceLMConfig
|
|
2
2
|
from .gemma2 import HFGemma2Config
|
|
3
3
|
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
4
|
from .gpt_oss import HFGPTOssConfig
|
|
5
5
|
from .llama import HFLlamaConfig
|
|
6
6
|
from .llamba import HFLlambaConfig
|
|
7
7
|
from .mistral import HFMistralConfig
|
|
8
|
+
from .modern_bert import ModernBERTConfig
|
|
8
9
|
from .qwen2 import HFQwen2Config
|
|
9
10
|
from .qwen3 import HFQwen3Config
|
|
10
11
|
|
|
@@ -18,5 +19,6 @@ __all__ = [
|
|
|
18
19
|
"HFMistralConfig",
|
|
19
20
|
"HFQwen2Config",
|
|
20
21
|
"HFQwen3Config",
|
|
21
|
-
"
|
|
22
|
+
"HuggingFaceLMConfig",
|
|
23
|
+
"ModernBERTConfig",
|
|
22
24
|
]
|
|
@@ -6,15 +6,22 @@ import cattrs
|
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
from jaxtyping import Array, DTypeLike
|
|
8
8
|
|
|
9
|
-
from lalamo.model_import.decoder_configs import
|
|
10
|
-
from lalamo.model_import.
|
|
9
|
+
from lalamo.model_import.decoder_configs import ForeignLMConfig
|
|
10
|
+
from lalamo.model_import.decoder_configs.common import ForeignClassifierConfig
|
|
11
|
+
from lalamo.model_import.loaders import (
|
|
12
|
+
load_huggingface_classifier,
|
|
13
|
+
load_huggingface_decoder,
|
|
14
|
+
)
|
|
11
15
|
from lalamo.modules import Decoder
|
|
16
|
+
from lalamo.modules.classifier import Classifier
|
|
17
|
+
from lalamo.modules.common import LalamoModule
|
|
12
18
|
|
|
13
19
|
__all__ = [
|
|
14
20
|
"AWQQuantizationConfig",
|
|
15
21
|
"GPTQMetaConfig",
|
|
16
22
|
"GPTQQuantizationConfig",
|
|
17
|
-
"
|
|
23
|
+
"HuggingFaceClassifierConfig",
|
|
24
|
+
"HuggingFaceLMConfig",
|
|
18
25
|
]
|
|
19
26
|
|
|
20
27
|
|
|
@@ -85,26 +92,45 @@ def _structure_quantization_config(v: object, _: object) -> QuantizationConfigTy
|
|
|
85
92
|
|
|
86
93
|
|
|
87
94
|
@dataclass(frozen=True)
|
|
88
|
-
class
|
|
95
|
+
class HuggingFaceLMConfig(ForeignLMConfig):
|
|
89
96
|
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
90
97
|
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
91
98
|
_converter.register_structure_hook(QuantizationConfigType, _structure_quantization_config)
|
|
92
99
|
|
|
93
100
|
@property
|
|
94
101
|
def eos_token_ids(self) -> list[int]:
|
|
95
|
-
|
|
96
|
-
|
|
102
|
+
result = getattr(self, "eos_token_id", None)
|
|
103
|
+
if result is None:
|
|
104
|
+
raise RuntimeError("model doesn't have eos_token_id, override eos_token_ids in model config")
|
|
97
105
|
|
|
98
|
-
|
|
106
|
+
if isinstance(result, int):
|
|
107
|
+
result = [result]
|
|
99
108
|
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def default_precision(self) -> DTypeLike:
|
|
113
|
+
return jnp.dtype(getattr(self, "torch_dtype", "bfloat16"))
|
|
114
|
+
|
|
115
|
+
def _load_weights(
|
|
116
|
+
self,
|
|
117
|
+
model: LalamoModule,
|
|
118
|
+
weights_dict: Mapping[str, Array],
|
|
119
|
+
) -> LalamoModule:
|
|
120
|
+
assert isinstance(model, Decoder)
|
|
121
|
+
return load_huggingface_decoder(model, weights_dict)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass(frozen=True)
|
|
125
|
+
class HuggingFaceClassifierConfig(ForeignClassifierConfig):
|
|
100
126
|
@property
|
|
101
127
|
def default_precision(self) -> DTypeLike:
|
|
102
128
|
return jnp.dtype(getattr(self, "torch_dtype", "bfloat16"))
|
|
103
129
|
|
|
104
|
-
@classmethod
|
|
105
130
|
def _load_weights(
|
|
106
|
-
|
|
107
|
-
model:
|
|
131
|
+
self,
|
|
132
|
+
model: LalamoModule,
|
|
108
133
|
weights_dict: Mapping[str, Array],
|
|
109
|
-
) ->
|
|
110
|
-
|
|
134
|
+
) -> LalamoModule:
|
|
135
|
+
assert isinstance(model, Classifier)
|
|
136
|
+
return load_huggingface_classifier(model, weights_dict)
|
|
@@ -7,23 +7,24 @@ from jaxtyping import DTypeLike
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
AttentionConfig,
|
|
9
9
|
DecoderConfig,
|
|
10
|
-
DecoderLayerConfig,
|
|
11
10
|
DenseMLPConfig,
|
|
12
11
|
FullPrecisionLinearConfig,
|
|
13
|
-
|
|
12
|
+
NormalizationConfig,
|
|
14
13
|
TiedEmbeddingConfig,
|
|
14
|
+
TransformerConfig,
|
|
15
|
+
TransformerLayerConfig,
|
|
15
16
|
UnscaledRoPEConfig,
|
|
16
17
|
UpcastMode,
|
|
17
18
|
)
|
|
18
19
|
from lalamo.modules.activations import GELU
|
|
19
20
|
|
|
20
|
-
from .common import
|
|
21
|
+
from .common import HuggingFaceLMConfig
|
|
21
22
|
|
|
22
23
|
__all__ = ["HFGemma2Config"]
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
@dataclass(frozen=True)
|
|
26
|
-
class HFGemma2Config(
|
|
27
|
+
class HFGemma2Config(HuggingFaceLMConfig):
|
|
27
28
|
architectures: list[Literal["Gemma2ForCausalLM"]]
|
|
28
29
|
attention_bias: bool
|
|
29
30
|
attention_dropout: float
|
|
@@ -72,12 +73,13 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
72
73
|
base=self.rope_theta,
|
|
73
74
|
max_sequence_length=self.max_position_embeddings,
|
|
74
75
|
)
|
|
75
|
-
rmsnorm_config =
|
|
76
|
+
rmsnorm_config = NormalizationConfig(
|
|
76
77
|
scale_precision=activation_precision,
|
|
77
78
|
accumulation_precision=accumulation_precision,
|
|
78
79
|
epsilon=self.rms_norm_eps,
|
|
79
80
|
scale_offset=1.0,
|
|
80
81
|
upcast_mode=UpcastMode.FULL_LAYER,
|
|
82
|
+
subtract_mean=False,
|
|
81
83
|
)
|
|
82
84
|
linear_config = FullPrecisionLinearConfig(
|
|
83
85
|
precision=activation_precision,
|
|
@@ -110,7 +112,7 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
110
112
|
scale=attention_scale,
|
|
111
113
|
sliding_window_size=sliding_window_size,
|
|
112
114
|
)
|
|
113
|
-
|
|
115
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
114
116
|
pre_mixer_norm_config=rmsnorm_config,
|
|
115
117
|
mixer_config=attention_config,
|
|
116
118
|
post_mixer_norm_config=rmsnorm_config,
|
|
@@ -118,16 +120,19 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
118
120
|
mlp_config=mlp_config,
|
|
119
121
|
post_mlp_norm_config=rmsnorm_config,
|
|
120
122
|
)
|
|
121
|
-
layer_configs.append(
|
|
123
|
+
layer_configs.append(transformer_layer_config)
|
|
122
124
|
|
|
123
|
-
|
|
124
|
-
embedding_config=embedding_config,
|
|
125
|
+
transformer_config = TransformerConfig(
|
|
125
126
|
global_rope_config=rope_config,
|
|
126
127
|
local_rope_config=None,
|
|
127
128
|
layer_configs=tuple(layer_configs),
|
|
128
129
|
output_norm_config=rmsnorm_config,
|
|
129
|
-
vocab_size=self.vocab_size,
|
|
130
130
|
model_dim=self.hidden_size,
|
|
131
131
|
hidden_dim=self.intermediate_size,
|
|
132
132
|
context_length=context_length or self.max_position_embeddings,
|
|
133
133
|
)
|
|
134
|
+
return DecoderConfig(
|
|
135
|
+
embedding_config=embedding_config,
|
|
136
|
+
transformer_config=transformer_config,
|
|
137
|
+
vocab_size=self.vocab_size,
|
|
138
|
+
)
|
|
@@ -1,23 +1,20 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from typing import Literal
|
|
4
4
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
from jaxtyping import DTypeLike
|
|
7
7
|
|
|
8
|
-
from lalamo.modules import
|
|
9
|
-
DecoderConfig,
|
|
10
|
-
TiedEmbeddingConfig,
|
|
11
|
-
)
|
|
8
|
+
from lalamo.modules import DecoderConfig, TiedEmbeddingConfig, TransformerConfig
|
|
12
9
|
from lalamo.modules.activations import GELU
|
|
13
|
-
from lalamo.modules.decoder_layer import DecoderLayerConfig
|
|
14
10
|
from lalamo.modules.linear import FullPrecisionLinearConfig
|
|
15
11
|
from lalamo.modules.mlp import DenseMLPConfig
|
|
16
|
-
from lalamo.modules.normalization import
|
|
12
|
+
from lalamo.modules.normalization import NormalizationConfig, UpcastMode
|
|
17
13
|
from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
|
|
18
|
-
from lalamo.modules.token_mixers import AttentionConfig
|
|
14
|
+
from lalamo.modules.token_mixers.attention import AttentionConfig
|
|
15
|
+
from lalamo.modules.transformer_layer import TransformerLayerConfig
|
|
19
16
|
|
|
20
|
-
from .common import
|
|
17
|
+
from .common import HuggingFaceLMConfig
|
|
21
18
|
|
|
22
19
|
__all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
|
|
23
20
|
|
|
@@ -80,12 +77,13 @@ class HFGemma3TextConfigRaw:
|
|
|
80
77
|
logit_soft_cap=None,
|
|
81
78
|
precision=activation_precision,
|
|
82
79
|
)
|
|
83
|
-
rms_norm_config =
|
|
80
|
+
rms_norm_config = NormalizationConfig(
|
|
84
81
|
scale_precision=activation_precision,
|
|
85
82
|
accumulation_precision=accumulation_precision,
|
|
86
83
|
epsilon=self.rms_norm_eps,
|
|
87
84
|
scale_offset=1.0,
|
|
88
85
|
upcast_mode=UpcastMode.FULL_LAYER,
|
|
86
|
+
subtract_mean=False,
|
|
89
87
|
)
|
|
90
88
|
|
|
91
89
|
if self.rope_scaling is not None:
|
|
@@ -134,7 +132,7 @@ class HFGemma3TextConfigRaw:
|
|
|
134
132
|
scale=attention_scale,
|
|
135
133
|
sliding_window_size=sliding_window_size,
|
|
136
134
|
)
|
|
137
|
-
|
|
135
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
138
136
|
pre_mixer_norm_config=rms_norm_config,
|
|
139
137
|
mixer_config=attention_config,
|
|
140
138
|
post_mixer_norm_config=rms_norm_config,
|
|
@@ -142,23 +140,29 @@ class HFGemma3TextConfigRaw:
|
|
|
142
140
|
mlp_config=mlp_config,
|
|
143
141
|
post_mlp_norm_config=rms_norm_config,
|
|
144
142
|
)
|
|
145
|
-
layer_configs.append(
|
|
146
|
-
|
|
147
|
-
|
|
143
|
+
layer_configs.append(transformer_layer_config)
|
|
144
|
+
|
|
145
|
+
transformer_config = TransformerConfig(
|
|
148
146
|
global_rope_config=global_rope_config,
|
|
149
147
|
local_rope_config=local_rope_config,
|
|
150
148
|
layer_configs=tuple(layer_configs),
|
|
151
149
|
output_norm_config=rms_norm_config,
|
|
152
|
-
vocab_size=self.vocab_size,
|
|
153
150
|
model_dim=self.hidden_size,
|
|
154
151
|
hidden_dim=self.intermediate_size,
|
|
155
152
|
context_length=context_length or self.max_position_embeddings,
|
|
156
153
|
)
|
|
157
154
|
|
|
155
|
+
return DecoderConfig(
|
|
156
|
+
embedding_config=embedding_config,
|
|
157
|
+
transformer_config=transformer_config,
|
|
158
|
+
vocab_size=self.vocab_size,
|
|
159
|
+
)
|
|
160
|
+
|
|
158
161
|
|
|
159
162
|
@dataclass(frozen=True)
|
|
160
|
-
class HFGemma3TextConfig(HFGemma3TextConfigRaw,
|
|
163
|
+
class HFGemma3TextConfig(HFGemma3TextConfigRaw, HuggingFaceLMConfig):
|
|
161
164
|
torch_dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
|
|
165
|
+
eos_token_id: int | list[int] = field(default_factory=list)
|
|
162
166
|
|
|
163
167
|
|
|
164
168
|
@dataclass(frozen=True)
|
|
@@ -174,7 +178,7 @@ class HFGemma3VisionConfig:
|
|
|
174
178
|
|
|
175
179
|
|
|
176
180
|
@dataclass(frozen=True)
|
|
177
|
-
class HFGemma3Config(
|
|
181
|
+
class HFGemma3Config(HuggingFaceLMConfig):
|
|
178
182
|
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
179
183
|
architectures: list[Literal["Gemma3ForConditionalGeneration"]]
|
|
180
184
|
boi_token_index: int
|
|
@@ -7,20 +7,21 @@ from jaxtyping import DTypeLike
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
AttentionConfig,
|
|
9
9
|
DecoderConfig,
|
|
10
|
-
DecoderLayerConfig,
|
|
11
10
|
DenseMLPConfig,
|
|
12
11
|
FullPrecisionLinearConfig,
|
|
13
12
|
MixtureOfExpertsConfig,
|
|
14
|
-
|
|
13
|
+
NormalizationConfig,
|
|
15
14
|
SoftmaxRouting,
|
|
16
15
|
TiedEmbeddingConfig,
|
|
16
|
+
TransformerConfig,
|
|
17
|
+
TransformerLayerConfig,
|
|
17
18
|
UntiedEmbeddingConfig,
|
|
18
19
|
UpcastMode,
|
|
19
20
|
YARNRoPEConfig,
|
|
20
21
|
)
|
|
21
22
|
from lalamo.modules.activations import SiLU
|
|
22
23
|
|
|
23
|
-
from .common import
|
|
24
|
+
from .common import HuggingFaceLMConfig
|
|
24
25
|
|
|
25
26
|
__all__ = ["HFGPTOssConfig"]
|
|
26
27
|
|
|
@@ -36,7 +37,7 @@ class YarnRopeScalingConfig:
|
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
@dataclass(frozen=True)
|
|
39
|
-
class HFGPTOssConfig(
|
|
40
|
+
class HFGPTOssConfig(HuggingFaceLMConfig):
|
|
40
41
|
# Core HF fields
|
|
41
42
|
architectures: list[Literal["GptOssForCausalLM"]]
|
|
42
43
|
attention_bias: bool
|
|
@@ -115,12 +116,13 @@ class HFGPTOssConfig(HuggingFaceConfig):
|
|
|
115
116
|
truncate=True,
|
|
116
117
|
)
|
|
117
118
|
|
|
118
|
-
rmsnorm_config =
|
|
119
|
+
rmsnorm_config = NormalizationConfig(
|
|
119
120
|
scale_precision=activation_precision,
|
|
120
121
|
accumulation_precision=accumulation_precision,
|
|
121
122
|
epsilon=self.rms_norm_eps,
|
|
122
123
|
scale_offset=None,
|
|
123
124
|
upcast_mode=UpcastMode.FULL_LAYER,
|
|
125
|
+
subtract_mean=False,
|
|
124
126
|
)
|
|
125
127
|
|
|
126
128
|
# Linear layers
|
|
@@ -179,7 +181,7 @@ class HFGPTOssConfig(HuggingFaceConfig):
|
|
|
179
181
|
scale=None,
|
|
180
182
|
sliding_window_size=sliding_window_size,
|
|
181
183
|
)
|
|
182
|
-
|
|
184
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
183
185
|
pre_mixer_norm_config=rmsnorm_config,
|
|
184
186
|
mixer_config=attention_config,
|
|
185
187
|
post_mixer_norm_config=None,
|
|
@@ -187,16 +189,20 @@ class HFGPTOssConfig(HuggingFaceConfig):
|
|
|
187
189
|
mlp_config=moe_config,
|
|
188
190
|
post_mlp_norm_config=None,
|
|
189
191
|
)
|
|
190
|
-
layer_configs.append(
|
|
192
|
+
layer_configs.append(transformer_layer_config)
|
|
191
193
|
|
|
192
|
-
|
|
193
|
-
embedding_config=embedding_config,
|
|
194
|
+
transformer_config = TransformerConfig(
|
|
194
195
|
global_rope_config=rope_config,
|
|
195
196
|
local_rope_config=None,
|
|
196
197
|
layer_configs=tuple(layer_configs),
|
|
197
198
|
output_norm_config=rmsnorm_config,
|
|
198
|
-
vocab_size=self.vocab_size,
|
|
199
199
|
model_dim=self.hidden_size,
|
|
200
200
|
hidden_dim=self.intermediate_size,
|
|
201
201
|
context_length=context_length or self.max_position_embeddings,
|
|
202
202
|
)
|
|
203
|
+
|
|
204
|
+
return DecoderConfig(
|
|
205
|
+
embedding_config=embedding_config,
|
|
206
|
+
transformer_config=transformer_config,
|
|
207
|
+
vocab_size=self.vocab_size,
|
|
208
|
+
)
|
|
@@ -7,14 +7,15 @@ from jaxtyping import DTypeLike
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
AttentionConfig,
|
|
9
9
|
DecoderConfig,
|
|
10
|
-
DecoderLayerConfig,
|
|
11
10
|
DenseMLPConfig,
|
|
12
11
|
FullPrecisionLinearConfig,
|
|
13
12
|
GroupQuantizedLinearConfig,
|
|
14
13
|
LlamaRoPEConfig,
|
|
15
|
-
|
|
14
|
+
NormalizationConfig,
|
|
16
15
|
SiLU,
|
|
17
16
|
TiedEmbeddingConfig,
|
|
17
|
+
TransformerConfig,
|
|
18
|
+
TransformerLayerConfig,
|
|
18
19
|
UnscaledRoPEConfig,
|
|
19
20
|
UntiedEmbeddingConfig,
|
|
20
21
|
UpcastMode,
|
|
@@ -22,7 +23,7 @@ from lalamo.modules import (
|
|
|
22
23
|
)
|
|
23
24
|
from lalamo.quantization import QuantizationMode
|
|
24
25
|
|
|
25
|
-
from .common import AWQQuantizationConfig, GPTQQuantizationConfig,
|
|
26
|
+
from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceLMConfig
|
|
26
27
|
|
|
27
28
|
__all__ = ["HFLlamaConfig"]
|
|
28
29
|
|
|
@@ -47,7 +48,7 @@ class YarnRopeScalingConfig:
|
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
@dataclass(frozen=True)
|
|
50
|
-
class HFLlamaConfig(
|
|
51
|
+
class HFLlamaConfig(HuggingFaceLMConfig):
|
|
51
52
|
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
52
53
|
architectures: list[Literal["LlamaForCausalLM"]]
|
|
53
54
|
attention_bias: bool
|
|
@@ -124,12 +125,13 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
124
125
|
)
|
|
125
126
|
else:
|
|
126
127
|
raise ValueError("Unsupported rope_scaling configuration")
|
|
127
|
-
rmsnorm_config =
|
|
128
|
+
rmsnorm_config = NormalizationConfig(
|
|
128
129
|
scale_precision=activation_precision,
|
|
129
130
|
accumulation_precision=accumulation_precision,
|
|
130
131
|
epsilon=self.rms_norm_eps,
|
|
131
132
|
scale_offset=None,
|
|
132
133
|
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
134
|
+
subtract_mean=False,
|
|
133
135
|
)
|
|
134
136
|
if self.quantization_config is None:
|
|
135
137
|
linear_config = FullPrecisionLinearConfig(
|
|
@@ -153,7 +155,7 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
153
155
|
has_out_biases=False,
|
|
154
156
|
num_heads=self.num_attention_heads,
|
|
155
157
|
num_groups=self.num_key_value_heads,
|
|
156
|
-
head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
|
|
158
|
+
head_dim=(self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads),
|
|
157
159
|
is_causal=True,
|
|
158
160
|
scale=None,
|
|
159
161
|
sliding_window_size=None,
|
|
@@ -166,7 +168,7 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
166
168
|
up_clipping=None,
|
|
167
169
|
gate_clipping=None,
|
|
168
170
|
)
|
|
169
|
-
|
|
171
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
170
172
|
pre_mixer_norm_config=rmsnorm_config,
|
|
171
173
|
mixer_config=attention_config,
|
|
172
174
|
post_mixer_norm_config=None,
|
|
@@ -174,14 +176,17 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
174
176
|
mlp_config=mlp_config,
|
|
175
177
|
post_mlp_norm_config=None,
|
|
176
178
|
)
|
|
177
|
-
|
|
178
|
-
embedding_config=embedding_config,
|
|
179
|
+
transformer_config = TransformerConfig(
|
|
179
180
|
global_rope_config=rope_config,
|
|
180
181
|
local_rope_config=None,
|
|
181
|
-
layer_configs=(
|
|
182
|
+
layer_configs=(transformer_layer_config,) * self.num_hidden_layers,
|
|
182
183
|
output_norm_config=rmsnorm_config,
|
|
183
|
-
vocab_size=self.vocab_size,
|
|
184
184
|
model_dim=self.hidden_size,
|
|
185
185
|
hidden_dim=self.intermediate_size,
|
|
186
186
|
context_length=context_length or self.max_position_embeddings,
|
|
187
187
|
)
|
|
188
|
+
return DecoderConfig(
|
|
189
|
+
embedding_config=embedding_config,
|
|
190
|
+
transformer_config=transformer_config,
|
|
191
|
+
vocab_size=self.vocab_size,
|
|
192
|
+
)
|
|
@@ -6,23 +6,24 @@ from jaxtyping import DTypeLike
|
|
|
6
6
|
|
|
7
7
|
from lalamo.modules import (
|
|
8
8
|
DecoderConfig,
|
|
9
|
-
DecoderLayerConfig,
|
|
10
9
|
DenseMLPConfig,
|
|
11
10
|
FullPrecisionLinearConfig,
|
|
12
11
|
Identity,
|
|
13
12
|
Mamba2Config,
|
|
14
13
|
MLXQuantizedLinearConfig,
|
|
15
14
|
MLXSemiQuantizedUntiedEmbeddingConfig,
|
|
16
|
-
|
|
15
|
+
NormalizationConfig,
|
|
17
16
|
SeparableCausalConvConfig,
|
|
18
17
|
SiLU,
|
|
19
18
|
TiedEmbeddingConfig,
|
|
19
|
+
TransformerConfig,
|
|
20
|
+
TransformerLayerConfig,
|
|
20
21
|
UntiedEmbeddingConfig,
|
|
21
22
|
UpcastMode,
|
|
22
23
|
)
|
|
23
24
|
from lalamo.quantization import QuantizationMode
|
|
24
25
|
|
|
25
|
-
from .common import
|
|
26
|
+
from .common import HuggingFaceLMConfig
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@dataclass(frozen=True)
|
|
@@ -45,7 +46,7 @@ class HFLlambaSsmConfig:
|
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
@dataclass(frozen=True)
|
|
48
|
-
class HFLlambaConfig(
|
|
49
|
+
class HFLlambaConfig(HuggingFaceLMConfig):
|
|
49
50
|
model_type: Literal["llamba"]
|
|
50
51
|
vocab_size: int
|
|
51
52
|
tie_embeddings: bool
|
|
@@ -74,7 +75,9 @@ class HFLlambaConfig(HuggingFaceConfig):
|
|
|
74
75
|
input_scale=None,
|
|
75
76
|
logit_soft_cap=None,
|
|
76
77
|
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
77
|
-
embedding_quantization_mode=QuantizationMode.from_num_bits(
|
|
78
|
+
embedding_quantization_mode=QuantizationMode.from_num_bits(
|
|
79
|
+
int(metadata_dict["quantization_kwargs.bits"])
|
|
80
|
+
),
|
|
78
81
|
activation_quantization_mode=None,
|
|
79
82
|
activation_precision=activation_precision,
|
|
80
83
|
)
|
|
@@ -91,18 +94,21 @@ class HFLlambaConfig(HuggingFaceConfig):
|
|
|
91
94
|
precision=activation_precision,
|
|
92
95
|
)
|
|
93
96
|
|
|
94
|
-
rmsnorm_config =
|
|
97
|
+
rmsnorm_config = NormalizationConfig(
|
|
95
98
|
scale_precision=activation_precision,
|
|
96
99
|
accumulation_precision=accumulation_precision,
|
|
97
100
|
epsilon=self.norm_epsilon,
|
|
98
101
|
scale_offset=None,
|
|
99
102
|
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
103
|
+
subtract_mean=False,
|
|
100
104
|
)
|
|
101
105
|
|
|
102
|
-
if "quantization_kwargs.group_size" in metadata_dict:
|
|
106
|
+
if metadata_dict and "quantization_kwargs.group_size" in metadata_dict:
|
|
103
107
|
linear_config = MLXQuantizedLinearConfig(
|
|
104
108
|
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
105
|
-
weight_quantization_mode=QuantizationMode.from_num_bits(
|
|
109
|
+
weight_quantization_mode=QuantizationMode.from_num_bits(
|
|
110
|
+
int(metadata_dict["quantization_kwargs.bits"])
|
|
111
|
+
),
|
|
106
112
|
activation_quantization_mode=None,
|
|
107
113
|
activation_precision=activation_precision,
|
|
108
114
|
)
|
|
@@ -148,7 +154,7 @@ class HFLlambaConfig(HuggingFaceConfig):
|
|
|
148
154
|
has_out_biases=self.ssm_cfg.bias,
|
|
149
155
|
)
|
|
150
156
|
|
|
151
|
-
|
|
157
|
+
transformer_layer_config = TransformerLayerConfig(
|
|
152
158
|
pre_mixer_norm_config=rmsnorm_config,
|
|
153
159
|
mixer_config=mamba_config,
|
|
154
160
|
post_mixer_norm_config=None,
|
|
@@ -156,15 +162,18 @@ class HFLlambaConfig(HuggingFaceConfig):
|
|
|
156
162
|
mlp_config=mlp_config,
|
|
157
163
|
post_mlp_norm_config=None,
|
|
158
164
|
)
|
|
159
|
-
|
|
160
|
-
return DecoderConfig(
|
|
161
|
-
embedding_config=embedding_config,
|
|
165
|
+
transformer_config = TransformerConfig(
|
|
162
166
|
global_rope_config=None,
|
|
163
167
|
local_rope_config=None,
|
|
164
|
-
layer_configs=(
|
|
168
|
+
layer_configs=(transformer_layer_config,) * self.n_layer,
|
|
165
169
|
output_norm_config=rmsnorm_config,
|
|
166
|
-
vocab_size=self.vocab_size,
|
|
167
170
|
model_dim=self.d_model,
|
|
168
171
|
hidden_dim=self.mlp_cfg.intermediate_size,
|
|
169
172
|
context_length=context_length or 4096,
|
|
170
173
|
)
|
|
174
|
+
|
|
175
|
+
return DecoderConfig(
|
|
176
|
+
embedding_config=embedding_config,
|
|
177
|
+
transformer_config=transformer_config,
|
|
178
|
+
vocab_size=self.vocab_size,
|
|
179
|
+
)
|