lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +247 -14
- lalamo/common.py +33 -0
- lalamo/data/__init__.py +3 -2
- lalamo/data/huggingface_message.py +4 -5
- lalamo/main.py +274 -9
- lalamo/message_processor.py +19 -1
- lalamo/model_import/common.py +17 -1
- lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo/model_import/remote_registry.py +44 -0
- lalamo/models/__init__.py +3 -0
- lalamo/models/common.py +22 -0
- lalamo/models/compile_helpers.py +58 -0
- lalamo/models/language_model.py +342 -56
- lalamo/models/lm_helpers.py +198 -0
- lalamo/modules/decoder.py +4 -0
- lalamo/modules/token_mixers/mamba.py +345 -105
- lalamo/speculator/__init__.py +0 -2
- lalamo/speculator/inference.py +35 -61
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/METADATA +1 -1
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/RECORD +25 -23
- lalamo/speculator/estimator.py +0 -127
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/WHEEL +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/top_level.txt +0 -0
lalamo/main.py
CHANGED
|
@@ -10,6 +10,7 @@ from pathlib import Path
|
|
|
10
10
|
from typing import Annotated
|
|
11
11
|
|
|
12
12
|
import jax.profiler
|
|
13
|
+
import requests
|
|
13
14
|
import thefuzz.process
|
|
14
15
|
from click import Context as ClickContext
|
|
15
16
|
from click import Parameter as ClickParameter
|
|
@@ -35,24 +36,31 @@ from lalamo.commands import (
|
|
|
35
36
|
CollectTracesCallbacks,
|
|
36
37
|
ConversionCallbacks,
|
|
37
38
|
EstimateBatchsizeCallbacks,
|
|
39
|
+
GenerateRepliesCallbacks,
|
|
38
40
|
Precision,
|
|
41
|
+
PullCallbacks,
|
|
39
42
|
TraceCallbacks,
|
|
40
43
|
TrainCallbacks,
|
|
44
|
+
_suggest_similar_models,
|
|
41
45
|
)
|
|
42
46
|
from lalamo.commands import collect_traces as _collect_traces
|
|
43
47
|
from lalamo.commands import convert as _convert
|
|
44
48
|
from lalamo.commands import estimate_batchsize as _estimate_batchsize
|
|
49
|
+
from lalamo.commands import generate_replies as _generate_replies
|
|
50
|
+
from lalamo.commands import pull as _pull
|
|
45
51
|
from lalamo.commands import trace as _trace
|
|
46
52
|
from lalamo.commands import train as _train
|
|
53
|
+
from lalamo.common import (
|
|
54
|
+
get_default_device_bytes,
|
|
55
|
+
get_usable_memory_from_bytes,
|
|
56
|
+
)
|
|
47
57
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
48
58
|
from lalamo.message_processor import UserMessage
|
|
49
59
|
from lalamo.model_import import REPO_TO_MODEL, ModelSpec
|
|
50
60
|
from lalamo.model_import.common import FileSpec
|
|
61
|
+
from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile, fetch_available_models
|
|
51
62
|
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
52
|
-
from lalamo.
|
|
53
|
-
get_default_device_bytes,
|
|
54
|
-
get_usable_memory_from_bytes,
|
|
55
|
-
)
|
|
63
|
+
from lalamo.models.common import BatchSizesComputedEvent
|
|
56
64
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
57
65
|
from lalamo.speculator.utils import test_speculator
|
|
58
66
|
|
|
@@ -76,7 +84,7 @@ class ModelParser(ParamType):
|
|
|
76
84
|
def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext | None) -> ModelSpec:
|
|
77
85
|
result = REPO_TO_MODEL.get(value)
|
|
78
86
|
if result is None:
|
|
79
|
-
closest_repo = _closest_repo(value)
|
|
87
|
+
closest_repo = _closest_repo(value, list(REPO_TO_MODEL))
|
|
80
88
|
error_message_parts = [
|
|
81
89
|
f'"{value}".',
|
|
82
90
|
]
|
|
@@ -92,10 +100,37 @@ class ModelParser(ParamType):
|
|
|
92
100
|
return result
|
|
93
101
|
|
|
94
102
|
|
|
95
|
-
|
|
96
|
-
|
|
103
|
+
class RemoteModelParser(ParamType):
|
|
104
|
+
name: str = "Pre-converted Model"
|
|
105
|
+
|
|
106
|
+
def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext | None) -> "RegistryModel":
|
|
107
|
+
try:
|
|
108
|
+
available_models = fetch_available_models()
|
|
109
|
+
except (requests.RequestException, ValueError) as e:
|
|
110
|
+
error_message = f"Failed to fetch model list from SDK. Check your internet connection.\n\nError: {e}"
|
|
111
|
+
return self.fail(error_message, param, ctx)
|
|
112
|
+
|
|
113
|
+
repo_to_model = {m.repo_id: m for m in available_models}
|
|
114
|
+
model_spec = repo_to_model.get(value)
|
|
115
|
+
if model_spec is None:
|
|
116
|
+
closest_repo = _closest_repo(value, list(repo_to_model))
|
|
117
|
+
if closest_repo:
|
|
118
|
+
model_spec = repo_to_model[closest_repo]
|
|
119
|
+
|
|
120
|
+
if model_spec is None:
|
|
121
|
+
suggestions = _suggest_similar_models(value, available_models)
|
|
122
|
+
error_message = f'Model "{value}" not found.'
|
|
123
|
+
if suggestions:
|
|
124
|
+
error_message += "\n\nDid you mean one of these?\n" + "\n".join(f" - {s}" for s in suggestions)
|
|
125
|
+
return self.fail(error_message, param, ctx)
|
|
126
|
+
|
|
127
|
+
return model_spec
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _closest_repo(query: str, repo_ids: list[str], min_score: float = 80) -> str | None:
|
|
131
|
+
if not repo_ids:
|
|
97
132
|
return None
|
|
98
|
-
(closest_match, score), *_ = thefuzz.process.extract(query,
|
|
133
|
+
(closest_match, score), *_ = thefuzz.process.extract(query, repo_ids)
|
|
99
134
|
if closest_match and score >= min_score:
|
|
100
135
|
return closest_match
|
|
101
136
|
return None
|
|
@@ -266,6 +301,49 @@ class CliConversionCallbacks(ConversionCallbacks):
|
|
|
266
301
|
console.print(f"🧑🍳 Model successfully cooked and saved to [cyan]`{self.output_dir}`[/cyan]!")
|
|
267
302
|
|
|
268
303
|
|
|
304
|
+
@dataclass
|
|
305
|
+
class CliPullCallbacks(PullCallbacks):
|
|
306
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
307
|
+
progress: Progress | None = None
|
|
308
|
+
downloading_tasks: dict[RegistryModelFile, TaskID] = field(default_factory=dict)
|
|
309
|
+
|
|
310
|
+
def started(self) -> None:
|
|
311
|
+
console.print(f"📦 Pulling [cyan]{self.model_spec.name}[/cyan] by [cyan]{self.model_spec.vendor}[/cyan]")
|
|
312
|
+
|
|
313
|
+
self.progress = self.stack.enter_context(
|
|
314
|
+
Progress(
|
|
315
|
+
SpinnerColumn(),
|
|
316
|
+
TextColumn("[progress.description]{task.description}"),
|
|
317
|
+
transient=True,
|
|
318
|
+
),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def output_dir_exists(self) -> None:
|
|
322
|
+
if not self.overwrite and not Confirm().ask(
|
|
323
|
+
rf"⚠️ Output directory [cyan]{self.output_dir}[/cyan] already exists."
|
|
324
|
+
r" Do you want to overwrite it?",
|
|
325
|
+
):
|
|
326
|
+
raise Exit
|
|
327
|
+
|
|
328
|
+
shutil.rmtree(self.output_dir)
|
|
329
|
+
|
|
330
|
+
def downloading(self, file_spec: RegistryModelFile) -> None:
|
|
331
|
+
assert self.progress is not None
|
|
332
|
+
|
|
333
|
+
self.downloading_tasks[file_spec] = self.progress.add_task(f"⬇️ Downloading {file_spec.name}...")
|
|
334
|
+
|
|
335
|
+
def finished_downloading(self, file_spec: RegistryModelFile) -> None:
|
|
336
|
+
assert self.progress is not None
|
|
337
|
+
|
|
338
|
+
self.progress.remove_task(self.downloading_tasks[file_spec])
|
|
339
|
+
|
|
340
|
+
def finished(self) -> None:
|
|
341
|
+
assert self.progress is not None
|
|
342
|
+
|
|
343
|
+
self.stack.close()
|
|
344
|
+
console.print(f"🎉 Model successfully pulled to [cyan]{self.output_dir}[/cyan]!")
|
|
345
|
+
|
|
346
|
+
|
|
269
347
|
@app.command(help="Convert the model for use with the Uzu inference engine.")
|
|
270
348
|
def convert(
|
|
271
349
|
model_repo: Annotated[
|
|
@@ -322,6 +400,46 @@ def convert(
|
|
|
322
400
|
)
|
|
323
401
|
|
|
324
402
|
|
|
403
|
+
@app.command(help="Pull a pre-converted model from the SDK repository.")
|
|
404
|
+
def pull(
|
|
405
|
+
model_spec: Annotated[
|
|
406
|
+
RegistryModel,
|
|
407
|
+
Argument(
|
|
408
|
+
help=(
|
|
409
|
+
"Model repository ID from the pre-converted catalog. "
|
|
410
|
+
"Example: [cyan]'meta-llama/Llama-3.2-1B-Instruct'[/cyan]. "
|
|
411
|
+
"Fuzzy matching is supported for typos and partial names."
|
|
412
|
+
),
|
|
413
|
+
click_type=RemoteModelParser(),
|
|
414
|
+
show_default=False,
|
|
415
|
+
metavar="MODEL_IDENTIFIER",
|
|
416
|
+
),
|
|
417
|
+
],
|
|
418
|
+
output_dir: Annotated[
|
|
419
|
+
Path | None,
|
|
420
|
+
Option(
|
|
421
|
+
help="Directory to save the pulled model to.",
|
|
422
|
+
show_default="Saves the pulled model in the `models/<model_name>` directory",
|
|
423
|
+
),
|
|
424
|
+
] = None,
|
|
425
|
+
overwrite: Annotated[
|
|
426
|
+
bool,
|
|
427
|
+
Option(
|
|
428
|
+
help="Overwrite existing model files without prompting.",
|
|
429
|
+
),
|
|
430
|
+
] = False,
|
|
431
|
+
) -> None:
|
|
432
|
+
if output_dir is None:
|
|
433
|
+
output_dir = DEFAULT_OUTPUT_DIR / model_spec.name
|
|
434
|
+
|
|
435
|
+
_pull(
|
|
436
|
+
model_spec,
|
|
437
|
+
output_dir,
|
|
438
|
+
partial(CliPullCallbacks),
|
|
439
|
+
overwrite=overwrite,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
325
443
|
@dataclass
|
|
326
444
|
class CliTraceCallbacks(TraceCallbacks):
|
|
327
445
|
overwrite: bool = False
|
|
@@ -492,6 +610,151 @@ def list_models(
|
|
|
492
610
|
console.print(table)
|
|
493
611
|
|
|
494
612
|
|
|
613
|
+
@dataclass
|
|
614
|
+
class CliGenerateRepliesCallbacks(GenerateRepliesCallbacks):
|
|
615
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
616
|
+
progress: Progress | None = None
|
|
617
|
+
loading_task: TaskID | None = None
|
|
618
|
+
estimating_task: TaskID | None = None
|
|
619
|
+
generation_task: TaskID | None = None
|
|
620
|
+
|
|
621
|
+
def loading_model(self) -> None:
|
|
622
|
+
self.progress = self.stack.enter_context(
|
|
623
|
+
Progress(
|
|
624
|
+
SpinnerColumn(),
|
|
625
|
+
TextColumn("[progress.description]{task.description}"),
|
|
626
|
+
MofNCompleteColumn(),
|
|
627
|
+
TimeElapsedColumn(),
|
|
628
|
+
transient=True,
|
|
629
|
+
),
|
|
630
|
+
)
|
|
631
|
+
self.loading_task = self.progress.add_task("🧠 [cyan]Loading model...[/cyan]", total=None)
|
|
632
|
+
|
|
633
|
+
def finished_loading_model(self) -> None:
|
|
634
|
+
assert self.progress is not None
|
|
635
|
+
assert self.loading_task is not None
|
|
636
|
+
self.progress.remove_task(self.loading_task)
|
|
637
|
+
|
|
638
|
+
def loading_dataset(self) -> None:
|
|
639
|
+
assert self.progress is not None
|
|
640
|
+
self.loading_task = self.progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]", total=None)
|
|
641
|
+
|
|
642
|
+
def finished_loading_dataset(self) -> None:
|
|
643
|
+
assert self.progress is not None
|
|
644
|
+
assert self.loading_task is not None
|
|
645
|
+
self.progress.remove_task(self.loading_task)
|
|
646
|
+
|
|
647
|
+
def estimating_batchsize(self, sequence_length: int, lo: int, hi: int | None) -> None:
|
|
648
|
+
assert self.progress is not None
|
|
649
|
+
hi_str = str(hi) if hi is not None else "?"
|
|
650
|
+
description = (
|
|
651
|
+
f"📐 [cyan]Computing batch size for the prompt length of {sequence_length}... ({lo}..{hi_str})[/cyan]"
|
|
652
|
+
)
|
|
653
|
+
if self.estimating_task is None:
|
|
654
|
+
self.estimating_task = self.progress.add_task(description)
|
|
655
|
+
else:
|
|
656
|
+
self.progress.update(self.estimating_task, description=description)
|
|
657
|
+
|
|
658
|
+
def batch_sizes_estimated(self) -> None:
|
|
659
|
+
assert self.progress is not None
|
|
660
|
+
if self.estimating_task is None:
|
|
661
|
+
self.estimating_task = self.progress.add_task(
|
|
662
|
+
"📐 [cyan]Estimating the best batch sizes...[/cyan]",
|
|
663
|
+
total=None,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
def batch_sizes_computed(self, event: BatchSizesComputedEvent) -> None:
|
|
667
|
+
assert self.progress is not None
|
|
668
|
+
if self.estimating_task is not None:
|
|
669
|
+
self.progress.remove_task(self.estimating_task)
|
|
670
|
+
self.estimating_task = None
|
|
671
|
+
output_console = self.progress.console if self.progress is not None else console
|
|
672
|
+
for info in event.batch_sizes:
|
|
673
|
+
output_console.print(
|
|
674
|
+
f"Prefix length {info.prefix_length} has {info.num_elements} elements, "
|
|
675
|
+
f"with batchsize of {info.batch_size}",
|
|
676
|
+
)
|
|
677
|
+
self.generation_task = self.progress.add_task(
|
|
678
|
+
"🔮 [cyan]Generating replies...[/cyan]",
|
|
679
|
+
total=self.total_rows,
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
def generation_progress(self, rows_processed: int) -> None:
|
|
683
|
+
assert self.progress is not None
|
|
684
|
+
assert self.generation_task is not None
|
|
685
|
+
self.progress.update(self.generation_task, completed=rows_processed + 1)
|
|
686
|
+
|
|
687
|
+
def finished_generation(self) -> None:
|
|
688
|
+
assert self.progress is not None
|
|
689
|
+
assert self.generation_task is not None
|
|
690
|
+
self.progress.update(self.generation_task, description="✅ Completed")
|
|
691
|
+
self.stack.close()
|
|
692
|
+
console.print(f"💾 Replies saved to [cyan]{self.output_path}[/cyan]")
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
@app.command(help="Generate replies for conversations in a parquet file.")
|
|
696
|
+
def generate_replies(
|
|
697
|
+
model_path: Annotated[
|
|
698
|
+
Path,
|
|
699
|
+
Argument(
|
|
700
|
+
help="Path to the model directory.",
|
|
701
|
+
metavar="MODEL_PATH",
|
|
702
|
+
),
|
|
703
|
+
],
|
|
704
|
+
dataset_path: Annotated[
|
|
705
|
+
Path,
|
|
706
|
+
Argument(
|
|
707
|
+
help="Path to the input parquet file with conversations.",
|
|
708
|
+
metavar="DATASET_PATH",
|
|
709
|
+
),
|
|
710
|
+
],
|
|
711
|
+
output_path: Annotated[
|
|
712
|
+
Path,
|
|
713
|
+
Option(
|
|
714
|
+
help="Path to save the output parquet file.",
|
|
715
|
+
),
|
|
716
|
+
],
|
|
717
|
+
vram_gb: Annotated[
|
|
718
|
+
int | None,
|
|
719
|
+
Option(
|
|
720
|
+
help="Maximum VRAM in GB. Batch sizes are estimated automatically.",
|
|
721
|
+
show_default="max on default device",
|
|
722
|
+
),
|
|
723
|
+
] = None,
|
|
724
|
+
max_output_length: Annotated[
|
|
725
|
+
int,
|
|
726
|
+
Option(help="Maximum number of tokens to generate per reply."),
|
|
727
|
+
] = 8192,
|
|
728
|
+
batch_size: Annotated[
|
|
729
|
+
int | None,
|
|
730
|
+
Option(help="Fixed batch size to use, skipping automatic estimation."),
|
|
731
|
+
] = None,
|
|
732
|
+
) -> None:
|
|
733
|
+
if batch_size is not None and vram_gb is not None:
|
|
734
|
+
err_console.print("Cannot use both --batch-size and --vram-gb")
|
|
735
|
+
raise Exit(1)
|
|
736
|
+
|
|
737
|
+
max_vram: int | None = None
|
|
738
|
+
if batch_size is None:
|
|
739
|
+
if vram_gb is not None:
|
|
740
|
+
mem_bytes = vram_gb * 1000 * 1000 * 1000
|
|
741
|
+
elif (mem_bytes := get_default_device_bytes()) is None:
|
|
742
|
+
err_console.print("Cannot get the default device's memory stats, use --vram-gb or --batch-size")
|
|
743
|
+
raise Exit(1)
|
|
744
|
+
|
|
745
|
+
max_vram = mem_bytes
|
|
746
|
+
|
|
747
|
+
_generate_replies(
|
|
748
|
+
model_path,
|
|
749
|
+
dataset_path,
|
|
750
|
+
output_path,
|
|
751
|
+
max_vram,
|
|
752
|
+
max_output_length,
|
|
753
|
+
batch_size,
|
|
754
|
+
CliGenerateRepliesCallbacks,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
|
|
495
758
|
speculator_app = Typer()
|
|
496
759
|
app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
|
|
497
760
|
|
|
@@ -727,10 +990,12 @@ def view_traces(
|
|
|
727
990
|
table.add_column("Prefix")
|
|
728
991
|
table.add_column("Completion")
|
|
729
992
|
|
|
993
|
+
from rich.text import Text
|
|
994
|
+
|
|
730
995
|
for completion in islice(traces, num_completions):
|
|
731
996
|
detokenized_prefix = model.message_processor.detokenize(completion.prefix_token_ids)
|
|
732
997
|
detokenized_completion = model.message_processor.detokenize(completion.completion_token_ids)
|
|
733
|
-
table.add_row(detokenized_prefix, detokenized_completion)
|
|
998
|
+
table.add_row(Text(detokenized_prefix), Text(detokenized_completion))
|
|
734
999
|
|
|
735
1000
|
console.print(table)
|
|
736
1001
|
|
lalamo/message_processor.py
CHANGED
|
@@ -39,6 +39,7 @@ class HuggingFaceMessage(TypedDict):
|
|
|
39
39
|
class HuggingFaceRequest(TypedDict):
|
|
40
40
|
add_generation_prompt: bool
|
|
41
41
|
bos_token: str | None
|
|
42
|
+
eos_token: str | None
|
|
42
43
|
messages: list[HuggingFaceMessage]
|
|
43
44
|
enable_thinking: NotRequired[bool]
|
|
44
45
|
tools: NotRequired[dict]
|
|
@@ -75,6 +76,7 @@ class MessageProcessorConfig:
|
|
|
75
76
|
system_role_name: str
|
|
76
77
|
user_role_name: str
|
|
77
78
|
assistant_role_name: str
|
|
79
|
+
eos_token: str | None
|
|
78
80
|
bos_token: str | None
|
|
79
81
|
|
|
80
82
|
def init(self, tokenizer: Tokenizer) -> "MessageProcessor":
|
|
@@ -115,6 +117,10 @@ class MessageProcessor:
|
|
|
115
117
|
def bos_token(self) -> str | None:
|
|
116
118
|
return self.config.bos_token
|
|
117
119
|
|
|
120
|
+
@property
|
|
121
|
+
def eos_token(self) -> str | None:
|
|
122
|
+
return self.config.eos_token
|
|
123
|
+
|
|
118
124
|
def message_to_dict(self, message: Message) -> HuggingFaceMessage:
|
|
119
125
|
match message:
|
|
120
126
|
case UserMessage(content=content):
|
|
@@ -137,7 +143,12 @@ class MessageProcessor:
|
|
|
137
143
|
enable_thinking: bool | None = None,
|
|
138
144
|
) -> HuggingFaceRequest:
|
|
139
145
|
converted_messages = [self.message_to_dict(message) for message in messages]
|
|
140
|
-
result = HuggingFaceRequest(
|
|
146
|
+
result = HuggingFaceRequest(
|
|
147
|
+
add_generation_prompt=True,
|
|
148
|
+
messages=converted_messages,
|
|
149
|
+
bos_token=self.bos_token,
|
|
150
|
+
eos_token=self.eos_token,
|
|
151
|
+
)
|
|
141
152
|
if enable_thinking is not None:
|
|
142
153
|
result["enable_thinking"] = enable_thinking
|
|
143
154
|
if tools is not None:
|
|
@@ -163,9 +174,16 @@ class MessageProcessor:
|
|
|
163
174
|
rendered = self.render_request(messages)
|
|
164
175
|
return self.tokenize_text(rendered)
|
|
165
176
|
|
|
177
|
+
def tokenize_requests(self, dataset: Iterable[Iterable[Message]]) -> list[list[int]]:
|
|
178
|
+
return [self.tokenize_request(messages) for messages in dataset]
|
|
179
|
+
|
|
166
180
|
def detokenize(self, tokens: list[int]) -> str:
|
|
167
181
|
return self.tokenizer.decode(tokens, skip_special_tokens=False)
|
|
168
182
|
|
|
183
|
+
def parse_tokenized_response(self, tokens: list[int]) -> AssistantMessage:
|
|
184
|
+
detokenized = self.detokenize(tokens)
|
|
185
|
+
return self.parse_response(detokenized)
|
|
186
|
+
|
|
169
187
|
def __post_init__(self) -> None:
|
|
170
188
|
if self.output_parser_regex is not None:
|
|
171
189
|
all_fields = AssistantMessage.__dataclass_fields__
|
lalamo/model_import/common.py
CHANGED
|
@@ -138,6 +138,7 @@ def import_message_processor(
|
|
|
138
138
|
progress_callback,
|
|
139
139
|
)
|
|
140
140
|
tokenizer_config = HFTokenizerConfig.from_json(tokenizer_config_file)
|
|
141
|
+
|
|
141
142
|
if tokenizer_config.chat_template is None:
|
|
142
143
|
match model_spec.configs.chat_template:
|
|
143
144
|
case JSONFieldSpec(file_spec, field_name):
|
|
@@ -165,13 +166,28 @@ def import_message_processor(
|
|
|
165
166
|
tokenizer.add_special_tokens(added_special_tokens)
|
|
166
167
|
tokenizer.add_tokens(added_not_special_tokens)
|
|
167
168
|
|
|
169
|
+
bos_token = tokenizer_config.bos_token
|
|
170
|
+
eos_token = tokenizer_config.eos_token
|
|
171
|
+
|
|
172
|
+
# If we were not able to identify bos/eos - they are probably somewhere else, so we check config.json
|
|
173
|
+
if eos_token is None or bos_token is None:
|
|
174
|
+
foreign_decoder_config_file = download_config_file(model_spec, output_dir, progress_callback)
|
|
175
|
+
with open(foreign_decoder_config_file) as foreign_decoder_file:
|
|
176
|
+
foreign_decoder_json = json.load(foreign_decoder_file)
|
|
177
|
+
|
|
178
|
+
if bos_token is None:
|
|
179
|
+
bos_token = foreign_decoder_json.get("bos_token_id")
|
|
180
|
+
if eos_token is None:
|
|
181
|
+
eos_token = foreign_decoder_json.get("eos_token_id")
|
|
182
|
+
|
|
168
183
|
message_processor_config = MessageProcessorConfig(
|
|
169
184
|
prompt_template=prompt_template,
|
|
170
185
|
output_parser_regex=model_spec.output_parser_regex,
|
|
171
186
|
system_role_name=model_spec.system_role_name,
|
|
172
187
|
user_role_name=model_spec.user_role_name,
|
|
173
188
|
assistant_role_name=model_spec.assistant_role_name,
|
|
174
|
-
bos_token=
|
|
189
|
+
bos_token=bos_token,
|
|
190
|
+
eos_token=eos_token,
|
|
175
191
|
)
|
|
176
192
|
return MessageProcessor(config=message_processor_config, tokenizer=tokenizer)
|
|
177
193
|
|
|
@@ -10,6 +10,8 @@ from .common import (
|
|
|
10
10
|
|
|
11
11
|
__all__ = ["MISTRAL_MODELS"]
|
|
12
12
|
|
|
13
|
+
CODESTRAL_TOKENIZER_REPO = "mistralai/Codestral-22B-v0.1"
|
|
14
|
+
|
|
13
15
|
CODESTRAL = [
|
|
14
16
|
ModelSpec(
|
|
15
17
|
vendor="Mistral",
|
|
@@ -21,6 +23,9 @@ CODESTRAL = [
|
|
|
21
23
|
config_type=HFMistralConfig,
|
|
22
24
|
weights_type=WeightsType.SAFETENSORS,
|
|
23
25
|
use_cases=(UseCase.CODE,),
|
|
26
|
+
configs=ConfigMap(
|
|
27
|
+
tokenizer_config=FileSpec(repo=CODESTRAL_TOKENIZER_REPO, filename="tokenizer_config.json"),
|
|
28
|
+
),
|
|
24
29
|
),
|
|
25
30
|
]
|
|
26
31
|
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, ClassVar
|
|
3
|
+
|
|
4
|
+
import cattrs
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class RegistryModelFile:
|
|
10
|
+
name: str
|
|
11
|
+
url: str
|
|
12
|
+
size: int
|
|
13
|
+
crc32c: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class RegistryModel:
|
|
18
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
19
|
+
|
|
20
|
+
id: str
|
|
21
|
+
vendor: str
|
|
22
|
+
name: str
|
|
23
|
+
family: str
|
|
24
|
+
size: str
|
|
25
|
+
repo_id: str
|
|
26
|
+
quantization: str | None
|
|
27
|
+
files: list[RegistryModelFile]
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_dict(cls, data: dict[str, Any]) -> "RegistryModel":
|
|
31
|
+
if "repoId" in data:
|
|
32
|
+
data = {**data, "repo_id": data.pop("repoId")}
|
|
33
|
+
return cls._converter.structure(data, cls)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def fetch_available_models() -> list[RegistryModel]:
|
|
37
|
+
api_url = "https://sdk.trymirai.com/api/v1/models/list/lalamo"
|
|
38
|
+
response = requests.get(api_url, timeout=30)
|
|
39
|
+
response.raise_for_status()
|
|
40
|
+
|
|
41
|
+
data = response.json()
|
|
42
|
+
models_data = data.get("models", [])
|
|
43
|
+
|
|
44
|
+
return [RegistryModel.from_dict(model_data) for model_data in models_data]
|
lalamo/models/__init__.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
from .classifier import ClassifierModel, ClassifierModelConfig
|
|
2
|
+
from .common import BatchSizeInfo, BatchSizesComputedEvent
|
|
2
3
|
from .language_model import GenerationConfig, LanguageModel, LanguageModelConfig
|
|
3
4
|
|
|
4
5
|
__all__ = [
|
|
6
|
+
"BatchSizeInfo",
|
|
7
|
+
"BatchSizesComputedEvent",
|
|
5
8
|
"ClassifierModel",
|
|
6
9
|
"ClassifierModelConfig",
|
|
7
10
|
"GenerationConfig",
|
lalamo/models/common.py
CHANGED
|
@@ -18,11 +18,33 @@ from lalamo.modules.decoder import DecoderConfig, DecoderResult
|
|
|
18
18
|
from lalamo.safetensors import safe_read
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
|
+
"BatchSizeInfo",
|
|
22
|
+
"BatchSizesComputedEvent",
|
|
21
23
|
"TextModel",
|
|
22
24
|
"TextModelConfig",
|
|
23
25
|
]
|
|
24
26
|
|
|
25
27
|
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class InferenceConfig:
|
|
30
|
+
max_output_length: int = 8192
|
|
31
|
+
padded_length: int = 8192
|
|
32
|
+
num_top_logits_to_return: int | None = None
|
|
33
|
+
batch_size: int | None = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class BatchSizeInfo:
|
|
38
|
+
prefix_length: int
|
|
39
|
+
num_elements: int
|
|
40
|
+
batch_size: int
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class BatchSizesComputedEvent:
|
|
45
|
+
batch_sizes: tuple[BatchSizeInfo, ...]
|
|
46
|
+
|
|
47
|
+
|
|
26
48
|
@dataclass(frozen=True)
|
|
27
49
|
class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
|
|
28
50
|
model_config: ConfigT
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
|
|
9
|
+
from .common import InferenceConfig
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from jax._src.stages import Compiled
|
|
13
|
+
|
|
14
|
+
from .language_model import ForwardPassConfig, GenerationConfig, LanguageModel
|
|
15
|
+
|
|
16
|
+
_compile_cache: dict[
|
|
17
|
+
tuple[int, GenerationConfig | None, InferenceConfig | None, ForwardPassConfig | None],
|
|
18
|
+
Compiled,
|
|
19
|
+
] = {}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def compile_generate_tokens(
|
|
23
|
+
model: LanguageModel,
|
|
24
|
+
generation_config: GenerationConfig | None = None,
|
|
25
|
+
inference_config: InferenceConfig = InferenceConfig(), # noqa: B008
|
|
26
|
+
*,
|
|
27
|
+
forward_pass_config: ForwardPassConfig | None = None,
|
|
28
|
+
) -> Compiled:
|
|
29
|
+
from .language_model import LanguageModel
|
|
30
|
+
|
|
31
|
+
key = (id(model), generation_config, inference_config, forward_pass_config)
|
|
32
|
+
if key not in _compile_cache:
|
|
33
|
+
generate_tokens_fn = functools.partial(
|
|
34
|
+
LanguageModel.generate_tokens,
|
|
35
|
+
generation_config=generation_config,
|
|
36
|
+
max_output_length=inference_config.max_output_length,
|
|
37
|
+
num_top_logits_to_return=inference_config.num_top_logits_to_return,
|
|
38
|
+
forward_pass_config=forward_pass_config,
|
|
39
|
+
)
|
|
40
|
+
_compile_cache[key] = (
|
|
41
|
+
jax.jit(generate_tokens_fn)
|
|
42
|
+
.lower(
|
|
43
|
+
model,
|
|
44
|
+
prompt_token_ids=jax.ShapeDtypeStruct(
|
|
45
|
+
(inference_config.batch_size, inference_config.padded_length),
|
|
46
|
+
jnp.int32,
|
|
47
|
+
),
|
|
48
|
+
prompt_lengths_without_padding=jax.ShapeDtypeStruct((inference_config.batch_size,), jnp.int32),
|
|
49
|
+
keys=jax.ShapeDtypeStruct((inference_config.batch_size,), jax.random.key(0).dtype),
|
|
50
|
+
)
|
|
51
|
+
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
52
|
+
# 0 - no autotune, gpu shouldn't be touched
|
|
53
|
+
# 1 - basic level, gpu should be touched veeery little
|
|
54
|
+
# 2,3 - gpu touched more and more
|
|
55
|
+
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
56
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
57
|
+
)
|
|
58
|
+
return _compile_cache[key]
|