lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +23 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +101 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
  44. lalamo-0.2.2.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple
3
+
4
+ import equinox as eqx
5
+ import jax
6
+ from einops import einsum, rearrange, repeat
7
+ from jax import numpy as jnp
8
+ from jax import vmap
9
+ from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
10
+
11
+ from lalamo.common import ParameterDict
12
+ from lalamo.modules.normalization import RMSNorm, RMSNormConfig
13
+
14
+ from .common import AttentionType, LalamoModule, WeightLayout
15
+ from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
16
+ from .linear import LinearBase, LinearConfig
17
+ from .rope import PositionalEmbeddings
18
+ from .utils import apply_soft_capping
19
+
20
+ __all__ = [
21
+ "Attention",
22
+ "AttentionConfig",
23
+ ]
24
+
25
+
26
+ def _repeat_kv(
27
+ keys_or_values: Float[Array, "tokens groups channels"],
28
+ group_size: int,
29
+ ) -> Float[Array, "tokens groups*group_size channels"]:
30
+ return repeat(
31
+ keys_or_values,
32
+ "tokens groups channels -> tokens (groups group_size) channels",
33
+ group_size=group_size,
34
+ )
35
+
36
+
37
+ def _soft_capped_attention_kernel(
38
+ queries: Float[Array, "dst_tokens heads head_channels"],
39
+ keys: Float[Array, "src_tokens groups head_channels"],
40
+ values: Float[Array, "src_tokens groups head_channels"],
41
+ mask: Bool[Array, "dst_tokens src_tokens"] | None,
42
+ scale: float | None,
43
+ logit_soft_cap: float,
44
+ ) -> Float[Array, "dst_tokens heads head_channels"]:
45
+ dst_length, num_heads, head_dim = queries.shape
46
+ src_length, num_groups, _ = keys.shape
47
+ if scale is None:
48
+ scale = head_dim**-0.5
49
+ group_size = num_heads // num_groups
50
+ keys = _repeat_kv(keys, group_size)
51
+ values = _repeat_kv(values, group_size)
52
+ queries_head_first = rearrange(queries, "dst_tokens heads channels -> heads dst_tokens channels")
53
+ keys_head_first = rearrange(keys, "src_tokens heads channels -> heads src_tokens channels")
54
+ attention_logits = einsum(
55
+ queries_head_first,
56
+ keys_head_first,
57
+ "heads dst_tokens channels, heads src_tokens channels -> heads dst_tokens src_tokens",
58
+ )
59
+ if mask is not None:
60
+ attention_logits = jnp.where(mask, attention_logits, jnp.array(float("-inf"), dtype=attention_logits.dtype))
61
+
62
+ attention_logits = attention_logits * scale
63
+ attention_logits = apply_soft_capping(attention_logits, logit_soft_cap)
64
+ attention_weights = jax.nn.softmax(attention_logits, axis=-1)
65
+ return einsum(
66
+ attention_weights,
67
+ values,
68
+ "heads dst_tokens src_tokens, src_tokens heads channels -> dst_tokens heads channels",
69
+ )
70
+
71
+
72
+ class AttentionResult(NamedTuple):
73
+ outputs: Float[Array, "suffix_tokens channels"]
74
+ kv_cache: KVCacheLayer | None = None
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class AttentionConfig:
79
+ qkv_projection_config: LinearConfig
80
+ out_projection_config: LinearConfig
81
+
82
+ query_norm_config: RMSNormConfig | None
83
+ key_norm_config: RMSNormConfig | None
84
+
85
+ logit_soft_cap: float | None
86
+ has_qkv_biases: bool
87
+ has_out_biases: bool
88
+
89
+ def random_init(
90
+ self,
91
+ model_dim: int,
92
+ num_heads: int,
93
+ num_groups: int,
94
+ head_dim: int,
95
+ is_causal: bool,
96
+ scale: float | None,
97
+ sliding_window_size: int | None,
98
+ *,
99
+ key: PRNGKeyArray,
100
+ ) -> "Attention":
101
+ qkv_key, out_key = jax.random.split(key)
102
+ qkv_projection = self.qkv_projection_config.random_init(
103
+ input_dim=model_dim,
104
+ output_dims=(
105
+ num_heads * head_dim,
106
+ num_groups * head_dim,
107
+ num_groups * head_dim,
108
+ ),
109
+ has_biases=self.has_qkv_biases,
110
+ key=qkv_key,
111
+ )
112
+ out_projection = self.out_projection_config.random_init(
113
+ num_heads * head_dim,
114
+ (model_dim,),
115
+ has_biases=self.has_out_biases,
116
+ key=out_key,
117
+ )
118
+
119
+ if self.query_norm_config is not None:
120
+ query_norm = self.query_norm_config.init(
121
+ channels=head_dim,
122
+ )
123
+ else:
124
+ query_norm = None
125
+
126
+ if self.key_norm_config is not None:
127
+ key_norm = self.key_norm_config.init(
128
+ channels=head_dim,
129
+ )
130
+ else:
131
+ key_norm = None
132
+
133
+ return Attention(
134
+ self,
135
+ qkv_projection=qkv_projection,
136
+ out_projection=out_projection,
137
+ query_norm=query_norm,
138
+ key_norm=key_norm,
139
+ num_heads=num_heads,
140
+ num_groups=num_groups,
141
+ head_dim=head_dim,
142
+ is_causal=is_causal,
143
+ scale=scale,
144
+ sliding_window_size=sliding_window_size,
145
+ )
146
+
147
+
148
+ class Attention(LalamoModule[AttentionConfig]):
149
+ qkv_projection: LinearBase
150
+ out_projection: LinearBase
151
+
152
+ query_norm: RMSNorm | None
153
+ key_norm: RMSNorm | None
154
+
155
+ num_heads: int = eqx.field(static=True)
156
+ num_groups: int = eqx.field(static=True)
157
+ head_dim: int = eqx.field(static=True)
158
+
159
+ is_causal: bool = eqx.field(static=True)
160
+
161
+ scale: float | None = eqx.field(static=True)
162
+ sliding_window_size: int | None = eqx.field(static=True)
163
+
164
+ @property
165
+ def activation_precision(self) -> DTypeLike:
166
+ return self.qkv_projection.activation_precision
167
+
168
+ @property
169
+ def model_dim(self) -> int:
170
+ return self.qkv_projection.input_dim
171
+
172
+ @property
173
+ def group_size(self) -> int:
174
+ return self.num_heads // self.num_groups
175
+
176
+ @property
177
+ def use_sliding_window(self) -> bool:
178
+ return self.sliding_window_size is not None
179
+
180
+ @property
181
+ def attention_type(self) -> AttentionType:
182
+ return AttentionType.SLIDING_WINDOW if self.sliding_window_size is not None else AttentionType.GLOBAL
183
+
184
+ def __post_init__(self) -> None:
185
+ if self.qkv_projection.has_biases != self.config.has_qkv_biases:
186
+ raise ValueError(
187
+ f"QKV projection has_biases {self.qkv_projection.has_biases} does not match"
188
+ f" the specified config has_qkv_biases {self.config.has_qkv_biases}",
189
+ )
190
+ if self.out_projection.has_biases != self.config.has_out_biases:
191
+ raise ValueError(
192
+ f"Output projection has_biases {self.out_projection.has_biases} does not match"
193
+ f" the specified config has_out_biases {self.config.has_out_biases}",
194
+ )
195
+ if self.query_norm is not None and self.query_norm.input_dim != self.head_dim:
196
+ raise ValueError(
197
+ f"Query normalization input dimension must match head_dim ({self.head_dim}),"
198
+ f" got {self.query_norm.input_dim}",
199
+ )
200
+ if self.key_norm is not None and self.key_norm.input_dim != self.head_dim:
201
+ raise ValueError(
202
+ f"Key normalization input dimension must match head_dim ({self.head_dim}),"
203
+ f" got {self.key_norm.input_dim}",
204
+ )
205
+ if self.num_heads % self.num_groups != 0:
206
+ raise ValueError(
207
+ "Number of heads must be divisible by the number of groups,"
208
+ f" got {self.num_heads} heads and {self.num_groups} groups",
209
+ )
210
+ if self.out_projection.input_dim != self.num_heads * self.head_dim:
211
+ raise ValueError(
212
+ f"Output projection input dimension must be num_heads * head_dim"
213
+ f" ({self.num_heads} * {self.head_dim} = {self.num_heads * self.head_dim}),"
214
+ f" got {self.out_projection.input_dim}",
215
+ )
216
+ q_output_dim, k_output_dim, v_output_dim = self.qkv_projection.output_dims
217
+ if q_output_dim != self.num_heads * self.head_dim:
218
+ raise ValueError(
219
+ f"Query projection output dimension must be num_heads * head_dim"
220
+ f" ({self.num_heads} * {self.head_dim} = {self.num_heads * self.head_dim}),"
221
+ f" got {q_output_dim}",
222
+ )
223
+ if k_output_dim != self.num_groups * self.head_dim:
224
+ raise ValueError(
225
+ f"Key projection output dimension must be num_groups * head_dim"
226
+ f" ({self.num_groups} * {self.head_dim} = {self.num_groups * self.head_dim}),"
227
+ f" got {k_output_dim}",
228
+ )
229
+ if v_output_dim != self.num_groups * self.head_dim:
230
+ raise ValueError(
231
+ f"Value projection output dimension must be num_groups * head_dim"
232
+ f" ({self.num_groups} * {self.head_dim} = {self.num_groups * self.head_dim}),"
233
+ f" got {v_output_dim}",
234
+ )
235
+
236
+ def __call__(
237
+ self,
238
+ inputs: Float[Array, "suffix_tokens channels"],
239
+ positional_embeddings: PositionalEmbeddings,
240
+ kv_cache: KVCacheLayer | None = None,
241
+ return_updated_kv_cache: bool = False,
242
+ length_without_padding: Int[Array, ""] | int | None = None,
243
+ ) -> AttentionResult:
244
+ queries, keys, values = vmap(self.qkv_projection, in_axes=0)(inputs)
245
+ queries = rearrange(
246
+ queries,
247
+ "tokens (heads head_channels) -> tokens heads head_channels",
248
+ heads=self.num_heads,
249
+ head_channels=self.head_dim,
250
+ )
251
+ keys = rearrange(
252
+ keys,
253
+ "tokens (groups head_channels) -> tokens groups head_channels",
254
+ groups=self.num_groups,
255
+ head_channels=self.head_dim,
256
+ )
257
+ values = rearrange(
258
+ values,
259
+ "tokens (groups head_channels) -> tokens groups head_channels",
260
+ groups=self.num_groups,
261
+ head_channels=self.head_dim,
262
+ )
263
+
264
+ if self.query_norm is not None:
265
+ queries = vmap(vmap(self.query_norm))(queries)
266
+ if self.key_norm is not None:
267
+ keys = vmap(vmap(self.key_norm))(keys)
268
+
269
+ apply_positional_embeddings = vmap(positional_embeddings.apply, in_axes=1, out_axes=1)
270
+ queries = apply_positional_embeddings(queries)
271
+ keys = apply_positional_embeddings(keys)
272
+
273
+ if kv_cache is None:
274
+ updated_kv_cache = DynamicKVCacheLayer.init(keys, values, length=length_without_padding)
275
+ else:
276
+ updated_kv_cache = kv_cache.extend(keys, values, added_length=length_without_padding)
277
+
278
+ num_suffix_tokens, _, _ = queries.shape
279
+ mask = updated_kv_cache.attention_mask(num_suffix_tokens, self.is_causal, self.sliding_window_size)
280
+
281
+ if self.config.logit_soft_cap is not None:
282
+ attention_output = _soft_capped_attention_kernel(
283
+ queries,
284
+ updated_kv_cache.keys,
285
+ updated_kv_cache.values,
286
+ mask=mask,
287
+ scale=self.scale,
288
+ logit_soft_cap=self.config.logit_soft_cap,
289
+ )
290
+ else:
291
+ attention_output = jax.nn.dot_product_attention(
292
+ queries,
293
+ updated_kv_cache.keys,
294
+ updated_kv_cache.values,
295
+ mask=mask,
296
+ scale=self.scale,
297
+ )
298
+ attention_output = rearrange(
299
+ attention_output,
300
+ "tokens heads head_channels -> tokens (heads head_channels)",
301
+ heads=self.num_heads,
302
+ head_channels=self.head_dim,
303
+ )
304
+ (result,) = vmap(self.out_projection, in_axes=0)(attention_output)
305
+
306
+ if not return_updated_kv_cache:
307
+ updated_kv_cache = None
308
+
309
+ return AttentionResult(
310
+ outputs=result,
311
+ kv_cache=updated_kv_cache,
312
+ )
313
+
314
+ def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
315
+ return StaticKVCacheLayer.empty(capacity, self.num_groups, self.head_dim, self.activation_precision)
316
+
317
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
318
+ result = ParameterDict(
319
+ qkv_projection=self.qkv_projection.export_weights(weight_layout),
320
+ out_projection=self.out_projection.export_weights(weight_layout),
321
+ )
322
+ if self.query_norm is not None:
323
+ result["query_norm"] = self.query_norm.export_weights(weight_layout)
324
+ if self.key_norm is not None:
325
+ result["key_norm"] = self.key_norm.export_weights(weight_layout)
326
+ return result
@@ -0,0 +1,133 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from types import UnionType
5
+
6
+ import equinox as eqx
7
+ from cattrs import Converter
8
+ from jax import numpy as jnp
9
+ from jaxtyping import DTypeLike
10
+
11
+ from lalamo.common import ParameterDict
12
+
13
+ __all__ = [
14
+ "AttentionType",
15
+ "DummyUnionMember",
16
+ "LalamoModule",
17
+ "config_converter",
18
+ "register_config_union",
19
+ ]
20
+
21
+
22
+ class WeightLayout(Enum):
23
+ AUTO = "auto"
24
+ INPUT_OUTPUT = "input_output"
25
+ OUTPUT_INPUT = "output_input"
26
+
27
+ def __str__(self) -> str:
28
+ match self:
29
+ case WeightLayout.AUTO:
30
+ return "auto"
31
+ case WeightLayout.INPUT_OUTPUT:
32
+ return "(input, output)"
33
+ case WeightLayout.OUTPUT_INPUT:
34
+ return "(output, input)"
35
+
36
+
37
+ class AttentionType(Enum):
38
+ GLOBAL = "global"
39
+ SLIDING_WINDOW = "sliding_window"
40
+
41
+
42
+ class LalamoModule[ConfigT](eqx.Module):
43
+ config: ConfigT = eqx.field(static=True)
44
+
45
+ @property
46
+ @abstractmethod
47
+ def activation_precision(self) -> DTypeLike: ...
48
+
49
+ @abstractmethod
50
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: ...
51
+
52
+
53
+ def _dtype_to_str(dtype: DTypeLike) -> str:
54
+ if dtype == jnp.bfloat16:
55
+ return "bfloat16"
56
+ try:
57
+ return str(dtype.dtype) # type: ignore
58
+ except AttributeError:
59
+ return str(dtype)
60
+
61
+
62
+ def _str_to_dtype(dtype_str: str) -> jnp.dtype:
63
+ return {
64
+ "int4": jnp.int4,
65
+ "int8": jnp.int8,
66
+ "int16": jnp.int16,
67
+ "int32": jnp.int32,
68
+ "int64": jnp.int64,
69
+ "bfloat16": jnp.bfloat16,
70
+ "float16": jnp.float16,
71
+ "float32": jnp.float32,
72
+ "float64": jnp.float64,
73
+ }[dtype_str]
74
+
75
+
76
+ config_converter = Converter()
77
+
78
+
79
+ config_converter.register_unstructure_hook_func(
80
+ lambda t: t in [jnp.dtype, DTypeLike],
81
+ _dtype_to_str,
82
+ )
83
+
84
+ config_converter.register_structure_hook_func(
85
+ lambda t: t in [jnp.dtype, DTypeLike],
86
+ lambda s, _: _str_to_dtype(s),
87
+ )
88
+
89
+
90
+ def register_config_union(union_type: UnionType) -> None:
91
+ union_members = union_type.__args__
92
+ name_to_type = {m.__name__: m for m in union_members}
93
+
94
+ def unstructure(obj: object) -> dict | None:
95
+ if obj is None:
96
+ return None
97
+ return {
98
+ "type": obj.__class__.__name__,
99
+ **config_converter.unstructure(obj),
100
+ }
101
+
102
+ config_converter.register_unstructure_hook(
103
+ union_type,
104
+ unstructure,
105
+ )
106
+
107
+ config_converter.register_unstructure_hook(
108
+ union_type | None,
109
+ unstructure,
110
+ )
111
+
112
+ def structure[T](config: dict | None, _: type[T]) -> T | None:
113
+ if config is None:
114
+ return None
115
+ new_config = dict(config)
116
+ type_name = new_config.pop("type")
117
+ target_type = name_to_type[type_name]
118
+ return name_to_type[type_name](**config_converter.structure(new_config, target_type))
119
+
120
+ config_converter.register_structure_hook(
121
+ union_type,
122
+ structure,
123
+ )
124
+
125
+ config_converter.register_structure_hook(
126
+ union_type | None,
127
+ structure,
128
+ )
129
+
130
+
131
+ @dataclass
132
+ class DummyUnionMember:
133
+ pass
@@ -0,0 +1,244 @@
1
+ from dataclasses import dataclass
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ from jax import vmap
6
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
7
+
8
+ from lalamo.common import ParameterDict
9
+
10
+ from .common import AttentionType, LalamoModule, WeightLayout
11
+ from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerResult
12
+ from .embedding import EmbeddingBase, EmbeddingConfig
13
+ from .kv_cache import KVCache
14
+ from .normalization import RMSNorm, RMSNormConfig
15
+ from .rope import PositionalEmbeddings, RoPE, RoPEConfig
16
+
17
+ __all__ = [
18
+ "Decoder",
19
+ "DecoderActivationTrace",
20
+ "DecoderConfig",
21
+ "DecoderResult",
22
+ ]
23
+
24
+
25
+ class DecoderActivationTrace(eqx.Module):
26
+ token_ids: Int[Array, " suffix_tokens"]
27
+ token_positions: Int[Array, " suffix_tokens"]
28
+ kv_cache: KVCache | None
29
+
30
+ local_positional_embeddings: PositionalEmbeddings
31
+ global_positional_embeddings: PositionalEmbeddings
32
+
33
+ layer_results: tuple[DecoderLayerResult, ...]
34
+
35
+ output_norm: Float[Array, "suffix_tokens channels"]
36
+
37
+ def export(self) -> ParameterDict:
38
+ result = ParameterDict(
39
+ token_ids=self.token_ids,
40
+ token_positions=self.token_positions,
41
+ local_positional_embeddings=self.local_positional_embeddings.export(),
42
+ global_positional_embeddings=self.global_positional_embeddings.export(),
43
+ layer_results=[layer_result.export() for layer_result in self.layer_results],
44
+ output_norm=self.output_norm,
45
+ )
46
+ if self.kv_cache is not None:
47
+ result["kv_cache"] = [kv_cache_layer_slice.export() for kv_cache_layer_slice in self.kv_cache]
48
+ return result
49
+
50
+
51
+ class DecoderResult(eqx.Module):
52
+ logits: Float[Array, "suffix_tokens channels"]
53
+ updated_kv_cache: KVCache | None = None
54
+ activation_trace: DecoderActivationTrace | None = None
55
+
56
+ def export(self) -> ParameterDict:
57
+ result = ParameterDict(
58
+ logits=self.logits,
59
+ )
60
+ if self.updated_kv_cache is not None:
61
+ result["updated_kv_cache"] = [
62
+ kv_cache_layer_slice.export() for kv_cache_layer_slice in self.updated_kv_cache
63
+ ]
64
+ if self.activation_trace is not None:
65
+ result["activation_trace"] = self.activation_trace.export()
66
+ return result
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class DecoderConfig:
71
+ embedding_config: EmbeddingConfig
72
+ global_rope_config: RoPEConfig
73
+ local_rope_config: RoPEConfig | None
74
+ layer_config: DecoderLayerConfig
75
+ output_norm_config: RMSNormConfig
76
+
77
+ vocab_size: int
78
+ model_dim: int
79
+ hidden_dim: int
80
+ num_heads: int
81
+ num_groups: int
82
+ head_dim: int
83
+ attention_scale: float | None
84
+ num_layers: int
85
+ sliding_window_sizes: tuple[int | None, ...] | None
86
+ context_length: int
87
+
88
+ def __post_init__(self) -> None:
89
+ if self.local_rope_config is not None and self.sliding_window_sizes is None:
90
+ raise ValueError("Sliding window sizes must be provided when using local RoPE")
91
+ if self.sliding_window_sizes is None:
92
+ return
93
+ if len(self.sliding_window_sizes) != self.num_layers:
94
+ raise ValueError(
95
+ f"Number of sliding window sizes {len(self.sliding_window_sizes)} does not match"
96
+ f" the number of layers {self.num_layers}",
97
+ )
98
+
99
+ def random_init(
100
+ self,
101
+ *,
102
+ key: PRNGKeyArray,
103
+ ) -> "Decoder":
104
+ embedding_key, layers_key = jax.random.split(key)
105
+ embedding = self.embedding_config.random_init(
106
+ vocab_size=self.vocab_size,
107
+ model_dim=self.model_dim,
108
+ key=embedding_key,
109
+ )
110
+ global_rope = self.global_rope_config.init(
111
+ head_dim=self.head_dim,
112
+ num_timesteps=self.context_length,
113
+ )
114
+
115
+ if self.local_rope_config:
116
+ assert self.sliding_window_sizes is not None
117
+ max_sliding_window_size = max(
118
+ window_size for window_size in self.sliding_window_sizes if window_size is not None
119
+ )
120
+ local_rope = self.local_rope_config.init(
121
+ head_dim=self.head_dim,
122
+ num_timesteps=max(max_sliding_window_size, self.context_length),
123
+ )
124
+ else:
125
+ local_rope = None
126
+
127
+ if self.sliding_window_sizes is None:
128
+ sliding_window_sizes = [None] * self.num_layers
129
+ else:
130
+ sliding_window_sizes = self.sliding_window_sizes
131
+ layers_keys = jax.random.split(layers_key, self.num_layers)
132
+ layers = tuple(
133
+ self.layer_config.random_init(
134
+ model_dim=self.model_dim,
135
+ hidden_dim=self.hidden_dim,
136
+ num_heads=self.num_heads,
137
+ num_groups=self.num_groups,
138
+ head_dim=self.head_dim,
139
+ attention_scale=self.attention_scale,
140
+ sliding_window_size=sliding_window_size,
141
+ key=key,
142
+ )
143
+ for sliding_window_size, key in zip(sliding_window_sizes, layers_keys, strict=True)
144
+ )
145
+ output_norm = self.output_norm_config.init(self.model_dim)
146
+ return Decoder(
147
+ self,
148
+ embedding=embedding,
149
+ global_rope=global_rope,
150
+ local_rope=local_rope,
151
+ layers=layers,
152
+ output_norm=output_norm,
153
+ )
154
+
155
+
156
+ class Decoder(LalamoModule[DecoderConfig]):
157
+ embedding: EmbeddingBase
158
+ global_rope: RoPE
159
+ local_rope: RoPE | None
160
+ layers: tuple[DecoderLayer, ...]
161
+ output_norm: RMSNorm
162
+
163
+ @property
164
+ def activation_precision(self) -> DTypeLike:
165
+ return self.embedding.activation_precision
166
+
167
+ def __call__(
168
+ self,
169
+ token_ids: Int[Array, " suffix_tokens"],
170
+ token_positions: Int[Array, " suffix_tokens"],
171
+ kv_cache: KVCache | None = None,
172
+ return_updated_kv_cache: bool = False,
173
+ return_activation_trace: bool = False,
174
+ length_without_padding: Int[Array, ""] | int | None = None,
175
+ ) -> DecoderResult:
176
+ maybe_kv_cache = kv_cache or ([None] * len(self.layers))
177
+ inner_features = self.embedding.embed(token_ids)
178
+
179
+ global_positional_embeddings = self.global_rope(token_positions)
180
+ if self.local_rope is not None:
181
+ local_positional_embeddings = self.local_rope(token_positions)
182
+ else:
183
+ local_positional_embeddings = global_positional_embeddings
184
+
185
+ updated_kv_cache_layers = []
186
+ layer_results = []
187
+ for layer, kv_cache_slice in zip(self.layers, maybe_kv_cache, strict=True):
188
+ if layer.attention_type == AttentionType.SLIDING_WINDOW:
189
+ positional_embeddings_to_use = local_positional_embeddings
190
+ else:
191
+ positional_embeddings_to_use = global_positional_embeddings
192
+
193
+ layer_result = layer(
194
+ inner_features,
195
+ positional_embeddings_to_use,
196
+ kv_cache=kv_cache_slice,
197
+ return_updated_kv_cache=return_updated_kv_cache,
198
+ return_activation_trace=return_activation_trace,
199
+ length_without_padding=length_without_padding,
200
+ )
201
+ inner_features = layer_result.outputs
202
+ layer_results.append(layer_result)
203
+ updated_kv_cache_layers.append(layer_result.updated_kv_cache)
204
+
205
+ normalized_outputs = vmap(self.output_norm, in_axes=0)(inner_features)
206
+ logits = vmap(self.embedding.readout, in_axes=0)(normalized_outputs)
207
+
208
+ if return_activation_trace:
209
+ activation_trace = DecoderActivationTrace(
210
+ token_ids=token_ids,
211
+ token_positions=token_positions,
212
+ kv_cache=kv_cache,
213
+ global_positional_embeddings=global_positional_embeddings,
214
+ local_positional_embeddings=local_positional_embeddings,
215
+ layer_results=tuple(layer_results),
216
+ output_norm=normalized_outputs,
217
+ )
218
+ else:
219
+ activation_trace = None
220
+
221
+ if return_updated_kv_cache:
222
+ updated_kv_cache = KVCache(updated_kv_cache_layers)
223
+ else:
224
+ updated_kv_cache = None
225
+
226
+ return DecoderResult(
227
+ logits=logits,
228
+ updated_kv_cache=updated_kv_cache,
229
+ activation_trace=activation_trace,
230
+ )
231
+
232
+ def init_static_kv_cache(self, capacity: int) -> KVCache:
233
+ return KVCache(layer.init_static_kv_cache(capacity) for layer in self.layers)
234
+
235
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
236
+ result = ParameterDict(
237
+ embedding=self.embedding.export_weights(weight_layout),
238
+ global_rope=self.global_rope.export_weights(weight_layout),
239
+ layers=[layer.export_weights(weight_layout) for layer in self.layers],
240
+ output_norm=self.output_norm.export_weights(weight_layout),
241
+ )
242
+ if self.local_rope:
243
+ result["local_rope"] = self.local_rope.export_weights(weight_layout)
244
+ return result