lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +271 -43
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +17 -7
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -4
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
  48. lalamo-0.4.0.dist-info/RECORD +71 -0
  49. lalamo-0.3.3.dist-info/RECORD +0 -59
  50. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
@@ -9,9 +9,10 @@ from jax import numpy as jnp
9
9
  from jax import vmap
10
10
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
+ from lalamo.common import dummy_array
12
13
  from lalamo.modules.normalization import RMSNorm, RMSNormConfig
13
14
 
14
- from .common import AttentionType, LalamoModule, ParameterTree, WeightLayout
15
+ from .common import AttentionType, LalamoModule, ParameterTree
15
16
  from .kv_cache import DynamicKVCacheLayer, KVCacheLayer, StaticKVCacheLayer
16
17
  from .linear import LinearBase, LinearConfig
17
18
  from .rope import PositionalEmbeddings
@@ -44,8 +45,6 @@ def _soft_capped_attention_kernel(
44
45
  ) -> Float[Array, "dst_tokens heads head_channels"]:
45
46
  _, num_heads, head_dim = queries.shape
46
47
  _, num_groups, _ = keys.shape
47
- if scale is None:
48
- scale = head_dim**-0.5
49
48
  group_size = num_heads // num_groups
50
49
  keys = _repeat_kv(keys, group_size)
51
50
  values = _repeat_kv(values, group_size)
@@ -59,7 +58,11 @@ def _soft_capped_attention_kernel(
59
58
  if mask is not None:
60
59
  attention_logits = jnp.where(mask, attention_logits, jnp.array(float("-inf"), dtype=attention_logits.dtype))
61
60
 
62
- attention_logits = attention_logits * scale
61
+ if scale is None:
62
+ scale_val = head_dim**-0.5
63
+ else:
64
+ scale_val = float(scale)
65
+ attention_logits = attention_logits * scale_val
63
66
  attention_logits = apply_soft_capping(attention_logits, logit_soft_cap)
64
67
  attention_weights = jax.nn.softmax(attention_logits, axis=-1)
65
68
  return einsum(
@@ -70,7 +73,7 @@ def _soft_capped_attention_kernel(
70
73
 
71
74
 
72
75
  class AttentionResult(NamedTuple):
73
- outputs: Float[Array, "suffix_tokens channels"]
76
+ outputs: Float[Array, "*batch suffix_tokens channels"]
74
77
  kv_cache: KVCacheLayer | None = None
75
78
 
76
79
 
@@ -83,6 +86,7 @@ class AttentionConfig:
83
86
  key_norm_config: RMSNormConfig | None
84
87
 
85
88
  logit_soft_cap: float | None
89
+ has_sinks: bool
86
90
  has_qkv_biases: bool
87
91
  has_out_biases: bool
88
92
 
@@ -130,12 +134,18 @@ class AttentionConfig:
130
134
  else:
131
135
  key_norm = None
132
136
 
137
+ if self.has_sinks:
138
+ sinks = jnp.zeros((num_heads,), dtype=qkv_projection.activation_precision)
139
+ else:
140
+ sinks = None
141
+
133
142
  return Attention(
134
143
  self,
135
144
  qkv_projection=qkv_projection,
136
145
  out_projection=out_projection,
137
146
  query_norm=query_norm,
138
147
  key_norm=key_norm,
148
+ sinks=sinks,
139
149
  num_heads=num_heads,
140
150
  num_groups=num_groups,
141
151
  head_dim=head_dim,
@@ -183,12 +193,18 @@ class AttentionConfig:
183
193
  else:
184
194
  key_norm = None
185
195
 
196
+ if self.has_sinks:
197
+ sinks = dummy_array(num_heads, qkv_projection.activation_precision)
198
+ else:
199
+ sinks = None
200
+
186
201
  return Attention(
187
202
  self,
188
203
  qkv_projection=qkv_projection,
189
204
  out_projection=out_projection,
190
205
  query_norm=query_norm,
191
206
  key_norm=key_norm,
207
+ sinks=sinks,
192
208
  num_heads=num_heads,
193
209
  num_groups=num_groups,
194
210
  head_dim=head_dim,
@@ -205,6 +221,8 @@ class Attention(LalamoModule[AttentionConfig]):
205
221
  query_norm: RMSNorm | None
206
222
  key_norm: RMSNorm | None
207
223
 
224
+ sinks: Float[Array, " heads"] | None
225
+
208
226
  num_heads: int = eqx.field(static=True)
209
227
  num_groups: int = eqx.field(static=True)
210
228
  head_dim: int = eqx.field(static=True)
@@ -234,6 +252,10 @@ class Attention(LalamoModule[AttentionConfig]):
234
252
  def attention_type(self) -> AttentionType:
235
253
  return AttentionType.SLIDING_WINDOW if self.sliding_window_size is not None else AttentionType.GLOBAL
236
254
 
255
+ @property
256
+ def has_sinks(self) -> bool:
257
+ return self.sinks is not None
258
+
237
259
  def __post_init__(self) -> None:
238
260
  if self.qkv_projection.has_biases != self.config.has_qkv_biases:
239
261
  raise ValueError(
@@ -285,6 +307,12 @@ class Attention(LalamoModule[AttentionConfig]):
285
307
  f" ({self.num_groups} * {self.head_dim} = {self.num_groups * self.head_dim}),"
286
308
  f" got {v_output_dim}",
287
309
  )
310
+ if self.sinks is not None:
311
+ (num_sink_heads,) = self.sinks.shape
312
+ if num_sink_heads != self.num_heads:
313
+ raise ValueError(
314
+ f"Number of sink heads must be equal to number of heads ({self.num_heads}), got {num_sink_heads}",
315
+ )
288
316
 
289
317
  @eqx.filter_jit
290
318
  def __call__(
@@ -325,12 +353,22 @@ class Attention(LalamoModule[AttentionConfig]):
325
353
  keys = apply_positional_embeddings(keys)
326
354
 
327
355
  if kv_cache is None:
328
- updated_kv_cache = DynamicKVCacheLayer.init(keys, values, length=length_without_padding)
356
+ updated_kv_cache = DynamicKVCacheLayer.init(self.has_sinks, keys, values, length=length_without_padding)
329
357
  else:
330
358
  updated_kv_cache = kv_cache.extend(keys, values, added_length=length_without_padding)
331
359
 
332
360
  num_suffix_tokens, _, _ = queries.shape
333
- mask = updated_kv_cache.attention_mask(num_suffix_tokens, self.is_causal, self.sliding_window_size)
361
+ mask = updated_kv_cache.attention_mask(
362
+ num_suffix_tokens,
363
+ self.is_causal,
364
+ length_without_padding,
365
+ self.sliding_window_size,
366
+ )
367
+ if self.sinks is not None:
368
+ sink_bias = jnp.zeros((self.num_heads, *mask.shape), dtype=queries.dtype)
369
+ sink_bias = sink_bias.at[:, :, 0].set(self.sinks[:, None])
370
+ else:
371
+ sink_bias = None
334
372
 
335
373
  if self.config.logit_soft_cap is not None:
336
374
  attention_output = _soft_capped_attention_kernel(
@@ -346,6 +384,7 @@ class Attention(LalamoModule[AttentionConfig]):
346
384
  queries,
347
385
  updated_kv_cache.keys,
348
386
  updated_kv_cache.values,
387
+ bias=sink_bias,
349
388
  mask=mask,
350
389
  scale=self.scale,
351
390
  )
@@ -366,41 +405,55 @@ class Attention(LalamoModule[AttentionConfig]):
366
405
  )
367
406
 
368
407
  def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
369
- return StaticKVCacheLayer.empty(capacity, self.num_groups, self.head_dim, self.activation_precision)
370
-
371
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
372
- result = dict(
373
- qkv_projection=self.qkv_projection.export_weights(weight_layout),
374
- out_projection=self.out_projection.export_weights(weight_layout),
408
+ return StaticKVCacheLayer.empty(
409
+ self.has_sinks,
410
+ capacity,
411
+ self.num_groups,
412
+ self.head_dim,
413
+ self.activation_precision,
375
414
  )
415
+
416
+ def export_weights(self) -> ParameterTree:
417
+ result: dict[str, ParameterTree | Array] = {
418
+ "qkv_projection": self.qkv_projection.export_weights(),
419
+ "out_projection": self.out_projection.export_weights(),
420
+ }
376
421
  if self.query_norm is not None:
377
- result["query_norm"] = self.query_norm.export_weights(weight_layout)
422
+ result["query_norm"] = self.query_norm.export_weights()
378
423
  if self.key_norm is not None:
379
- result["key_norm"] = self.key_norm.export_weights(weight_layout)
424
+ result["key_norm"] = self.key_norm.export_weights()
425
+ if self.sinks is not None:
426
+ assert isinstance(self.sinks, Array)
427
+ result["sinks"] = self.sinks
380
428
  return result
381
429
 
382
430
  def import_weights(
383
431
  self,
384
432
  weights: ParameterTree[Array],
385
- weight_layout: WeightLayout = WeightLayout.AUTO,
386
433
  ) -> Self:
387
434
  assert isinstance(weights, Mapping)
388
435
  assert isinstance(weights["qkv_projection"], Mapping)
389
436
  assert isinstance(weights["out_projection"], Mapping)
390
437
  if self.query_norm is not None:
391
438
  assert isinstance(weights["query_norm"], Mapping)
392
- query_norm = self.query_norm.import_weights(weights["query_norm"], weight_layout)
439
+ query_norm = self.query_norm.import_weights(weights["query_norm"])
393
440
  else:
394
441
  query_norm = None
395
442
  if self.key_norm is not None:
396
443
  assert isinstance(weights["key_norm"], Mapping)
397
- key_norm = self.key_norm.import_weights(weights["key_norm"], weight_layout)
444
+ key_norm = self.key_norm.import_weights(weights["key_norm"])
398
445
  else:
399
446
  key_norm = None
447
+ if self.sinks is not None:
448
+ assert isinstance(weights["sinks"], Array)
449
+ sinks = weights["sinks"]
450
+ else:
451
+ sinks = None
400
452
  return replace(
401
453
  self,
402
- qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"], weight_layout),
403
- out_projection=self.out_projection.import_weights(weights["out_projection"], weight_layout),
454
+ qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"]),
455
+ out_projection=self.out_projection.import_weights(weights["out_projection"]),
404
456
  query_norm=query_norm,
405
457
  key_norm=key_norm,
458
+ sinks=sinks,
406
459
  )
lalamo/modules/common.py CHANGED
@@ -6,79 +6,31 @@ from typing import Self
6
6
 
7
7
  import equinox as eqx
8
8
  from cattrs import Converter
9
- from einops import rearrange
10
9
  from jax import numpy as jnp
11
- from jaxtyping import Array, DTypeLike, Float
10
+ from jaxtyping import Array, DTypeLike
12
11
 
13
12
  from lalamo.common import ParameterTree
14
13
 
15
14
  __all__ = [
16
15
  "AttentionType",
17
16
  "DummyUnionMember",
17
+ "ForwardPassMode",
18
18
  "LalamoModule",
19
19
  "config_converter",
20
- "from_layout",
21
- "into_layout",
22
20
  "register_config_union",
23
21
  ]
24
22
 
25
23
 
26
- class WeightLayout(Enum):
27
- AUTO = "auto"
28
- INPUT_OUTPUT = "input_output"
29
- OUTPUT_INPUT = "output_input"
30
-
31
- def __str__(self) -> str:
32
- match self:
33
- case WeightLayout.AUTO:
34
- return "auto"
35
- case WeightLayout.INPUT_OUTPUT:
36
- return "(input, output)"
37
- case WeightLayout.OUTPUT_INPUT:
38
- return "(output, input)"
39
-
40
-
41
- _DEFAULT_WEIGHT_LAYOUT = WeightLayout.INPUT_OUTPUT
42
-
43
-
44
- def into_layout(
45
- weights: Float[Array, "in_channels out_channels"],
46
- layout: WeightLayout,
47
- ) -> Float[Array, "in_channels out_channels"] | Float[Array, "out_channels in_channels"]:
48
- if layout == WeightLayout.AUTO:
49
- layout = _DEFAULT_WEIGHT_LAYOUT
50
- match layout:
51
- case WeightLayout.OUTPUT_INPUT:
52
- return weights
53
- case WeightLayout.INPUT_OUTPUT:
54
- return rearrange(
55
- weights,
56
- "total_out_channels in_channels -> in_channels total_out_channels",
57
- )
58
-
59
-
60
- def from_layout(
61
- weights: ParameterTree | Array,
62
- layout: WeightLayout,
63
- ) -> Array:
64
- assert isinstance(weights, Array)
65
- if layout == WeightLayout.AUTO:
66
- layout = _DEFAULT_WEIGHT_LAYOUT
67
- match layout:
68
- case WeightLayout.OUTPUT_INPUT:
69
- return weights
70
- case WeightLayout.INPUT_OUTPUT:
71
- return rearrange(
72
- weights,
73
- "in_channels total_out_channels -> total_out_channels in_channels",
74
- )
75
-
76
-
77
24
  class AttentionType(Enum):
78
25
  GLOBAL = "global"
79
26
  SLIDING_WINDOW = "sliding_window"
80
27
 
81
28
 
29
+ class ForwardPassMode(Enum):
30
+ MULTI_TOKEN = "multi_token"
31
+ SINGLE_TOKEN = "single_token"
32
+
33
+
82
34
  class LalamoModule[ConfigT](eqx.Module):
83
35
  config: ConfigT = eqx.field(static=True)
84
36
 
@@ -87,13 +39,12 @@ class LalamoModule[ConfigT](eqx.Module):
87
39
  def activation_precision(self) -> DTypeLike: ...
88
40
 
89
41
  @abstractmethod
90
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree[Array]: ...
42
+ def export_weights(self) -> ParameterTree[Array]: ...
91
43
 
92
44
  @abstractmethod
93
45
  def import_weights(
94
46
  self,
95
47
  weights: ParameterTree[Array],
96
- weight_layout: WeightLayout = WeightLayout.AUTO,
97
48
  ) -> Self: ...
98
49
 
99
50
 
lalamo/modules/decoder.py CHANGED
@@ -8,9 +8,10 @@ from jax import vmap
8
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
10
  from lalamo.common import ParameterTree
11
+ from lalamo.modules.utils import vmap_twice
11
12
 
12
- from .common import AttentionType, LalamoModule, WeightLayout
13
- from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerResult
13
+ from .common import AttentionType, ForwardPassMode, LalamoModule
14
+ from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerForwardPassConfig, DecoderLayerResult
14
15
  from .embedding import EmbeddingBase, EmbeddingConfig
15
16
  from .kv_cache import KVCache
16
17
  from .normalization import RMSNorm, RMSNormConfig
@@ -20,13 +21,17 @@ __all__ = [
20
21
  "Decoder",
21
22
  "DecoderActivationTrace",
22
23
  "DecoderConfig",
24
+ "DecoderForwardPassConfig",
23
25
  "DecoderResult",
24
26
  ]
25
27
 
26
28
 
29
+ type DecoderForwardPassConfig = DecoderLayerForwardPassConfig
30
+
31
+
27
32
  class DecoderActivationTrace(eqx.Module):
28
- token_ids: Int[Array, " suffix_tokens"]
29
- token_positions: Int[Array, " suffix_tokens"]
33
+ token_ids: Int[Array, "batch suffix_tokens"]
34
+ token_positions: Int[Array, "batch suffix_tokens"]
30
35
  kv_cache: KVCache | None
31
36
 
32
37
  local_positional_embeddings: PositionalEmbeddings
@@ -34,7 +39,7 @@ class DecoderActivationTrace(eqx.Module):
34
39
 
35
40
  layer_results: tuple[DecoderLayerResult, ...]
36
41
 
37
- output_norm: Float[Array, "suffix_tokens channels"]
42
+ output_norm: Float[Array, "batch suffix_tokens channels"]
38
43
 
39
44
  def export(self) -> ParameterTree:
40
45
  result = dict(
@@ -51,7 +56,7 @@ class DecoderActivationTrace(eqx.Module):
51
56
 
52
57
 
53
58
  class DecoderResult(eqx.Module):
54
- logits: Float[Array, "suffix_tokens channels"]
59
+ logits: Float[Array, "batch suffix_tokens channels"]
55
60
  updated_kv_cache: KVCache | None = None
56
61
  activation_trace: DecoderActivationTrace | None = None
57
62
 
@@ -167,13 +172,9 @@ class DecoderConfig:
167
172
  )
168
173
 
169
174
  if self.local_rope_config:
170
- assert self.sliding_window_sizes is not None
171
- max_sliding_window_size = max(
172
- window_size for window_size in self.sliding_window_sizes if window_size is not None
173
- )
174
175
  local_rope = self.local_rope_config.init(
175
176
  head_dim=self.head_dim,
176
- num_timesteps=max(max_sliding_window_size, self.context_length),
177
+ num_timesteps=self.context_length,
177
178
  )
178
179
  else:
179
180
  local_rope = None
@@ -219,19 +220,31 @@ class Decoder(LalamoModule[DecoderConfig]):
219
220
  @eqx.filter_jit
220
221
  def __call__(
221
222
  self,
222
- token_ids: Int[Array, " suffix_tokens"],
223
- token_positions: Int[Array, " suffix_tokens"],
223
+ token_ids: Int[Array, "batch suffix_tokens"],
224
+ token_positions: Int[Array, "batch suffix_tokens"],
224
225
  kv_cache: KVCache | None = None,
225
226
  return_updated_kv_cache: bool = False,
226
227
  return_activation_trace: bool = False,
227
- length_without_padding: Int[Array, ""] | int | None = None,
228
+ lengths_without_padding: Int[Array, " batch"] | None = None,
229
+ forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
230
+ forward_pass_config: DecoderForwardPassConfig | None = None,
228
231
  ) -> DecoderResult:
232
+ if token_ids.ndim != 2:
233
+ raise ValueError(
234
+ f"token_ids must be a 2D arrays of size (batch_size, sequence_length), got {token_ids.shape}",
235
+ )
236
+ if token_positions.ndim != 2:
237
+ raise ValueError(
238
+ "token_positions must be a 2D arrays of size (batch_size, sequence_length),"
239
+ f" got {token_positions.shape}",
240
+ )
241
+
229
242
  maybe_kv_cache = kv_cache or ([None] * len(self.layers))
230
- inner_features = self.embedding.embed(token_ids)
243
+ inner_features = vmap(self.embedding.embed)(token_ids)
231
244
 
232
- global_positional_embeddings = self.global_rope(token_positions)
245
+ global_positional_embeddings = vmap(self.global_rope)(token_positions)
233
246
  if self.local_rope is not None:
234
- local_positional_embeddings = self.local_rope(token_positions)
247
+ local_positional_embeddings = vmap(self.local_rope)(token_positions)
235
248
  else:
236
249
  local_positional_embeddings = global_positional_embeddings
237
250
 
@@ -249,14 +262,16 @@ class Decoder(LalamoModule[DecoderConfig]):
249
262
  kv_cache=kv_cache_slice,
250
263
  return_updated_kv_cache=return_updated_kv_cache,
251
264
  return_activation_trace=return_activation_trace,
252
- length_without_padding=length_without_padding,
265
+ lengths_without_padding=lengths_without_padding,
266
+ forward_pass_mode=forward_pass_mode,
267
+ forward_pass_config=forward_pass_config,
253
268
  )
254
269
  inner_features = layer_result.outputs
255
270
  layer_results.append(layer_result)
256
271
  updated_kv_cache_layers.append(layer_result.updated_kv_cache)
257
272
 
258
- normalized_outputs = vmap(self.output_norm, in_axes=0)(inner_features)
259
- logits = vmap(self.embedding.readout, in_axes=0)(normalized_outputs)
273
+ normalized_outputs = vmap_twice(self.output_norm)(inner_features)
274
+ logits = vmap_twice(self.embedding.readout)(normalized_outputs)
260
275
 
261
276
  if return_activation_trace:
262
277
  activation_trace = DecoderActivationTrace(
@@ -282,24 +297,23 @@ class Decoder(LalamoModule[DecoderConfig]):
282
297
  activation_trace=activation_trace,
283
298
  )
284
299
 
285
- def init_static_kv_cache(self, capacity: int) -> KVCache:
286
- return KVCache(layer.init_static_kv_cache(capacity) for layer in self.layers)
300
+ def init_static_kv_cache(self, batch_size: int, capacity: int) -> KVCache:
301
+ return KVCache(layer.init_static_kv_cache(batch_size, capacity) for layer in self.layers)
287
302
 
288
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
303
+ def export_weights(self) -> ParameterTree:
289
304
  result = dict(
290
- embedding=self.embedding.export_weights(weight_layout),
291
- global_rope=self.global_rope.export_weights(weight_layout),
292
- layers=[layer.export_weights(weight_layout) for layer in self.layers],
293
- output_norm=self.output_norm.export_weights(weight_layout),
305
+ embedding=self.embedding.export_weights(),
306
+ global_rope=self.global_rope.export_weights(),
307
+ layers=[layer.export_weights() for layer in self.layers],
308
+ output_norm=self.output_norm.export_weights(),
294
309
  )
295
310
  if self.local_rope:
296
- result["local_rope"] = self.local_rope.export_weights(weight_layout)
311
+ result["local_rope"] = self.local_rope.export_weights()
297
312
  return result
298
313
 
299
314
  def import_weights(
300
315
  self,
301
316
  weights: ParameterTree[Array],
302
- weight_layout: WeightLayout = WeightLayout.AUTO,
303
317
  ) -> Self:
304
318
  assert isinstance(weights, Mapping)
305
319
  assert isinstance(weights["embedding"], Mapping)
@@ -308,19 +322,19 @@ class Decoder(LalamoModule[DecoderConfig]):
308
322
  assert isinstance(weights["output_norm"], Mapping)
309
323
  if self.local_rope:
310
324
  assert isinstance(weights["local_rope"], Mapping)
311
- local_rope = self.local_rope.import_weights(weights["local_rope"], weight_layout)
325
+ local_rope = self.local_rope.import_weights(weights["local_rope"])
312
326
  else:
313
327
  local_rope = None
314
328
 
315
329
  layers = []
316
330
  for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
317
331
  assert isinstance(layer_weights, Mapping)
318
- layers.append(layer.import_weights(layer_weights, weight_layout))
332
+ layers.append(layer.import_weights(layer_weights))
319
333
  return replace(
320
334
  self,
321
- embedding=self.embedding.import_weights(weights["embedding"], weight_layout),
322
- global_rope=self.global_rope.import_weights(weights["global_rope"], weight_layout),
335
+ embedding=self.embedding.import_weights(weights["embedding"]),
336
+ global_rope=self.global_rope.import_weights(weights["global_rope"]),
323
337
  layers=tuple(layers),
324
- output_norm=self.output_norm.import_weights(weights["output_norm"], weight_layout),
338
+ output_norm=self.output_norm.import_weights(weights["output_norm"]),
325
339
  local_rope=local_rope,
326
340
  )