lalamo 0.5.8__tar.gz → 0.5.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.5.8 → lalamo-0.5.10}/PKG-INFO +1 -1
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/__init__.py +1 -1
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/common.py +2 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/__init__.py +2 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +31 -9
- lalamo-0.5.10/lalamo/model_import/decoder_configs/huggingface/lfm2.py +174 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/loaders/huggingface.py +71 -10
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/__init__.py +4 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/common.py +1 -0
- lalamo-0.5.10/lalamo/model_import/model_specs/essential_ai.py +17 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo-0.5.10/lalamo/model_import/model_specs/lfm2.py +21 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/__init__.py +6 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/__init__.py +15 -2
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/common.py +1 -1
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/mamba.py +2 -2
- lalamo-0.5.10/lalamo/modules/token_mixers/short_conv.py +168 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/state/__init__.py +2 -0
- lalamo-0.5.10/lalamo/modules/token_mixers/state/short_conv_state.py +33 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/transformer.py +18 -6
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/transformer_layer.py +1 -1
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/utils.py +7 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo.egg-info/SOURCES.txt +6 -0
- lalamo-0.5.10/tests/test_lfm2_models.py +14 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/LICENSE +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/README.md +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/data/__init__.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/data/utils.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/main.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/message_processor.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/mirai.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/models/__init__.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/models/classifier.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/models/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/models/language_model.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/activations.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/classifier.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/decoder.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/linear.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/mlp.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/rope.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/attention.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/modules/utils.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/quantization.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/registry_abc.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/sampling.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/speculator/common.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/speculator/estimator.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/speculator/inference.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/pyproject.toml +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/setup.cfg +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_cartesia_mlx_models.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_chat_template.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_generation.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_huggingface_model_conversion.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_huggingface_models.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_mlx_models.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_model_spec.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_models.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_moe.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_parameter_tree.py +0 -0
- {lalamo-0.5.8 → lalamo-0.5.10}/tests/test_registry_abc.py +0 -0
|
@@ -17,6 +17,7 @@ from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
|
|
|
17
17
|
from lalamo.models import ClassifierModel, ClassifierModelConfig, GenerationConfig, LanguageModel, LanguageModelConfig
|
|
18
18
|
from lalamo.modules import Classifier, Decoder, LalamoModule
|
|
19
19
|
from lalamo.quantization import QuantizationMode
|
|
20
|
+
from lalamo.utils import process_chat_template
|
|
20
21
|
|
|
21
22
|
from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
|
|
22
23
|
from .huggingface_generation_config import HFGenerationConfig
|
|
@@ -154,6 +155,7 @@ def import_message_processor(
|
|
|
154
155
|
if model_spec.configs.chat_template is not None:
|
|
155
156
|
raise ValueError("Conflicting chat template specifications.")
|
|
156
157
|
prompt_template = tokenizer_config.chat_template
|
|
158
|
+
prompt_template = process_chat_template(prompt_template)
|
|
157
159
|
tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
|
158
160
|
|
|
159
161
|
added_tokens = tokenizer_config.added_tokens()
|
|
@@ -6,6 +6,7 @@ from .huggingface import (
|
|
|
6
6
|
HFGemma3Config,
|
|
7
7
|
HFGemma3TextConfig,
|
|
8
8
|
HFGPTOssConfig,
|
|
9
|
+
HFLFM2Config,
|
|
9
10
|
HFLlamaConfig,
|
|
10
11
|
HFLlambaConfig,
|
|
11
12
|
HFMistralConfig,
|
|
@@ -22,6 +23,7 @@ __all__ = [
|
|
|
22
23
|
"HFGemma2Config",
|
|
23
24
|
"HFGemma3Config",
|
|
24
25
|
"HFGemma3TextConfig",
|
|
26
|
+
"HFLFM2Config",
|
|
25
27
|
"HFLlamaConfig",
|
|
26
28
|
"HFLlambaConfig",
|
|
27
29
|
"HFMistralConfig",
|
|
@@ -2,6 +2,7 @@ from .common import HuggingFaceLMConfig
|
|
|
2
2
|
from .gemma2 import HFGemma2Config
|
|
3
3
|
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
4
|
from .gpt_oss import HFGPTOssConfig
|
|
5
|
+
from .lfm2 import HFLFM2Config
|
|
5
6
|
from .llama import HFLlamaConfig
|
|
6
7
|
from .llamba import HFLlambaConfig
|
|
7
8
|
from .mistral import HFMistralConfig
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"HFGemma2Config",
|
|
15
16
|
"HFGemma3Config",
|
|
16
17
|
"HFGemma3TextConfig",
|
|
18
|
+
"HFLFM2Config",
|
|
17
19
|
"HFLlamaConfig",
|
|
18
20
|
"HFLlambaConfig",
|
|
19
21
|
"HFMistralConfig",
|
|
@@ -10,7 +10,7 @@ from lalamo.modules.activations import GELU
|
|
|
10
10
|
from lalamo.modules.linear import FullPrecisionLinearConfig
|
|
11
11
|
from lalamo.modules.mlp import DenseMLPConfig
|
|
12
12
|
from lalamo.modules.normalization import NormalizationConfig, UpcastMode
|
|
13
|
-
from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
|
|
13
|
+
from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig, YARNRoPEConfig
|
|
14
14
|
from lalamo.modules.token_mixers.attention import AttentionConfig
|
|
15
15
|
from lalamo.modules.transformer_layer import TransformerLayerConfig
|
|
16
16
|
|
|
@@ -19,9 +19,6 @@ from .common import HuggingFaceLMConfig
|
|
|
19
19
|
__all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER = 6
|
|
23
|
-
|
|
24
|
-
|
|
25
22
|
def _round_to_bfloat16(x: float) -> float:
|
|
26
23
|
return jnp.asarray(x).astype(jnp.bfloat16).item()
|
|
27
24
|
|
|
@@ -32,6 +29,16 @@ class GemmaRoPEScalingConfig:
|
|
|
32
29
|
rope_type: Literal["linear"]
|
|
33
30
|
|
|
34
31
|
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class YarnRopeScalingConfig:
|
|
34
|
+
factor: float
|
|
35
|
+
beta_fast: float
|
|
36
|
+
beta_slow: float
|
|
37
|
+
original_max_position_embeddings: int
|
|
38
|
+
rope_type: Literal["yarn"]
|
|
39
|
+
truncate: bool = False
|
|
40
|
+
|
|
41
|
+
|
|
35
42
|
@dataclass(frozen=True)
|
|
36
43
|
class HFGemma3TextConfigRaw:
|
|
37
44
|
hidden_size: int
|
|
@@ -39,6 +46,7 @@ class HFGemma3TextConfigRaw:
|
|
|
39
46
|
model_type: Literal["gemma3_text"]
|
|
40
47
|
num_hidden_layers: int
|
|
41
48
|
sliding_window: int
|
|
49
|
+
sliding_window_pattern: int
|
|
42
50
|
rms_norm_eps: float = 1e-06
|
|
43
51
|
query_pre_attn_scalar: float = 256.0
|
|
44
52
|
attention_bias: bool = False
|
|
@@ -49,7 +57,7 @@ class HFGemma3TextConfigRaw:
|
|
|
49
57
|
max_position_embeddings: int = 131072
|
|
50
58
|
rope_theta: float = 1000000.0
|
|
51
59
|
rope_local_base_freq: float = 10000.0
|
|
52
|
-
rope_scaling: GemmaRoPEScalingConfig | None = None
|
|
60
|
+
rope_scaling: GemmaRoPEScalingConfig | YarnRopeScalingConfig | None = None
|
|
53
61
|
final_logit_softcapping: float | None = None
|
|
54
62
|
vocab_size: int = 262208
|
|
55
63
|
|
|
@@ -57,7 +65,7 @@ class HFGemma3TextConfigRaw:
|
|
|
57
65
|
def sliding_window_sizes(self) -> list[int | None]:
|
|
58
66
|
result = []
|
|
59
67
|
for i in range(self.num_hidden_layers):
|
|
60
|
-
if (i + 1) %
|
|
68
|
+
if (i + 1) % self.sliding_window_pattern == 0:
|
|
61
69
|
result.append(None)
|
|
62
70
|
else:
|
|
63
71
|
result.append(self.sliding_window)
|
|
@@ -74,7 +82,7 @@ class HFGemma3TextConfigRaw:
|
|
|
74
82
|
attention_scale = self.query_pre_attn_scalar**-0.5
|
|
75
83
|
embedding_config = TiedEmbeddingConfig(
|
|
76
84
|
input_scale=input_scale,
|
|
77
|
-
logit_soft_cap=
|
|
85
|
+
logit_soft_cap=self.final_logit_softcapping,
|
|
78
86
|
precision=activation_precision,
|
|
79
87
|
)
|
|
80
88
|
rms_norm_config = NormalizationConfig(
|
|
@@ -86,19 +94,33 @@ class HFGemma3TextConfigRaw:
|
|
|
86
94
|
subtract_mean=False,
|
|
87
95
|
)
|
|
88
96
|
|
|
89
|
-
if self.rope_scaling
|
|
97
|
+
if isinstance(self.rope_scaling, GemmaRoPEScalingConfig):
|
|
90
98
|
global_rope_config = LinearScalingRoPEConfig(
|
|
91
99
|
precision=activation_precision,
|
|
92
100
|
base=self.rope_theta,
|
|
93
101
|
max_sequence_length=self.max_position_embeddings,
|
|
94
102
|
scaling_factor=self.rope_scaling.factor,
|
|
95
103
|
)
|
|
96
|
-
|
|
104
|
+
elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
|
|
105
|
+
global_rope_config = YARNRoPEConfig(
|
|
106
|
+
precision=activation_precision,
|
|
107
|
+
base=self.rope_theta,
|
|
108
|
+
scaling_factor=self.rope_scaling.factor,
|
|
109
|
+
max_sequence_length=self.max_position_embeddings,
|
|
110
|
+
original_context_length=self.rope_scaling.original_max_position_embeddings,
|
|
111
|
+
beta_fast=self.rope_scaling.beta_fast,
|
|
112
|
+
beta_slow=self.rope_scaling.beta_slow,
|
|
113
|
+
truncate=self.rope_scaling.truncate,
|
|
114
|
+
)
|
|
115
|
+
elif self.rope_scaling is None:
|
|
97
116
|
global_rope_config = UnscaledRoPEConfig(
|
|
98
117
|
precision=activation_precision,
|
|
99
118
|
base=self.rope_theta,
|
|
100
119
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
101
120
|
)
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError("Invalid rope scaling configuration")
|
|
123
|
+
|
|
102
124
|
local_rope_config = UnscaledRoPEConfig(
|
|
103
125
|
precision=activation_precision,
|
|
104
126
|
base=self.rope_local_base_freq,
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from jaxtyping import DTypeLike
|
|
6
|
+
|
|
7
|
+
from lalamo.modules import (
|
|
8
|
+
AttentionConfig,
|
|
9
|
+
DecoderConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
|
+
FullPrecisionLinearConfig,
|
|
12
|
+
NormalizationConfig,
|
|
13
|
+
SeparableCausalConvConfig,
|
|
14
|
+
ShortConvConfig,
|
|
15
|
+
SiLU,
|
|
16
|
+
TiedEmbeddingConfig,
|
|
17
|
+
TransformerConfig,
|
|
18
|
+
TransformerLayerConfig,
|
|
19
|
+
UnscaledRoPEConfig,
|
|
20
|
+
UntiedEmbeddingConfig,
|
|
21
|
+
UpcastMode,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from .common import HuggingFaceLMConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class HFLFM2Config(HuggingFaceLMConfig):
|
|
29
|
+
architectures: list[Literal["Lfm2ForCausalLM"]]
|
|
30
|
+
block_auto_adjust_ff_dim: Literal[False]
|
|
31
|
+
block_dim: int
|
|
32
|
+
block_ff_dim: int
|
|
33
|
+
block_ffn_dim_multiplier: float
|
|
34
|
+
block_mlp_init_scale: float
|
|
35
|
+
block_multiple_of: int
|
|
36
|
+
block_norm_eps: float
|
|
37
|
+
block_out_init_scale: float
|
|
38
|
+
block_use_swiglu: bool
|
|
39
|
+
block_use_xavier_init: bool
|
|
40
|
+
bos_token_id: int
|
|
41
|
+
conv_L_cache: int # noqa: N815
|
|
42
|
+
conv_bias: int
|
|
43
|
+
conv_dim: int
|
|
44
|
+
conv_dim_out: int
|
|
45
|
+
conv_use_xavier_init: bool
|
|
46
|
+
eos_token_id: int
|
|
47
|
+
hidden_size: int
|
|
48
|
+
initializer_range: float
|
|
49
|
+
intermediate_size: int
|
|
50
|
+
layer_types: list[Literal["conv", "full_attention"]]
|
|
51
|
+
max_position_embeddings: int
|
|
52
|
+
model_type: Literal["lfm2"]
|
|
53
|
+
norm_eps: float
|
|
54
|
+
num_attention_heads: int
|
|
55
|
+
num_heads: int
|
|
56
|
+
num_hidden_layers: int
|
|
57
|
+
num_key_value_heads: int
|
|
58
|
+
pad_token_id: int
|
|
59
|
+
rope_theta: float
|
|
60
|
+
theta: float
|
|
61
|
+
tie_embedding: bool
|
|
62
|
+
torch_dtype: Literal["bfloat16"]
|
|
63
|
+
transformers_version: str
|
|
64
|
+
use_cache: bool
|
|
65
|
+
use_pos_enc: bool
|
|
66
|
+
vocab_size: int
|
|
67
|
+
|
|
68
|
+
def to_decoder_config(
|
|
69
|
+
self,
|
|
70
|
+
context_length: int | None,
|
|
71
|
+
activation_precision: DTypeLike,
|
|
72
|
+
accumulation_precision: DTypeLike,
|
|
73
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
74
|
+
) -> DecoderConfig:
|
|
75
|
+
assert self.num_attention_heads == self.num_heads
|
|
76
|
+
|
|
77
|
+
if self.tie_embedding:
|
|
78
|
+
embedding_config = TiedEmbeddingConfig(
|
|
79
|
+
input_scale=None,
|
|
80
|
+
logit_soft_cap=None,
|
|
81
|
+
precision=activation_precision,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
85
|
+
input_scale=None,
|
|
86
|
+
logit_soft_cap=None,
|
|
87
|
+
precision=activation_precision,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
rope_config = UnscaledRoPEConfig(
|
|
91
|
+
precision=activation_precision,
|
|
92
|
+
base=self.rope_theta,
|
|
93
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
linear_config = FullPrecisionLinearConfig(activation_precision)
|
|
97
|
+
|
|
98
|
+
block_norm_config = NormalizationConfig(
|
|
99
|
+
scale_precision=activation_precision,
|
|
100
|
+
accumulation_precision=accumulation_precision,
|
|
101
|
+
epsilon=self.block_norm_eps,
|
|
102
|
+
scale_offset=None,
|
|
103
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
104
|
+
subtract_mean=False,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
attention_config = AttentionConfig(
|
|
108
|
+
qkv_projection_config=linear_config,
|
|
109
|
+
out_projection_config=linear_config,
|
|
110
|
+
query_norm_config=block_norm_config,
|
|
111
|
+
key_norm_config=block_norm_config,
|
|
112
|
+
num_heads=self.num_attention_heads,
|
|
113
|
+
num_groups=self.num_key_value_heads,
|
|
114
|
+
head_dim=self.hidden_size // self.num_heads,
|
|
115
|
+
is_causal=True,
|
|
116
|
+
scale=None,
|
|
117
|
+
sliding_window_size=None,
|
|
118
|
+
logit_soft_cap=None,
|
|
119
|
+
has_sinks=False,
|
|
120
|
+
has_qkv_biases=False,
|
|
121
|
+
has_out_biases=False,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
short_conv_config = ShortConvConfig(
|
|
125
|
+
in_projection_config=linear_config,
|
|
126
|
+
conv_config=SeparableCausalConvConfig(activation_precision, has_biases=False),
|
|
127
|
+
out_projection_config=linear_config,
|
|
128
|
+
kernel_size=self.conv_L_cache,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
mlp_config = DenseMLPConfig(
|
|
132
|
+
linear_config=linear_config,
|
|
133
|
+
activation=SiLU(),
|
|
134
|
+
has_up_biases=False,
|
|
135
|
+
has_down_biases=False,
|
|
136
|
+
up_clipping=None,
|
|
137
|
+
gate_clipping=None,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
layer_configs = [
|
|
141
|
+
TransformerLayerConfig(
|
|
142
|
+
pre_mixer_norm_config=block_norm_config,
|
|
143
|
+
mixer_config={"conv": short_conv_config, "full_attention": attention_config}[layer_type],
|
|
144
|
+
post_mixer_norm_config=None,
|
|
145
|
+
pre_mlp_norm_config=block_norm_config,
|
|
146
|
+
mlp_config=mlp_config,
|
|
147
|
+
post_mlp_norm_config=None,
|
|
148
|
+
) for layer_type in self.layer_types
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
output_norm_config = NormalizationConfig(
|
|
152
|
+
scale_precision=activation_precision,
|
|
153
|
+
accumulation_precision=accumulation_precision,
|
|
154
|
+
epsilon=self.norm_eps,
|
|
155
|
+
scale_offset=None,
|
|
156
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
157
|
+
subtract_mean=False,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
transformer_config = TransformerConfig(
|
|
161
|
+
global_rope_config=rope_config,
|
|
162
|
+
local_rope_config=None,
|
|
163
|
+
layer_configs=tuple(layer_configs),
|
|
164
|
+
output_norm_config=output_norm_config,
|
|
165
|
+
model_dim=self.hidden_size,
|
|
166
|
+
hidden_dim=self.intermediate_size,
|
|
167
|
+
context_length=context_length or self.max_position_embeddings,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return DecoderConfig(
|
|
171
|
+
embedding_config=embedding_config,
|
|
172
|
+
transformer_config=transformer_config,
|
|
173
|
+
vocab_size=self.vocab_size,
|
|
174
|
+
)
|
|
@@ -8,17 +8,21 @@ from jaxtyping import Array, DTypeLike
|
|
|
8
8
|
from lalamo.common import ParameterPath
|
|
9
9
|
from lalamo.modules import (
|
|
10
10
|
Attention,
|
|
11
|
+
AttentionConfig,
|
|
11
12
|
Decoder,
|
|
12
13
|
DenseMLP,
|
|
13
14
|
FullPrecisionLinear,
|
|
14
15
|
GroupQuantizedLinear,
|
|
15
16
|
LinearBase,
|
|
16
17
|
Mamba2,
|
|
18
|
+
Mamba2Config,
|
|
17
19
|
MLXQuantizedLinear,
|
|
18
20
|
MLXQuantizedTiedEmbedding,
|
|
19
21
|
MLXSemiQuantizedUntiedEmbedding,
|
|
20
22
|
Normalization,
|
|
21
23
|
SeparableCausalConv,
|
|
24
|
+
ShortConv,
|
|
25
|
+
ShortConvConfig,
|
|
22
26
|
TiedEmbedding,
|
|
23
27
|
TransformerLayer,
|
|
24
28
|
UntiedEmbedding,
|
|
@@ -300,7 +304,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
300
304
|
down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
|
|
301
305
|
down_b = weights_dict[experts_path / "down_proj_bias"]
|
|
302
306
|
if down_b.ndim == 1:
|
|
303
|
-
down_b = jnp.broadcast_to(down_b, down_w.shape[:-1]
|
|
307
|
+
down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
|
|
304
308
|
|
|
305
309
|
down_projection = load_parameters(
|
|
306
310
|
lambda m: (m.weights, m.biases), # type: ignore
|
|
@@ -345,21 +349,42 @@ def load_attention(
|
|
|
345
349
|
weights_dict: Mapping[str, Array],
|
|
346
350
|
path: ParameterPath,
|
|
347
351
|
) -> Attention:
|
|
352
|
+
if (path / "o_proj.weight") in weights_dict:
|
|
353
|
+
o_proj_name = "o_proj"
|
|
354
|
+
elif (path / "out_proj.weight") in weights_dict:
|
|
355
|
+
o_proj_name = "out_proj"
|
|
356
|
+
else:
|
|
357
|
+
raise NotImplementedError("Can't determine attention output projection name")
|
|
358
|
+
|
|
348
359
|
qkv_projection = load_linear(
|
|
349
360
|
module.qkv_projection,
|
|
350
361
|
weights_dict,
|
|
351
362
|
path,
|
|
352
363
|
sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
|
|
353
364
|
)
|
|
354
|
-
out_projection = load_linear(module.out_projection, weights_dict, path /
|
|
365
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / o_proj_name)
|
|
355
366
|
|
|
356
367
|
if module.query_norm is not None:
|
|
357
|
-
|
|
368
|
+
if (path / "q_norm.weight") in weights_dict:
|
|
369
|
+
q_norm_name = "q_norm"
|
|
370
|
+
elif (path / "q_layernorm.weight") in weights_dict:
|
|
371
|
+
q_norm_name = "q_layernorm"
|
|
372
|
+
else:
|
|
373
|
+
raise NotImplementedError("Can't determine attention query projection parameter name")
|
|
374
|
+
|
|
375
|
+
query_norm = load_rmsnorm(module.query_norm, weights_dict, path / q_norm_name)
|
|
358
376
|
else:
|
|
359
377
|
query_norm = None
|
|
360
378
|
|
|
361
379
|
if module.key_norm is not None:
|
|
362
|
-
|
|
380
|
+
if (path / "k_norm.weight") in weights_dict:
|
|
381
|
+
k_norm_name = "k_norm"
|
|
382
|
+
elif (path / "k_layernorm.weight") in weights_dict:
|
|
383
|
+
k_norm_name = "k_layernorm"
|
|
384
|
+
else:
|
|
385
|
+
raise NotImplementedError("Can't determine attention key projection parameter name")
|
|
386
|
+
|
|
387
|
+
key_norm = load_rmsnorm(module.key_norm, weights_dict, path / k_norm_name)
|
|
363
388
|
else:
|
|
364
389
|
key_norm = None
|
|
365
390
|
|
|
@@ -382,7 +407,7 @@ def load_attention(
|
|
|
382
407
|
)
|
|
383
408
|
|
|
384
409
|
|
|
385
|
-
def
|
|
410
|
+
def _load_conv(
|
|
386
411
|
conv_module: SeparableCausalConv,
|
|
387
412
|
weights_dict: Mapping[str, Array],
|
|
388
413
|
path: ParameterPath,
|
|
@@ -390,6 +415,8 @@ def _load_mamba_conv(
|
|
|
390
415
|
weight_path = path / "conv1d" / "weight"
|
|
391
416
|
if weight_path not in weights_dict:
|
|
392
417
|
weight_path = path / "conv_weight"
|
|
418
|
+
if weight_path not in weights_dict:
|
|
419
|
+
weight_path = path / "conv.weight"
|
|
393
420
|
if weight_path not in weights_dict:
|
|
394
421
|
weight_path = None
|
|
395
422
|
|
|
@@ -402,6 +429,8 @@ def _load_mamba_conv(
|
|
|
402
429
|
bias_path = path / "conv1d" / "bias"
|
|
403
430
|
if bias_path not in weights_dict:
|
|
404
431
|
bias_path = path / "conv_bias"
|
|
432
|
+
if bias_path not in weights_dict:
|
|
433
|
+
bias_path = path / "conv.bias"
|
|
405
434
|
if bias_path not in weights_dict:
|
|
406
435
|
bias_path = None
|
|
407
436
|
|
|
@@ -424,7 +453,7 @@ def load_mamba2(
|
|
|
424
453
|
) -> Mamba2:
|
|
425
454
|
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
426
455
|
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
427
|
-
conv =
|
|
456
|
+
conv = _load_conv(module.conv, weights_dict, path)
|
|
428
457
|
|
|
429
458
|
skip_connection_weight_path = path / "D"
|
|
430
459
|
if skip_connection_weight_path in weights_dict:
|
|
@@ -451,6 +480,22 @@ def load_mamba2(
|
|
|
451
480
|
)
|
|
452
481
|
|
|
453
482
|
|
|
483
|
+
def load_short_conv(
|
|
484
|
+
module: ShortConv,
|
|
485
|
+
weights_dict: Mapping[str, Array],
|
|
486
|
+
path: ParameterPath,
|
|
487
|
+
) -> ShortConv:
|
|
488
|
+
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
489
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
490
|
+
conv = _load_conv(module.conv, weights_dict, path)
|
|
491
|
+
|
|
492
|
+
return load_parameters(
|
|
493
|
+
lambda m: (m.in_projection, m.out_projection, m.conv),
|
|
494
|
+
module,
|
|
495
|
+
(in_projection, out_projection, conv),
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
454
499
|
def load_transformer_layer(
|
|
455
500
|
module: TransformerLayer,
|
|
456
501
|
weights_dict: Mapping[str, Array],
|
|
@@ -478,6 +523,8 @@ def load_transformer_layer(
|
|
|
478
523
|
mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
479
524
|
elif isinstance(module.mixer, Mamba2):
|
|
480
525
|
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
526
|
+
elif isinstance(module.mixer, ShortConv):
|
|
527
|
+
mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
481
528
|
else:
|
|
482
529
|
mixer = module.mixer
|
|
483
530
|
|
|
@@ -625,11 +672,12 @@ def load_huggingface_decoder(
|
|
|
625
672
|
|
|
626
673
|
is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
|
|
627
674
|
is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
|
|
675
|
+
is_lfm2 = any(key.startswith("model.layers.0.operator_norm.weight") for key in weights_dict)
|
|
628
676
|
if is_llamba_full_precision:
|
|
629
677
|
decoder_path = base_path / "backbone"
|
|
630
678
|
embedding_path = decoder_path / "embedding"
|
|
631
679
|
pre_mixer_norm_key = "input_layernorm"
|
|
632
|
-
mixer_key = "mixer"
|
|
680
|
+
mixer_key = {Mamba2Config: "mixer"}
|
|
633
681
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
634
682
|
mlp_key = "mlp"
|
|
635
683
|
up_proj_key = "up_proj"
|
|
@@ -642,7 +690,7 @@ def load_huggingface_decoder(
|
|
|
642
690
|
decoder_path = base_path / "model"
|
|
643
691
|
embedding_path = base_path / "embedding.encoder"
|
|
644
692
|
pre_mixer_norm_key = "norm"
|
|
645
|
-
mixer_key = "layer"
|
|
693
|
+
mixer_key = {Mamba2Config: "layer"}
|
|
646
694
|
pre_mlp_norm_key = "norm"
|
|
647
695
|
mlp_key = "layer"
|
|
648
696
|
up_proj_key = "gate_proj"
|
|
@@ -651,11 +699,24 @@ def load_huggingface_decoder(
|
|
|
651
699
|
alternating_layers = True
|
|
652
700
|
norm_key = "norm"
|
|
653
701
|
lm_head_path = base_path / "head.linear"
|
|
702
|
+
elif is_lfm2:
|
|
703
|
+
decoder_path = base_path / "model"
|
|
704
|
+
embedding_path = decoder_path / "embed_tokens"
|
|
705
|
+
pre_mixer_norm_key = "operator_norm"
|
|
706
|
+
mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
|
|
707
|
+
pre_mlp_norm_key = "ffn_norm"
|
|
708
|
+
mlp_key = "feed_forward"
|
|
709
|
+
up_proj_key = "w3"
|
|
710
|
+
gate_proj_key = "w1"
|
|
711
|
+
down_proj_key = "w2"
|
|
712
|
+
alternating_layers = False
|
|
713
|
+
norm_key = "embedding_norm"
|
|
714
|
+
lm_head_path = base_path / "lm_head"
|
|
654
715
|
else:
|
|
655
716
|
decoder_path = base_path / "model"
|
|
656
717
|
embedding_path = decoder_path / "embed_tokens"
|
|
657
718
|
pre_mixer_norm_key = "input_layernorm"
|
|
658
|
-
mixer_key = "self_attn"
|
|
719
|
+
mixer_key = {AttentionConfig: "self_attn"}
|
|
659
720
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
660
721
|
mlp_key = "mlp"
|
|
661
722
|
up_proj_key = "up_proj"
|
|
@@ -687,7 +748,7 @@ def load_huggingface_decoder(
|
|
|
687
748
|
weights_dict,
|
|
688
749
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
689
750
|
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
690
|
-
mixer_key,
|
|
751
|
+
mixer_key[type(layer.config.mixer_config)], # type: ignore
|
|
691
752
|
mlp_key,
|
|
692
753
|
pre_mixer_norm_key,
|
|
693
754
|
pre_mlp_norm_key,
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
|
|
2
2
|
from .deepseek import DEEPSEEK_MODELS
|
|
3
|
+
from .essential_ai import RNJ_MODELS
|
|
3
4
|
from .gemma import GEMMA_MODELS
|
|
4
5
|
from .gpt_oss import GPT_OSS_MODELS
|
|
5
6
|
from .huggingface import HUGGINGFACE_MODELS
|
|
7
|
+
from .lfm2 import LFM2_MODELS
|
|
6
8
|
from .llama import LLAMA_MODELS
|
|
7
9
|
from .llamba import LLAMBA_MODELS
|
|
8
10
|
from .mirai import MIRAI_CLASSIFIER_MODELS
|
|
@@ -24,6 +26,7 @@ __all__ = [
|
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
ALL_MODEL_LISTS = [
|
|
29
|
+
LFM2_MODELS,
|
|
27
30
|
LLAMA_MODELS,
|
|
28
31
|
LLAMBA_MODELS,
|
|
29
32
|
DEEPSEEK_MODELS,
|
|
@@ -36,6 +39,7 @@ ALL_MODEL_LISTS = [
|
|
|
36
39
|
QWEN_MODELS,
|
|
37
40
|
REKA_MODELS,
|
|
38
41
|
MIRAI_CLASSIFIER_MODELS,
|
|
42
|
+
RNJ_MODELS,
|
|
39
43
|
]
|
|
40
44
|
|
|
41
45
|
ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
|
|
@@ -56,6 +56,7 @@ class WeightsType(Enum):
|
|
|
56
56
|
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
57
57
|
else:
|
|
58
58
|
import torch
|
|
59
|
+
|
|
59
60
|
from lalamo.modules.torch_interop import torch_to_jax
|
|
60
61
|
|
|
61
62
|
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs.huggingface import HFGemma3TextConfig
|
|
2
|
+
|
|
3
|
+
from .common import ModelSpec
|
|
4
|
+
|
|
5
|
+
__all__ = ["RNJ_MODELS"]
|
|
6
|
+
|
|
7
|
+
RNJ_MODELS = [
|
|
8
|
+
ModelSpec(
|
|
9
|
+
vendor="EssentialAI",
|
|
10
|
+
family="Rnj-1",
|
|
11
|
+
name="Rnj-1-Instruct",
|
|
12
|
+
size="8B",
|
|
13
|
+
quantization=None,
|
|
14
|
+
repo="EssentialAI/rnj-1-instruct",
|
|
15
|
+
config_type=HFGemma3TextConfig,
|
|
16
|
+
),
|
|
17
|
+
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
2
|
+
|
|
3
|
+
from .common import ConfigMap, FileSpec, ModelSpec
|
|
4
|
+
|
|
5
|
+
__all__ = ["LFM2_MODELS"]
|
|
6
|
+
|
|
7
|
+
LFM2_MODELS = [
|
|
8
|
+
ModelSpec(
|
|
9
|
+
vendor="LiquidAI",
|
|
10
|
+
family="LFM2",
|
|
11
|
+
name="LFM2-2.6B",
|
|
12
|
+
size="2.6B",
|
|
13
|
+
repo="LiquidAI/LFM2-2.6B",
|
|
14
|
+
config_type=HFLFM2Config,
|
|
15
|
+
quantization=None,
|
|
16
|
+
configs=ConfigMap(
|
|
17
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
18
|
+
),
|
|
19
|
+
use_cases=tuple(),
|
|
20
|
+
),
|
|
21
|
+
]
|
|
@@ -69,6 +69,9 @@ from .token_mixers import (
|
|
|
69
69
|
Mamba2Config,
|
|
70
70
|
SeparableCausalConv,
|
|
71
71
|
SeparableCausalConvConfig,
|
|
72
|
+
ShortConv,
|
|
73
|
+
ShortConvConfig,
|
|
74
|
+
ShortConvStateLayer,
|
|
72
75
|
State,
|
|
73
76
|
StaticKVCacheLayer,
|
|
74
77
|
)
|
|
@@ -136,6 +139,9 @@ __all__ = [
|
|
|
136
139
|
"RoutingFunction",
|
|
137
140
|
"SeparableCausalConv",
|
|
138
141
|
"SeparableCausalConvConfig",
|
|
142
|
+
"ShortConv",
|
|
143
|
+
"ShortConvConfig",
|
|
144
|
+
"ShortConvStateLayer",
|
|
139
145
|
"SiLU",
|
|
140
146
|
"SoftmaxRouting",
|
|
141
147
|
"State",
|
|
@@ -3,9 +3,18 @@ from lalamo.modules.common import register_config_union
|
|
|
3
3
|
from .attention import Attention, AttentionConfig, AttentionResult
|
|
4
4
|
from .common import TokenMixerBase, TokenMixerResult
|
|
5
5
|
from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
|
|
6
|
-
from .
|
|
6
|
+
from .short_conv import ShortConv, ShortConvConfig, ShortConvResult
|
|
7
|
+
from .state import (
|
|
8
|
+
DynamicKVCacheLayer,
|
|
9
|
+
KVCacheLayer,
|
|
10
|
+
Mamba2StateLayer,
|
|
11
|
+
ShortConvStateLayer,
|
|
12
|
+
State,
|
|
13
|
+
StateLayerBase,
|
|
14
|
+
StaticKVCacheLayer,
|
|
15
|
+
)
|
|
7
16
|
|
|
8
|
-
TokenMixerConfig = AttentionConfig | Mamba2Config
|
|
17
|
+
TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
|
|
9
18
|
|
|
10
19
|
register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
|
|
11
20
|
|
|
@@ -21,6 +30,10 @@ __all__ = [
|
|
|
21
30
|
"Mamba2StateLayer",
|
|
22
31
|
"SeparableCausalConv",
|
|
23
32
|
"SeparableCausalConvConfig",
|
|
33
|
+
"ShortConv",
|
|
34
|
+
"ShortConvConfig",
|
|
35
|
+
"ShortConvResult",
|
|
36
|
+
"ShortConvStateLayer",
|
|
24
37
|
"State",
|
|
25
38
|
"StateLayerBase",
|
|
26
39
|
"StaticKVCacheLayer",
|