lalamo 0.4.1__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.4.1 → lalamo-0.5.0}/PKG-INFO +3 -2
- {lalamo-0.4.1 → lalamo-0.5.0}/README.md +2 -1
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/__init__.py +1 -1
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/language_model.py +22 -23
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/main.py +2 -16
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/common.py +24 -6
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/__init__.py +2 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/common.py +4 -4
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/executorch.py +17 -10
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo-0.5.0/lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/loaders/executorch.py +5 -4
- lalamo-0.5.0/lalamo/model_import/loaders/huggingface.py +653 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/__init__.py +2 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/common.py +16 -5
- lalamo-0.5.0/lalamo/model_import/model_specs/llamba.py +40 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/qwen.py +29 -1
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/__init__.py +33 -6
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/activations.py +9 -2
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/common.py +10 -5
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/decoder.py +93 -97
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/decoder_layer.py +85 -103
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/embedding.py +279 -5
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/linear.py +335 -30
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/mlp.py +6 -7
- lalamo-0.5.0/lalamo/modules/mlx_interop.py +19 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/rope.py +1 -1
- lalamo-0.5.0/lalamo/modules/token_mixers/__init__.py +30 -0
- {lalamo-0.4.1/lalamo/modules → lalamo-0.5.0/lalamo/modules/token_mixers}/attention.py +72 -70
- lalamo-0.5.0/lalamo/modules/token_mixers/common.py +78 -0
- lalamo-0.5.0/lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo-0.5.0/lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo-0.5.0/lalamo/modules/token_mixers/state/common.py +26 -0
- {lalamo-0.4.1/lalamo/modules → lalamo-0.5.0/lalamo/modules/token_mixers/state}/kv_cache.py +5 -16
- lalamo-0.5.0/lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/utils.py +24 -2
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo.egg-info/PKG-INFO +3 -2
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo.egg-info/SOURCES.txt +13 -2
- {lalamo-0.4.1 → lalamo-0.5.0}/pyproject.toml +1 -1
- {lalamo-0.4.1 → lalamo-0.5.0}/tests/test_generation.py +4 -4
- lalamo-0.5.0/tests/test_huggingface_models.py +24 -0
- lalamo-0.5.0/tests/test_mlx_models.py +20 -0
- lalamo-0.5.0/tests/test_models.py +456 -0
- lalamo-0.4.1/lalamo/model_import/loaders/huggingface.py +0 -401
- lalamo-0.4.1/tests/test_huggingface_models.py +0 -87
- {lalamo-0.4.1 → lalamo-0.5.0}/LICENSE +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/common.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/data/__init__.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/data/utils.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/message_processor.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/modules/utils.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/quantization.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/registry_abc.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/sampling.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/speculator/common.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/speculator/inference.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/setup.cfg +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/tests/test_model_spec.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/tests/test_moe.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/tests/test_parameter_tree.py +0 -0
- {lalamo-0.4.1 → lalamo-0.5.0}/tests/test_registry_abc.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lalamo
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: JAX library for optimization and export of models for use with the UZU inference engine.
|
|
5
5
|
Requires-Python: <4,>=3.12
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
@@ -38,7 +38,8 @@ Dynamic: license-file
|
|
|
38
38
|
|
|
39
39
|
<a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
|
|
40
40
|
<a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
|
|
41
|
-
<a href="
|
|
41
|
+
<a href="https://discord.com/invite/trymirai"><img src="https://img.shields.io/discord/1377764166764462120?label=Discord" alt="Discord"></a>
|
|
42
|
+
<a href="mailto:contact@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
|
|
42
43
|
<a href="https://docs.trymirai.com/overview/lalamo"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
|
|
43
44
|
[](LICENSE)
|
|
44
45
|
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
|
|
7
7
|
<a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
|
|
8
8
|
<a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
|
|
9
|
-
<a href="
|
|
9
|
+
<a href="https://discord.com/invite/trymirai"><img src="https://img.shields.io/discord/1377764166764462120?label=Discord" alt="Discord"></a>
|
|
10
|
+
<a href="mailto:contact@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
|
|
10
11
|
<a href="https://docs.trymirai.com/overview/lalamo"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
|
|
11
12
|
[](LICENSE)
|
|
12
13
|
|
|
@@ -14,8 +14,7 @@ from tokenizers import Tokenizer
|
|
|
14
14
|
|
|
15
15
|
from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
|
|
16
16
|
from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
|
|
17
|
-
from lalamo.modules import Decoder, DecoderConfig,
|
|
18
|
-
from lalamo.modules.common import ForwardPassMode
|
|
17
|
+
from lalamo.modules import Decoder, DecoderConfig, ForwardPassMode, LalamoModule, State, config_converter
|
|
19
18
|
from lalamo.modules.decoder import DecoderForwardPassConfig
|
|
20
19
|
from lalamo.sampling import SamplingPolicy, make_policy
|
|
21
20
|
from lalamo.utils import open_safetensors
|
|
@@ -37,13 +36,13 @@ type ForwardPassConfig = DecoderForwardPassConfig
|
|
|
37
36
|
class PrefillResults(NamedTuple):
|
|
38
37
|
last_token_logits: Float[Array, "batch vocabulary"]
|
|
39
38
|
last_token_indices: Int[Array, " batch"]
|
|
40
|
-
|
|
39
|
+
state: State
|
|
41
40
|
|
|
42
41
|
|
|
43
42
|
class DecodingState(NamedTuple):
|
|
44
43
|
last_token_logits: Float[Array, "batch vocabulary"]
|
|
45
44
|
last_token_indices: Int[Array, " batch"]
|
|
46
|
-
|
|
45
|
+
state: State
|
|
47
46
|
stop_flags: Bool[Array, " batch"]
|
|
48
47
|
|
|
49
48
|
|
|
@@ -89,7 +88,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
89
88
|
with open(path / "config.json") as config_file:
|
|
90
89
|
config_json = json.load(config_file)
|
|
91
90
|
config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
|
|
92
|
-
with open_safetensors(path / "model.safetensors") as weights_dict:
|
|
91
|
+
with open_safetensors(path / "model.safetensors") as (weights_dict, _):
|
|
93
92
|
weights = unflatten_parameters(weights_dict)
|
|
94
93
|
decoder = config.decoder_config.empty().import_weights(weights)
|
|
95
94
|
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
|
@@ -124,21 +123,21 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
124
123
|
self,
|
|
125
124
|
token_ids: Int[Array, "batch tokens"],
|
|
126
125
|
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
127
|
-
|
|
126
|
+
state_capacity: int | None = None,
|
|
128
127
|
forward_pass_config: ForwardPassConfig | None = None,
|
|
129
128
|
) -> PrefillResults:
|
|
130
129
|
batch_size, sequence_length = token_ids.shape
|
|
131
130
|
token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
|
|
132
|
-
if
|
|
133
|
-
|
|
131
|
+
if state_capacity is not None:
|
|
132
|
+
state = self.decoder.init_static_state(batch_size, state_capacity)
|
|
134
133
|
else:
|
|
135
|
-
|
|
134
|
+
state = None
|
|
136
135
|
|
|
137
136
|
decoder_outputs = self.decoder(
|
|
138
137
|
token_ids,
|
|
139
138
|
token_positions,
|
|
140
|
-
|
|
141
|
-
|
|
139
|
+
state,
|
|
140
|
+
return_updated_state=True,
|
|
142
141
|
lengths_without_padding=lengths_without_padding,
|
|
143
142
|
forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
|
|
144
143
|
forward_pass_config=forward_pass_config,
|
|
@@ -151,11 +150,11 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
151
150
|
|
|
152
151
|
last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
|
|
153
152
|
|
|
154
|
-
assert decoder_outputs.
|
|
153
|
+
assert decoder_outputs.updated_state is not None
|
|
155
154
|
return PrefillResults(
|
|
156
155
|
last_token_logits=last_token_logits,
|
|
157
156
|
last_token_indices=last_logits_indices,
|
|
158
|
-
|
|
157
|
+
state=decoder_outputs.updated_state,
|
|
159
158
|
)
|
|
160
159
|
|
|
161
160
|
@eqx.filter_jit
|
|
@@ -187,7 +186,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
187
186
|
initial_state = DecodingState(
|
|
188
187
|
prefill_results.last_token_logits,
|
|
189
188
|
prefill_results.last_token_indices,
|
|
190
|
-
prefill_results.
|
|
189
|
+
prefill_results.state,
|
|
191
190
|
jnp.zeros(batch_size, dtype=jnp.bool),
|
|
192
191
|
)
|
|
193
192
|
|
|
@@ -224,16 +223,16 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
224
223
|
decoder_outputs = self.decoder(
|
|
225
224
|
next_token_ids[:, None],
|
|
226
225
|
next_token_indices[:, None],
|
|
227
|
-
state.
|
|
228
|
-
|
|
226
|
+
state.state,
|
|
227
|
+
return_updated_state=True,
|
|
229
228
|
forward_pass_mode=forward_pass_mode,
|
|
230
229
|
forward_pass_config=forward_pass_config,
|
|
231
230
|
)
|
|
232
|
-
assert decoder_outputs.
|
|
231
|
+
assert decoder_outputs.updated_state is not None, "updated_state should not be None"
|
|
233
232
|
new_state = DecodingState(
|
|
234
233
|
decoder_outputs.logits.squeeze(1),
|
|
235
234
|
next_token_indices,
|
|
236
|
-
decoder_outputs.
|
|
235
|
+
decoder_outputs.updated_state,
|
|
237
236
|
stop_flags,
|
|
238
237
|
)
|
|
239
238
|
return new_state, GenerationStepResults(next_token_ids, next_top_k_token_ids, next_top_k_token_logits)
|
|
@@ -338,7 +337,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
338
337
|
state = DecodingState(
|
|
339
338
|
prefill_results.last_token_logits,
|
|
340
339
|
prefill_results.last_token_indices,
|
|
341
|
-
prefill_results.
|
|
340
|
+
prefill_results.state,
|
|
342
341
|
jnp.array([0], dtype=jnp.bool),
|
|
343
342
|
)
|
|
344
343
|
|
|
@@ -356,14 +355,14 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
356
355
|
decoder_outputs = self.decoder(
|
|
357
356
|
next_token_id.reshape(1, 1),
|
|
358
357
|
next_token_indices.reshape(1, 1),
|
|
359
|
-
state.
|
|
360
|
-
|
|
358
|
+
state.state,
|
|
359
|
+
return_updated_state=True,
|
|
361
360
|
forward_pass_config=forward_pass_config,
|
|
362
361
|
)
|
|
363
|
-
assert decoder_outputs.
|
|
362
|
+
assert decoder_outputs.updated_state is not None, "updated_state should not be None"
|
|
364
363
|
state = DecodingState(
|
|
365
364
|
decoder_outputs.logits.squeeze(1),
|
|
366
365
|
next_token_indices,
|
|
367
|
-
decoder_outputs.
|
|
366
|
+
decoder_outputs.updated_state,
|
|
368
367
|
state.stop_flags,
|
|
369
368
|
)
|
|
@@ -27,7 +27,6 @@ from rich.progress import (
|
|
|
27
27
|
TextColumn,
|
|
28
28
|
TimeElapsedColumn,
|
|
29
29
|
TimeRemainingColumn,
|
|
30
|
-
track,
|
|
31
30
|
)
|
|
32
31
|
from rich.table import Table
|
|
33
32
|
from safetensors.flax import save_file
|
|
@@ -50,7 +49,6 @@ from lalamo.modules import config_converter
|
|
|
50
49
|
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
51
50
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
52
51
|
from lalamo.speculator.utils import SpeculatorTrainingEvent, test_speculator, train_speculator
|
|
53
|
-
from lalamo.utils import jax_uint4_to_packed_uint8
|
|
54
52
|
|
|
55
53
|
SCRIPT_NAME = Path(sys.argv[0]).name
|
|
56
54
|
|
|
@@ -109,16 +107,6 @@ def _error(message: str) -> None:
|
|
|
109
107
|
raise Exit(1)
|
|
110
108
|
|
|
111
109
|
|
|
112
|
-
def _pack_uint4_weights(weights: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]:
|
|
113
|
-
packed_weights = {}
|
|
114
|
-
for key, value in weights.items():
|
|
115
|
-
if value.dtype == jnp.uint4:
|
|
116
|
-
packed_weights[key] = jax_uint4_to_packed_uint8(value)
|
|
117
|
-
else:
|
|
118
|
-
packed_weights[key] = value
|
|
119
|
-
return packed_weights
|
|
120
|
-
|
|
121
|
-
|
|
122
110
|
@app.command(help="Chat with a converted model.")
|
|
123
111
|
def chat(
|
|
124
112
|
model_path: Annotated[
|
|
@@ -274,7 +262,7 @@ def convert(
|
|
|
274
262
|
result = model.decoder(
|
|
275
263
|
token_ids,
|
|
276
264
|
token_positions,
|
|
277
|
-
|
|
265
|
+
return_updated_state=True,
|
|
278
266
|
return_activation_trace=True,
|
|
279
267
|
)
|
|
280
268
|
traces = flatten_parameters(result.export())
|
|
@@ -286,8 +274,7 @@ def convert(
|
|
|
286
274
|
weights = flatten_parameters(model.export_weights())
|
|
287
275
|
del model
|
|
288
276
|
|
|
289
|
-
|
|
290
|
-
save_file(packed_weights, output_dir / "model.safetensors")
|
|
277
|
+
save_file(weights, output_dir / "model.safetensors")
|
|
291
278
|
|
|
292
279
|
config_json = config_converter.unstructure(metadata, ModelMetadata)
|
|
293
280
|
with open(output_dir / "config.json", "w") as file:
|
|
@@ -511,7 +498,6 @@ def train(
|
|
|
511
498
|
) as progress:
|
|
512
499
|
inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
|
|
513
500
|
|
|
514
|
-
|
|
515
501
|
def progress_callback(event: SpeculatorTrainingEvent) -> None:
|
|
516
502
|
progress.update(inference_task, completed=event.trained_tokens)
|
|
517
503
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
|
+
import json
|
|
2
3
|
from collections import ChainMap
|
|
3
4
|
from collections.abc import Callable
|
|
4
5
|
from contextlib import ExitStack
|
|
@@ -14,6 +15,7 @@ from tokenizers import Tokenizer
|
|
|
14
15
|
|
|
15
16
|
from lalamo.language_model import GenerationConfig, LanguageModel, LanguageModelConfig
|
|
16
17
|
from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
|
|
18
|
+
from lalamo.model_import.model_specs.common import JSONFieldSpec
|
|
17
19
|
from lalamo.quantization import QuantizationMode
|
|
18
20
|
|
|
19
21
|
from .huggingface_generation_config import HFGenerationConfig
|
|
@@ -130,10 +132,17 @@ def import_message_processor(
|
|
|
130
132
|
)
|
|
131
133
|
tokenizer_config = HFTokenizerConfig.from_json(tokenizer_config_file)
|
|
132
134
|
if tokenizer_config.chat_template is None:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
135
|
+
match model_spec.configs.chat_template:
|
|
136
|
+
case JSONFieldSpec(file_spec, field_name):
|
|
137
|
+
json_file = download_file(file_spec, model_spec.repo, output_dir)
|
|
138
|
+
with open(json_file) as file:
|
|
139
|
+
json_dict = json.load(file)
|
|
140
|
+
prompt_template = json_dict[field_name]
|
|
141
|
+
case FileSpec(_) as file_spec:
|
|
142
|
+
chat_template_file = download_file(file_spec, model_spec.repo, output_dir)
|
|
143
|
+
prompt_template = chat_template_file.read_text()
|
|
144
|
+
case None:
|
|
145
|
+
raise ValueError("No chat template specified.")
|
|
137
146
|
else:
|
|
138
147
|
if model_spec.configs.chat_template is not None:
|
|
139
148
|
raise ValueError("Conflicting chat template specifications.")
|
|
@@ -180,15 +189,24 @@ def import_model(
|
|
|
180
189
|
weights_paths = download_weights(model_spec, progress_callback=progress_callback)
|
|
181
190
|
with ExitStack() as stack:
|
|
182
191
|
weights_shards = []
|
|
192
|
+
metadata_shards = []
|
|
183
193
|
for weights_path in weights_paths:
|
|
184
|
-
weights_shard = stack.enter_context(model_spec.weights_type.load(weights_path, precision))
|
|
194
|
+
weights_shard, metadata_shard = stack.enter_context(model_spec.weights_type.load(weights_path, precision))
|
|
185
195
|
weights_shards.append(weights_shard)
|
|
196
|
+
metadata_shards.append(metadata_shard)
|
|
186
197
|
weights_dict: ChainMap[str, Array] = ChainMap(*weights_shards)
|
|
198
|
+
metadata_dict: ChainMap[str, str] = ChainMap(*metadata_shards)
|
|
187
199
|
|
|
188
200
|
if progress_callback is not None:
|
|
189
201
|
progress_callback(InitializingModelEvent())
|
|
190
202
|
|
|
191
|
-
decoder = foreign_decoder_config.load_decoder(
|
|
203
|
+
decoder = foreign_decoder_config.load_decoder(
|
|
204
|
+
context_length,
|
|
205
|
+
precision,
|
|
206
|
+
accumulation_precision,
|
|
207
|
+
weights_dict,
|
|
208
|
+
metadata_dict,
|
|
209
|
+
)
|
|
192
210
|
|
|
193
211
|
if progress_callback is not None:
|
|
194
212
|
progress_callback(FinishedInitializingModelEvent())
|
|
@@ -7,6 +7,7 @@ from .huggingface import (
|
|
|
7
7
|
HFGemma3TextConfig,
|
|
8
8
|
HFGPTOssConfig,
|
|
9
9
|
HFLlamaConfig,
|
|
10
|
+
HFLlambaConfig,
|
|
10
11
|
HFMistralConfig,
|
|
11
12
|
HFQwen2Config,
|
|
12
13
|
HFQwen3Config,
|
|
@@ -20,6 +21,7 @@ __all__ = [
|
|
|
20
21
|
"HFGemma3Config",
|
|
21
22
|
"HFGemma3TextConfig",
|
|
22
23
|
"HFLlamaConfig",
|
|
24
|
+
"HFLlambaConfig",
|
|
23
25
|
"HFMistralConfig",
|
|
24
26
|
"HFQwen2Config",
|
|
25
27
|
"HFQwen3Config",
|
|
@@ -19,11 +19,9 @@ class ForeignConfig(RegistryABC):
|
|
|
19
19
|
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
20
20
|
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
21
21
|
|
|
22
|
-
eos_token_id: int | list[int]
|
|
23
|
-
|
|
24
22
|
@property
|
|
25
23
|
def eos_token_ids(self) -> list[int]:
|
|
26
|
-
|
|
24
|
+
raise NotImplementedError
|
|
27
25
|
|
|
28
26
|
@property
|
|
29
27
|
@abstractmethod
|
|
@@ -41,6 +39,7 @@ class ForeignConfig(RegistryABC):
|
|
|
41
39
|
context_length: int | None,
|
|
42
40
|
activation_precision: DTypeLike,
|
|
43
41
|
accumulation_precision: DTypeLike,
|
|
42
|
+
metadata_dict: Mapping[str, str],
|
|
44
43
|
) -> DecoderConfig:
|
|
45
44
|
raise NotImplementedError
|
|
46
45
|
|
|
@@ -58,7 +57,8 @@ class ForeignConfig(RegistryABC):
|
|
|
58
57
|
activation_precision: DTypeLike,
|
|
59
58
|
accumulation_precision: DTypeLike,
|
|
60
59
|
weights_dict: Mapping[str, Array],
|
|
60
|
+
metadata_dict: Mapping[str, str],
|
|
61
61
|
) -> Decoder:
|
|
62
|
-
config = self.to_decoder_config(context_length, activation_precision, accumulation_precision)
|
|
62
|
+
config = self.to_decoder_config(context_length, activation_precision, accumulation_precision, metadata_dict)
|
|
63
63
|
model = config.empty()
|
|
64
64
|
return self._load_weights(model, weights_dict)
|
|
@@ -51,6 +51,12 @@ class LoraConfig:
|
|
|
51
51
|
|
|
52
52
|
@dataclass(frozen=True)
|
|
53
53
|
class ExecutorchConfig(ForeignConfig):
|
|
54
|
+
eos_token_id: int | list[int]
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def eos_token_ids(self) -> list[int]:
|
|
58
|
+
return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
|
|
59
|
+
|
|
54
60
|
@property
|
|
55
61
|
def default_precision(self) -> DTypeLike:
|
|
56
62
|
return jnp.bfloat16
|
|
@@ -89,6 +95,7 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
89
95
|
context_length: int | None,
|
|
90
96
|
activation_precision: DTypeLike,
|
|
91
97
|
accumulation_precision: DTypeLike,
|
|
98
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
92
99
|
) -> DecoderConfig:
|
|
93
100
|
if self.lora_args is None:
|
|
94
101
|
raise ValueError("We only support QLoRA models for now.")
|
|
@@ -136,6 +143,12 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
136
143
|
has_sinks=False,
|
|
137
144
|
has_qkv_biases=False,
|
|
138
145
|
has_out_biases=False,
|
|
146
|
+
num_heads=self.n_heads,
|
|
147
|
+
num_groups=self.n_kv_heads,
|
|
148
|
+
head_dim=self.dim // self.n_heads,
|
|
149
|
+
is_causal=True,
|
|
150
|
+
scale=None,
|
|
151
|
+
sliding_window_size=None,
|
|
139
152
|
)
|
|
140
153
|
mlp_config = DenseMLPConfig(
|
|
141
154
|
linear_config=linear_config,
|
|
@@ -146,9 +159,9 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
146
159
|
gate_clipping=None,
|
|
147
160
|
)
|
|
148
161
|
decoder_layer_config = DecoderLayerConfig(
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
162
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
163
|
+
mixer_config=attention_config,
|
|
164
|
+
post_mixer_norm_config=None,
|
|
152
165
|
pre_mlp_norm_config=rmsnorm_config,
|
|
153
166
|
mlp_config=mlp_config,
|
|
154
167
|
post_mlp_norm_config=None,
|
|
@@ -157,16 +170,10 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
157
170
|
embedding_config=embedding_config,
|
|
158
171
|
global_rope_config=rope_config,
|
|
159
172
|
local_rope_config=None,
|
|
160
|
-
|
|
173
|
+
layer_configs=(decoder_layer_config,) * self.n_layers,
|
|
161
174
|
output_norm_config=rmsnorm_config,
|
|
162
175
|
vocab_size=self.vocab_size,
|
|
163
176
|
model_dim=self.dim,
|
|
164
177
|
hidden_dim=self._find_hidden_size(),
|
|
165
|
-
num_heads=self.n_heads,
|
|
166
|
-
num_groups=self.n_kv_heads,
|
|
167
|
-
head_dim=self.dim // self.n_heads,
|
|
168
|
-
attention_scale=None,
|
|
169
|
-
num_layers=self.n_layers,
|
|
170
|
-
sliding_window_sizes=None,
|
|
171
178
|
context_length=context_length or MAX_SEQUENCE_LENGTH,
|
|
172
179
|
)
|
|
@@ -3,6 +3,7 @@ from .gemma2 import HFGemma2Config
|
|
|
3
3
|
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
4
|
from .gpt_oss import HFGPTOssConfig
|
|
5
5
|
from .llama import HFLlamaConfig
|
|
6
|
+
from .llamba import HFLlambaConfig
|
|
6
7
|
from .mistral import HFMistralConfig
|
|
7
8
|
from .qwen2 import HFQwen2Config
|
|
8
9
|
from .qwen3 import HFQwen3Config
|
|
@@ -13,6 +14,7 @@ __all__ = [
|
|
|
13
14
|
"HFGemma3Config",
|
|
14
15
|
"HFGemma3TextConfig",
|
|
15
16
|
"HFLlamaConfig",
|
|
17
|
+
"HFLlambaConfig",
|
|
16
18
|
"HFMistralConfig",
|
|
17
19
|
"HFQwen2Config",
|
|
18
20
|
"HFQwen3Config",
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import ClassVar, Literal
|
|
4
4
|
|
|
5
|
+
import cattrs
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
from jaxtyping import Array, DTypeLike
|
|
7
8
|
|
|
@@ -56,11 +57,45 @@ class GPTQQuantizationConfig:
|
|
|
56
57
|
sym: bool
|
|
57
58
|
|
|
58
59
|
|
|
60
|
+
@dataclass(frozen=True)
|
|
61
|
+
class MLXQuantizationConfig:
|
|
62
|
+
group_size: int
|
|
63
|
+
bits: int
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
QuantizationConfigType = AWQQuantizationConfig | GPTQQuantizationConfig | MLXQuantizationConfig | None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _structure_quantization_config(v: object, _: object) -> QuantizationConfigType:
|
|
70
|
+
match v:
|
|
71
|
+
case None:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
case {"quant_method": "awq", **_other}:
|
|
75
|
+
return cattrs.structure(v, AWQQuantizationConfig)
|
|
76
|
+
|
|
77
|
+
case {"quant_method": "gptq", **_other}:
|
|
78
|
+
return cattrs.structure(v, GPTQQuantizationConfig)
|
|
79
|
+
|
|
80
|
+
case {**_other}:
|
|
81
|
+
return cattrs.structure(v, MLXQuantizationConfig)
|
|
82
|
+
|
|
83
|
+
case _:
|
|
84
|
+
raise RuntimeError(f"Cannot structure {v}field")
|
|
85
|
+
|
|
86
|
+
|
|
59
87
|
@dataclass(frozen=True)
|
|
60
88
|
class HuggingFaceConfig(ForeignConfig):
|
|
89
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
90
|
+
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
91
|
+
_converter.register_structure_hook(QuantizationConfigType, _structure_quantization_config)
|
|
92
|
+
|
|
61
93
|
@property
|
|
62
94
|
def eos_token_ids(self) -> list[int]:
|
|
63
|
-
|
|
95
|
+
if not hasattr(self, "eos_token_id"):
|
|
96
|
+
raise RuntimeError("model doesn't havve eos_token_id, override eos_token_ids in model config")
|
|
97
|
+
|
|
98
|
+
return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id # type: ignore (This is a bug in pyright)
|
|
64
99
|
|
|
65
100
|
@property
|
|
66
101
|
def default_precision(self) -> DTypeLike:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
@@ -57,10 +58,8 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
57
58
|
context_length: int | None,
|
|
58
59
|
activation_precision: DTypeLike,
|
|
59
60
|
accumulation_precision: DTypeLike,
|
|
61
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
60
62
|
) -> DecoderConfig:
|
|
61
|
-
sliding_window_sizes = tuple(
|
|
62
|
-
self.sliding_window if not bool(i % 2) else None for i in range(self.num_hidden_layers)
|
|
63
|
-
)
|
|
64
63
|
embedding_input_scale = self.hidden_size**0.5
|
|
65
64
|
attention_scale = self.query_pre_attn_scalar**-0.5
|
|
66
65
|
embedding_config = TiedEmbeddingConfig(
|
|
@@ -83,16 +82,6 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
83
82
|
linear_config = FullPrecisionLinearConfig(
|
|
84
83
|
precision=activation_precision,
|
|
85
84
|
)
|
|
86
|
-
attention_config = AttentionConfig(
|
|
87
|
-
qkv_projection_config=linear_config,
|
|
88
|
-
out_projection_config=linear_config,
|
|
89
|
-
query_norm_config=None,
|
|
90
|
-
key_norm_config=None,
|
|
91
|
-
logit_soft_cap=self.attn_logit_softcapping,
|
|
92
|
-
has_sinks=False,
|
|
93
|
-
has_qkv_biases=self.attention_bias,
|
|
94
|
-
has_out_biases=False,
|
|
95
|
-
)
|
|
96
85
|
mlp_config = DenseMLPConfig(
|
|
97
86
|
linear_config=linear_config,
|
|
98
87
|
activation=GELU(),
|
|
@@ -101,28 +90,44 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
101
90
|
up_clipping=None,
|
|
102
91
|
gate_clipping=None,
|
|
103
92
|
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
93
|
+
|
|
94
|
+
layer_configs = []
|
|
95
|
+
for i in range(self.num_hidden_layers):
|
|
96
|
+
sliding_window_size = self.sliding_window if not bool(i % 2) else None
|
|
97
|
+
attention_config = AttentionConfig(
|
|
98
|
+
qkv_projection_config=linear_config,
|
|
99
|
+
out_projection_config=linear_config,
|
|
100
|
+
query_norm_config=None,
|
|
101
|
+
key_norm_config=None,
|
|
102
|
+
logit_soft_cap=self.attn_logit_softcapping,
|
|
103
|
+
has_sinks=False,
|
|
104
|
+
has_qkv_biases=self.attention_bias,
|
|
105
|
+
has_out_biases=False,
|
|
106
|
+
num_heads=self.num_attention_heads,
|
|
107
|
+
num_groups=self.num_key_value_heads,
|
|
108
|
+
head_dim=self.head_dim,
|
|
109
|
+
is_causal=True,
|
|
110
|
+
scale=attention_scale,
|
|
111
|
+
sliding_window_size=sliding_window_size,
|
|
112
|
+
)
|
|
113
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
114
|
+
pre_mixer_norm_config=rmsnorm_config,
|
|
115
|
+
mixer_config=attention_config,
|
|
116
|
+
post_mixer_norm_config=rmsnorm_config,
|
|
117
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
118
|
+
mlp_config=mlp_config,
|
|
119
|
+
post_mlp_norm_config=rmsnorm_config,
|
|
120
|
+
)
|
|
121
|
+
layer_configs.append(decoder_layer_config)
|
|
122
|
+
|
|
112
123
|
return DecoderConfig(
|
|
113
124
|
embedding_config=embedding_config,
|
|
114
125
|
global_rope_config=rope_config,
|
|
115
126
|
local_rope_config=None,
|
|
116
|
-
|
|
127
|
+
layer_configs=tuple(layer_configs),
|
|
117
128
|
output_norm_config=rmsnorm_config,
|
|
118
129
|
vocab_size=self.vocab_size,
|
|
119
130
|
model_dim=self.hidden_size,
|
|
120
131
|
hidden_dim=self.intermediate_size,
|
|
121
|
-
num_heads=self.num_attention_heads,
|
|
122
|
-
num_groups=self.num_key_value_heads,
|
|
123
|
-
head_dim=self.head_dim,
|
|
124
|
-
attention_scale=attention_scale,
|
|
125
|
-
num_layers=self.num_hidden_layers,
|
|
126
|
-
sliding_window_sizes=sliding_window_sizes,
|
|
127
132
|
context_length=context_length or self.max_position_embeddings,
|
|
128
133
|
)
|