lalamo 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +273 -45
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +10 -6
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -3
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/METADATA +11 -4
  48. lalamo-0.4.1.dist-info/RECORD +71 -0
  49. lalamo-0.3.4.dist-info/RECORD +0 -59
  50. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,48 @@
1
1
  from collections.abc import Mapping
2
2
  from dataclasses import dataclass, replace
3
+ from functools import partial
3
4
  from typing import Self
4
5
 
5
6
  import equinox as eqx
6
7
  import jax
8
+ import jax.numpy as jnp
7
9
  from jax import vmap
8
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
11
 
10
12
  from lalamo.common import ParameterTree
11
13
 
12
14
  from .attention import Attention, AttentionConfig
13
- from .common import AttentionType, LalamoModule, WeightLayout
15
+ from .common import AttentionType, ForwardPassMode, LalamoModule
14
16
  from .kv_cache import KVCacheLayer, StaticKVCacheLayer
15
- from .mlp import MLP, MLPConfig
17
+ from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
16
18
  from .normalization import RMSNorm, RMSNormConfig
17
19
  from .rope import PositionalEmbeddings
20
+ from .utils import vmap_twice
18
21
 
19
22
  __all__ = [
20
23
  "DecoderLayer",
21
24
  "DecoderLayerActivationTrace",
22
25
  "DecoderLayerConfig",
26
+ "DecoderLayerForwardPassConfig",
23
27
  "DecoderLayerResult",
24
28
  ]
25
29
 
26
30
 
31
+ type DecoderLayerForwardPassConfig = MLPForwardPassConfig
32
+
33
+
27
34
  class DecoderLayerActivationTrace(eqx.Module):
28
- inputs: Float[Array, "suffix_tokens channels"]
35
+ inputs: Float[Array, "batch suffix_tokens channels"]
29
36
  positional_embeddings: PositionalEmbeddings
30
37
  kv_cache: KVCacheLayer | None
31
38
 
32
- mlp_inputs: Float[Array, "suffix_tokens channels"]
33
- pre_attention_norm: Float[Array, "suffix_tokens channels"]
34
- attention: Float[Array, "suffix_tokens channels"]
35
- post_attention_norm: Float[Array, "suffix_tokens channels"] | None
36
- pre_mlp_norm: Float[Array, "suffix_tokens channels"]
37
- mlp: Float[Array, "suffix_tokens channels"]
38
- post_mlp_norm: Float[Array, "suffix_tokens channels"] | None
39
+ mlp_inputs: Float[Array, "batch suffix_tokens channels"]
40
+ pre_attention_norm: Float[Array, "batch suffix_tokens channels"]
41
+ attention: Float[Array, "batch suffix_tokens channels"]
42
+ post_attention_norm: Float[Array, "batch suffix_tokens channels"] | None
43
+ pre_mlp_norm: Float[Array, "batch suffix_tokens channels"]
44
+ mlp: Float[Array, "batch suffix_tokens channels"]
45
+ post_mlp_norm: Float[Array, "batch suffix_tokens channels"] | None
39
46
 
40
47
  def export(self) -> ParameterTree:
41
48
  result = dict(
@@ -171,7 +178,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
171
178
  attention: Attention
172
179
  post_attention_norm: RMSNorm | None
173
180
  pre_mlp_norm: RMSNorm
174
- mlp: MLP
181
+ mlp: MLPBase
175
182
  post_mlp_norm: RMSNorm | None
176
183
 
177
184
  @property
@@ -201,44 +208,50 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
201
208
  )
202
209
  if self.mlp.model_dim != model_dim:
203
210
  raise ValueError(
204
- f"MLP up projection dim {self.mlp.up_projection.input_dim} does not match"
211
+ f"MLP up projection dim {self.mlp.model_dim} does not match"
205
212
  f" the first normalization layer dim {model_dim}",
206
213
  )
207
- if self.mlp.hidden_dim != self.mlp.down_projection.input_dim:
208
- raise ValueError(
209
- f"MLP down projection dim {self.mlp.down_projection.input_dim} does not match"
210
- f" the up projection dim {self.mlp.hidden_dim}",
211
- )
212
214
 
213
215
  @eqx.filter_jit
214
216
  def __call__(
215
217
  self,
216
- inputs: Float[Array, "suffix_tokens channels"],
218
+ inputs: Float[Array, "batch suffix_tokens channels"],
217
219
  positional_embeddings: PositionalEmbeddings,
218
220
  kv_cache: KVCacheLayer | None = None,
219
221
  return_updated_kv_cache: bool = False,
220
222
  return_activation_trace: bool = False,
221
- length_without_padding: Int[Array, ""] | int | None = None,
223
+ lengths_without_padding: Int[Array, " batch"] | None = None,
224
+ forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
225
+ forward_pass_config: DecoderLayerForwardPassConfig | None = None,
222
226
  ) -> DecoderLayerResult:
223
- normalized_attention_inputs = vmap(self.pre_attention_norm, in_axes=0)(inputs)
224
- attention_outputs, updated_kv_cache = self.attention(
227
+ if inputs.ndim != 3:
228
+ raise ValueError(
229
+ f"Inputs to decoder layers must be a 3D arrays of size (batch_size, sequence_length, hidden_dim),"
230
+ f" got {inputs.shape}",
231
+ )
232
+ normalized_attention_inputs = vmap_twice(self.pre_attention_norm)(inputs)
233
+ batched_attention_fn = vmap(partial(self.attention, return_updated_kv_cache=return_updated_kv_cache))
234
+ attention_outputs, updated_kv_cache = batched_attention_fn(
225
235
  normalized_attention_inputs,
226
236
  positional_embeddings,
227
237
  kv_cache=kv_cache,
228
- return_updated_kv_cache=return_updated_kv_cache,
229
- length_without_padding=length_without_padding,
238
+ length_without_padding=lengths_without_padding,
230
239
  )
231
240
  if self.post_attention_norm is not None:
232
- normalized_attention_outputs = vmap(self.post_attention_norm, in_axes=0)(attention_outputs)
241
+ normalized_attention_outputs = vmap_twice(self.post_attention_norm)(attention_outputs)
233
242
  mlp_inputs = inputs + normalized_attention_outputs
234
243
  else:
235
244
  normalized_attention_outputs = None
236
245
  mlp_inputs = inputs + attention_outputs
237
246
 
238
- normalized_mlp_inputs = vmap(self.pre_mlp_norm, in_axes=0)(mlp_inputs)
239
- mlp_outputs = vmap(self.mlp, in_axes=0)(normalized_mlp_inputs)
247
+ normalized_mlp_inputs = vmap_twice(self.pre_mlp_norm)(mlp_inputs)
248
+ mlp_outputs = self.mlp(
249
+ normalized_mlp_inputs,
250
+ forward_pass_mode=forward_pass_mode,
251
+ forward_pass_config=forward_pass_config,
252
+ )
240
253
  if self.post_mlp_norm is not None:
241
- normalized_mlp_outputs = vmap(self.post_mlp_norm, in_axes=0)(mlp_outputs)
254
+ normalized_mlp_outputs = vmap_twice(self.post_mlp_norm)(mlp_outputs)
242
255
  outputs = mlp_inputs + normalized_mlp_outputs
243
256
  else:
244
257
  normalized_mlp_outputs = None
@@ -266,26 +279,28 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
266
279
  activation_trace=activation_trace,
267
280
  )
268
281
 
269
- def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
270
- return self.attention.init_static_kv_cache(capacity)
282
+ def init_static_kv_cache(self, batch_size: int, capacity: int) -> StaticKVCacheLayer:
283
+ return jax.tree.map(
284
+ lambda array: jnp.repeat(array[None, ...], batch_size, axis=0),
285
+ self.attention.init_static_kv_cache(capacity),
286
+ )
271
287
 
272
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
288
+ def export_weights(self) -> ParameterTree:
273
289
  result = dict(
274
- pre_attention_norm=self.pre_attention_norm.export_weights(weight_layout),
275
- attention=self.attention.export_weights(weight_layout),
276
- pre_mlp_norm=self.pre_mlp_norm.export_weights(weight_layout),
277
- mlp=self.mlp.export_weights(weight_layout),
290
+ pre_attention_norm=self.pre_attention_norm.export_weights(),
291
+ attention=self.attention.export_weights(),
292
+ pre_mlp_norm=self.pre_mlp_norm.export_weights(),
293
+ mlp=self.mlp.export_weights(),
278
294
  )
279
295
  if self.post_attention_norm is not None:
280
- result["post_attention_norm"] = self.post_attention_norm.export_weights(weight_layout)
296
+ result["post_attention_norm"] = self.post_attention_norm.export_weights()
281
297
  if self.post_mlp_norm is not None:
282
- result["post_mlp_norm"] = self.post_mlp_norm.export_weights(weight_layout)
298
+ result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
283
299
  return result
284
300
 
285
301
  def import_weights(
286
302
  self,
287
303
  weights: ParameterTree[Array],
288
- weight_layout: WeightLayout = WeightLayout.AUTO,
289
304
  ) -> Self:
290
305
  assert isinstance(weights, Mapping)
291
306
  assert isinstance(weights["pre_attention_norm"], Mapping)
@@ -297,21 +312,20 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
297
312
  assert isinstance(weights["post_attention_norm"], Mapping)
298
313
  post_attention_norm = self.post_attention_norm.import_weights(
299
314
  weights["post_attention_norm"],
300
- weight_layout,
301
315
  )
302
316
  else:
303
317
  post_attention_norm = None
304
318
  if self.post_mlp_norm is not None:
305
319
  assert isinstance(weights["post_mlp_norm"], Mapping)
306
- post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"], weight_layout)
320
+ post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
307
321
  else:
308
322
  post_mlp_norm = None
309
323
  return replace(
310
324
  self,
311
- pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"], weight_layout),
312
- attention=self.attention.import_weights(weights["attention"], weight_layout),
325
+ pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"]),
326
+ attention=self.attention.import_weights(weights["attention"]),
313
327
  post_attention_norm=post_attention_norm,
314
- pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"], weight_layout),
315
- mlp=self.mlp.import_weights(weights["mlp"], weight_layout),
328
+ pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
329
+ mlp=self.mlp.import_weights(weights["mlp"]),
316
330
  post_mlp_norm=post_mlp_norm,
317
331
  )
@@ -13,9 +13,6 @@ from lalamo.quantization import QuantizationMode, dynamically_quantize_activatio
13
13
 
14
14
  from .common import (
15
15
  LalamoModule,
16
- WeightLayout,
17
- from_layout,
18
- into_layout,
19
16
  register_config_union,
20
17
  )
21
18
  from .utils import apply_soft_capping
@@ -35,7 +32,7 @@ __all__ = [
35
32
  @dataclass(frozen=True)
36
33
  class EmbeddingConfigBase:
37
34
  input_scale: float | None
38
- logits_soft_cap: float | None
35
+ logit_soft_cap: float | None
39
36
 
40
37
  @abstractmethod
41
38
  def random_init(
@@ -79,8 +76,8 @@ class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
79
76
  @eqx.filter_jit
80
77
  def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
81
78
  logits = self._prepare_output_weights() @ x
82
- if self.config.logits_soft_cap is not None:
83
- logits = apply_soft_capping(logits, self.config.logits_soft_cap)
79
+ if self.config.logit_soft_cap is not None:
80
+ logits = apply_soft_capping(logits, self.config.logit_soft_cap)
84
81
  return logits
85
82
 
86
83
 
@@ -136,13 +133,12 @@ class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
136
133
  def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
137
134
  return self.weights
138
135
 
139
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
136
+ def export_weights(self) -> ParameterTree:
140
137
  return {"weights": self.weights}
141
138
 
142
139
  def import_weights(
143
140
  self,
144
141
  weights: ParameterTree[Array],
145
- weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
146
142
  ) -> Self:
147
143
  assert isinstance(weights, Mapping)
148
144
  return replace(self, weights=weights["weights"])
@@ -184,7 +180,7 @@ class UntiedEmbeddingConfig(EmbeddingConfigBase):
184
180
 
185
181
  class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
186
182
  input_weights: Float[Array, "vocabulary channels"]
187
- output_weights: Float[Array, "vocabulary channels"]
183
+ output_weights: Float[Array, "channels vocabulary"]
188
184
 
189
185
  @property
190
186
  def activation_precision(self) -> DTypeLike:
@@ -228,22 +224,21 @@ class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
228
224
  def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
229
225
  return self.output_weights
230
226
 
231
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
227
+ def export_weights(self) -> ParameterTree:
232
228
  return {
233
229
  "input_weights": self.input_weights,
234
- "output_weights": into_layout(self.output_weights, weight_layout),
230
+ "output_weights": self.output_weights,
235
231
  }
236
232
 
237
233
  def import_weights(
238
234
  self,
239
235
  weights: ParameterTree[Array],
240
- weight_layout: WeightLayout = WeightLayout.AUTO,
241
236
  ) -> Self:
242
237
  assert isinstance(weights, Mapping)
243
238
  return replace(
244
239
  self,
245
240
  input_weights=weights["input_weights"],
246
- output_weights=from_layout(weights["output_weights"], weight_layout),
241
+ output_weights=weights["output_weights"],
247
242
  )
248
243
 
249
244
 
@@ -339,23 +334,22 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
339
334
  x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
340
335
  return super().readout(x)
341
336
 
342
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
337
+ def export_weights(self) -> ParameterTree:
343
338
  return {
344
- "weights": into_layout(self.int_weights, weight_layout),
345
- "scales": into_layout(self.scales, weight_layout),
339
+ "weights": self.int_weights,
340
+ "scales": self.scales,
346
341
  }
347
342
 
348
343
  def import_weights(
349
344
  self,
350
345
  weights: ParameterTree[Array],
351
- weight_layout: WeightLayout = WeightLayout.AUTO,
352
346
  ) -> Self:
353
347
  assert isinstance(weights, Mapping)
354
348
  assert isinstance(weights["weights"], Array)
355
349
  return replace(
356
350
  self,
357
- weights=from_layout(weights["weights"].astype(self.weights.dtype), weight_layout),
358
- scales=from_layout(weights["scales"], weight_layout),
351
+ weights=weights["weights"].astype(self.weights.dtype),
352
+ scales=weights["scales"],
359
353
  )
360
354
 
361
355
 
@@ -13,13 +13,14 @@ __all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer
13
13
 
14
14
 
15
15
  class KVCacheLayer(eqx.Module):
16
- keys: Float[Array, "tokens groups head_channels"]
17
- values: Float[Array, "tokens groups head_channels"]
16
+ has_sinks: bool = eqx.field(static=True)
17
+ keys: Float[Array, "*batch tokens groups head_channels"]
18
+ values: Float[Array, "*batch tokens groups head_channels"]
18
19
 
19
20
  def __post_init__(self) -> None:
20
- if self.keys.ndim != 3:
21
+ if self.keys.ndim not in (3, 4):
21
22
  raise ValueError(
22
- f"Key and value buffers must have 3 dimensions: capacity, groups, head_channels,"
23
+ f"Key and value buffers must have 3 or 4 dimensions: [batch], capacity, groups, head_channels,"
23
24
  f" got shape {self.keys.shape}",
24
25
  )
25
26
  if self.keys.shape != self.values.shape:
@@ -27,11 +28,18 @@ class KVCacheLayer(eqx.Module):
27
28
  if self.keys.dtype != self.values.dtype:
28
29
  raise ValueError("Keys and values buffers must have the same dtype")
29
30
 
31
+ def _raise_if_batched(self) -> None:
32
+ if self.keys.ndim != 3:
33
+ raise ValueError(
34
+ "Attempted to call a method on a batched version of KVCacheLayer. Use vmap instead.",
35
+ )
36
+
30
37
  @abstractmethod
31
38
  def attention_mask(
32
39
  self,
33
40
  suffix_length: int,
34
41
  is_causal: bool,
42
+ suffix_length_without_padding: Int[Array, ""] | int | None = None,
35
43
  sliding_window_size: int | None = None,
36
44
  ) -> Bool[Array, "suffix_tokens tokens"]: ...
37
45
 
@@ -68,29 +76,42 @@ class DynamicKVCacheLayer(KVCacheLayer):
68
76
  @classmethod
69
77
  def init(
70
78
  cls,
79
+ has_sinks: bool,
71
80
  keys: Float[Array, "tokens groups head_channels"],
72
81
  values: Float[Array, "tokens groups head_channels"],
73
82
  length: Int[Array, ""] | int | None = None,
74
83
  ) -> "DynamicKVCacheLayer":
75
- num_tokens, _, _ = keys.shape
84
+ num_tokens, num_groups, head_dim = keys.shape
76
85
  if length is None:
77
86
  padding_mask = None
78
87
  else:
79
- padding_mask = jnp.arange(num_tokens, dtype=jnp.int32) < length
80
- return cls(keys, values, padding_mask)
88
+ token_indices = jnp.arange(num_tokens, dtype=jnp.int32)
89
+ padding_mask = token_indices < length
90
+ if has_sinks:
91
+ sinks = jnp.zeros((1, num_groups, head_dim), dtype=keys.dtype)
92
+ keys = jnp.concatenate([sinks, keys], axis=0)
93
+ values = jnp.concatenate([sinks, values], axis=0)
94
+ if padding_mask is not None:
95
+ true = jnp.ones((1,), dtype=jnp.bool)
96
+ padding_mask = jnp.concatenate([true, padding_mask], axis=0)
97
+ return cls(has_sinks, keys, values, padding_mask)
81
98
 
82
99
  def attention_mask(
83
100
  self,
84
101
  suffix_length: int,
85
102
  is_causal: bool,
103
+ suffix_length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
86
104
  sliding_window_size: int | None = None,
87
105
  ) -> Bool[Array, "suffix_tokens tokens"]:
106
+ self._raise_if_batched()
88
107
  total_num_tokens, _, _ = self.keys.shape
89
108
  result = jnp.ones((suffix_length, total_num_tokens), dtype=jnp.bool)
90
109
  if is_causal:
91
110
  result = jnp.tril(result, k=total_num_tokens - suffix_length)
92
111
  if sliding_window_size is not None:
93
112
  result = jnp.triu(result, k=1 - sliding_window_size)
113
+ if self.has_sinks:
114
+ result = result.at[:, 0].set(True)
94
115
  if self.padding_mask is not None:
95
116
  result = result & self.padding_mask[None, :]
96
117
  return result
@@ -101,12 +122,13 @@ class DynamicKVCacheLayer(KVCacheLayer):
101
122
  added_values: Float[Array, "new_tokens groups head_channels"],
102
123
  added_length: Int[Array, ""] | int | None = None,
103
124
  ) -> "DynamicKVCacheLayer":
125
+ self._raise_if_batched()
104
126
  updated_keys = jnp.concatenate([self.keys, added_keys], axis=0)
105
127
  updated_values = jnp.concatenate([self.values, added_values], axis=0)
106
128
 
107
129
  added_padded_length, _, _ = added_keys.shape
108
130
  if self.padding_mask is None and added_length is None:
109
- return DynamicKVCacheLayer(updated_keys, updated_values)
131
+ return DynamicKVCacheLayer(self.has_sinks, updated_keys, updated_values)
110
132
  if added_length is None:
111
133
  added_length = added_padded_length
112
134
 
@@ -118,20 +140,24 @@ class DynamicKVCacheLayer(KVCacheLayer):
118
140
 
119
141
  added_padding_mask = jnp.arange(added_padded_length, dtype=jnp.int32) < added_length
120
142
  updated_padding_mask = jnp.concatenate([old_padding_mask, added_padding_mask], axis=0)
121
- return DynamicKVCacheLayer(updated_keys, updated_values, updated_padding_mask)
143
+ return DynamicKVCacheLayer(self.has_sinks, updated_keys, updated_values, updated_padding_mask)
122
144
 
123
145
 
124
146
  class StaticKVCacheLayer(KVCacheLayer):
125
- current_length: Int[Array, ""]
147
+ current_length: Int[Array, "*batch"]
126
148
 
127
149
  def attention_mask(
128
150
  self,
129
151
  suffix_length: int,
130
152
  is_causal: bool,
153
+ suffix_length_without_padding: Int[Array, ""] | int | None = None,
131
154
  sliding_window_size: int | None = None,
132
155
  ) -> Bool[Array, "suffix_tokens tokens"]:
156
+ self._raise_if_batched()
157
+ if suffix_length_without_padding is None:
158
+ suffix_length_without_padding = suffix_length
133
159
  if is_causal:
134
- query_offsets = jnp.arange(-suffix_length, 0, dtype=jnp.int32)
160
+ query_offsets = jnp.arange(0, suffix_length, dtype=jnp.int32) - suffix_length_without_padding
135
161
  else:
136
162
  query_offsets = jnp.zeros(suffix_length, dtype=jnp.int32)
137
163
 
@@ -142,15 +168,19 @@ class StaticKVCacheLayer(KVCacheLayer):
142
168
  if sliding_window_size is not None:
143
169
  swa_mask = query_indices[:, None] < (key_indices[None, :] + sliding_window_size)
144
170
  result = result & swa_mask
171
+ if self.has_sinks:
172
+ result = result.at[:, 0].set(True)
145
173
 
146
174
  return result
147
175
 
148
176
  @property
149
177
  def padding_mask(self) -> Bool[Array, " tokens"] | None:
178
+ self._raise_if_batched()
150
179
  return jnp.arange(self.capacity, dtype=jnp.int32) < self.current_length
151
180
 
152
181
  @property
153
182
  def capacity(self) -> int:
183
+ self._raise_if_batched()
154
184
  result, _, _ = self.keys.shape
155
185
  return result
156
186
 
@@ -160,6 +190,7 @@ class StaticKVCacheLayer(KVCacheLayer):
160
190
  added_values: Float[Array, "tokens groups head_channels"],
161
191
  added_length: Int[Array, ""] | int | None = None,
162
192
  ) -> "StaticKVCacheLayer":
193
+ self._raise_if_batched()
163
194
  if added_keys.shape != added_values.shape:
164
195
  raise ValueError("Keys and values must have the same shape")
165
196
  num_added_tokens, new_num_groups, new_head_dim = added_keys.shape
@@ -185,12 +216,18 @@ class StaticKVCacheLayer(KVCacheLayer):
185
216
  allow_negative_indices=False,
186
217
  )
187
218
  updated_sequence_length = self.current_length + added_length
188
- return StaticKVCacheLayer(keys=updated_keys, values=updated_values, current_length=updated_sequence_length)
219
+ return StaticKVCacheLayer(
220
+ has_sinks=self.has_sinks,
221
+ keys=updated_keys,
222
+ values=updated_values,
223
+ current_length=updated_sequence_length,
224
+ )
189
225
 
190
226
  @classmethod
191
- def empty(cls, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
227
+ def empty(cls, has_sinks: bool, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
192
228
  return cls(
193
- keys=jnp.empty((capacity, num_groups, head_dim), dtype=dtype),
194
- values=jnp.empty((capacity, num_groups, head_dim), dtype=dtype),
195
- current_length=jnp.array(0, dtype=jnp.int32),
229
+ has_sinks=has_sinks,
230
+ keys=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
231
+ values=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
232
+ current_length=jnp.array(has_sinks, dtype=jnp.int32),
196
233
  )