lalamo 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +3 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +2 -0
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/METADATA +1 -1
  49. lalamo-0.5.3.dist-info/RECORD +88 -0
  50. lalamo-0.5.2.dist-info/RECORD +0 -80
  51. {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.2.dist-info → lalamo-0.5.3.dist-info}/top_level.txt +0 -0
lalamo/modules/decoder.py CHANGED
@@ -1,4 +1,4 @@
1
- from collections.abc import Mapping, Sequence
1
+ from collections.abc import Mapping
2
2
  from dataclasses import dataclass, replace
3
3
  from typing import Self
4
4
 
@@ -9,12 +9,16 @@ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
10
  from lalamo.common import ParameterTree
11
11
 
12
- from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
13
- from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerForwardPassConfig, DecoderLayerResult
12
+ from .common import ForwardPassMode, LalamoModule
14
13
  from .embedding import EmbeddingBase, EmbeddingConfig
15
- from .normalization import RMSNorm, RMSNormConfig
16
- from .rope import PositionalEmbeddings, RoPE, RoPEConfig
17
- from .token_mixers import AttentionConfig, State
14
+ from .rope import PositionalEmbeddings
15
+ from .token_mixers import State
16
+ from .transformer import (
17
+ Transformer,
18
+ TransformerConfig,
19
+ TransformerForwardPassConfig,
20
+ TransformerLayerResult,
21
+ )
18
22
  from .utils import vmap_twice
19
23
 
20
24
  __all__ = [
@@ -26,7 +30,7 @@ __all__ = [
26
30
  ]
27
31
 
28
32
 
29
- type DecoderForwardPassConfig = DecoderLayerForwardPassConfig
33
+ type DecoderForwardPassConfig = TransformerForwardPassConfig
30
34
 
31
35
 
32
36
  class DecoderActivationTrace(eqx.Module):
@@ -37,7 +41,7 @@ class DecoderActivationTrace(eqx.Module):
37
41
  local_positional_embeddings: PositionalEmbeddings | None
38
42
  global_positional_embeddings: PositionalEmbeddings | None
39
43
 
40
- layer_results: tuple[DecoderLayerResult, ...]
44
+ layer_results: tuple[TransformerLayerResult, ...]
41
45
 
42
46
  output_norm: Float[Array, "batch suffix_tokens channels"]
43
47
 
@@ -48,12 +52,12 @@ class DecoderActivationTrace(eqx.Module):
48
52
  layer_results=[layer_result.export() for layer_result in self.layer_results],
49
53
  output_norm=self.output_norm,
50
54
  )
55
+ if self.state is not None:
56
+ result["state"] = [state_layer.export() for state_layer in self.state]
51
57
  if self.local_positional_embeddings is not None:
52
58
  result["local_positional_embeddings"] = self.local_positional_embeddings.export()
53
59
  if self.global_positional_embeddings is not None:
54
60
  result["global_positional_embeddings"] = self.global_positional_embeddings.export()
55
- if self.state is not None:
56
- result["state"] = [state_layer.export() for state_layer in self.state]
57
61
  return result
58
62
 
59
63
 
@@ -76,124 +80,46 @@ class DecoderResult(eqx.Module):
76
80
  @dataclass(frozen=True)
77
81
  class DecoderConfig:
78
82
  embedding_config: EmbeddingConfig
79
- global_rope_config: RoPEConfig | None
80
- local_rope_config: RoPEConfig | None
81
- layer_configs: tuple[DecoderLayerConfig, ...]
82
- output_norm_config: RMSNormConfig
83
+ transformer_config: TransformerConfig
83
84
 
84
85
  vocab_size: int
85
- model_dim: int
86
- hidden_dim: int
87
- context_length: int
88
86
 
89
87
  def random_init(
90
88
  self,
91
89
  *,
92
90
  key: PRNGKeyArray,
93
91
  ) -> "Decoder":
94
- embedding_key, layers_key = jax.random.split(key)
92
+ embedding_key, transformer_key = jax.random.split(key)
95
93
  embedding = self.embedding_config.random_init(
96
94
  vocab_size=self.vocab_size,
97
- model_dim=self.model_dim,
95
+ model_dim=self.transformer_config.model_dim,
98
96
  key=embedding_key,
99
97
  )
98
+ transformer = self.transformer_config.random_init(key=transformer_key)
100
99
 
101
- first_layer_config, *_ = self.layer_configs
102
-
103
- if self.global_rope_config:
104
- global_rope = self.global_rope_config.init(
105
- head_dim=first_layer_config.rope_dim,
106
- num_timesteps=self.context_length,
107
- )
108
- else:
109
- global_rope = None
110
-
111
- if self.local_rope_config:
112
- max_sliding_window_size = max(
113
- layer_config.mixer_config.sliding_window_size or 0
114
- for layer_config in self.layer_configs
115
- if isinstance(layer_config.mixer_config, AttentionConfig)
116
- )
117
- local_rope = self.local_rope_config.init(
118
- head_dim=first_layer_config.rope_dim,
119
- num_timesteps=max(max_sliding_window_size, self.context_length),
120
- )
121
- else:
122
- local_rope = None
123
-
124
- layers_keys = jax.random.split(layers_key, len(self.layer_configs))
125
- layers = tuple(
126
- layer_config.random_init(
127
- model_dim=self.model_dim,
128
- hidden_dim=self.hidden_dim,
129
- key=key,
130
- )
131
- for layer_config, key in zip(self.layer_configs, layers_keys, strict=False)
132
- )
133
- output_norm = self.output_norm_config.init(self.model_dim)
134
100
  return Decoder(
135
- self,
101
+ config=self,
136
102
  embedding=embedding,
137
- global_rope=global_rope,
138
- local_rope=local_rope,
139
- layers=layers,
140
- output_norm=output_norm,
103
+ transformer=transformer,
141
104
  )
142
105
 
143
- def empty(
144
- self,
145
- ) -> "Decoder":
106
+ def empty(self) -> "Decoder":
146
107
  embedding = self.embedding_config.empty(
147
108
  vocab_size=self.vocab_size,
148
- model_dim=self.model_dim,
109
+ model_dim=self.transformer_config.model_dim,
149
110
  )
111
+ transformer = self.transformer_config.empty()
150
112
 
151
- first_layer_config, *_ = self.layer_configs
152
-
153
- if self.global_rope_config:
154
- global_rope = self.global_rope_config.init(
155
- head_dim=first_layer_config.rope_dim,
156
- num_timesteps=self.context_length,
157
- )
158
- else:
159
- global_rope = None
160
-
161
- if self.local_rope_config:
162
- max_sliding_window_size = max(
163
- layer_config.mixer_config.sliding_window_size or 0
164
- for layer_config in self.layer_configs
165
- if isinstance(layer_config.mixer_config, AttentionConfig)
166
- )
167
- local_rope = self.local_rope_config.init(
168
- head_dim=first_layer_config.rope_dim,
169
- num_timesteps=max(max_sliding_window_size, self.context_length),
170
- )
171
- else:
172
- local_rope = None
173
- layers = tuple(
174
- layer_config.empty(
175
- model_dim=self.model_dim,
176
- hidden_dim=self.hidden_dim,
177
- )
178
- for layer_config in self.layer_configs
179
- )
180
- output_norm = self.output_norm_config.empty(self.model_dim)
181
113
  return Decoder(
182
- self,
114
+ config=self,
183
115
  embedding=embedding,
184
- global_rope=global_rope,
185
- local_rope=local_rope,
186
- layers=layers,
187
- output_norm=output_norm,
116
+ transformer=transformer,
188
117
  )
189
118
 
190
119
 
191
120
  class Decoder(LalamoModule[DecoderConfig]):
192
121
  embedding: EmbeddingBase
193
- global_rope: RoPE | None
194
- local_rope: RoPE | None
195
- layers: tuple[DecoderLayer, ...]
196
- output_norm: RMSNorm
122
+ transformer: Transformer
197
123
 
198
124
  @property
199
125
  def activation_precision(self) -> DTypeLike:
@@ -213,93 +139,59 @@ class Decoder(LalamoModule[DecoderConfig]):
213
139
  ) -> DecoderResult:
214
140
  if token_ids.ndim != 2:
215
141
  raise ValueError(
216
- f"token_ids must be a 2D arrays of size (batch_size, sequence_length), got {token_ids.shape}",
142
+ f"token_ids must be a 2D array of size (batch_size, sequence_length), got {token_ids.shape}",
217
143
  )
218
144
  if token_positions.ndim != 2:
219
145
  raise ValueError(
220
- "token_positions must be a 2D arrays of size (batch_size, sequence_length),"
146
+ "token_positions must be a 2D array of size (batch_size, sequence_length),"
221
147
  f" got {token_positions.shape}",
222
148
  )
223
149
 
224
- maybe_state = state or ([None] * len(self.layers))
225
150
  inner_features = vmap(self.embedding.embed)(token_ids)
226
151
 
227
- if self.global_rope is not None:
228
- global_positional_embeddings = vmap(self.global_rope)(token_positions)
229
- else:
230
- global_positional_embeddings = None
231
-
232
- if self.local_rope is not None:
233
- local_positional_embeddings = vmap(self.local_rope)(token_positions)
234
- else:
235
- local_positional_embeddings = global_positional_embeddings
236
-
237
- updated_state_layers = []
238
- layer_results = []
239
- for layer, state_layer in zip(self.layers, maybe_state, strict=True):
240
- match layer.positional_embedding_selector:
241
- case PositionalEmbeddingSelector.LOCAL:
242
- positional_embeddings_to_use = local_positional_embeddings
243
- case PositionalEmbeddingSelector.GLOBAL:
244
- positional_embeddings_to_use = global_positional_embeddings
245
- case PositionalEmbeddingSelector.NONE:
246
- positional_embeddings_to_use = None
247
-
248
- layer_result = layer(
249
- inner_features,
250
- positional_embeddings_to_use,
251
- state=state_layer,
252
- return_updated_state=return_updated_state,
253
- return_activation_trace=return_activation_trace,
254
- lengths_without_padding=lengths_without_padding,
255
- forward_pass_mode=forward_pass_mode,
256
- forward_pass_config=forward_pass_config,
257
- )
258
- inner_features = layer_result.outputs
259
- layer_results.append(layer_result)
260
- updated_state_layers.append(layer_result.updated_state)
152
+ transformer_result = self.transformer(
153
+ inner_features=inner_features,
154
+ token_positions=token_positions,
155
+ state=state,
156
+ return_updated_state=return_updated_state,
157
+ return_layer_results=return_activation_trace,
158
+ return_positional_embeddings=return_activation_trace,
159
+ lengths_without_padding=lengths_without_padding,
160
+ forward_pass_mode=forward_pass_mode,
161
+ forward_pass_config=forward_pass_config,
162
+ )
261
163
 
262
- normalized_outputs = vmap_twice(self.output_norm)(inner_features)
263
- logits = vmap_twice(self.embedding.readout)(normalized_outputs)
164
+ logits = vmap_twice(self.embedding.readout)(transformer_result.outputs)
264
165
 
265
166
  if return_activation_trace:
167
+ assert transformer_result.layer_results is not None
168
+
266
169
  activation_trace = DecoderActivationTrace(
267
170
  token_ids=token_ids,
268
171
  token_positions=token_positions,
269
172
  state=state,
270
- global_positional_embeddings=global_positional_embeddings,
271
- local_positional_embeddings=local_positional_embeddings,
272
- layer_results=tuple(layer_results),
273
- output_norm=normalized_outputs,
173
+ global_positional_embeddings=transformer_result.global_positional_embeddings,
174
+ local_positional_embeddings=transformer_result.local_positional_embeddings,
175
+ layer_results=transformer_result.layer_results,
176
+ output_norm=transformer_result.outputs,
274
177
  )
275
178
  else:
276
179
  activation_trace = None
277
180
 
278
- if return_updated_state:
279
- updated_state = State(updated_state_layers)
280
- else:
281
- updated_state = None
282
-
283
181
  return DecoderResult(
284
182
  logits=logits,
285
- updated_state=updated_state,
183
+ updated_state=transformer_result.updated_state,
286
184
  activation_trace=activation_trace,
287
185
  )
288
186
 
289
187
  def init_static_state(self, batch_size: int, capacity: int) -> State:
290
- return State(layer.init_static_state(batch_size, capacity) for layer in self.layers)
188
+ return self.transformer.init_static_state(batch_size, capacity)
291
189
 
292
190
  def export_weights(self) -> ParameterTree:
293
- result = dict(
191
+ return dict(
294
192
  embedding=self.embedding.export_weights(),
295
- layers=[layer.export_weights() for layer in self.layers],
296
- output_norm=self.output_norm.export_weights(),
193
+ transformer=self.transformer.export_weights(),
297
194
  )
298
- if self.global_rope:
299
- result["global_rope"] = self.global_rope.export_weights()
300
- if self.local_rope:
301
- result["local_rope"] = self.local_rope.export_weights()
302
- return result
303
195
 
304
196
  def import_weights(
305
197
  self,
@@ -307,30 +199,10 @@ class Decoder(LalamoModule[DecoderConfig]):
307
199
  ) -> Self:
308
200
  assert isinstance(weights, Mapping)
309
201
  assert isinstance(weights["embedding"], Mapping)
310
- assert isinstance(weights["layers"], Sequence)
311
- assert isinstance(weights["output_norm"], Mapping)
312
-
313
- if self.local_rope:
314
- assert isinstance(weights["local_rope"], Mapping)
315
- local_rope = self.local_rope.import_weights(weights["local_rope"])
316
- else:
317
- local_rope = None
318
-
319
- if self.global_rope:
320
- assert isinstance(weights["global_rope"], Mapping)
321
- global_rope = self.global_rope.import_weights(weights["global_rope"])
322
- else:
323
- global_rope = None
202
+ assert isinstance(weights["transformer"], Mapping)
324
203
 
325
- layers = []
326
- for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
327
- assert isinstance(layer_weights, Mapping)
328
- layers.append(layer.import_weights(layer_weights))
329
204
  return replace(
330
205
  self,
331
206
  embedding=self.embedding.import_weights(weights["embedding"]),
332
- global_rope=global_rope,
333
- layers=tuple(layers),
334
- output_norm=self.output_norm.import_weights(weights["output_norm"]),
335
- local_rope=local_rope,
207
+ transformer=self.transformer.import_weights(weights["transformer"]),
336
208
  )
lalamo/modules/mlp.py CHANGED
@@ -16,7 +16,12 @@ from lalamo.common import ParameterTree
16
16
  from lalamo.modules.utils import vmap_twice
17
17
 
18
18
  from .activations import Activation
19
- from .common import DummyUnionMember, ForwardPassMode, LalamoModule, register_config_union
19
+ from .common import (
20
+ DummyUnionMember,
21
+ ForwardPassMode,
22
+ LalamoModule,
23
+ register_config_union,
24
+ )
20
25
  from .linear import LinearBase, LinearConfig
21
26
 
22
27
  __all__ = [
@@ -192,7 +197,10 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
192
197
  f" the gate output dimension {gate_output_dim}",
193
198
  )
194
199
  (down_output_dim,) = self.down_projection.output_dims
195
- if (self.up_projection.input_dim, up_output_dim) != (down_output_dim, self.down_projection.input_dim):
200
+ if (self.up_projection.input_dim, up_output_dim) != (
201
+ down_output_dim,
202
+ self.down_projection.input_dim,
203
+ ):
196
204
  raise ValueError(
197
205
  f"Down projection dimensions {self.down_projection.input_dim, down_output_dim} do not match"
198
206
  f" the up projection output dimensions {self.up_projection.input_dim, up_output_dim}",
@@ -209,7 +217,10 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
209
217
  return vmap_twice(self.call_unbatched)(inputs)
210
218
 
211
219
  @eqx.filter_jit
212
- def call_unbatched(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
220
+ def call_unbatched(
221
+ self,
222
+ inputs: Float[Array, " channels"],
223
+ ) -> Float[Array, " channels"]:
213
224
  if self.mixture_size is not None:
214
225
  raise ValueError(
215
226
  "Mixtures of linear layers cannot be called directly."
@@ -222,6 +233,7 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
222
233
  up_proj = jnp.clip(up_proj, *self.config.up_clipping)
223
234
  gate = self.config.activation(gate)
224
235
  (result,) = self.down_projection(up_proj * gate)
236
+
225
237
  return result
226
238
 
227
239
  def export_weights(self) -> ParameterTree:
@@ -450,10 +462,21 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
450
462
  mode="drop",
451
463
  )
452
464
 
453
- return jax.lax.cond(jnp.any(token_indices_for_chunk != _SENTINEL), inner, lambda: accumulator), None
465
+ return (
466
+ jax.lax.cond(
467
+ jnp.any(token_indices_for_chunk != _SENTINEL),
468
+ inner,
469
+ lambda: accumulator,
470
+ ),
471
+ None,
472
+ )
454
473
 
455
474
  result, _ = jax.lax.scan(loop_iteration, jnp.zeros_like(flattened_inputs), chunked_token_indices)
456
- return rearrange(result, "(batch suffix_tokens) channels -> batch suffix_tokens channels", batch=batch_size)
475
+ return rearrange(
476
+ result,
477
+ "(batch suffix_tokens) channels -> batch suffix_tokens channels",
478
+ batch=batch_size,
479
+ )
457
480
 
458
481
  def export_weights(
459
482
  self,
@@ -13,8 +13,8 @@ from lalamo.common import ParameterTree, dummy_array
13
13
  from .common import LalamoModule
14
14
 
15
15
  __all__ = [
16
- "RMSNorm",
17
- "RMSNormConfig",
16
+ "Normalization",
17
+ "NormalizationConfig",
18
18
  "UpcastMode",
19
19
  ]
20
20
 
@@ -25,25 +25,26 @@ class UpcastMode(Enum):
25
25
 
26
26
 
27
27
  @dataclass(frozen=True)
28
- class RMSNormConfig:
28
+ class NormalizationConfig:
29
29
  scale_precision: DTypeLike
30
30
  accumulation_precision: DTypeLike
31
31
  epsilon: float
32
32
  scale_offset: float | None
33
33
  upcast_mode: UpcastMode
34
+ subtract_mean: bool
34
35
 
35
- def init(self, input_dim: int) -> "RMSNorm":
36
+ def init(self, input_dim: int) -> "Normalization":
36
37
  scales = jnp.ones(input_dim, dtype=self.scale_precision)
37
- return RMSNorm(self, scales=scales)
38
+ return Normalization(self, scales=scales)
38
39
 
39
- def empty(self, input_dim: int) -> "RMSNorm":
40
- return RMSNorm(
40
+ def empty(self, input_dim: int) -> "Normalization":
41
+ return Normalization(
41
42
  config=self,
42
43
  scales=dummy_array(input_dim, dtype=self.scale_precision),
43
44
  )
44
45
 
45
46
 
46
- class RMSNorm(LalamoModule[RMSNormConfig]):
47
+ class Normalization(LalamoModule[NormalizationConfig]):
47
48
  scales: Float[Array, " channels"]
48
49
 
49
50
  @property
@@ -66,6 +67,10 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
66
67
  def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
67
68
  upcasted_inputs = inputs.astype(self.config.accumulation_precision)
68
69
 
70
+ if self.config.subtract_mean:
71
+ mean = jnp.mean(upcasted_inputs)
72
+ upcasted_inputs = upcasted_inputs - mean
73
+
69
74
  adjusted_variance = jnp.mean(jnp.square(upcasted_inputs)) + self.config.epsilon
70
75
  normalized_x = upcasted_inputs * jax.lax.rsqrt(adjusted_variance)
71
76
 
@@ -12,7 +12,7 @@ from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
12
12
  from lalamo.common import dummy_array
13
13
  from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
14
14
  from lalamo.modules.linear import LinearBase, LinearConfig
15
- from lalamo.modules.normalization import RMSNorm, RMSNormConfig
15
+ from lalamo.modules.normalization import Normalization, NormalizationConfig
16
16
  from lalamo.modules.rope import PositionalEmbeddings
17
17
  from lalamo.modules.utils import apply_soft_capping
18
18
 
@@ -58,7 +58,11 @@ def _soft_capped_attention_kernel(
58
58
  "heads dst_tokens channels, heads src_tokens channels -> heads dst_tokens src_tokens",
59
59
  )
60
60
  if mask is not None:
61
- attention_logits = jnp.where(mask, attention_logits, jnp.array(float("-inf"), dtype=attention_logits.dtype))
61
+ attention_logits = jnp.where(
62
+ mask,
63
+ attention_logits,
64
+ jnp.array(float("-inf"), dtype=attention_logits.dtype),
65
+ )
62
66
 
63
67
  if scale is None:
64
68
  scale_val = head_dim**-0.5
@@ -82,8 +86,8 @@ class AttentionConfig(TokenMixerConfigBase):
82
86
  qkv_projection_config: LinearConfig
83
87
  out_projection_config: LinearConfig
84
88
 
85
- query_norm_config: RMSNormConfig | None
86
- key_norm_config: RMSNormConfig | None
89
+ query_norm_config: NormalizationConfig | None
90
+ key_norm_config: NormalizationConfig | None
87
91
 
88
92
  num_heads: int
89
93
  num_groups: int
@@ -217,8 +221,8 @@ class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
217
221
  qkv_projection: LinearBase
218
222
  out_projection: LinearBase
219
223
 
220
- query_norm: RMSNorm | None
221
- key_norm: RMSNorm | None
224
+ query_norm: Normalization | None
225
+ key_norm: Normalization | None
222
226
 
223
227
  sinks: Float[Array, " heads"] | None
224
228
 
@@ -89,7 +89,7 @@ class DynamicKVCacheLayer(KVCacheLayer):
89
89
  self,
90
90
  suffix_length: int,
91
91
  is_causal: bool,
92
- suffix_length_without_padding: Int[Array, ""] | int | None = None, # noqa: ARG002
92
+ suffix_length_without_padding: (Int[Array, ""] | int | None) = None, # noqa: ARG002
93
93
  sliding_window_size: int | None = None,
94
94
  ) -> Bool[Array, "suffix_tokens tokens"]:
95
95
  self._raise_if_batched()
@@ -97,8 +97,11 @@ class DynamicKVCacheLayer(KVCacheLayer):
97
97
  result = jnp.ones((suffix_length, total_num_tokens), dtype=jnp.bool)
98
98
  if is_causal:
99
99
  result = jnp.tril(result, k=total_num_tokens - suffix_length)
100
- if sliding_window_size is not None:
101
- result = jnp.triu(result, k=1 - sliding_window_size)
100
+ if sliding_window_size is not None:
101
+ result = jnp.triu(result, k=1 - sliding_window_size)
102
+ elif sliding_window_size is not None:
103
+ top_zeroed = jnp.tril(result, k=sliding_window_size // 2)
104
+ result = jnp.triu(top_zeroed, k=-sliding_window_size // 2)
102
105
  if self.has_sinks:
103
106
  result = result.at[:, 0].set(True)
104
107
  if self.padding_mask is not None:
@@ -213,7 +216,14 @@ class StaticKVCacheLayer(KVCacheLayer):
213
216
  )
214
217
 
215
218
  @classmethod
216
- def init(cls, has_sinks: bool, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
219
+ def init(
220
+ cls,
221
+ has_sinks: bool,
222
+ capacity: int,
223
+ num_groups: int,
224
+ head_dim: int,
225
+ dtype: DTypeLike,
226
+ ) -> Self:
217
227
  return cls(
218
228
  has_sinks=has_sinks,
219
229
  keys=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),