lalamo 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +4 -18
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
- lalamo-0.5.0.dist-info/RECORD +80 -0
- lalamo-0.4.0.dist-info/RECORD +0 -71
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFLlambaConfig
|
|
2
|
+
from lalamo.quantization import QuantizationMode
|
|
3
|
+
|
|
4
|
+
from .common import ConfigMap, FileSpec, ModelSpec
|
|
5
|
+
|
|
6
|
+
__all__ = ["LLAMBA_MODELS"]
|
|
7
|
+
|
|
8
|
+
LLAMBA_MODELS = [
|
|
9
|
+
ModelSpec(
|
|
10
|
+
vendor="Cartesia",
|
|
11
|
+
family="Llamba",
|
|
12
|
+
name="Llamba-1B",
|
|
13
|
+
size="1B",
|
|
14
|
+
quantization=None,
|
|
15
|
+
repo="cartesia-ai/Llamba-1B",
|
|
16
|
+
config_type=HFLlambaConfig,
|
|
17
|
+
configs=ConfigMap(
|
|
18
|
+
tokenizer=FileSpec("tokenizer.json", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
19
|
+
tokenizer_config=FileSpec("tokenizer_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
20
|
+
generation_config=FileSpec("generation_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
21
|
+
),
|
|
22
|
+
use_cases=tuple(),
|
|
23
|
+
),
|
|
24
|
+
ModelSpec(
|
|
25
|
+
vendor="Cartesia",
|
|
26
|
+
family="Llamba",
|
|
27
|
+
name="Llamba-1B-4bit-mlx",
|
|
28
|
+
size="1B",
|
|
29
|
+
quantization=QuantizationMode.UINT4,
|
|
30
|
+
repo="cartesia-ai/Llamba-1B-4bit-mlx",
|
|
31
|
+
config_type=HFLlambaConfig,
|
|
32
|
+
configs=ConfigMap(
|
|
33
|
+
model_config=FileSpec("config.json", "cartesia-ai/Llamba-1B"),
|
|
34
|
+
tokenizer=FileSpec("tokenizer.json", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
35
|
+
tokenizer_config=FileSpec("tokenizer_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
36
|
+
generation_config=FileSpec("generation_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
37
|
+
),
|
|
38
|
+
use_cases=tuple(),
|
|
39
|
+
),
|
|
40
|
+
]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from lalamo.model_import.decoder_configs import HFQwen2Config, HFQwen3Config
|
|
2
2
|
from lalamo.quantization import QuantizationMode
|
|
3
3
|
|
|
4
|
-
from .common import ModelSpec, UseCase, WeightsType
|
|
4
|
+
from .common import ConfigMap, FileSpec, ModelSpec, UseCase, WeightsType
|
|
5
5
|
|
|
6
6
|
__all__ = ["QWEN_MODELS"]
|
|
7
7
|
|
|
@@ -148,6 +148,20 @@ QWEN3 = [
|
|
|
148
148
|
repo="Qwen/Qwen3-0.6B",
|
|
149
149
|
config_type=HFQwen3Config,
|
|
150
150
|
),
|
|
151
|
+
ModelSpec(
|
|
152
|
+
vendor="Alibaba",
|
|
153
|
+
family="Qwen3",
|
|
154
|
+
name="Qwen3-0.6B-MLX-4bit",
|
|
155
|
+
size="0.6B",
|
|
156
|
+
quantization=QuantizationMode.UINT4,
|
|
157
|
+
repo="Qwen/Qwen3-0.6B-MLX-4bit",
|
|
158
|
+
config_type=HFQwen3Config,
|
|
159
|
+
configs=ConfigMap(
|
|
160
|
+
tokenizer=FileSpec("tokenizer.json", "Qwen/Qwen3-0.6B"),
|
|
161
|
+
tokenizer_config=FileSpec("tokenizer_config.json", "Qwen/Qwen3-0.6B"),
|
|
162
|
+
generation_config=FileSpec("generation_config.json", "Qwen/Qwen3-0.6B"),
|
|
163
|
+
),
|
|
164
|
+
),
|
|
151
165
|
ModelSpec(
|
|
152
166
|
vendor="Alibaba",
|
|
153
167
|
family="Qwen3",
|
|
@@ -177,6 +191,20 @@ QWEN3 = [
|
|
|
177
191
|
repo="Qwen/Qwen3-4B-AWQ",
|
|
178
192
|
config_type=HFQwen3Config,
|
|
179
193
|
),
|
|
194
|
+
ModelSpec(
|
|
195
|
+
vendor="Alibaba",
|
|
196
|
+
family="Qwen3",
|
|
197
|
+
name="Qwen3-4B-MLX-4bit",
|
|
198
|
+
size="4B",
|
|
199
|
+
quantization=QuantizationMode.UINT4,
|
|
200
|
+
repo="Qwen/Qwen3-4B-MLX-4bit",
|
|
201
|
+
config_type=HFQwen3Config,
|
|
202
|
+
configs=ConfigMap(
|
|
203
|
+
tokenizer=FileSpec("tokenizer.json", "Qwen/Qwen3-4B"),
|
|
204
|
+
tokenizer_config=FileSpec("tokenizer_config.json", "Qwen/Qwen3-4B"),
|
|
205
|
+
generation_config=FileSpec("generation_config.json", "Qwen/Qwen3-4B"),
|
|
206
|
+
),
|
|
207
|
+
),
|
|
180
208
|
ModelSpec(
|
|
181
209
|
vendor="Alibaba",
|
|
182
210
|
family="Qwen3",
|
lalamo/modules/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
from .activations import GELU, Activation, SiLU
|
|
2
|
-
from .
|
|
3
|
-
from .common import AttentionType, ForwardPassMode, LalamoModule, config_converter
|
|
1
|
+
from .activations import GELU, Activation, Identity, SiLU
|
|
2
|
+
from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector, config_converter
|
|
4
3
|
from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderForwardPassConfig, DecoderResult
|
|
5
4
|
from .decoder_layer import (
|
|
6
5
|
DecoderLayer,
|
|
@@ -12,6 +11,10 @@ from .decoder_layer import (
|
|
|
12
11
|
from .embedding import (
|
|
13
12
|
EmbeddingBase,
|
|
14
13
|
EmbeddingConfig,
|
|
14
|
+
MLXQuantizedTiedEmbedding,
|
|
15
|
+
MLXQuantizedTiedEmbeddingConfig,
|
|
16
|
+
MLXSemiQuantizedUntiedEmbedding,
|
|
17
|
+
MLXSemiQuantizedUntiedEmbeddingConfig,
|
|
15
18
|
QuantizedTiedEmbedding,
|
|
16
19
|
QuantizedTiedEmbeddingConfig,
|
|
17
20
|
TiedEmbedding,
|
|
@@ -19,7 +22,6 @@ from .embedding import (
|
|
|
19
22
|
UntiedEmbedding,
|
|
20
23
|
UntiedEmbeddingConfig,
|
|
21
24
|
)
|
|
22
|
-
from .kv_cache import DynamicKVCacheLayer, KVCache, KVCacheLayer, StaticKVCacheLayer
|
|
23
25
|
from .linear import (
|
|
24
26
|
FullPrecisionLinear,
|
|
25
27
|
FullPrecisionLinearConfig,
|
|
@@ -27,6 +29,8 @@ from .linear import (
|
|
|
27
29
|
GroupQuantizedLinearConfig,
|
|
28
30
|
LinearBase,
|
|
29
31
|
LinearConfig,
|
|
32
|
+
MLXQuantizedLinear,
|
|
33
|
+
MLXQuantizedLinearConfig,
|
|
30
34
|
QLoRALinear,
|
|
31
35
|
QLoRALinearConfig,
|
|
32
36
|
)
|
|
@@ -51,13 +55,24 @@ from .rope import (
|
|
|
51
55
|
UnscaledRoPEConfig,
|
|
52
56
|
YARNRoPEConfig,
|
|
53
57
|
)
|
|
58
|
+
from .token_mixers import (
|
|
59
|
+
Attention,
|
|
60
|
+
AttentionConfig,
|
|
61
|
+
DynamicKVCacheLayer,
|
|
62
|
+
KVCacheLayer,
|
|
63
|
+
Mamba2,
|
|
64
|
+
Mamba2Config,
|
|
65
|
+
SeparableCausalConv,
|
|
66
|
+
SeparableCausalConvConfig,
|
|
67
|
+
State,
|
|
68
|
+
StaticKVCacheLayer,
|
|
69
|
+
)
|
|
54
70
|
|
|
55
71
|
__all__ = [
|
|
56
72
|
"GELU",
|
|
57
73
|
"Activation",
|
|
58
74
|
"Attention",
|
|
59
75
|
"AttentionConfig",
|
|
60
|
-
"AttentionType",
|
|
61
76
|
"Decoder",
|
|
62
77
|
"DecoderActivationTrace",
|
|
63
78
|
"DecoderConfig",
|
|
@@ -78,7 +93,7 @@ __all__ = [
|
|
|
78
93
|
"FullPrecisionLinearConfig",
|
|
79
94
|
"GroupQuantizedLinear",
|
|
80
95
|
"GroupQuantizedLinearConfig",
|
|
81
|
-
"
|
|
96
|
+
"Identity",
|
|
82
97
|
"KVCacheLayer",
|
|
83
98
|
"LalamoModule",
|
|
84
99
|
"LinearBase",
|
|
@@ -88,8 +103,17 @@ __all__ = [
|
|
|
88
103
|
"MLPBase",
|
|
89
104
|
"MLPConfig",
|
|
90
105
|
"MLPForwardPassConfig",
|
|
106
|
+
"MLXQuantizedLinear",
|
|
107
|
+
"MLXQuantizedLinearConfig",
|
|
108
|
+
"MLXQuantizedTiedEmbedding",
|
|
109
|
+
"MLXQuantizedTiedEmbeddingConfig",
|
|
110
|
+
"MLXSemiQuantizedUntiedEmbedding",
|
|
111
|
+
"MLXSemiQuantizedUntiedEmbeddingConfig",
|
|
112
|
+
"Mamba2",
|
|
113
|
+
"Mamba2Config",
|
|
91
114
|
"MixtureOfExperts",
|
|
92
115
|
"MixtureOfExpertsConfig",
|
|
116
|
+
"PositionalEmbeddingSelector",
|
|
93
117
|
"PositionalEmbeddings",
|
|
94
118
|
"QLoRALinear",
|
|
95
119
|
"QLoRALinearConfig",
|
|
@@ -100,8 +124,11 @@ __all__ = [
|
|
|
100
124
|
"RoPE",
|
|
101
125
|
"RoPEConfig",
|
|
102
126
|
"RoutingFunction",
|
|
127
|
+
"SeparableCausalConv",
|
|
128
|
+
"SeparableCausalConvConfig",
|
|
103
129
|
"SiLU",
|
|
104
130
|
"SoftmaxRouting",
|
|
131
|
+
"State",
|
|
105
132
|
"StaticKVCacheLayer",
|
|
106
133
|
"TiedEmbedding",
|
|
107
134
|
"TiedEmbeddingConfig",
|
lalamo/modules/activations.py
CHANGED
|
@@ -10,6 +10,7 @@ from lalamo.modules.common import register_config_union
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"GELU",
|
|
12
12
|
"Activation",
|
|
13
|
+
"Identity",
|
|
13
14
|
"SiLU",
|
|
14
15
|
]
|
|
15
16
|
|
|
@@ -34,7 +35,13 @@ class GELU(ActivationBase):
|
|
|
34
35
|
return jax.nn.gelu(x)
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class Identity(ActivationBase):
|
|
40
|
+
def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
|
|
41
|
+
return x
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
Activation = SiLU | GELU | Identity
|
|
38
45
|
|
|
39
46
|
|
|
40
|
-
register_config_union(Activation)
|
|
47
|
+
register_config_union(Activation) # type: ignore (pyright bug)
|
lalamo/modules/common.py
CHANGED
|
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from types import UnionType
|
|
5
|
-
from typing import Self
|
|
5
|
+
from typing import Any, Self
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
from cattrs import Converter
|
|
@@ -12,18 +12,19 @@ from jaxtyping import Array, DTypeLike
|
|
|
12
12
|
from lalamo.common import ParameterTree
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
|
-
"AttentionType",
|
|
16
15
|
"DummyUnionMember",
|
|
17
16
|
"ForwardPassMode",
|
|
18
17
|
"LalamoModule",
|
|
18
|
+
"PositionalEmbeddingSelector",
|
|
19
19
|
"config_converter",
|
|
20
20
|
"register_config_union",
|
|
21
21
|
]
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class
|
|
24
|
+
class PositionalEmbeddingSelector(Enum):
|
|
25
25
|
GLOBAL = "global"
|
|
26
|
-
|
|
26
|
+
LOCAL = "sliding_window"
|
|
27
|
+
NONE = "none"
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class ForwardPassMode(Enum):
|
|
@@ -128,4 +129,8 @@ def register_config_union(union_type: UnionType) -> None:
|
|
|
128
129
|
|
|
129
130
|
@dataclass
|
|
130
131
|
class DummyUnionMember:
|
|
131
|
-
|
|
132
|
+
def __getattribute__(self, name: str, /) -> Any: # noqa: ANN401
|
|
133
|
+
raise NotImplementedError
|
|
134
|
+
|
|
135
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
|
136
|
+
raise NotImplementedError
|
lalamo/modules/decoder.py
CHANGED
|
@@ -8,14 +8,14 @@ from jax import vmap
|
|
|
8
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
9
|
|
|
10
10
|
from lalamo.common import ParameterTree
|
|
11
|
-
from lalamo.modules.utils import vmap_twice
|
|
12
11
|
|
|
13
|
-
from .common import
|
|
12
|
+
from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
|
|
14
13
|
from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerForwardPassConfig, DecoderLayerResult
|
|
15
14
|
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
16
|
-
from .kv_cache import KVCache
|
|
17
15
|
from .normalization import RMSNorm, RMSNormConfig
|
|
18
16
|
from .rope import PositionalEmbeddings, RoPE, RoPEConfig
|
|
17
|
+
from .token_mixers import AttentionConfig, State
|
|
18
|
+
from .utils import vmap_twice
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"Decoder",
|
|
@@ -32,42 +32,42 @@ type DecoderForwardPassConfig = DecoderLayerForwardPassConfig
|
|
|
32
32
|
class DecoderActivationTrace(eqx.Module):
|
|
33
33
|
token_ids: Int[Array, "batch suffix_tokens"]
|
|
34
34
|
token_positions: Int[Array, "batch suffix_tokens"]
|
|
35
|
-
|
|
35
|
+
state: State | None
|
|
36
36
|
|
|
37
|
-
local_positional_embeddings: PositionalEmbeddings
|
|
38
|
-
global_positional_embeddings: PositionalEmbeddings
|
|
37
|
+
local_positional_embeddings: PositionalEmbeddings | None
|
|
38
|
+
global_positional_embeddings: PositionalEmbeddings | None
|
|
39
39
|
|
|
40
40
|
layer_results: tuple[DecoderLayerResult, ...]
|
|
41
41
|
|
|
42
42
|
output_norm: Float[Array, "batch suffix_tokens channels"]
|
|
43
43
|
|
|
44
44
|
def export(self) -> ParameterTree:
|
|
45
|
-
result = dict(
|
|
45
|
+
result: dict[str, ParameterTree | Array] = dict(
|
|
46
46
|
token_ids=self.token_ids,
|
|
47
47
|
token_positions=self.token_positions,
|
|
48
|
-
local_positional_embeddings=self.local_positional_embeddings.export(),
|
|
49
|
-
global_positional_embeddings=self.global_positional_embeddings.export(),
|
|
50
48
|
layer_results=[layer_result.export() for layer_result in self.layer_results],
|
|
51
49
|
output_norm=self.output_norm,
|
|
52
50
|
)
|
|
53
|
-
if self.
|
|
54
|
-
result["
|
|
51
|
+
if self.local_positional_embeddings is not None:
|
|
52
|
+
result["local_positional_embeddings"] = self.local_positional_embeddings.export()
|
|
53
|
+
if self.global_positional_embeddings is not None:
|
|
54
|
+
result["global_positional_embeddings"] = self.global_positional_embeddings.export()
|
|
55
|
+
if self.state is not None:
|
|
56
|
+
result["state"] = [state_layer.export() for state_layer in self.state]
|
|
55
57
|
return result
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
class DecoderResult(eqx.Module):
|
|
59
61
|
logits: Float[Array, "batch suffix_tokens channels"]
|
|
60
|
-
|
|
62
|
+
updated_state: State | None = None
|
|
61
63
|
activation_trace: DecoderActivationTrace | None = None
|
|
62
64
|
|
|
63
65
|
def export(self) -> ParameterTree:
|
|
64
66
|
result: dict[str, ParameterTree | Array] = dict(
|
|
65
67
|
logits=self.logits,
|
|
66
68
|
)
|
|
67
|
-
if self.
|
|
68
|
-
result["
|
|
69
|
-
kv_cache_layer_slice.export() for kv_cache_layer_slice in self.updated_kv_cache
|
|
70
|
-
]
|
|
69
|
+
if self.updated_state is not None:
|
|
70
|
+
result["updated_state"] = [state_layer.export() for state_layer in self.updated_state]
|
|
71
71
|
if self.activation_trace is not None:
|
|
72
72
|
result["activation_trace"] = self.activation_trace.export()
|
|
73
73
|
return result
|
|
@@ -76,33 +76,16 @@ class DecoderResult(eqx.Module):
|
|
|
76
76
|
@dataclass(frozen=True)
|
|
77
77
|
class DecoderConfig:
|
|
78
78
|
embedding_config: EmbeddingConfig
|
|
79
|
-
global_rope_config: RoPEConfig
|
|
79
|
+
global_rope_config: RoPEConfig | None
|
|
80
80
|
local_rope_config: RoPEConfig | None
|
|
81
|
-
|
|
81
|
+
layer_configs: tuple[DecoderLayerConfig, ...]
|
|
82
82
|
output_norm_config: RMSNormConfig
|
|
83
83
|
|
|
84
84
|
vocab_size: int
|
|
85
85
|
model_dim: int
|
|
86
86
|
hidden_dim: int
|
|
87
|
-
num_heads: int
|
|
88
|
-
num_groups: int
|
|
89
|
-
head_dim: int
|
|
90
|
-
attention_scale: float | None
|
|
91
|
-
num_layers: int
|
|
92
|
-
sliding_window_sizes: tuple[int | None, ...] | None
|
|
93
87
|
context_length: int
|
|
94
88
|
|
|
95
|
-
def __post_init__(self) -> None:
|
|
96
|
-
if self.local_rope_config is not None and self.sliding_window_sizes is None:
|
|
97
|
-
raise ValueError("Sliding window sizes must be provided when using local RoPE")
|
|
98
|
-
if self.sliding_window_sizes is None:
|
|
99
|
-
return
|
|
100
|
-
if len(self.sliding_window_sizes) != self.num_layers:
|
|
101
|
-
raise ValueError(
|
|
102
|
-
f"Number of sliding window sizes {len(self.sliding_window_sizes)} does not match"
|
|
103
|
-
f" the number of layers {self.num_layers}",
|
|
104
|
-
)
|
|
105
|
-
|
|
106
89
|
def random_init(
|
|
107
90
|
self,
|
|
108
91
|
*,
|
|
@@ -114,40 +97,38 @@ class DecoderConfig:
|
|
|
114
97
|
model_dim=self.model_dim,
|
|
115
98
|
key=embedding_key,
|
|
116
99
|
)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
100
|
+
|
|
101
|
+
first_layer_config, *_ = self.layer_configs
|
|
102
|
+
|
|
103
|
+
if self.global_rope_config:
|
|
104
|
+
global_rope = self.global_rope_config.init(
|
|
105
|
+
head_dim=first_layer_config.rope_dim,
|
|
106
|
+
num_timesteps=self.context_length,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
global_rope = None
|
|
121
110
|
|
|
122
111
|
if self.local_rope_config:
|
|
123
|
-
assert self.sliding_window_sizes is not None
|
|
124
112
|
max_sliding_window_size = max(
|
|
125
|
-
|
|
113
|
+
layer_config.mixer_config.sliding_window_size or 0
|
|
114
|
+
for layer_config in self.layer_configs
|
|
115
|
+
if isinstance(layer_config.mixer_config, AttentionConfig)
|
|
126
116
|
)
|
|
127
117
|
local_rope = self.local_rope_config.init(
|
|
128
|
-
head_dim=
|
|
118
|
+
head_dim=first_layer_config.rope_dim,
|
|
129
119
|
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
130
120
|
)
|
|
131
121
|
else:
|
|
132
122
|
local_rope = None
|
|
133
123
|
|
|
134
|
-
|
|
135
|
-
sliding_window_sizes = [None] * self.num_layers
|
|
136
|
-
else:
|
|
137
|
-
sliding_window_sizes = self.sliding_window_sizes
|
|
138
|
-
layers_keys = jax.random.split(layers_key, self.num_layers)
|
|
124
|
+
layers_keys = jax.random.split(layers_key, len(self.layer_configs))
|
|
139
125
|
layers = tuple(
|
|
140
|
-
|
|
126
|
+
layer_config.random_init(
|
|
141
127
|
model_dim=self.model_dim,
|
|
142
128
|
hidden_dim=self.hidden_dim,
|
|
143
|
-
num_heads=self.num_heads,
|
|
144
|
-
num_groups=self.num_groups,
|
|
145
|
-
head_dim=self.head_dim,
|
|
146
|
-
attention_scale=self.attention_scale,
|
|
147
|
-
sliding_window_size=sliding_window_size,
|
|
148
129
|
key=key,
|
|
149
130
|
)
|
|
150
|
-
for
|
|
131
|
+
for layer_config, key in zip(self.layer_configs, layers_keys, strict=False)
|
|
151
132
|
)
|
|
152
133
|
output_norm = self.output_norm_config.init(self.model_dim)
|
|
153
134
|
return Decoder(
|
|
@@ -166,34 +147,35 @@ class DecoderConfig:
|
|
|
166
147
|
vocab_size=self.vocab_size,
|
|
167
148
|
model_dim=self.model_dim,
|
|
168
149
|
)
|
|
169
|
-
global_rope = self.global_rope_config.init(
|
|
170
|
-
head_dim=self.head_dim,
|
|
171
|
-
num_timesteps=self.context_length,
|
|
172
|
-
)
|
|
173
150
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
151
|
+
first_layer_config, *_ = self.layer_configs
|
|
152
|
+
|
|
153
|
+
if self.global_rope_config:
|
|
154
|
+
global_rope = self.global_rope_config.init(
|
|
155
|
+
head_dim=first_layer_config.rope_dim,
|
|
177
156
|
num_timesteps=self.context_length,
|
|
178
157
|
)
|
|
179
158
|
else:
|
|
180
|
-
|
|
159
|
+
global_rope = None
|
|
181
160
|
|
|
182
|
-
if self.
|
|
183
|
-
|
|
161
|
+
if self.local_rope_config:
|
|
162
|
+
max_sliding_window_size = max(
|
|
163
|
+
layer_config.mixer_config.sliding_window_size or 0
|
|
164
|
+
for layer_config in self.layer_configs
|
|
165
|
+
if isinstance(layer_config.mixer_config, AttentionConfig)
|
|
166
|
+
)
|
|
167
|
+
local_rope = self.local_rope_config.init(
|
|
168
|
+
head_dim=first_layer_config.rope_dim,
|
|
169
|
+
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
170
|
+
)
|
|
184
171
|
else:
|
|
185
|
-
|
|
172
|
+
local_rope = None
|
|
186
173
|
layers = tuple(
|
|
187
|
-
|
|
174
|
+
layer_config.empty(
|
|
188
175
|
model_dim=self.model_dim,
|
|
189
176
|
hidden_dim=self.hidden_dim,
|
|
190
|
-
num_heads=self.num_heads,
|
|
191
|
-
num_groups=self.num_groups,
|
|
192
|
-
head_dim=self.head_dim,
|
|
193
|
-
attention_scale=self.attention_scale,
|
|
194
|
-
sliding_window_size=sliding_window_size,
|
|
195
177
|
)
|
|
196
|
-
for
|
|
178
|
+
for layer_config in self.layer_configs
|
|
197
179
|
)
|
|
198
180
|
output_norm = self.output_norm_config.empty(self.model_dim)
|
|
199
181
|
return Decoder(
|
|
@@ -208,7 +190,7 @@ class DecoderConfig:
|
|
|
208
190
|
|
|
209
191
|
class Decoder(LalamoModule[DecoderConfig]):
|
|
210
192
|
embedding: EmbeddingBase
|
|
211
|
-
global_rope: RoPE
|
|
193
|
+
global_rope: RoPE | None
|
|
212
194
|
local_rope: RoPE | None
|
|
213
195
|
layers: tuple[DecoderLayer, ...]
|
|
214
196
|
output_norm: RMSNorm
|
|
@@ -218,12 +200,12 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
218
200
|
return self.embedding.activation_precision
|
|
219
201
|
|
|
220
202
|
@eqx.filter_jit
|
|
221
|
-
def __call__(
|
|
203
|
+
def __call__( # noqa: PLR0912
|
|
222
204
|
self,
|
|
223
205
|
token_ids: Int[Array, "batch suffix_tokens"],
|
|
224
206
|
token_positions: Int[Array, "batch suffix_tokens"],
|
|
225
|
-
|
|
226
|
-
|
|
207
|
+
state: State | None = None,
|
|
208
|
+
return_updated_state: bool = False,
|
|
227
209
|
return_activation_trace: bool = False,
|
|
228
210
|
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
229
211
|
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
@@ -239,28 +221,35 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
239
221
|
f" got {token_positions.shape}",
|
|
240
222
|
)
|
|
241
223
|
|
|
242
|
-
|
|
224
|
+
maybe_state = state or ([None] * len(self.layers))
|
|
243
225
|
inner_features = vmap(self.embedding.embed)(token_ids)
|
|
244
226
|
|
|
245
|
-
|
|
227
|
+
if self.global_rope is not None:
|
|
228
|
+
global_positional_embeddings = vmap(self.global_rope)(token_positions)
|
|
229
|
+
else:
|
|
230
|
+
global_positional_embeddings = None
|
|
231
|
+
|
|
246
232
|
if self.local_rope is not None:
|
|
247
233
|
local_positional_embeddings = vmap(self.local_rope)(token_positions)
|
|
248
234
|
else:
|
|
249
235
|
local_positional_embeddings = global_positional_embeddings
|
|
250
236
|
|
|
251
|
-
|
|
237
|
+
updated_state_layers = []
|
|
252
238
|
layer_results = []
|
|
253
|
-
for layer,
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
239
|
+
for layer, state_layer in zip(self.layers, maybe_state, strict=True):
|
|
240
|
+
match layer.positional_embedding_selector:
|
|
241
|
+
case PositionalEmbeddingSelector.LOCAL:
|
|
242
|
+
positional_embeddings_to_use = local_positional_embeddings
|
|
243
|
+
case PositionalEmbeddingSelector.GLOBAL:
|
|
244
|
+
positional_embeddings_to_use = global_positional_embeddings
|
|
245
|
+
case PositionalEmbeddingSelector.NONE:
|
|
246
|
+
positional_embeddings_to_use = None
|
|
258
247
|
|
|
259
248
|
layer_result = layer(
|
|
260
249
|
inner_features,
|
|
261
250
|
positional_embeddings_to_use,
|
|
262
|
-
|
|
263
|
-
|
|
251
|
+
state=state_layer,
|
|
252
|
+
return_updated_state=return_updated_state,
|
|
264
253
|
return_activation_trace=return_activation_trace,
|
|
265
254
|
lengths_without_padding=lengths_without_padding,
|
|
266
255
|
forward_pass_mode=forward_pass_mode,
|
|
@@ -268,7 +257,7 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
268
257
|
)
|
|
269
258
|
inner_features = layer_result.outputs
|
|
270
259
|
layer_results.append(layer_result)
|
|
271
|
-
|
|
260
|
+
updated_state_layers.append(layer_result.updated_state)
|
|
272
261
|
|
|
273
262
|
normalized_outputs = vmap_twice(self.output_norm)(inner_features)
|
|
274
263
|
logits = vmap_twice(self.embedding.readout)(normalized_outputs)
|
|
@@ -277,7 +266,7 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
277
266
|
activation_trace = DecoderActivationTrace(
|
|
278
267
|
token_ids=token_ids,
|
|
279
268
|
token_positions=token_positions,
|
|
280
|
-
|
|
269
|
+
state=state,
|
|
281
270
|
global_positional_embeddings=global_positional_embeddings,
|
|
282
271
|
local_positional_embeddings=local_positional_embeddings,
|
|
283
272
|
layer_results=tuple(layer_results),
|
|
@@ -286,27 +275,28 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
286
275
|
else:
|
|
287
276
|
activation_trace = None
|
|
288
277
|
|
|
289
|
-
if
|
|
290
|
-
|
|
278
|
+
if return_updated_state:
|
|
279
|
+
updated_state = State(updated_state_layers)
|
|
291
280
|
else:
|
|
292
|
-
|
|
281
|
+
updated_state = None
|
|
293
282
|
|
|
294
283
|
return DecoderResult(
|
|
295
284
|
logits=logits,
|
|
296
|
-
|
|
285
|
+
updated_state=updated_state,
|
|
297
286
|
activation_trace=activation_trace,
|
|
298
287
|
)
|
|
299
288
|
|
|
300
|
-
def
|
|
301
|
-
return
|
|
289
|
+
def init_static_state(self, batch_size: int, capacity: int) -> State:
|
|
290
|
+
return State(layer.init_static_state(batch_size, capacity) for layer in self.layers)
|
|
302
291
|
|
|
303
292
|
def export_weights(self) -> ParameterTree:
|
|
304
293
|
result = dict(
|
|
305
294
|
embedding=self.embedding.export_weights(),
|
|
306
|
-
global_rope=self.global_rope.export_weights(),
|
|
307
295
|
layers=[layer.export_weights() for layer in self.layers],
|
|
308
296
|
output_norm=self.output_norm.export_weights(),
|
|
309
297
|
)
|
|
298
|
+
if self.global_rope:
|
|
299
|
+
result["global_rope"] = self.global_rope.export_weights()
|
|
310
300
|
if self.local_rope:
|
|
311
301
|
result["local_rope"] = self.local_rope.export_weights()
|
|
312
302
|
return result
|
|
@@ -317,15 +307,21 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
317
307
|
) -> Self:
|
|
318
308
|
assert isinstance(weights, Mapping)
|
|
319
309
|
assert isinstance(weights["embedding"], Mapping)
|
|
320
|
-
assert isinstance(weights["global_rope"], Mapping)
|
|
321
310
|
assert isinstance(weights["layers"], Sequence)
|
|
322
311
|
assert isinstance(weights["output_norm"], Mapping)
|
|
312
|
+
|
|
323
313
|
if self.local_rope:
|
|
324
314
|
assert isinstance(weights["local_rope"], Mapping)
|
|
325
315
|
local_rope = self.local_rope.import_weights(weights["local_rope"])
|
|
326
316
|
else:
|
|
327
317
|
local_rope = None
|
|
328
318
|
|
|
319
|
+
if self.global_rope:
|
|
320
|
+
assert isinstance(weights["global_rope"], Mapping)
|
|
321
|
+
global_rope = self.global_rope.import_weights(weights["global_rope"])
|
|
322
|
+
else:
|
|
323
|
+
global_rope = None
|
|
324
|
+
|
|
329
325
|
layers = []
|
|
330
326
|
for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
|
|
331
327
|
assert isinstance(layer_weights, Mapping)
|
|
@@ -333,7 +329,7 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
333
329
|
return replace(
|
|
334
330
|
self,
|
|
335
331
|
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
336
|
-
global_rope=
|
|
332
|
+
global_rope=global_rope,
|
|
337
333
|
layers=tuple(layers),
|
|
338
334
|
output_norm=self.output_norm.import_weights(weights["output_norm"]),
|
|
339
335
|
local_rope=local_rope,
|