lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +271 -43
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +17 -7
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -4
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
- lalamo-0.4.0.dist-info/RECORD +71 -0
- lalamo-0.3.3.dist-info/RECORD +0 -59
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
lalamo/modules/decoder_layer.py
CHANGED
|
@@ -1,41 +1,48 @@
|
|
|
1
1
|
from collections.abc import Mapping
|
|
2
2
|
from dataclasses import dataclass, replace
|
|
3
|
+
from functools import partial
|
|
3
4
|
from typing import Self
|
|
4
5
|
|
|
5
6
|
import equinox as eqx
|
|
6
7
|
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
7
9
|
from jax import vmap
|
|
8
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
11
|
|
|
10
12
|
from lalamo.common import ParameterTree
|
|
11
13
|
|
|
12
14
|
from .attention import Attention, AttentionConfig
|
|
13
|
-
from .common import AttentionType,
|
|
15
|
+
from .common import AttentionType, ForwardPassMode, LalamoModule
|
|
14
16
|
from .kv_cache import KVCacheLayer, StaticKVCacheLayer
|
|
15
|
-
from .mlp import
|
|
17
|
+
from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
|
|
16
18
|
from .normalization import RMSNorm, RMSNormConfig
|
|
17
19
|
from .rope import PositionalEmbeddings
|
|
20
|
+
from .utils import vmap_twice
|
|
18
21
|
|
|
19
22
|
__all__ = [
|
|
20
23
|
"DecoderLayer",
|
|
21
24
|
"DecoderLayerActivationTrace",
|
|
22
25
|
"DecoderLayerConfig",
|
|
26
|
+
"DecoderLayerForwardPassConfig",
|
|
23
27
|
"DecoderLayerResult",
|
|
24
28
|
]
|
|
25
29
|
|
|
26
30
|
|
|
31
|
+
type DecoderLayerForwardPassConfig = MLPForwardPassConfig
|
|
32
|
+
|
|
33
|
+
|
|
27
34
|
class DecoderLayerActivationTrace(eqx.Module):
|
|
28
|
-
inputs: Float[Array, "suffix_tokens channels"]
|
|
35
|
+
inputs: Float[Array, "batch suffix_tokens channels"]
|
|
29
36
|
positional_embeddings: PositionalEmbeddings
|
|
30
37
|
kv_cache: KVCacheLayer | None
|
|
31
38
|
|
|
32
|
-
mlp_inputs: Float[Array, "suffix_tokens channels"]
|
|
33
|
-
pre_attention_norm: Float[Array, "suffix_tokens channels"]
|
|
34
|
-
attention: Float[Array, "suffix_tokens channels"]
|
|
35
|
-
post_attention_norm: Float[Array, "suffix_tokens channels"] | None
|
|
36
|
-
pre_mlp_norm: Float[Array, "suffix_tokens channels"]
|
|
37
|
-
mlp: Float[Array, "suffix_tokens channels"]
|
|
38
|
-
post_mlp_norm: Float[Array, "suffix_tokens channels"] | None
|
|
39
|
+
mlp_inputs: Float[Array, "batch suffix_tokens channels"]
|
|
40
|
+
pre_attention_norm: Float[Array, "batch suffix_tokens channels"]
|
|
41
|
+
attention: Float[Array, "batch suffix_tokens channels"]
|
|
42
|
+
post_attention_norm: Float[Array, "batch suffix_tokens channels"] | None
|
|
43
|
+
pre_mlp_norm: Float[Array, "batch suffix_tokens channels"]
|
|
44
|
+
mlp: Float[Array, "batch suffix_tokens channels"]
|
|
45
|
+
post_mlp_norm: Float[Array, "batch suffix_tokens channels"] | None
|
|
39
46
|
|
|
40
47
|
def export(self) -> ParameterTree:
|
|
41
48
|
result = dict(
|
|
@@ -171,7 +178,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
171
178
|
attention: Attention
|
|
172
179
|
post_attention_norm: RMSNorm | None
|
|
173
180
|
pre_mlp_norm: RMSNorm
|
|
174
|
-
mlp:
|
|
181
|
+
mlp: MLPBase
|
|
175
182
|
post_mlp_norm: RMSNorm | None
|
|
176
183
|
|
|
177
184
|
@property
|
|
@@ -201,44 +208,50 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
201
208
|
)
|
|
202
209
|
if self.mlp.model_dim != model_dim:
|
|
203
210
|
raise ValueError(
|
|
204
|
-
f"MLP up projection dim {self.mlp.
|
|
211
|
+
f"MLP up projection dim {self.mlp.model_dim} does not match"
|
|
205
212
|
f" the first normalization layer dim {model_dim}",
|
|
206
213
|
)
|
|
207
|
-
if self.mlp.hidden_dim != self.mlp.down_projection.input_dim:
|
|
208
|
-
raise ValueError(
|
|
209
|
-
f"MLP down projection dim {self.mlp.down_projection.input_dim} does not match"
|
|
210
|
-
f" the up projection dim {self.mlp.hidden_dim}",
|
|
211
|
-
)
|
|
212
214
|
|
|
213
215
|
@eqx.filter_jit
|
|
214
216
|
def __call__(
|
|
215
217
|
self,
|
|
216
|
-
inputs: Float[Array, "suffix_tokens channels"],
|
|
218
|
+
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
217
219
|
positional_embeddings: PositionalEmbeddings,
|
|
218
220
|
kv_cache: KVCacheLayer | None = None,
|
|
219
221
|
return_updated_kv_cache: bool = False,
|
|
220
222
|
return_activation_trace: bool = False,
|
|
221
|
-
|
|
223
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
224
|
+
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
225
|
+
forward_pass_config: DecoderLayerForwardPassConfig | None = None,
|
|
222
226
|
) -> DecoderLayerResult:
|
|
223
|
-
|
|
224
|
-
|
|
227
|
+
if inputs.ndim != 3:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Inputs to decoder layers must be a 3D arrays of size (batch_size, sequence_length, hidden_dim),"
|
|
230
|
+
f" got {inputs.shape}",
|
|
231
|
+
)
|
|
232
|
+
normalized_attention_inputs = vmap_twice(self.pre_attention_norm)(inputs)
|
|
233
|
+
batched_attention_fn = vmap(partial(self.attention, return_updated_kv_cache=return_updated_kv_cache))
|
|
234
|
+
attention_outputs, updated_kv_cache = batched_attention_fn(
|
|
225
235
|
normalized_attention_inputs,
|
|
226
236
|
positional_embeddings,
|
|
227
237
|
kv_cache=kv_cache,
|
|
228
|
-
|
|
229
|
-
length_without_padding=length_without_padding,
|
|
238
|
+
length_without_padding=lengths_without_padding,
|
|
230
239
|
)
|
|
231
240
|
if self.post_attention_norm is not None:
|
|
232
|
-
normalized_attention_outputs =
|
|
241
|
+
normalized_attention_outputs = vmap_twice(self.post_attention_norm)(attention_outputs)
|
|
233
242
|
mlp_inputs = inputs + normalized_attention_outputs
|
|
234
243
|
else:
|
|
235
244
|
normalized_attention_outputs = None
|
|
236
245
|
mlp_inputs = inputs + attention_outputs
|
|
237
246
|
|
|
238
|
-
normalized_mlp_inputs =
|
|
239
|
-
mlp_outputs =
|
|
247
|
+
normalized_mlp_inputs = vmap_twice(self.pre_mlp_norm)(mlp_inputs)
|
|
248
|
+
mlp_outputs = self.mlp(
|
|
249
|
+
normalized_mlp_inputs,
|
|
250
|
+
forward_pass_mode=forward_pass_mode,
|
|
251
|
+
forward_pass_config=forward_pass_config,
|
|
252
|
+
)
|
|
240
253
|
if self.post_mlp_norm is not None:
|
|
241
|
-
normalized_mlp_outputs =
|
|
254
|
+
normalized_mlp_outputs = vmap_twice(self.post_mlp_norm)(mlp_outputs)
|
|
242
255
|
outputs = mlp_inputs + normalized_mlp_outputs
|
|
243
256
|
else:
|
|
244
257
|
normalized_mlp_outputs = None
|
|
@@ -266,26 +279,28 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
266
279
|
activation_trace=activation_trace,
|
|
267
280
|
)
|
|
268
281
|
|
|
269
|
-
def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
|
|
270
|
-
return
|
|
282
|
+
def init_static_kv_cache(self, batch_size: int, capacity: int) -> StaticKVCacheLayer:
|
|
283
|
+
return jax.tree.map(
|
|
284
|
+
lambda array: jnp.repeat(array[None, ...], batch_size, axis=0),
|
|
285
|
+
self.attention.init_static_kv_cache(capacity),
|
|
286
|
+
)
|
|
271
287
|
|
|
272
|
-
def export_weights(self
|
|
288
|
+
def export_weights(self) -> ParameterTree:
|
|
273
289
|
result = dict(
|
|
274
|
-
pre_attention_norm=self.pre_attention_norm.export_weights(
|
|
275
|
-
attention=self.attention.export_weights(
|
|
276
|
-
pre_mlp_norm=self.pre_mlp_norm.export_weights(
|
|
277
|
-
mlp=self.mlp.export_weights(
|
|
290
|
+
pre_attention_norm=self.pre_attention_norm.export_weights(),
|
|
291
|
+
attention=self.attention.export_weights(),
|
|
292
|
+
pre_mlp_norm=self.pre_mlp_norm.export_weights(),
|
|
293
|
+
mlp=self.mlp.export_weights(),
|
|
278
294
|
)
|
|
279
295
|
if self.post_attention_norm is not None:
|
|
280
|
-
result["post_attention_norm"] = self.post_attention_norm.export_weights(
|
|
296
|
+
result["post_attention_norm"] = self.post_attention_norm.export_weights()
|
|
281
297
|
if self.post_mlp_norm is not None:
|
|
282
|
-
result["post_mlp_norm"] = self.post_mlp_norm.export_weights(
|
|
298
|
+
result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
|
|
283
299
|
return result
|
|
284
300
|
|
|
285
301
|
def import_weights(
|
|
286
302
|
self,
|
|
287
303
|
weights: ParameterTree[Array],
|
|
288
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
289
304
|
) -> Self:
|
|
290
305
|
assert isinstance(weights, Mapping)
|
|
291
306
|
assert isinstance(weights["pre_attention_norm"], Mapping)
|
|
@@ -297,21 +312,20 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
297
312
|
assert isinstance(weights["post_attention_norm"], Mapping)
|
|
298
313
|
post_attention_norm = self.post_attention_norm.import_weights(
|
|
299
314
|
weights["post_attention_norm"],
|
|
300
|
-
weight_layout,
|
|
301
315
|
)
|
|
302
316
|
else:
|
|
303
317
|
post_attention_norm = None
|
|
304
318
|
if self.post_mlp_norm is not None:
|
|
305
319
|
assert isinstance(weights["post_mlp_norm"], Mapping)
|
|
306
|
-
post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"]
|
|
320
|
+
post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
|
|
307
321
|
else:
|
|
308
322
|
post_mlp_norm = None
|
|
309
323
|
return replace(
|
|
310
324
|
self,
|
|
311
|
-
pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"]
|
|
312
|
-
attention=self.attention.import_weights(weights["attention"]
|
|
325
|
+
pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"]),
|
|
326
|
+
attention=self.attention.import_weights(weights["attention"]),
|
|
313
327
|
post_attention_norm=post_attention_norm,
|
|
314
|
-
pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]
|
|
315
|
-
mlp=self.mlp.import_weights(weights["mlp"]
|
|
328
|
+
pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
|
|
329
|
+
mlp=self.mlp.import_weights(weights["mlp"]),
|
|
316
330
|
post_mlp_norm=post_mlp_norm,
|
|
317
331
|
)
|
lalamo/modules/embedding.py
CHANGED
|
@@ -13,9 +13,6 @@ from lalamo.quantization import QuantizationMode, dynamically_quantize_activatio
|
|
|
13
13
|
|
|
14
14
|
from .common import (
|
|
15
15
|
LalamoModule,
|
|
16
|
-
WeightLayout,
|
|
17
|
-
from_layout,
|
|
18
|
-
into_layout,
|
|
19
16
|
register_config_union,
|
|
20
17
|
)
|
|
21
18
|
from .utils import apply_soft_capping
|
|
@@ -35,7 +32,7 @@ __all__ = [
|
|
|
35
32
|
@dataclass(frozen=True)
|
|
36
33
|
class EmbeddingConfigBase:
|
|
37
34
|
input_scale: float | None
|
|
38
|
-
|
|
35
|
+
logit_soft_cap: float | None
|
|
39
36
|
|
|
40
37
|
@abstractmethod
|
|
41
38
|
def random_init(
|
|
@@ -79,8 +76,8 @@ class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
|
|
|
79
76
|
@eqx.filter_jit
|
|
80
77
|
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
81
78
|
logits = self._prepare_output_weights() @ x
|
|
82
|
-
if self.config.
|
|
83
|
-
logits = apply_soft_capping(logits, self.config.
|
|
79
|
+
if self.config.logit_soft_cap is not None:
|
|
80
|
+
logits = apply_soft_capping(logits, self.config.logit_soft_cap)
|
|
84
81
|
return logits
|
|
85
82
|
|
|
86
83
|
|
|
@@ -136,13 +133,12 @@ class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
|
|
|
136
133
|
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
137
134
|
return self.weights
|
|
138
135
|
|
|
139
|
-
def export_weights(self
|
|
136
|
+
def export_weights(self) -> ParameterTree:
|
|
140
137
|
return {"weights": self.weights}
|
|
141
138
|
|
|
142
139
|
def import_weights(
|
|
143
140
|
self,
|
|
144
141
|
weights: ParameterTree[Array],
|
|
145
|
-
weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
|
|
146
142
|
) -> Self:
|
|
147
143
|
assert isinstance(weights, Mapping)
|
|
148
144
|
return replace(self, weights=weights["weights"])
|
|
@@ -184,7 +180,7 @@ class UntiedEmbeddingConfig(EmbeddingConfigBase):
|
|
|
184
180
|
|
|
185
181
|
class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
|
|
186
182
|
input_weights: Float[Array, "vocabulary channels"]
|
|
187
|
-
output_weights: Float[Array, "vocabulary
|
|
183
|
+
output_weights: Float[Array, "channels vocabulary"]
|
|
188
184
|
|
|
189
185
|
@property
|
|
190
186
|
def activation_precision(self) -> DTypeLike:
|
|
@@ -228,22 +224,21 @@ class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
|
|
|
228
224
|
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
229
225
|
return self.output_weights
|
|
230
226
|
|
|
231
|
-
def export_weights(self
|
|
227
|
+
def export_weights(self) -> ParameterTree:
|
|
232
228
|
return {
|
|
233
229
|
"input_weights": self.input_weights,
|
|
234
|
-
"output_weights":
|
|
230
|
+
"output_weights": self.output_weights,
|
|
235
231
|
}
|
|
236
232
|
|
|
237
233
|
def import_weights(
|
|
238
234
|
self,
|
|
239
235
|
weights: ParameterTree[Array],
|
|
240
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
241
236
|
) -> Self:
|
|
242
237
|
assert isinstance(weights, Mapping)
|
|
243
238
|
return replace(
|
|
244
239
|
self,
|
|
245
240
|
input_weights=weights["input_weights"],
|
|
246
|
-
output_weights=
|
|
241
|
+
output_weights=weights["output_weights"],
|
|
247
242
|
)
|
|
248
243
|
|
|
249
244
|
|
|
@@ -339,23 +334,22 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
339
334
|
x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
|
|
340
335
|
return super().readout(x)
|
|
341
336
|
|
|
342
|
-
def export_weights(self
|
|
337
|
+
def export_weights(self) -> ParameterTree:
|
|
343
338
|
return {
|
|
344
|
-
"weights":
|
|
345
|
-
"scales":
|
|
339
|
+
"weights": self.int_weights,
|
|
340
|
+
"scales": self.scales,
|
|
346
341
|
}
|
|
347
342
|
|
|
348
343
|
def import_weights(
|
|
349
344
|
self,
|
|
350
345
|
weights: ParameterTree[Array],
|
|
351
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
352
346
|
) -> Self:
|
|
353
347
|
assert isinstance(weights, Mapping)
|
|
354
348
|
assert isinstance(weights["weights"], Array)
|
|
355
349
|
return replace(
|
|
356
350
|
self,
|
|
357
|
-
weights=
|
|
358
|
-
scales=
|
|
351
|
+
weights=weights["weights"].astype(self.weights.dtype),
|
|
352
|
+
scales=weights["scales"],
|
|
359
353
|
)
|
|
360
354
|
|
|
361
355
|
|
lalamo/modules/kv_cache.py
CHANGED
|
@@ -13,13 +13,14 @@ __all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class KVCacheLayer(eqx.Module):
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
has_sinks: bool = eqx.field(static=True)
|
|
17
|
+
keys: Float[Array, "*batch tokens groups head_channels"]
|
|
18
|
+
values: Float[Array, "*batch tokens groups head_channels"]
|
|
18
19
|
|
|
19
20
|
def __post_init__(self) -> None:
|
|
20
|
-
if self.keys.ndim
|
|
21
|
+
if self.keys.ndim not in (3, 4):
|
|
21
22
|
raise ValueError(
|
|
22
|
-
f"Key and value buffers must have 3 dimensions: capacity, groups, head_channels,"
|
|
23
|
+
f"Key and value buffers must have 3 or 4 dimensions: [batch], capacity, groups, head_channels,"
|
|
23
24
|
f" got shape {self.keys.shape}",
|
|
24
25
|
)
|
|
25
26
|
if self.keys.shape != self.values.shape:
|
|
@@ -27,11 +28,18 @@ class KVCacheLayer(eqx.Module):
|
|
|
27
28
|
if self.keys.dtype != self.values.dtype:
|
|
28
29
|
raise ValueError("Keys and values buffers must have the same dtype")
|
|
29
30
|
|
|
31
|
+
def _raise_if_batched(self) -> None:
|
|
32
|
+
if self.keys.ndim != 3:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
"Attempted to call a method on a batched version of KVCacheLayer. Use vmap instead.",
|
|
35
|
+
)
|
|
36
|
+
|
|
30
37
|
@abstractmethod
|
|
31
38
|
def attention_mask(
|
|
32
39
|
self,
|
|
33
40
|
suffix_length: int,
|
|
34
41
|
is_causal: bool,
|
|
42
|
+
suffix_length_without_padding: Int[Array, ""] | int | None = None,
|
|
35
43
|
sliding_window_size: int | None = None,
|
|
36
44
|
) -> Bool[Array, "suffix_tokens tokens"]: ...
|
|
37
45
|
|
|
@@ -68,29 +76,42 @@ class DynamicKVCacheLayer(KVCacheLayer):
|
|
|
68
76
|
@classmethod
|
|
69
77
|
def init(
|
|
70
78
|
cls,
|
|
79
|
+
has_sinks: bool,
|
|
71
80
|
keys: Float[Array, "tokens groups head_channels"],
|
|
72
81
|
values: Float[Array, "tokens groups head_channels"],
|
|
73
82
|
length: Int[Array, ""] | int | None = None,
|
|
74
83
|
) -> "DynamicKVCacheLayer":
|
|
75
|
-
num_tokens,
|
|
84
|
+
num_tokens, num_groups, head_dim = keys.shape
|
|
76
85
|
if length is None:
|
|
77
86
|
padding_mask = None
|
|
78
87
|
else:
|
|
79
|
-
|
|
80
|
-
|
|
88
|
+
token_indices = jnp.arange(num_tokens, dtype=jnp.int32)
|
|
89
|
+
padding_mask = token_indices < length
|
|
90
|
+
if has_sinks:
|
|
91
|
+
sinks = jnp.zeros((1, num_groups, head_dim), dtype=keys.dtype)
|
|
92
|
+
keys = jnp.concatenate([sinks, keys], axis=0)
|
|
93
|
+
values = jnp.concatenate([sinks, values], axis=0)
|
|
94
|
+
if padding_mask is not None:
|
|
95
|
+
true = jnp.ones((1,), dtype=jnp.bool)
|
|
96
|
+
padding_mask = jnp.concatenate([true, padding_mask], axis=0)
|
|
97
|
+
return cls(has_sinks, keys, values, padding_mask)
|
|
81
98
|
|
|
82
99
|
def attention_mask(
|
|
83
100
|
self,
|
|
84
101
|
suffix_length: int,
|
|
85
102
|
is_causal: bool,
|
|
103
|
+
suffix_length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
|
|
86
104
|
sliding_window_size: int | None = None,
|
|
87
105
|
) -> Bool[Array, "suffix_tokens tokens"]:
|
|
106
|
+
self._raise_if_batched()
|
|
88
107
|
total_num_tokens, _, _ = self.keys.shape
|
|
89
108
|
result = jnp.ones((suffix_length, total_num_tokens), dtype=jnp.bool)
|
|
90
109
|
if is_causal:
|
|
91
110
|
result = jnp.tril(result, k=total_num_tokens - suffix_length)
|
|
92
111
|
if sliding_window_size is not None:
|
|
93
112
|
result = jnp.triu(result, k=1 - sliding_window_size)
|
|
113
|
+
if self.has_sinks:
|
|
114
|
+
result = result.at[:, 0].set(True)
|
|
94
115
|
if self.padding_mask is not None:
|
|
95
116
|
result = result & self.padding_mask[None, :]
|
|
96
117
|
return result
|
|
@@ -101,12 +122,13 @@ class DynamicKVCacheLayer(KVCacheLayer):
|
|
|
101
122
|
added_values: Float[Array, "new_tokens groups head_channels"],
|
|
102
123
|
added_length: Int[Array, ""] | int | None = None,
|
|
103
124
|
) -> "DynamicKVCacheLayer":
|
|
125
|
+
self._raise_if_batched()
|
|
104
126
|
updated_keys = jnp.concatenate([self.keys, added_keys], axis=0)
|
|
105
127
|
updated_values = jnp.concatenate([self.values, added_values], axis=0)
|
|
106
128
|
|
|
107
129
|
added_padded_length, _, _ = added_keys.shape
|
|
108
130
|
if self.padding_mask is None and added_length is None:
|
|
109
|
-
return DynamicKVCacheLayer(updated_keys, updated_values)
|
|
131
|
+
return DynamicKVCacheLayer(self.has_sinks, updated_keys, updated_values)
|
|
110
132
|
if added_length is None:
|
|
111
133
|
added_length = added_padded_length
|
|
112
134
|
|
|
@@ -118,20 +140,24 @@ class DynamicKVCacheLayer(KVCacheLayer):
|
|
|
118
140
|
|
|
119
141
|
added_padding_mask = jnp.arange(added_padded_length, dtype=jnp.int32) < added_length
|
|
120
142
|
updated_padding_mask = jnp.concatenate([old_padding_mask, added_padding_mask], axis=0)
|
|
121
|
-
return DynamicKVCacheLayer(updated_keys, updated_values, updated_padding_mask)
|
|
143
|
+
return DynamicKVCacheLayer(self.has_sinks, updated_keys, updated_values, updated_padding_mask)
|
|
122
144
|
|
|
123
145
|
|
|
124
146
|
class StaticKVCacheLayer(KVCacheLayer):
|
|
125
|
-
current_length: Int[Array, ""]
|
|
147
|
+
current_length: Int[Array, "*batch"]
|
|
126
148
|
|
|
127
149
|
def attention_mask(
|
|
128
150
|
self,
|
|
129
151
|
suffix_length: int,
|
|
130
152
|
is_causal: bool,
|
|
153
|
+
suffix_length_without_padding: Int[Array, ""] | int | None = None,
|
|
131
154
|
sliding_window_size: int | None = None,
|
|
132
155
|
) -> Bool[Array, "suffix_tokens tokens"]:
|
|
156
|
+
self._raise_if_batched()
|
|
157
|
+
if suffix_length_without_padding is None:
|
|
158
|
+
suffix_length_without_padding = suffix_length
|
|
133
159
|
if is_causal:
|
|
134
|
-
query_offsets = jnp.arange(
|
|
160
|
+
query_offsets = jnp.arange(0, suffix_length, dtype=jnp.int32) - suffix_length_without_padding
|
|
135
161
|
else:
|
|
136
162
|
query_offsets = jnp.zeros(suffix_length, dtype=jnp.int32)
|
|
137
163
|
|
|
@@ -142,15 +168,19 @@ class StaticKVCacheLayer(KVCacheLayer):
|
|
|
142
168
|
if sliding_window_size is not None:
|
|
143
169
|
swa_mask = query_indices[:, None] < (key_indices[None, :] + sliding_window_size)
|
|
144
170
|
result = result & swa_mask
|
|
171
|
+
if self.has_sinks:
|
|
172
|
+
result = result.at[:, 0].set(True)
|
|
145
173
|
|
|
146
174
|
return result
|
|
147
175
|
|
|
148
176
|
@property
|
|
149
177
|
def padding_mask(self) -> Bool[Array, " tokens"] | None:
|
|
178
|
+
self._raise_if_batched()
|
|
150
179
|
return jnp.arange(self.capacity, dtype=jnp.int32) < self.current_length
|
|
151
180
|
|
|
152
181
|
@property
|
|
153
182
|
def capacity(self) -> int:
|
|
183
|
+
self._raise_if_batched()
|
|
154
184
|
result, _, _ = self.keys.shape
|
|
155
185
|
return result
|
|
156
186
|
|
|
@@ -160,6 +190,7 @@ class StaticKVCacheLayer(KVCacheLayer):
|
|
|
160
190
|
added_values: Float[Array, "tokens groups head_channels"],
|
|
161
191
|
added_length: Int[Array, ""] | int | None = None,
|
|
162
192
|
) -> "StaticKVCacheLayer":
|
|
193
|
+
self._raise_if_batched()
|
|
163
194
|
if added_keys.shape != added_values.shape:
|
|
164
195
|
raise ValueError("Keys and values must have the same shape")
|
|
165
196
|
num_added_tokens, new_num_groups, new_head_dim = added_keys.shape
|
|
@@ -185,12 +216,18 @@ class StaticKVCacheLayer(KVCacheLayer):
|
|
|
185
216
|
allow_negative_indices=False,
|
|
186
217
|
)
|
|
187
218
|
updated_sequence_length = self.current_length + added_length
|
|
188
|
-
return StaticKVCacheLayer(
|
|
219
|
+
return StaticKVCacheLayer(
|
|
220
|
+
has_sinks=self.has_sinks,
|
|
221
|
+
keys=updated_keys,
|
|
222
|
+
values=updated_values,
|
|
223
|
+
current_length=updated_sequence_length,
|
|
224
|
+
)
|
|
189
225
|
|
|
190
226
|
@classmethod
|
|
191
|
-
def empty(cls, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
|
|
227
|
+
def empty(cls, has_sinks: bool, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
|
|
192
228
|
return cls(
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
229
|
+
has_sinks=has_sinks,
|
|
230
|
+
keys=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
|
|
231
|
+
values=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
|
|
232
|
+
current_length=jnp.array(has_sinks, dtype=jnp.int32),
|
|
196
233
|
)
|