lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +271 -43
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +17 -7
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -4
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
  48. lalamo-0.4.0.dist-info/RECORD +71 -0
  49. lalamo-0.3.3.dist-info/RECORD +0 -59
  50. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/main.py CHANGED
@@ -1,12 +1,16 @@
1
1
  import json
2
+ import random
2
3
  import re
3
4
  import shutil
4
5
  import sys
5
6
  from enum import Enum
7
+ from itertools import chain
6
8
  from pathlib import Path
7
9
  from typing import Annotated
8
10
 
11
+ import jax
9
12
  import jax.numpy as jnp
13
+ import jax.profiler
10
14
  import thefuzz.process
11
15
  from click import Context as ClickContext
12
16
  from click import Parameter as ClickParameter
@@ -14,13 +18,24 @@ from click import ParamType
14
18
  from jaxtyping import DTypeLike
15
19
  from rich import box
16
20
  from rich.console import Console
21
+ from rich.live import Live
17
22
  from rich.panel import Panel
18
- from rich.progress import Progress, SpinnerColumn, TextColumn
23
+ from rich.progress import (
24
+ MofNCompleteColumn,
25
+ Progress,
26
+ SpinnerColumn,
27
+ TextColumn,
28
+ TimeElapsedColumn,
29
+ TimeRemainingColumn,
30
+ track,
31
+ )
19
32
  from rich.table import Table
20
33
  from safetensors.flax import save_file
21
- from typer import Argument, Exit, Option, Typer
34
+ from typer import Argument, Context, Exit, Option, Typer
22
35
 
23
36
  from lalamo.common import flatten_parameters
37
+ from lalamo.data import import_hf_parquet
38
+ from lalamo.data.lalamo_completions import LalamoCompletion
24
39
  from lalamo.language_model import LanguageModel
25
40
  from lalamo.message_processor import UserMessage
26
41
  from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
@@ -31,7 +46,10 @@ from lalamo.model_import.common import (
31
46
  InitializingModelEvent,
32
47
  StatusEvent,
33
48
  )
34
- from lalamo.modules import WeightLayout, config_converter
49
+ from lalamo.modules import config_converter
50
+ from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
51
+ from lalamo.speculator.ngram import NGramSpeculator
52
+ from lalamo.speculator.utils import SpeculatorTrainingEvent, test_speculator, train_speculator
35
53
  from lalamo.utils import jax_uint4_to_packed_uint8
36
54
 
37
55
  SCRIPT_NAME = Path(sys.argv[0]).name
@@ -110,27 +128,19 @@ def chat(
110
128
  metavar="MODEL_PATH",
111
129
  ),
112
130
  ],
113
- weight_layout: Annotated[
114
- WeightLayout | None,
115
- Option(
116
- help=(
117
- "(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
118
- "\n\n\n\n"
119
- "If set to AUTO, the layout will depend on the model."
120
- ),
121
- show_default="auto",
122
- ),
123
- ] = None,
124
131
  ) -> None:
125
- if weight_layout is None:
126
- weight_layout = WeightLayout.AUTO
127
132
  with Progress(
128
133
  SpinnerColumn(),
129
134
  TextColumn("[progress.description]{task.description}"),
130
135
  transient=True,
131
136
  ) as progress:
132
- progress.add_task("🚀 [cyan]Loading model...[/cyan]")
133
- model = LanguageModel.load(model_path, weight_layout)
137
+ loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
138
+ model = LanguageModel.load(model_path)
139
+ progress.remove_task(loading_task)
140
+ warmup_task = progress.add_task("🔥 Warming up compilation cache...")
141
+ list(model.stream_reply_text([UserMessage("")], max_output_length=1))
142
+ progress.remove_task(warmup_task)
143
+ console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
134
144
  messages = []
135
145
  while True:
136
146
  user_text = console.input("[cyan]user> [/cyan]")
@@ -170,17 +180,6 @@ def convert(
170
180
  show_default="Native precision of the model",
171
181
  ),
172
182
  ] = None,
173
- weight_layout: Annotated[
174
- WeightLayout | None,
175
- Option(
176
- help=(
177
- "(EXPERIMENTAL) Order of dimensions in the weights of linear layers."
178
- "\n\n\n\n"
179
- "If set to AUTO, the layout will depend on the model."
180
- ),
181
- show_default="auto",
182
- ),
183
- ] = None,
184
183
  output_dir: Annotated[
185
184
  Path | None,
186
185
  Option(
@@ -213,18 +212,10 @@ def convert(
213
212
  else:
214
213
  precision_dtype = None
215
214
 
216
- if weight_layout is not None:
217
- weight_layout = WeightLayout(weight_layout)
218
- else:
219
- weight_layout = WeightLayout.AUTO
220
-
221
215
  if output_dir is None:
222
216
  output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
223
217
 
224
- console.print(f"🚀 Converting [cyan]{model_repo.name}[/cyan] by [cyan]{model_repo.vendor}[/cyan].")
225
- conversion_strs = [
226
- f"⚙️ Using weight layout [cyan]{weight_layout}[/cyan]",
227
- ]
218
+ conversion_strs = [f"🚀 Converting [cyan]{model_repo.name}[/cyan] by [cyan]{model_repo.vendor}[/cyan]"]
228
219
  if precision is not None:
229
220
  conversion_strs.append(
230
221
  f" and converting floating-point weights into [cyan]{precision.name.lower()}[/cyan] precision",
@@ -292,7 +283,7 @@ def convert(
292
283
  progress.remove_task(main_task)
293
284
 
294
285
  model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
295
- weights = flatten_parameters(model.export_weights(weight_layout))
286
+ weights = flatten_parameters(model.export_weights())
296
287
  del model
297
288
 
298
289
  packed_weights = _pack_uint4_weights(weights)
@@ -312,10 +303,10 @@ def _model_size_string_to_int(
312
303
  ) -> float:
313
304
  match = _regex.match(size_str)
314
305
  factors = {
315
- "K": 1024**1,
316
- "M": 1024**2,
317
- "B": 1024**3,
318
- "T": 1024**4,
306
+ "K": 1000**1,
307
+ "M": 1000**2,
308
+ "B": 1000**3,
309
+ "T": 1000**4,
319
310
  }
320
311
  if match:
321
312
  return float(match.group("number")) * factors[match.group("suffix")]
@@ -368,5 +359,242 @@ def list_models(
368
359
  console.print(table)
369
360
 
370
361
 
362
+ speculator_app = Typer()
363
+ app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
364
+
365
+
366
+ @speculator_app.command(help="Run model inference and collect traces for speculator training")
367
+ def collect_traces(
368
+ model_path: Annotated[
369
+ Path,
370
+ Argument(
371
+ help="Path to the model directory",
372
+ metavar="MODEL_PATH",
373
+ ),
374
+ ],
375
+ dataset_path: Annotated[
376
+ Path,
377
+ Argument(
378
+ help="Path to the dataset with prompts",
379
+ metavar="DATASET_PATH",
380
+ ),
381
+ ],
382
+ output_path: Annotated[
383
+ Path,
384
+ Option(
385
+ help="File to save the trace to",
386
+ metavar="OUTPUT_PATH",
387
+ ),
388
+ ],
389
+ num_logits_per_token: Annotated[
390
+ int,
391
+ Option(help="Record logits for this number of most probable tokens"),
392
+ ] = 8,
393
+ max_input_length: Annotated[
394
+ int,
395
+ Option(help="Filter prompts that have more than this number of tokens in context"),
396
+ ] = 1024,
397
+ max_output_length: Annotated[
398
+ int,
399
+ Option(help="Maximum number of tokens to generate in one completion"),
400
+ ] = 1024,
401
+ batch_size: Annotated[
402
+ int,
403
+ Option(help="Number of sequences in one batch"),
404
+ ] = 1,
405
+ num_tokens_to_generate: Annotated[
406
+ int | None,
407
+ Option(
408
+ help="Exit early after generating this number of output tokens",
409
+ show_default="all",
410
+ ),
411
+ ] = None,
412
+ ) -> None:
413
+ with Live(refresh_per_second=10) as live:
414
+ with Progress(
415
+ SpinnerColumn(),
416
+ TextColumn("[progress.description]{task.description}"),
417
+ transient=True,
418
+ disable=True,
419
+ ) as progress:
420
+ live.update(progress, refresh=True)
421
+ loading_model_task = progress.add_task("🧠 [cyan]Loading model...[/cyan]")
422
+ model = LanguageModel.load(model_path)
423
+ progress.remove_task(loading_model_task)
424
+
425
+ loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
426
+ dataset = iter(import_hf_parquet(dataset_path))
427
+ dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
428
+ progress.remove_task(loading_dataset_task)
429
+
430
+ with Progress(
431
+ SpinnerColumn(),
432
+ TextColumn("[progress.description]{task.description}"),
433
+ MofNCompleteColumn(),
434
+ TimeElapsedColumn(),
435
+ TimeRemainingColumn(),
436
+ disable=True,
437
+ ) as progress:
438
+ live.update(progress, refresh=True)
439
+ inference_task = progress.add_task("🔮 [cyan]Running inference...[/cyan]", total=num_tokens_to_generate)
440
+
441
+ def progress_callback(event: CollectTracesEvent) -> None:
442
+ progress.update(inference_task, completed=event.tokens_generated)
443
+
444
+ traces = inference_collect_traces(
445
+ model,
446
+ dataset,
447
+ num_logits_per_token,
448
+ batch_size,
449
+ max_input_length,
450
+ max_output_length,
451
+ num_tokens_to_generate,
452
+ progress_callback,
453
+ )
454
+
455
+ output_path.parent.mkdir(parents=True, exist_ok=True)
456
+ with open(output_path, "wb+") as output_fd:
457
+ for trace in traces:
458
+ blob = trace.serialize()
459
+ output_fd.write(blob)
460
+
461
+ progress.update(inference_task, description="✅ Completed")
462
+
463
+
464
+ @speculator_app.command(help="Train a speculator from inference traces")
465
+ def train(
466
+ trace_path: Annotated[
467
+ Path,
468
+ Argument(
469
+ help="File of llm inference traces to train the speculator on",
470
+ metavar="TRACE_PATH",
471
+ ),
472
+ ],
473
+ output_path: Annotated[
474
+ Path,
475
+ Option(
476
+ help="File to save the output to",
477
+ metavar="OUTPUT_PATH",
478
+ ),
479
+ ],
480
+ hashtable_size: Annotated[
481
+ int,
482
+ Option(help="Size of ngram hashtable"),
483
+ ] = 65536,
484
+ num_logits_per_token: Annotated[
485
+ int,
486
+ Option(help="Top K tokens to keep in ngram hashtable"),
487
+ ] = 8,
488
+ ngram_size: Annotated[
489
+ int,
490
+ Option(help="Length of ngrams"),
491
+ ] = 2,
492
+ subsample_size: Annotated[
493
+ int | None,
494
+ Option(
495
+ help="Exit early after training the model on this number of tokens",
496
+ show_default="all",
497
+ ),
498
+ ] = None,
499
+ ) -> None:
500
+ with open(trace_path, "rb") as trace_fd:
501
+ traces = LalamoCompletion.deserialize_many(trace_fd)
502
+
503
+ speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
504
+
505
+ with Progress(
506
+ SpinnerColumn(),
507
+ TextColumn("[progress.description]{task.description}"),
508
+ MofNCompleteColumn(),
509
+ TimeElapsedColumn(),
510
+ TimeRemainingColumn(),
511
+ ) as progress:
512
+ inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
513
+
514
+
515
+ def progress_callback(event: SpeculatorTrainingEvent) -> None:
516
+ progress.update(inference_task, completed=event.trained_tokens)
517
+
518
+ train_speculator(speculator, traces, subsample_size, progress_callback)
519
+
520
+ progress.update(inference_task, description="✅ Completed")
521
+
522
+ output_path.parent.mkdir(parents=True, exist_ok=True)
523
+ with open(output_path, "wb+") as fd:
524
+ fd.write(speculator.serialize())
525
+
526
+
527
+ @speculator_app.command(help="Run speculator as an autoregressive llm")
528
+ def test(
529
+ speculator_path: Annotated[
530
+ Path,
531
+ Argument(
532
+ help="Path to the speculator file.",
533
+ metavar="SPECULATOR_PATH",
534
+ ),
535
+ ],
536
+ model_path: Annotated[
537
+ Path,
538
+ Argument(
539
+ help="Path to the model directory for detokenization.",
540
+ metavar="MODEL_PATH",
541
+ ),
542
+ ],
543
+ seed: Annotated[
544
+ int | None,
545
+ Option(help="Set seed for deterministic sampling"),
546
+ ] = None,
547
+ num_sequences: Annotated[
548
+ int,
549
+ Option(help="Number of sequences to generate"),
550
+ ] = 8,
551
+ ) -> None:
552
+ model = LanguageModel.load(model_path)
553
+
554
+ with open(speculator_path, "rb") as fd:
555
+ speculator = NGramSpeculator.deserialize(fd.read())
556
+
557
+ table = Table(
558
+ show_header=False,
559
+ show_lines=True,
560
+ box=box.ROUNDED,
561
+ )
562
+
563
+ if seed is not None:
564
+ random.seed(seed)
565
+
566
+ for _ in range(num_sequences):
567
+ sequence = test_speculator(speculator)
568
+ detokenized = model.message_processor.detokenize(sequence)
569
+ table.add_row(detokenized)
570
+
571
+ console.print(table)
572
+
573
+
574
+ @app.callback()
575
+ def _profile_memory(
576
+ ctx: Context,
577
+ profile_memory: Annotated[
578
+ Path | None,
579
+ Option(
580
+ help="Record and save the XLA memory profile to specified path",
581
+ show_default="Don't save the XLA memory profile",
582
+ envvar="LALAMO_PROFILE_MEMORY",
583
+ ),
584
+ ] = None,
585
+ ) -> None:
586
+ if profile_memory is None:
587
+ return
588
+
589
+ if profile_memory.is_dir():
590
+ profile_memory /= "lalamo-memory.prof"
591
+
592
+ def _save_memory_profile() -> None:
593
+ console.print(f"Saving XLA memory profile to {profile_memory}")
594
+ jax.profiler.save_device_memory_profile(profile_memory)
595
+
596
+ ctx.call_on_close(_save_memory_profile)
597
+
598
+
371
599
  if __name__ == "__main__":
372
600
  app()
@@ -1,6 +1,7 @@
1
1
  import re
2
2
  from collections.abc import Iterable
3
3
  from dataclasses import dataclass
4
+ from datetime import datetime
4
5
  from functools import cached_property
5
6
  from re import Pattern
6
7
  from typing import NotRequired, TypedDict
@@ -24,6 +25,10 @@ type ToolSchema = None # WIP
24
25
  type Image = None # WIP
25
26
 
26
27
 
28
+ def _strftime_now(format_string: str) -> str:
29
+ return datetime.now().strftime(format_string) # noqa: DTZ005
30
+
31
+
27
32
  class HuggingFaceMessage(TypedDict):
28
33
  role: str
29
34
  content: str
@@ -141,7 +146,7 @@ class MessageProcessor:
141
146
 
142
147
  def render_request(self, messages: Iterable[Message]) -> str:
143
148
  request_dict = self.request_to_dict(messages)
144
- return self.prompt_template.render(request_dict)
149
+ return self.prompt_template.render({**request_dict, "strftime_now": _strftime_now})
145
150
 
146
151
  def parse_response(self, response: str) -> AssistantMessage:
147
152
  if self.output_parser_regex is None:
@@ -154,6 +159,11 @@ class MessageProcessor:
154
159
  def tokenize(self, text: str) -> list[int]:
155
160
  return self.tokenizer.encode(text, add_special_tokens=False).ids
156
161
 
162
+ def tokenize_request(self, messages: Iterable[Message]) -> list[int]:
163
+ rendered = self.render_request(messages)
164
+ tokenized = self.tokenize(rendered)
165
+ return tokenized
166
+
157
167
  def detokenize(self, tokens: list[int]) -> str:
158
168
  return self.tokenizer.decode(tokens, skip_special_tokens=False)
159
169
 
@@ -1,6 +1,7 @@
1
1
  import importlib.metadata
2
2
  from collections import ChainMap
3
3
  from collections.abc import Callable
4
+ from contextlib import ExitStack
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
  from typing import NamedTuple
@@ -138,7 +139,13 @@ def import_message_processor(
138
139
  raise ValueError("Conflicting chat template specifications.")
139
140
  prompt_template = tokenizer_config.chat_template
140
141
  tokenizer = Tokenizer.from_file(str(tokenizer_file))
141
- tokenizer.add_special_tokens(tokenizer_config.added_tokens())
142
+
143
+ added_tokens = tokenizer_config.added_tokens()
144
+ added_special_tokens = [token for token in added_tokens if token.special]
145
+ added_not_special_tokens = [token for token in added_tokens if not token.special]
146
+ tokenizer.add_special_tokens(added_special_tokens)
147
+ tokenizer.add_tokens(added_not_special_tokens)
148
+
142
149
  message_processor_config = MessageProcessorConfig(
143
150
  prompt_template=prompt_template,
144
151
  output_parser_regex=model_spec.output_parser_regex,
@@ -171,14 +178,17 @@ def import_model(
171
178
  precision = foreign_decoder_config.default_precision
172
179
 
173
180
  weights_paths = download_weights(model_spec, progress_callback=progress_callback)
174
- weights_dict: ChainMap[str, Array] = ChainMap(
175
- *[model_spec.weights_type.load(weights_path, precision) for weights_path in weights_paths], # type: ignore
176
- )
181
+ with ExitStack() as stack:
182
+ weights_shards = []
183
+ for weights_path in weights_paths:
184
+ weights_shard = stack.enter_context(model_spec.weights_type.load(weights_path, precision))
185
+ weights_shards.append(weights_shard)
186
+ weights_dict: ChainMap[str, Array] = ChainMap(*weights_shards)
177
187
 
178
- if progress_callback is not None:
179
- progress_callback(InitializingModelEvent())
188
+ if progress_callback is not None:
189
+ progress_callback(InitializingModelEvent())
180
190
 
181
- decoder = foreign_decoder_config.load_decoder(context_length, precision, accumulation_precision, weights_dict)
191
+ decoder = foreign_decoder_config.load_decoder(context_length, precision, accumulation_precision, weights_dict)
182
192
 
183
193
  if progress_callback is not None:
184
194
  progress_callback(FinishedInitializingModelEvent())
@@ -1,9 +1,11 @@
1
1
  from .common import ForeignConfig
2
+
2
3
  # from .executorch import ETLlamaConfig
3
4
  from .huggingface import (
4
5
  HFGemma2Config,
5
6
  HFGemma3Config,
6
7
  HFGemma3TextConfig,
8
+ HFGPTOssConfig,
7
9
  HFLlamaConfig,
8
10
  HFMistralConfig,
9
11
  HFQwen2Config,
@@ -13,6 +15,7 @@ from .huggingface import (
13
15
  __all__ = [
14
16
  # "ETLlamaConfig",
15
17
  "ForeignConfig",
18
+ "HFGPTOssConfig",
16
19
  "HFGemma2Config",
17
20
  "HFGemma3Config",
18
21
  "HFGemma3TextConfig",
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
 
3
4
  import jax.numpy as jnp
@@ -5,18 +6,18 @@ from jaxtyping import Array, DTypeLike
5
6
 
6
7
  from lalamo.model_import.loaders.executorch import load_executorch
7
8
  from lalamo.modules import (
8
- Activation,
9
9
  AttentionConfig,
10
10
  Decoder,
11
11
  DecoderConfig,
12
12
  DecoderLayerConfig,
13
+ DenseMLPConfig,
13
14
  LlamaRoPEConfig,
14
- MLPConfig,
15
15
  QLoRALinearConfig,
16
16
  QuantizedTiedEmbeddingConfig,
17
17
  RMSNormConfig,
18
18
  UpcastMode,
19
19
  )
20
+ from lalamo.modules.activations import SiLU
20
21
  from lalamo.quantization import QuantizationMode
21
22
 
22
23
  from .common import ForeignConfig
@@ -58,7 +59,7 @@ class ExecutorchConfig(ForeignConfig):
58
59
  def _load_weights(
59
60
  cls,
60
61
  model: Decoder,
61
- weights_dict: dict[str, Array],
62
+ weights_dict: Mapping[str, Array],
62
63
  ) -> Decoder:
63
64
  return load_executorch(model, weights_dict)
64
65
 
@@ -97,7 +98,7 @@ class ETLlamaConfig(ExecutorchConfig):
97
98
 
98
99
  embedding_config = QuantizedTiedEmbeddingConfig(
99
100
  input_scale=None,
100
- logits_soft_cap=None,
101
+ logit_soft_cap=None,
101
102
  embedding_quantization_mode=EMBEDDING_QUANTIZATION_MODE,
102
103
  activation_quantization_mode=ACTIVATION_QUANTIZATION_MODE,
103
104
  activation_precision=activation_precision,
@@ -132,12 +133,17 @@ class ETLlamaConfig(ExecutorchConfig):
132
133
  query_norm_config=None,
133
134
  key_norm_config=None,
134
135
  logit_soft_cap=None,
136
+ has_sinks=False,
135
137
  has_qkv_biases=False,
136
138
  has_out_biases=False,
137
139
  )
138
- mlp_config = MLPConfig(
140
+ mlp_config = DenseMLPConfig(
139
141
  linear_config=linear_config,
140
- activation=Activation.SILU,
142
+ activation=SiLU(),
143
+ has_up_biases=False,
144
+ has_down_biases=False,
145
+ up_clipping=None,
146
+ gate_clipping=None,
141
147
  )
142
148
  decoder_layer_config = DecoderLayerConfig(
143
149
  pre_attention_norm_config=rmsnorm_config,
@@ -1,12 +1,14 @@
1
1
  from .common import HuggingFaceConfig
2
2
  from .gemma2 import HFGemma2Config
3
3
  from .gemma3 import HFGemma3Config, HFGemma3TextConfig
4
+ from .gpt_oss import HFGPTOssConfig
4
5
  from .llama import HFLlamaConfig
5
6
  from .mistral import HFMistralConfig
6
7
  from .qwen2 import HFQwen2Config
7
8
  from .qwen3 import HFQwen3Config
8
9
 
9
10
  __all__ = [
11
+ "HFGPTOssConfig",
10
12
  "HFGemma2Config",
11
13
  "HFGemma3Config",
12
14
  "HFGemma3TextConfig",
@@ -58,15 +58,13 @@ class GPTQQuantizationConfig:
58
58
 
59
59
  @dataclass(frozen=True)
60
60
  class HuggingFaceConfig(ForeignConfig):
61
- torch_dtype: Literal["bfloat16", "float16", "float32"]
62
-
63
61
  @property
64
62
  def eos_token_ids(self) -> list[int]:
65
63
  return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
66
64
 
67
65
  @property
68
66
  def default_precision(self) -> DTypeLike:
69
- return jnp.dtype(self.torch_dtype)
67
+ return jnp.dtype(getattr(self, "torch_dtype", "bfloat16"))
70
68
 
71
69
  @classmethod
72
70
  def _load_weights(
@@ -4,17 +4,17 @@ from typing import Literal
4
4
  from jaxtyping import DTypeLike
5
5
 
6
6
  from lalamo.modules import (
7
- Activation,
8
7
  AttentionConfig,
9
8
  DecoderConfig,
10
9
  DecoderLayerConfig,
10
+ DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
- MLPConfig,
13
12
  RMSNormConfig,
14
13
  TiedEmbeddingConfig,
15
14
  UnscaledRoPEConfig,
16
15
  UpcastMode,
17
16
  )
17
+ from lalamo.modules.activations import GELU
18
18
 
19
19
  from .common import HuggingFaceConfig
20
20
 
@@ -50,6 +50,7 @@ class HFGemma2Config(HuggingFaceConfig):
50
50
  transformers_version: str
51
51
  use_cache: bool
52
52
  vocab_size: int
53
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
53
54
 
54
55
  def to_decoder_config(
55
56
  self,
@@ -64,7 +65,7 @@ class HFGemma2Config(HuggingFaceConfig):
64
65
  attention_scale = self.query_pre_attn_scalar**-0.5
65
66
  embedding_config = TiedEmbeddingConfig(
66
67
  input_scale=embedding_input_scale,
67
- logits_soft_cap=self.final_logit_softcapping,
68
+ logit_soft_cap=self.final_logit_softcapping,
68
69
  precision=activation_precision,
69
70
  )
70
71
  rope_config = UnscaledRoPEConfig(
@@ -88,12 +89,17 @@ class HFGemma2Config(HuggingFaceConfig):
88
89
  query_norm_config=None,
89
90
  key_norm_config=None,
90
91
  logit_soft_cap=self.attn_logit_softcapping,
92
+ has_sinks=False,
91
93
  has_qkv_biases=self.attention_bias,
92
94
  has_out_biases=False,
93
95
  )
94
- mlp_config = MLPConfig(
96
+ mlp_config = DenseMLPConfig(
95
97
  linear_config=linear_config,
96
- activation=Activation.GELU,
98
+ activation=GELU(),
99
+ has_up_biases=False,
100
+ has_down_biases=False,
101
+ up_clipping=None,
102
+ gate_clipping=None,
97
103
  )
98
104
  decoder_layer_config = DecoderLayerConfig(
99
105
  pre_attention_norm_config=rmsnorm_config,