lalamo 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/common.py +2 -0
  3. lalamo/model_import/decoder_configs/__init__.py +2 -0
  4. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  5. lalamo/model_import/decoder_configs/huggingface/gemma3.py +31 -9
  6. lalamo/model_import/decoder_configs/huggingface/lfm2.py +174 -0
  7. lalamo/model_import/loaders/huggingface.py +71 -10
  8. lalamo/model_import/model_specs/__init__.py +4 -0
  9. lalamo/model_import/model_specs/common.py +1 -0
  10. lalamo/model_import/model_specs/essential_ai.py +17 -0
  11. lalamo/model_import/model_specs/huggingface.py +1 -1
  12. lalamo/model_import/model_specs/lfm2.py +21 -0
  13. lalamo/modules/__init__.py +6 -0
  14. lalamo/modules/token_mixers/__init__.py +15 -2
  15. lalamo/modules/token_mixers/common.py +1 -1
  16. lalamo/modules/token_mixers/mamba.py +2 -2
  17. lalamo/modules/token_mixers/short_conv.py +168 -0
  18. lalamo/modules/token_mixers/state/__init__.py +2 -0
  19. lalamo/modules/token_mixers/state/short_conv_state.py +33 -0
  20. lalamo/modules/transformer.py +18 -6
  21. lalamo/modules/transformer_layer.py +1 -1
  22. lalamo/utils.py +7 -0
  23. {lalamo-0.5.8.dist-info → lalamo-0.5.10.dist-info}/METADATA +1 -1
  24. {lalamo-0.5.8.dist-info → lalamo-0.5.10.dist-info}/RECORD +28 -23
  25. {lalamo-0.5.8.dist-info → lalamo-0.5.10.dist-info}/WHEEL +0 -0
  26. {lalamo-0.5.8.dist-info → lalamo-0.5.10.dist-info}/entry_points.txt +0 -0
  27. {lalamo-0.5.8.dist-info → lalamo-0.5.10.dist-info}/licenses/LICENSE +0 -0
  28. {lalamo-0.5.8.dist-info → lalamo-0.5.10.dist-info}/top_level.txt +0 -0
lalamo/__init__.py CHANGED
@@ -15,7 +15,7 @@ from lalamo.speculator import (
15
15
  SpeculatorTrainingEvent,
16
16
  )
17
17
 
18
- __version__ = "0.5.8"
18
+ __version__ = "0.5.10"
19
19
 
20
20
  __all__ = [
21
21
  "AssistantMessage",
@@ -17,6 +17,7 @@ from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
17
17
  from lalamo.models import ClassifierModel, ClassifierModelConfig, GenerationConfig, LanguageModel, LanguageModelConfig
18
18
  from lalamo.modules import Classifier, Decoder, LalamoModule
19
19
  from lalamo.quantization import QuantizationMode
20
+ from lalamo.utils import process_chat_template
20
21
 
21
22
  from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
22
23
  from .huggingface_generation_config import HFGenerationConfig
@@ -154,6 +155,7 @@ def import_message_processor(
154
155
  if model_spec.configs.chat_template is not None:
155
156
  raise ValueError("Conflicting chat template specifications.")
156
157
  prompt_template = tokenizer_config.chat_template
158
+ prompt_template = process_chat_template(prompt_template)
157
159
  tokenizer = Tokenizer.from_file(str(tokenizer_file))
158
160
 
159
161
  added_tokens = tokenizer_config.added_tokens()
@@ -6,6 +6,7 @@ from .huggingface import (
6
6
  HFGemma3Config,
7
7
  HFGemma3TextConfig,
8
8
  HFGPTOssConfig,
9
+ HFLFM2Config,
9
10
  HFLlamaConfig,
10
11
  HFLlambaConfig,
11
12
  HFMistralConfig,
@@ -22,6 +23,7 @@ __all__ = [
22
23
  "HFGemma2Config",
23
24
  "HFGemma3Config",
24
25
  "HFGemma3TextConfig",
26
+ "HFLFM2Config",
25
27
  "HFLlamaConfig",
26
28
  "HFLlambaConfig",
27
29
  "HFMistralConfig",
@@ -2,6 +2,7 @@ from .common import HuggingFaceLMConfig
2
2
  from .gemma2 import HFGemma2Config
3
3
  from .gemma3 import HFGemma3Config, HFGemma3TextConfig
4
4
  from .gpt_oss import HFGPTOssConfig
5
+ from .lfm2 import HFLFM2Config
5
6
  from .llama import HFLlamaConfig
6
7
  from .llamba import HFLlambaConfig
7
8
  from .mistral import HFMistralConfig
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "HFGemma2Config",
15
16
  "HFGemma3Config",
16
17
  "HFGemma3TextConfig",
18
+ "HFLFM2Config",
17
19
  "HFLlamaConfig",
18
20
  "HFLlambaConfig",
19
21
  "HFMistralConfig",
@@ -10,7 +10,7 @@ from lalamo.modules.activations import GELU
10
10
  from lalamo.modules.linear import FullPrecisionLinearConfig
11
11
  from lalamo.modules.mlp import DenseMLPConfig
12
12
  from lalamo.modules.normalization import NormalizationConfig, UpcastMode
13
- from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
13
+ from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig, YARNRoPEConfig
14
14
  from lalamo.modules.token_mixers.attention import AttentionConfig
15
15
  from lalamo.modules.transformer_layer import TransformerLayerConfig
16
16
 
@@ -19,9 +19,6 @@ from .common import HuggingFaceLMConfig
19
19
  __all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
20
20
 
21
21
 
22
- NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER = 6
23
-
24
-
25
22
  def _round_to_bfloat16(x: float) -> float:
26
23
  return jnp.asarray(x).astype(jnp.bfloat16).item()
27
24
 
@@ -32,6 +29,16 @@ class GemmaRoPEScalingConfig:
32
29
  rope_type: Literal["linear"]
33
30
 
34
31
 
32
+ @dataclass(frozen=True)
33
+ class YarnRopeScalingConfig:
34
+ factor: float
35
+ beta_fast: float
36
+ beta_slow: float
37
+ original_max_position_embeddings: int
38
+ rope_type: Literal["yarn"]
39
+ truncate: bool = False
40
+
41
+
35
42
  @dataclass(frozen=True)
36
43
  class HFGemma3TextConfigRaw:
37
44
  hidden_size: int
@@ -39,6 +46,7 @@ class HFGemma3TextConfigRaw:
39
46
  model_type: Literal["gemma3_text"]
40
47
  num_hidden_layers: int
41
48
  sliding_window: int
49
+ sliding_window_pattern: int
42
50
  rms_norm_eps: float = 1e-06
43
51
  query_pre_attn_scalar: float = 256.0
44
52
  attention_bias: bool = False
@@ -49,7 +57,7 @@ class HFGemma3TextConfigRaw:
49
57
  max_position_embeddings: int = 131072
50
58
  rope_theta: float = 1000000.0
51
59
  rope_local_base_freq: float = 10000.0
52
- rope_scaling: GemmaRoPEScalingConfig | None = None
60
+ rope_scaling: GemmaRoPEScalingConfig | YarnRopeScalingConfig | None = None
53
61
  final_logit_softcapping: float | None = None
54
62
  vocab_size: int = 262208
55
63
 
@@ -57,7 +65,7 @@ class HFGemma3TextConfigRaw:
57
65
  def sliding_window_sizes(self) -> list[int | None]:
58
66
  result = []
59
67
  for i in range(self.num_hidden_layers):
60
- if (i + 1) % NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER == 0:
68
+ if (i + 1) % self.sliding_window_pattern == 0:
61
69
  result.append(None)
62
70
  else:
63
71
  result.append(self.sliding_window)
@@ -74,7 +82,7 @@ class HFGemma3TextConfigRaw:
74
82
  attention_scale = self.query_pre_attn_scalar**-0.5
75
83
  embedding_config = TiedEmbeddingConfig(
76
84
  input_scale=input_scale,
77
- logit_soft_cap=None,
85
+ logit_soft_cap=self.final_logit_softcapping,
78
86
  precision=activation_precision,
79
87
  )
80
88
  rms_norm_config = NormalizationConfig(
@@ -86,19 +94,33 @@ class HFGemma3TextConfigRaw:
86
94
  subtract_mean=False,
87
95
  )
88
96
 
89
- if self.rope_scaling is not None:
97
+ if isinstance(self.rope_scaling, GemmaRoPEScalingConfig):
90
98
  global_rope_config = LinearScalingRoPEConfig(
91
99
  precision=activation_precision,
92
100
  base=self.rope_theta,
93
101
  max_sequence_length=self.max_position_embeddings,
94
102
  scaling_factor=self.rope_scaling.factor,
95
103
  )
96
- else:
104
+ elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
105
+ global_rope_config = YARNRoPEConfig(
106
+ precision=activation_precision,
107
+ base=self.rope_theta,
108
+ scaling_factor=self.rope_scaling.factor,
109
+ max_sequence_length=self.max_position_embeddings,
110
+ original_context_length=self.rope_scaling.original_max_position_embeddings,
111
+ beta_fast=self.rope_scaling.beta_fast,
112
+ beta_slow=self.rope_scaling.beta_slow,
113
+ truncate=self.rope_scaling.truncate,
114
+ )
115
+ elif self.rope_scaling is None:
97
116
  global_rope_config = UnscaledRoPEConfig(
98
117
  precision=activation_precision,
99
118
  base=self.rope_theta,
100
119
  max_sequence_length=context_length or self.max_position_embeddings,
101
120
  )
121
+ else:
122
+ raise ValueError("Invalid rope scaling configuration")
123
+
102
124
  local_rope_config = UnscaledRoPEConfig(
103
125
  precision=activation_precision,
104
126
  base=self.rope_local_base_freq,
@@ -0,0 +1,174 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass
3
+ from typing import Literal
4
+
5
+ from jaxtyping import DTypeLike
6
+
7
+ from lalamo.modules import (
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DenseMLPConfig,
11
+ FullPrecisionLinearConfig,
12
+ NormalizationConfig,
13
+ SeparableCausalConvConfig,
14
+ ShortConvConfig,
15
+ SiLU,
16
+ TiedEmbeddingConfig,
17
+ TransformerConfig,
18
+ TransformerLayerConfig,
19
+ UnscaledRoPEConfig,
20
+ UntiedEmbeddingConfig,
21
+ UpcastMode,
22
+ )
23
+
24
+ from .common import HuggingFaceLMConfig
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class HFLFM2Config(HuggingFaceLMConfig):
29
+ architectures: list[Literal["Lfm2ForCausalLM"]]
30
+ block_auto_adjust_ff_dim: Literal[False]
31
+ block_dim: int
32
+ block_ff_dim: int
33
+ block_ffn_dim_multiplier: float
34
+ block_mlp_init_scale: float
35
+ block_multiple_of: int
36
+ block_norm_eps: float
37
+ block_out_init_scale: float
38
+ block_use_swiglu: bool
39
+ block_use_xavier_init: bool
40
+ bos_token_id: int
41
+ conv_L_cache: int # noqa: N815
42
+ conv_bias: int
43
+ conv_dim: int
44
+ conv_dim_out: int
45
+ conv_use_xavier_init: bool
46
+ eos_token_id: int
47
+ hidden_size: int
48
+ initializer_range: float
49
+ intermediate_size: int
50
+ layer_types: list[Literal["conv", "full_attention"]]
51
+ max_position_embeddings: int
52
+ model_type: Literal["lfm2"]
53
+ norm_eps: float
54
+ num_attention_heads: int
55
+ num_heads: int
56
+ num_hidden_layers: int
57
+ num_key_value_heads: int
58
+ pad_token_id: int
59
+ rope_theta: float
60
+ theta: float
61
+ tie_embedding: bool
62
+ torch_dtype: Literal["bfloat16"]
63
+ transformers_version: str
64
+ use_cache: bool
65
+ use_pos_enc: bool
66
+ vocab_size: int
67
+
68
+ def to_decoder_config(
69
+ self,
70
+ context_length: int | None,
71
+ activation_precision: DTypeLike,
72
+ accumulation_precision: DTypeLike,
73
+ metadata_dict: Mapping[str, str], # noqa: ARG002
74
+ ) -> DecoderConfig:
75
+ assert self.num_attention_heads == self.num_heads
76
+
77
+ if self.tie_embedding:
78
+ embedding_config = TiedEmbeddingConfig(
79
+ input_scale=None,
80
+ logit_soft_cap=None,
81
+ precision=activation_precision,
82
+ )
83
+ else:
84
+ embedding_config = UntiedEmbeddingConfig(
85
+ input_scale=None,
86
+ logit_soft_cap=None,
87
+ precision=activation_precision,
88
+ )
89
+
90
+ rope_config = UnscaledRoPEConfig(
91
+ precision=activation_precision,
92
+ base=self.rope_theta,
93
+ max_sequence_length=context_length or self.max_position_embeddings,
94
+ )
95
+
96
+ linear_config = FullPrecisionLinearConfig(activation_precision)
97
+
98
+ block_norm_config = NormalizationConfig(
99
+ scale_precision=activation_precision,
100
+ accumulation_precision=accumulation_precision,
101
+ epsilon=self.block_norm_eps,
102
+ scale_offset=None,
103
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
104
+ subtract_mean=False,
105
+ )
106
+
107
+ attention_config = AttentionConfig(
108
+ qkv_projection_config=linear_config,
109
+ out_projection_config=linear_config,
110
+ query_norm_config=block_norm_config,
111
+ key_norm_config=block_norm_config,
112
+ num_heads=self.num_attention_heads,
113
+ num_groups=self.num_key_value_heads,
114
+ head_dim=self.hidden_size // self.num_heads,
115
+ is_causal=True,
116
+ scale=None,
117
+ sliding_window_size=None,
118
+ logit_soft_cap=None,
119
+ has_sinks=False,
120
+ has_qkv_biases=False,
121
+ has_out_biases=False,
122
+ )
123
+
124
+ short_conv_config = ShortConvConfig(
125
+ in_projection_config=linear_config,
126
+ conv_config=SeparableCausalConvConfig(activation_precision, has_biases=False),
127
+ out_projection_config=linear_config,
128
+ kernel_size=self.conv_L_cache,
129
+ )
130
+
131
+ mlp_config = DenseMLPConfig(
132
+ linear_config=linear_config,
133
+ activation=SiLU(),
134
+ has_up_biases=False,
135
+ has_down_biases=False,
136
+ up_clipping=None,
137
+ gate_clipping=None,
138
+ )
139
+
140
+ layer_configs = [
141
+ TransformerLayerConfig(
142
+ pre_mixer_norm_config=block_norm_config,
143
+ mixer_config={"conv": short_conv_config, "full_attention": attention_config}[layer_type],
144
+ post_mixer_norm_config=None,
145
+ pre_mlp_norm_config=block_norm_config,
146
+ mlp_config=mlp_config,
147
+ post_mlp_norm_config=None,
148
+ ) for layer_type in self.layer_types
149
+ ]
150
+
151
+ output_norm_config = NormalizationConfig(
152
+ scale_precision=activation_precision,
153
+ accumulation_precision=accumulation_precision,
154
+ epsilon=self.norm_eps,
155
+ scale_offset=None,
156
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
157
+ subtract_mean=False,
158
+ )
159
+
160
+ transformer_config = TransformerConfig(
161
+ global_rope_config=rope_config,
162
+ local_rope_config=None,
163
+ layer_configs=tuple(layer_configs),
164
+ output_norm_config=output_norm_config,
165
+ model_dim=self.hidden_size,
166
+ hidden_dim=self.intermediate_size,
167
+ context_length=context_length or self.max_position_embeddings,
168
+ )
169
+
170
+ return DecoderConfig(
171
+ embedding_config=embedding_config,
172
+ transformer_config=transformer_config,
173
+ vocab_size=self.vocab_size,
174
+ )
@@ -8,17 +8,21 @@ from jaxtyping import Array, DTypeLike
8
8
  from lalamo.common import ParameterPath
9
9
  from lalamo.modules import (
10
10
  Attention,
11
+ AttentionConfig,
11
12
  Decoder,
12
13
  DenseMLP,
13
14
  FullPrecisionLinear,
14
15
  GroupQuantizedLinear,
15
16
  LinearBase,
16
17
  Mamba2,
18
+ Mamba2Config,
17
19
  MLXQuantizedLinear,
18
20
  MLXQuantizedTiedEmbedding,
19
21
  MLXSemiQuantizedUntiedEmbedding,
20
22
  Normalization,
21
23
  SeparableCausalConv,
24
+ ShortConv,
25
+ ShortConvConfig,
22
26
  TiedEmbedding,
23
27
  TransformerLayer,
24
28
  UntiedEmbedding,
@@ -300,7 +304,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
300
304
  down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
301
305
  down_b = weights_dict[experts_path / "down_proj_bias"]
302
306
  if down_b.ndim == 1:
303
- down_b = jnp.broadcast_to(down_b, down_w.shape[:-1] + (down_b.shape[0],))
307
+ down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
304
308
 
305
309
  down_projection = load_parameters(
306
310
  lambda m: (m.weights, m.biases), # type: ignore
@@ -345,21 +349,42 @@ def load_attention(
345
349
  weights_dict: Mapping[str, Array],
346
350
  path: ParameterPath,
347
351
  ) -> Attention:
352
+ if (path / "o_proj.weight") in weights_dict:
353
+ o_proj_name = "o_proj"
354
+ elif (path / "out_proj.weight") in weights_dict:
355
+ o_proj_name = "out_proj"
356
+ else:
357
+ raise NotImplementedError("Can't determine attention output projection name")
358
+
348
359
  qkv_projection = load_linear(
349
360
  module.qkv_projection,
350
361
  weights_dict,
351
362
  path,
352
363
  sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
353
364
  )
354
- out_projection = load_linear(module.out_projection, weights_dict, path / "o_proj")
365
+ out_projection = load_linear(module.out_projection, weights_dict, path / o_proj_name)
355
366
 
356
367
  if module.query_norm is not None:
357
- query_norm = load_rmsnorm(module.query_norm, weights_dict, path / "q_norm")
368
+ if (path / "q_norm.weight") in weights_dict:
369
+ q_norm_name = "q_norm"
370
+ elif (path / "q_layernorm.weight") in weights_dict:
371
+ q_norm_name = "q_layernorm"
372
+ else:
373
+ raise NotImplementedError("Can't determine attention query projection parameter name")
374
+
375
+ query_norm = load_rmsnorm(module.query_norm, weights_dict, path / q_norm_name)
358
376
  else:
359
377
  query_norm = None
360
378
 
361
379
  if module.key_norm is not None:
362
- key_norm = load_rmsnorm(module.key_norm, weights_dict, path / "k_norm")
380
+ if (path / "k_norm.weight") in weights_dict:
381
+ k_norm_name = "k_norm"
382
+ elif (path / "k_layernorm.weight") in weights_dict:
383
+ k_norm_name = "k_layernorm"
384
+ else:
385
+ raise NotImplementedError("Can't determine attention key projection parameter name")
386
+
387
+ key_norm = load_rmsnorm(module.key_norm, weights_dict, path / k_norm_name)
363
388
  else:
364
389
  key_norm = None
365
390
 
@@ -382,7 +407,7 @@ def load_attention(
382
407
  )
383
408
 
384
409
 
385
- def _load_mamba_conv(
410
+ def _load_conv(
386
411
  conv_module: SeparableCausalConv,
387
412
  weights_dict: Mapping[str, Array],
388
413
  path: ParameterPath,
@@ -390,6 +415,8 @@ def _load_mamba_conv(
390
415
  weight_path = path / "conv1d" / "weight"
391
416
  if weight_path not in weights_dict:
392
417
  weight_path = path / "conv_weight"
418
+ if weight_path not in weights_dict:
419
+ weight_path = path / "conv.weight"
393
420
  if weight_path not in weights_dict:
394
421
  weight_path = None
395
422
 
@@ -402,6 +429,8 @@ def _load_mamba_conv(
402
429
  bias_path = path / "conv1d" / "bias"
403
430
  if bias_path not in weights_dict:
404
431
  bias_path = path / "conv_bias"
432
+ if bias_path not in weights_dict:
433
+ bias_path = path / "conv.bias"
405
434
  if bias_path not in weights_dict:
406
435
  bias_path = None
407
436
 
@@ -424,7 +453,7 @@ def load_mamba2(
424
453
  ) -> Mamba2:
425
454
  in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
426
455
  out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
427
- conv = _load_mamba_conv(module.conv, weights_dict, path)
456
+ conv = _load_conv(module.conv, weights_dict, path)
428
457
 
429
458
  skip_connection_weight_path = path / "D"
430
459
  if skip_connection_weight_path in weights_dict:
@@ -451,6 +480,22 @@ def load_mamba2(
451
480
  )
452
481
 
453
482
 
483
+ def load_short_conv(
484
+ module: ShortConv,
485
+ weights_dict: Mapping[str, Array],
486
+ path: ParameterPath,
487
+ ) -> ShortConv:
488
+ in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
489
+ out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
490
+ conv = _load_conv(module.conv, weights_dict, path)
491
+
492
+ return load_parameters(
493
+ lambda m: (m.in_projection, m.out_projection, m.conv),
494
+ module,
495
+ (in_projection, out_projection, conv),
496
+ )
497
+
498
+
454
499
  def load_transformer_layer(
455
500
  module: TransformerLayer,
456
501
  weights_dict: Mapping[str, Array],
@@ -478,6 +523,8 @@ def load_transformer_layer(
478
523
  mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
479
524
  elif isinstance(module.mixer, Mamba2):
480
525
  mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
526
+ elif isinstance(module.mixer, ShortConv):
527
+ mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key)
481
528
  else:
482
529
  mixer = module.mixer
483
530
 
@@ -625,11 +672,12 @@ def load_huggingface_decoder(
625
672
 
626
673
  is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
627
674
  is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
675
+ is_lfm2 = any(key.startswith("model.layers.0.operator_norm.weight") for key in weights_dict)
628
676
  if is_llamba_full_precision:
629
677
  decoder_path = base_path / "backbone"
630
678
  embedding_path = decoder_path / "embedding"
631
679
  pre_mixer_norm_key = "input_layernorm"
632
- mixer_key = "mixer"
680
+ mixer_key = {Mamba2Config: "mixer"}
633
681
  pre_mlp_norm_key = "post_attention_layernorm"
634
682
  mlp_key = "mlp"
635
683
  up_proj_key = "up_proj"
@@ -642,7 +690,7 @@ def load_huggingface_decoder(
642
690
  decoder_path = base_path / "model"
643
691
  embedding_path = base_path / "embedding.encoder"
644
692
  pre_mixer_norm_key = "norm"
645
- mixer_key = "layer"
693
+ mixer_key = {Mamba2Config: "layer"}
646
694
  pre_mlp_norm_key = "norm"
647
695
  mlp_key = "layer"
648
696
  up_proj_key = "gate_proj"
@@ -651,11 +699,24 @@ def load_huggingface_decoder(
651
699
  alternating_layers = True
652
700
  norm_key = "norm"
653
701
  lm_head_path = base_path / "head.linear"
702
+ elif is_lfm2:
703
+ decoder_path = base_path / "model"
704
+ embedding_path = decoder_path / "embed_tokens"
705
+ pre_mixer_norm_key = "operator_norm"
706
+ mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
707
+ pre_mlp_norm_key = "ffn_norm"
708
+ mlp_key = "feed_forward"
709
+ up_proj_key = "w3"
710
+ gate_proj_key = "w1"
711
+ down_proj_key = "w2"
712
+ alternating_layers = False
713
+ norm_key = "embedding_norm"
714
+ lm_head_path = base_path / "lm_head"
654
715
  else:
655
716
  decoder_path = base_path / "model"
656
717
  embedding_path = decoder_path / "embed_tokens"
657
718
  pre_mixer_norm_key = "input_layernorm"
658
- mixer_key = "self_attn"
719
+ mixer_key = {AttentionConfig: "self_attn"}
659
720
  pre_mlp_norm_key = "post_attention_layernorm"
660
721
  mlp_key = "mlp"
661
722
  up_proj_key = "up_proj"
@@ -687,7 +748,7 @@ def load_huggingface_decoder(
687
748
  weights_dict,
688
749
  decoder_path / "layers" / ((i * 2) if alternating_layers else i),
689
750
  decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
690
- mixer_key,
751
+ mixer_key[type(layer.config.mixer_config)], # type: ignore
691
752
  mlp_key,
692
753
  pre_mixer_norm_key,
693
754
  pre_mlp_norm_key,
@@ -1,8 +1,10 @@
1
1
  from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
2
2
  from .deepseek import DEEPSEEK_MODELS
3
+ from .essential_ai import RNJ_MODELS
3
4
  from .gemma import GEMMA_MODELS
4
5
  from .gpt_oss import GPT_OSS_MODELS
5
6
  from .huggingface import HUGGINGFACE_MODELS
7
+ from .lfm2 import LFM2_MODELS
6
8
  from .llama import LLAMA_MODELS
7
9
  from .llamba import LLAMBA_MODELS
8
10
  from .mirai import MIRAI_CLASSIFIER_MODELS
@@ -24,6 +26,7 @@ __all__ = [
24
26
 
25
27
 
26
28
  ALL_MODEL_LISTS = [
29
+ LFM2_MODELS,
27
30
  LLAMA_MODELS,
28
31
  LLAMBA_MODELS,
29
32
  DEEPSEEK_MODELS,
@@ -36,6 +39,7 @@ ALL_MODEL_LISTS = [
36
39
  QWEN_MODELS,
37
40
  REKA_MODELS,
38
41
  MIRAI_CLASSIFIER_MODELS,
42
+ RNJ_MODELS,
39
43
  ]
40
44
 
41
45
  ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
@@ -56,6 +56,7 @@ class WeightsType(Enum):
56
56
  yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
57
57
  else:
58
58
  import torch
59
+
59
60
  from lalamo.modules.torch_interop import torch_to_jax
60
61
 
61
62
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
@@ -0,0 +1,17 @@
1
+ from lalamo.model_import.decoder_configs.huggingface import HFGemma3TextConfig
2
+
3
+ from .common import ModelSpec
4
+
5
+ __all__ = ["RNJ_MODELS"]
6
+
7
+ RNJ_MODELS = [
8
+ ModelSpec(
9
+ vendor="EssentialAI",
10
+ family="Rnj-1",
11
+ name="Rnj-1-Instruct",
12
+ size="8B",
13
+ quantization=None,
14
+ repo="EssentialAI/rnj-1-instruct",
15
+ config_type=HFGemma3TextConfig,
16
+ ),
17
+ ]
@@ -14,5 +14,5 @@ HUGGINGFACE_MODELS = [
14
14
  repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
15
15
  config_type=HFLlamaConfig,
16
16
  use_cases=tuple(),
17
- )
17
+ ),
18
18
  ]
@@ -0,0 +1,21 @@
1
+ from lalamo.model_import.decoder_configs import HFLFM2Config
2
+
3
+ from .common import ConfigMap, FileSpec, ModelSpec
4
+
5
+ __all__ = ["LFM2_MODELS"]
6
+
7
+ LFM2_MODELS = [
8
+ ModelSpec(
9
+ vendor="LiquidAI",
10
+ family="LFM2",
11
+ name="LFM2-2.6B",
12
+ size="2.6B",
13
+ repo="LiquidAI/LFM2-2.6B",
14
+ config_type=HFLFM2Config,
15
+ quantization=None,
16
+ configs=ConfigMap(
17
+ chat_template=FileSpec("chat_template.jinja"),
18
+ ),
19
+ use_cases=tuple(),
20
+ ),
21
+ ]
@@ -69,6 +69,9 @@ from .token_mixers import (
69
69
  Mamba2Config,
70
70
  SeparableCausalConv,
71
71
  SeparableCausalConvConfig,
72
+ ShortConv,
73
+ ShortConvConfig,
74
+ ShortConvStateLayer,
72
75
  State,
73
76
  StaticKVCacheLayer,
74
77
  )
@@ -136,6 +139,9 @@ __all__ = [
136
139
  "RoutingFunction",
137
140
  "SeparableCausalConv",
138
141
  "SeparableCausalConvConfig",
142
+ "ShortConv",
143
+ "ShortConvConfig",
144
+ "ShortConvStateLayer",
139
145
  "SiLU",
140
146
  "SoftmaxRouting",
141
147
  "State",
@@ -3,9 +3,18 @@ from lalamo.modules.common import register_config_union
3
3
  from .attention import Attention, AttentionConfig, AttentionResult
4
4
  from .common import TokenMixerBase, TokenMixerResult
5
5
  from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
6
- from .state import DynamicKVCacheLayer, KVCacheLayer, Mamba2StateLayer, State, StateLayerBase, StaticKVCacheLayer
6
+ from .short_conv import ShortConv, ShortConvConfig, ShortConvResult
7
+ from .state import (
8
+ DynamicKVCacheLayer,
9
+ KVCacheLayer,
10
+ Mamba2StateLayer,
11
+ ShortConvStateLayer,
12
+ State,
13
+ StateLayerBase,
14
+ StaticKVCacheLayer,
15
+ )
7
16
 
8
- TokenMixerConfig = AttentionConfig | Mamba2Config
17
+ TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
9
18
 
10
19
  register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
11
20
 
@@ -21,6 +30,10 @@ __all__ = [
21
30
  "Mamba2StateLayer",
22
31
  "SeparableCausalConv",
23
32
  "SeparableCausalConvConfig",
33
+ "ShortConv",
34
+ "ShortConvConfig",
35
+ "ShortConvResult",
36
+ "ShortConvStateLayer",
24
37
  "State",
25
38
  "StateLayerBase",
26
39
  "StaticKVCacheLayer",
@@ -25,7 +25,7 @@ class TokenMixerResult[StateLayerT](NamedTuple):
25
25
  class TokenMixerConfigBase(ABC):
26
26
  @property
27
27
  @abstractmethod
28
- def rope_dim(self) -> int: ...
28
+ def rope_dim(self) -> int | None: ...
29
29
 
30
30
  @abstractmethod
31
31
  def random_init(
@@ -184,8 +184,8 @@ class Mamba2Config(TokenMixerConfigBase):
184
184
  return self.num_heads * self.head_dim
185
185
 
186
186
  @property
187
- def rope_dim(self) -> int:
188
- return self.head_dim
187
+ def rope_dim(self) -> None:
188
+ return None
189
189
 
190
190
  def random_init(
191
191
  self,
@@ -0,0 +1,168 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass, replace
3
+ from typing import Self
4
+
5
+ import equinox as eqx
6
+ from jax import vmap
7
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
8
+
9
+ from lalamo.common import ParameterTree
10
+ from lalamo.modules.common import PositionalEmbeddingSelector
11
+ from lalamo.modules.linear import LinearBase, LinearConfig
12
+ from lalamo.modules.rope import PositionalEmbeddings
13
+
14
+ from .common import TokenMixerBase, TokenMixerConfigBase, TokenMixerResult
15
+ from .mamba import SeparableCausalConv, SeparableCausalConvConfig
16
+ from .state import ShortConvStateLayer
17
+
18
+ __all__ = [
19
+ "ShortConv",
20
+ "ShortConvConfig",
21
+ "ShortConvResult",
22
+ ]
23
+
24
+
25
+ ShortConvResult = TokenMixerResult[ShortConvStateLayer]
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class ShortConvConfig(TokenMixerConfigBase):
30
+ in_projection_config: LinearConfig
31
+ conv_config: SeparableCausalConvConfig
32
+ out_projection_config: LinearConfig
33
+
34
+ kernel_size: int
35
+
36
+ @property
37
+ def rope_dim(self) -> None:
38
+ return None
39
+
40
+ def random_init(
41
+ self,
42
+ model_dim: int,
43
+ *,
44
+ key: PRNGKeyArray,
45
+ ) -> "ShortConv":
46
+ in_projection = self.in_projection_config.random_init(
47
+ input_dim=model_dim,
48
+ output_dims=(model_dim,)*3,
49
+ has_biases=False,
50
+ key=key,
51
+ )
52
+
53
+ conv = self.conv_config.random_init(model_dim, self.kernel_size, key=key)
54
+
55
+ out_projection = self.out_projection_config.random_init(
56
+ input_dim=model_dim,
57
+ output_dims=(model_dim,),
58
+ has_biases=False,
59
+ key=key,
60
+ )
61
+
62
+ return ShortConv(
63
+ self,
64
+ in_projection=in_projection,
65
+ conv=conv,
66
+ out_projection=out_projection,
67
+ )
68
+
69
+ def empty(
70
+ self,
71
+ model_dim: int,
72
+ ) -> "ShortConv":
73
+ in_projection = self.in_projection_config.empty(
74
+ input_dim=model_dim,
75
+ output_dims=(model_dim,)*3,
76
+ has_biases=False,
77
+ )
78
+
79
+ conv = self.conv_config.empty(model_dim, self.kernel_size)
80
+
81
+ out_projection = self.out_projection_config.empty(
82
+ input_dim=model_dim,
83
+ output_dims=(model_dim,),
84
+ has_biases=False,
85
+ )
86
+
87
+ return ShortConv(
88
+ self,
89
+ in_projection=in_projection,
90
+ conv=conv,
91
+ out_projection=out_projection,
92
+ )
93
+
94
+
95
+ class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
96
+ in_projection: LinearBase
97
+ conv: SeparableCausalConv
98
+ out_projection: LinearBase
99
+
100
+ @property
101
+ def activation_precision(self) -> DTypeLike:
102
+ return self.in_projection.activation_precision
103
+
104
+ @property
105
+ def model_dim(self) -> int:
106
+ return self.in_projection.input_dim
107
+
108
+ @property
109
+ def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
110
+ return PositionalEmbeddingSelector.NONE
111
+
112
+ @eqx.filter_jit
113
+ def __call__(
114
+ self,
115
+ inputs: Float[Array, "suffix_tokens channels"],
116
+ positional_embeddings: PositionalEmbeddings | None,
117
+ state: ShortConvStateLayer | None = None,
118
+ return_updated_state: bool = False,
119
+ length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
120
+ ) -> TokenMixerResult[ShortConvStateLayer]:
121
+ if positional_embeddings is not None:
122
+ raise ValueError("Positional embeddings are not supported for ShortConv.")
123
+
124
+ pre_conv_gate, post_conv_gate, x = vmap(self.in_projection)(inputs)
125
+
126
+ prev_conv_state = state.conv_state if state is not None else None
127
+ conv_output = self.conv(x * pre_conv_gate, prev_conv_state, return_updated_state)
128
+
129
+ (outputs,) = vmap(self.out_projection)(conv_output.outputs * post_conv_gate)
130
+ updated_conv_state = conv_output.state
131
+
132
+ if return_updated_state:
133
+ assert updated_conv_state is not None
134
+ updated_state = ShortConvStateLayer(updated_conv_state)
135
+ else:
136
+ updated_state = None
137
+
138
+ return TokenMixerResult(outputs, updated_state)
139
+
140
+ def init_static_state(self, capacity: int) -> ShortConvStateLayer: # noqa: ARG002
141
+ return ShortConvStateLayer.init(
142
+ self.config.kernel_size,
143
+ self.in_projection.input_dim,
144
+ self.activation_precision,
145
+ )
146
+
147
+ def export_weights(self) -> ParameterTree:
148
+ return {
149
+ "in_projection": self.in_projection.export_weights(),
150
+ "conv": self.conv.export_weights(),
151
+ "out_projection": self.out_projection.export_weights(),
152
+ }
153
+
154
+ def import_weights(
155
+ self,
156
+ weights: ParameterTree[Array],
157
+ ) -> Self:
158
+ assert isinstance(weights, Mapping)
159
+ assert isinstance(weights["in_projection"], Mapping)
160
+ assert isinstance(weights["conv"], Mapping)
161
+ assert isinstance(weights["out_projection"], Mapping)
162
+
163
+ return replace(
164
+ self,
165
+ in_projection=self.in_projection.import_weights(weights["in_projection"]),
166
+ conv=self.conv.import_weights(weights["conv"]),
167
+ out_projection=self.out_projection.import_weights(weights["out_projection"]),
168
+ )
@@ -1,11 +1,13 @@
1
1
  from .common import State, StateLayerBase
2
2
  from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
3
3
  from .mamba_state import Mamba2StateLayer
4
+ from .short_conv_state import ShortConvStateLayer
4
5
 
5
6
  __all__ = [
6
7
  "DynamicKVCacheLayer",
7
8
  "KVCacheLayer",
8
9
  "Mamba2StateLayer",
10
+ "ShortConvStateLayer",
9
11
  "State",
10
12
  "StateLayerBase",
11
13
  "StaticKVCacheLayer",
@@ -0,0 +1,33 @@
1
+ from typing import Self
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array, DTypeLike, Float
5
+
6
+ from lalamo.common import ParameterTree
7
+
8
+ from .common import StateLayerBase
9
+
10
+ __all__ = ["ShortConvStateLayer"]
11
+
12
+
13
+ class ShortConvStateLayer(StateLayerBase):
14
+ conv_state: Float[Array, "*batch tokens conv_channels"]
15
+
16
+ def __post_init__(self) -> None:
17
+ if self.conv_state.ndim not in (2, 3):
18
+ raise ValueError(
19
+ f"Conv state must have 2 or 3 dimensions: [batch], tokens, conv_channels,"
20
+ f" got shape {self.conv_state.shape}",
21
+ )
22
+
23
+ @classmethod
24
+ def init(
25
+ cls,
26
+ kernel_size: int,
27
+ model_dim: int,
28
+ dtype: DTypeLike,
29
+ ) -> Self:
30
+ return cls(conv_state=jnp.zeros((kernel_size - 1, model_dim), dtype=dtype))
31
+
32
+ def export(self) -> ParameterTree:
33
+ return dict(conv_state=self.conv_state)
@@ -65,17 +65,23 @@ class TransformerConfig:
65
65
  context_length: int
66
66
 
67
67
  def random_init(self, *, key: PRNGKeyArray) -> "Transformer":
68
- first_layer_config, *_ = self.layer_configs
68
+ rope_dims = (layer.rope_dim for layer in self.layer_configs if layer.rope_dim is not None)
69
+ rope_dim = next(rope_dims, None)
70
+ assert all(d == rope_dim for d in rope_dims)
69
71
 
70
72
  if self.global_rope_config:
73
+ assert rope_dim is not None
74
+
71
75
  global_rope = self.global_rope_config.init(
72
- head_dim=first_layer_config.rope_dim,
76
+ head_dim=rope_dim,
73
77
  num_timesteps=self.context_length,
74
78
  )
75
79
  else:
76
80
  global_rope = None
77
81
 
78
82
  if self.local_rope_config:
83
+ assert rope_dim is not None
84
+
79
85
  max_sliding_window_size = max(
80
86
  layer_config.mixer_config.sliding_window_size or 0
81
87
  for layer_config in self.layer_configs
@@ -83,7 +89,7 @@ class TransformerConfig:
83
89
  )
84
90
 
85
91
  local_rope = self.local_rope_config.init(
86
- head_dim=first_layer_config.rope_dim,
92
+ head_dim=rope_dim,
87
93
  num_timesteps=max(max_sliding_window_size, self.context_length),
88
94
  )
89
95
  else:
@@ -109,19 +115,25 @@ class TransformerConfig:
109
115
  )
110
116
 
111
117
  def empty(self) -> "Transformer":
112
- first_layer_config, *_ = self.layer_configs
118
+ rope_dims = (layer.rope_dim for layer in self.layer_configs if layer.rope_dim is not None)
119
+ rope_dim = next(rope_dims, None)
120
+ assert all(d == rope_dim for d in rope_dims)
113
121
 
114
122
  if self.global_rope_config:
123
+ assert rope_dim is not None
124
+
115
125
  global_rope = self.global_rope_config.init(
116
- head_dim=first_layer_config.rope_dim,
126
+ head_dim=rope_dim,
117
127
  num_timesteps=self.context_length,
118
128
  )
119
129
  else:
120
130
  global_rope = None
121
131
 
122
132
  if self.local_rope_config:
133
+ assert rope_dim is not None
134
+
123
135
  local_rope = self.local_rope_config.init(
124
- head_dim=first_layer_config.rope_dim,
136
+ head_dim=rope_dim,
125
137
  num_timesteps=self.context_length,
126
138
  )
127
139
  else:
@@ -89,7 +89,7 @@ class TransformerLayerConfig:
89
89
  post_mlp_norm_config: NormalizationConfig | None
90
90
 
91
91
  @property
92
- def rope_dim(self) -> int:
92
+ def rope_dim(self) -> int | None:
93
93
  return self.mixer_config.rope_dim
94
94
 
95
95
  def random_init(
lalamo/utils.py CHANGED
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "MapSequence",
25
25
  "jax_uint4_to_packed_uint8",
26
26
  "open_safetensors",
27
+ "process_chat_template",
27
28
  ]
28
29
 
29
30
 
@@ -159,3 +160,9 @@ def jax_uint8_to_unpacked_uint4(array: Array) -> Array:
159
160
  )
160
161
 
161
162
  return unpacked.astype(jnp.uint4)
163
+
164
+
165
+ def process_chat_template(template: str) -> str:
166
+ template = template.replace("{% generation %}", "")
167
+ template = template.replace("{%- endgeneration -%}", "")
168
+ return template
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.8
3
+ Version: 0.5.10
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,27 +1,28 @@
1
- lalamo/__init__.py,sha256=ZJ5Cjq4OoGVrjba9zUYIYnFGRKZkCkhBLaakdt4D008,814
1
+ lalamo/__init__.py,sha256=sCPww-cg0OE8syJQqxdBI7CV5Mpwxj64H0FNbWdHfO4,815
2
2
  lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
3
3
  lalamo/main.py,sha256=GgUT7lT48-XQuAEH7qzsDKG8Lx9iBf-sYBIRhZL9q7E,23978
4
4
  lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
5
5
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
6
  lalamo/registry_abc.py,sha256=ENjXiD_wEH100fNjG-W5Em1L_EQ0Lf0pdRhRGvf3qZk,2197
7
7
  lalamo/sampling.py,sha256=g_dNiJyZrRqoQIiLid4cr6nRT9N5tSz3GtHr8Bt4n-E,3404
8
- lalamo/utils.py,sha256=9kg5P19eaqGrSyAiNSbdfOwrv4s1PJZTHYdiNctlBSY,4368
8
+ lalamo/utils.py,sha256=QwATVXAeHBsQEDyt_31SHgxFphFVZYHpv3ZaklXks9Y,4585
9
9
  lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
10
10
  lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
11
11
  lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
12
12
  lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
13
13
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
14
- lalamo/model_import/common.py,sha256=tdZsteRsxL6DVUFwHw_1eeNLckflOdAaIm7Wm9eJzxM,12311
14
+ lalamo/model_import/common.py,sha256=wvyGD-iLut_Pm3HjDMI05upqdtCW3HWeoeB0YmiFeqk,12419
15
15
  lalamo/model_import/huggingface_generation_config.py,sha256=mot6VQ6ezCtEhN6VjhnvaU-nR5P5T2BuBUgpFNnWJxU,1495
16
16
  lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
17
- lalamo/model_import/decoder_configs/__init__.py,sha256=1ZqMcEHvCJjMIZ9iNyY31XMXOaFxB-NbqIU01BtmcEk,641
17
+ lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
18
18
  lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
19
19
  lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDjOTrWevfyw7QbWF1mUdOQ,5924
20
- lalamo/model_import/decoder_configs/huggingface/__init__.py,sha256=3H7GPTFNNahEvI8D1SGg2mGBgPhsIdZ213MglwbGDlE,645
20
+ lalamo/model_import/decoder_configs/huggingface/__init__.py,sha256=AboZJgZxOuIigPShskj-FqBkBqwlJZoKHP0RDqx-MyY,696
21
21
  lalamo/model_import/decoder_configs/huggingface/common.py,sha256=YYIDEQy8x7lqL2qtxUHrNqfjZEiizBZ_26sTqOzjRtQ,3792
22
22
  lalamo/model_import/decoder_configs/huggingface/gemma2.py,sha256=g8LH_GlSNyL04WWi596zI0rWsD3ahnfNjDk-9zZNcDE,4759
23
- lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=KlhL7y6lW_cUgsT2JjvlQbsuKZggI8DG5wazZZBk0zM,7415
23
+ lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=aSZ0TtpgDYA10rHi8eD0C_Jsn48siM_HXqfZ4O7nh94,8372
24
24
  lalamo/model_import/decoder_configs/huggingface/gpt_oss.py,sha256=MBCoPbuWyzbJiBRtHOtpaPHJjQ1UVCAYcVrfIejTnlQ,7446
25
+ lalamo/model_import/decoder_configs/huggingface/lfm2.py,sha256=Esjg9VsIKTE9B9Vu6DHb-VZxSdqxLRgbkyUwpjnmKhc,5510
25
26
  lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=UPeQiz2Dix8YaZYRxn9z44OZJ6c4xBQmcUZcM0Ymvh4,6934
26
27
  lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=ANB-vQK8U-zVFubZSTDXXt2S70T5SVOGzf7eOVvPzIQ,5773
27
28
  lalamo/model_import/decoder_configs/huggingface/mistral.py,sha256=MDGC0ivzJuUpOC11n8vFdcVzqccUyaRw_hkL74mVlAg,4599
@@ -31,14 +32,16 @@ lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=lySVO-TvusAYUjDn
31
32
  lalamo/model_import/loaders/__init__.py,sha256=3THc1wQ4EPBzQkL_4EaKCa7Ev5Z7oczcvc4AHy9v5EI,228
32
33
  lalamo/model_import/loaders/common.py,sha256=kkugV-bMQlN1zvGHoj3uc7z0FbXKoMtXEBTvyu4KxK4,1844
33
34
  lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFlYrqtWyNkBU_8,9219
34
- lalamo/model_import/loaders/huggingface.py,sha256=ITA0Y_kCDFL4Tanuvd1NWUvV77WEn0VEzkcX5Whlwys,29835
35
+ lalamo/model_import/loaders/huggingface.py,sha256=sErBtGxODzqUkn-hJlzhCNhWmWqTeH4BneeQ8cqDhZo,32283
35
36
  lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
36
- lalamo/model_import/model_specs/__init__.py,sha256=V7S5Uo3GVBUG7KD0czMtmWZcQ-FJgryTZlxC7Abn_c0,1175
37
- lalamo/model_import/model_specs/common.py,sha256=RVPlNWHG_5OvU1W3YcOpqYz59Dh8plDmd7z1xNrqmaY,6585
37
+ lalamo/model_import/model_specs/__init__.py,sha256=JISqwJkloQkGD2jvi1MakNEWapIwlNXXVi5giZyXB74,1275
38
+ lalamo/model_import/model_specs/common.py,sha256=RLySCIkmGiA1IVZgLeemssMBMo4hMYMpmBjV0cRwBb4,6586
38
39
  lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
40
+ lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
39
41
  lalamo/model_import/model_specs/gemma.py,sha256=irWgylL-pc7y3Gn5DK3fjKoCT9kJWH3B7mTa-1Gmxqc,1306
40
42
  lalamo/model_import/model_specs/gpt_oss.py,sha256=PLo0QGrXKdX61ReTRdyOaP_EH3Dmj5lp3fpJjZRwRVA,542
41
- lalamo/model_import/model_specs/huggingface.py,sha256=eF8ItF5reFrFkjYxwiAJcFwUAlN6CpXfM-aQ8a92ItM,430
43
+ lalamo/model_import/model_specs/huggingface.py,sha256=TEkU8y95_hmUWyF-Q5hn0dE2SvXbApghAsQwhWRu4D0,431
44
+ lalamo/model_import/model_specs/lfm2.py,sha256=UlCQkKBWu7YMlc3L_c-cMOgXKw7j2wCHIu9ELwkkoCE,498
42
45
  lalamo/model_import/model_specs/llama.py,sha256=Ml-xvRGlXBT9NJhmEpwgNo6C84oBSMYgA1_PrCYGcAw,990
43
46
  lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
44
47
  lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
@@ -51,7 +54,7 @@ lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
51
54
  lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
52
55
  lalamo/models/common.py,sha256=PDteofGxjSBWYw_mPxbN1DTUba70aOURrAIjl13SSHc,2954
53
56
  lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
54
- lalamo/modules/__init__.py,sha256=xWJ4OPAF4gKd0evYwXIK5kTnbH6nI55oLAePcoDDHQ0,3730
57
+ lalamo/modules/__init__.py,sha256=dFCicpcx-XV9sVTMR7x4TVF2tAGpzFi_sCTPAyawoJo,3858
55
58
  lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
56
59
  lalamo/modules/classifier.py,sha256=_jtJ3INEq1dJP5HpUmcDk9YYzpRYlQ04zvFGaWBV6Lg,12101
57
60
  lalamo/modules/common.py,sha256=dqDEOi-C3H4U9iWUisU32RA-wRDCGuaUNGbObRBhyQM,3315
@@ -63,26 +66,28 @@ lalamo/modules/mlx_interop.py,sha256=FdfU_1iES-HQ9r4K0SkYwJTyvE0f-_T5ursNCjPLZKY
63
66
  lalamo/modules/normalization.py,sha256=cBdOq6OmJssunVeEwFRJD0BDhgFAN7J8gOKwzIUAY8I,3005
64
67
  lalamo/modules/rope.py,sha256=rCik7vBNqRXYg3LGbmc1mezPRNbIYMg5cydTFpQy-eU,10157
65
68
  lalamo/modules/torch_interop.py,sha256=-mujd1zI4ec2w92Hd50RtDa0K3jl6ZSnPxc5r3Fp9nU,916
66
- lalamo/modules/transformer.py,sha256=67-WZX2eE314abiQOhRNSooTHeJh4q9mlZQIxQ-oASU,10162
67
- lalamo/modules/transformer_layer.py,sha256=CfkYIn8a3pR4PPsI9hmAXpyiTbjXo-Gzl2OU9lAQlkI,12724
69
+ lalamo/modules/transformer.py,sha256=4olEO8Eh7U6RwSnaECn39ooPuTKUZp_6QmvO6vdirrQ,10532
70
+ lalamo/modules/transformer_layer.py,sha256=ZYmGR2Ej328l7K-YpV4eEiBk8SzLsw1RiuSiUP94UpY,12731
68
71
  lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
69
- lalamo/modules/token_mixers/__init__.py,sha256=_t4T25C4WBVJQ1SqkQPGrrAc7bPKhDO3K2btgefVh5s,909
72
+ lalamo/modules/token_mixers/__init__.py,sha256=z6x8cNjis6xIi_2llIoByKqMF2W4xJ05rDnxitHQ3jU,1139
70
73
  lalamo/modules/token_mixers/attention.py,sha256=gkGMFah2OHB_tyJpkshM1KhMnzG6U7Xt273MkBvDk58,16584
71
- lalamo/modules/token_mixers/common.py,sha256=-ej1pIrrp845ztavJ3oh82U3HEsV_rEHxMZOEDp7iK8,1979
72
- lalamo/modules/token_mixers/mamba.py,sha256=MIIMZAlVVE4YwyT0PsxA0OWXa13ondoJchRxQbhq678,18797
73
- lalamo/modules/token_mixers/state/__init__.py,sha256=iQaX7njz3XtwGugI5_PUOIp1wdCzd5h08UkgF6yW3zo,307
74
+ lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
75
+ lalamo/modules/token_mixers/mamba.py,sha256=fo8xvvmIQss2lKLhav19Jzk1-hTykNp2sjcN6ntcWj4,18789
76
+ lalamo/modules/token_mixers/short_conv.py,sha256=93SmoVsuAtdX4ckAkvhHXHiO67pU6soYFpBZxdPFEwc,5219
77
+ lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
74
78
  lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
75
79
  lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
76
80
  lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxbnMJCl--JOoe9Fkcc9Mg,1642
81
+ lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
77
82
  lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
78
83
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
79
84
  lalamo/speculator/estimator.py,sha256=4D8dPZCWsrpORb7y8pQ6VsiIg1Cblvvxe6gXCoYtcD4,2530
80
85
  lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
81
86
  lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
82
87
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
83
- lalamo-0.5.8.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
84
- lalamo-0.5.8.dist-info/METADATA,sha256=miYVR0hj7X-d1X09Bwaqf9-zKUqmljZ2qrhkV1rLICQ,3146
85
- lalamo-0.5.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
86
- lalamo-0.5.8.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
87
- lalamo-0.5.8.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
88
- lalamo-0.5.8.dist-info/RECORD,,
88
+ lalamo-0.5.10.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
89
+ lalamo-0.5.10.dist-info/METADATA,sha256=7KSYbe35d3aafssFta83t2MzVShN0JJsVd5nPfjb2VA,3147
90
+ lalamo-0.5.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
91
+ lalamo-0.5.10.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
92
+ lalamo-0.5.10.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
93
+ lalamo-0.5.10.dist-info/RECORD,,