lalamo 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +2 -16
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.1.dist-info/RECORD +0 -71
  46. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.1.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Mapping
2
2
  from dataclasses import dataclass, replace
3
- from typing import NamedTuple, Self
3
+ from typing import Self
4
4
 
5
5
  import equinox as eqx
6
6
  import jax
@@ -10,17 +10,19 @@ from jax import vmap
10
10
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
12
  from lalamo.common import dummy_array
13
+ from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
14
+ from lalamo.modules.linear import LinearBase, LinearConfig
13
15
  from lalamo.modules.normalization import RMSNorm, RMSNormConfig
16
+ from lalamo.modules.rope import PositionalEmbeddings
17
+ from lalamo.modules.utils import apply_soft_capping
14
18
 
15
- from .common import AttentionType, LalamoModule, ParameterTree
16
- from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
17
- from .linear import LinearBase, LinearConfig
18
- from .rope import PositionalEmbeddings
19
- from .utils import apply_soft_capping
19
+ from .common import TokenMixerBase, TokenMixerConfigBase, TokenMixerResult
20
+ from .state import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
20
21
 
21
22
  __all__ = [
22
23
  "Attention",
23
24
  "AttentionConfig",
25
+ "AttentionResult",
24
26
  ]
25
27
 
26
28
 
@@ -72,33 +74,36 @@ def _soft_capped_attention_kernel(
72
74
  )
73
75
 
74
76
 
75
- class AttentionResult(NamedTuple):
76
- outputs: Float[Array, "*batch suffix_tokens channels"]
77
- kv_cache: KVCacheLayer | None = None
77
+ AttentionResult = TokenMixerResult[KVCacheLayer]
78
78
 
79
79
 
80
80
  @dataclass(frozen=True)
81
- class AttentionConfig:
81
+ class AttentionConfig(TokenMixerConfigBase):
82
82
  qkv_projection_config: LinearConfig
83
83
  out_projection_config: LinearConfig
84
84
 
85
85
  query_norm_config: RMSNormConfig | None
86
86
  key_norm_config: RMSNormConfig | None
87
87
 
88
+ num_heads: int
89
+ num_groups: int
90
+ head_dim: int
91
+ is_causal: bool
92
+ scale: float | None
93
+ sliding_window_size: int | None
94
+
88
95
  logit_soft_cap: float | None
89
96
  has_sinks: bool
90
97
  has_qkv_biases: bool
91
98
  has_out_biases: bool
92
99
 
100
+ @property
101
+ def rope_dim(self) -> int:
102
+ return self.head_dim
103
+
93
104
  def random_init(
94
105
  self,
95
106
  model_dim: int,
96
- num_heads: int,
97
- num_groups: int,
98
- head_dim: int,
99
- is_causal: bool,
100
- scale: float | None,
101
- sliding_window_size: int | None,
102
107
  *,
103
108
  key: PRNGKeyArray,
104
109
  ) -> "Attention":
@@ -106,15 +111,15 @@ class AttentionConfig:
106
111
  qkv_projection = self.qkv_projection_config.random_init(
107
112
  input_dim=model_dim,
108
113
  output_dims=(
109
- num_heads * head_dim,
110
- num_groups * head_dim,
111
- num_groups * head_dim,
114
+ self.num_heads * self.head_dim,
115
+ self.num_groups * self.head_dim,
116
+ self.num_groups * self.head_dim,
112
117
  ),
113
118
  has_biases=self.has_qkv_biases,
114
119
  key=qkv_key,
115
120
  )
116
121
  out_projection = self.out_projection_config.random_init(
117
- num_heads * head_dim,
122
+ self.num_heads * self.head_dim,
118
123
  (model_dim,),
119
124
  has_biases=self.has_out_biases,
120
125
  key=out_key,
@@ -122,20 +127,20 @@ class AttentionConfig:
122
127
 
123
128
  if self.query_norm_config is not None:
124
129
  query_norm = self.query_norm_config.init(
125
- input_dim=head_dim,
130
+ input_dim=self.head_dim,
126
131
  )
127
132
  else:
128
133
  query_norm = None
129
134
 
130
135
  if self.key_norm_config is not None:
131
136
  key_norm = self.key_norm_config.init(
132
- input_dim=head_dim,
137
+ input_dim=self.head_dim,
133
138
  )
134
139
  else:
135
140
  key_norm = None
136
141
 
137
142
  if self.has_sinks:
138
- sinks = jnp.zeros((num_heads,), dtype=qkv_projection.activation_precision)
143
+ sinks = jnp.zeros((self.num_heads,), dtype=qkv_projection.activation_precision)
139
144
  else:
140
145
  sinks = None
141
146
 
@@ -146,55 +151,49 @@ class AttentionConfig:
146
151
  query_norm=query_norm,
147
152
  key_norm=key_norm,
148
153
  sinks=sinks,
149
- num_heads=num_heads,
150
- num_groups=num_groups,
151
- head_dim=head_dim,
152
- is_causal=is_causal,
153
- scale=scale,
154
- sliding_window_size=sliding_window_size,
154
+ num_heads=self.num_heads,
155
+ num_groups=self.num_groups,
156
+ head_dim=self.head_dim,
157
+ is_causal=self.is_causal,
158
+ scale=self.scale,
159
+ sliding_window_size=self.sliding_window_size,
155
160
  )
156
161
 
157
162
  def empty(
158
163
  self,
159
164
  model_dim: int,
160
- num_heads: int,
161
- num_groups: int,
162
- head_dim: int,
163
- is_causal: bool,
164
- scale: float | None,
165
- sliding_window_size: int | None,
166
165
  ) -> "Attention":
167
166
  qkv_projection = self.qkv_projection_config.empty(
168
167
  input_dim=model_dim,
169
168
  output_dims=(
170
- num_heads * head_dim,
171
- num_groups * head_dim,
172
- num_groups * head_dim,
169
+ self.num_heads * self.head_dim,
170
+ self.num_groups * self.head_dim,
171
+ self.num_groups * self.head_dim,
173
172
  ),
174
173
  has_biases=self.has_qkv_biases,
175
174
  )
176
175
  out_projection = self.out_projection_config.empty(
177
- num_heads * head_dim,
176
+ self.num_heads * self.head_dim,
178
177
  (model_dim,),
179
178
  has_biases=self.has_out_biases,
180
179
  )
181
180
 
182
181
  if self.query_norm_config is not None:
183
182
  query_norm = self.query_norm_config.empty(
184
- input_dim=head_dim,
183
+ input_dim=self.head_dim,
185
184
  )
186
185
  else:
187
186
  query_norm = None
188
187
 
189
188
  if self.key_norm_config is not None:
190
189
  key_norm = self.key_norm_config.empty(
191
- input_dim=head_dim,
190
+ input_dim=self.head_dim,
192
191
  )
193
192
  else:
194
193
  key_norm = None
195
194
 
196
195
  if self.has_sinks:
197
- sinks = dummy_array(num_heads, qkv_projection.activation_precision)
196
+ sinks = dummy_array(self.num_heads, qkv_projection.activation_precision)
198
197
  else:
199
198
  sinks = None
200
199
 
@@ -205,16 +204,16 @@ class AttentionConfig:
205
204
  query_norm=query_norm,
206
205
  key_norm=key_norm,
207
206
  sinks=sinks,
208
- num_heads=num_heads,
209
- num_groups=num_groups,
210
- head_dim=head_dim,
211
- is_causal=is_causal,
212
- scale=scale,
213
- sliding_window_size=sliding_window_size,
207
+ num_heads=self.num_heads,
208
+ num_groups=self.num_groups,
209
+ head_dim=self.head_dim,
210
+ is_causal=self.is_causal,
211
+ scale=self.scale,
212
+ sliding_window_size=self.sliding_window_size,
214
213
  )
215
214
 
216
215
 
217
- class Attention(LalamoModule[AttentionConfig]):
216
+ class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
218
217
  qkv_projection: LinearBase
219
218
  out_projection: LinearBase
220
219
 
@@ -249,8 +248,10 @@ class Attention(LalamoModule[AttentionConfig]):
249
248
  return self.sliding_window_size is not None
250
249
 
251
250
  @property
252
- def attention_type(self) -> AttentionType:
253
- return AttentionType.SLIDING_WINDOW if self.sliding_window_size is not None else AttentionType.GLOBAL
251
+ def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
252
+ if self.use_sliding_window:
253
+ return PositionalEmbeddingSelector.LOCAL
254
+ return PositionalEmbeddingSelector.GLOBAL
254
255
 
255
256
  @property
256
257
  def has_sinks(self) -> bool:
@@ -318,9 +319,9 @@ class Attention(LalamoModule[AttentionConfig]):
318
319
  def __call__(
319
320
  self,
320
321
  inputs: Float[Array, "suffix_tokens channels"],
321
- positional_embeddings: PositionalEmbeddings,
322
- kv_cache: KVCacheLayer | None = None,
323
- return_updated_kv_cache: bool = False,
322
+ positional_embeddings: PositionalEmbeddings | None,
323
+ state: KVCacheLayer | None = None,
324
+ return_updated_state: bool = False,
324
325
  length_without_padding: Int[Array, ""] | int | None = None,
325
326
  ) -> AttentionResult:
326
327
  queries, keys, values = vmap(self.qkv_projection, in_axes=0)(inputs)
@@ -348,17 +349,18 @@ class Attention(LalamoModule[AttentionConfig]):
348
349
  if self.key_norm is not None:
349
350
  keys = vmap(vmap(self.key_norm))(keys)
350
351
 
351
- apply_positional_embeddings = vmap(positional_embeddings.apply, in_axes=1, out_axes=1)
352
- queries = apply_positional_embeddings(queries)
353
- keys = apply_positional_embeddings(keys)
352
+ if positional_embeddings is not None:
353
+ apply_positional_embeddings = vmap(positional_embeddings.apply, in_axes=1, out_axes=1)
354
+ queries = apply_positional_embeddings(queries)
355
+ keys = apply_positional_embeddings(keys)
354
356
 
355
- if kv_cache is None:
356
- updated_kv_cache = DynamicKVCacheLayer.init(self.has_sinks, keys, values, length=length_without_padding)
357
+ if state is None:
358
+ updated_state = DynamicKVCacheLayer.init(self.has_sinks, keys, values, length=length_without_padding)
357
359
  else:
358
- updated_kv_cache = kv_cache.extend(keys, values, added_length=length_without_padding)
360
+ updated_state = state.extend(keys, values, added_length=length_without_padding)
359
361
 
360
362
  num_suffix_tokens, _, _ = queries.shape
361
- mask = updated_kv_cache.attention_mask(
363
+ mask = updated_state.attention_mask(
362
364
  num_suffix_tokens,
363
365
  self.is_causal,
364
366
  length_without_padding,
@@ -373,8 +375,8 @@ class Attention(LalamoModule[AttentionConfig]):
373
375
  if self.config.logit_soft_cap is not None:
374
376
  attention_output = _soft_capped_attention_kernel(
375
377
  queries,
376
- updated_kv_cache.keys,
377
- updated_kv_cache.values,
378
+ updated_state.keys,
379
+ updated_state.values,
378
380
  mask=mask,
379
381
  scale=self.scale,
380
382
  logit_soft_cap=self.config.logit_soft_cap,
@@ -382,8 +384,8 @@ class Attention(LalamoModule[AttentionConfig]):
382
384
  else:
383
385
  attention_output = jax.nn.dot_product_attention(
384
386
  queries,
385
- updated_kv_cache.keys,
386
- updated_kv_cache.values,
387
+ updated_state.keys,
388
+ updated_state.values,
387
389
  bias=sink_bias,
388
390
  mask=mask,
389
391
  scale=self.scale,
@@ -396,16 +398,16 @@ class Attention(LalamoModule[AttentionConfig]):
396
398
  )
397
399
  (result,) = vmap(self.out_projection, in_axes=0)(attention_output)
398
400
 
399
- if not return_updated_kv_cache:
400
- updated_kv_cache = None
401
+ if not return_updated_state:
402
+ updated_state = None
401
403
 
402
404
  return AttentionResult(
403
405
  outputs=result,
404
- kv_cache=updated_kv_cache,
406
+ state=updated_state,
405
407
  )
406
408
 
407
- def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
408
- return StaticKVCacheLayer.empty(
409
+ def init_static_state(self, capacity: int) -> StaticKVCacheLayer:
410
+ return StaticKVCacheLayer.init(
409
411
  self.has_sinks,
410
412
  capacity,
411
413
  self.num_groups,
@@ -0,0 +1,78 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import NamedTuple, Self
4
+
5
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
6
+
7
+ from lalamo.modules.common import LalamoModule, ParameterTree, PositionalEmbeddingSelector
8
+ from lalamo.modules.rope import PositionalEmbeddings
9
+
10
+ from .state.common import StateLayerBase
11
+
12
+ __all__ = [
13
+ "TokenMixerBase",
14
+ "TokenMixerConfigBase",
15
+ "TokenMixerResult",
16
+ ]
17
+
18
+
19
+ class TokenMixerResult[StateLayerT](NamedTuple):
20
+ outputs: Float[Array, "*batch suffix_tokens channels"]
21
+ state: StateLayerT | None = None
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class TokenMixerConfigBase(ABC):
26
+ @property
27
+ @abstractmethod
28
+ def rope_dim(self) -> int: ...
29
+
30
+ @abstractmethod
31
+ def random_init(
32
+ self,
33
+ model_dim: int,
34
+ *,
35
+ key: PRNGKeyArray,
36
+ ) -> "TokenMixerBase": ...
37
+
38
+ @abstractmethod
39
+ def empty(
40
+ self,
41
+ model_dim: int,
42
+ ) -> "TokenMixerBase": ...
43
+
44
+
45
+ class TokenMixerBase[ConfigT, StateLayerT: StateLayerBase](LalamoModule[ConfigT]):
46
+ @property
47
+ @abstractmethod
48
+ def activation_precision(self) -> DTypeLike: ...
49
+
50
+ @property
51
+ @abstractmethod
52
+ def model_dim(self) -> int: ...
53
+
54
+ @property
55
+ @abstractmethod
56
+ def positional_embedding_selector(self) -> PositionalEmbeddingSelector: ...
57
+
58
+ @abstractmethod
59
+ def __call__(
60
+ self,
61
+ inputs: Float[Array, "suffix_tokens channels"],
62
+ positional_embeddings: PositionalEmbeddings | None,
63
+ state: StateLayerT | None = None,
64
+ return_updated_state: bool = False,
65
+ length_without_padding: Int[Array, ""] | int | None = None,
66
+ ) -> TokenMixerResult[StateLayerT]: ...
67
+
68
+ @abstractmethod
69
+ def init_static_state(self, capacity: int) -> StateLayerT: ...
70
+
71
+ @abstractmethod
72
+ def export_weights(self) -> ParameterTree: ...
73
+
74
+ @abstractmethod
75
+ def import_weights(
76
+ self,
77
+ weights: ParameterTree[Array],
78
+ ) -> Self: ...