lalamo 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/lfm2.py +174 -0
- lalamo/model_import/loaders/huggingface.py +70 -9
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +1 -0
- lalamo/model_import/model_specs/lfm2.py +21 -0
- lalamo/modules/__init__.py +6 -0
- lalamo/modules/token_mixers/__init__.py +15 -2
- lalamo/modules/token_mixers/common.py +1 -1
- lalamo/modules/token_mixers/mamba.py +2 -2
- lalamo/modules/token_mixers/short_conv.py +168 -0
- lalamo/modules/token_mixers/state/__init__.py +2 -0
- lalamo/modules/token_mixers/state/short_conv_state.py +33 -0
- lalamo/modules/transformer.py +18 -6
- lalamo/modules/transformer_layer.py +1 -1
- {lalamo-0.5.9.dist-info → lalamo-0.5.10.dist-info}/METADATA +1 -1
- {lalamo-0.5.9.dist-info → lalamo-0.5.10.dist-info}/RECORD +23 -19
- {lalamo-0.5.9.dist-info → lalamo-0.5.10.dist-info}/WHEEL +0 -0
- {lalamo-0.5.9.dist-info → lalamo-0.5.10.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.9.dist-info → lalamo-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.9.dist-info → lalamo-0.5.10.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from .huggingface import (
|
|
|
6
6
|
HFGemma3Config,
|
|
7
7
|
HFGemma3TextConfig,
|
|
8
8
|
HFGPTOssConfig,
|
|
9
|
+
HFLFM2Config,
|
|
9
10
|
HFLlamaConfig,
|
|
10
11
|
HFLlambaConfig,
|
|
11
12
|
HFMistralConfig,
|
|
@@ -22,6 +23,7 @@ __all__ = [
|
|
|
22
23
|
"HFGemma2Config",
|
|
23
24
|
"HFGemma3Config",
|
|
24
25
|
"HFGemma3TextConfig",
|
|
26
|
+
"HFLFM2Config",
|
|
25
27
|
"HFLlamaConfig",
|
|
26
28
|
"HFLlambaConfig",
|
|
27
29
|
"HFMistralConfig",
|
|
@@ -2,6 +2,7 @@ from .common import HuggingFaceLMConfig
|
|
|
2
2
|
from .gemma2 import HFGemma2Config
|
|
3
3
|
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
4
|
from .gpt_oss import HFGPTOssConfig
|
|
5
|
+
from .lfm2 import HFLFM2Config
|
|
5
6
|
from .llama import HFLlamaConfig
|
|
6
7
|
from .llamba import HFLlambaConfig
|
|
7
8
|
from .mistral import HFMistralConfig
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"HFGemma2Config",
|
|
15
16
|
"HFGemma3Config",
|
|
16
17
|
"HFGemma3TextConfig",
|
|
18
|
+
"HFLFM2Config",
|
|
17
19
|
"HFLlamaConfig",
|
|
18
20
|
"HFLlambaConfig",
|
|
19
21
|
"HFMistralConfig",
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from jaxtyping import DTypeLike
|
|
6
|
+
|
|
7
|
+
from lalamo.modules import (
|
|
8
|
+
AttentionConfig,
|
|
9
|
+
DecoderConfig,
|
|
10
|
+
DenseMLPConfig,
|
|
11
|
+
FullPrecisionLinearConfig,
|
|
12
|
+
NormalizationConfig,
|
|
13
|
+
SeparableCausalConvConfig,
|
|
14
|
+
ShortConvConfig,
|
|
15
|
+
SiLU,
|
|
16
|
+
TiedEmbeddingConfig,
|
|
17
|
+
TransformerConfig,
|
|
18
|
+
TransformerLayerConfig,
|
|
19
|
+
UnscaledRoPEConfig,
|
|
20
|
+
UntiedEmbeddingConfig,
|
|
21
|
+
UpcastMode,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from .common import HuggingFaceLMConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class HFLFM2Config(HuggingFaceLMConfig):
|
|
29
|
+
architectures: list[Literal["Lfm2ForCausalLM"]]
|
|
30
|
+
block_auto_adjust_ff_dim: Literal[False]
|
|
31
|
+
block_dim: int
|
|
32
|
+
block_ff_dim: int
|
|
33
|
+
block_ffn_dim_multiplier: float
|
|
34
|
+
block_mlp_init_scale: float
|
|
35
|
+
block_multiple_of: int
|
|
36
|
+
block_norm_eps: float
|
|
37
|
+
block_out_init_scale: float
|
|
38
|
+
block_use_swiglu: bool
|
|
39
|
+
block_use_xavier_init: bool
|
|
40
|
+
bos_token_id: int
|
|
41
|
+
conv_L_cache: int # noqa: N815
|
|
42
|
+
conv_bias: int
|
|
43
|
+
conv_dim: int
|
|
44
|
+
conv_dim_out: int
|
|
45
|
+
conv_use_xavier_init: bool
|
|
46
|
+
eos_token_id: int
|
|
47
|
+
hidden_size: int
|
|
48
|
+
initializer_range: float
|
|
49
|
+
intermediate_size: int
|
|
50
|
+
layer_types: list[Literal["conv", "full_attention"]]
|
|
51
|
+
max_position_embeddings: int
|
|
52
|
+
model_type: Literal["lfm2"]
|
|
53
|
+
norm_eps: float
|
|
54
|
+
num_attention_heads: int
|
|
55
|
+
num_heads: int
|
|
56
|
+
num_hidden_layers: int
|
|
57
|
+
num_key_value_heads: int
|
|
58
|
+
pad_token_id: int
|
|
59
|
+
rope_theta: float
|
|
60
|
+
theta: float
|
|
61
|
+
tie_embedding: bool
|
|
62
|
+
torch_dtype: Literal["bfloat16"]
|
|
63
|
+
transformers_version: str
|
|
64
|
+
use_cache: bool
|
|
65
|
+
use_pos_enc: bool
|
|
66
|
+
vocab_size: int
|
|
67
|
+
|
|
68
|
+
def to_decoder_config(
|
|
69
|
+
self,
|
|
70
|
+
context_length: int | None,
|
|
71
|
+
activation_precision: DTypeLike,
|
|
72
|
+
accumulation_precision: DTypeLike,
|
|
73
|
+
metadata_dict: Mapping[str, str], # noqa: ARG002
|
|
74
|
+
) -> DecoderConfig:
|
|
75
|
+
assert self.num_attention_heads == self.num_heads
|
|
76
|
+
|
|
77
|
+
if self.tie_embedding:
|
|
78
|
+
embedding_config = TiedEmbeddingConfig(
|
|
79
|
+
input_scale=None,
|
|
80
|
+
logit_soft_cap=None,
|
|
81
|
+
precision=activation_precision,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
embedding_config = UntiedEmbeddingConfig(
|
|
85
|
+
input_scale=None,
|
|
86
|
+
logit_soft_cap=None,
|
|
87
|
+
precision=activation_precision,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
rope_config = UnscaledRoPEConfig(
|
|
91
|
+
precision=activation_precision,
|
|
92
|
+
base=self.rope_theta,
|
|
93
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
linear_config = FullPrecisionLinearConfig(activation_precision)
|
|
97
|
+
|
|
98
|
+
block_norm_config = NormalizationConfig(
|
|
99
|
+
scale_precision=activation_precision,
|
|
100
|
+
accumulation_precision=accumulation_precision,
|
|
101
|
+
epsilon=self.block_norm_eps,
|
|
102
|
+
scale_offset=None,
|
|
103
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
104
|
+
subtract_mean=False,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
attention_config = AttentionConfig(
|
|
108
|
+
qkv_projection_config=linear_config,
|
|
109
|
+
out_projection_config=linear_config,
|
|
110
|
+
query_norm_config=block_norm_config,
|
|
111
|
+
key_norm_config=block_norm_config,
|
|
112
|
+
num_heads=self.num_attention_heads,
|
|
113
|
+
num_groups=self.num_key_value_heads,
|
|
114
|
+
head_dim=self.hidden_size // self.num_heads,
|
|
115
|
+
is_causal=True,
|
|
116
|
+
scale=None,
|
|
117
|
+
sliding_window_size=None,
|
|
118
|
+
logit_soft_cap=None,
|
|
119
|
+
has_sinks=False,
|
|
120
|
+
has_qkv_biases=False,
|
|
121
|
+
has_out_biases=False,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
short_conv_config = ShortConvConfig(
|
|
125
|
+
in_projection_config=linear_config,
|
|
126
|
+
conv_config=SeparableCausalConvConfig(activation_precision, has_biases=False),
|
|
127
|
+
out_projection_config=linear_config,
|
|
128
|
+
kernel_size=self.conv_L_cache,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
mlp_config = DenseMLPConfig(
|
|
132
|
+
linear_config=linear_config,
|
|
133
|
+
activation=SiLU(),
|
|
134
|
+
has_up_biases=False,
|
|
135
|
+
has_down_biases=False,
|
|
136
|
+
up_clipping=None,
|
|
137
|
+
gate_clipping=None,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
layer_configs = [
|
|
141
|
+
TransformerLayerConfig(
|
|
142
|
+
pre_mixer_norm_config=block_norm_config,
|
|
143
|
+
mixer_config={"conv": short_conv_config, "full_attention": attention_config}[layer_type],
|
|
144
|
+
post_mixer_norm_config=None,
|
|
145
|
+
pre_mlp_norm_config=block_norm_config,
|
|
146
|
+
mlp_config=mlp_config,
|
|
147
|
+
post_mlp_norm_config=None,
|
|
148
|
+
) for layer_type in self.layer_types
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
output_norm_config = NormalizationConfig(
|
|
152
|
+
scale_precision=activation_precision,
|
|
153
|
+
accumulation_precision=accumulation_precision,
|
|
154
|
+
epsilon=self.norm_eps,
|
|
155
|
+
scale_offset=None,
|
|
156
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
157
|
+
subtract_mean=False,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
transformer_config = TransformerConfig(
|
|
161
|
+
global_rope_config=rope_config,
|
|
162
|
+
local_rope_config=None,
|
|
163
|
+
layer_configs=tuple(layer_configs),
|
|
164
|
+
output_norm_config=output_norm_config,
|
|
165
|
+
model_dim=self.hidden_size,
|
|
166
|
+
hidden_dim=self.intermediate_size,
|
|
167
|
+
context_length=context_length or self.max_position_embeddings,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return DecoderConfig(
|
|
171
|
+
embedding_config=embedding_config,
|
|
172
|
+
transformer_config=transformer_config,
|
|
173
|
+
vocab_size=self.vocab_size,
|
|
174
|
+
)
|
|
@@ -8,17 +8,21 @@ from jaxtyping import Array, DTypeLike
|
|
|
8
8
|
from lalamo.common import ParameterPath
|
|
9
9
|
from lalamo.modules import (
|
|
10
10
|
Attention,
|
|
11
|
+
AttentionConfig,
|
|
11
12
|
Decoder,
|
|
12
13
|
DenseMLP,
|
|
13
14
|
FullPrecisionLinear,
|
|
14
15
|
GroupQuantizedLinear,
|
|
15
16
|
LinearBase,
|
|
16
17
|
Mamba2,
|
|
18
|
+
Mamba2Config,
|
|
17
19
|
MLXQuantizedLinear,
|
|
18
20
|
MLXQuantizedTiedEmbedding,
|
|
19
21
|
MLXSemiQuantizedUntiedEmbedding,
|
|
20
22
|
Normalization,
|
|
21
23
|
SeparableCausalConv,
|
|
24
|
+
ShortConv,
|
|
25
|
+
ShortConvConfig,
|
|
22
26
|
TiedEmbedding,
|
|
23
27
|
TransformerLayer,
|
|
24
28
|
UntiedEmbedding,
|
|
@@ -345,21 +349,42 @@ def load_attention(
|
|
|
345
349
|
weights_dict: Mapping[str, Array],
|
|
346
350
|
path: ParameterPath,
|
|
347
351
|
) -> Attention:
|
|
352
|
+
if (path / "o_proj.weight") in weights_dict:
|
|
353
|
+
o_proj_name = "o_proj"
|
|
354
|
+
elif (path / "out_proj.weight") in weights_dict:
|
|
355
|
+
o_proj_name = "out_proj"
|
|
356
|
+
else:
|
|
357
|
+
raise NotImplementedError("Can't determine attention output projection name")
|
|
358
|
+
|
|
348
359
|
qkv_projection = load_linear(
|
|
349
360
|
module.qkv_projection,
|
|
350
361
|
weights_dict,
|
|
351
362
|
path,
|
|
352
363
|
sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
|
|
353
364
|
)
|
|
354
|
-
out_projection = load_linear(module.out_projection, weights_dict, path /
|
|
365
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / o_proj_name)
|
|
355
366
|
|
|
356
367
|
if module.query_norm is not None:
|
|
357
|
-
|
|
368
|
+
if (path / "q_norm.weight") in weights_dict:
|
|
369
|
+
q_norm_name = "q_norm"
|
|
370
|
+
elif (path / "q_layernorm.weight") in weights_dict:
|
|
371
|
+
q_norm_name = "q_layernorm"
|
|
372
|
+
else:
|
|
373
|
+
raise NotImplementedError("Can't determine attention query projection parameter name")
|
|
374
|
+
|
|
375
|
+
query_norm = load_rmsnorm(module.query_norm, weights_dict, path / q_norm_name)
|
|
358
376
|
else:
|
|
359
377
|
query_norm = None
|
|
360
378
|
|
|
361
379
|
if module.key_norm is not None:
|
|
362
|
-
|
|
380
|
+
if (path / "k_norm.weight") in weights_dict:
|
|
381
|
+
k_norm_name = "k_norm"
|
|
382
|
+
elif (path / "k_layernorm.weight") in weights_dict:
|
|
383
|
+
k_norm_name = "k_layernorm"
|
|
384
|
+
else:
|
|
385
|
+
raise NotImplementedError("Can't determine attention key projection parameter name")
|
|
386
|
+
|
|
387
|
+
key_norm = load_rmsnorm(module.key_norm, weights_dict, path / k_norm_name)
|
|
363
388
|
else:
|
|
364
389
|
key_norm = None
|
|
365
390
|
|
|
@@ -382,7 +407,7 @@ def load_attention(
|
|
|
382
407
|
)
|
|
383
408
|
|
|
384
409
|
|
|
385
|
-
def
|
|
410
|
+
def _load_conv(
|
|
386
411
|
conv_module: SeparableCausalConv,
|
|
387
412
|
weights_dict: Mapping[str, Array],
|
|
388
413
|
path: ParameterPath,
|
|
@@ -390,6 +415,8 @@ def _load_mamba_conv(
|
|
|
390
415
|
weight_path = path / "conv1d" / "weight"
|
|
391
416
|
if weight_path not in weights_dict:
|
|
392
417
|
weight_path = path / "conv_weight"
|
|
418
|
+
if weight_path not in weights_dict:
|
|
419
|
+
weight_path = path / "conv.weight"
|
|
393
420
|
if weight_path not in weights_dict:
|
|
394
421
|
weight_path = None
|
|
395
422
|
|
|
@@ -402,6 +429,8 @@ def _load_mamba_conv(
|
|
|
402
429
|
bias_path = path / "conv1d" / "bias"
|
|
403
430
|
if bias_path not in weights_dict:
|
|
404
431
|
bias_path = path / "conv_bias"
|
|
432
|
+
if bias_path not in weights_dict:
|
|
433
|
+
bias_path = path / "conv.bias"
|
|
405
434
|
if bias_path not in weights_dict:
|
|
406
435
|
bias_path = None
|
|
407
436
|
|
|
@@ -424,7 +453,7 @@ def load_mamba2(
|
|
|
424
453
|
) -> Mamba2:
|
|
425
454
|
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
426
455
|
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
427
|
-
conv =
|
|
456
|
+
conv = _load_conv(module.conv, weights_dict, path)
|
|
428
457
|
|
|
429
458
|
skip_connection_weight_path = path / "D"
|
|
430
459
|
if skip_connection_weight_path in weights_dict:
|
|
@@ -451,6 +480,22 @@ def load_mamba2(
|
|
|
451
480
|
)
|
|
452
481
|
|
|
453
482
|
|
|
483
|
+
def load_short_conv(
|
|
484
|
+
module: ShortConv,
|
|
485
|
+
weights_dict: Mapping[str, Array],
|
|
486
|
+
path: ParameterPath,
|
|
487
|
+
) -> ShortConv:
|
|
488
|
+
in_projection = load_linear(module.in_projection, weights_dict, path / "in_proj")
|
|
489
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / "out_proj")
|
|
490
|
+
conv = _load_conv(module.conv, weights_dict, path)
|
|
491
|
+
|
|
492
|
+
return load_parameters(
|
|
493
|
+
lambda m: (m.in_projection, m.out_projection, m.conv),
|
|
494
|
+
module,
|
|
495
|
+
(in_projection, out_projection, conv),
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
454
499
|
def load_transformer_layer(
|
|
455
500
|
module: TransformerLayer,
|
|
456
501
|
weights_dict: Mapping[str, Array],
|
|
@@ -478,6 +523,8 @@ def load_transformer_layer(
|
|
|
478
523
|
mixer = load_attention(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
479
524
|
elif isinstance(module.mixer, Mamba2):
|
|
480
525
|
mixer = load_mamba2(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
526
|
+
elif isinstance(module.mixer, ShortConv):
|
|
527
|
+
mixer = load_short_conv(module.mixer, weights_dict, mixer_path / mixer_key)
|
|
481
528
|
else:
|
|
482
529
|
mixer = module.mixer
|
|
483
530
|
|
|
@@ -625,11 +672,12 @@ def load_huggingface_decoder(
|
|
|
625
672
|
|
|
626
673
|
is_llamba_full_precision = any(key.startswith("backbone.") for key in weights_dict)
|
|
627
674
|
is_llamba_mlx = any(key.startswith("embedding.encoder.") for key in weights_dict)
|
|
675
|
+
is_lfm2 = any(key.startswith("model.layers.0.operator_norm.weight") for key in weights_dict)
|
|
628
676
|
if is_llamba_full_precision:
|
|
629
677
|
decoder_path = base_path / "backbone"
|
|
630
678
|
embedding_path = decoder_path / "embedding"
|
|
631
679
|
pre_mixer_norm_key = "input_layernorm"
|
|
632
|
-
mixer_key = "mixer"
|
|
680
|
+
mixer_key = {Mamba2Config: "mixer"}
|
|
633
681
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
634
682
|
mlp_key = "mlp"
|
|
635
683
|
up_proj_key = "up_proj"
|
|
@@ -642,7 +690,7 @@ def load_huggingface_decoder(
|
|
|
642
690
|
decoder_path = base_path / "model"
|
|
643
691
|
embedding_path = base_path / "embedding.encoder"
|
|
644
692
|
pre_mixer_norm_key = "norm"
|
|
645
|
-
mixer_key = "layer"
|
|
693
|
+
mixer_key = {Mamba2Config: "layer"}
|
|
646
694
|
pre_mlp_norm_key = "norm"
|
|
647
695
|
mlp_key = "layer"
|
|
648
696
|
up_proj_key = "gate_proj"
|
|
@@ -651,11 +699,24 @@ def load_huggingface_decoder(
|
|
|
651
699
|
alternating_layers = True
|
|
652
700
|
norm_key = "norm"
|
|
653
701
|
lm_head_path = base_path / "head.linear"
|
|
702
|
+
elif is_lfm2:
|
|
703
|
+
decoder_path = base_path / "model"
|
|
704
|
+
embedding_path = decoder_path / "embed_tokens"
|
|
705
|
+
pre_mixer_norm_key = "operator_norm"
|
|
706
|
+
mixer_key = {ShortConvConfig: "conv", AttentionConfig: "self_attn"}
|
|
707
|
+
pre_mlp_norm_key = "ffn_norm"
|
|
708
|
+
mlp_key = "feed_forward"
|
|
709
|
+
up_proj_key = "w3"
|
|
710
|
+
gate_proj_key = "w1"
|
|
711
|
+
down_proj_key = "w2"
|
|
712
|
+
alternating_layers = False
|
|
713
|
+
norm_key = "embedding_norm"
|
|
714
|
+
lm_head_path = base_path / "lm_head"
|
|
654
715
|
else:
|
|
655
716
|
decoder_path = base_path / "model"
|
|
656
717
|
embedding_path = decoder_path / "embed_tokens"
|
|
657
718
|
pre_mixer_norm_key = "input_layernorm"
|
|
658
|
-
mixer_key = "self_attn"
|
|
719
|
+
mixer_key = {AttentionConfig: "self_attn"}
|
|
659
720
|
pre_mlp_norm_key = "post_attention_layernorm"
|
|
660
721
|
mlp_key = "mlp"
|
|
661
722
|
up_proj_key = "up_proj"
|
|
@@ -687,7 +748,7 @@ def load_huggingface_decoder(
|
|
|
687
748
|
weights_dict,
|
|
688
749
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
689
750
|
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
690
|
-
mixer_key,
|
|
751
|
+
mixer_key[type(layer.config.mixer_config)], # type: ignore
|
|
691
752
|
mlp_key,
|
|
692
753
|
pre_mixer_norm_key,
|
|
693
754
|
pre_mlp_norm_key,
|
|
@@ -4,6 +4,7 @@ from .essential_ai import RNJ_MODELS
|
|
|
4
4
|
from .gemma import GEMMA_MODELS
|
|
5
5
|
from .gpt_oss import GPT_OSS_MODELS
|
|
6
6
|
from .huggingface import HUGGINGFACE_MODELS
|
|
7
|
+
from .lfm2 import LFM2_MODELS
|
|
7
8
|
from .llama import LLAMA_MODELS
|
|
8
9
|
from .llamba import LLAMBA_MODELS
|
|
9
10
|
from .mirai import MIRAI_CLASSIFIER_MODELS
|
|
@@ -25,6 +26,7 @@ __all__ = [
|
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
ALL_MODEL_LISTS = [
|
|
29
|
+
LFM2_MODELS,
|
|
28
30
|
LLAMA_MODELS,
|
|
29
31
|
LLAMBA_MODELS,
|
|
30
32
|
DEEPSEEK_MODELS,
|
|
@@ -56,6 +56,7 @@ class WeightsType(Enum):
|
|
|
56
56
|
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
57
57
|
else:
|
|
58
58
|
import torch
|
|
59
|
+
|
|
59
60
|
from lalamo.modules.torch_interop import torch_to_jax
|
|
60
61
|
|
|
61
62
|
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
2
|
+
|
|
3
|
+
from .common import ConfigMap, FileSpec, ModelSpec
|
|
4
|
+
|
|
5
|
+
__all__ = ["LFM2_MODELS"]
|
|
6
|
+
|
|
7
|
+
LFM2_MODELS = [
|
|
8
|
+
ModelSpec(
|
|
9
|
+
vendor="LiquidAI",
|
|
10
|
+
family="LFM2",
|
|
11
|
+
name="LFM2-2.6B",
|
|
12
|
+
size="2.6B",
|
|
13
|
+
repo="LiquidAI/LFM2-2.6B",
|
|
14
|
+
config_type=HFLFM2Config,
|
|
15
|
+
quantization=None,
|
|
16
|
+
configs=ConfigMap(
|
|
17
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
18
|
+
),
|
|
19
|
+
use_cases=tuple(),
|
|
20
|
+
),
|
|
21
|
+
]
|
lalamo/modules/__init__.py
CHANGED
|
@@ -69,6 +69,9 @@ from .token_mixers import (
|
|
|
69
69
|
Mamba2Config,
|
|
70
70
|
SeparableCausalConv,
|
|
71
71
|
SeparableCausalConvConfig,
|
|
72
|
+
ShortConv,
|
|
73
|
+
ShortConvConfig,
|
|
74
|
+
ShortConvStateLayer,
|
|
72
75
|
State,
|
|
73
76
|
StaticKVCacheLayer,
|
|
74
77
|
)
|
|
@@ -136,6 +139,9 @@ __all__ = [
|
|
|
136
139
|
"RoutingFunction",
|
|
137
140
|
"SeparableCausalConv",
|
|
138
141
|
"SeparableCausalConvConfig",
|
|
142
|
+
"ShortConv",
|
|
143
|
+
"ShortConvConfig",
|
|
144
|
+
"ShortConvStateLayer",
|
|
139
145
|
"SiLU",
|
|
140
146
|
"SoftmaxRouting",
|
|
141
147
|
"State",
|
|
@@ -3,9 +3,18 @@ from lalamo.modules.common import register_config_union
|
|
|
3
3
|
from .attention import Attention, AttentionConfig, AttentionResult
|
|
4
4
|
from .common import TokenMixerBase, TokenMixerResult
|
|
5
5
|
from .mamba import Mamba2, Mamba2Config, Mamba2Result, SeparableCausalConv, SeparableCausalConvConfig
|
|
6
|
-
from .
|
|
6
|
+
from .short_conv import ShortConv, ShortConvConfig, ShortConvResult
|
|
7
|
+
from .state import (
|
|
8
|
+
DynamicKVCacheLayer,
|
|
9
|
+
KVCacheLayer,
|
|
10
|
+
Mamba2StateLayer,
|
|
11
|
+
ShortConvStateLayer,
|
|
12
|
+
State,
|
|
13
|
+
StateLayerBase,
|
|
14
|
+
StaticKVCacheLayer,
|
|
15
|
+
)
|
|
7
16
|
|
|
8
|
-
TokenMixerConfig = AttentionConfig | Mamba2Config
|
|
17
|
+
TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
|
|
9
18
|
|
|
10
19
|
register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
|
|
11
20
|
|
|
@@ -21,6 +30,10 @@ __all__ = [
|
|
|
21
30
|
"Mamba2StateLayer",
|
|
22
31
|
"SeparableCausalConv",
|
|
23
32
|
"SeparableCausalConvConfig",
|
|
33
|
+
"ShortConv",
|
|
34
|
+
"ShortConvConfig",
|
|
35
|
+
"ShortConvResult",
|
|
36
|
+
"ShortConvStateLayer",
|
|
24
37
|
"State",
|
|
25
38
|
"StateLayerBase",
|
|
26
39
|
"StaticKVCacheLayer",
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from typing import Self
|
|
4
|
+
|
|
5
|
+
import equinox as eqx
|
|
6
|
+
from jax import vmap
|
|
7
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
8
|
+
|
|
9
|
+
from lalamo.common import ParameterTree
|
|
10
|
+
from lalamo.modules.common import PositionalEmbeddingSelector
|
|
11
|
+
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
12
|
+
from lalamo.modules.rope import PositionalEmbeddings
|
|
13
|
+
|
|
14
|
+
from .common import TokenMixerBase, TokenMixerConfigBase, TokenMixerResult
|
|
15
|
+
from .mamba import SeparableCausalConv, SeparableCausalConvConfig
|
|
16
|
+
from .state import ShortConvStateLayer
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"ShortConv",
|
|
20
|
+
"ShortConvConfig",
|
|
21
|
+
"ShortConvResult",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
ShortConvResult = TokenMixerResult[ShortConvStateLayer]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class ShortConvConfig(TokenMixerConfigBase):
|
|
30
|
+
in_projection_config: LinearConfig
|
|
31
|
+
conv_config: SeparableCausalConvConfig
|
|
32
|
+
out_projection_config: LinearConfig
|
|
33
|
+
|
|
34
|
+
kernel_size: int
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def rope_dim(self) -> None:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
def random_init(
|
|
41
|
+
self,
|
|
42
|
+
model_dim: int,
|
|
43
|
+
*,
|
|
44
|
+
key: PRNGKeyArray,
|
|
45
|
+
) -> "ShortConv":
|
|
46
|
+
in_projection = self.in_projection_config.random_init(
|
|
47
|
+
input_dim=model_dim,
|
|
48
|
+
output_dims=(model_dim,)*3,
|
|
49
|
+
has_biases=False,
|
|
50
|
+
key=key,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
conv = self.conv_config.random_init(model_dim, self.kernel_size, key=key)
|
|
54
|
+
|
|
55
|
+
out_projection = self.out_projection_config.random_init(
|
|
56
|
+
input_dim=model_dim,
|
|
57
|
+
output_dims=(model_dim,),
|
|
58
|
+
has_biases=False,
|
|
59
|
+
key=key,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return ShortConv(
|
|
63
|
+
self,
|
|
64
|
+
in_projection=in_projection,
|
|
65
|
+
conv=conv,
|
|
66
|
+
out_projection=out_projection,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def empty(
|
|
70
|
+
self,
|
|
71
|
+
model_dim: int,
|
|
72
|
+
) -> "ShortConv":
|
|
73
|
+
in_projection = self.in_projection_config.empty(
|
|
74
|
+
input_dim=model_dim,
|
|
75
|
+
output_dims=(model_dim,)*3,
|
|
76
|
+
has_biases=False,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
conv = self.conv_config.empty(model_dim, self.kernel_size)
|
|
80
|
+
|
|
81
|
+
out_projection = self.out_projection_config.empty(
|
|
82
|
+
input_dim=model_dim,
|
|
83
|
+
output_dims=(model_dim,),
|
|
84
|
+
has_biases=False,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return ShortConv(
|
|
88
|
+
self,
|
|
89
|
+
in_projection=in_projection,
|
|
90
|
+
conv=conv,
|
|
91
|
+
out_projection=out_projection,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
|
|
96
|
+
in_projection: LinearBase
|
|
97
|
+
conv: SeparableCausalConv
|
|
98
|
+
out_projection: LinearBase
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def activation_precision(self) -> DTypeLike:
|
|
102
|
+
return self.in_projection.activation_precision
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def model_dim(self) -> int:
|
|
106
|
+
return self.in_projection.input_dim
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
|
|
110
|
+
return PositionalEmbeddingSelector.NONE
|
|
111
|
+
|
|
112
|
+
@eqx.filter_jit
|
|
113
|
+
def __call__(
|
|
114
|
+
self,
|
|
115
|
+
inputs: Float[Array, "suffix_tokens channels"],
|
|
116
|
+
positional_embeddings: PositionalEmbeddings | None,
|
|
117
|
+
state: ShortConvStateLayer | None = None,
|
|
118
|
+
return_updated_state: bool = False,
|
|
119
|
+
length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
|
|
120
|
+
) -> TokenMixerResult[ShortConvStateLayer]:
|
|
121
|
+
if positional_embeddings is not None:
|
|
122
|
+
raise ValueError("Positional embeddings are not supported for ShortConv.")
|
|
123
|
+
|
|
124
|
+
pre_conv_gate, post_conv_gate, x = vmap(self.in_projection)(inputs)
|
|
125
|
+
|
|
126
|
+
prev_conv_state = state.conv_state if state is not None else None
|
|
127
|
+
conv_output = self.conv(x * pre_conv_gate, prev_conv_state, return_updated_state)
|
|
128
|
+
|
|
129
|
+
(outputs,) = vmap(self.out_projection)(conv_output.outputs * post_conv_gate)
|
|
130
|
+
updated_conv_state = conv_output.state
|
|
131
|
+
|
|
132
|
+
if return_updated_state:
|
|
133
|
+
assert updated_conv_state is not None
|
|
134
|
+
updated_state = ShortConvStateLayer(updated_conv_state)
|
|
135
|
+
else:
|
|
136
|
+
updated_state = None
|
|
137
|
+
|
|
138
|
+
return TokenMixerResult(outputs, updated_state)
|
|
139
|
+
|
|
140
|
+
def init_static_state(self, capacity: int) -> ShortConvStateLayer: # noqa: ARG002
|
|
141
|
+
return ShortConvStateLayer.init(
|
|
142
|
+
self.config.kernel_size,
|
|
143
|
+
self.in_projection.input_dim,
|
|
144
|
+
self.activation_precision,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def export_weights(self) -> ParameterTree:
|
|
148
|
+
return {
|
|
149
|
+
"in_projection": self.in_projection.export_weights(),
|
|
150
|
+
"conv": self.conv.export_weights(),
|
|
151
|
+
"out_projection": self.out_projection.export_weights(),
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
def import_weights(
|
|
155
|
+
self,
|
|
156
|
+
weights: ParameterTree[Array],
|
|
157
|
+
) -> Self:
|
|
158
|
+
assert isinstance(weights, Mapping)
|
|
159
|
+
assert isinstance(weights["in_projection"], Mapping)
|
|
160
|
+
assert isinstance(weights["conv"], Mapping)
|
|
161
|
+
assert isinstance(weights["out_projection"], Mapping)
|
|
162
|
+
|
|
163
|
+
return replace(
|
|
164
|
+
self,
|
|
165
|
+
in_projection=self.in_projection.import_weights(weights["in_projection"]),
|
|
166
|
+
conv=self.conv.import_weights(weights["conv"]),
|
|
167
|
+
out_projection=self.out_projection.import_weights(weights["out_projection"]),
|
|
168
|
+
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from .common import State, StateLayerBase
|
|
2
2
|
from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
|
|
3
3
|
from .mamba_state import Mamba2StateLayer
|
|
4
|
+
from .short_conv_state import ShortConvStateLayer
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
"DynamicKVCacheLayer",
|
|
7
8
|
"KVCacheLayer",
|
|
8
9
|
"Mamba2StateLayer",
|
|
10
|
+
"ShortConvStateLayer",
|
|
9
11
|
"State",
|
|
10
12
|
"StateLayerBase",
|
|
11
13
|
"StaticKVCacheLayer",
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array, DTypeLike, Float
|
|
5
|
+
|
|
6
|
+
from lalamo.common import ParameterTree
|
|
7
|
+
|
|
8
|
+
from .common import StateLayerBase
|
|
9
|
+
|
|
10
|
+
__all__ = ["ShortConvStateLayer"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ShortConvStateLayer(StateLayerBase):
|
|
14
|
+
conv_state: Float[Array, "*batch tokens conv_channels"]
|
|
15
|
+
|
|
16
|
+
def __post_init__(self) -> None:
|
|
17
|
+
if self.conv_state.ndim not in (2, 3):
|
|
18
|
+
raise ValueError(
|
|
19
|
+
f"Conv state must have 2 or 3 dimensions: [batch], tokens, conv_channels,"
|
|
20
|
+
f" got shape {self.conv_state.shape}",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def init(
|
|
25
|
+
cls,
|
|
26
|
+
kernel_size: int,
|
|
27
|
+
model_dim: int,
|
|
28
|
+
dtype: DTypeLike,
|
|
29
|
+
) -> Self:
|
|
30
|
+
return cls(conv_state=jnp.zeros((kernel_size - 1, model_dim), dtype=dtype))
|
|
31
|
+
|
|
32
|
+
def export(self) -> ParameterTree:
|
|
33
|
+
return dict(conv_state=self.conv_state)
|
lalamo/modules/transformer.py
CHANGED
|
@@ -65,17 +65,23 @@ class TransformerConfig:
|
|
|
65
65
|
context_length: int
|
|
66
66
|
|
|
67
67
|
def random_init(self, *, key: PRNGKeyArray) -> "Transformer":
|
|
68
|
-
|
|
68
|
+
rope_dims = (layer.rope_dim for layer in self.layer_configs if layer.rope_dim is not None)
|
|
69
|
+
rope_dim = next(rope_dims, None)
|
|
70
|
+
assert all(d == rope_dim for d in rope_dims)
|
|
69
71
|
|
|
70
72
|
if self.global_rope_config:
|
|
73
|
+
assert rope_dim is not None
|
|
74
|
+
|
|
71
75
|
global_rope = self.global_rope_config.init(
|
|
72
|
-
head_dim=
|
|
76
|
+
head_dim=rope_dim,
|
|
73
77
|
num_timesteps=self.context_length,
|
|
74
78
|
)
|
|
75
79
|
else:
|
|
76
80
|
global_rope = None
|
|
77
81
|
|
|
78
82
|
if self.local_rope_config:
|
|
83
|
+
assert rope_dim is not None
|
|
84
|
+
|
|
79
85
|
max_sliding_window_size = max(
|
|
80
86
|
layer_config.mixer_config.sliding_window_size or 0
|
|
81
87
|
for layer_config in self.layer_configs
|
|
@@ -83,7 +89,7 @@ class TransformerConfig:
|
|
|
83
89
|
)
|
|
84
90
|
|
|
85
91
|
local_rope = self.local_rope_config.init(
|
|
86
|
-
head_dim=
|
|
92
|
+
head_dim=rope_dim,
|
|
87
93
|
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
88
94
|
)
|
|
89
95
|
else:
|
|
@@ -109,19 +115,25 @@ class TransformerConfig:
|
|
|
109
115
|
)
|
|
110
116
|
|
|
111
117
|
def empty(self) -> "Transformer":
|
|
112
|
-
|
|
118
|
+
rope_dims = (layer.rope_dim for layer in self.layer_configs if layer.rope_dim is not None)
|
|
119
|
+
rope_dim = next(rope_dims, None)
|
|
120
|
+
assert all(d == rope_dim for d in rope_dims)
|
|
113
121
|
|
|
114
122
|
if self.global_rope_config:
|
|
123
|
+
assert rope_dim is not None
|
|
124
|
+
|
|
115
125
|
global_rope = self.global_rope_config.init(
|
|
116
|
-
head_dim=
|
|
126
|
+
head_dim=rope_dim,
|
|
117
127
|
num_timesteps=self.context_length,
|
|
118
128
|
)
|
|
119
129
|
else:
|
|
120
130
|
global_rope = None
|
|
121
131
|
|
|
122
132
|
if self.local_rope_config:
|
|
133
|
+
assert rope_dim is not None
|
|
134
|
+
|
|
123
135
|
local_rope = self.local_rope_config.init(
|
|
124
|
-
head_dim=
|
|
136
|
+
head_dim=rope_dim,
|
|
125
137
|
num_timesteps=self.context_length,
|
|
126
138
|
)
|
|
127
139
|
else:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=sCPww-cg0OE8syJQqxdBI7CV5Mpwxj64H0FNbWdHfO4,815
|
|
2
2
|
lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
|
|
3
3
|
lalamo/main.py,sha256=GgUT7lT48-XQuAEH7qzsDKG8Lx9iBf-sYBIRhZL9q7E,23978
|
|
4
4
|
lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
|
|
@@ -14,14 +14,15 @@ lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vA
|
|
|
14
14
|
lalamo/model_import/common.py,sha256=wvyGD-iLut_Pm3HjDMI05upqdtCW3HWeoeB0YmiFeqk,12419
|
|
15
15
|
lalamo/model_import/huggingface_generation_config.py,sha256=mot6VQ6ezCtEhN6VjhnvaU-nR5P5T2BuBUgpFNnWJxU,1495
|
|
16
16
|
lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
|
|
17
|
-
lalamo/model_import/decoder_configs/__init__.py,sha256=
|
|
17
|
+
lalamo/model_import/decoder_configs/__init__.py,sha256=YvlSsJqNEQPCNKcUzCw0MLjt8H3vcfjc4sz1OK7qdIQ,679
|
|
18
18
|
lalamo/model_import/decoder_configs/common.py,sha256=L8PCgF5fIt3RqPlmLiJpBzDguKk9iTjk4XSItxwVG4c,3260
|
|
19
19
|
lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDjOTrWevfyw7QbWF1mUdOQ,5924
|
|
20
|
-
lalamo/model_import/decoder_configs/huggingface/__init__.py,sha256=
|
|
20
|
+
lalamo/model_import/decoder_configs/huggingface/__init__.py,sha256=AboZJgZxOuIigPShskj-FqBkBqwlJZoKHP0RDqx-MyY,696
|
|
21
21
|
lalamo/model_import/decoder_configs/huggingface/common.py,sha256=YYIDEQy8x7lqL2qtxUHrNqfjZEiizBZ_26sTqOzjRtQ,3792
|
|
22
22
|
lalamo/model_import/decoder_configs/huggingface/gemma2.py,sha256=g8LH_GlSNyL04WWi596zI0rWsD3ahnfNjDk-9zZNcDE,4759
|
|
23
23
|
lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=aSZ0TtpgDYA10rHi8eD0C_Jsn48siM_HXqfZ4O7nh94,8372
|
|
24
24
|
lalamo/model_import/decoder_configs/huggingface/gpt_oss.py,sha256=MBCoPbuWyzbJiBRtHOtpaPHJjQ1UVCAYcVrfIejTnlQ,7446
|
|
25
|
+
lalamo/model_import/decoder_configs/huggingface/lfm2.py,sha256=Esjg9VsIKTE9B9Vu6DHb-VZxSdqxLRgbkyUwpjnmKhc,5510
|
|
25
26
|
lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=UPeQiz2Dix8YaZYRxn9z44OZJ6c4xBQmcUZcM0Ymvh4,6934
|
|
26
27
|
lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=ANB-vQK8U-zVFubZSTDXXt2S70T5SVOGzf7eOVvPzIQ,5773
|
|
27
28
|
lalamo/model_import/decoder_configs/huggingface/mistral.py,sha256=MDGC0ivzJuUpOC11n8vFdcVzqccUyaRw_hkL74mVlAg,4599
|
|
@@ -31,15 +32,16 @@ lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=lySVO-TvusAYUjDn
|
|
|
31
32
|
lalamo/model_import/loaders/__init__.py,sha256=3THc1wQ4EPBzQkL_4EaKCa7Ev5Z7oczcvc4AHy9v5EI,228
|
|
32
33
|
lalamo/model_import/loaders/common.py,sha256=kkugV-bMQlN1zvGHoj3uc7z0FbXKoMtXEBTvyu4KxK4,1844
|
|
33
34
|
lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFlYrqtWyNkBU_8,9219
|
|
34
|
-
lalamo/model_import/loaders/huggingface.py,sha256=
|
|
35
|
+
lalamo/model_import/loaders/huggingface.py,sha256=sErBtGxODzqUkn-hJlzhCNhWmWqTeH4BneeQ8cqDhZo,32283
|
|
35
36
|
lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
|
|
36
|
-
lalamo/model_import/model_specs/__init__.py,sha256=
|
|
37
|
-
lalamo/model_import/model_specs/common.py,sha256=
|
|
37
|
+
lalamo/model_import/model_specs/__init__.py,sha256=JISqwJkloQkGD2jvi1MakNEWapIwlNXXVi5giZyXB74,1275
|
|
38
|
+
lalamo/model_import/model_specs/common.py,sha256=RLySCIkmGiA1IVZgLeemssMBMo4hMYMpmBjV0cRwBb4,6586
|
|
38
39
|
lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
|
|
39
40
|
lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
|
|
40
41
|
lalamo/model_import/model_specs/gemma.py,sha256=irWgylL-pc7y3Gn5DK3fjKoCT9kJWH3B7mTa-1Gmxqc,1306
|
|
41
42
|
lalamo/model_import/model_specs/gpt_oss.py,sha256=PLo0QGrXKdX61ReTRdyOaP_EH3Dmj5lp3fpJjZRwRVA,542
|
|
42
43
|
lalamo/model_import/model_specs/huggingface.py,sha256=TEkU8y95_hmUWyF-Q5hn0dE2SvXbApghAsQwhWRu4D0,431
|
|
44
|
+
lalamo/model_import/model_specs/lfm2.py,sha256=UlCQkKBWu7YMlc3L_c-cMOgXKw7j2wCHIu9ELwkkoCE,498
|
|
43
45
|
lalamo/model_import/model_specs/llama.py,sha256=Ml-xvRGlXBT9NJhmEpwgNo6C84oBSMYgA1_PrCYGcAw,990
|
|
44
46
|
lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
|
|
45
47
|
lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
|
|
@@ -52,7 +54,7 @@ lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
|
|
|
52
54
|
lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
|
|
53
55
|
lalamo/models/common.py,sha256=PDteofGxjSBWYw_mPxbN1DTUba70aOURrAIjl13SSHc,2954
|
|
54
56
|
lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
|
|
55
|
-
lalamo/modules/__init__.py,sha256=
|
|
57
|
+
lalamo/modules/__init__.py,sha256=dFCicpcx-XV9sVTMR7x4TVF2tAGpzFi_sCTPAyawoJo,3858
|
|
56
58
|
lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
|
|
57
59
|
lalamo/modules/classifier.py,sha256=_jtJ3INEq1dJP5HpUmcDk9YYzpRYlQ04zvFGaWBV6Lg,12101
|
|
58
60
|
lalamo/modules/common.py,sha256=dqDEOi-C3H4U9iWUisU32RA-wRDCGuaUNGbObRBhyQM,3315
|
|
@@ -64,26 +66,28 @@ lalamo/modules/mlx_interop.py,sha256=FdfU_1iES-HQ9r4K0SkYwJTyvE0f-_T5ursNCjPLZKY
|
|
|
64
66
|
lalamo/modules/normalization.py,sha256=cBdOq6OmJssunVeEwFRJD0BDhgFAN7J8gOKwzIUAY8I,3005
|
|
65
67
|
lalamo/modules/rope.py,sha256=rCik7vBNqRXYg3LGbmc1mezPRNbIYMg5cydTFpQy-eU,10157
|
|
66
68
|
lalamo/modules/torch_interop.py,sha256=-mujd1zI4ec2w92Hd50RtDa0K3jl6ZSnPxc5r3Fp9nU,916
|
|
67
|
-
lalamo/modules/transformer.py,sha256=
|
|
68
|
-
lalamo/modules/transformer_layer.py,sha256=
|
|
69
|
+
lalamo/modules/transformer.py,sha256=4olEO8Eh7U6RwSnaECn39ooPuTKUZp_6QmvO6vdirrQ,10532
|
|
70
|
+
lalamo/modules/transformer_layer.py,sha256=ZYmGR2Ej328l7K-YpV4eEiBk8SzLsw1RiuSiUP94UpY,12731
|
|
69
71
|
lalamo/modules/utils.py,sha256=t_TayWT6g5LtYKhJaod-u_COWaI_VbNd3eYek9Nj0lc,441
|
|
70
|
-
lalamo/modules/token_mixers/__init__.py,sha256=
|
|
72
|
+
lalamo/modules/token_mixers/__init__.py,sha256=z6x8cNjis6xIi_2llIoByKqMF2W4xJ05rDnxitHQ3jU,1139
|
|
71
73
|
lalamo/modules/token_mixers/attention.py,sha256=gkGMFah2OHB_tyJpkshM1KhMnzG6U7Xt273MkBvDk58,16584
|
|
72
|
-
lalamo/modules/token_mixers/common.py,sha256
|
|
73
|
-
lalamo/modules/token_mixers/mamba.py,sha256=
|
|
74
|
-
lalamo/modules/token_mixers/
|
|
74
|
+
lalamo/modules/token_mixers/common.py,sha256=CcrbXXvGU27uxGLh5L-G8VDtcOiW5Wpm13uBEOd6lVg,1986
|
|
75
|
+
lalamo/modules/token_mixers/mamba.py,sha256=fo8xvvmIQss2lKLhav19Jzk1-hTykNp2sjcN6ntcWj4,18789
|
|
76
|
+
lalamo/modules/token_mixers/short_conv.py,sha256=93SmoVsuAtdX4ckAkvhHXHiO67pU6soYFpBZxdPFEwc,5219
|
|
77
|
+
lalamo/modules/token_mixers/state/__init__.py,sha256=OKWPmiwszMWgwamewoVHd28owanHAO2j2e30Iivtv-4,384
|
|
75
78
|
lalamo/modules/token_mixers/state/common.py,sha256=dcwBevAdeJpBjf7_YRk7TKrJHsCnpljhfzZy-3h9898,661
|
|
76
79
|
lalamo/modules/token_mixers/state/kv_cache.py,sha256=QfnS3XgSmyDI9MBUbeLI4ABHLxiMcXDbZsqe0fd3KQo,8788
|
|
77
80
|
lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxbnMJCl--JOoe9Fkcc9Mg,1642
|
|
81
|
+
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
78
82
|
lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
|
|
79
83
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
80
84
|
lalamo/speculator/estimator.py,sha256=4D8dPZCWsrpORb7y8pQ6VsiIg1Cblvvxe6gXCoYtcD4,2530
|
|
81
85
|
lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
|
|
82
86
|
lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
|
|
83
87
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
84
|
-
lalamo-0.5.
|
|
85
|
-
lalamo-0.5.
|
|
86
|
-
lalamo-0.5.
|
|
87
|
-
lalamo-0.5.
|
|
88
|
-
lalamo-0.5.
|
|
89
|
-
lalamo-0.5.
|
|
88
|
+
lalamo-0.5.10.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
89
|
+
lalamo-0.5.10.dist-info/METADATA,sha256=7KSYbe35d3aafssFta83t2MzVShN0JJsVd5nPfjb2VA,3147
|
|
90
|
+
lalamo-0.5.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
91
|
+
lalamo-0.5.10.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
92
|
+
lalamo-0.5.10.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
93
|
+
lalamo-0.5.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|