lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +271 -43
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +17 -7
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -4
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
- lalamo-0.4.0.dist-info/RECORD +71 -0
- lalamo-0.3.3.dist-info/RECORD +0 -59
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/main.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import random
|
|
2
3
|
import re
|
|
3
4
|
import shutil
|
|
4
5
|
import sys
|
|
5
6
|
from enum import Enum
|
|
7
|
+
from itertools import chain
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
from typing import Annotated
|
|
8
10
|
|
|
11
|
+
import jax
|
|
9
12
|
import jax.numpy as jnp
|
|
13
|
+
import jax.profiler
|
|
10
14
|
import thefuzz.process
|
|
11
15
|
from click import Context as ClickContext
|
|
12
16
|
from click import Parameter as ClickParameter
|
|
@@ -14,13 +18,24 @@ from click import ParamType
|
|
|
14
18
|
from jaxtyping import DTypeLike
|
|
15
19
|
from rich import box
|
|
16
20
|
from rich.console import Console
|
|
21
|
+
from rich.live import Live
|
|
17
22
|
from rich.panel import Panel
|
|
18
|
-
from rich.progress import
|
|
23
|
+
from rich.progress import (
|
|
24
|
+
MofNCompleteColumn,
|
|
25
|
+
Progress,
|
|
26
|
+
SpinnerColumn,
|
|
27
|
+
TextColumn,
|
|
28
|
+
TimeElapsedColumn,
|
|
29
|
+
TimeRemainingColumn,
|
|
30
|
+
track,
|
|
31
|
+
)
|
|
19
32
|
from rich.table import Table
|
|
20
33
|
from safetensors.flax import save_file
|
|
21
|
-
from typer import Argument, Exit, Option, Typer
|
|
34
|
+
from typer import Argument, Context, Exit, Option, Typer
|
|
22
35
|
|
|
23
36
|
from lalamo.common import flatten_parameters
|
|
37
|
+
from lalamo.data import import_hf_parquet
|
|
38
|
+
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
24
39
|
from lalamo.language_model import LanguageModel
|
|
25
40
|
from lalamo.message_processor import UserMessage
|
|
26
41
|
from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
|
|
@@ -31,7 +46,10 @@ from lalamo.model_import.common import (
|
|
|
31
46
|
InitializingModelEvent,
|
|
32
47
|
StatusEvent,
|
|
33
48
|
)
|
|
34
|
-
from lalamo.modules import
|
|
49
|
+
from lalamo.modules import config_converter
|
|
50
|
+
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
51
|
+
from lalamo.speculator.ngram import NGramSpeculator
|
|
52
|
+
from lalamo.speculator.utils import SpeculatorTrainingEvent, test_speculator, train_speculator
|
|
35
53
|
from lalamo.utils import jax_uint4_to_packed_uint8
|
|
36
54
|
|
|
37
55
|
SCRIPT_NAME = Path(sys.argv[0]).name
|
|
@@ -110,27 +128,19 @@ def chat(
|
|
|
110
128
|
metavar="MODEL_PATH",
|
|
111
129
|
),
|
|
112
130
|
],
|
|
113
|
-
weight_layout: Annotated[
|
|
114
|
-
WeightLayout | None,
|
|
115
|
-
Option(
|
|
116
|
-
help=(
|
|
117
|
-
"(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
|
|
118
|
-
"\n\n\n\n"
|
|
119
|
-
"If set to AUTO, the layout will depend on the model."
|
|
120
|
-
),
|
|
121
|
-
show_default="auto",
|
|
122
|
-
),
|
|
123
|
-
] = None,
|
|
124
131
|
) -> None:
|
|
125
|
-
if weight_layout is None:
|
|
126
|
-
weight_layout = WeightLayout.AUTO
|
|
127
132
|
with Progress(
|
|
128
133
|
SpinnerColumn(),
|
|
129
134
|
TextColumn("[progress.description]{task.description}"),
|
|
130
135
|
transient=True,
|
|
131
136
|
) as progress:
|
|
132
|
-
progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
133
|
-
model = LanguageModel.load(model_path
|
|
137
|
+
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
138
|
+
model = LanguageModel.load(model_path)
|
|
139
|
+
progress.remove_task(loading_task)
|
|
140
|
+
warmup_task = progress.add_task("🔥 Warming up compilation cache...")
|
|
141
|
+
list(model.stream_reply_text([UserMessage("")], max_output_length=1))
|
|
142
|
+
progress.remove_task(warmup_task)
|
|
143
|
+
console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
|
|
134
144
|
messages = []
|
|
135
145
|
while True:
|
|
136
146
|
user_text = console.input("[cyan]user> [/cyan]")
|
|
@@ -170,17 +180,6 @@ def convert(
|
|
|
170
180
|
show_default="Native precision of the model",
|
|
171
181
|
),
|
|
172
182
|
] = None,
|
|
173
|
-
weight_layout: Annotated[
|
|
174
|
-
WeightLayout | None,
|
|
175
|
-
Option(
|
|
176
|
-
help=(
|
|
177
|
-
"(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
|
|
178
|
-
"\n\n\n\n"
|
|
179
|
-
"If set to AUTO, the layout will depend on the model."
|
|
180
|
-
),
|
|
181
|
-
show_default="auto",
|
|
182
|
-
),
|
|
183
|
-
] = None,
|
|
184
183
|
output_dir: Annotated[
|
|
185
184
|
Path | None,
|
|
186
185
|
Option(
|
|
@@ -213,18 +212,10 @@ def convert(
|
|
|
213
212
|
else:
|
|
214
213
|
precision_dtype = None
|
|
215
214
|
|
|
216
|
-
if weight_layout is not None:
|
|
217
|
-
weight_layout = WeightLayout(weight_layout)
|
|
218
|
-
else:
|
|
219
|
-
weight_layout = WeightLayout.AUTO
|
|
220
|
-
|
|
221
215
|
if output_dir is None:
|
|
222
216
|
output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
|
|
223
217
|
|
|
224
|
-
|
|
225
|
-
conversion_strs = [
|
|
226
|
-
f"⚙️ Using weight layout [cyan]{weight_layout}[/cyan]",
|
|
227
|
-
]
|
|
218
|
+
conversion_strs = [f"🚀 Converting [cyan]{model_repo.name}[/cyan] by [cyan]{model_repo.vendor}[/cyan]"]
|
|
228
219
|
if precision is not None:
|
|
229
220
|
conversion_strs.append(
|
|
230
221
|
f" and converting floating-point weights into [cyan]{precision.name.lower()}[/cyan] precision",
|
|
@@ -292,7 +283,7 @@ def convert(
|
|
|
292
283
|
progress.remove_task(main_task)
|
|
293
284
|
|
|
294
285
|
model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
|
|
295
|
-
weights = flatten_parameters(model.export_weights(
|
|
286
|
+
weights = flatten_parameters(model.export_weights())
|
|
296
287
|
del model
|
|
297
288
|
|
|
298
289
|
packed_weights = _pack_uint4_weights(weights)
|
|
@@ -312,10 +303,10 @@ def _model_size_string_to_int(
|
|
|
312
303
|
) -> float:
|
|
313
304
|
match = _regex.match(size_str)
|
|
314
305
|
factors = {
|
|
315
|
-
"K":
|
|
316
|
-
"M":
|
|
317
|
-
"B":
|
|
318
|
-
"T":
|
|
306
|
+
"K": 1000**1,
|
|
307
|
+
"M": 1000**2,
|
|
308
|
+
"B": 1000**3,
|
|
309
|
+
"T": 1000**4,
|
|
319
310
|
}
|
|
320
311
|
if match:
|
|
321
312
|
return float(match.group("number")) * factors[match.group("suffix")]
|
|
@@ -368,5 +359,242 @@ def list_models(
|
|
|
368
359
|
console.print(table)
|
|
369
360
|
|
|
370
361
|
|
|
362
|
+
speculator_app = Typer()
|
|
363
|
+
app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
@speculator_app.command(help="Run model inference and collect traces for speculator training")
|
|
367
|
+
def collect_traces(
|
|
368
|
+
model_path: Annotated[
|
|
369
|
+
Path,
|
|
370
|
+
Argument(
|
|
371
|
+
help="Path to the model directory",
|
|
372
|
+
metavar="MODEL_PATH",
|
|
373
|
+
),
|
|
374
|
+
],
|
|
375
|
+
dataset_path: Annotated[
|
|
376
|
+
Path,
|
|
377
|
+
Argument(
|
|
378
|
+
help="Path to the dataset with prompts",
|
|
379
|
+
metavar="DATASET_PATH",
|
|
380
|
+
),
|
|
381
|
+
],
|
|
382
|
+
output_path: Annotated[
|
|
383
|
+
Path,
|
|
384
|
+
Option(
|
|
385
|
+
help="File to save the trace to",
|
|
386
|
+
metavar="OUTPUT_PATH",
|
|
387
|
+
),
|
|
388
|
+
],
|
|
389
|
+
num_logits_per_token: Annotated[
|
|
390
|
+
int,
|
|
391
|
+
Option(help="Record logits for this number of most probable tokens"),
|
|
392
|
+
] = 8,
|
|
393
|
+
max_input_length: Annotated[
|
|
394
|
+
int,
|
|
395
|
+
Option(help="Filter prompts that have more than this number of tokens in context"),
|
|
396
|
+
] = 1024,
|
|
397
|
+
max_output_length: Annotated[
|
|
398
|
+
int,
|
|
399
|
+
Option(help="Maximum number of tokens to generate in one completion"),
|
|
400
|
+
] = 1024,
|
|
401
|
+
batch_size: Annotated[
|
|
402
|
+
int,
|
|
403
|
+
Option(help="Number of sequences in one batch"),
|
|
404
|
+
] = 1,
|
|
405
|
+
num_tokens_to_generate: Annotated[
|
|
406
|
+
int | None,
|
|
407
|
+
Option(
|
|
408
|
+
help="Exit early after generating this number of output tokens",
|
|
409
|
+
show_default="all",
|
|
410
|
+
),
|
|
411
|
+
] = None,
|
|
412
|
+
) -> None:
|
|
413
|
+
with Live(refresh_per_second=10) as live:
|
|
414
|
+
with Progress(
|
|
415
|
+
SpinnerColumn(),
|
|
416
|
+
TextColumn("[progress.description]{task.description}"),
|
|
417
|
+
transient=True,
|
|
418
|
+
disable=True,
|
|
419
|
+
) as progress:
|
|
420
|
+
live.update(progress, refresh=True)
|
|
421
|
+
loading_model_task = progress.add_task("🧠 [cyan]Loading model...[/cyan]")
|
|
422
|
+
model = LanguageModel.load(model_path)
|
|
423
|
+
progress.remove_task(loading_model_task)
|
|
424
|
+
|
|
425
|
+
loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
|
|
426
|
+
dataset = iter(import_hf_parquet(dataset_path))
|
|
427
|
+
dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
|
|
428
|
+
progress.remove_task(loading_dataset_task)
|
|
429
|
+
|
|
430
|
+
with Progress(
|
|
431
|
+
SpinnerColumn(),
|
|
432
|
+
TextColumn("[progress.description]{task.description}"),
|
|
433
|
+
MofNCompleteColumn(),
|
|
434
|
+
TimeElapsedColumn(),
|
|
435
|
+
TimeRemainingColumn(),
|
|
436
|
+
disable=True,
|
|
437
|
+
) as progress:
|
|
438
|
+
live.update(progress, refresh=True)
|
|
439
|
+
inference_task = progress.add_task("🔮 [cyan]Running inference...[/cyan]", total=num_tokens_to_generate)
|
|
440
|
+
|
|
441
|
+
def progress_callback(event: CollectTracesEvent) -> None:
|
|
442
|
+
progress.update(inference_task, completed=event.tokens_generated)
|
|
443
|
+
|
|
444
|
+
traces = inference_collect_traces(
|
|
445
|
+
model,
|
|
446
|
+
dataset,
|
|
447
|
+
num_logits_per_token,
|
|
448
|
+
batch_size,
|
|
449
|
+
max_input_length,
|
|
450
|
+
max_output_length,
|
|
451
|
+
num_tokens_to_generate,
|
|
452
|
+
progress_callback,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
456
|
+
with open(output_path, "wb+") as output_fd:
|
|
457
|
+
for trace in traces:
|
|
458
|
+
blob = trace.serialize()
|
|
459
|
+
output_fd.write(blob)
|
|
460
|
+
|
|
461
|
+
progress.update(inference_task, description="✅ Completed")
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@speculator_app.command(help="Train a speculator from inference traces")
|
|
465
|
+
def train(
|
|
466
|
+
trace_path: Annotated[
|
|
467
|
+
Path,
|
|
468
|
+
Argument(
|
|
469
|
+
help="File of llm inference traces to train the speculator on",
|
|
470
|
+
metavar="TRACE_PATH",
|
|
471
|
+
),
|
|
472
|
+
],
|
|
473
|
+
output_path: Annotated[
|
|
474
|
+
Path,
|
|
475
|
+
Option(
|
|
476
|
+
help="File to save the output to",
|
|
477
|
+
metavar="OUTPUT_PATH",
|
|
478
|
+
),
|
|
479
|
+
],
|
|
480
|
+
hashtable_size: Annotated[
|
|
481
|
+
int,
|
|
482
|
+
Option(help="Size of ngram hashtable"),
|
|
483
|
+
] = 65536,
|
|
484
|
+
num_logits_per_token: Annotated[
|
|
485
|
+
int,
|
|
486
|
+
Option(help="Top K tokens to keep in ngram hashtable"),
|
|
487
|
+
] = 8,
|
|
488
|
+
ngram_size: Annotated[
|
|
489
|
+
int,
|
|
490
|
+
Option(help="Length of ngrams"),
|
|
491
|
+
] = 2,
|
|
492
|
+
subsample_size: Annotated[
|
|
493
|
+
int | None,
|
|
494
|
+
Option(
|
|
495
|
+
help="Exit early after training the model on this number of tokens",
|
|
496
|
+
show_default="all",
|
|
497
|
+
),
|
|
498
|
+
] = None,
|
|
499
|
+
) -> None:
|
|
500
|
+
with open(trace_path, "rb") as trace_fd:
|
|
501
|
+
traces = LalamoCompletion.deserialize_many(trace_fd)
|
|
502
|
+
|
|
503
|
+
speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
|
|
504
|
+
|
|
505
|
+
with Progress(
|
|
506
|
+
SpinnerColumn(),
|
|
507
|
+
TextColumn("[progress.description]{task.description}"),
|
|
508
|
+
MofNCompleteColumn(),
|
|
509
|
+
TimeElapsedColumn(),
|
|
510
|
+
TimeRemainingColumn(),
|
|
511
|
+
) as progress:
|
|
512
|
+
inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def progress_callback(event: SpeculatorTrainingEvent) -> None:
|
|
516
|
+
progress.update(inference_task, completed=event.trained_tokens)
|
|
517
|
+
|
|
518
|
+
train_speculator(speculator, traces, subsample_size, progress_callback)
|
|
519
|
+
|
|
520
|
+
progress.update(inference_task, description="✅ Completed")
|
|
521
|
+
|
|
522
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
523
|
+
with open(output_path, "wb+") as fd:
|
|
524
|
+
fd.write(speculator.serialize())
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@speculator_app.command(help="Run speculator as an autoregressive llm")
|
|
528
|
+
def test(
|
|
529
|
+
speculator_path: Annotated[
|
|
530
|
+
Path,
|
|
531
|
+
Argument(
|
|
532
|
+
help="Path to the speculator file.",
|
|
533
|
+
metavar="SPECULATOR_PATH",
|
|
534
|
+
),
|
|
535
|
+
],
|
|
536
|
+
model_path: Annotated[
|
|
537
|
+
Path,
|
|
538
|
+
Argument(
|
|
539
|
+
help="Path to the model directory for detokenization.",
|
|
540
|
+
metavar="MODEL_PATH",
|
|
541
|
+
),
|
|
542
|
+
],
|
|
543
|
+
seed: Annotated[
|
|
544
|
+
int | None,
|
|
545
|
+
Option(help="Set seed for deterministic sampling"),
|
|
546
|
+
] = None,
|
|
547
|
+
num_sequences: Annotated[
|
|
548
|
+
int,
|
|
549
|
+
Option(help="Number of sequences to generate"),
|
|
550
|
+
] = 8,
|
|
551
|
+
) -> None:
|
|
552
|
+
model = LanguageModel.load(model_path)
|
|
553
|
+
|
|
554
|
+
with open(speculator_path, "rb") as fd:
|
|
555
|
+
speculator = NGramSpeculator.deserialize(fd.read())
|
|
556
|
+
|
|
557
|
+
table = Table(
|
|
558
|
+
show_header=False,
|
|
559
|
+
show_lines=True,
|
|
560
|
+
box=box.ROUNDED,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if seed is not None:
|
|
564
|
+
random.seed(seed)
|
|
565
|
+
|
|
566
|
+
for _ in range(num_sequences):
|
|
567
|
+
sequence = test_speculator(speculator)
|
|
568
|
+
detokenized = model.message_processor.detokenize(sequence)
|
|
569
|
+
table.add_row(detokenized)
|
|
570
|
+
|
|
571
|
+
console.print(table)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
@app.callback()
|
|
575
|
+
def _profile_memory(
|
|
576
|
+
ctx: Context,
|
|
577
|
+
profile_memory: Annotated[
|
|
578
|
+
Path | None,
|
|
579
|
+
Option(
|
|
580
|
+
help="Record and save the XLA memory profile to specified path",
|
|
581
|
+
show_default="Don't save the XLA memory profile",
|
|
582
|
+
envvar="LALAMO_PROFILE_MEMORY",
|
|
583
|
+
),
|
|
584
|
+
] = None,
|
|
585
|
+
) -> None:
|
|
586
|
+
if profile_memory is None:
|
|
587
|
+
return
|
|
588
|
+
|
|
589
|
+
if profile_memory.is_dir():
|
|
590
|
+
profile_memory /= "lalamo-memory.prof"
|
|
591
|
+
|
|
592
|
+
def _save_memory_profile() -> None:
|
|
593
|
+
console.print(f"Saving XLA memory profile to {profile_memory}")
|
|
594
|
+
jax.profiler.save_device_memory_profile(profile_memory)
|
|
595
|
+
|
|
596
|
+
ctx.call_on_close(_save_memory_profile)
|
|
597
|
+
|
|
598
|
+
|
|
371
599
|
if __name__ == "__main__":
|
|
372
600
|
app()
|
lalamo/message_processor.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Iterable
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
from functools import cached_property
|
|
5
6
|
from re import Pattern
|
|
6
7
|
from typing import NotRequired, TypedDict
|
|
@@ -24,6 +25,10 @@ type ToolSchema = None # WIP
|
|
|
24
25
|
type Image = None # WIP
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
def _strftime_now(format_string: str) -> str:
|
|
29
|
+
return datetime.now().strftime(format_string) # noqa: DTZ005
|
|
30
|
+
|
|
31
|
+
|
|
27
32
|
class HuggingFaceMessage(TypedDict):
|
|
28
33
|
role: str
|
|
29
34
|
content: str
|
|
@@ -141,7 +146,7 @@ class MessageProcessor:
|
|
|
141
146
|
|
|
142
147
|
def render_request(self, messages: Iterable[Message]) -> str:
|
|
143
148
|
request_dict = self.request_to_dict(messages)
|
|
144
|
-
return self.prompt_template.render(request_dict)
|
|
149
|
+
return self.prompt_template.render({**request_dict, "strftime_now": _strftime_now})
|
|
145
150
|
|
|
146
151
|
def parse_response(self, response: str) -> AssistantMessage:
|
|
147
152
|
if self.output_parser_regex is None:
|
|
@@ -154,6 +159,11 @@ class MessageProcessor:
|
|
|
154
159
|
def tokenize(self, text: str) -> list[int]:
|
|
155
160
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
156
161
|
|
|
162
|
+
def tokenize_request(self, messages: Iterable[Message]) -> list[int]:
|
|
163
|
+
rendered = self.render_request(messages)
|
|
164
|
+
tokenized = self.tokenize(rendered)
|
|
165
|
+
return tokenized
|
|
166
|
+
|
|
157
167
|
def detokenize(self, tokens: list[int]) -> str:
|
|
158
168
|
return self.tokenizer.decode(tokens, skip_special_tokens=False)
|
|
159
169
|
|
lalamo/model_import/common.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
2
|
from collections import ChainMap
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from contextlib import ExitStack
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import NamedTuple
|
|
@@ -138,7 +139,13 @@ def import_message_processor(
|
|
|
138
139
|
raise ValueError("Conflicting chat template specifications.")
|
|
139
140
|
prompt_template = tokenizer_config.chat_template
|
|
140
141
|
tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
|
141
|
-
|
|
142
|
+
|
|
143
|
+
added_tokens = tokenizer_config.added_tokens()
|
|
144
|
+
added_special_tokens = [token for token in added_tokens if token.special]
|
|
145
|
+
added_not_special_tokens = [token for token in added_tokens if not token.special]
|
|
146
|
+
tokenizer.add_special_tokens(added_special_tokens)
|
|
147
|
+
tokenizer.add_tokens(added_not_special_tokens)
|
|
148
|
+
|
|
142
149
|
message_processor_config = MessageProcessorConfig(
|
|
143
150
|
prompt_template=prompt_template,
|
|
144
151
|
output_parser_regex=model_spec.output_parser_regex,
|
|
@@ -171,14 +178,17 @@ def import_model(
|
|
|
171
178
|
precision = foreign_decoder_config.default_precision
|
|
172
179
|
|
|
173
180
|
weights_paths = download_weights(model_spec, progress_callback=progress_callback)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
181
|
+
with ExitStack() as stack:
|
|
182
|
+
weights_shards = []
|
|
183
|
+
for weights_path in weights_paths:
|
|
184
|
+
weights_shard = stack.enter_context(model_spec.weights_type.load(weights_path, precision))
|
|
185
|
+
weights_shards.append(weights_shard)
|
|
186
|
+
weights_dict: ChainMap[str, Array] = ChainMap(*weights_shards)
|
|
177
187
|
|
|
178
|
-
|
|
179
|
-
|
|
188
|
+
if progress_callback is not None:
|
|
189
|
+
progress_callback(InitializingModelEvent())
|
|
180
190
|
|
|
181
|
-
|
|
191
|
+
decoder = foreign_decoder_config.load_decoder(context_length, precision, accumulation_precision, weights_dict)
|
|
182
192
|
|
|
183
193
|
if progress_callback is not None:
|
|
184
194
|
progress_callback(FinishedInitializingModelEvent())
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from .common import ForeignConfig
|
|
2
|
+
|
|
2
3
|
# from .executorch import ETLlamaConfig
|
|
3
4
|
from .huggingface import (
|
|
4
5
|
HFGemma2Config,
|
|
5
6
|
HFGemma3Config,
|
|
6
7
|
HFGemma3TextConfig,
|
|
8
|
+
HFGPTOssConfig,
|
|
7
9
|
HFLlamaConfig,
|
|
8
10
|
HFMistralConfig,
|
|
9
11
|
HFQwen2Config,
|
|
@@ -13,6 +15,7 @@ from .huggingface import (
|
|
|
13
15
|
__all__ = [
|
|
14
16
|
# "ETLlamaConfig",
|
|
15
17
|
"ForeignConfig",
|
|
18
|
+
"HFGPTOssConfig",
|
|
16
19
|
"HFGemma2Config",
|
|
17
20
|
"HFGemma3Config",
|
|
18
21
|
"HFGemma3TextConfig",
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
|
|
3
4
|
import jax.numpy as jnp
|
|
@@ -5,18 +6,18 @@ from jaxtyping import Array, DTypeLike
|
|
|
5
6
|
|
|
6
7
|
from lalamo.model_import.loaders.executorch import load_executorch
|
|
7
8
|
from lalamo.modules import (
|
|
8
|
-
Activation,
|
|
9
9
|
AttentionConfig,
|
|
10
10
|
Decoder,
|
|
11
11
|
DecoderConfig,
|
|
12
12
|
DecoderLayerConfig,
|
|
13
|
+
DenseMLPConfig,
|
|
13
14
|
LlamaRoPEConfig,
|
|
14
|
-
MLPConfig,
|
|
15
15
|
QLoRALinearConfig,
|
|
16
16
|
QuantizedTiedEmbeddingConfig,
|
|
17
17
|
RMSNormConfig,
|
|
18
18
|
UpcastMode,
|
|
19
19
|
)
|
|
20
|
+
from lalamo.modules.activations import SiLU
|
|
20
21
|
from lalamo.quantization import QuantizationMode
|
|
21
22
|
|
|
22
23
|
from .common import ForeignConfig
|
|
@@ -58,7 +59,7 @@ class ExecutorchConfig(ForeignConfig):
|
|
|
58
59
|
def _load_weights(
|
|
59
60
|
cls,
|
|
60
61
|
model: Decoder,
|
|
61
|
-
weights_dict:
|
|
62
|
+
weights_dict: Mapping[str, Array],
|
|
62
63
|
) -> Decoder:
|
|
63
64
|
return load_executorch(model, weights_dict)
|
|
64
65
|
|
|
@@ -97,7 +98,7 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
97
98
|
|
|
98
99
|
embedding_config = QuantizedTiedEmbeddingConfig(
|
|
99
100
|
input_scale=None,
|
|
100
|
-
|
|
101
|
+
logit_soft_cap=None,
|
|
101
102
|
embedding_quantization_mode=EMBEDDING_QUANTIZATION_MODE,
|
|
102
103
|
activation_quantization_mode=ACTIVATION_QUANTIZATION_MODE,
|
|
103
104
|
activation_precision=activation_precision,
|
|
@@ -132,12 +133,17 @@ class ETLlamaConfig(ExecutorchConfig):
|
|
|
132
133
|
query_norm_config=None,
|
|
133
134
|
key_norm_config=None,
|
|
134
135
|
logit_soft_cap=None,
|
|
136
|
+
has_sinks=False,
|
|
135
137
|
has_qkv_biases=False,
|
|
136
138
|
has_out_biases=False,
|
|
137
139
|
)
|
|
138
|
-
mlp_config =
|
|
140
|
+
mlp_config = DenseMLPConfig(
|
|
139
141
|
linear_config=linear_config,
|
|
140
|
-
activation=
|
|
142
|
+
activation=SiLU(),
|
|
143
|
+
has_up_biases=False,
|
|
144
|
+
has_down_biases=False,
|
|
145
|
+
up_clipping=None,
|
|
146
|
+
gate_clipping=None,
|
|
141
147
|
)
|
|
142
148
|
decoder_layer_config = DecoderLayerConfig(
|
|
143
149
|
pre_attention_norm_config=rmsnorm_config,
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from .common import HuggingFaceConfig
|
|
2
2
|
from .gemma2 import HFGemma2Config
|
|
3
3
|
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
|
+
from .gpt_oss import HFGPTOssConfig
|
|
4
5
|
from .llama import HFLlamaConfig
|
|
5
6
|
from .mistral import HFMistralConfig
|
|
6
7
|
from .qwen2 import HFQwen2Config
|
|
7
8
|
from .qwen3 import HFQwen3Config
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
11
|
+
"HFGPTOssConfig",
|
|
10
12
|
"HFGemma2Config",
|
|
11
13
|
"HFGemma3Config",
|
|
12
14
|
"HFGemma3TextConfig",
|
|
@@ -58,15 +58,13 @@ class GPTQQuantizationConfig:
|
|
|
58
58
|
|
|
59
59
|
@dataclass(frozen=True)
|
|
60
60
|
class HuggingFaceConfig(ForeignConfig):
|
|
61
|
-
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
62
|
-
|
|
63
61
|
@property
|
|
64
62
|
def eos_token_ids(self) -> list[int]:
|
|
65
63
|
return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
|
|
66
64
|
|
|
67
65
|
@property
|
|
68
66
|
def default_precision(self) -> DTypeLike:
|
|
69
|
-
return jnp.dtype(self
|
|
67
|
+
return jnp.dtype(getattr(self, "torch_dtype", "bfloat16"))
|
|
70
68
|
|
|
71
69
|
@classmethod
|
|
72
70
|
def _load_weights(
|
|
@@ -4,17 +4,17 @@ from typing import Literal
|
|
|
4
4
|
from jaxtyping import DTypeLike
|
|
5
5
|
|
|
6
6
|
from lalamo.modules import (
|
|
7
|
-
Activation,
|
|
8
7
|
AttentionConfig,
|
|
9
8
|
DecoderConfig,
|
|
10
9
|
DecoderLayerConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
11
|
FullPrecisionLinearConfig,
|
|
12
|
-
MLPConfig,
|
|
13
12
|
RMSNormConfig,
|
|
14
13
|
TiedEmbeddingConfig,
|
|
15
14
|
UnscaledRoPEConfig,
|
|
16
15
|
UpcastMode,
|
|
17
16
|
)
|
|
17
|
+
from lalamo.modules.activations import GELU
|
|
18
18
|
|
|
19
19
|
from .common import HuggingFaceConfig
|
|
20
20
|
|
|
@@ -50,6 +50,7 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
50
50
|
transformers_version: str
|
|
51
51
|
use_cache: bool
|
|
52
52
|
vocab_size: int
|
|
53
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
53
54
|
|
|
54
55
|
def to_decoder_config(
|
|
55
56
|
self,
|
|
@@ -64,7 +65,7 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
64
65
|
attention_scale = self.query_pre_attn_scalar**-0.5
|
|
65
66
|
embedding_config = TiedEmbeddingConfig(
|
|
66
67
|
input_scale=embedding_input_scale,
|
|
67
|
-
|
|
68
|
+
logit_soft_cap=self.final_logit_softcapping,
|
|
68
69
|
precision=activation_precision,
|
|
69
70
|
)
|
|
70
71
|
rope_config = UnscaledRoPEConfig(
|
|
@@ -88,12 +89,17 @@ class HFGemma2Config(HuggingFaceConfig):
|
|
|
88
89
|
query_norm_config=None,
|
|
89
90
|
key_norm_config=None,
|
|
90
91
|
logit_soft_cap=self.attn_logit_softcapping,
|
|
92
|
+
has_sinks=False,
|
|
91
93
|
has_qkv_biases=self.attention_bias,
|
|
92
94
|
has_out_biases=False,
|
|
93
95
|
)
|
|
94
|
-
mlp_config =
|
|
96
|
+
mlp_config = DenseMLPConfig(
|
|
95
97
|
linear_config=linear_config,
|
|
96
|
-
activation=
|
|
98
|
+
activation=GELU(),
|
|
99
|
+
has_up_biases=False,
|
|
100
|
+
has_down_biases=False,
|
|
101
|
+
up_clipping=None,
|
|
102
|
+
gate_clipping=None,
|
|
97
103
|
)
|
|
98
104
|
decoder_layer_config = DecoderLayerConfig(
|
|
99
105
|
pre_attention_norm_config=rmsnorm_config,
|