lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +15 -2
- lalamo/data/__init__.py +0 -1
- lalamo/data/huggingface_message.py +1 -0
- lalamo/main.py +167 -18
- lalamo/message_processor.py +2 -3
- lalamo/model_import/common.py +120 -27
- lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo/model_import/decoder_configs/common.py +62 -21
- lalamo/model_import/decoder_configs/executorch.py +14 -9
- lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo/model_import/loaders/__init__.py +3 -2
- lalamo/model_import/loaders/executorch.py +24 -12
- lalamo/model_import/loaders/huggingface.py +258 -30
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +8 -2
- lalamo/model_import/model_specs/gemma.py +5 -1
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo/models/__init__.py +10 -0
- lalamo/models/common.py +81 -0
- lalamo/{language_model.py → models/language_model.py} +32 -49
- lalamo/models/router.py +59 -0
- lalamo/modules/__init__.py +33 -16
- lalamo/modules/classifier.py +339 -0
- lalamo/modules/common.py +6 -3
- lalamo/modules/decoder.py +52 -180
- lalamo/modules/mlp.py +28 -5
- lalamo/modules/normalization.py +13 -8
- lalamo/modules/token_mixers/attention.py +10 -6
- lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo/modules/transformer.py +273 -0
- lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
- lalamo/speculator/__init__.py +6 -2
- lalamo/speculator/estimator.py +91 -0
- lalamo/speculator/inference.py +28 -9
- lalamo/speculator/ngram.py +7 -3
- lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
- lalamo-0.5.4.dist-info/RECORD +88 -0
- lalamo-0.5.2.dist-info/RECORD +0 -80
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,6 @@ from lalamo.common import ParameterPath
|
|
|
9
9
|
from lalamo.modules import (
|
|
10
10
|
Attention,
|
|
11
11
|
Decoder,
|
|
12
|
-
DecoderLayer,
|
|
13
12
|
DenseMLP,
|
|
14
13
|
FullPrecisionLinear,
|
|
15
14
|
GroupQuantizedLinear,
|
|
@@ -18,18 +17,20 @@ from lalamo.modules import (
|
|
|
18
17
|
MLXQuantizedLinear,
|
|
19
18
|
MLXQuantizedTiedEmbedding,
|
|
20
19
|
MLXSemiQuantizedUntiedEmbedding,
|
|
21
|
-
|
|
20
|
+
Normalization,
|
|
22
21
|
SeparableCausalConv,
|
|
23
22
|
TiedEmbedding,
|
|
23
|
+
TransformerLayer,
|
|
24
24
|
UntiedEmbedding,
|
|
25
25
|
)
|
|
26
|
+
from lalamo.modules.classifier import Classifier
|
|
26
27
|
from lalamo.modules.mlp import MixtureOfExperts, MLPBase
|
|
27
28
|
from lalamo.quantization import QuantizationMode
|
|
28
29
|
|
|
29
30
|
from .common import load_parameters
|
|
30
31
|
from .utils import decode_mxfp4, deinterleave_pairwise_columns
|
|
31
32
|
|
|
32
|
-
__all__ = ["
|
|
33
|
+
__all__ = ["load_huggingface_decoder"]
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
AWQ_UINT4_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
|
|
@@ -42,15 +43,20 @@ def _reverse_uint4_order(array: Array, reverse_order: Array) -> Array:
|
|
|
42
43
|
if last_dim % pack_factor != 0:
|
|
43
44
|
return array
|
|
44
45
|
|
|
45
|
-
array_reshaped = rearrange(
|
|
46
|
+
array_reshaped = rearrange(
|
|
47
|
+
array,
|
|
48
|
+
"... (group pack_factor) -> ... group pack_factor",
|
|
49
|
+
pack_factor=pack_factor,
|
|
50
|
+
)
|
|
46
51
|
array_reordered = array_reshaped[..., reverse_order]
|
|
47
52
|
return rearrange(array_reordered, "... group pack_factor -> ... (group pack_factor)")
|
|
48
53
|
|
|
49
54
|
|
|
50
55
|
def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
|
|
51
|
-
assert packed_weights.dtype in (
|
|
52
|
-
|
|
53
|
-
|
|
56
|
+
assert packed_weights.dtype in (
|
|
57
|
+
jnp.int32,
|
|
58
|
+
jnp.uint32,
|
|
59
|
+
), f"Expected packed_weights to be of dtype jnp.(u)int32, got {packed_weights.dtype}"
|
|
54
60
|
assert 32 % mode.bits == 0
|
|
55
61
|
|
|
56
62
|
shifts = jnp.arange(0, 32, mode.bits)
|
|
@@ -309,7 +315,14 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
309
315
|
)
|
|
310
316
|
else:
|
|
311
317
|
# Fallback: recursively load a standard DenseMLP experts module
|
|
312
|
-
experts = load_mlp(
|
|
318
|
+
experts = load_mlp(
|
|
319
|
+
module.experts,
|
|
320
|
+
weights_dict,
|
|
321
|
+
experts_path,
|
|
322
|
+
"up_proj",
|
|
323
|
+
"gate_proj",
|
|
324
|
+
"down_proj",
|
|
325
|
+
)
|
|
313
326
|
|
|
314
327
|
return load_parameters(
|
|
315
328
|
lambda m: (m.router, m.experts),
|
|
@@ -319,10 +332,10 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
319
332
|
|
|
320
333
|
|
|
321
334
|
def load_rmsnorm(
|
|
322
|
-
module:
|
|
335
|
+
module: Normalization,
|
|
323
336
|
weights_dict: Mapping[str, Array],
|
|
324
337
|
path: ParameterPath,
|
|
325
|
-
) ->
|
|
338
|
+
) -> Normalization:
|
|
326
339
|
scales = weights_dict[path / "weight"]
|
|
327
340
|
return load_parameters(lambda m: (m.scales,), module, (scales,))
|
|
328
341
|
|
|
@@ -357,7 +370,13 @@ def load_attention(
|
|
|
357
370
|
sinks = module.sinks
|
|
358
371
|
|
|
359
372
|
return load_parameters(
|
|
360
|
-
lambda m: (
|
|
373
|
+
lambda m: (
|
|
374
|
+
m.qkv_projection,
|
|
375
|
+
m.out_projection,
|
|
376
|
+
m.query_norm,
|
|
377
|
+
m.key_norm,
|
|
378
|
+
m.sinks,
|
|
379
|
+
),
|
|
361
380
|
module,
|
|
362
381
|
(qkv_projection, out_projection, query_norm, key_norm, sinks),
|
|
363
382
|
)
|
|
@@ -420,14 +439,20 @@ def load_mamba2(
|
|
|
420
439
|
gate_bias = module.gate_bias
|
|
421
440
|
|
|
422
441
|
return load_parameters(
|
|
423
|
-
lambda m: (
|
|
442
|
+
lambda m: (
|
|
443
|
+
m.in_projection,
|
|
444
|
+
m.out_projection,
|
|
445
|
+
m.conv,
|
|
446
|
+
m.skip_connection_weight,
|
|
447
|
+
m.gate_bias,
|
|
448
|
+
),
|
|
424
449
|
module,
|
|
425
450
|
(in_projection, out_projection, conv, skip_connection_weight, gate_bias),
|
|
426
451
|
)
|
|
427
452
|
|
|
428
453
|
|
|
429
|
-
def
|
|
430
|
-
module:
|
|
454
|
+
def load_transformer_layer(
|
|
455
|
+
module: TransformerLayer,
|
|
431
456
|
weights_dict: Mapping[str, Array],
|
|
432
457
|
mixer_path: ParameterPath,
|
|
433
458
|
mlp_path: ParameterPath,
|
|
@@ -438,13 +463,16 @@ def load_decoder_layer(
|
|
|
438
463
|
up_proj_key: str,
|
|
439
464
|
gate_proj_key: str,
|
|
440
465
|
down_proj_key: str,
|
|
441
|
-
) ->
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
466
|
+
) -> TransformerLayer:
|
|
467
|
+
if module.pre_mixer_norm is not None:
|
|
468
|
+
pre_attention_norm = load_rmsnorm(
|
|
469
|
+
module.pre_mixer_norm,
|
|
470
|
+
weights_dict,
|
|
471
|
+
mixer_path / pre_mixer_norm_key,
|
|
472
|
+
)
|
|
447
473
|
|
|
474
|
+
else:
|
|
475
|
+
pre_attention_norm = None
|
|
448
476
|
# Load mixer (attention or mamba)
|
|
449
477
|
if isinstance(module.mixer, Attention):
|
|
450
478
|
mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
@@ -474,7 +502,14 @@ def load_decoder_layer(
|
|
|
474
502
|
mlp_path / pre_mlp_norm_key,
|
|
475
503
|
)
|
|
476
504
|
|
|
477
|
-
mlp = load_mlp(
|
|
505
|
+
mlp = load_mlp(
|
|
506
|
+
module.mlp,
|
|
507
|
+
weights_dict,
|
|
508
|
+
mlp_path / mlp_key,
|
|
509
|
+
up_proj_key,
|
|
510
|
+
gate_proj_key,
|
|
511
|
+
down_proj_key,
|
|
512
|
+
)
|
|
478
513
|
|
|
479
514
|
if module.post_mlp_norm is not None:
|
|
480
515
|
post_mlp_norm = load_rmsnorm(
|
|
@@ -486,9 +521,23 @@ def load_decoder_layer(
|
|
|
486
521
|
post_mlp_norm = None
|
|
487
522
|
|
|
488
523
|
return load_parameters(
|
|
489
|
-
lambda m: (
|
|
524
|
+
lambda m: (
|
|
525
|
+
m.pre_mixer_norm,
|
|
526
|
+
m.mixer,
|
|
527
|
+
m.post_mixer_norm,
|
|
528
|
+
m.pre_mlp_norm,
|
|
529
|
+
m.mlp,
|
|
530
|
+
m.post_mlp_norm,
|
|
531
|
+
),
|
|
490
532
|
module,
|
|
491
|
-
(
|
|
533
|
+
(
|
|
534
|
+
pre_attention_norm,
|
|
535
|
+
mixer,
|
|
536
|
+
post_attention_norm,
|
|
537
|
+
pre_mlp_norm,
|
|
538
|
+
mlp,
|
|
539
|
+
post_mlp_norm,
|
|
540
|
+
),
|
|
492
541
|
)
|
|
493
542
|
|
|
494
543
|
|
|
@@ -558,10 +607,14 @@ def load_untied_embedding(
|
|
|
558
607
|
) -> UntiedEmbedding:
|
|
559
608
|
input_weights = weights_dict[embedding_path / "weight"]
|
|
560
609
|
output_weights = weights_dict[lm_head_path / "weight"]
|
|
561
|
-
return load_parameters(
|
|
610
|
+
return load_parameters(
|
|
611
|
+
lambda m: (m.input_weights, m.output_weights),
|
|
612
|
+
module,
|
|
613
|
+
(input_weights, output_weights),
|
|
614
|
+
)
|
|
562
615
|
|
|
563
616
|
|
|
564
|
-
def
|
|
617
|
+
def load_huggingface_decoder(
|
|
565
618
|
module: Decoder,
|
|
566
619
|
weights_dict: Mapping[str, Array],
|
|
567
620
|
) -> Decoder:
|
|
@@ -629,7 +682,7 @@ def load_huggingface(
|
|
|
629
682
|
raise TypeError(f"Unsupported embedding type: {type(module.embedding)}")
|
|
630
683
|
|
|
631
684
|
decoder_layers = tuple(
|
|
632
|
-
|
|
685
|
+
load_transformer_layer(
|
|
633
686
|
layer,
|
|
634
687
|
weights_dict,
|
|
635
688
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
@@ -642,12 +695,187 @@ def load_huggingface(
|
|
|
642
695
|
gate_proj_key,
|
|
643
696
|
down_proj_key,
|
|
644
697
|
)
|
|
645
|
-
for i, layer in enumerate(module.layers)
|
|
698
|
+
for i, layer in enumerate(module.transformer.layers)
|
|
646
699
|
)
|
|
647
|
-
|
|
648
|
-
output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / norm_key)
|
|
700
|
+
output_norm = load_rmsnorm(module.transformer.output_norm, weights_dict, decoder_path / norm_key)
|
|
649
701
|
return load_parameters(
|
|
650
|
-
lambda m: (m.embedding, m.layers, m.output_norm),
|
|
702
|
+
lambda m: (m.embedding, m.transformer.layers, m.transformer.output_norm),
|
|
651
703
|
module,
|
|
652
704
|
(embedding, decoder_layers, output_norm),
|
|
653
705
|
)
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def load_huggingface_classifier(
|
|
709
|
+
module: Classifier,
|
|
710
|
+
weights_dict: Mapping[str, Array],
|
|
711
|
+
) -> Classifier:
|
|
712
|
+
def load_tied_embedding_local(
|
|
713
|
+
module: TiedEmbedding,
|
|
714
|
+
weights_dict: Mapping[str, Array],
|
|
715
|
+
decoder_path: ParameterPath,
|
|
716
|
+
) -> TiedEmbedding:
|
|
717
|
+
input_weights = weights_dict[decoder_path / "embeddings" / "tok_embeddings" / "weight"]
|
|
718
|
+
return load_parameters(lambda m: (m.weights,), module, (input_weights,))
|
|
719
|
+
|
|
720
|
+
def load_linear_with_reshufling(
|
|
721
|
+
module: LinearBase,
|
|
722
|
+
weights_dict: Mapping[str, Array],
|
|
723
|
+
path: ParameterPath,
|
|
724
|
+
) -> LinearBase:
|
|
725
|
+
"""Loads a linear layer and reshufle some weights in resulting matrix to meet
|
|
726
|
+
requirements of downstream 'split' in MLP layer in attention."""
|
|
727
|
+
|
|
728
|
+
assert not module.has_biases, "Expecting no biases in FullPrecisionLinear"
|
|
729
|
+
assert isinstance(module, FullPrecisionLinear), "Expecting FullPrecisionLinear module as input"
|
|
730
|
+
|
|
731
|
+
weights = weights_dict[path / "weight"]
|
|
732
|
+
rows, _ = weights.shape
|
|
733
|
+
shuffled_weights = jnp.vstack((weights[rows // 2 :, :], weights[: rows // 2, :]))
|
|
734
|
+
return load_parameters(lambda m: (m.weights, m.biases), module, (shuffled_weights, None))
|
|
735
|
+
|
|
736
|
+
def load_attention_local(
|
|
737
|
+
module: Attention,
|
|
738
|
+
weights_dict: Mapping[str, Array],
|
|
739
|
+
path: ParameterPath,
|
|
740
|
+
) -> Attention:
|
|
741
|
+
qkv_projection = load_linear(
|
|
742
|
+
module.qkv_projection,
|
|
743
|
+
weights_dict,
|
|
744
|
+
path / "Wqkv",
|
|
745
|
+
sublayers_to_fuse=None,
|
|
746
|
+
)
|
|
747
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / "Wo")
|
|
748
|
+
|
|
749
|
+
if module.query_norm is not None:
|
|
750
|
+
query_norm = load_rmsnorm(module.query_norm, weights_dict, path / "q_norm")
|
|
751
|
+
else:
|
|
752
|
+
query_norm = None
|
|
753
|
+
|
|
754
|
+
if module.key_norm is not None:
|
|
755
|
+
key_norm = load_rmsnorm(module.key_norm, weights_dict, path / "k_norm")
|
|
756
|
+
else:
|
|
757
|
+
key_norm = None
|
|
758
|
+
|
|
759
|
+
return load_parameters(
|
|
760
|
+
lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm),
|
|
761
|
+
module,
|
|
762
|
+
(qkv_projection, out_projection, query_norm, key_norm),
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
def load_mlp_local(module: MLPBase, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLPBase:
|
|
766
|
+
assert isinstance(module, DenseMLP)
|
|
767
|
+
up_projection = load_linear_with_reshufling(
|
|
768
|
+
module.up_projection,
|
|
769
|
+
weights_dict,
|
|
770
|
+
path / "Wi",
|
|
771
|
+
)
|
|
772
|
+
down_projection = load_linear(module.down_projection, weights_dict, path / "Wo")
|
|
773
|
+
return load_parameters(
|
|
774
|
+
lambda m: (m.up_projection, m.down_projection),
|
|
775
|
+
module,
|
|
776
|
+
(up_projection, down_projection),
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
def load_transformer_layer_local(
|
|
780
|
+
module: TransformerLayer,
|
|
781
|
+
weights_dict: Mapping[str, Array],
|
|
782
|
+
path: ParameterPath,
|
|
783
|
+
) -> TransformerLayer:
|
|
784
|
+
if module.pre_mixer_norm is not None:
|
|
785
|
+
pre_attention_norm = load_rmsnorm(
|
|
786
|
+
module.pre_mixer_norm,
|
|
787
|
+
weights_dict,
|
|
788
|
+
path / "attn_norm",
|
|
789
|
+
)
|
|
790
|
+
else:
|
|
791
|
+
pre_attention_norm = None
|
|
792
|
+
|
|
793
|
+
assert isinstance(module.mixer, Attention)
|
|
794
|
+
attention = load_attention_local(module.mixer, weights_dict, path / "attn")
|
|
795
|
+
if module.post_mixer_norm is not None:
|
|
796
|
+
post_attention_norm = load_rmsnorm(
|
|
797
|
+
module.post_mixer_norm,
|
|
798
|
+
weights_dict,
|
|
799
|
+
path / "post_attention_layernorm",
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
pre_mlp_norm = load_rmsnorm(
|
|
803
|
+
module.pre_mlp_norm,
|
|
804
|
+
weights_dict,
|
|
805
|
+
path / "pre_feedforward_layernorm",
|
|
806
|
+
)
|
|
807
|
+
else:
|
|
808
|
+
post_attention_norm = None
|
|
809
|
+
|
|
810
|
+
pre_mlp_norm = load_rmsnorm(
|
|
811
|
+
module.pre_mlp_norm,
|
|
812
|
+
weights_dict,
|
|
813
|
+
path / "mlp_norm",
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
mlp = load_mlp_local(module.mlp, weights_dict, path / "mlp")
|
|
817
|
+
if module.post_mlp_norm is not None:
|
|
818
|
+
post_mlp_norm = load_rmsnorm(
|
|
819
|
+
module.post_mlp_norm,
|
|
820
|
+
weights_dict,
|
|
821
|
+
path / "post_feedforward_layernorm",
|
|
822
|
+
)
|
|
823
|
+
else:
|
|
824
|
+
post_mlp_norm = None
|
|
825
|
+
return load_parameters(
|
|
826
|
+
lambda m: (
|
|
827
|
+
m.pre_mixer_norm,
|
|
828
|
+
m.mixer,
|
|
829
|
+
m.post_mixer_norm,
|
|
830
|
+
m.pre_mlp_norm,
|
|
831
|
+
m.mlp,
|
|
832
|
+
m.post_mlp_norm,
|
|
833
|
+
),
|
|
834
|
+
module,
|
|
835
|
+
(
|
|
836
|
+
pre_attention_norm,
|
|
837
|
+
attention,
|
|
838
|
+
post_attention_norm,
|
|
839
|
+
pre_mlp_norm,
|
|
840
|
+
mlp,
|
|
841
|
+
post_mlp_norm,
|
|
842
|
+
),
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
base_path = ParameterPath()
|
|
846
|
+
decoder_path = base_path / "model"
|
|
847
|
+
head_path = base_path / "head"
|
|
848
|
+
classifier_path = base_path / "classifier"
|
|
849
|
+
assert isinstance(module.embedding, TiedEmbedding)
|
|
850
|
+
embedding = load_tied_embedding_local(module.embedding, weights_dict, decoder_path)
|
|
851
|
+
embedding_norm = load_rmsnorm(module.embedding_norm, weights_dict, base_path / "model" / "embeddings" / "norm")
|
|
852
|
+
|
|
853
|
+
decoder_layers = tuple(
|
|
854
|
+
load_transformer_layer_local(layer, weights_dict, decoder_path / "layers" / i)
|
|
855
|
+
for i, layer in enumerate(module.transformer.layers)
|
|
856
|
+
)
|
|
857
|
+
output_norm = load_rmsnorm(module.transformer.output_norm, weights_dict, decoder_path / "final_norm")
|
|
858
|
+
head_dense = load_linear(module.prediction_head.dense, weights_dict, head_path / "dense")
|
|
859
|
+
head_norm = load_rmsnorm(module.prediction_head.norm, weights_dict, head_path / "norm")
|
|
860
|
+
head_readout = load_linear(module.prediction_head.readout, weights_dict, classifier_path)
|
|
861
|
+
return load_parameters(
|
|
862
|
+
lambda m: (
|
|
863
|
+
m.embedding,
|
|
864
|
+
m.embedding_norm,
|
|
865
|
+
m.transformer.layers,
|
|
866
|
+
m.transformer.output_norm,
|
|
867
|
+
m.prediction_head.dense,
|
|
868
|
+
m.prediction_head.norm,
|
|
869
|
+
m.prediction_head.readout,
|
|
870
|
+
),
|
|
871
|
+
module,
|
|
872
|
+
(
|
|
873
|
+
embedding,
|
|
874
|
+
embedding_norm,
|
|
875
|
+
decoder_layers,
|
|
876
|
+
output_norm,
|
|
877
|
+
head_dense,
|
|
878
|
+
head_norm,
|
|
879
|
+
head_readout,
|
|
880
|
+
),
|
|
881
|
+
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from .common import FileSpec, ModelSpec, UseCase, build_quantized_models
|
|
1
|
+
from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
|
|
2
2
|
from .deepseek import DEEPSEEK_MODELS
|
|
3
3
|
from .gemma import GEMMA_MODELS
|
|
4
4
|
from .gpt_oss import GPT_OSS_MODELS
|
|
5
5
|
from .huggingface import HUGGINGFACE_MODELS
|
|
6
6
|
from .llama import LLAMA_MODELS
|
|
7
7
|
from .llamba import LLAMBA_MODELS
|
|
8
|
+
from .mirai import MIRAI_ROUTER_MODELS
|
|
8
9
|
from .mistral import MISTRAL_MODELS
|
|
9
10
|
|
|
10
11
|
# from .pleias import PLEIAS_MODELS
|
|
@@ -17,6 +18,7 @@ __all__ = [
|
|
|
17
18
|
"REPO_TO_MODEL",
|
|
18
19
|
"FileSpec",
|
|
19
20
|
"ModelSpec",
|
|
21
|
+
"ModelType",
|
|
20
22
|
"UseCase",
|
|
21
23
|
]
|
|
22
24
|
|
|
@@ -33,9 +35,9 @@ ALL_MODEL_LISTS = [
|
|
|
33
35
|
POLARIS_MODELS,
|
|
34
36
|
QWEN_MODELS,
|
|
35
37
|
REKA_MODELS,
|
|
38
|
+
MIRAI_ROUTER_MODELS,
|
|
36
39
|
]
|
|
37
40
|
|
|
38
|
-
|
|
39
41
|
ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
|
|
40
42
|
|
|
41
43
|
|
|
@@ -5,7 +5,7 @@ from collections.abc import (
|
|
|
5
5
|
)
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
|
-
from enum import Enum
|
|
8
|
+
from enum import Enum, StrEnum
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import ClassVar, cast, get_args, get_origin
|
|
11
11
|
|
|
@@ -22,6 +22,7 @@ __all__ = [
|
|
|
22
22
|
"FileSpec",
|
|
23
23
|
"JSONFieldSpec",
|
|
24
24
|
"ModelSpec",
|
|
25
|
+
"ModelType",
|
|
25
26
|
"UseCase",
|
|
26
27
|
"WeightsType",
|
|
27
28
|
"awq_model_spec",
|
|
@@ -29,6 +30,11 @@ __all__ = [
|
|
|
29
30
|
]
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
class ModelType(StrEnum):
|
|
34
|
+
LANGUAGE_MODEL = "language_model"
|
|
35
|
+
ROUTER_MODEL = "router_model"
|
|
36
|
+
|
|
37
|
+
|
|
32
38
|
def cast_if_float(array: Array, cast_to: DTypeLike) -> Array:
|
|
33
39
|
if array.dtype in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
|
|
34
40
|
return array.astype(cast_to)
|
|
@@ -50,7 +56,6 @@ class WeightsType(Enum):
|
|
|
50
56
|
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
51
57
|
else:
|
|
52
58
|
import torch
|
|
53
|
-
|
|
54
59
|
from lalamo.modules.torch_interop import torch_to_jax
|
|
55
60
|
|
|
56
61
|
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
@@ -129,6 +134,7 @@ class ModelSpec:
|
|
|
129
134
|
assistant_role_name: str = "assistant"
|
|
130
135
|
tool_role_name: str = "tool"
|
|
131
136
|
weights_type: WeightsType = WeightsType.SAFETENSORS
|
|
137
|
+
model_type: ModelType = ModelType.LANGUAGE_MODEL
|
|
132
138
|
configs: ConfigMap = field(default=ConfigMap())
|
|
133
139
|
use_cases: tuple[UseCase, ...] = tuple()
|
|
134
140
|
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs.huggingface import ModernBERTConfig
|
|
2
|
+
|
|
3
|
+
from .common import ConfigMap, FileSpec, ModelSpec, ModelType
|
|
4
|
+
|
|
5
|
+
__all__ = ["MIRAI_ROUTER_MODELS"]
|
|
6
|
+
|
|
7
|
+
MIRAI_ROUTER_MODELS = [
|
|
8
|
+
ModelSpec(
|
|
9
|
+
vendor="trymirai",
|
|
10
|
+
family="ModernBERT",
|
|
11
|
+
name="ModernBERT-Chat-Moderation",
|
|
12
|
+
size="0.15B",
|
|
13
|
+
quantization=None,
|
|
14
|
+
repo="trymirai/chat-moderation-router",
|
|
15
|
+
config_type=ModernBERTConfig,
|
|
16
|
+
use_cases=tuple(),
|
|
17
|
+
model_type=ModelType("router_model"),
|
|
18
|
+
configs=ConfigMap(chat_template=FileSpec("chat_template.jinja")),
|
|
19
|
+
),
|
|
20
|
+
]
|
lalamo/models/common.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, replace
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Self
|
|
7
|
+
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
from jax import Array
|
|
10
|
+
from jax import numpy as jnp
|
|
11
|
+
from tokenizers import Tokenizer
|
|
12
|
+
|
|
13
|
+
from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
|
|
14
|
+
from lalamo.message_processor import Message, MessageProcessor, MessageProcessorConfig, UserMessage
|
|
15
|
+
from lalamo.modules import Classifier, Decoder, LalamoModule, config_converter
|
|
16
|
+
from lalamo.modules.classifier import ClassifierConfig, ClassifierResult
|
|
17
|
+
from lalamo.modules.decoder import DecoderConfig, DecoderResult
|
|
18
|
+
from lalamo.utils import open_safetensors
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"TextModel",
|
|
22
|
+
"TextModelConfig",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
|
|
28
|
+
model_config: ConfigT
|
|
29
|
+
message_processor_config: MessageProcessorConfig
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def init(
|
|
33
|
+
self,
|
|
34
|
+
model: LalamoModule,
|
|
35
|
+
message_processor: MessageProcessor,
|
|
36
|
+
) -> LalamoModule[Self]: ...
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def load_model(cls, path: Path | str) -> LalamoModule[Self]:
|
|
40
|
+
if isinstance(path, str):
|
|
41
|
+
path = Path(path)
|
|
42
|
+
with open(path / "config.json") as config_file:
|
|
43
|
+
config_json = json.load(config_file)
|
|
44
|
+
config = config_converter.structure(config_json["model_config"], cls)
|
|
45
|
+
with open_safetensors(path / "model.safetensors") as open_results:
|
|
46
|
+
weights_dict, _ = open_results
|
|
47
|
+
weights = unflatten_parameters(weights_dict)
|
|
48
|
+
model = config.model_config.empty().import_weights(weights)
|
|
49
|
+
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
|
50
|
+
message_processor = MessageProcessor(config.message_processor_config, tokenizer)
|
|
51
|
+
return config.init(model, message_processor)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TextModel[ConfigT, ModelT: Decoder | Classifier](LalamoModule[ConfigT]):
|
|
55
|
+
model: ModelT
|
|
56
|
+
message_processor: MessageProcessor = eqx.field(static=True)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def activation_precision(self) -> DTypeLike:
|
|
60
|
+
return self.model.activation_precision
|
|
61
|
+
|
|
62
|
+
def export_weights(self) -> ParameterTree:
|
|
63
|
+
return self.model.export_weights()
|
|
64
|
+
|
|
65
|
+
def import_weights(
|
|
66
|
+
self,
|
|
67
|
+
weights: ParameterTree[Array],
|
|
68
|
+
) -> Self:
|
|
69
|
+
return replace(
|
|
70
|
+
self,
|
|
71
|
+
model=self.model.import_weights(weights),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def record_trace(self, messages: Iterable[Message] | None = None) -> ClassifierResult | DecoderResult:
|
|
75
|
+
if messages is None:
|
|
76
|
+
messages = [UserMessage("Tell me about London")]
|
|
77
|
+
|
|
78
|
+
token_ids = jnp.array(self.message_processor.tokenize_request(messages))[None:]
|
|
79
|
+
_, num_tokens = token_ids.shape
|
|
80
|
+
token_positions = jnp.arange(num_tokens)[None, :]
|
|
81
|
+
return self.model(token_ids=token_ids, token_positions=token_positions, return_activation_trace=True)
|