lalamo 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +4 -18
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.0.dist-info/RECORD +0 -71
  46. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,40 @@
1
+ from lalamo.model_import.decoder_configs import HFLlambaConfig
2
+ from lalamo.quantization import QuantizationMode
3
+
4
+ from .common import ConfigMap, FileSpec, ModelSpec
5
+
6
+ __all__ = ["LLAMBA_MODELS"]
7
+
8
+ LLAMBA_MODELS = [
9
+ ModelSpec(
10
+ vendor="Cartesia",
11
+ family="Llamba",
12
+ name="Llamba-1B",
13
+ size="1B",
14
+ quantization=None,
15
+ repo="cartesia-ai/Llamba-1B",
16
+ config_type=HFLlambaConfig,
17
+ configs=ConfigMap(
18
+ tokenizer=FileSpec("tokenizer.json", "meta-llama/Llama-3.2-1B-Instruct"),
19
+ tokenizer_config=FileSpec("tokenizer_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
20
+ generation_config=FileSpec("generation_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
21
+ ),
22
+ use_cases=tuple(),
23
+ ),
24
+ ModelSpec(
25
+ vendor="Cartesia",
26
+ family="Llamba",
27
+ name="Llamba-1B-4bit-mlx",
28
+ size="1B",
29
+ quantization=QuantizationMode.UINT4,
30
+ repo="cartesia-ai/Llamba-1B-4bit-mlx",
31
+ config_type=HFLlambaConfig,
32
+ configs=ConfigMap(
33
+ model_config=FileSpec("config.json", "cartesia-ai/Llamba-1B"),
34
+ tokenizer=FileSpec("tokenizer.json", "meta-llama/Llama-3.2-1B-Instruct"),
35
+ tokenizer_config=FileSpec("tokenizer_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
36
+ generation_config=FileSpec("generation_config.json", "meta-llama/Llama-3.2-1B-Instruct"),
37
+ ),
38
+ use_cases=tuple(),
39
+ ),
40
+ ]
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import.decoder_configs import HFQwen2Config, HFQwen3Config
2
2
  from lalamo.quantization import QuantizationMode
3
3
 
4
- from .common import ModelSpec, UseCase, WeightsType
4
+ from .common import ConfigMap, FileSpec, ModelSpec, UseCase, WeightsType
5
5
 
6
6
  __all__ = ["QWEN_MODELS"]
7
7
 
@@ -148,6 +148,20 @@ QWEN3 = [
148
148
  repo="Qwen/Qwen3-0.6B",
149
149
  config_type=HFQwen3Config,
150
150
  ),
151
+ ModelSpec(
152
+ vendor="Alibaba",
153
+ family="Qwen3",
154
+ name="Qwen3-0.6B-MLX-4bit",
155
+ size="0.6B",
156
+ quantization=QuantizationMode.UINT4,
157
+ repo="Qwen/Qwen3-0.6B-MLX-4bit",
158
+ config_type=HFQwen3Config,
159
+ configs=ConfigMap(
160
+ tokenizer=FileSpec("tokenizer.json", "Qwen/Qwen3-0.6B"),
161
+ tokenizer_config=FileSpec("tokenizer_config.json", "Qwen/Qwen3-0.6B"),
162
+ generation_config=FileSpec("generation_config.json", "Qwen/Qwen3-0.6B"),
163
+ ),
164
+ ),
151
165
  ModelSpec(
152
166
  vendor="Alibaba",
153
167
  family="Qwen3",
@@ -177,6 +191,20 @@ QWEN3 = [
177
191
  repo="Qwen/Qwen3-4B-AWQ",
178
192
  config_type=HFQwen3Config,
179
193
  ),
194
+ ModelSpec(
195
+ vendor="Alibaba",
196
+ family="Qwen3",
197
+ name="Qwen3-4B-MLX-4bit",
198
+ size="4B",
199
+ quantization=QuantizationMode.UINT4,
200
+ repo="Qwen/Qwen3-4B-MLX-4bit",
201
+ config_type=HFQwen3Config,
202
+ configs=ConfigMap(
203
+ tokenizer=FileSpec("tokenizer.json", "Qwen/Qwen3-4B"),
204
+ tokenizer_config=FileSpec("tokenizer_config.json", "Qwen/Qwen3-4B"),
205
+ generation_config=FileSpec("generation_config.json", "Qwen/Qwen3-4B"),
206
+ ),
207
+ ),
180
208
  ModelSpec(
181
209
  vendor="Alibaba",
182
210
  family="Qwen3",
@@ -1,6 +1,5 @@
1
- from .activations import GELU, Activation, SiLU
2
- from .attention import Attention, AttentionConfig
3
- from .common import AttentionType, ForwardPassMode, LalamoModule, config_converter
1
+ from .activations import GELU, Activation, Identity, SiLU
2
+ from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector, config_converter
4
3
  from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderForwardPassConfig, DecoderResult
5
4
  from .decoder_layer import (
6
5
  DecoderLayer,
@@ -12,6 +11,10 @@ from .decoder_layer import (
12
11
  from .embedding import (
13
12
  EmbeddingBase,
14
13
  EmbeddingConfig,
14
+ MLXQuantizedTiedEmbedding,
15
+ MLXQuantizedTiedEmbeddingConfig,
16
+ MLXSemiQuantizedUntiedEmbedding,
17
+ MLXSemiQuantizedUntiedEmbeddingConfig,
15
18
  QuantizedTiedEmbedding,
16
19
  QuantizedTiedEmbeddingConfig,
17
20
  TiedEmbedding,
@@ -19,7 +22,6 @@ from .embedding import (
19
22
  UntiedEmbedding,
20
23
  UntiedEmbeddingConfig,
21
24
  )
22
- from .kv_cache import DynamicKVCacheLayer, KVCache, KVCacheLayer, StaticKVCacheLayer
23
25
  from .linear import (
24
26
  FullPrecisionLinear,
25
27
  FullPrecisionLinearConfig,
@@ -27,6 +29,8 @@ from .linear import (
27
29
  GroupQuantizedLinearConfig,
28
30
  LinearBase,
29
31
  LinearConfig,
32
+ MLXQuantizedLinear,
33
+ MLXQuantizedLinearConfig,
30
34
  QLoRALinear,
31
35
  QLoRALinearConfig,
32
36
  )
@@ -51,13 +55,24 @@ from .rope import (
51
55
  UnscaledRoPEConfig,
52
56
  YARNRoPEConfig,
53
57
  )
58
+ from .token_mixers import (
59
+ Attention,
60
+ AttentionConfig,
61
+ DynamicKVCacheLayer,
62
+ KVCacheLayer,
63
+ Mamba2,
64
+ Mamba2Config,
65
+ SeparableCausalConv,
66
+ SeparableCausalConvConfig,
67
+ State,
68
+ StaticKVCacheLayer,
69
+ )
54
70
 
55
71
  __all__ = [
56
72
  "GELU",
57
73
  "Activation",
58
74
  "Attention",
59
75
  "AttentionConfig",
60
- "AttentionType",
61
76
  "Decoder",
62
77
  "DecoderActivationTrace",
63
78
  "DecoderConfig",
@@ -78,7 +93,7 @@ __all__ = [
78
93
  "FullPrecisionLinearConfig",
79
94
  "GroupQuantizedLinear",
80
95
  "GroupQuantizedLinearConfig",
81
- "KVCache",
96
+ "Identity",
82
97
  "KVCacheLayer",
83
98
  "LalamoModule",
84
99
  "LinearBase",
@@ -88,8 +103,17 @@ __all__ = [
88
103
  "MLPBase",
89
104
  "MLPConfig",
90
105
  "MLPForwardPassConfig",
106
+ "MLXQuantizedLinear",
107
+ "MLXQuantizedLinearConfig",
108
+ "MLXQuantizedTiedEmbedding",
109
+ "MLXQuantizedTiedEmbeddingConfig",
110
+ "MLXSemiQuantizedUntiedEmbedding",
111
+ "MLXSemiQuantizedUntiedEmbeddingConfig",
112
+ "Mamba2",
113
+ "Mamba2Config",
91
114
  "MixtureOfExperts",
92
115
  "MixtureOfExpertsConfig",
116
+ "PositionalEmbeddingSelector",
93
117
  "PositionalEmbeddings",
94
118
  "QLoRALinear",
95
119
  "QLoRALinearConfig",
@@ -100,8 +124,11 @@ __all__ = [
100
124
  "RoPE",
101
125
  "RoPEConfig",
102
126
  "RoutingFunction",
127
+ "SeparableCausalConv",
128
+ "SeparableCausalConvConfig",
103
129
  "SiLU",
104
130
  "SoftmaxRouting",
131
+ "State",
105
132
  "StaticKVCacheLayer",
106
133
  "TiedEmbedding",
107
134
  "TiedEmbeddingConfig",
@@ -10,6 +10,7 @@ from lalamo.modules.common import register_config_union
10
10
  __all__ = [
11
11
  "GELU",
12
12
  "Activation",
13
+ "Identity",
13
14
  "SiLU",
14
15
  ]
15
16
 
@@ -34,7 +35,13 @@ class GELU(ActivationBase):
34
35
  return jax.nn.gelu(x)
35
36
 
36
37
 
37
- Activation = SiLU | GELU
38
+ @dataclass(frozen=True)
39
+ class Identity(ActivationBase):
40
+ def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
41
+ return x
42
+
43
+
44
+ Activation = SiLU | GELU | Identity
38
45
 
39
46
 
40
- register_config_union(Activation)
47
+ register_config_union(Activation) # type: ignore (pyright bug)
lalamo/modules/common.py CHANGED
@@ -2,7 +2,7 @@ from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from types import UnionType
5
- from typing import Self
5
+ from typing import Any, Self
6
6
 
7
7
  import equinox as eqx
8
8
  from cattrs import Converter
@@ -12,18 +12,19 @@ from jaxtyping import Array, DTypeLike
12
12
  from lalamo.common import ParameterTree
13
13
 
14
14
  __all__ = [
15
- "AttentionType",
16
15
  "DummyUnionMember",
17
16
  "ForwardPassMode",
18
17
  "LalamoModule",
18
+ "PositionalEmbeddingSelector",
19
19
  "config_converter",
20
20
  "register_config_union",
21
21
  ]
22
22
 
23
23
 
24
- class AttentionType(Enum):
24
+ class PositionalEmbeddingSelector(Enum):
25
25
  GLOBAL = "global"
26
- SLIDING_WINDOW = "sliding_window"
26
+ LOCAL = "sliding_window"
27
+ NONE = "none"
27
28
 
28
29
 
29
30
  class ForwardPassMode(Enum):
@@ -128,4 +129,8 @@ def register_config_union(union_type: UnionType) -> None:
128
129
 
129
130
  @dataclass
130
131
  class DummyUnionMember:
131
- pass
132
+ def __getattribute__(self, name: str, /) -> Any: # noqa: ANN401
133
+ raise NotImplementedError
134
+
135
+ def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401
136
+ raise NotImplementedError
lalamo/modules/decoder.py CHANGED
@@ -8,14 +8,14 @@ from jax import vmap
8
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
10
  from lalamo.common import ParameterTree
11
- from lalamo.modules.utils import vmap_twice
12
11
 
13
- from .common import AttentionType, ForwardPassMode, LalamoModule
12
+ from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
14
13
  from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerForwardPassConfig, DecoderLayerResult
15
14
  from .embedding import EmbeddingBase, EmbeddingConfig
16
- from .kv_cache import KVCache
17
15
  from .normalization import RMSNorm, RMSNormConfig
18
16
  from .rope import PositionalEmbeddings, RoPE, RoPEConfig
17
+ from .token_mixers import AttentionConfig, State
18
+ from .utils import vmap_twice
19
19
 
20
20
  __all__ = [
21
21
  "Decoder",
@@ -32,42 +32,42 @@ type DecoderForwardPassConfig = DecoderLayerForwardPassConfig
32
32
  class DecoderActivationTrace(eqx.Module):
33
33
  token_ids: Int[Array, "batch suffix_tokens"]
34
34
  token_positions: Int[Array, "batch suffix_tokens"]
35
- kv_cache: KVCache | None
35
+ state: State | None
36
36
 
37
- local_positional_embeddings: PositionalEmbeddings
38
- global_positional_embeddings: PositionalEmbeddings
37
+ local_positional_embeddings: PositionalEmbeddings | None
38
+ global_positional_embeddings: PositionalEmbeddings | None
39
39
 
40
40
  layer_results: tuple[DecoderLayerResult, ...]
41
41
 
42
42
  output_norm: Float[Array, "batch suffix_tokens channels"]
43
43
 
44
44
  def export(self) -> ParameterTree:
45
- result = dict(
45
+ result: dict[str, ParameterTree | Array] = dict(
46
46
  token_ids=self.token_ids,
47
47
  token_positions=self.token_positions,
48
- local_positional_embeddings=self.local_positional_embeddings.export(),
49
- global_positional_embeddings=self.global_positional_embeddings.export(),
50
48
  layer_results=[layer_result.export() for layer_result in self.layer_results],
51
49
  output_norm=self.output_norm,
52
50
  )
53
- if self.kv_cache is not None:
54
- result["kv_cache"] = [kv_cache_layer_slice.export() for kv_cache_layer_slice in self.kv_cache]
51
+ if self.local_positional_embeddings is not None:
52
+ result["local_positional_embeddings"] = self.local_positional_embeddings.export()
53
+ if self.global_positional_embeddings is not None:
54
+ result["global_positional_embeddings"] = self.global_positional_embeddings.export()
55
+ if self.state is not None:
56
+ result["state"] = [state_layer.export() for state_layer in self.state]
55
57
  return result
56
58
 
57
59
 
58
60
  class DecoderResult(eqx.Module):
59
61
  logits: Float[Array, "batch suffix_tokens channels"]
60
- updated_kv_cache: KVCache | None = None
62
+ updated_state: State | None = None
61
63
  activation_trace: DecoderActivationTrace | None = None
62
64
 
63
65
  def export(self) -> ParameterTree:
64
66
  result: dict[str, ParameterTree | Array] = dict(
65
67
  logits=self.logits,
66
68
  )
67
- if self.updated_kv_cache is not None:
68
- result["updated_kv_cache"] = [
69
- kv_cache_layer_slice.export() for kv_cache_layer_slice in self.updated_kv_cache
70
- ]
69
+ if self.updated_state is not None:
70
+ result["updated_state"] = [state_layer.export() for state_layer in self.updated_state]
71
71
  if self.activation_trace is not None:
72
72
  result["activation_trace"] = self.activation_trace.export()
73
73
  return result
@@ -76,33 +76,16 @@ class DecoderResult(eqx.Module):
76
76
  @dataclass(frozen=True)
77
77
  class DecoderConfig:
78
78
  embedding_config: EmbeddingConfig
79
- global_rope_config: RoPEConfig
79
+ global_rope_config: RoPEConfig | None
80
80
  local_rope_config: RoPEConfig | None
81
- layer_config: DecoderLayerConfig
81
+ layer_configs: tuple[DecoderLayerConfig, ...]
82
82
  output_norm_config: RMSNormConfig
83
83
 
84
84
  vocab_size: int
85
85
  model_dim: int
86
86
  hidden_dim: int
87
- num_heads: int
88
- num_groups: int
89
- head_dim: int
90
- attention_scale: float | None
91
- num_layers: int
92
- sliding_window_sizes: tuple[int | None, ...] | None
93
87
  context_length: int
94
88
 
95
- def __post_init__(self) -> None:
96
- if self.local_rope_config is not None and self.sliding_window_sizes is None:
97
- raise ValueError("Sliding window sizes must be provided when using local RoPE")
98
- if self.sliding_window_sizes is None:
99
- return
100
- if len(self.sliding_window_sizes) != self.num_layers:
101
- raise ValueError(
102
- f"Number of sliding window sizes {len(self.sliding_window_sizes)} does not match"
103
- f" the number of layers {self.num_layers}",
104
- )
105
-
106
89
  def random_init(
107
90
  self,
108
91
  *,
@@ -114,40 +97,38 @@ class DecoderConfig:
114
97
  model_dim=self.model_dim,
115
98
  key=embedding_key,
116
99
  )
117
- global_rope = self.global_rope_config.init(
118
- head_dim=self.head_dim,
119
- num_timesteps=self.context_length,
120
- )
100
+
101
+ first_layer_config, *_ = self.layer_configs
102
+
103
+ if self.global_rope_config:
104
+ global_rope = self.global_rope_config.init(
105
+ head_dim=first_layer_config.rope_dim,
106
+ num_timesteps=self.context_length,
107
+ )
108
+ else:
109
+ global_rope = None
121
110
 
122
111
  if self.local_rope_config:
123
- assert self.sliding_window_sizes is not None
124
112
  max_sliding_window_size = max(
125
- window_size for window_size in self.sliding_window_sizes if window_size is not None
113
+ layer_config.mixer_config.sliding_window_size or 0
114
+ for layer_config in self.layer_configs
115
+ if isinstance(layer_config.mixer_config, AttentionConfig)
126
116
  )
127
117
  local_rope = self.local_rope_config.init(
128
- head_dim=self.head_dim,
118
+ head_dim=first_layer_config.rope_dim,
129
119
  num_timesteps=max(max_sliding_window_size, self.context_length),
130
120
  )
131
121
  else:
132
122
  local_rope = None
133
123
 
134
- if self.sliding_window_sizes is None:
135
- sliding_window_sizes = [None] * self.num_layers
136
- else:
137
- sliding_window_sizes = self.sliding_window_sizes
138
- layers_keys = jax.random.split(layers_key, self.num_layers)
124
+ layers_keys = jax.random.split(layers_key, len(self.layer_configs))
139
125
  layers = tuple(
140
- self.layer_config.random_init(
126
+ layer_config.random_init(
141
127
  model_dim=self.model_dim,
142
128
  hidden_dim=self.hidden_dim,
143
- num_heads=self.num_heads,
144
- num_groups=self.num_groups,
145
- head_dim=self.head_dim,
146
- attention_scale=self.attention_scale,
147
- sliding_window_size=sliding_window_size,
148
129
  key=key,
149
130
  )
150
- for sliding_window_size, key in zip(sliding_window_sizes, layers_keys, strict=True)
131
+ for layer_config, key in zip(self.layer_configs, layers_keys, strict=False)
151
132
  )
152
133
  output_norm = self.output_norm_config.init(self.model_dim)
153
134
  return Decoder(
@@ -166,34 +147,35 @@ class DecoderConfig:
166
147
  vocab_size=self.vocab_size,
167
148
  model_dim=self.model_dim,
168
149
  )
169
- global_rope = self.global_rope_config.init(
170
- head_dim=self.head_dim,
171
- num_timesteps=self.context_length,
172
- )
173
150
 
174
- if self.local_rope_config:
175
- local_rope = self.local_rope_config.init(
176
- head_dim=self.head_dim,
151
+ first_layer_config, *_ = self.layer_configs
152
+
153
+ if self.global_rope_config:
154
+ global_rope = self.global_rope_config.init(
155
+ head_dim=first_layer_config.rope_dim,
177
156
  num_timesteps=self.context_length,
178
157
  )
179
158
  else:
180
- local_rope = None
159
+ global_rope = None
181
160
 
182
- if self.sliding_window_sizes is None:
183
- sliding_window_sizes = [None] * self.num_layers
161
+ if self.local_rope_config:
162
+ max_sliding_window_size = max(
163
+ layer_config.mixer_config.sliding_window_size or 0
164
+ for layer_config in self.layer_configs
165
+ if isinstance(layer_config.mixer_config, AttentionConfig)
166
+ )
167
+ local_rope = self.local_rope_config.init(
168
+ head_dim=first_layer_config.rope_dim,
169
+ num_timesteps=max(max_sliding_window_size, self.context_length),
170
+ )
184
171
  else:
185
- sliding_window_sizes = self.sliding_window_sizes
172
+ local_rope = None
186
173
  layers = tuple(
187
- self.layer_config.empty(
174
+ layer_config.empty(
188
175
  model_dim=self.model_dim,
189
176
  hidden_dim=self.hidden_dim,
190
- num_heads=self.num_heads,
191
- num_groups=self.num_groups,
192
- head_dim=self.head_dim,
193
- attention_scale=self.attention_scale,
194
- sliding_window_size=sliding_window_size,
195
177
  )
196
- for sliding_window_size in sliding_window_sizes
178
+ for layer_config in self.layer_configs
197
179
  )
198
180
  output_norm = self.output_norm_config.empty(self.model_dim)
199
181
  return Decoder(
@@ -208,7 +190,7 @@ class DecoderConfig:
208
190
 
209
191
  class Decoder(LalamoModule[DecoderConfig]):
210
192
  embedding: EmbeddingBase
211
- global_rope: RoPE
193
+ global_rope: RoPE | None
212
194
  local_rope: RoPE | None
213
195
  layers: tuple[DecoderLayer, ...]
214
196
  output_norm: RMSNorm
@@ -218,12 +200,12 @@ class Decoder(LalamoModule[DecoderConfig]):
218
200
  return self.embedding.activation_precision
219
201
 
220
202
  @eqx.filter_jit
221
- def __call__(
203
+ def __call__( # noqa: PLR0912
222
204
  self,
223
205
  token_ids: Int[Array, "batch suffix_tokens"],
224
206
  token_positions: Int[Array, "batch suffix_tokens"],
225
- kv_cache: KVCache | None = None,
226
- return_updated_kv_cache: bool = False,
207
+ state: State | None = None,
208
+ return_updated_state: bool = False,
227
209
  return_activation_trace: bool = False,
228
210
  lengths_without_padding: Int[Array, " batch"] | None = None,
229
211
  forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
@@ -239,28 +221,35 @@ class Decoder(LalamoModule[DecoderConfig]):
239
221
  f" got {token_positions.shape}",
240
222
  )
241
223
 
242
- maybe_kv_cache = kv_cache or ([None] * len(self.layers))
224
+ maybe_state = state or ([None] * len(self.layers))
243
225
  inner_features = vmap(self.embedding.embed)(token_ids)
244
226
 
245
- global_positional_embeddings = vmap(self.global_rope)(token_positions)
227
+ if self.global_rope is not None:
228
+ global_positional_embeddings = vmap(self.global_rope)(token_positions)
229
+ else:
230
+ global_positional_embeddings = None
231
+
246
232
  if self.local_rope is not None:
247
233
  local_positional_embeddings = vmap(self.local_rope)(token_positions)
248
234
  else:
249
235
  local_positional_embeddings = global_positional_embeddings
250
236
 
251
- updated_kv_cache_layers = []
237
+ updated_state_layers = []
252
238
  layer_results = []
253
- for layer, kv_cache_slice in zip(self.layers, maybe_kv_cache, strict=True):
254
- if layer.attention_type == AttentionType.SLIDING_WINDOW:
255
- positional_embeddings_to_use = local_positional_embeddings
256
- else:
257
- positional_embeddings_to_use = global_positional_embeddings
239
+ for layer, state_layer in zip(self.layers, maybe_state, strict=True):
240
+ match layer.positional_embedding_selector:
241
+ case PositionalEmbeddingSelector.LOCAL:
242
+ positional_embeddings_to_use = local_positional_embeddings
243
+ case PositionalEmbeddingSelector.GLOBAL:
244
+ positional_embeddings_to_use = global_positional_embeddings
245
+ case PositionalEmbeddingSelector.NONE:
246
+ positional_embeddings_to_use = None
258
247
 
259
248
  layer_result = layer(
260
249
  inner_features,
261
250
  positional_embeddings_to_use,
262
- kv_cache=kv_cache_slice,
263
- return_updated_kv_cache=return_updated_kv_cache,
251
+ state=state_layer,
252
+ return_updated_state=return_updated_state,
264
253
  return_activation_trace=return_activation_trace,
265
254
  lengths_without_padding=lengths_without_padding,
266
255
  forward_pass_mode=forward_pass_mode,
@@ -268,7 +257,7 @@ class Decoder(LalamoModule[DecoderConfig]):
268
257
  )
269
258
  inner_features = layer_result.outputs
270
259
  layer_results.append(layer_result)
271
- updated_kv_cache_layers.append(layer_result.updated_kv_cache)
260
+ updated_state_layers.append(layer_result.updated_state)
272
261
 
273
262
  normalized_outputs = vmap_twice(self.output_norm)(inner_features)
274
263
  logits = vmap_twice(self.embedding.readout)(normalized_outputs)
@@ -277,7 +266,7 @@ class Decoder(LalamoModule[DecoderConfig]):
277
266
  activation_trace = DecoderActivationTrace(
278
267
  token_ids=token_ids,
279
268
  token_positions=token_positions,
280
- kv_cache=kv_cache,
269
+ state=state,
281
270
  global_positional_embeddings=global_positional_embeddings,
282
271
  local_positional_embeddings=local_positional_embeddings,
283
272
  layer_results=tuple(layer_results),
@@ -286,27 +275,28 @@ class Decoder(LalamoModule[DecoderConfig]):
286
275
  else:
287
276
  activation_trace = None
288
277
 
289
- if return_updated_kv_cache:
290
- updated_kv_cache = KVCache(updated_kv_cache_layers)
278
+ if return_updated_state:
279
+ updated_state = State(updated_state_layers)
291
280
  else:
292
- updated_kv_cache = None
281
+ updated_state = None
293
282
 
294
283
  return DecoderResult(
295
284
  logits=logits,
296
- updated_kv_cache=updated_kv_cache,
285
+ updated_state=updated_state,
297
286
  activation_trace=activation_trace,
298
287
  )
299
288
 
300
- def init_static_kv_cache(self, batch_size: int, capacity: int) -> KVCache:
301
- return KVCache(layer.init_static_kv_cache(batch_size, capacity) for layer in self.layers)
289
+ def init_static_state(self, batch_size: int, capacity: int) -> State:
290
+ return State(layer.init_static_state(batch_size, capacity) for layer in self.layers)
302
291
 
303
292
  def export_weights(self) -> ParameterTree:
304
293
  result = dict(
305
294
  embedding=self.embedding.export_weights(),
306
- global_rope=self.global_rope.export_weights(),
307
295
  layers=[layer.export_weights() for layer in self.layers],
308
296
  output_norm=self.output_norm.export_weights(),
309
297
  )
298
+ if self.global_rope:
299
+ result["global_rope"] = self.global_rope.export_weights()
310
300
  if self.local_rope:
311
301
  result["local_rope"] = self.local_rope.export_weights()
312
302
  return result
@@ -317,15 +307,21 @@ class Decoder(LalamoModule[DecoderConfig]):
317
307
  ) -> Self:
318
308
  assert isinstance(weights, Mapping)
319
309
  assert isinstance(weights["embedding"], Mapping)
320
- assert isinstance(weights["global_rope"], Mapping)
321
310
  assert isinstance(weights["layers"], Sequence)
322
311
  assert isinstance(weights["output_norm"], Mapping)
312
+
323
313
  if self.local_rope:
324
314
  assert isinstance(weights["local_rope"], Mapping)
325
315
  local_rope = self.local_rope.import_weights(weights["local_rope"])
326
316
  else:
327
317
  local_rope = None
328
318
 
319
+ if self.global_rope:
320
+ assert isinstance(weights["global_rope"], Mapping)
321
+ global_rope = self.global_rope.import_weights(weights["global_rope"])
322
+ else:
323
+ global_rope = None
324
+
329
325
  layers = []
330
326
  for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
331
327
  assert isinstance(layer_weights, Mapping)
@@ -333,7 +329,7 @@ class Decoder(LalamoModule[DecoderConfig]):
333
329
  return replace(
334
330
  self,
335
331
  embedding=self.embedding.import_weights(weights["embedding"]),
336
- global_rope=self.global_rope.import_weights(weights["global_rope"]),
332
+ global_rope=global_rope,
337
333
  layers=tuple(layers),
338
334
  output_norm=self.output_norm.import_weights(weights["output_norm"]),
339
335
  local_rope=local_rope,