lalamo 0.5.2__tar.gz → 0.5.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.5.2 → lalamo-0.5.3}/PKG-INFO +1 -1
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/__init__.py +3 -2
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/data/__init__.py +0 -1
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/data/huggingface_message.py +1 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/main.py +167 -18
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/message_processor.py +2 -3
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/common.py +120 -27
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo-0.5.3/lalamo/model_import/decoder_configs/common.py +105 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/executorch.py +14 -9
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo-0.5.3/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo-0.5.3/lalamo/model_import/loaders/__init__.py +8 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/loaders/executorch.py +24 -12
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/loaders/huggingface.py +258 -30
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/__init__.py +4 -2
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/common.py +8 -2
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/gemma.py +5 -1
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo-0.5.3/lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo-0.5.3/lalamo/models/__init__.py +10 -0
- lalamo-0.5.3/lalamo/models/common.py +81 -0
- {lalamo-0.5.2/lalamo → lalamo-0.5.3/lalamo/models}/language_model.py +32 -49
- lalamo-0.5.3/lalamo/models/router.py +59 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/__init__.py +33 -16
- lalamo-0.5.3/lalamo/modules/classifier.py +339 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/common.py +6 -3
- lalamo-0.5.3/lalamo/modules/decoder.py +208 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/mlp.py +28 -5
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/normalization.py +13 -8
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/attention.py +10 -6
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo-0.5.2/lalamo/modules/decoder.py → lalamo-0.5.3/lalamo/modules/transformer.py +75 -138
- lalamo-0.5.2/lalamo/modules/decoder_layer.py → lalamo-0.5.3/lalamo/modules/transformer_layer.py +62 -45
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/speculator/__init__.py +2 -0
- lalamo-0.5.3/lalamo/speculator/estimator.py +91 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/speculator/inference.py +28 -9
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/speculator/ngram.py +7 -3
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo.egg-info/SOURCES.txt +12 -2
- {lalamo-0.5.2 → lalamo-0.5.3}/pyproject.toml +5 -0
- lalamo-0.5.3/tests/test_cartesia_mlx_models.py +22 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_generation.py +6 -4
- lalamo-0.5.3/tests/test_huggingface_model_conversion.py +97 -0
- lalamo-0.5.3/tests/test_huggingface_models.py +41 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_mlx_models.py +1 -1
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_models.py +101 -65
- lalamo-0.5.2/lalamo/model_import/decoder_configs/common.py +0 -64
- lalamo-0.5.2/lalamo/model_import/loaders/__init__.py +0 -7
- lalamo-0.5.2/tests/test_huggingface_models.py +0 -24
- {lalamo-0.5.2 → lalamo-0.5.3}/LICENSE +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/README.md +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/common.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/data/utils.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/activations.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/linear.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/rope.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/__init__.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/common.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/mamba.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/state/__init__.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/modules/utils.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/quantization.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/registry_abc.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/sampling.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/speculator/common.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo/utils.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/setup.cfg +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_model_spec.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_moe.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_parameter_tree.py +0 -0
- {lalamo-0.5.2 → lalamo-0.5.3}/tests/test_registry_abc.py +0 -0
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from lalamo.language_model import LanguageModel
|
|
2
1
|
from lalamo.message_processor import (
|
|
3
2
|
AssistantMessage,
|
|
4
3
|
ContentBlock,
|
|
@@ -9,8 +8,9 @@ from lalamo.message_processor import (
|
|
|
9
8
|
UserMessage,
|
|
10
9
|
)
|
|
11
10
|
from lalamo.model_import import ModelSpec, import_model
|
|
11
|
+
from lalamo.models import LanguageModel, Router
|
|
12
12
|
|
|
13
|
-
__version__ = "0.5.
|
|
13
|
+
__version__ = "0.5.3"
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
16
|
"AssistantMessage",
|
|
@@ -19,6 +19,7 @@ __all__ = [
|
|
|
19
19
|
"LanguageModel",
|
|
20
20
|
"Message",
|
|
21
21
|
"ModelSpec",
|
|
22
|
+
"Router",
|
|
22
23
|
"SystemMessage",
|
|
23
24
|
"ToolSchema",
|
|
24
25
|
"UserMessage",
|
|
@@ -4,12 +4,11 @@ import re
|
|
|
4
4
|
import shutil
|
|
5
5
|
import sys
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from itertools import chain
|
|
7
|
+
from itertools import chain, islice
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import Annotated
|
|
10
10
|
|
|
11
11
|
import jax
|
|
12
|
-
import jax.numpy as jnp
|
|
13
12
|
import jax.profiler
|
|
14
13
|
import thefuzz.process
|
|
15
14
|
from click import Context as ClickContext
|
|
@@ -35,7 +34,6 @@ from typer import Argument, Context, Exit, Option, Typer
|
|
|
35
34
|
from lalamo.common import flatten_parameters
|
|
36
35
|
from lalamo.data import import_hf_parquet
|
|
37
36
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
38
|
-
from lalamo.language_model import LanguageModel
|
|
39
37
|
from lalamo.message_processor import UserMessage
|
|
40
38
|
from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
|
|
41
39
|
from lalamo.model_import.common import (
|
|
@@ -45,10 +43,16 @@ from lalamo.model_import.common import (
|
|
|
45
43
|
InitializingModelEvent,
|
|
46
44
|
StatusEvent,
|
|
47
45
|
)
|
|
46
|
+
from lalamo.models import LanguageModelConfig, RouterConfig
|
|
48
47
|
from lalamo.modules import config_converter
|
|
48
|
+
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
49
49
|
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
50
50
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
51
|
-
from lalamo.speculator.utils import
|
|
51
|
+
from lalamo.speculator.utils import (
|
|
52
|
+
SpeculatorTrainingEvent,
|
|
53
|
+
test_speculator,
|
|
54
|
+
train_speculator,
|
|
55
|
+
)
|
|
52
56
|
|
|
53
57
|
SCRIPT_NAME = Path(sys.argv[0]).name
|
|
54
58
|
|
|
@@ -123,7 +127,7 @@ def chat(
|
|
|
123
127
|
transient=True,
|
|
124
128
|
) as progress:
|
|
125
129
|
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
126
|
-
model =
|
|
130
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
127
131
|
progress.remove_task(loading_task)
|
|
128
132
|
warmup_task = progress.add_task("🔥 Warming up compilation cache...")
|
|
129
133
|
list(model.stream_reply_text([UserMessage("")], max_output_length=1))
|
|
@@ -145,6 +149,39 @@ def chat(
|
|
|
145
149
|
messages.append(model.message_processor.parse_response(model_response_text))
|
|
146
150
|
|
|
147
151
|
|
|
152
|
+
@app.command(help="Classify given message with a Router type of model.")
|
|
153
|
+
def classify(
|
|
154
|
+
model_path: Annotated[
|
|
155
|
+
Path,
|
|
156
|
+
Argument(
|
|
157
|
+
help="Path to the model directory.",
|
|
158
|
+
metavar="MODEL_PATH",
|
|
159
|
+
),
|
|
160
|
+
],
|
|
161
|
+
) -> None:
|
|
162
|
+
with Progress(
|
|
163
|
+
SpinnerColumn(),
|
|
164
|
+
TextColumn("[progress.description]{task.description}"),
|
|
165
|
+
transient=True,
|
|
166
|
+
) as progress:
|
|
167
|
+
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
168
|
+
model = RouterConfig.load_model(model_path)
|
|
169
|
+
progress.remove_task(loading_task)
|
|
170
|
+
warmup_task = progress.add_task("🔥 Warming up...")
|
|
171
|
+
model.classify_chat([UserMessage(content="warmup message")])
|
|
172
|
+
progress.remove_task(warmup_task)
|
|
173
|
+
console.print(f"🤖 Classifying input with [blue]{model_path}[/blue]:")
|
|
174
|
+
while True:
|
|
175
|
+
user_text = console.input("[cyan]user> [/cyan]")
|
|
176
|
+
user_message = UserMessage(user_text)
|
|
177
|
+
|
|
178
|
+
console.print("[red]assistant> [/red]", end="")
|
|
179
|
+
result = model.classify_chat([user_message])
|
|
180
|
+
for label, confidence in result.items():
|
|
181
|
+
console.print(f"{label} : {confidence}", end="")
|
|
182
|
+
console.print()
|
|
183
|
+
|
|
184
|
+
|
|
148
185
|
@app.command(help="Convert the model for use with the Uzu inference engine.")
|
|
149
186
|
def convert(
|
|
150
187
|
model_repo: Annotated[
|
|
@@ -194,6 +231,12 @@ def convert(
|
|
|
194
231
|
help="Overwrite existing model files.",
|
|
195
232
|
),
|
|
196
233
|
] = False,
|
|
234
|
+
message_for_trace: Annotated[
|
|
235
|
+
str | None,
|
|
236
|
+
Option(
|
|
237
|
+
help="Text message to use as prompt when recording trace",
|
|
238
|
+
),
|
|
239
|
+
] = None,
|
|
197
240
|
) -> None:
|
|
198
241
|
if precision is not None:
|
|
199
242
|
precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
|
|
@@ -224,6 +267,8 @@ def convert(
|
|
|
224
267
|
console.print("Exiting...")
|
|
225
268
|
raise Exit
|
|
226
269
|
|
|
270
|
+
message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
|
|
271
|
+
|
|
227
272
|
with Progress(
|
|
228
273
|
SpinnerColumn(),
|
|
229
274
|
TextColumn("[progress.description]{task.description}"),
|
|
@@ -254,17 +299,7 @@ def convert(
|
|
|
254
299
|
|
|
255
300
|
if include_traces:
|
|
256
301
|
trace_task = progress.add_task("🚁 Generating traces...")
|
|
257
|
-
|
|
258
|
-
num_tokens = 512
|
|
259
|
-
token_stride = 8
|
|
260
|
-
token_ids = jnp.arange(0, num_tokens, dtype=jnp.int32)[None, :]
|
|
261
|
-
token_positions = jnp.arange(0, num_tokens * token_stride, token_stride, dtype=jnp.int32)[None, :]
|
|
262
|
-
result = model.decoder(
|
|
263
|
-
token_ids,
|
|
264
|
-
token_positions,
|
|
265
|
-
return_updated_state=True,
|
|
266
|
-
return_activation_trace=True,
|
|
267
|
-
)
|
|
302
|
+
result = model.record_trace(message)
|
|
268
303
|
traces = flatten_parameters(result.export())
|
|
269
304
|
save_file(traces, output_dir / "traces.safetensors")
|
|
270
305
|
progress.remove_task(trace_task)
|
|
@@ -350,6 +385,77 @@ speculator_app = Typer()
|
|
|
350
385
|
app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
|
|
351
386
|
|
|
352
387
|
|
|
388
|
+
@speculator_app.command(help="Estimate maximum batch size at which a model can be run.")
|
|
389
|
+
def estimate_batchsize(
|
|
390
|
+
model_path: Annotated[
|
|
391
|
+
Path,
|
|
392
|
+
Argument(
|
|
393
|
+
help="Path to the model directory",
|
|
394
|
+
metavar="MODEL_PATH",
|
|
395
|
+
),
|
|
396
|
+
],
|
|
397
|
+
max_input_length: Annotated[
|
|
398
|
+
int,
|
|
399
|
+
Option(help="Max input length of a model."),
|
|
400
|
+
] = 1024,
|
|
401
|
+
max_output_length: Annotated[
|
|
402
|
+
int,
|
|
403
|
+
Option(help="Max output length of a model."),
|
|
404
|
+
] = 1024,
|
|
405
|
+
num_logits_per_token: Annotated[
|
|
406
|
+
int,
|
|
407
|
+
Option(help="Number of top logits that will be recorded."),
|
|
408
|
+
] = 8,
|
|
409
|
+
vram_gb: Annotated[
|
|
410
|
+
int | None,
|
|
411
|
+
Option(
|
|
412
|
+
help="Maximum vram size in gb allowed.",
|
|
413
|
+
show_default="max on default device",
|
|
414
|
+
),
|
|
415
|
+
] = None,
|
|
416
|
+
) -> None:
|
|
417
|
+
if vram_gb is not None:
|
|
418
|
+
mem = vram_gb * 1024 * 1024 * 1024
|
|
419
|
+
else:
|
|
420
|
+
memory_stats = jax.local_devices()[0].memory_stats()
|
|
421
|
+
if memory_stats is None:
|
|
422
|
+
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
423
|
+
raise Exit(1)
|
|
424
|
+
if "bytes_limit" not in memory_stats:
|
|
425
|
+
err_console.print("Cannot get the default device's bytes limit, use --vram-gb")
|
|
426
|
+
raise Exit(1)
|
|
427
|
+
mem = memory_stats["bytes_limit"]
|
|
428
|
+
|
|
429
|
+
with Progress(
|
|
430
|
+
SpinnerColumn(),
|
|
431
|
+
TextColumn("[progress.description]{task.description}"),
|
|
432
|
+
transient=True,
|
|
433
|
+
) as progress:
|
|
434
|
+
loading_model_task = progress.add_task("[cyan]Loading model...[/cyan]")
|
|
435
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
436
|
+
progress.remove_task(loading_model_task)
|
|
437
|
+
|
|
438
|
+
estimating_batchsize_task = progress.add_task("[cyan]Estimating batch size...[/cyan]")
|
|
439
|
+
|
|
440
|
+
def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
|
|
441
|
+
lo = str(event.lo)
|
|
442
|
+
hi = str(event.hi) if event.hi is not None else "?"
|
|
443
|
+
description = f"[cyan]Estimating batch size... ({lo}..{hi})[/cyan]"
|
|
444
|
+
progress.update(estimating_batchsize_task, description=description)
|
|
445
|
+
|
|
446
|
+
bs = estimate_batchsize_from_memory(
|
|
447
|
+
model,
|
|
448
|
+
max_input_length,
|
|
449
|
+
max_output_length,
|
|
450
|
+
num_logits_per_token,
|
|
451
|
+
mem,
|
|
452
|
+
progress_callback,
|
|
453
|
+
)
|
|
454
|
+
progress.remove_task(estimating_batchsize_task)
|
|
455
|
+
|
|
456
|
+
console.print(f"Found maximum batch size: [cyan]{bs}[/cyan]")
|
|
457
|
+
|
|
458
|
+
|
|
353
459
|
@speculator_app.command(help="Run model inference and collect traces for speculator training")
|
|
354
460
|
def collect_traces(
|
|
355
461
|
model_path: Annotated[
|
|
@@ -406,7 +512,7 @@ def collect_traces(
|
|
|
406
512
|
) as progress:
|
|
407
513
|
live.update(progress, refresh=True)
|
|
408
514
|
loading_model_task = progress.add_task("🧠 [cyan]Loading model...[/cyan]")
|
|
409
|
-
model =
|
|
515
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
410
516
|
progress.remove_task(loading_model_task)
|
|
411
517
|
|
|
412
518
|
loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
|
|
@@ -448,6 +554,49 @@ def collect_traces(
|
|
|
448
554
|
progress.update(inference_task, description="✅ Completed")
|
|
449
555
|
|
|
450
556
|
|
|
557
|
+
@speculator_app.command(help="View model inference traces")
|
|
558
|
+
def view_traces(
|
|
559
|
+
trace_path: Annotated[
|
|
560
|
+
Path,
|
|
561
|
+
Argument(
|
|
562
|
+
help="File of inference traces to view.",
|
|
563
|
+
metavar="TRACE_PATH",
|
|
564
|
+
),
|
|
565
|
+
],
|
|
566
|
+
model_path: Annotated[
|
|
567
|
+
Path,
|
|
568
|
+
Argument(
|
|
569
|
+
help="Path to the model directory for detokenization.",
|
|
570
|
+
metavar="MODEL_PATH",
|
|
571
|
+
),
|
|
572
|
+
],
|
|
573
|
+
num_completions: Annotated[
|
|
574
|
+
int | None,
|
|
575
|
+
Option(
|
|
576
|
+
help="Number of completions to show.",
|
|
577
|
+
),
|
|
578
|
+
] = None,
|
|
579
|
+
) -> None:
|
|
580
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
581
|
+
|
|
582
|
+
with open(trace_path, "rb") as trace_fd:
|
|
583
|
+
traces = LalamoCompletion.deserialize_many(trace_fd)
|
|
584
|
+
|
|
585
|
+
table = Table(
|
|
586
|
+
show_lines=True,
|
|
587
|
+
box=box.ROUNDED,
|
|
588
|
+
)
|
|
589
|
+
table.add_column("Prefix")
|
|
590
|
+
table.add_column("Completion")
|
|
591
|
+
|
|
592
|
+
for completion in islice(traces, num_completions):
|
|
593
|
+
detokenized_prefix = model.message_processor.detokenize(completion.prefix_token_ids)
|
|
594
|
+
detokenized_completion = model.message_processor.detokenize(completion.completion_token_ids)
|
|
595
|
+
table.add_row(detokenized_prefix, detokenized_completion)
|
|
596
|
+
|
|
597
|
+
console.print(table)
|
|
598
|
+
|
|
599
|
+
|
|
451
600
|
@speculator_app.command(help="Train a speculator from inference traces")
|
|
452
601
|
def train(
|
|
453
602
|
trace_path: Annotated[
|
|
@@ -535,7 +684,7 @@ def test(
|
|
|
535
684
|
Option(help="Number of sequences to generate"),
|
|
536
685
|
] = 8,
|
|
537
686
|
) -> None:
|
|
538
|
-
model =
|
|
687
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
539
688
|
|
|
540
689
|
with open(speculator_path, "rb") as fd:
|
|
541
690
|
speculator = NGramSpeculator.deserialize(fd.read())
|
|
@@ -156,13 +156,12 @@ class MessageProcessor:
|
|
|
156
156
|
raise ValueError(f"Invalid response format: {response}")
|
|
157
157
|
return AssistantMessage(**match.groupdict())
|
|
158
158
|
|
|
159
|
-
def
|
|
159
|
+
def tokenize_text(self, text: str) -> list[int]:
|
|
160
160
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
161
161
|
|
|
162
162
|
def tokenize_request(self, messages: Iterable[Message]) -> list[int]:
|
|
163
163
|
rendered = self.render_request(messages)
|
|
164
|
-
|
|
165
|
-
return tokenized
|
|
164
|
+
return self.tokenize_text(rendered)
|
|
166
165
|
|
|
167
166
|
def detokenize(self, tokens: list[int]) -> str:
|
|
168
167
|
return self.tokenizer.decode(tokens, skip_special_tokens=False)
|
|
@@ -13,14 +13,16 @@ from jax import Array
|
|
|
13
13
|
from jaxtyping import DTypeLike
|
|
14
14
|
from tokenizers import Tokenizer
|
|
15
15
|
|
|
16
|
-
from lalamo.language_model import GenerationConfig, LanguageModel, LanguageModelConfig
|
|
17
16
|
from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
|
|
18
|
-
from lalamo.
|
|
17
|
+
from lalamo.models import GenerationConfig, LanguageModel, LanguageModelConfig, Router, RouterConfig
|
|
18
|
+
from lalamo.modules import Classifier, Decoder, LalamoModule
|
|
19
19
|
from lalamo.quantization import QuantizationMode
|
|
20
20
|
|
|
21
|
+
from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
|
|
21
22
|
from .huggingface_generation_config import HFGenerationConfig
|
|
22
23
|
from .huggingface_tokenizer_config import HFTokenizerConfig
|
|
23
|
-
from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, UseCase
|
|
24
|
+
from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, ModelType, UseCase
|
|
25
|
+
from .model_specs.common import JSONFieldSpec
|
|
24
26
|
|
|
25
27
|
__all__ = [
|
|
26
28
|
"REPO_TO_MODEL",
|
|
@@ -29,6 +31,7 @@ __all__ = [
|
|
|
29
31
|
"InitializingModelEvent",
|
|
30
32
|
"ModelMetadata",
|
|
31
33
|
"ModelSpec",
|
|
34
|
+
"ModelType",
|
|
32
35
|
"StatusEvent",
|
|
33
36
|
"import_model",
|
|
34
37
|
]
|
|
@@ -68,7 +71,8 @@ class ModelMetadata:
|
|
|
68
71
|
quantization: QuantizationMode | None
|
|
69
72
|
repo: str
|
|
70
73
|
use_cases: tuple[UseCase, ...]
|
|
71
|
-
|
|
74
|
+
model_type: ModelType
|
|
75
|
+
model_config: LanguageModelConfig | RouterConfig
|
|
72
76
|
|
|
73
77
|
|
|
74
78
|
def download_file(
|
|
@@ -114,7 +118,7 @@ def download_config_file(
|
|
|
114
118
|
|
|
115
119
|
|
|
116
120
|
class ImportResults(NamedTuple):
|
|
117
|
-
model: LanguageModel
|
|
121
|
+
model: LanguageModel | Router
|
|
118
122
|
metadata: ModelMetadata
|
|
119
123
|
|
|
120
124
|
|
|
@@ -166,26 +170,14 @@ def import_message_processor(
|
|
|
166
170
|
return MessageProcessor(config=message_processor_config, tokenizer=tokenizer)
|
|
167
171
|
|
|
168
172
|
|
|
169
|
-
def
|
|
170
|
-
model_spec: ModelSpec
|
|
171
|
-
|
|
173
|
+
def _load_main_processing_module(
|
|
174
|
+
model_spec: ModelSpec,
|
|
175
|
+
precision: DTypeLike,
|
|
176
|
+
foreign_config: ForeignConfig,
|
|
177
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
172
178
|
context_length: int | None = None,
|
|
173
|
-
precision: DTypeLike | None = None,
|
|
174
179
|
accumulation_precision: DTypeLike = jnp.float32,
|
|
175
|
-
|
|
176
|
-
) -> ImportResults:
|
|
177
|
-
if isinstance(model_spec, str):
|
|
178
|
-
try:
|
|
179
|
-
model_spec = REPO_TO_MODEL[model_spec]
|
|
180
|
-
except KeyError as e:
|
|
181
|
-
raise ValueError(f"Unknown model: {model_spec}") from e
|
|
182
|
-
|
|
183
|
-
foreign_decoder_config_file = download_config_file(model_spec)
|
|
184
|
-
foreign_decoder_config = model_spec.config_type.from_json(foreign_decoder_config_file)
|
|
185
|
-
|
|
186
|
-
if precision is None:
|
|
187
|
-
precision = foreign_decoder_config.default_precision
|
|
188
|
-
|
|
180
|
+
) -> LalamoModule:
|
|
189
181
|
weights_paths = download_weights(model_spec, progress_callback=progress_callback)
|
|
190
182
|
with ExitStack() as stack:
|
|
191
183
|
weights_shards = []
|
|
@@ -200,7 +192,7 @@ def import_model(
|
|
|
200
192
|
if progress_callback is not None:
|
|
201
193
|
progress_callback(InitializingModelEvent())
|
|
202
194
|
|
|
203
|
-
|
|
195
|
+
processing_module = foreign_config.load(
|
|
204
196
|
context_length,
|
|
205
197
|
precision,
|
|
206
198
|
accumulation_precision,
|
|
@@ -208,6 +200,33 @@ def import_model(
|
|
|
208
200
|
metadata_dict,
|
|
209
201
|
)
|
|
210
202
|
|
|
203
|
+
return processing_module
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _import_language_model(
|
|
207
|
+
model_spec: ModelSpec,
|
|
208
|
+
*,
|
|
209
|
+
context_length: int | None = None,
|
|
210
|
+
precision: DTypeLike | None = None,
|
|
211
|
+
accumulation_precision: DTypeLike = jnp.float32,
|
|
212
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
213
|
+
) -> tuple[LanguageModel, LanguageModelConfig]:
|
|
214
|
+
foreign_decoder_config_file = download_config_file(model_spec)
|
|
215
|
+
foreign_decoder_config = model_spec.config_type.from_json(foreign_decoder_config_file)
|
|
216
|
+
assert isinstance(foreign_decoder_config, ForeignLMConfig)
|
|
217
|
+
|
|
218
|
+
if precision is None:
|
|
219
|
+
precision = foreign_decoder_config.default_precision
|
|
220
|
+
decoder = _load_main_processing_module(
|
|
221
|
+
model_spec,
|
|
222
|
+
precision,
|
|
223
|
+
foreign_decoder_config,
|
|
224
|
+
progress_callback,
|
|
225
|
+
context_length,
|
|
226
|
+
accumulation_precision,
|
|
227
|
+
)
|
|
228
|
+
assert isinstance(decoder, Decoder)
|
|
229
|
+
|
|
211
230
|
if progress_callback is not None:
|
|
212
231
|
progress_callback(FinishedInitializingModelEvent())
|
|
213
232
|
|
|
@@ -235,12 +254,85 @@ def import_model(
|
|
|
235
254
|
)
|
|
236
255
|
|
|
237
256
|
language_model_config = LanguageModelConfig(
|
|
238
|
-
|
|
257
|
+
model_config=decoder.config,
|
|
239
258
|
message_processor_config=message_processor.config,
|
|
240
259
|
generation_config=generation_config,
|
|
241
260
|
)
|
|
242
261
|
|
|
243
262
|
language_model = LanguageModel(language_model_config, decoder, message_processor)
|
|
263
|
+
return language_model, language_model_config
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _import_router(
|
|
267
|
+
model_spec: ModelSpec,
|
|
268
|
+
*,
|
|
269
|
+
context_length: int | None = None,
|
|
270
|
+
precision: DTypeLike | None = None,
|
|
271
|
+
accumulation_precision: DTypeLike = jnp.float32,
|
|
272
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
273
|
+
) -> tuple[Router, RouterConfig]:
|
|
274
|
+
foreign_classifier_config_file = download_config_file(model_spec)
|
|
275
|
+
foreign_classifier_config = model_spec.config_type.from_json(foreign_classifier_config_file)
|
|
276
|
+
assert isinstance(foreign_classifier_config, ForeignClassifierConfig)
|
|
277
|
+
|
|
278
|
+
if precision is None:
|
|
279
|
+
precision = foreign_classifier_config.default_precision
|
|
280
|
+
|
|
281
|
+
classifier = _load_main_processing_module(
|
|
282
|
+
model_spec,
|
|
283
|
+
precision,
|
|
284
|
+
foreign_classifier_config,
|
|
285
|
+
progress_callback,
|
|
286
|
+
context_length,
|
|
287
|
+
accumulation_precision,
|
|
288
|
+
)
|
|
289
|
+
assert isinstance(classifier, Classifier)
|
|
290
|
+
|
|
291
|
+
if progress_callback is not None:
|
|
292
|
+
progress_callback(FinishedInitializingModelEvent())
|
|
293
|
+
|
|
294
|
+
message_processor = import_message_processor(model_spec)
|
|
295
|
+
|
|
296
|
+
router_config = RouterConfig(
|
|
297
|
+
model_config=classifier.config,
|
|
298
|
+
message_processor_config=message_processor.config,
|
|
299
|
+
)
|
|
300
|
+
router_model = Router(router_config, classifier, message_processor)
|
|
301
|
+
return router_model, router_config
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def import_model(
|
|
305
|
+
model_spec: ModelSpec | str,
|
|
306
|
+
*,
|
|
307
|
+
context_length: int | None = None,
|
|
308
|
+
precision: DTypeLike | None = None,
|
|
309
|
+
accumulation_precision: DTypeLike = jnp.float32,
|
|
310
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
311
|
+
) -> ImportResults:
|
|
312
|
+
if isinstance(model_spec, str):
|
|
313
|
+
try:
|
|
314
|
+
model_spec = REPO_TO_MODEL[model_spec]
|
|
315
|
+
except KeyError as e:
|
|
316
|
+
raise ValueError(f"Unknown model: {model_spec}") from e
|
|
317
|
+
|
|
318
|
+
match model_spec.model_type:
|
|
319
|
+
case ModelType.LANGUAGE_MODEL:
|
|
320
|
+
model, config = _import_language_model(
|
|
321
|
+
model_spec,
|
|
322
|
+
context_length=context_length,
|
|
323
|
+
precision=precision,
|
|
324
|
+
accumulation_precision=accumulation_precision,
|
|
325
|
+
progress_callback=progress_callback,
|
|
326
|
+
)
|
|
327
|
+
case ModelType.ROUTER_MODEL:
|
|
328
|
+
model, config = _import_router(
|
|
329
|
+
model_spec,
|
|
330
|
+
context_length=context_length,
|
|
331
|
+
precision=precision,
|
|
332
|
+
accumulation_precision=accumulation_precision,
|
|
333
|
+
progress_callback=progress_callback,
|
|
334
|
+
)
|
|
335
|
+
|
|
244
336
|
metadata = ModelMetadata(
|
|
245
337
|
toolchain_version=LALAMO_VERSION,
|
|
246
338
|
vendor=model_spec.vendor,
|
|
@@ -250,6 +342,7 @@ def import_model(
|
|
|
250
342
|
quantization=model_spec.quantization,
|
|
251
343
|
repo=model_spec.repo,
|
|
252
344
|
use_cases=model_spec.use_cases,
|
|
253
|
-
|
|
345
|
+
model_type=model_spec.model_type,
|
|
346
|
+
model_config=config,
|
|
254
347
|
)
|
|
255
|
-
return ImportResults(
|
|
348
|
+
return ImportResults(model, metadata)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .common import ForeignConfig
|
|
1
|
+
from .common import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
|
|
2
2
|
|
|
3
3
|
# from .executorch import ETLlamaConfig
|
|
4
4
|
from .huggingface import (
|
|
@@ -14,8 +14,10 @@ from .huggingface import (
|
|
|
14
14
|
)
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
-
|
|
17
|
+
"ForeignClassifierConfig",
|
|
18
18
|
"ForeignConfig",
|
|
19
|
+
# "ETLlamaConfig",
|
|
20
|
+
"ForeignLMConfig",
|
|
19
21
|
"HFGPTOssConfig",
|
|
20
22
|
"HFGemma2Config",
|
|
21
23
|
"HFGemma3Config",
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import ClassVar, Self
|
|
7
|
+
|
|
8
|
+
import cattrs
|
|
9
|
+
from jaxtyping import Array, DTypeLike
|
|
10
|
+
|
|
11
|
+
from lalamo.modules import ClassifierConfig, DecoderConfig
|
|
12
|
+
from lalamo.modules.common import LalamoModule
|
|
13
|
+
from lalamo.registry_abc import RegistryABC
|
|
14
|
+
|
|
15
|
+
__all__ = ["ForeignClassifierConfig", "ForeignLMConfig"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class ForeignConfig[ConfigT: DecoderConfig | ClassifierConfig](RegistryABC):
|
|
20
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
21
|
+
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def default_precision(self) -> DTypeLike: ...
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_json(cls, json_path: Path | str) -> Self:
|
|
29
|
+
json_path = Path(json_path)
|
|
30
|
+
with open(json_path) as f:
|
|
31
|
+
config = json.load(f)
|
|
32
|
+
return cls._converter.structure(config, cls)
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def _load_weights(
|
|
36
|
+
self,
|
|
37
|
+
model: LalamoModule,
|
|
38
|
+
weights_dict: Mapping[str, Array],
|
|
39
|
+
) -> LalamoModule: ...
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def to_lalamo_config(
|
|
43
|
+
self,
|
|
44
|
+
context_length: int | None,
|
|
45
|
+
activation_precision: DTypeLike,
|
|
46
|
+
accumulation_precision: DTypeLike,
|
|
47
|
+
metadata_dict: Mapping[str, str],
|
|
48
|
+
) -> ConfigT: ...
|
|
49
|
+
|
|
50
|
+
def load(
|
|
51
|
+
self,
|
|
52
|
+
context_length: int | None,
|
|
53
|
+
activation_precision: DTypeLike,
|
|
54
|
+
accumulation_precision: DTypeLike,
|
|
55
|
+
weights_dict: Mapping[str, Array],
|
|
56
|
+
metadata_dict: Mapping[str, str],
|
|
57
|
+
) -> LalamoModule[ConfigT]:
|
|
58
|
+
config = self.to_lalamo_config(context_length, activation_precision, accumulation_precision, metadata_dict)
|
|
59
|
+
model = config.empty()
|
|
60
|
+
return self._load_weights(model, weights_dict)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class ForeignLMConfig(ForeignConfig, RegistryABC):
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def to_decoder_config(
|
|
67
|
+
self,
|
|
68
|
+
context_length: int | None,
|
|
69
|
+
activation_precision: DTypeLike,
|
|
70
|
+
accumulation_precision: DTypeLike,
|
|
71
|
+
metadata_dict: Mapping[str, str],
|
|
72
|
+
) -> DecoderConfig: ...
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def eos_token_ids(self) -> list[int]: ...
|
|
77
|
+
|
|
78
|
+
def to_lalamo_config(
|
|
79
|
+
self,
|
|
80
|
+
context_length: int | None,
|
|
81
|
+
activation_precision: DTypeLike,
|
|
82
|
+
accumulation_precision: DTypeLike,
|
|
83
|
+
metadata_dict: Mapping[str, str],
|
|
84
|
+
) -> DecoderConfig:
|
|
85
|
+
return self.to_decoder_config(context_length, activation_precision, accumulation_precision, metadata_dict)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(frozen=True)
|
|
89
|
+
class ForeignClassifierConfig(ForeignConfig, RegistryABC):
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def to_classifier_config(
|
|
92
|
+
self,
|
|
93
|
+
context_length: int | None,
|
|
94
|
+
activation_precision: DTypeLike,
|
|
95
|
+
accumulation_precision: DTypeLike,
|
|
96
|
+
) -> ClassifierConfig: ...
|
|
97
|
+
|
|
98
|
+
def to_lalamo_config(
|
|
99
|
+
self,
|
|
100
|
+
context_length: int | None,
|
|
101
|
+
activation_precision: DTypeLike,
|
|
102
|
+
accumulation_precision: DTypeLike,
|
|
103
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
104
|
+
) -> ClassifierConfig:
|
|
105
|
+
return self.to_classifier_config(context_length, activation_precision, accumulation_precision)
|