lalamo 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +4 -18
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.0.dist-info/RECORD +0 -71
  46. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,553 @@
1
+ import math
2
+ from collections.abc import Mapping
3
+ from dataclasses import dataclass, replace
4
+ from typing import NamedTuple, Self
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from einops import einsum, rearrange
10
+ from jax import vmap
11
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
+
13
+ from lalamo.common import ParameterTree, dummy_array
14
+ from lalamo.modules.activations import Activation
15
+ from lalamo.modules.common import LalamoModule, PositionalEmbeddingSelector
16
+ from lalamo.modules.linear import LinearBase, LinearConfig
17
+ from lalamo.modules.rope import PositionalEmbeddings
18
+
19
+ from .common import TokenMixerBase, TokenMixerConfigBase, TokenMixerResult
20
+ from .state import Mamba2StateLayer
21
+
22
+ __all__ = [
23
+ "Mamba2",
24
+ "Mamba2Config",
25
+ "Mamba2Result",
26
+ "SeparableCausalConv",
27
+ "SeparableCausalConvConfig",
28
+ ]
29
+
30
+
31
+ Mamba2Result = TokenMixerResult[Mamba2StateLayer]
32
+
33
+
34
+ class CausalConvResult(NamedTuple):
35
+ outputs: Float[Array, "*batch suffix_tokens channels"]
36
+ state: Float[Array, "*batch tokens channels"] | None = None
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class SeparableCausalConvConfig:
41
+ precision: DTypeLike
42
+ has_biases: bool
43
+
44
+ def random_init(
45
+ self,
46
+ input_dim: int,
47
+ kernel_size: int,
48
+ *,
49
+ key: PRNGKeyArray,
50
+ ) -> "SeparableCausalConv":
51
+ scale = 1 / math.sqrt(kernel_size * input_dim)
52
+ weights = jax.random.uniform(
53
+ key,
54
+ (input_dim, kernel_size),
55
+ minval=-scale,
56
+ maxval=scale,
57
+ dtype=self.precision,
58
+ )
59
+ if self.has_biases:
60
+ biases = jnp.zeros((input_dim,), dtype=self.precision)
61
+ else:
62
+ biases = None
63
+ return SeparableCausalConv(self, weights=weights, biases=biases)
64
+
65
+ def empty(
66
+ self,
67
+ input_dim: int,
68
+ kernel_size: int,
69
+ ) -> "SeparableCausalConv":
70
+ weights = dummy_array(
71
+ (input_dim, kernel_size),
72
+ dtype=self.precision,
73
+ )
74
+ if self.has_biases:
75
+ biases = dummy_array((input_dim,), dtype=self.precision)
76
+ else:
77
+ biases = None
78
+ return SeparableCausalConv(self, weights=weights, biases=biases)
79
+
80
+
81
+ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
82
+ weights: Float[Array, "channels kernel"]
83
+ biases: Float[Array, " channels"] | None
84
+
85
+ def __post_init__(self) -> None:
86
+ input_dim, _ = self.weights.shape
87
+ if self.biases is not None:
88
+ (output_dim,) = self.biases.shape
89
+ if output_dim != input_dim:
90
+ raise ValueError(
91
+ f"Output dimension of biases ({output_dim}) must match input dimension ({input_dim})",
92
+ )
93
+
94
+ @property
95
+ def activation_precision(self) -> DTypeLike:
96
+ return self.config.precision
97
+
98
+ @property
99
+ def input_dim(self) -> int:
100
+ input_dim, _ = self.weights.shape
101
+ return input_dim
102
+
103
+ @property
104
+ def kernel_size(self) -> int:
105
+ _, kernel_size = self.weights.shape
106
+ return kernel_size
107
+
108
+ @property
109
+ def has_biases(self) -> bool:
110
+ return self.biases is not None
111
+
112
+ def __call__(
113
+ self,
114
+ inputs: Float[Array, "suffix_tokens channels"],
115
+ state: Float[Array, "prefix_tokens channels"] | None = None,
116
+ return_updated_state: bool = False,
117
+ ) -> CausalConvResult:
118
+ num_suffix_tokens, input_dim = inputs.shape
119
+
120
+ if state is None:
121
+ state = jnp.zeros((self.kernel_size - 1, input_dim), dtype=inputs.dtype)
122
+
123
+ required_context = num_suffix_tokens + self.kernel_size - 1
124
+
125
+ inputs_with_history = jnp.concatenate([state, inputs], axis=0)
126
+ conv_outputs = jax.lax.conv_general_dilated(
127
+ inputs_with_history[None, -required_context:, :],
128
+ self.weights[:, :, None],
129
+ window_strides=(1,),
130
+ feature_group_count=input_dim,
131
+ padding="VALID",
132
+ dimension_numbers=("NTC", "OTI", "NTC"),
133
+ )
134
+
135
+ results = conv_outputs.squeeze(0)
136
+ if self.biases is not None:
137
+ results = results + self.biases
138
+
139
+ return CausalConvResult(
140
+ results,
141
+ (inputs_with_history if return_updated_state else None),
142
+ )
143
+
144
+ def export_weights(self) -> ParameterTree:
145
+ result: dict[str, Array] = {"weights": self.weights}
146
+ if self.biases is not None:
147
+ result["biases"] = self.biases
148
+ return result
149
+
150
+ def import_weights(self, weights: ParameterTree[Array]) -> "SeparableCausalConv":
151
+ assert isinstance(weights, Mapping)
152
+ assert isinstance(weights["weights"], Array)
153
+ if self.biases is not None:
154
+ assert isinstance(weights["biases"], Array)
155
+ biases = weights["biases"]
156
+ else:
157
+ biases = None
158
+ return replace(
159
+ self,
160
+ weights=weights["weights"],
161
+ biases=biases,
162
+ )
163
+
164
+
165
+ @dataclass(frozen=True)
166
+ class Mamba2Config(TokenMixerConfigBase):
167
+ in_projection_config: LinearConfig
168
+ out_projection_config: LinearConfig
169
+ conv_config: SeparableCausalConvConfig
170
+ activation: Activation
171
+
172
+ kernel_size: int
173
+ num_heads: int
174
+ num_groups: int
175
+ head_dim: int
176
+ state_dim: int
177
+ expansion_factor: int
178
+
179
+ has_in_biases: bool
180
+ has_out_biases: bool
181
+
182
+ @property
183
+ def inner_dim(self) -> int:
184
+ return self.num_heads * self.head_dim
185
+
186
+ @property
187
+ def rope_dim(self) -> int:
188
+ return self.head_dim
189
+
190
+ def random_init(
191
+ self,
192
+ model_dim: int,
193
+ *,
194
+ key: PRNGKeyArray,
195
+ ) -> "Mamba2":
196
+ in_key, out_key, conv_key, skip_key = jax.random.split(key, 4)
197
+
198
+ in_projection = self.in_projection_config.random_init(
199
+ input_dim=model_dim,
200
+ output_dims=(
201
+ self.inner_dim + 2 * self.num_groups * self.state_dim,
202
+ self.inner_dim,
203
+ self.num_heads,
204
+ ),
205
+ has_biases=self.has_in_biases,
206
+ key=in_key,
207
+ )
208
+
209
+ out_projection = self.out_projection_config.random_init(
210
+ self.inner_dim,
211
+ (model_dim,),
212
+ has_biases=self.has_out_biases,
213
+ key=out_key,
214
+ )
215
+
216
+ conv_channels = self.inner_dim + 2 * self.num_groups * self.state_dim
217
+ conv = self.conv_config.random_init(conv_channels, self.kernel_size, key=conv_key)
218
+
219
+ skip_connection_weight = jax.random.normal(
220
+ skip_key,
221
+ (self.num_heads,),
222
+ dtype=in_projection.activation_precision,
223
+ )
224
+
225
+ gate_bias = jnp.zeros((self.inner_dim,), dtype=in_projection.activation_precision)
226
+
227
+ return Mamba2(
228
+ self,
229
+ in_projection=in_projection,
230
+ conv=conv,
231
+ out_projection=out_projection,
232
+ skip_connection_weight=skip_connection_weight,
233
+ gate_bias=gate_bias,
234
+ num_heads=self.num_heads,
235
+ num_groups=self.num_groups,
236
+ head_dim=self.head_dim,
237
+ state_dim=self.state_dim,
238
+ )
239
+
240
+ def empty(
241
+ self,
242
+ model_dim: int,
243
+ ) -> "Mamba2":
244
+ in_projection = self.in_projection_config.empty(
245
+ input_dim=model_dim,
246
+ output_dims=(
247
+ self.inner_dim + 2 * self.num_groups * self.state_dim,
248
+ self.inner_dim,
249
+ self.num_heads,
250
+ ),
251
+ has_biases=self.has_in_biases,
252
+ )
253
+
254
+ out_projection = self.out_projection_config.empty(
255
+ self.inner_dim,
256
+ (model_dim,),
257
+ has_biases=self.has_out_biases,
258
+ )
259
+
260
+ conv_channels = self.inner_dim + 2 * self.num_groups * self.state_dim
261
+ conv = self.conv_config.empty(conv_channels, self.kernel_size)
262
+
263
+ skip_connection_weight = dummy_array((self.num_heads,), in_projection.activation_precision)
264
+ gate_bias = dummy_array((self.inner_dim,), in_projection.activation_precision)
265
+
266
+ return Mamba2(
267
+ self,
268
+ in_projection=in_projection,
269
+ conv=conv,
270
+ out_projection=out_projection,
271
+ skip_connection_weight=skip_connection_weight,
272
+ gate_bias=gate_bias,
273
+ num_heads=self.num_heads,
274
+ num_groups=self.num_groups,
275
+ head_dim=self.head_dim,
276
+ state_dim=self.state_dim,
277
+ )
278
+
279
+
280
+ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
281
+ in_projection: LinearBase
282
+ conv: SeparableCausalConv
283
+ out_projection: LinearBase
284
+
285
+ skip_connection_weight: Float[Array, " heads"]
286
+ gate_bias: Float[Array, " inner_channels"]
287
+
288
+ num_heads: int = eqx.field(static=True)
289
+ num_groups: int = eqx.field(static=True)
290
+ head_dim: int = eqx.field(static=True)
291
+ state_dim: int = eqx.field(static=True)
292
+
293
+ @property
294
+ def activation_precision(self) -> DTypeLike:
295
+ return self.in_projection.activation_precision
296
+
297
+ @property
298
+ def model_dim(self) -> int:
299
+ return self.in_projection.input_dim
300
+
301
+ @property
302
+ def inner_dim(self) -> int:
303
+ return self.num_heads * self.head_dim
304
+
305
+ @property
306
+ def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
307
+ return PositionalEmbeddingSelector.NONE
308
+
309
+ def __post_init__(self) -> None:
310
+ if self.skip_connection_weight.shape != (self.num_heads,):
311
+ raise ValueError(
312
+ f"Skip connection weight must have shape (num_heads,) = ({self.num_heads},), "
313
+ f"got {self.skip_connection_weight.shape}",
314
+ )
315
+ if self.gate_bias.shape != (self.inner_dim,):
316
+ raise ValueError(
317
+ f"Gate bias must have shape (inner_dim,) = ({self.inner_dim},), got {self.gate_bias.shape}",
318
+ )
319
+ if self.num_heads % self.num_groups != 0:
320
+ raise ValueError(
321
+ f"Number of value heads ({self.num_heads}) must be divisible by number of groups ({self.num_groups})",
322
+ )
323
+
324
+ def _scan(
325
+ self,
326
+ hidden_states: Float[Array, "suffix_tokens heads head_channels"],
327
+ input_projection: Float[Array, "suffix_tokens groups state_channels"],
328
+ output_projection: Float[Array, "suffix_tokens groups state_channels"],
329
+ time_delta_log: Float[Array, "suffix_tokens heads"],
330
+ initial_state: Float[Array, "heads head_channels state_channels"],
331
+ num_steps: Int[Array, ""] | int,
332
+ ) -> tuple[
333
+ Float[Array, "suffix_tokens heads head_channels"],
334
+ Float[Array, "heads head_channels state_channels"],
335
+ ]:
336
+ def scan_fn(
337
+ index_and_carry_state: tuple[Int[Array, ""], Float[Array, "heads head_channels state_channels"]],
338
+ step_inputs: tuple[
339
+ Float[Array, "heads head_channels"],
340
+ Float[Array, "groups state_channels"],
341
+ Float[Array, "groups state_channels"],
342
+ Float[Array, " heads"],
343
+ ],
344
+ ) -> tuple[
345
+ tuple[Int[Array, ""], Float[Array, "heads head_channels state_channels"]],
346
+ Float[Array, "heads head_channels"],
347
+ ]:
348
+ index, carry_state = index_and_carry_state
349
+ hidden_state_t, input_proj_t, output_proj_t, time_delta_log_t = step_inputs
350
+ dt = jax.nn.softplus(time_delta_log_t)[:, None]
351
+ heads_per_group = self.num_heads // self.num_groups
352
+
353
+ hidden_grouped = rearrange(
354
+ hidden_state_t,
355
+ "(groups heads) head_channels -> groups heads head_channels",
356
+ groups=self.num_groups,
357
+ heads=heads_per_group,
358
+ )
359
+ x_norm_grouped = hidden_grouped / (
360
+ dt.reshape(self.num_heads)[
361
+ rearrange(
362
+ jnp.arange(self.num_heads),
363
+ "(groups heads)-> groups heads",
364
+ groups=self.num_groups,
365
+ heads=heads_per_group,
366
+ )
367
+ ][:, :, None]
368
+ + 1e-8
369
+ )
370
+
371
+ decay = jnp.exp(-dt)[:, :, None]
372
+ mix = dt[:, :, None]
373
+ decay_group = rearrange(
374
+ decay,
375
+ "(groups heads) 1 1 -> groups heads 1 1",
376
+ groups=self.num_groups,
377
+ heads=heads_per_group,
378
+ )
379
+ mix_group = rearrange(
380
+ mix,
381
+ "(groups heads) 1 1 -> groups heads 1 1",
382
+ groups=self.num_groups,
383
+ heads=heads_per_group,
384
+ )
385
+
386
+ input_contribution_group = mix_group * x_norm_grouped[:, :, :, None] * input_proj_t[:, None, None, :]
387
+ carry_state_group = rearrange(
388
+ carry_state,
389
+ "(groups heads) head_channels state_channels -> groups heads head_channels state_channels",
390
+ groups=self.num_groups,
391
+ heads=heads_per_group,
392
+ )
393
+ updated_state_group = decay_group * carry_state_group + input_contribution_group
394
+
395
+ output_group = einsum(
396
+ updated_state_group,
397
+ output_proj_t,
398
+ "groups heads head_channels state_channels, groups state_channels -> groups heads head_channels",
399
+ )
400
+ updated_state = rearrange(
401
+ updated_state_group,
402
+ "groups heads head_channels state_channels -> (groups heads) head_channels state_channels",
403
+ )
404
+ output_t = rearrange(output_group, "groups heads head_channels -> (groups heads) head_channels")
405
+
406
+ propagated_state = jax.lax.cond(index < num_steps, lambda: updated_state, lambda: carry_state)
407
+
408
+ return (index + 1, propagated_state), output_t
409
+
410
+ (_, final_state), outputs = jax.lax.scan(
411
+ scan_fn,
412
+ (jnp.zeros((), dtype=jnp.int32), initial_state),
413
+ (hidden_states, input_projection, output_projection, time_delta_log),
414
+ )
415
+
416
+ return outputs, final_state
417
+
418
+ @eqx.filter_jit
419
+ def __call__(
420
+ self,
421
+ inputs: Float[Array, "suffix_tokens channels"],
422
+ positional_embeddings: PositionalEmbeddings | None,
423
+ state: Mamba2StateLayer | None = None,
424
+ return_updated_state: bool = False,
425
+ length_without_padding: Int[Array, ""] | int | None = None,
426
+ ) -> Mamba2Result:
427
+ if positional_embeddings is not None:
428
+ raise ValueError("Positional embeddings are not supported for Mamba2.")
429
+
430
+ conv_inputs, gate_values, time_delta_log = vmap(self.in_projection)(inputs)
431
+
432
+ if state is None:
433
+ state = Mamba2StateLayer.init(
434
+ self.config.kernel_size,
435
+ self.inner_dim,
436
+ self.num_heads,
437
+ self.num_groups,
438
+ self.head_dim,
439
+ self.state_dim,
440
+ self.activation_precision,
441
+ )
442
+
443
+ conv_output, updated_conv_state = self.conv(
444
+ conv_inputs,
445
+ state.conv_state,
446
+ return_updated_state=return_updated_state,
447
+ )
448
+ conv_activated = self.config.activation(conv_output)
449
+
450
+ x_channels, input_proj_channels, output_proj_channels = jnp.split(
451
+ conv_activated,
452
+ [
453
+ self.inner_dim,
454
+ self.inner_dim + self.num_groups * self.state_dim,
455
+ ],
456
+ axis=-1,
457
+ )
458
+
459
+ hidden_states = rearrange(
460
+ x_channels,
461
+ "suffix_tokens (heads head_channels) -> suffix_tokens heads head_channels",
462
+ heads=self.num_heads,
463
+ )
464
+ input_projection = rearrange(
465
+ input_proj_channels,
466
+ "suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
467
+ groups=self.num_groups,
468
+ )
469
+ output_projection = rearrange(
470
+ output_proj_channels,
471
+ "suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
472
+ groups=self.num_groups,
473
+ )
474
+ time_delta_log = rearrange(
475
+ time_delta_log,
476
+ "suffix_tokens heads -> suffix_tokens heads",
477
+ heads=self.num_heads,
478
+ )
479
+
480
+ if length_without_padding is None:
481
+ length_without_padding, _ = inputs.shape
482
+
483
+ ssm_outputs, final_ssm_state = self._scan(
484
+ hidden_states,
485
+ input_projection,
486
+ output_projection,
487
+ time_delta_log,
488
+ state.ssm_state,
489
+ length_without_padding,
490
+ )
491
+
492
+ skip_contribution = self.skip_connection_weight[None, :, None] * hidden_states
493
+ ssm_outputs = ssm_outputs + skip_contribution
494
+
495
+ ssm_outputs = rearrange(
496
+ ssm_outputs,
497
+ "suffix_tokens heads head_channels -> suffix_tokens (heads head_channels)",
498
+ )
499
+
500
+ gated_outputs = ssm_outputs * jax.nn.silu(gate_values + self.gate_bias)
501
+
502
+ (outputs,) = vmap(self.out_projection)(gated_outputs)
503
+
504
+ if return_updated_state:
505
+ assert updated_conv_state is not None
506
+ updated_state = Mamba2StateLayer(updated_conv_state, final_ssm_state)
507
+ else:
508
+ updated_state = None
509
+
510
+ return Mamba2Result(
511
+ outputs=outputs,
512
+ state=updated_state,
513
+ )
514
+
515
+ def init_static_state(self, capacity: int) -> Mamba2StateLayer: # noqa: ARG002
516
+ return Mamba2StateLayer.init(
517
+ self.config.kernel_size,
518
+ self.inner_dim,
519
+ self.num_heads,
520
+ self.num_groups,
521
+ self.head_dim,
522
+ self.state_dim,
523
+ self.activation_precision,
524
+ )
525
+
526
+ def export_weights(self) -> ParameterTree:
527
+ return {
528
+ "in_projection": self.in_projection.export_weights(),
529
+ "out_projection": self.out_projection.export_weights(),
530
+ "conv": self.conv.export_weights(),
531
+ "skip_connection_weight": self.skip_connection_weight,
532
+ "gate_bias": self.gate_bias,
533
+ }
534
+
535
+ def import_weights(
536
+ self,
537
+ weights: ParameterTree[Array],
538
+ ) -> Self:
539
+ assert isinstance(weights, Mapping)
540
+ assert isinstance(weights["in_projection"], Mapping)
541
+ assert isinstance(weights["out_projection"], Mapping)
542
+ assert isinstance(weights["conv"], Mapping)
543
+ assert isinstance(weights["skip_connection_weight"], Array)
544
+ assert isinstance(weights["gate_bias"], Array)
545
+
546
+ return replace(
547
+ self,
548
+ in_projection=self.in_projection.import_weights(weights["in_projection"]),
549
+ out_projection=self.out_projection.import_weights(weights["out_projection"]),
550
+ conv=self.conv.import_weights(weights["conv"]),
551
+ skip_connection_weight=weights["skip_connection_weight"],
552
+ gate_bias=weights["gate_bias"],
553
+ )
@@ -0,0 +1,12 @@
1
+ from .common import State, StateLayerBase
2
+ from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
3
+ from .mamba_state import Mamba2StateLayer
4
+
5
+ __all__ = [
6
+ "DynamicKVCacheLayer",
7
+ "KVCacheLayer",
8
+ "Mamba2StateLayer",
9
+ "State",
10
+ "StateLayerBase",
11
+ "StaticKVCacheLayer",
12
+ ]
@@ -0,0 +1,26 @@
1
+ from abc import abstractmethod
2
+ from typing import Self
3
+
4
+ import equinox as eqx
5
+ from jax.tree_util import register_pytree_node_class
6
+
7
+ from lalamo.common import ParameterTree
8
+
9
+ __all__ = ["State", "StateLayerBase"]
10
+
11
+
12
+ class StateLayerBase(eqx.Module):
13
+ @abstractmethod
14
+ def export(self) -> ParameterTree: ...
15
+
16
+
17
+ @register_pytree_node_class
18
+ class State(tuple[StateLayerBase, ...]):
19
+ __slots__ = ()
20
+
21
+ def tree_flatten(self) -> tuple[tuple[StateLayerBase, ...], None]:
22
+ return (tuple(self), None)
23
+
24
+ @classmethod
25
+ def tree_unflatten(cls, aux_data: None, children: tuple[StateLayerBase, ...]) -> Self: # noqa: ARG003
26
+ return cls(children)
@@ -4,15 +4,16 @@ from typing import Self
4
4
  import equinox as eqx
5
5
  import jax.numpy as jnp
6
6
  from jax.lax import dynamic_update_slice_in_dim
7
- from jax.tree_util import register_pytree_node_class
8
7
  from jaxtyping import Array, Bool, DTypeLike, Float, Int
9
8
 
10
9
  from lalamo.common import ParameterTree
11
10
 
12
- __all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer"]
11
+ from .common import StateLayerBase
13
12
 
13
+ __all__ = ["DynamicKVCacheLayer", "KVCacheLayer", "StaticKVCacheLayer"]
14
14
 
15
- class KVCacheLayer(eqx.Module):
15
+
16
+ class KVCacheLayer(StateLayerBase):
16
17
  has_sinks: bool = eqx.field(static=True)
17
18
  keys: Float[Array, "*batch tokens groups head_channels"]
18
19
  values: Float[Array, "*batch tokens groups head_channels"]
@@ -58,18 +59,6 @@ class KVCacheLayer(eqx.Module):
58
59
  )
59
60
 
60
61
 
61
- @register_pytree_node_class
62
- class KVCache(tuple[KVCacheLayer, ...]):
63
- __slots__ = ()
64
-
65
- def tree_flatten(self) -> tuple[tuple[KVCacheLayer, ...], None]:
66
- return (tuple(self), None)
67
-
68
- @classmethod
69
- def tree_unflatten(cls, aux_data: None, children: tuple[KVCacheLayer, ...]) -> Self: # noqa: ARG003
70
- return cls(children)
71
-
72
-
73
62
  class DynamicKVCacheLayer(KVCacheLayer):
74
63
  padding_mask: Bool[Array, " tokens"] | None = None
75
64
 
@@ -224,7 +213,7 @@ class StaticKVCacheLayer(KVCacheLayer):
224
213
  )
225
214
 
226
215
  @classmethod
227
- def empty(cls, has_sinks: bool, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
216
+ def init(cls, has_sinks: bool, capacity: int, num_groups: int, head_dim: int, dtype: DTypeLike) -> Self:
228
217
  return cls(
229
218
  has_sinks=has_sinks,
230
219
  keys=jnp.zeros((capacity, num_groups, head_dim), dtype=dtype),
@@ -0,0 +1,51 @@
1
+ from typing import Self
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array, DTypeLike, Float
5
+
6
+ from lalamo.common import ParameterTree
7
+
8
+ from .common import StateLayerBase
9
+
10
+ __all__ = ["Mamba2StateLayer"]
11
+
12
+
13
+ class Mamba2StateLayer(StateLayerBase):
14
+ conv_state: Float[Array, "*batch tokens conv_channels"]
15
+ ssm_state: Float[Array, "*batch groups head_channels state_channels"]
16
+
17
+ def __post_init__(self) -> None:
18
+ if self.conv_state.ndim not in (2, 3):
19
+ raise ValueError(
20
+ f"Conv state must have 2 or 3 dimensions: [batch], tokens, conv_channels,"
21
+ f" got shape {self.conv_state.shape}",
22
+ )
23
+ if self.ssm_state.ndim not in (3, 4):
24
+ raise ValueError(
25
+ f"SSM state must have 3 or 4 dimensions: [batch], groups, head_channels, state_channels,"
26
+ f" got shape {self.ssm_state.shape}",
27
+ )
28
+ if self.conv_state.dtype != self.ssm_state.dtype:
29
+ raise ValueError("Conv state and SSM state must have the same dtype")
30
+
31
+ @classmethod
32
+ def init(
33
+ cls,
34
+ kernel_size: int,
35
+ inner_dim: int,
36
+ num_heads: int,
37
+ num_groups: int,
38
+ head_dim: int,
39
+ state_dim: int,
40
+ dtype: DTypeLike,
41
+ ) -> Self:
42
+ return cls(
43
+ conv_state=jnp.zeros((kernel_size - 1, inner_dim + 2 * num_groups * state_dim), dtype=dtype),
44
+ ssm_state=jnp.zeros((num_heads, head_dim, state_dim), dtype=dtype),
45
+ )
46
+
47
+ def export(self) -> ParameterTree:
48
+ return dict(
49
+ conv_state=self.conv_state,
50
+ ssm_state=self.ssm_state,
51
+ )