lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +2 -16
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
- lalamo-0.5.0.dist-info/RECORD +80 -0
- lalamo-0.4.1.dist-info/RECORD +0 -71
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
2
|
from dataclasses import dataclass, replace
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Self
|
|
4
4
|
|
|
5
5
|
import equinox as eqx
|
|
6
6
|
import jax
|
|
@@ -10,17 +10,19 @@ from jax import vmap
|
|
|
10
10
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
12
|
from lalamo.common import dummy_array
|
|
13
|
+
from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
|
|
14
|
+
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
13
15
|
from lalamo.modules.normalization import RMSNorm, RMSNormConfig
|
|
16
|
+
from lalamo.modules.rope import PositionalEmbeddings
|
|
17
|
+
from lalamo.modules.utils import apply_soft_capping
|
|
14
18
|
|
|
15
|
-
from .common import
|
|
16
|
-
from .
|
|
17
|
-
from .linear import LinearBase, LinearConfig
|
|
18
|
-
from .rope import PositionalEmbeddings
|
|
19
|
-
from .utils import apply_soft_capping
|
|
19
|
+
from .common import TokenMixerBase, TokenMixerConfigBase, TokenMixerResult
|
|
20
|
+
from .state import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
|
|
20
21
|
|
|
21
22
|
__all__ = [
|
|
22
23
|
"Attention",
|
|
23
24
|
"AttentionConfig",
|
|
25
|
+
"AttentionResult",
|
|
24
26
|
]
|
|
25
27
|
|
|
26
28
|
|
|
@@ -72,33 +74,36 @@ def _soft_capped_attention_kernel(
|
|
|
72
74
|
)
|
|
73
75
|
|
|
74
76
|
|
|
75
|
-
|
|
76
|
-
outputs: Float[Array, "*batch suffix_tokens channels"]
|
|
77
|
-
kv_cache: KVCacheLayer | None = None
|
|
77
|
+
AttentionResult = TokenMixerResult[KVCacheLayer]
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
@dataclass(frozen=True)
|
|
81
|
-
class AttentionConfig:
|
|
81
|
+
class AttentionConfig(TokenMixerConfigBase):
|
|
82
82
|
qkv_projection_config: LinearConfig
|
|
83
83
|
out_projection_config: LinearConfig
|
|
84
84
|
|
|
85
85
|
query_norm_config: RMSNormConfig | None
|
|
86
86
|
key_norm_config: RMSNormConfig | None
|
|
87
87
|
|
|
88
|
+
num_heads: int
|
|
89
|
+
num_groups: int
|
|
90
|
+
head_dim: int
|
|
91
|
+
is_causal: bool
|
|
92
|
+
scale: float | None
|
|
93
|
+
sliding_window_size: int | None
|
|
94
|
+
|
|
88
95
|
logit_soft_cap: float | None
|
|
89
96
|
has_sinks: bool
|
|
90
97
|
has_qkv_biases: bool
|
|
91
98
|
has_out_biases: bool
|
|
92
99
|
|
|
100
|
+
@property
|
|
101
|
+
def rope_dim(self) -> int:
|
|
102
|
+
return self.head_dim
|
|
103
|
+
|
|
93
104
|
def random_init(
|
|
94
105
|
self,
|
|
95
106
|
model_dim: int,
|
|
96
|
-
num_heads: int,
|
|
97
|
-
num_groups: int,
|
|
98
|
-
head_dim: int,
|
|
99
|
-
is_causal: bool,
|
|
100
|
-
scale: float | None,
|
|
101
|
-
sliding_window_size: int | None,
|
|
102
107
|
*,
|
|
103
108
|
key: PRNGKeyArray,
|
|
104
109
|
) -> "Attention":
|
|
@@ -106,15 +111,15 @@ class AttentionConfig:
|
|
|
106
111
|
qkv_projection = self.qkv_projection_config.random_init(
|
|
107
112
|
input_dim=model_dim,
|
|
108
113
|
output_dims=(
|
|
109
|
-
num_heads * head_dim,
|
|
110
|
-
num_groups * head_dim,
|
|
111
|
-
num_groups * head_dim,
|
|
114
|
+
self.num_heads * self.head_dim,
|
|
115
|
+
self.num_groups * self.head_dim,
|
|
116
|
+
self.num_groups * self.head_dim,
|
|
112
117
|
),
|
|
113
118
|
has_biases=self.has_qkv_biases,
|
|
114
119
|
key=qkv_key,
|
|
115
120
|
)
|
|
116
121
|
out_projection = self.out_projection_config.random_init(
|
|
117
|
-
num_heads * head_dim,
|
|
122
|
+
self.num_heads * self.head_dim,
|
|
118
123
|
(model_dim,),
|
|
119
124
|
has_biases=self.has_out_biases,
|
|
120
125
|
key=out_key,
|
|
@@ -122,20 +127,20 @@ class AttentionConfig:
|
|
|
122
127
|
|
|
123
128
|
if self.query_norm_config is not None:
|
|
124
129
|
query_norm = self.query_norm_config.init(
|
|
125
|
-
input_dim=head_dim,
|
|
130
|
+
input_dim=self.head_dim,
|
|
126
131
|
)
|
|
127
132
|
else:
|
|
128
133
|
query_norm = None
|
|
129
134
|
|
|
130
135
|
if self.key_norm_config is not None:
|
|
131
136
|
key_norm = self.key_norm_config.init(
|
|
132
|
-
input_dim=head_dim,
|
|
137
|
+
input_dim=self.head_dim,
|
|
133
138
|
)
|
|
134
139
|
else:
|
|
135
140
|
key_norm = None
|
|
136
141
|
|
|
137
142
|
if self.has_sinks:
|
|
138
|
-
sinks = jnp.zeros((num_heads,), dtype=qkv_projection.activation_precision)
|
|
143
|
+
sinks = jnp.zeros((self.num_heads,), dtype=qkv_projection.activation_precision)
|
|
139
144
|
else:
|
|
140
145
|
sinks = None
|
|
141
146
|
|
|
@@ -146,55 +151,49 @@ class AttentionConfig:
|
|
|
146
151
|
query_norm=query_norm,
|
|
147
152
|
key_norm=key_norm,
|
|
148
153
|
sinks=sinks,
|
|
149
|
-
num_heads=num_heads,
|
|
150
|
-
num_groups=num_groups,
|
|
151
|
-
head_dim=head_dim,
|
|
152
|
-
is_causal=is_causal,
|
|
153
|
-
scale=scale,
|
|
154
|
-
sliding_window_size=sliding_window_size,
|
|
154
|
+
num_heads=self.num_heads,
|
|
155
|
+
num_groups=self.num_groups,
|
|
156
|
+
head_dim=self.head_dim,
|
|
157
|
+
is_causal=self.is_causal,
|
|
158
|
+
scale=self.scale,
|
|
159
|
+
sliding_window_size=self.sliding_window_size,
|
|
155
160
|
)
|
|
156
161
|
|
|
157
162
|
def empty(
|
|
158
163
|
self,
|
|
159
164
|
model_dim: int,
|
|
160
|
-
num_heads: int,
|
|
161
|
-
num_groups: int,
|
|
162
|
-
head_dim: int,
|
|
163
|
-
is_causal: bool,
|
|
164
|
-
scale: float | None,
|
|
165
|
-
sliding_window_size: int | None,
|
|
166
165
|
) -> "Attention":
|
|
167
166
|
qkv_projection = self.qkv_projection_config.empty(
|
|
168
167
|
input_dim=model_dim,
|
|
169
168
|
output_dims=(
|
|
170
|
-
num_heads * head_dim,
|
|
171
|
-
num_groups * head_dim,
|
|
172
|
-
num_groups * head_dim,
|
|
169
|
+
self.num_heads * self.head_dim,
|
|
170
|
+
self.num_groups * self.head_dim,
|
|
171
|
+
self.num_groups * self.head_dim,
|
|
173
172
|
),
|
|
174
173
|
has_biases=self.has_qkv_biases,
|
|
175
174
|
)
|
|
176
175
|
out_projection = self.out_projection_config.empty(
|
|
177
|
-
num_heads * head_dim,
|
|
176
|
+
self.num_heads * self.head_dim,
|
|
178
177
|
(model_dim,),
|
|
179
178
|
has_biases=self.has_out_biases,
|
|
180
179
|
)
|
|
181
180
|
|
|
182
181
|
if self.query_norm_config is not None:
|
|
183
182
|
query_norm = self.query_norm_config.empty(
|
|
184
|
-
input_dim=head_dim,
|
|
183
|
+
input_dim=self.head_dim,
|
|
185
184
|
)
|
|
186
185
|
else:
|
|
187
186
|
query_norm = None
|
|
188
187
|
|
|
189
188
|
if self.key_norm_config is not None:
|
|
190
189
|
key_norm = self.key_norm_config.empty(
|
|
191
|
-
input_dim=head_dim,
|
|
190
|
+
input_dim=self.head_dim,
|
|
192
191
|
)
|
|
193
192
|
else:
|
|
194
193
|
key_norm = None
|
|
195
194
|
|
|
196
195
|
if self.has_sinks:
|
|
197
|
-
sinks = dummy_array(num_heads, qkv_projection.activation_precision)
|
|
196
|
+
sinks = dummy_array(self.num_heads, qkv_projection.activation_precision)
|
|
198
197
|
else:
|
|
199
198
|
sinks = None
|
|
200
199
|
|
|
@@ -205,16 +204,16 @@ class AttentionConfig:
|
|
|
205
204
|
query_norm=query_norm,
|
|
206
205
|
key_norm=key_norm,
|
|
207
206
|
sinks=sinks,
|
|
208
|
-
num_heads=num_heads,
|
|
209
|
-
num_groups=num_groups,
|
|
210
|
-
head_dim=head_dim,
|
|
211
|
-
is_causal=is_causal,
|
|
212
|
-
scale=scale,
|
|
213
|
-
sliding_window_size=sliding_window_size,
|
|
207
|
+
num_heads=self.num_heads,
|
|
208
|
+
num_groups=self.num_groups,
|
|
209
|
+
head_dim=self.head_dim,
|
|
210
|
+
is_causal=self.is_causal,
|
|
211
|
+
scale=self.scale,
|
|
212
|
+
sliding_window_size=self.sliding_window_size,
|
|
214
213
|
)
|
|
215
214
|
|
|
216
215
|
|
|
217
|
-
class Attention(
|
|
216
|
+
class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
|
|
218
217
|
qkv_projection: LinearBase
|
|
219
218
|
out_projection: LinearBase
|
|
220
219
|
|
|
@@ -249,8 +248,10 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
249
248
|
return self.sliding_window_size is not None
|
|
250
249
|
|
|
251
250
|
@property
|
|
252
|
-
def
|
|
253
|
-
|
|
251
|
+
def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
|
|
252
|
+
if self.use_sliding_window:
|
|
253
|
+
return PositionalEmbeddingSelector.LOCAL
|
|
254
|
+
return PositionalEmbeddingSelector.GLOBAL
|
|
254
255
|
|
|
255
256
|
@property
|
|
256
257
|
def has_sinks(self) -> bool:
|
|
@@ -318,9 +319,9 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
318
319
|
def __call__(
|
|
319
320
|
self,
|
|
320
321
|
inputs: Float[Array, "suffix_tokens channels"],
|
|
321
|
-
positional_embeddings: PositionalEmbeddings,
|
|
322
|
-
|
|
323
|
-
|
|
322
|
+
positional_embeddings: PositionalEmbeddings | None,
|
|
323
|
+
state: KVCacheLayer | None = None,
|
|
324
|
+
return_updated_state: bool = False,
|
|
324
325
|
length_without_padding: Int[Array, ""] | int | None = None,
|
|
325
326
|
) -> AttentionResult:
|
|
326
327
|
queries, keys, values = vmap(self.qkv_projection, in_axes=0)(inputs)
|
|
@@ -348,17 +349,18 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
348
349
|
if self.key_norm is not None:
|
|
349
350
|
keys = vmap(vmap(self.key_norm))(keys)
|
|
350
351
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
352
|
+
if positional_embeddings is not None:
|
|
353
|
+
apply_positional_embeddings = vmap(positional_embeddings.apply, in_axes=1, out_axes=1)
|
|
354
|
+
queries = apply_positional_embeddings(queries)
|
|
355
|
+
keys = apply_positional_embeddings(keys)
|
|
354
356
|
|
|
355
|
-
if
|
|
356
|
-
|
|
357
|
+
if state is None:
|
|
358
|
+
updated_state = DynamicKVCacheLayer.init(self.has_sinks, keys, values, length=length_without_padding)
|
|
357
359
|
else:
|
|
358
|
-
|
|
360
|
+
updated_state = state.extend(keys, values, added_length=length_without_padding)
|
|
359
361
|
|
|
360
362
|
num_suffix_tokens, _, _ = queries.shape
|
|
361
|
-
mask =
|
|
363
|
+
mask = updated_state.attention_mask(
|
|
362
364
|
num_suffix_tokens,
|
|
363
365
|
self.is_causal,
|
|
364
366
|
length_without_padding,
|
|
@@ -373,8 +375,8 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
373
375
|
if self.config.logit_soft_cap is not None:
|
|
374
376
|
attention_output = _soft_capped_attention_kernel(
|
|
375
377
|
queries,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
+
updated_state.keys,
|
|
379
|
+
updated_state.values,
|
|
378
380
|
mask=mask,
|
|
379
381
|
scale=self.scale,
|
|
380
382
|
logit_soft_cap=self.config.logit_soft_cap,
|
|
@@ -382,8 +384,8 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
382
384
|
else:
|
|
383
385
|
attention_output = jax.nn.dot_product_attention(
|
|
384
386
|
queries,
|
|
385
|
-
|
|
386
|
-
|
|
387
|
+
updated_state.keys,
|
|
388
|
+
updated_state.values,
|
|
387
389
|
bias=sink_bias,
|
|
388
390
|
mask=mask,
|
|
389
391
|
scale=self.scale,
|
|
@@ -396,16 +398,16 @@ class Attention(LalamoModule[AttentionConfig]):
|
|
|
396
398
|
)
|
|
397
399
|
(result,) = vmap(self.out_projection, in_axes=0)(attention_output)
|
|
398
400
|
|
|
399
|
-
if not
|
|
400
|
-
|
|
401
|
+
if not return_updated_state:
|
|
402
|
+
updated_state = None
|
|
401
403
|
|
|
402
404
|
return AttentionResult(
|
|
403
405
|
outputs=result,
|
|
404
|
-
|
|
406
|
+
state=updated_state,
|
|
405
407
|
)
|
|
406
408
|
|
|
407
|
-
def
|
|
408
|
-
return StaticKVCacheLayer.
|
|
409
|
+
def init_static_state(self, capacity: int) -> StaticKVCacheLayer:
|
|
410
|
+
return StaticKVCacheLayer.init(
|
|
409
411
|
self.has_sinks,
|
|
410
412
|
capacity,
|
|
411
413
|
self.num_groups,
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import NamedTuple, Self
|
|
4
|
+
|
|
5
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
6
|
+
|
|
7
|
+
from lalamo.modules.common import LalamoModule, ParameterTree, PositionalEmbeddingSelector
|
|
8
|
+
from lalamo.modules.rope import PositionalEmbeddings
|
|
9
|
+
|
|
10
|
+
from .state.common import StateLayerBase
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"TokenMixerBase",
|
|
14
|
+
"TokenMixerConfigBase",
|
|
15
|
+
"TokenMixerResult",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TokenMixerResult[StateLayerT](NamedTuple):
|
|
20
|
+
outputs: Float[Array, "*batch suffix_tokens channels"]
|
|
21
|
+
state: StateLayerT | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True)
|
|
25
|
+
class TokenMixerConfigBase(ABC):
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def rope_dim(self) -> int: ...
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def random_init(
|
|
32
|
+
self,
|
|
33
|
+
model_dim: int,
|
|
34
|
+
*,
|
|
35
|
+
key: PRNGKeyArray,
|
|
36
|
+
) -> "TokenMixerBase": ...
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def empty(
|
|
40
|
+
self,
|
|
41
|
+
model_dim: int,
|
|
42
|
+
) -> "TokenMixerBase": ...
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TokenMixerBase[ConfigT, StateLayerT: StateLayerBase](LalamoModule[ConfigT]):
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def activation_precision(self) -> DTypeLike: ...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def model_dim(self) -> int: ...
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def positional_embedding_selector(self) -> PositionalEmbeddingSelector: ...
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def __call__(
|
|
60
|
+
self,
|
|
61
|
+
inputs: Float[Array, "suffix_tokens channels"],
|
|
62
|
+
positional_embeddings: PositionalEmbeddings | None,
|
|
63
|
+
state: StateLayerT | None = None,
|
|
64
|
+
return_updated_state: bool = False,
|
|
65
|
+
length_without_padding: Int[Array, ""] | int | None = None,
|
|
66
|
+
) -> TokenMixerResult[StateLayerT]: ...
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def init_static_state(self, capacity: int) -> StateLayerT: ...
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def export_weights(self) -> ParameterTree: ...
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def import_weights(
|
|
76
|
+
self,
|
|
77
|
+
weights: ParameterTree[Array],
|
|
78
|
+
) -> Self: ...
|