lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +15 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +6 -2
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
  49. lalamo-0.5.4.dist-info/RECORD +88 -0
  50. lalamo-0.5.2.dist-info/RECORD +0 -80
  51. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,6 @@ from lalamo.common import ParameterPath
9
9
  from lalamo.modules import (
10
10
  Attention,
11
11
  Decoder,
12
- DecoderLayer,
13
12
  DenseMLP,
14
13
  FullPrecisionLinear,
15
14
  GroupQuantizedLinear,
@@ -18,18 +17,20 @@ from lalamo.modules import (
18
17
  MLXQuantizedLinear,
19
18
  MLXQuantizedTiedEmbedding,
20
19
  MLXSemiQuantizedUntiedEmbedding,
21
- RMSNorm,
20
+ Normalization,
22
21
  SeparableCausalConv,
23
22
  TiedEmbedding,
23
+ TransformerLayer,
24
24
  UntiedEmbedding,
25
25
  )
26
+ from lalamo.modules.classifier import Classifier
26
27
  from lalamo.modules.mlp import MixtureOfExperts, MLPBase
27
28
  from lalamo.quantization import QuantizationMode
28
29
 
29
30
  from .common import load_parameters
30
31
  from .utils import decode_mxfp4, deinterleave_pairwise_columns
31
32
 
32
- __all__ = ["load_huggingface"]
33
+ __all__ = ["load_huggingface_decoder"]
33
34
 
34
35
 
35
36
  AWQ_UINT4_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
@@ -42,15 +43,20 @@ def _reverse_uint4_order(array: Array, reverse_order: Array) -> Array:
42
43
  if last_dim % pack_factor != 0:
43
44
  return array
44
45
 
45
- array_reshaped = rearrange(array, "... (group pack_factor) -> ... group pack_factor", pack_factor=pack_factor)
46
+ array_reshaped = rearrange(
47
+ array,
48
+ "... (group pack_factor) -> ... group pack_factor",
49
+ pack_factor=pack_factor,
50
+ )
46
51
  array_reordered = array_reshaped[..., reverse_order]
47
52
  return rearrange(array_reordered, "... group pack_factor -> ... (group pack_factor)")
48
53
 
49
54
 
50
55
  def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
51
- assert packed_weights.dtype in (jnp.int32, jnp.uint32), (
52
- f"Expected packed_weights to be of dtype jnp.(u)int32, got {packed_weights.dtype}"
53
- )
56
+ assert packed_weights.dtype in (
57
+ jnp.int32,
58
+ jnp.uint32,
59
+ ), f"Expected packed_weights to be of dtype jnp.(u)int32, got {packed_weights.dtype}"
54
60
  assert 32 % mode.bits == 0
55
61
 
56
62
  shifts = jnp.arange(0, 32, mode.bits)
@@ -309,7 +315,14 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
309
315
  )
310
316
  else:
311
317
  # Fallback: recursively load a standard DenseMLP experts module
312
- experts = load_mlp(module.experts, weights_dict, experts_path, "up_proj", "gate_proj", "down_proj")
318
+ experts = load_mlp(
319
+ module.experts,
320
+ weights_dict,
321
+ experts_path,
322
+ "up_proj",
323
+ "gate_proj",
324
+ "down_proj",
325
+ )
313
326
 
314
327
  return load_parameters(
315
328
  lambda m: (m.router, m.experts),
@@ -319,10 +332,10 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
319
332
 
320
333
 
321
334
  def load_rmsnorm(
322
- module: RMSNorm,
335
+ module: Normalization,
323
336
  weights_dict: Mapping[str, Array],
324
337
  path: ParameterPath,
325
- ) -> RMSNorm:
338
+ ) -> Normalization:
326
339
  scales = weights_dict[path / "weight"]
327
340
  return load_parameters(lambda m: (m.scales,), module, (scales,))
328
341
 
@@ -357,7 +370,13 @@ def load_attention(
357
370
  sinks = module.sinks
358
371
 
359
372
  return load_parameters(
360
- lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm, m.sinks),
373
+ lambda m: (
374
+ m.qkv_projection,
375
+ m.out_projection,
376
+ m.query_norm,
377
+ m.key_norm,
378
+ m.sinks,
379
+ ),
361
380
  module,
362
381
  (qkv_projection, out_projection, query_norm, key_norm, sinks),
363
382
  )
@@ -420,14 +439,20 @@ def load_mamba2(
420
439
  gate_bias = module.gate_bias
421
440
 
422
441
  return load_parameters(
423
- lambda m: (m.in_projection, m.out_projection, m.conv, m.skip_connection_weight, m.gate_bias),
442
+ lambda m: (
443
+ m.in_projection,
444
+ m.out_projection,
445
+ m.conv,
446
+ m.skip_connection_weight,
447
+ m.gate_bias,
448
+ ),
424
449
  module,
425
450
  (in_projection, out_projection, conv, skip_connection_weight, gate_bias),
426
451
  )
427
452
 
428
453
 
429
- def load_decoder_layer(
430
- module: DecoderLayer,
454
+ def load_transformer_layer(
455
+ module: TransformerLayer,
431
456
  weights_dict: Mapping[str, Array],
432
457
  mixer_path: ParameterPath,
433
458
  mlp_path: ParameterPath,
@@ -438,13 +463,16 @@ def load_decoder_layer(
438
463
  up_proj_key: str,
439
464
  gate_proj_key: str,
440
465
  down_proj_key: str,
441
- ) -> DecoderLayer:
442
- pre_attention_norm = load_rmsnorm(
443
- module.pre_mixer_norm,
444
- weights_dict,
445
- mixer_path / pre_mixer_norm_key,
446
- )
466
+ ) -> TransformerLayer:
467
+ if module.pre_mixer_norm is not None:
468
+ pre_attention_norm = load_rmsnorm(
469
+ module.pre_mixer_norm,
470
+ weights_dict,
471
+ mixer_path / pre_mixer_norm_key,
472
+ )
447
473
 
474
+ else:
475
+ pre_attention_norm = None
448
476
  # Load mixer (attention or mamba)
449
477
  if isinstance(module.mixer, Attention):
450
478
  mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
@@ -474,7 +502,14 @@ def load_decoder_layer(
474
502
  mlp_path / pre_mlp_norm_key,
475
503
  )
476
504
 
477
- mlp = load_mlp(module.mlp, weights_dict, mlp_path / mlp_key, up_proj_key, gate_proj_key, down_proj_key)
505
+ mlp = load_mlp(
506
+ module.mlp,
507
+ weights_dict,
508
+ mlp_path / mlp_key,
509
+ up_proj_key,
510
+ gate_proj_key,
511
+ down_proj_key,
512
+ )
478
513
 
479
514
  if module.post_mlp_norm is not None:
480
515
  post_mlp_norm = load_rmsnorm(
@@ -486,9 +521,23 @@ def load_decoder_layer(
486
521
  post_mlp_norm = None
487
522
 
488
523
  return load_parameters(
489
- lambda m: (m.pre_mixer_norm, m.mixer, m.post_mixer_norm, m.pre_mlp_norm, m.mlp, m.post_mlp_norm),
524
+ lambda m: (
525
+ m.pre_mixer_norm,
526
+ m.mixer,
527
+ m.post_mixer_norm,
528
+ m.pre_mlp_norm,
529
+ m.mlp,
530
+ m.post_mlp_norm,
531
+ ),
490
532
  module,
491
- (pre_attention_norm, mixer, post_attention_norm, pre_mlp_norm, mlp, post_mlp_norm),
533
+ (
534
+ pre_attention_norm,
535
+ mixer,
536
+ post_attention_norm,
537
+ pre_mlp_norm,
538
+ mlp,
539
+ post_mlp_norm,
540
+ ),
492
541
  )
493
542
 
494
543
 
@@ -558,10 +607,14 @@ def load_untied_embedding(
558
607
  ) -> UntiedEmbedding:
559
608
  input_weights = weights_dict[embedding_path / "weight"]
560
609
  output_weights = weights_dict[lm_head_path / "weight"]
561
- return load_parameters(lambda m: (m.input_weights, m.output_weights), module, (input_weights, output_weights))
610
+ return load_parameters(
611
+ lambda m: (m.input_weights, m.output_weights),
612
+ module,
613
+ (input_weights, output_weights),
614
+ )
562
615
 
563
616
 
564
- def load_huggingface(
617
+ def load_huggingface_decoder(
565
618
  module: Decoder,
566
619
  weights_dict: Mapping[str, Array],
567
620
  ) -> Decoder:
@@ -629,7 +682,7 @@ def load_huggingface(
629
682
  raise TypeError(f"Unsupported embedding type: {type(module.embedding)}")
630
683
 
631
684
  decoder_layers = tuple(
632
- load_decoder_layer(
685
+ load_transformer_layer(
633
686
  layer,
634
687
  weights_dict,
635
688
  decoder_path / "layers" / ((i * 2) if alternating_layers else i),
@@ -642,12 +695,187 @@ def load_huggingface(
642
695
  gate_proj_key,
643
696
  down_proj_key,
644
697
  )
645
- for i, layer in enumerate(module.layers)
698
+ for i, layer in enumerate(module.transformer.layers)
646
699
  )
647
-
648
- output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / norm_key)
700
+ output_norm = load_rmsnorm(module.transformer.output_norm, weights_dict, decoder_path / norm_key)
649
701
  return load_parameters(
650
- lambda m: (m.embedding, m.layers, m.output_norm),
702
+ lambda m: (m.embedding, m.transformer.layers, m.transformer.output_norm),
651
703
  module,
652
704
  (embedding, decoder_layers, output_norm),
653
705
  )
706
+
707
+
708
+ def load_huggingface_classifier(
709
+ module: Classifier,
710
+ weights_dict: Mapping[str, Array],
711
+ ) -> Classifier:
712
+ def load_tied_embedding_local(
713
+ module: TiedEmbedding,
714
+ weights_dict: Mapping[str, Array],
715
+ decoder_path: ParameterPath,
716
+ ) -> TiedEmbedding:
717
+ input_weights = weights_dict[decoder_path / "embeddings" / "tok_embeddings" / "weight"]
718
+ return load_parameters(lambda m: (m.weights,), module, (input_weights,))
719
+
720
+ def load_linear_with_reshufling(
721
+ module: LinearBase,
722
+ weights_dict: Mapping[str, Array],
723
+ path: ParameterPath,
724
+ ) -> LinearBase:
725
+ """Loads a linear layer and reshufle some weights in resulting matrix to meet
726
+ requirements of downstream 'split' in MLP layer in attention."""
727
+
728
+ assert not module.has_biases, "Expecting no biases in FullPrecisionLinear"
729
+ assert isinstance(module, FullPrecisionLinear), "Expecting FullPrecisionLinear module as input"
730
+
731
+ weights = weights_dict[path / "weight"]
732
+ rows, _ = weights.shape
733
+ shuffled_weights = jnp.vstack((weights[rows // 2 :, :], weights[: rows // 2, :]))
734
+ return load_parameters(lambda m: (m.weights, m.biases), module, (shuffled_weights, None))
735
+
736
+ def load_attention_local(
737
+ module: Attention,
738
+ weights_dict: Mapping[str, Array],
739
+ path: ParameterPath,
740
+ ) -> Attention:
741
+ qkv_projection = load_linear(
742
+ module.qkv_projection,
743
+ weights_dict,
744
+ path / "Wqkv",
745
+ sublayers_to_fuse=None,
746
+ )
747
+ out_projection = load_linear(module.out_projection, weights_dict, path / "Wo")
748
+
749
+ if module.query_norm is not None:
750
+ query_norm = load_rmsnorm(module.query_norm, weights_dict, path / "q_norm")
751
+ else:
752
+ query_norm = None
753
+
754
+ if module.key_norm is not None:
755
+ key_norm = load_rmsnorm(module.key_norm, weights_dict, path / "k_norm")
756
+ else:
757
+ key_norm = None
758
+
759
+ return load_parameters(
760
+ lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm),
761
+ module,
762
+ (qkv_projection, out_projection, query_norm, key_norm),
763
+ )
764
+
765
+ def load_mlp_local(module: MLPBase, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLPBase:
766
+ assert isinstance(module, DenseMLP)
767
+ up_projection = load_linear_with_reshufling(
768
+ module.up_projection,
769
+ weights_dict,
770
+ path / "Wi",
771
+ )
772
+ down_projection = load_linear(module.down_projection, weights_dict, path / "Wo")
773
+ return load_parameters(
774
+ lambda m: (m.up_projection, m.down_projection),
775
+ module,
776
+ (up_projection, down_projection),
777
+ )
778
+
779
+ def load_transformer_layer_local(
780
+ module: TransformerLayer,
781
+ weights_dict: Mapping[str, Array],
782
+ path: ParameterPath,
783
+ ) -> TransformerLayer:
784
+ if module.pre_mixer_norm is not None:
785
+ pre_attention_norm = load_rmsnorm(
786
+ module.pre_mixer_norm,
787
+ weights_dict,
788
+ path / "attn_norm",
789
+ )
790
+ else:
791
+ pre_attention_norm = None
792
+
793
+ assert isinstance(module.mixer, Attention)
794
+ attention = load_attention_local(module.mixer, weights_dict, path / "attn")
795
+ if module.post_mixer_norm is not None:
796
+ post_attention_norm = load_rmsnorm(
797
+ module.post_mixer_norm,
798
+ weights_dict,
799
+ path / "post_attention_layernorm",
800
+ )
801
+
802
+ pre_mlp_norm = load_rmsnorm(
803
+ module.pre_mlp_norm,
804
+ weights_dict,
805
+ path / "pre_feedforward_layernorm",
806
+ )
807
+ else:
808
+ post_attention_norm = None
809
+
810
+ pre_mlp_norm = load_rmsnorm(
811
+ module.pre_mlp_norm,
812
+ weights_dict,
813
+ path / "mlp_norm",
814
+ )
815
+
816
+ mlp = load_mlp_local(module.mlp, weights_dict, path / "mlp")
817
+ if module.post_mlp_norm is not None:
818
+ post_mlp_norm = load_rmsnorm(
819
+ module.post_mlp_norm,
820
+ weights_dict,
821
+ path / "post_feedforward_layernorm",
822
+ )
823
+ else:
824
+ post_mlp_norm = None
825
+ return load_parameters(
826
+ lambda m: (
827
+ m.pre_mixer_norm,
828
+ m.mixer,
829
+ m.post_mixer_norm,
830
+ m.pre_mlp_norm,
831
+ m.mlp,
832
+ m.post_mlp_norm,
833
+ ),
834
+ module,
835
+ (
836
+ pre_attention_norm,
837
+ attention,
838
+ post_attention_norm,
839
+ pre_mlp_norm,
840
+ mlp,
841
+ post_mlp_norm,
842
+ ),
843
+ )
844
+
845
+ base_path = ParameterPath()
846
+ decoder_path = base_path / "model"
847
+ head_path = base_path / "head"
848
+ classifier_path = base_path / "classifier"
849
+ assert isinstance(module.embedding, TiedEmbedding)
850
+ embedding = load_tied_embedding_local(module.embedding, weights_dict, decoder_path)
851
+ embedding_norm = load_rmsnorm(module.embedding_norm, weights_dict, base_path / "model" / "embeddings" / "norm")
852
+
853
+ decoder_layers = tuple(
854
+ load_transformer_layer_local(layer, weights_dict, decoder_path / "layers" / i)
855
+ for i, layer in enumerate(module.transformer.layers)
856
+ )
857
+ output_norm = load_rmsnorm(module.transformer.output_norm, weights_dict, decoder_path / "final_norm")
858
+ head_dense = load_linear(module.prediction_head.dense, weights_dict, head_path / "dense")
859
+ head_norm = load_rmsnorm(module.prediction_head.norm, weights_dict, head_path / "norm")
860
+ head_readout = load_linear(module.prediction_head.readout, weights_dict, classifier_path)
861
+ return load_parameters(
862
+ lambda m: (
863
+ m.embedding,
864
+ m.embedding_norm,
865
+ m.transformer.layers,
866
+ m.transformer.output_norm,
867
+ m.prediction_head.dense,
868
+ m.prediction_head.norm,
869
+ m.prediction_head.readout,
870
+ ),
871
+ module,
872
+ (
873
+ embedding,
874
+ embedding_norm,
875
+ decoder_layers,
876
+ output_norm,
877
+ head_dense,
878
+ head_norm,
879
+ head_readout,
880
+ ),
881
+ )
@@ -1,10 +1,11 @@
1
- from .common import FileSpec, ModelSpec, UseCase, build_quantized_models
1
+ from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
2
2
  from .deepseek import DEEPSEEK_MODELS
3
3
  from .gemma import GEMMA_MODELS
4
4
  from .gpt_oss import GPT_OSS_MODELS
5
5
  from .huggingface import HUGGINGFACE_MODELS
6
6
  from .llama import LLAMA_MODELS
7
7
  from .llamba import LLAMBA_MODELS
8
+ from .mirai import MIRAI_ROUTER_MODELS
8
9
  from .mistral import MISTRAL_MODELS
9
10
 
10
11
  # from .pleias import PLEIAS_MODELS
@@ -17,6 +18,7 @@ __all__ = [
17
18
  "REPO_TO_MODEL",
18
19
  "FileSpec",
19
20
  "ModelSpec",
21
+ "ModelType",
20
22
  "UseCase",
21
23
  ]
22
24
 
@@ -33,9 +35,9 @@ ALL_MODEL_LISTS = [
33
35
  POLARIS_MODELS,
34
36
  QWEN_MODELS,
35
37
  REKA_MODELS,
38
+ MIRAI_ROUTER_MODELS,
36
39
  ]
37
40
 
38
-
39
41
  ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
40
42
 
41
43
 
@@ -5,7 +5,7 @@ from collections.abc import (
5
5
  )
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, field
8
- from enum import Enum
8
+ from enum import Enum, StrEnum
9
9
  from pathlib import Path
10
10
  from typing import ClassVar, cast, get_args, get_origin
11
11
 
@@ -22,6 +22,7 @@ __all__ = [
22
22
  "FileSpec",
23
23
  "JSONFieldSpec",
24
24
  "ModelSpec",
25
+ "ModelType",
25
26
  "UseCase",
26
27
  "WeightsType",
27
28
  "awq_model_spec",
@@ -29,6 +30,11 @@ __all__ = [
29
30
  ]
30
31
 
31
32
 
33
+ class ModelType(StrEnum):
34
+ LANGUAGE_MODEL = "language_model"
35
+ ROUTER_MODEL = "router_model"
36
+
37
+
32
38
  def cast_if_float(array: Array, cast_to: DTypeLike) -> Array:
33
39
  if array.dtype in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
34
40
  return array.astype(cast_to)
@@ -50,7 +56,6 @@ class WeightsType(Enum):
50
56
  yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
51
57
  else:
52
58
  import torch
53
-
54
59
  from lalamo.modules.torch_interop import torch_to_jax
55
60
 
56
61
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
@@ -129,6 +134,7 @@ class ModelSpec:
129
134
  assistant_role_name: str = "assistant"
130
135
  tool_role_name: str = "tool"
131
136
  weights_type: WeightsType = WeightsType.SAFETENSORS
137
+ model_type: ModelType = ModelType.LANGUAGE_MODEL
132
138
  configs: ConfigMap = field(default=ConfigMap())
133
139
  use_cases: tuple[UseCase, ...] = tuple()
134
140
 
@@ -1,4 +1,8 @@
1
- from lalamo.model_import.decoder_configs import HFGemma2Config, HFGemma3Config, HFGemma3TextConfig
1
+ from lalamo.model_import.decoder_configs import (
2
+ HFGemma2Config,
3
+ HFGemma3Config,
4
+ HFGemma3TextConfig,
5
+ )
2
6
 
3
7
  from .common import ModelSpec, WeightsType
4
8
 
@@ -14,5 +14,5 @@ HUGGINGFACE_MODELS = [
14
14
  repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
15
15
  config_type=HFLlamaConfig,
16
16
  use_cases=tuple(),
17
- ),
17
+ )
18
18
  ]
@@ -0,0 +1,20 @@
1
+ from lalamo.model_import.decoder_configs.huggingface import ModernBERTConfig
2
+
3
+ from .common import ConfigMap, FileSpec, ModelSpec, ModelType
4
+
5
+ __all__ = ["MIRAI_ROUTER_MODELS"]
6
+
7
+ MIRAI_ROUTER_MODELS = [
8
+ ModelSpec(
9
+ vendor="trymirai",
10
+ family="ModernBERT",
11
+ name="ModernBERT-Chat-Moderation",
12
+ size="0.15B",
13
+ quantization=None,
14
+ repo="trymirai/chat-moderation-router",
15
+ config_type=ModernBERTConfig,
16
+ use_cases=tuple(),
17
+ model_type=ModelType("router_model"),
18
+ configs=ConfigMap(chat_template=FileSpec("chat_template.jinja")),
19
+ ),
20
+ ]
@@ -0,0 +1,10 @@
1
+ from .language_model import GenerationConfig, LanguageModel, LanguageModelConfig
2
+ from .router import Router, RouterConfig
3
+
4
+ __all__ = [
5
+ "GenerationConfig",
6
+ "LanguageModel",
7
+ "LanguageModelConfig",
8
+ "Router",
9
+ "RouterConfig",
10
+ ]
@@ -0,0 +1,81 @@
1
+ import json
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, replace
5
+ from pathlib import Path
6
+ from typing import Self
7
+
8
+ import equinox as eqx
9
+ from jax import Array
10
+ from jax import numpy as jnp
11
+ from tokenizers import Tokenizer
12
+
13
+ from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
14
+ from lalamo.message_processor import Message, MessageProcessor, MessageProcessorConfig, UserMessage
15
+ from lalamo.modules import Classifier, Decoder, LalamoModule, config_converter
16
+ from lalamo.modules.classifier import ClassifierConfig, ClassifierResult
17
+ from lalamo.modules.decoder import DecoderConfig, DecoderResult
18
+ from lalamo.utils import open_safetensors
19
+
20
+ __all__ = [
21
+ "TextModel",
22
+ "TextModelConfig",
23
+ ]
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
28
+ model_config: ConfigT
29
+ message_processor_config: MessageProcessorConfig
30
+
31
+ @abstractmethod
32
+ def init(
33
+ self,
34
+ model: LalamoModule,
35
+ message_processor: MessageProcessor,
36
+ ) -> LalamoModule[Self]: ...
37
+
38
+ @classmethod
39
+ def load_model(cls, path: Path | str) -> LalamoModule[Self]:
40
+ if isinstance(path, str):
41
+ path = Path(path)
42
+ with open(path / "config.json") as config_file:
43
+ config_json = json.load(config_file)
44
+ config = config_converter.structure(config_json["model_config"], cls)
45
+ with open_safetensors(path / "model.safetensors") as open_results:
46
+ weights_dict, _ = open_results
47
+ weights = unflatten_parameters(weights_dict)
48
+ model = config.model_config.empty().import_weights(weights)
49
+ tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
50
+ message_processor = MessageProcessor(config.message_processor_config, tokenizer)
51
+ return config.init(model, message_processor)
52
+
53
+
54
+ class TextModel[ConfigT, ModelT: Decoder | Classifier](LalamoModule[ConfigT]):
55
+ model: ModelT
56
+ message_processor: MessageProcessor = eqx.field(static=True)
57
+
58
+ @property
59
+ def activation_precision(self) -> DTypeLike:
60
+ return self.model.activation_precision
61
+
62
+ def export_weights(self) -> ParameterTree:
63
+ return self.model.export_weights()
64
+
65
+ def import_weights(
66
+ self,
67
+ weights: ParameterTree[Array],
68
+ ) -> Self:
69
+ return replace(
70
+ self,
71
+ model=self.model.import_weights(weights),
72
+ )
73
+
74
+ def record_trace(self, messages: Iterable[Message] | None = None) -> ClassifierResult | DecoderResult:
75
+ if messages is None:
76
+ messages = [UserMessage("Tell me about London")]
77
+
78
+ token_ids = jnp.array(self.message_processor.tokenize_request(messages))[None:]
79
+ _, num_tokens = token_ids.shape
80
+ token_positions = jnp.arange(num_tokens)[None, :]
81
+ return self.model(token_ids=token_ids, token_positions=token_positions, return_activation_trace=True)