lalamo 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +3 -2
- lalamo/data/__init__.py +0 -1
- lalamo/data/huggingface_message.py +1 -0
- lalamo/main.py +167 -18
- lalamo/message_processor.py +2 -3
- lalamo/model_import/common.py +120 -27
- lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo/model_import/decoder_configs/common.py +62 -21
- lalamo/model_import/decoder_configs/executorch.py +14 -9
- lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo/model_import/loaders/__init__.py +3 -2
- lalamo/model_import/loaders/executorch.py +24 -12
- lalamo/model_import/loaders/huggingface.py +258 -30
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +8 -2
- lalamo/model_import/model_specs/gemma.py +5 -1
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo/models/__init__.py +10 -0
- lalamo/models/common.py +81 -0
- lalamo/{language_model.py → models/language_model.py} +32 -49
- lalamo/models/router.py +59 -0
- lalamo/modules/__init__.py +33 -16
- lalamo/modules/classifier.py +339 -0
- lalamo/modules/common.py +6 -3
- lalamo/modules/decoder.py +52 -180
- lalamo/modules/mlp.py +28 -5
- lalamo/modules/normalization.py +13 -8
- lalamo/modules/token_mixers/attention.py +10 -6
- lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo/modules/transformer.py +273 -0
- lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
- lalamo/speculator/__init__.py +2 -0
- lalamo/speculator/estimator.py +91 -0
- lalamo/speculator/inference.py +28 -9
- lalamo/speculator/ngram.py +7 -3
- lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/METADATA +1 -1
- lalamo-0.5.3.dist-info/RECORD +88 -0
- lalamo-0.5.2.dist-info/RECORD +0 -80
- {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/WHEEL +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/top_level.txt +0 -0
lalamo/modules/decoder.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from collections.abc import Mapping
|
|
1
|
+
from collections.abc import Mapping
|
|
2
2
|
from dataclasses import dataclass, replace
|
|
3
3
|
from typing import Self
|
|
4
4
|
|
|
@@ -9,12 +9,16 @@ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
|
9
9
|
|
|
10
10
|
from lalamo.common import ParameterTree
|
|
11
11
|
|
|
12
|
-
from .common import ForwardPassMode, LalamoModule
|
|
13
|
-
from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerForwardPassConfig, DecoderLayerResult
|
|
12
|
+
from .common import ForwardPassMode, LalamoModule
|
|
14
13
|
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
15
|
-
from .
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
14
|
+
from .rope import PositionalEmbeddings
|
|
15
|
+
from .token_mixers import State
|
|
16
|
+
from .transformer import (
|
|
17
|
+
Transformer,
|
|
18
|
+
TransformerConfig,
|
|
19
|
+
TransformerForwardPassConfig,
|
|
20
|
+
TransformerLayerResult,
|
|
21
|
+
)
|
|
18
22
|
from .utils import vmap_twice
|
|
19
23
|
|
|
20
24
|
__all__ = [
|
|
@@ -26,7 +30,7 @@ __all__ = [
|
|
|
26
30
|
]
|
|
27
31
|
|
|
28
32
|
|
|
29
|
-
type DecoderForwardPassConfig =
|
|
33
|
+
type DecoderForwardPassConfig = TransformerForwardPassConfig
|
|
30
34
|
|
|
31
35
|
|
|
32
36
|
class DecoderActivationTrace(eqx.Module):
|
|
@@ -37,7 +41,7 @@ class DecoderActivationTrace(eqx.Module):
|
|
|
37
41
|
local_positional_embeddings: PositionalEmbeddings | None
|
|
38
42
|
global_positional_embeddings: PositionalEmbeddings | None
|
|
39
43
|
|
|
40
|
-
layer_results: tuple[
|
|
44
|
+
layer_results: tuple[TransformerLayerResult, ...]
|
|
41
45
|
|
|
42
46
|
output_norm: Float[Array, "batch suffix_tokens channels"]
|
|
43
47
|
|
|
@@ -48,12 +52,12 @@ class DecoderActivationTrace(eqx.Module):
|
|
|
48
52
|
layer_results=[layer_result.export() for layer_result in self.layer_results],
|
|
49
53
|
output_norm=self.output_norm,
|
|
50
54
|
)
|
|
55
|
+
if self.state is not None:
|
|
56
|
+
result["state"] = [state_layer.export() for state_layer in self.state]
|
|
51
57
|
if self.local_positional_embeddings is not None:
|
|
52
58
|
result["local_positional_embeddings"] = self.local_positional_embeddings.export()
|
|
53
59
|
if self.global_positional_embeddings is not None:
|
|
54
60
|
result["global_positional_embeddings"] = self.global_positional_embeddings.export()
|
|
55
|
-
if self.state is not None:
|
|
56
|
-
result["state"] = [state_layer.export() for state_layer in self.state]
|
|
57
61
|
return result
|
|
58
62
|
|
|
59
63
|
|
|
@@ -76,124 +80,46 @@ class DecoderResult(eqx.Module):
|
|
|
76
80
|
@dataclass(frozen=True)
|
|
77
81
|
class DecoderConfig:
|
|
78
82
|
embedding_config: EmbeddingConfig
|
|
79
|
-
|
|
80
|
-
local_rope_config: RoPEConfig | None
|
|
81
|
-
layer_configs: tuple[DecoderLayerConfig, ...]
|
|
82
|
-
output_norm_config: RMSNormConfig
|
|
83
|
+
transformer_config: TransformerConfig
|
|
83
84
|
|
|
84
85
|
vocab_size: int
|
|
85
|
-
model_dim: int
|
|
86
|
-
hidden_dim: int
|
|
87
|
-
context_length: int
|
|
88
86
|
|
|
89
87
|
def random_init(
|
|
90
88
|
self,
|
|
91
89
|
*,
|
|
92
90
|
key: PRNGKeyArray,
|
|
93
91
|
) -> "Decoder":
|
|
94
|
-
embedding_key,
|
|
92
|
+
embedding_key, transformer_key = jax.random.split(key)
|
|
95
93
|
embedding = self.embedding_config.random_init(
|
|
96
94
|
vocab_size=self.vocab_size,
|
|
97
|
-
model_dim=self.model_dim,
|
|
95
|
+
model_dim=self.transformer_config.model_dim,
|
|
98
96
|
key=embedding_key,
|
|
99
97
|
)
|
|
98
|
+
transformer = self.transformer_config.random_init(key=transformer_key)
|
|
100
99
|
|
|
101
|
-
first_layer_config, *_ = self.layer_configs
|
|
102
|
-
|
|
103
|
-
if self.global_rope_config:
|
|
104
|
-
global_rope = self.global_rope_config.init(
|
|
105
|
-
head_dim=first_layer_config.rope_dim,
|
|
106
|
-
num_timesteps=self.context_length,
|
|
107
|
-
)
|
|
108
|
-
else:
|
|
109
|
-
global_rope = None
|
|
110
|
-
|
|
111
|
-
if self.local_rope_config:
|
|
112
|
-
max_sliding_window_size = max(
|
|
113
|
-
layer_config.mixer_config.sliding_window_size or 0
|
|
114
|
-
for layer_config in self.layer_configs
|
|
115
|
-
if isinstance(layer_config.mixer_config, AttentionConfig)
|
|
116
|
-
)
|
|
117
|
-
local_rope = self.local_rope_config.init(
|
|
118
|
-
head_dim=first_layer_config.rope_dim,
|
|
119
|
-
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
120
|
-
)
|
|
121
|
-
else:
|
|
122
|
-
local_rope = None
|
|
123
|
-
|
|
124
|
-
layers_keys = jax.random.split(layers_key, len(self.layer_configs))
|
|
125
|
-
layers = tuple(
|
|
126
|
-
layer_config.random_init(
|
|
127
|
-
model_dim=self.model_dim,
|
|
128
|
-
hidden_dim=self.hidden_dim,
|
|
129
|
-
key=key,
|
|
130
|
-
)
|
|
131
|
-
for layer_config, key in zip(self.layer_configs, layers_keys, strict=False)
|
|
132
|
-
)
|
|
133
|
-
output_norm = self.output_norm_config.init(self.model_dim)
|
|
134
100
|
return Decoder(
|
|
135
|
-
self,
|
|
101
|
+
config=self,
|
|
136
102
|
embedding=embedding,
|
|
137
|
-
|
|
138
|
-
local_rope=local_rope,
|
|
139
|
-
layers=layers,
|
|
140
|
-
output_norm=output_norm,
|
|
103
|
+
transformer=transformer,
|
|
141
104
|
)
|
|
142
105
|
|
|
143
|
-
def empty(
|
|
144
|
-
self,
|
|
145
|
-
) -> "Decoder":
|
|
106
|
+
def empty(self) -> "Decoder":
|
|
146
107
|
embedding = self.embedding_config.empty(
|
|
147
108
|
vocab_size=self.vocab_size,
|
|
148
|
-
model_dim=self.model_dim,
|
|
109
|
+
model_dim=self.transformer_config.model_dim,
|
|
149
110
|
)
|
|
111
|
+
transformer = self.transformer_config.empty()
|
|
150
112
|
|
|
151
|
-
first_layer_config, *_ = self.layer_configs
|
|
152
|
-
|
|
153
|
-
if self.global_rope_config:
|
|
154
|
-
global_rope = self.global_rope_config.init(
|
|
155
|
-
head_dim=first_layer_config.rope_dim,
|
|
156
|
-
num_timesteps=self.context_length,
|
|
157
|
-
)
|
|
158
|
-
else:
|
|
159
|
-
global_rope = None
|
|
160
|
-
|
|
161
|
-
if self.local_rope_config:
|
|
162
|
-
max_sliding_window_size = max(
|
|
163
|
-
layer_config.mixer_config.sliding_window_size or 0
|
|
164
|
-
for layer_config in self.layer_configs
|
|
165
|
-
if isinstance(layer_config.mixer_config, AttentionConfig)
|
|
166
|
-
)
|
|
167
|
-
local_rope = self.local_rope_config.init(
|
|
168
|
-
head_dim=first_layer_config.rope_dim,
|
|
169
|
-
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
170
|
-
)
|
|
171
|
-
else:
|
|
172
|
-
local_rope = None
|
|
173
|
-
layers = tuple(
|
|
174
|
-
layer_config.empty(
|
|
175
|
-
model_dim=self.model_dim,
|
|
176
|
-
hidden_dim=self.hidden_dim,
|
|
177
|
-
)
|
|
178
|
-
for layer_config in self.layer_configs
|
|
179
|
-
)
|
|
180
|
-
output_norm = self.output_norm_config.empty(self.model_dim)
|
|
181
113
|
return Decoder(
|
|
182
|
-
self,
|
|
114
|
+
config=self,
|
|
183
115
|
embedding=embedding,
|
|
184
|
-
|
|
185
|
-
local_rope=local_rope,
|
|
186
|
-
layers=layers,
|
|
187
|
-
output_norm=output_norm,
|
|
116
|
+
transformer=transformer,
|
|
188
117
|
)
|
|
189
118
|
|
|
190
119
|
|
|
191
120
|
class Decoder(LalamoModule[DecoderConfig]):
|
|
192
121
|
embedding: EmbeddingBase
|
|
193
|
-
|
|
194
|
-
local_rope: RoPE | None
|
|
195
|
-
layers: tuple[DecoderLayer, ...]
|
|
196
|
-
output_norm: RMSNorm
|
|
122
|
+
transformer: Transformer
|
|
197
123
|
|
|
198
124
|
@property
|
|
199
125
|
def activation_precision(self) -> DTypeLike:
|
|
@@ -213,93 +139,59 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
213
139
|
) -> DecoderResult:
|
|
214
140
|
if token_ids.ndim != 2:
|
|
215
141
|
raise ValueError(
|
|
216
|
-
f"token_ids must be a 2D
|
|
142
|
+
f"token_ids must be a 2D array of size (batch_size, sequence_length), got {token_ids.shape}",
|
|
217
143
|
)
|
|
218
144
|
if token_positions.ndim != 2:
|
|
219
145
|
raise ValueError(
|
|
220
|
-
"token_positions must be a 2D
|
|
146
|
+
"token_positions must be a 2D array of size (batch_size, sequence_length),"
|
|
221
147
|
f" got {token_positions.shape}",
|
|
222
148
|
)
|
|
223
149
|
|
|
224
|
-
maybe_state = state or ([None] * len(self.layers))
|
|
225
150
|
inner_features = vmap(self.embedding.embed)(token_ids)
|
|
226
151
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
layer_results = []
|
|
239
|
-
for layer, state_layer in zip(self.layers, maybe_state, strict=True):
|
|
240
|
-
match layer.positional_embedding_selector:
|
|
241
|
-
case PositionalEmbeddingSelector.LOCAL:
|
|
242
|
-
positional_embeddings_to_use = local_positional_embeddings
|
|
243
|
-
case PositionalEmbeddingSelector.GLOBAL:
|
|
244
|
-
positional_embeddings_to_use = global_positional_embeddings
|
|
245
|
-
case PositionalEmbeddingSelector.NONE:
|
|
246
|
-
positional_embeddings_to_use = None
|
|
247
|
-
|
|
248
|
-
layer_result = layer(
|
|
249
|
-
inner_features,
|
|
250
|
-
positional_embeddings_to_use,
|
|
251
|
-
state=state_layer,
|
|
252
|
-
return_updated_state=return_updated_state,
|
|
253
|
-
return_activation_trace=return_activation_trace,
|
|
254
|
-
lengths_without_padding=lengths_without_padding,
|
|
255
|
-
forward_pass_mode=forward_pass_mode,
|
|
256
|
-
forward_pass_config=forward_pass_config,
|
|
257
|
-
)
|
|
258
|
-
inner_features = layer_result.outputs
|
|
259
|
-
layer_results.append(layer_result)
|
|
260
|
-
updated_state_layers.append(layer_result.updated_state)
|
|
152
|
+
transformer_result = self.transformer(
|
|
153
|
+
inner_features=inner_features,
|
|
154
|
+
token_positions=token_positions,
|
|
155
|
+
state=state,
|
|
156
|
+
return_updated_state=return_updated_state,
|
|
157
|
+
return_layer_results=return_activation_trace,
|
|
158
|
+
return_positional_embeddings=return_activation_trace,
|
|
159
|
+
lengths_without_padding=lengths_without_padding,
|
|
160
|
+
forward_pass_mode=forward_pass_mode,
|
|
161
|
+
forward_pass_config=forward_pass_config,
|
|
162
|
+
)
|
|
261
163
|
|
|
262
|
-
|
|
263
|
-
logits = vmap_twice(self.embedding.readout)(normalized_outputs)
|
|
164
|
+
logits = vmap_twice(self.embedding.readout)(transformer_result.outputs)
|
|
264
165
|
|
|
265
166
|
if return_activation_trace:
|
|
167
|
+
assert transformer_result.layer_results is not None
|
|
168
|
+
|
|
266
169
|
activation_trace = DecoderActivationTrace(
|
|
267
170
|
token_ids=token_ids,
|
|
268
171
|
token_positions=token_positions,
|
|
269
172
|
state=state,
|
|
270
|
-
global_positional_embeddings=global_positional_embeddings,
|
|
271
|
-
local_positional_embeddings=local_positional_embeddings,
|
|
272
|
-
layer_results=
|
|
273
|
-
output_norm=
|
|
173
|
+
global_positional_embeddings=transformer_result.global_positional_embeddings,
|
|
174
|
+
local_positional_embeddings=transformer_result.local_positional_embeddings,
|
|
175
|
+
layer_results=transformer_result.layer_results,
|
|
176
|
+
output_norm=transformer_result.outputs,
|
|
274
177
|
)
|
|
275
178
|
else:
|
|
276
179
|
activation_trace = None
|
|
277
180
|
|
|
278
|
-
if return_updated_state:
|
|
279
|
-
updated_state = State(updated_state_layers)
|
|
280
|
-
else:
|
|
281
|
-
updated_state = None
|
|
282
|
-
|
|
283
181
|
return DecoderResult(
|
|
284
182
|
logits=logits,
|
|
285
|
-
updated_state=updated_state,
|
|
183
|
+
updated_state=transformer_result.updated_state,
|
|
286
184
|
activation_trace=activation_trace,
|
|
287
185
|
)
|
|
288
186
|
|
|
289
187
|
def init_static_state(self, batch_size: int, capacity: int) -> State:
|
|
290
|
-
return
|
|
188
|
+
return self.transformer.init_static_state(batch_size, capacity)
|
|
291
189
|
|
|
292
190
|
def export_weights(self) -> ParameterTree:
|
|
293
|
-
|
|
191
|
+
return dict(
|
|
294
192
|
embedding=self.embedding.export_weights(),
|
|
295
|
-
|
|
296
|
-
output_norm=self.output_norm.export_weights(),
|
|
193
|
+
transformer=self.transformer.export_weights(),
|
|
297
194
|
)
|
|
298
|
-
if self.global_rope:
|
|
299
|
-
result["global_rope"] = self.global_rope.export_weights()
|
|
300
|
-
if self.local_rope:
|
|
301
|
-
result["local_rope"] = self.local_rope.export_weights()
|
|
302
|
-
return result
|
|
303
195
|
|
|
304
196
|
def import_weights(
|
|
305
197
|
self,
|
|
@@ -307,30 +199,10 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
307
199
|
) -> Self:
|
|
308
200
|
assert isinstance(weights, Mapping)
|
|
309
201
|
assert isinstance(weights["embedding"], Mapping)
|
|
310
|
-
assert isinstance(weights["
|
|
311
|
-
assert isinstance(weights["output_norm"], Mapping)
|
|
312
|
-
|
|
313
|
-
if self.local_rope:
|
|
314
|
-
assert isinstance(weights["local_rope"], Mapping)
|
|
315
|
-
local_rope = self.local_rope.import_weights(weights["local_rope"])
|
|
316
|
-
else:
|
|
317
|
-
local_rope = None
|
|
318
|
-
|
|
319
|
-
if self.global_rope:
|
|
320
|
-
assert isinstance(weights["global_rope"], Mapping)
|
|
321
|
-
global_rope = self.global_rope.import_weights(weights["global_rope"])
|
|
322
|
-
else:
|
|
323
|
-
global_rope = None
|
|
202
|
+
assert isinstance(weights["transformer"], Mapping)
|
|
324
203
|
|
|
325
|
-
layers = []
|
|
326
|
-
for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
|
|
327
|
-
assert isinstance(layer_weights, Mapping)
|
|
328
|
-
layers.append(layer.import_weights(layer_weights))
|
|
329
204
|
return replace(
|
|
330
205
|
self,
|
|
331
206
|
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
332
|
-
|
|
333
|
-
layers=tuple(layers),
|
|
334
|
-
output_norm=self.output_norm.import_weights(weights["output_norm"]),
|
|
335
|
-
local_rope=local_rope,
|
|
207
|
+
transformer=self.transformer.import_weights(weights["transformer"]),
|
|
336
208
|
)
|
lalamo/modules/mlp.py
CHANGED
|
@@ -16,7 +16,12 @@ from lalamo.common import ParameterTree
|
|
|
16
16
|
from lalamo.modules.utils import vmap_twice
|
|
17
17
|
|
|
18
18
|
from .activations import Activation
|
|
19
|
-
from .common import
|
|
19
|
+
from .common import (
|
|
20
|
+
DummyUnionMember,
|
|
21
|
+
ForwardPassMode,
|
|
22
|
+
LalamoModule,
|
|
23
|
+
register_config_union,
|
|
24
|
+
)
|
|
20
25
|
from .linear import LinearBase, LinearConfig
|
|
21
26
|
|
|
22
27
|
__all__ = [
|
|
@@ -192,7 +197,10 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
|
|
|
192
197
|
f" the gate output dimension {gate_output_dim}",
|
|
193
198
|
)
|
|
194
199
|
(down_output_dim,) = self.down_projection.output_dims
|
|
195
|
-
if (self.up_projection.input_dim, up_output_dim) != (
|
|
200
|
+
if (self.up_projection.input_dim, up_output_dim) != (
|
|
201
|
+
down_output_dim,
|
|
202
|
+
self.down_projection.input_dim,
|
|
203
|
+
):
|
|
196
204
|
raise ValueError(
|
|
197
205
|
f"Down projection dimensions {self.down_projection.input_dim, down_output_dim} do not match"
|
|
198
206
|
f" the up projection output dimensions {self.up_projection.input_dim, up_output_dim}",
|
|
@@ -209,7 +217,10 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
|
|
|
209
217
|
return vmap_twice(self.call_unbatched)(inputs)
|
|
210
218
|
|
|
211
219
|
@eqx.filter_jit
|
|
212
|
-
def call_unbatched(
|
|
220
|
+
def call_unbatched(
|
|
221
|
+
self,
|
|
222
|
+
inputs: Float[Array, " channels"],
|
|
223
|
+
) -> Float[Array, " channels"]:
|
|
213
224
|
if self.mixture_size is not None:
|
|
214
225
|
raise ValueError(
|
|
215
226
|
"Mixtures of linear layers cannot be called directly."
|
|
@@ -222,6 +233,7 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
|
|
|
222
233
|
up_proj = jnp.clip(up_proj, *self.config.up_clipping)
|
|
223
234
|
gate = self.config.activation(gate)
|
|
224
235
|
(result,) = self.down_projection(up_proj * gate)
|
|
236
|
+
|
|
225
237
|
return result
|
|
226
238
|
|
|
227
239
|
def export_weights(self) -> ParameterTree:
|
|
@@ -450,10 +462,21 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
|
|
|
450
462
|
mode="drop",
|
|
451
463
|
)
|
|
452
464
|
|
|
453
|
-
return
|
|
465
|
+
return (
|
|
466
|
+
jax.lax.cond(
|
|
467
|
+
jnp.any(token_indices_for_chunk != _SENTINEL),
|
|
468
|
+
inner,
|
|
469
|
+
lambda: accumulator,
|
|
470
|
+
),
|
|
471
|
+
None,
|
|
472
|
+
)
|
|
454
473
|
|
|
455
474
|
result, _ = jax.lax.scan(loop_iteration, jnp.zeros_like(flattened_inputs), chunked_token_indices)
|
|
456
|
-
return rearrange(
|
|
475
|
+
return rearrange(
|
|
476
|
+
result,
|
|
477
|
+
"(batch suffix_tokens) channels -> batch suffix_tokens channels",
|
|
478
|
+
batch=batch_size,
|
|
479
|
+
)
|
|
457
480
|
|
|
458
481
|
def export_weights(
|
|
459
482
|
self,
|
lalamo/modules/normalization.py
CHANGED
|
@@ -13,8 +13,8 @@ from lalamo.common import ParameterTree, dummy_array
|
|
|
13
13
|
from .common import LalamoModule
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
17
|
-
"
|
|
16
|
+
"Normalization",
|
|
17
|
+
"NormalizationConfig",
|
|
18
18
|
"UpcastMode",
|
|
19
19
|
]
|
|
20
20
|
|
|
@@ -25,25 +25,26 @@ class UpcastMode(Enum):
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@dataclass(frozen=True)
|
|
28
|
-
class
|
|
28
|
+
class NormalizationConfig:
|
|
29
29
|
scale_precision: DTypeLike
|
|
30
30
|
accumulation_precision: DTypeLike
|
|
31
31
|
epsilon: float
|
|
32
32
|
scale_offset: float | None
|
|
33
33
|
upcast_mode: UpcastMode
|
|
34
|
+
subtract_mean: bool
|
|
34
35
|
|
|
35
|
-
def init(self, input_dim: int) -> "
|
|
36
|
+
def init(self, input_dim: int) -> "Normalization":
|
|
36
37
|
scales = jnp.ones(input_dim, dtype=self.scale_precision)
|
|
37
|
-
return
|
|
38
|
+
return Normalization(self, scales=scales)
|
|
38
39
|
|
|
39
|
-
def empty(self, input_dim: int) -> "
|
|
40
|
-
return
|
|
40
|
+
def empty(self, input_dim: int) -> "Normalization":
|
|
41
|
+
return Normalization(
|
|
41
42
|
config=self,
|
|
42
43
|
scales=dummy_array(input_dim, dtype=self.scale_precision),
|
|
43
44
|
)
|
|
44
45
|
|
|
45
46
|
|
|
46
|
-
class
|
|
47
|
+
class Normalization(LalamoModule[NormalizationConfig]):
|
|
47
48
|
scales: Float[Array, " channels"]
|
|
48
49
|
|
|
49
50
|
@property
|
|
@@ -66,6 +67,10 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
|
|
|
66
67
|
def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
|
|
67
68
|
upcasted_inputs = inputs.astype(self.config.accumulation_precision)
|
|
68
69
|
|
|
70
|
+
if self.config.subtract_mean:
|
|
71
|
+
mean = jnp.mean(upcasted_inputs)
|
|
72
|
+
upcasted_inputs = upcasted_inputs - mean
|
|
73
|
+
|
|
69
74
|
adjusted_variance = jnp.mean(jnp.square(upcasted_inputs)) + self.config.epsilon
|
|
70
75
|
normalized_x = upcasted_inputs * jax.lax.rsqrt(adjusted_variance)
|
|
71
76
|
|
|
@@ -12,7 +12,7 @@ from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
|
12
12
|
from lalamo.common import dummy_array
|
|
13
13
|
from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
|
|
14
14
|
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
15
|
-
from lalamo.modules.normalization import
|
|
15
|
+
from lalamo.modules.normalization import Normalization, NormalizationConfig
|
|
16
16
|
from lalamo.modules.rope import PositionalEmbeddings
|
|
17
17
|
from lalamo.modules.utils import apply_soft_capping
|
|
18
18
|
|
|
@@ -58,7 +58,11 @@ def _soft_capped_attention_kernel(
|
|
|
58
58
|
"heads dst_tokens channels, heads src_tokens channels -> heads dst_tokens src_tokens",
|
|
59
59
|
)
|
|
60
60
|
if mask is not None:
|
|
61
|
-
attention_logits = jnp.where(
|
|
61
|
+
attention_logits = jnp.where(
|
|
62
|
+
mask,
|
|
63
|
+
attention_logits,
|
|
64
|
+
jnp.array(float("-inf"), dtype=attention_logits.dtype),
|
|
65
|
+
)
|
|
62
66
|
|
|
63
67
|
if scale is None:
|
|
64
68
|
scale_val = head_dim**-0.5
|
|
@@ -82,8 +86,8 @@ class AttentionConfig(TokenMixerConfigBase):
|
|
|
82
86
|
qkv_projection_config: LinearConfig
|
|
83
87
|
out_projection_config: LinearConfig
|
|
84
88
|
|
|
85
|
-
query_norm_config:
|
|
86
|
-
key_norm_config:
|
|
89
|
+
query_norm_config: NormalizationConfig | None
|
|
90
|
+
key_norm_config: NormalizationConfig | None
|
|
87
91
|
|
|
88
92
|
num_heads: int
|
|
89
93
|
num_groups: int
|
|
@@ -217,8 +221,8 @@ class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
|
|
|
217
221
|
qkv_projection: LinearBase
|
|
218
222
|
out_projection: LinearBase
|
|
219
223
|
|
|
220
|
-
query_norm:
|
|
221
|
-
key_norm:
|
|
224
|
+
query_norm: Normalization | None
|
|
225
|
+
key_norm: Normalization | None
|
|
222
226
|
|
|
223
227
|
sinks: Float[Array, " heads"] | None
|
|
224
228
|
|
|
@@ -89,7 +89,7 @@ class DynamicKVCacheLayer(KVCacheLayer):
|
|
|
89
89
|
self,
|
|
90
90
|
suffix_length: int,
|
|
91
91
|
is_causal: bool,
|
|
92
|
-
suffix_length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
|
|
92
|
+
suffix_length_without_padding: (Int[Array, ""] | int | None) = None, # noqa: ARG002
|
|
93
93
|
sliding_window_size: int | None = None,
|
|
94
94
|
) -> Bool[Array, "suffix_tokens tokens"]:
|
|
95
95
|
self._raise_if_batched()
|
|
@@ -97,8 +97,11 @@ class DynamicKVCacheLayer(KVCacheLayer):
|
|
|
97
97
|
result = jnp.ones((suffix_length, total_num_tokens), dtype=jnp.bool)
|
|
98
98
|
if is_causal:
|
|
99
99
|
result = jnp.tril(result, k=total_num_tokens - suffix_length)
|
|
100
|
-
|
|
101
|
-
|
|
100
|
+
if sliding_window_size is not None:
|
|
101
|
+
result = jnp.triu(result, k=1 - sliding_window_size)
|
|
102
|
+
elif sliding_window_size is not None:
|
|
103
|
+
top_zeroed = jnp.tril(result, k=sliding_window_size // 2)
|
|
104
|
+
result = jnp.triu(top_zeroed, k=-sliding_window_size // 2)
|
|
102
105
|
if self.has_sinks:
|
|
103
106
|
result = result.at[:, 0].set(True)
|
|
104
107
|
if self.padding_mask is not None:
|
|
@@ -213,7 +216,14 @@ class StaticKVCacheLayer(KVCacheLayer):
|
|
|
213
216
|
)
|
|
214
217
|
|
|
215
218
|
@classmethod
|
|
216
|
-
def init(
|
|
219
|
+
def init(
|
|
220
|
+
cls,
|
|
221
|
+
has_sinks: bool,
|
|
222
|
+
capacity: int,
|
|
223
|
+
num_groups: int,
|
|
224
|
+
head_dim: int,
|
|
225
|
+
dtype: DTypeLike,
|
|
226
|
+
) -> Self:
|
|
217
227
|
return cls(
|
|
218
228
|
has_sinks=has_sinks,
|
|
219
229
|
keys=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
|