lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +15 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +6 -2
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
  49. lalamo-0.5.4.dist-info/RECORD +88 -0
  50. lalamo-0.5.2.dist-info/RECORD +0 -80
  51. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
lalamo/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from lalamo.language_model import LanguageModel
2
1
  from lalamo.message_processor import (
3
2
  AssistantMessage,
4
3
  ContentBlock,
@@ -9,18 +8,32 @@ from lalamo.message_processor import (
9
8
  UserMessage,
10
9
  )
11
10
  from lalamo.model_import import ModelSpec, import_model
11
+ from lalamo.models import LanguageModel, Router
12
+ from lalamo.speculator import (
13
+ CollectTracesEvent,
14
+ SpeculatorTrainingEvent,
15
+ estimate_batchsize_from_memory,
16
+ inference_collect_traces,
17
+ train_speculator,
18
+ )
12
19
 
13
- __version__ = "0.5.2"
20
+ __version__ = "0.5.4"
14
21
 
15
22
  __all__ = [
16
23
  "AssistantMessage",
24
+ "CollectTracesEvent",
17
25
  "ContentBlock",
18
26
  "Image",
19
27
  "LanguageModel",
20
28
  "Message",
21
29
  "ModelSpec",
30
+ "Router",
31
+ "SpeculatorTrainingEvent",
22
32
  "SystemMessage",
23
33
  "ToolSchema",
24
34
  "UserMessage",
35
+ "estimate_batchsize_from_memory",
25
36
  "import_model",
37
+ "inference_collect_traces",
38
+ "train_speculator",
26
39
  ]
lalamo/data/__init__.py CHANGED
@@ -5,4 +5,3 @@ __all__ = [
5
5
  "get_prefixes_ending_in_user_message",
6
6
  "import_hf_parquet",
7
7
  ]
8
-
@@ -29,6 +29,7 @@ class HFMessage:
29
29
  case other:
30
30
  raise ValueError(f"Cannot convert {other} message")
31
31
 
32
+
32
33
  def import_hf_parquet(path: Path | str) -> Iterable[list[Message]]:
33
34
  path = Path(path)
34
35
 
lalamo/main.py CHANGED
@@ -4,12 +4,11 @@ import re
4
4
  import shutil
5
5
  import sys
6
6
  from enum import Enum
7
- from itertools import chain
7
+ from itertools import chain, islice
8
8
  from pathlib import Path
9
9
  from typing import Annotated
10
10
 
11
11
  import jax
12
- import jax.numpy as jnp
13
12
  import jax.profiler
14
13
  import thefuzz.process
15
14
  from click import Context as ClickContext
@@ -35,7 +34,6 @@ from typer import Argument, Context, Exit, Option, Typer
35
34
  from lalamo.common import flatten_parameters
36
35
  from lalamo.data import import_hf_parquet
37
36
  from lalamo.data.lalamo_completions import LalamoCompletion
38
- from lalamo.language_model import LanguageModel
39
37
  from lalamo.message_processor import UserMessage
40
38
  from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
41
39
  from lalamo.model_import.common import (
@@ -45,10 +43,16 @@ from lalamo.model_import.common import (
45
43
  InitializingModelEvent,
46
44
  StatusEvent,
47
45
  )
46
+ from lalamo.models import LanguageModelConfig, RouterConfig
48
47
  from lalamo.modules import config_converter
48
+ from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
49
49
  from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
50
50
  from lalamo.speculator.ngram import NGramSpeculator
51
- from lalamo.speculator.utils import SpeculatorTrainingEvent, test_speculator, train_speculator
51
+ from lalamo.speculator.utils import (
52
+ SpeculatorTrainingEvent,
53
+ test_speculator,
54
+ train_speculator,
55
+ )
52
56
 
53
57
  SCRIPT_NAME = Path(sys.argv[0]).name
54
58
 
@@ -123,7 +127,7 @@ def chat(
123
127
  transient=True,
124
128
  ) as progress:
125
129
  loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
126
- model = LanguageModel.load(model_path)
130
+ model = LanguageModelConfig.load_model(model_path)
127
131
  progress.remove_task(loading_task)
128
132
  warmup_task = progress.add_task("🔥 Warming up compilation cache...")
129
133
  list(model.stream_reply_text([UserMessage("")], max_output_length=1))
@@ -145,6 +149,39 @@ def chat(
145
149
  messages.append(model.message_processor.parse_response(model_response_text))
146
150
 
147
151
 
152
+ @app.command(help="Classify given message with a Router type of model.")
153
+ def classify(
154
+ model_path: Annotated[
155
+ Path,
156
+ Argument(
157
+ help="Path to the model directory.",
158
+ metavar="MODEL_PATH",
159
+ ),
160
+ ],
161
+ ) -> None:
162
+ with Progress(
163
+ SpinnerColumn(),
164
+ TextColumn("[progress.description]{task.description}"),
165
+ transient=True,
166
+ ) as progress:
167
+ loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
168
+ model = RouterConfig.load_model(model_path)
169
+ progress.remove_task(loading_task)
170
+ warmup_task = progress.add_task("🔥 Warming up...")
171
+ model.classify_chat([UserMessage(content="warmup message")])
172
+ progress.remove_task(warmup_task)
173
+ console.print(f"🤖 Classifying input with [blue]{model_path}[/blue]:")
174
+ while True:
175
+ user_text = console.input("[cyan]user> [/cyan]")
176
+ user_message = UserMessage(user_text)
177
+
178
+ console.print("[red]assistant> [/red]", end="")
179
+ result = model.classify_chat([user_message])
180
+ for label, confidence in result.items():
181
+ console.print(f"{label} : {confidence}", end="")
182
+ console.print()
183
+
184
+
148
185
  @app.command(help="Convert the model for use with the Uzu inference engine.")
149
186
  def convert(
150
187
  model_repo: Annotated[
@@ -194,6 +231,12 @@ def convert(
194
231
  help="Overwrite existing model files.",
195
232
  ),
196
233
  ] = False,
234
+ message_for_trace: Annotated[
235
+ str | None,
236
+ Option(
237
+ help="Text message to use as prompt when recording trace",
238
+ ),
239
+ ] = None,
197
240
  ) -> None:
198
241
  if precision is not None:
199
242
  precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
@@ -224,6 +267,8 @@ def convert(
224
267
  console.print("Exiting...")
225
268
  raise Exit
226
269
 
270
+ message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
271
+
227
272
  with Progress(
228
273
  SpinnerColumn(),
229
274
  TextColumn("[progress.description]{task.description}"),
@@ -254,17 +299,7 @@ def convert(
254
299
 
255
300
  if include_traces:
256
301
  trace_task = progress.add_task("🚁 Generating traces...")
257
-
258
- num_tokens = 512
259
- token_stride = 8
260
- token_ids = jnp.arange(0, num_tokens, dtype=jnp.int32)[None, :]
261
- token_positions = jnp.arange(0, num_tokens * token_stride, token_stride, dtype=jnp.int32)[None, :]
262
- result = model.decoder(
263
- token_ids,
264
- token_positions,
265
- return_updated_state=True,
266
- return_activation_trace=True,
267
- )
302
+ result = model.record_trace(message)
268
303
  traces = flatten_parameters(result.export())
269
304
  save_file(traces, output_dir / "traces.safetensors")
270
305
  progress.remove_task(trace_task)
@@ -350,6 +385,77 @@ speculator_app = Typer()
350
385
  app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
351
386
 
352
387
 
388
+ @speculator_app.command(help="Estimate maximum batch size at which a model can be run.")
389
+ def estimate_batchsize(
390
+ model_path: Annotated[
391
+ Path,
392
+ Argument(
393
+ help="Path to the model directory",
394
+ metavar="MODEL_PATH",
395
+ ),
396
+ ],
397
+ max_input_length: Annotated[
398
+ int,
399
+ Option(help="Max input length of a model."),
400
+ ] = 1024,
401
+ max_output_length: Annotated[
402
+ int,
403
+ Option(help="Max output length of a model."),
404
+ ] = 1024,
405
+ num_logits_per_token: Annotated[
406
+ int,
407
+ Option(help="Number of top logits that will be recorded."),
408
+ ] = 8,
409
+ vram_gb: Annotated[
410
+ int | None,
411
+ Option(
412
+ help="Maximum vram size in gb allowed.",
413
+ show_default="max on default device",
414
+ ),
415
+ ] = None,
416
+ ) -> None:
417
+ if vram_gb is not None:
418
+ mem = vram_gb * 1024 * 1024 * 1024
419
+ else:
420
+ memory_stats = jax.local_devices()[0].memory_stats()
421
+ if memory_stats is None:
422
+ err_console.print("Cannot get the default device's memory stats, use --vram-gb")
423
+ raise Exit(1)
424
+ if "bytes_limit" not in memory_stats:
425
+ err_console.print("Cannot get the default device's bytes limit, use --vram-gb")
426
+ raise Exit(1)
427
+ mem = memory_stats["bytes_limit"]
428
+
429
+ with Progress(
430
+ SpinnerColumn(),
431
+ TextColumn("[progress.description]{task.description}"),
432
+ transient=True,
433
+ ) as progress:
434
+ loading_model_task = progress.add_task("[cyan]Loading model...[/cyan]")
435
+ model = LanguageModelConfig.load_model(model_path)
436
+ progress.remove_task(loading_model_task)
437
+
438
+ estimating_batchsize_task = progress.add_task("[cyan]Estimating batch size...[/cyan]")
439
+
440
+ def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
441
+ lo = str(event.lo)
442
+ hi = str(event.hi) if event.hi is not None else "?"
443
+ description = f"[cyan]Estimating batch size... ({lo}..{hi})[/cyan]"
444
+ progress.update(estimating_batchsize_task, description=description)
445
+
446
+ bs = estimate_batchsize_from_memory(
447
+ model,
448
+ max_input_length,
449
+ max_output_length,
450
+ num_logits_per_token,
451
+ mem,
452
+ progress_callback,
453
+ )
454
+ progress.remove_task(estimating_batchsize_task)
455
+
456
+ console.print(f"Found maximum batch size: [cyan]{bs}[/cyan]")
457
+
458
+
353
459
  @speculator_app.command(help="Run model inference and collect traces for speculator training")
354
460
  def collect_traces(
355
461
  model_path: Annotated[
@@ -406,7 +512,7 @@ def collect_traces(
406
512
  ) as progress:
407
513
  live.update(progress, refresh=True)
408
514
  loading_model_task = progress.add_task("🧠 [cyan]Loading model...[/cyan]")
409
- model = LanguageModel.load(model_path)
515
+ model = LanguageModelConfig.load_model(model_path)
410
516
  progress.remove_task(loading_model_task)
411
517
 
412
518
  loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
@@ -448,6 +554,49 @@ def collect_traces(
448
554
  progress.update(inference_task, description="✅ Completed")
449
555
 
450
556
 
557
+ @speculator_app.command(help="View model inference traces")
558
+ def view_traces(
559
+ trace_path: Annotated[
560
+ Path,
561
+ Argument(
562
+ help="File of inference traces to view.",
563
+ metavar="TRACE_PATH",
564
+ ),
565
+ ],
566
+ model_path: Annotated[
567
+ Path,
568
+ Argument(
569
+ help="Path to the model directory for detokenization.",
570
+ metavar="MODEL_PATH",
571
+ ),
572
+ ],
573
+ num_completions: Annotated[
574
+ int | None,
575
+ Option(
576
+ help="Number of completions to show.",
577
+ ),
578
+ ] = None,
579
+ ) -> None:
580
+ model = LanguageModelConfig.load_model(model_path)
581
+
582
+ with open(trace_path, "rb") as trace_fd:
583
+ traces = LalamoCompletion.deserialize_many(trace_fd)
584
+
585
+ table = Table(
586
+ show_lines=True,
587
+ box=box.ROUNDED,
588
+ )
589
+ table.add_column("Prefix")
590
+ table.add_column("Completion")
591
+
592
+ for completion in islice(traces, num_completions):
593
+ detokenized_prefix = model.message_processor.detokenize(completion.prefix_token_ids)
594
+ detokenized_completion = model.message_processor.detokenize(completion.completion_token_ids)
595
+ table.add_row(detokenized_prefix, detokenized_completion)
596
+
597
+ console.print(table)
598
+
599
+
451
600
  @speculator_app.command(help="Train a speculator from inference traces")
452
601
  def train(
453
602
  trace_path: Annotated[
@@ -535,7 +684,7 @@ def test(
535
684
  Option(help="Number of sequences to generate"),
536
685
  ] = 8,
537
686
  ) -> None:
538
- model = LanguageModel.load(model_path)
687
+ model = LanguageModelConfig.load_model(model_path)
539
688
 
540
689
  with open(speculator_path, "rb") as fd:
541
690
  speculator = NGramSpeculator.deserialize(fd.read())
@@ -156,13 +156,12 @@ class MessageProcessor:
156
156
  raise ValueError(f"Invalid response format: {response}")
157
157
  return AssistantMessage(**match.groupdict())
158
158
 
159
- def tokenize(self, text: str) -> list[int]:
159
+ def tokenize_text(self, text: str) -> list[int]:
160
160
  return self.tokenizer.encode(text, add_special_tokens=False).ids
161
161
 
162
162
  def tokenize_request(self, messages: Iterable[Message]) -> list[int]:
163
163
  rendered = self.render_request(messages)
164
- tokenized = self.tokenize(rendered)
165
- return tokenized
164
+ return self.tokenize_text(rendered)
166
165
 
167
166
  def detokenize(self, tokens: list[int]) -> str:
168
167
  return self.tokenizer.decode(tokens, skip_special_tokens=False)
@@ -13,14 +13,16 @@ from jax import Array
13
13
  from jaxtyping import DTypeLike
14
14
  from tokenizers import Tokenizer
15
15
 
16
- from lalamo.language_model import GenerationConfig, LanguageModel, LanguageModelConfig
17
16
  from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
18
- from lalamo.model_import.model_specs.common import JSONFieldSpec
17
+ from lalamo.models import GenerationConfig, LanguageModel, LanguageModelConfig, Router, RouterConfig
18
+ from lalamo.modules import Classifier, Decoder, LalamoModule
19
19
  from lalamo.quantization import QuantizationMode
20
20
 
21
+ from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
21
22
  from .huggingface_generation_config import HFGenerationConfig
22
23
  from .huggingface_tokenizer_config import HFTokenizerConfig
23
- from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, UseCase
24
+ from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, ModelType, UseCase
25
+ from .model_specs.common import JSONFieldSpec
24
26
 
25
27
  __all__ = [
26
28
  "REPO_TO_MODEL",
@@ -29,6 +31,7 @@ __all__ = [
29
31
  "InitializingModelEvent",
30
32
  "ModelMetadata",
31
33
  "ModelSpec",
34
+ "ModelType",
32
35
  "StatusEvent",
33
36
  "import_model",
34
37
  ]
@@ -68,7 +71,8 @@ class ModelMetadata:
68
71
  quantization: QuantizationMode | None
69
72
  repo: str
70
73
  use_cases: tuple[UseCase, ...]
71
- model_config: LanguageModelConfig
74
+ model_type: ModelType
75
+ model_config: LanguageModelConfig | RouterConfig
72
76
 
73
77
 
74
78
  def download_file(
@@ -114,7 +118,7 @@ def download_config_file(
114
118
 
115
119
 
116
120
  class ImportResults(NamedTuple):
117
- model: LanguageModel
121
+ model: LanguageModel | Router
118
122
  metadata: ModelMetadata
119
123
 
120
124
 
@@ -166,26 +170,14 @@ def import_message_processor(
166
170
  return MessageProcessor(config=message_processor_config, tokenizer=tokenizer)
167
171
 
168
172
 
169
- def import_model(
170
- model_spec: ModelSpec | str,
171
- *,
173
+ def _load_main_processing_module(
174
+ model_spec: ModelSpec,
175
+ precision: DTypeLike,
176
+ foreign_config: ForeignConfig,
177
+ progress_callback: Callable[[StatusEvent], None] | None = None,
172
178
  context_length: int | None = None,
173
- precision: DTypeLike | None = None,
174
179
  accumulation_precision: DTypeLike = jnp.float32,
175
- progress_callback: Callable[[StatusEvent], None] | None = None,
176
- ) -> ImportResults:
177
- if isinstance(model_spec, str):
178
- try:
179
- model_spec = REPO_TO_MODEL[model_spec]
180
- except KeyError as e:
181
- raise ValueError(f"Unknown model: {model_spec}") from e
182
-
183
- foreign_decoder_config_file = download_config_file(model_spec)
184
- foreign_decoder_config = model_spec.config_type.from_json(foreign_decoder_config_file)
185
-
186
- if precision is None:
187
- precision = foreign_decoder_config.default_precision
188
-
180
+ ) -> LalamoModule:
189
181
  weights_paths = download_weights(model_spec, progress_callback=progress_callback)
190
182
  with ExitStack() as stack:
191
183
  weights_shards = []
@@ -200,7 +192,7 @@ def import_model(
200
192
  if progress_callback is not None:
201
193
  progress_callback(InitializingModelEvent())
202
194
 
203
- decoder = foreign_decoder_config.load_decoder(
195
+ processing_module = foreign_config.load(
204
196
  context_length,
205
197
  precision,
206
198
  accumulation_precision,
@@ -208,6 +200,33 @@ def import_model(
208
200
  metadata_dict,
209
201
  )
210
202
 
203
+ return processing_module
204
+
205
+
206
+ def _import_language_model(
207
+ model_spec: ModelSpec,
208
+ *,
209
+ context_length: int | None = None,
210
+ precision: DTypeLike | None = None,
211
+ accumulation_precision: DTypeLike = jnp.float32,
212
+ progress_callback: Callable[[StatusEvent], None] | None = None,
213
+ ) -> tuple[LanguageModel, LanguageModelConfig]:
214
+ foreign_decoder_config_file = download_config_file(model_spec)
215
+ foreign_decoder_config = model_spec.config_type.from_json(foreign_decoder_config_file)
216
+ assert isinstance(foreign_decoder_config, ForeignLMConfig)
217
+
218
+ if precision is None:
219
+ precision = foreign_decoder_config.default_precision
220
+ decoder = _load_main_processing_module(
221
+ model_spec,
222
+ precision,
223
+ foreign_decoder_config,
224
+ progress_callback,
225
+ context_length,
226
+ accumulation_precision,
227
+ )
228
+ assert isinstance(decoder, Decoder)
229
+
211
230
  if progress_callback is not None:
212
231
  progress_callback(FinishedInitializingModelEvent())
213
232
 
@@ -235,12 +254,85 @@ def import_model(
235
254
  )
236
255
 
237
256
  language_model_config = LanguageModelConfig(
238
- decoder_config=decoder.config,
257
+ model_config=decoder.config,
239
258
  message_processor_config=message_processor.config,
240
259
  generation_config=generation_config,
241
260
  )
242
261
 
243
262
  language_model = LanguageModel(language_model_config, decoder, message_processor)
263
+ return language_model, language_model_config
264
+
265
+
266
+ def _import_router(
267
+ model_spec: ModelSpec,
268
+ *,
269
+ context_length: int | None = None,
270
+ precision: DTypeLike | None = None,
271
+ accumulation_precision: DTypeLike = jnp.float32,
272
+ progress_callback: Callable[[StatusEvent], None] | None = None,
273
+ ) -> tuple[Router, RouterConfig]:
274
+ foreign_classifier_config_file = download_config_file(model_spec)
275
+ foreign_classifier_config = model_spec.config_type.from_json(foreign_classifier_config_file)
276
+ assert isinstance(foreign_classifier_config, ForeignClassifierConfig)
277
+
278
+ if precision is None:
279
+ precision = foreign_classifier_config.default_precision
280
+
281
+ classifier = _load_main_processing_module(
282
+ model_spec,
283
+ precision,
284
+ foreign_classifier_config,
285
+ progress_callback,
286
+ context_length,
287
+ accumulation_precision,
288
+ )
289
+ assert isinstance(classifier, Classifier)
290
+
291
+ if progress_callback is not None:
292
+ progress_callback(FinishedInitializingModelEvent())
293
+
294
+ message_processor = import_message_processor(model_spec)
295
+
296
+ router_config = RouterConfig(
297
+ model_config=classifier.config,
298
+ message_processor_config=message_processor.config,
299
+ )
300
+ router_model = Router(router_config, classifier, message_processor)
301
+ return router_model, router_config
302
+
303
+
304
+ def import_model(
305
+ model_spec: ModelSpec | str,
306
+ *,
307
+ context_length: int | None = None,
308
+ precision: DTypeLike | None = None,
309
+ accumulation_precision: DTypeLike = jnp.float32,
310
+ progress_callback: Callable[[StatusEvent], None] | None = None,
311
+ ) -> ImportResults:
312
+ if isinstance(model_spec, str):
313
+ try:
314
+ model_spec = REPO_TO_MODEL[model_spec]
315
+ except KeyError as e:
316
+ raise ValueError(f"Unknown model: {model_spec}") from e
317
+
318
+ match model_spec.model_type:
319
+ case ModelType.LANGUAGE_MODEL:
320
+ model, config = _import_language_model(
321
+ model_spec,
322
+ context_length=context_length,
323
+ precision=precision,
324
+ accumulation_precision=accumulation_precision,
325
+ progress_callback=progress_callback,
326
+ )
327
+ case ModelType.ROUTER_MODEL:
328
+ model, config = _import_router(
329
+ model_spec,
330
+ context_length=context_length,
331
+ precision=precision,
332
+ accumulation_precision=accumulation_precision,
333
+ progress_callback=progress_callback,
334
+ )
335
+
244
336
  metadata = ModelMetadata(
245
337
  toolchain_version=LALAMO_VERSION,
246
338
  vendor=model_spec.vendor,
@@ -250,6 +342,7 @@ def import_model(
250
342
  quantization=model_spec.quantization,
251
343
  repo=model_spec.repo,
252
344
  use_cases=model_spec.use_cases,
253
- model_config=language_model_config,
345
+ model_type=model_spec.model_type,
346
+ model_config=config,
254
347
  )
255
- return ImportResults(language_model, metadata)
348
+ return ImportResults(model, metadata)
@@ -1,4 +1,4 @@
1
- from .common import ForeignConfig
1
+ from .common import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
2
2
 
3
3
  # from .executorch import ETLlamaConfig
4
4
  from .huggingface import (
@@ -14,8 +14,10 @@ from .huggingface import (
14
14
  )
15
15
 
16
16
  __all__ = [
17
- # "ETLlamaConfig",
17
+ "ForeignClassifierConfig",
18
18
  "ForeignConfig",
19
+ # "ETLlamaConfig",
20
+ "ForeignLMConfig",
19
21
  "HFGPTOssConfig",
20
22
  "HFGemma2Config",
21
23
  "HFGemma3Config",
@@ -8,21 +8,18 @@ from typing import ClassVar, Self
8
8
  import cattrs
9
9
  from jaxtyping import Array, DTypeLike
10
10
 
11
- from lalamo.modules import Decoder, DecoderConfig
11
+ from lalamo.modules import ClassifierConfig, DecoderConfig
12
+ from lalamo.modules.common import LalamoModule
12
13
  from lalamo.registry_abc import RegistryABC
13
14
 
14
- __all__ = ["ForeignConfig"]
15
+ __all__ = ["ForeignClassifierConfig", "ForeignLMConfig"]
15
16
 
16
17
 
17
18
  @dataclass(frozen=True)
18
- class ForeignConfig(RegistryABC):
19
+ class ForeignConfig[ConfigT: DecoderConfig | ClassifierConfig](RegistryABC):
19
20
  _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
20
21
  _converter.register_structure_hook(int | list[int], lambda v, _: v)
21
22
 
22
- @property
23
- def eos_token_ids(self) -> list[int]:
24
- raise NotImplementedError
25
-
26
23
  @property
27
24
  @abstractmethod
28
25
  def default_precision(self) -> DTypeLike: ...
@@ -34,31 +31,75 @@ class ForeignConfig(RegistryABC):
34
31
  config = json.load(f)
35
32
  return cls._converter.structure(config, cls)
36
33
 
37
- def to_decoder_config(
34
+ @abstractmethod
35
+ def _load_weights(
36
+ self,
37
+ model: LalamoModule,
38
+ weights_dict: Mapping[str, Array],
39
+ ) -> LalamoModule: ...
40
+
41
+ @abstractmethod
42
+ def to_lalamo_config(
38
43
  self,
39
44
  context_length: int | None,
40
45
  activation_precision: DTypeLike,
41
46
  accumulation_precision: DTypeLike,
42
47
  metadata_dict: Mapping[str, str],
43
- ) -> DecoderConfig:
44
- raise NotImplementedError
45
-
46
- @classmethod
47
- def _load_weights(
48
- cls,
49
- model: Decoder,
50
- weights_dict: Mapping[str, Array],
51
- ) -> Decoder:
52
- raise NotImplementedError
48
+ ) -> ConfigT: ...
53
49
 
54
- def load_decoder(
50
+ def load(
55
51
  self,
56
52
  context_length: int | None,
57
53
  activation_precision: DTypeLike,
58
54
  accumulation_precision: DTypeLike,
59
55
  weights_dict: Mapping[str, Array],
60
56
  metadata_dict: Mapping[str, str],
61
- ) -> Decoder:
62
- config = self.to_decoder_config(context_length, activation_precision, accumulation_precision, metadata_dict)
57
+ ) -> LalamoModule[ConfigT]:
58
+ config = self.to_lalamo_config(context_length, activation_precision, accumulation_precision, metadata_dict)
63
59
  model = config.empty()
64
60
  return self._load_weights(model, weights_dict)
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class ForeignLMConfig(ForeignConfig, RegistryABC):
65
+ @abstractmethod
66
+ def to_decoder_config(
67
+ self,
68
+ context_length: int | None,
69
+ activation_precision: DTypeLike,
70
+ accumulation_precision: DTypeLike,
71
+ metadata_dict: Mapping[str, str],
72
+ ) -> DecoderConfig: ...
73
+
74
+ @property
75
+ @abstractmethod
76
+ def eos_token_ids(self) -> list[int]: ...
77
+
78
+ def to_lalamo_config(
79
+ self,
80
+ context_length: int | None,
81
+ activation_precision: DTypeLike,
82
+ accumulation_precision: DTypeLike,
83
+ metadata_dict: Mapping[str, str],
84
+ ) -> DecoderConfig:
85
+ return self.to_decoder_config(context_length, activation_precision, accumulation_precision, metadata_dict)
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class ForeignClassifierConfig(ForeignConfig, RegistryABC):
90
+ @abstractmethod
91
+ def to_classifier_config(
92
+ self,
93
+ context_length: int | None,
94
+ activation_precision: DTypeLike,
95
+ accumulation_precision: DTypeLike,
96
+ ) -> ClassifierConfig: ...
97
+
98
+ def to_lalamo_config(
99
+ self,
100
+ context_length: int | None,
101
+ activation_precision: DTypeLike,
102
+ accumulation_precision: DTypeLike,
103
+ metadata_dict: Mapping[str, str], # noqa: ARG002
104
+ ) -> ClassifierConfig:
105
+ return self.to_classifier_config(context_length, activation_precision, accumulation_precision)