lalamo 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/common.py +79 -29
  3. lalamo/language_model.py +106 -83
  4. lalamo/main.py +91 -18
  5. lalamo/message_processor.py +170 -0
  6. lalamo/model_import/common.py +159 -43
  7. lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
  8. lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
  9. lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
  10. lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
  11. lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
  12. lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
  13. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
  14. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
  15. lalamo/model_import/huggingface_generation_config.py +44 -0
  16. lalamo/model_import/huggingface_tokenizer_config.py +85 -0
  17. lalamo/model_import/loaders/common.py +2 -1
  18. lalamo/model_import/loaders/huggingface.py +12 -10
  19. lalamo/model_import/model_specs/__init__.py +3 -2
  20. lalamo/model_import/model_specs/common.py +31 -32
  21. lalamo/model_import/model_specs/deepseek.py +1 -10
  22. lalamo/model_import/model_specs/gemma.py +2 -25
  23. lalamo/model_import/model_specs/huggingface.py +2 -12
  24. lalamo/model_import/model_specs/llama.py +2 -58
  25. lalamo/model_import/model_specs/mistral.py +9 -19
  26. lalamo/model_import/model_specs/pleias.py +3 -13
  27. lalamo/model_import/model_specs/polaris.py +5 -7
  28. lalamo/model_import/model_specs/qwen.py +12 -111
  29. lalamo/model_import/model_specs/reka.py +4 -13
  30. lalamo/modules/__init__.py +2 -1
  31. lalamo/modules/attention.py +90 -10
  32. lalamo/modules/common.py +51 -4
  33. lalamo/modules/decoder.py +90 -8
  34. lalamo/modules/decoder_layer.py +85 -8
  35. lalamo/modules/embedding.py +95 -29
  36. lalamo/modules/kv_cache.py +3 -3
  37. lalamo/modules/linear.py +170 -130
  38. lalamo/modules/mlp.py +40 -7
  39. lalamo/modules/normalization.py +24 -6
  40. lalamo/modules/rope.py +24 -6
  41. lalamo/sampling.py +99 -0
  42. lalamo/utils.py +86 -1
  43. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/METADATA +6 -6
  44. lalamo-0.3.1.dist-info/RECORD +58 -0
  45. lalamo-0.2.7.dist-info/RECORD +0 -54
  46. /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
  47. /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
  48. /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
  49. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/WHEEL +0 -0
  50. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/entry_points.txt +0 -0
  51. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/licenses/LICENSE +0 -0
  52. {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/top_level.txt +0 -0
lalamo/modules/decoder.py CHANGED
@@ -1,11 +1,13 @@
1
- from dataclasses import dataclass
1
+ from collections.abc import Mapping, Sequence
2
+ from dataclasses import dataclass, replace
3
+ from typing import Self
2
4
 
3
5
  import equinox as eqx
4
6
  import jax
5
7
  from jax import vmap
6
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
7
9
 
8
- from lalamo.common import ParameterDict
10
+ from lalamo.common import ParameterTree
9
11
 
10
12
  from .common import AttentionType, LalamoModule, WeightLayout
11
13
  from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerResult
@@ -34,8 +36,8 @@ class DecoderActivationTrace(eqx.Module):
34
36
 
35
37
  output_norm: Float[Array, "suffix_tokens channels"]
36
38
 
37
- def export(self) -> ParameterDict:
38
- result = ParameterDict(
39
+ def export(self) -> ParameterTree:
40
+ result = dict(
39
41
  token_ids=self.token_ids,
40
42
  token_positions=self.token_positions,
41
43
  local_positional_embeddings=self.local_positional_embeddings.export(),
@@ -53,8 +55,8 @@ class DecoderResult(eqx.Module):
53
55
  updated_kv_cache: KVCache | None = None
54
56
  activation_trace: DecoderActivationTrace | None = None
55
57
 
56
- def export(self) -> ParameterDict:
57
- result = ParameterDict(
58
+ def export(self) -> ParameterTree:
59
+ result: dict[str, ParameterTree | Array] = dict(
58
60
  logits=self.logits,
59
61
  )
60
62
  if self.updated_kv_cache is not None:
@@ -152,6 +154,56 @@ class DecoderConfig:
152
154
  output_norm=output_norm,
153
155
  )
154
156
 
157
+ def empty(
158
+ self,
159
+ ) -> "Decoder":
160
+ embedding = self.embedding_config.empty(
161
+ vocab_size=self.vocab_size,
162
+ model_dim=self.model_dim,
163
+ )
164
+ global_rope = self.global_rope_config.init(
165
+ head_dim=self.head_dim,
166
+ num_timesteps=self.context_length,
167
+ )
168
+
169
+ if self.local_rope_config:
170
+ assert self.sliding_window_sizes is not None
171
+ max_sliding_window_size = max(
172
+ window_size for window_size in self.sliding_window_sizes if window_size is not None
173
+ )
174
+ local_rope = self.local_rope_config.init(
175
+ head_dim=self.head_dim,
176
+ num_timesteps=max(max_sliding_window_size, self.context_length),
177
+ )
178
+ else:
179
+ local_rope = None
180
+
181
+ if self.sliding_window_sizes is None:
182
+ sliding_window_sizes = [None] * self.num_layers
183
+ else:
184
+ sliding_window_sizes = self.sliding_window_sizes
185
+ layers = tuple(
186
+ self.layer_config.empty(
187
+ model_dim=self.model_dim,
188
+ hidden_dim=self.hidden_dim,
189
+ num_heads=self.num_heads,
190
+ num_groups=self.num_groups,
191
+ head_dim=self.head_dim,
192
+ attention_scale=self.attention_scale,
193
+ sliding_window_size=sliding_window_size,
194
+ )
195
+ for sliding_window_size in sliding_window_sizes
196
+ )
197
+ output_norm = self.output_norm_config.empty(self.model_dim)
198
+ return Decoder(
199
+ self,
200
+ embedding=embedding,
201
+ global_rope=global_rope,
202
+ local_rope=local_rope,
203
+ layers=layers,
204
+ output_norm=output_norm,
205
+ )
206
+
155
207
 
156
208
  class Decoder(LalamoModule[DecoderConfig]):
157
209
  embedding: EmbeddingBase
@@ -164,6 +216,7 @@ class Decoder(LalamoModule[DecoderConfig]):
164
216
  def activation_precision(self) -> DTypeLike:
165
217
  return self.embedding.activation_precision
166
218
 
219
+ @eqx.filter_jit
167
220
  def __call__(
168
221
  self,
169
222
  token_ids: Int[Array, " suffix_tokens"],
@@ -232,8 +285,8 @@ class Decoder(LalamoModule[DecoderConfig]):
232
285
  def init_static_kv_cache(self, capacity: int) -> KVCache:
233
286
  return KVCache(layer.init_static_kv_cache(capacity) for layer in self.layers)
234
287
 
235
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
236
- result = ParameterDict(
288
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
289
+ result = dict(
237
290
  embedding=self.embedding.export_weights(weight_layout),
238
291
  global_rope=self.global_rope.export_weights(weight_layout),
239
292
  layers=[layer.export_weights(weight_layout) for layer in self.layers],
@@ -242,3 +295,32 @@ class Decoder(LalamoModule[DecoderConfig]):
242
295
  if self.local_rope:
243
296
  result["local_rope"] = self.local_rope.export_weights(weight_layout)
244
297
  return result
298
+
299
+ def import_weights(
300
+ self,
301
+ weights: ParameterTree[Array],
302
+ weight_layout: WeightLayout = WeightLayout.AUTO,
303
+ ) -> Self:
304
+ assert isinstance(weights, Mapping)
305
+ assert isinstance(weights["embedding"], Mapping)
306
+ assert isinstance(weights["global_rope"], Mapping)
307
+ assert isinstance(weights["layers"], Sequence)
308
+ assert isinstance(weights["output_norm"], Mapping)
309
+ if self.local_rope:
310
+ assert isinstance(weights["local_rope"], Mapping)
311
+ local_rope = self.local_rope.import_weights(weights["local_rope"], weight_layout)
312
+ else:
313
+ local_rope = None
314
+
315
+ layers = []
316
+ for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
317
+ assert isinstance(layer_weights, Mapping)
318
+ layers.append(layer.import_weights(layer_weights, weight_layout))
319
+ return replace(
320
+ self,
321
+ embedding=self.embedding.import_weights(weights["embedding"], weight_layout),
322
+ global_rope=self.global_rope.import_weights(weights["global_rope"], weight_layout),
323
+ layers=tuple(layers),
324
+ output_norm=self.output_norm.import_weights(weights["output_norm"], weight_layout),
325
+ local_rope=local_rope,
326
+ )
@@ -1,11 +1,13 @@
1
- from dataclasses import dataclass
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass, replace
3
+ from typing import Self
2
4
 
3
5
  import equinox as eqx
4
6
  import jax
5
7
  from jax import vmap
6
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
7
9
 
8
- from lalamo.common import ParameterDict
10
+ from lalamo.common import ParameterTree
9
11
 
10
12
  from .attention import Attention, AttentionConfig
11
13
  from .common import AttentionType, LalamoModule, WeightLayout
@@ -35,8 +37,8 @@ class DecoderLayerActivationTrace(eqx.Module):
35
37
  mlp: Float[Array, "suffix_tokens channels"]
36
38
  post_mlp_norm: Float[Array, "suffix_tokens channels"] | None
37
39
 
38
- def export(self) -> ParameterDict:
39
- result = ParameterDict(
40
+ def export(self) -> ParameterTree:
41
+ result = dict(
40
42
  inputs=self.inputs,
41
43
  positional_embeddings=self.positional_embeddings.export(),
42
44
  mlp_inputs=self.mlp_inputs,
@@ -59,8 +61,8 @@ class DecoderLayerResult(eqx.Module):
59
61
  updated_kv_cache: KVCacheLayer | None
60
62
  activation_trace: DecoderLayerActivationTrace | None
61
63
 
62
- def export(self) -> ParameterDict:
63
- result = ParameterDict(
64
+ def export(self) -> ParameterTree:
65
+ result: dict[str, ParameterTree | Array] = dict(
64
66
  outputs=self.outputs,
65
67
  )
66
68
  if self.updated_kv_cache is not None:
@@ -123,6 +125,46 @@ class DecoderLayerConfig:
123
125
  post_mlp_norm=post_mlp_norm,
124
126
  )
125
127
 
128
+ def empty(
129
+ self,
130
+ model_dim: int,
131
+ hidden_dim: int,
132
+ num_heads: int,
133
+ num_groups: int,
134
+ head_dim: int,
135
+ attention_scale: float | None,
136
+ sliding_window_size: int | None,
137
+ ) -> "DecoderLayer":
138
+ pre_attention_norm = self.pre_attention_norm_config.empty(model_dim)
139
+ attention = self.attention_config.empty(
140
+ model_dim=model_dim,
141
+ num_heads=num_heads,
142
+ num_groups=num_groups,
143
+ head_dim=head_dim,
144
+ is_causal=True,
145
+ scale=attention_scale,
146
+ sliding_window_size=sliding_window_size,
147
+ )
148
+ if self.post_attention_norm_config is not None:
149
+ post_attention_norm = self.post_attention_norm_config.empty(model_dim)
150
+ else:
151
+ post_attention_norm = None
152
+ pre_mlp_norm = self.pre_mlp_norm_config.empty(model_dim)
153
+ mlp = self.mlp_config.empty(model_dim, hidden_dim)
154
+ if self.post_mlp_norm_config is not None:
155
+ post_mlp_norm = self.post_mlp_norm_config.empty(model_dim)
156
+ else:
157
+ post_mlp_norm = None
158
+ return DecoderLayer(
159
+ config=self,
160
+ pre_attention_norm=pre_attention_norm,
161
+ attention=attention,
162
+ post_attention_norm=post_attention_norm,
163
+ pre_mlp_norm=pre_mlp_norm,
164
+ mlp=mlp,
165
+ post_mlp_norm=post_mlp_norm,
166
+ )
167
+
126
168
 
127
169
  class DecoderLayer(LalamoModule[DecoderLayerConfig]):
128
170
  pre_attention_norm: RMSNorm
@@ -168,6 +210,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
168
210
  f" the up projection dim {self.mlp.hidden_dim}",
169
211
  )
170
212
 
213
+ @eqx.filter_jit
171
214
  def __call__(
172
215
  self,
173
216
  inputs: Float[Array, "suffix_tokens channels"],
@@ -226,8 +269,8 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
226
269
  def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
227
270
  return self.attention.init_static_kv_cache(capacity)
228
271
 
229
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
230
- result = ParameterDict(
272
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
273
+ result = dict(
231
274
  pre_attention_norm=self.pre_attention_norm.export_weights(weight_layout),
232
275
  attention=self.attention.export_weights(weight_layout),
233
276
  pre_mlp_norm=self.pre_mlp_norm.export_weights(weight_layout),
@@ -238,3 +281,37 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
238
281
  if self.post_mlp_norm is not None:
239
282
  result["post_mlp_norm"] = self.post_mlp_norm.export_weights(weight_layout)
240
283
  return result
284
+
285
+ def import_weights(
286
+ self,
287
+ weights: ParameterTree[Array],
288
+ weight_layout: WeightLayout = WeightLayout.AUTO,
289
+ ) -> Self:
290
+ assert isinstance(weights, Mapping)
291
+ assert isinstance(weights["pre_attention_norm"], Mapping)
292
+ assert isinstance(weights["attention"], Mapping)
293
+ assert isinstance(weights["mlp"], Mapping)
294
+ assert isinstance(weights["pre_mlp_norm"], Mapping)
295
+
296
+ if self.post_attention_norm is not None:
297
+ assert isinstance(weights["post_attention_norm"], Mapping)
298
+ post_attention_norm = self.post_attention_norm.import_weights(
299
+ weights["post_attention_norm"],
300
+ weight_layout,
301
+ )
302
+ else:
303
+ post_attention_norm = None
304
+ if self.post_mlp_norm is not None:
305
+ assert isinstance(weights["post_mlp_norm"], Mapping)
306
+ post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"], weight_layout)
307
+ else:
308
+ post_mlp_norm = None
309
+ return replace(
310
+ self,
311
+ pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"], weight_layout),
312
+ attention=self.attention.import_weights(weights["attention"], weight_layout),
313
+ post_attention_norm=post_attention_norm,
314
+ pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"], weight_layout),
315
+ mlp=self.mlp.import_weights(weights["mlp"], weight_layout),
316
+ post_mlp_norm=post_mlp_norm,
317
+ )
@@ -1,15 +1,23 @@
1
1
  from abc import abstractmethod
2
- from dataclasses import dataclass
2
+ from collections.abc import Mapping
3
+ from dataclasses import dataclass, replace
4
+ from typing import Self
3
5
 
6
+ import equinox as eqx
4
7
  import jax
5
8
  import jax.numpy as jnp
6
- from einops import rearrange
7
9
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
8
10
 
9
- from lalamo.common import ParameterDict
11
+ from lalamo.common import ParameterTree, dummy_array
10
12
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
11
13
 
12
- from .common import LalamoModule, WeightLayout, register_config_union
14
+ from .common import (
15
+ LalamoModule,
16
+ WeightLayout,
17
+ from_layout,
18
+ into_layout,
19
+ register_config_union,
20
+ )
13
21
  from .utils import apply_soft_capping
14
22
 
15
23
  __all__ = [
@@ -38,6 +46,13 @@ class EmbeddingConfigBase:
38
46
  key: PRNGKeyArray,
39
47
  ) -> "EmbeddingBase": ...
40
48
 
49
+ @abstractmethod
50
+ def empty(
51
+ self,
52
+ vocab_size: int,
53
+ model_dim: int,
54
+ ) -> "EmbeddingBase": ...
55
+
41
56
 
42
57
  class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
43
58
  @abstractmethod
@@ -54,16 +69,14 @@ class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
54
69
  @abstractmethod
55
70
  def model_dim(self) -> int: ...
56
71
 
57
- @classmethod
58
- def _default_weight_layout(cls) -> WeightLayout:
59
- return WeightLayout.INPUT_OUTPUT
60
-
72
+ @eqx.filter_jit
61
73
  def embed(self, x: Int[Array, " tokens"]) -> Float[Array, "tokens channels"]:
62
74
  result = self._prepare_input_weights()[x]
63
75
  if self.config.input_scale is not None:
64
76
  result = result * jnp.array(self.config.input_scale, dtype=result.dtype)
65
77
  return result
66
78
 
79
+ @eqx.filter_jit
67
80
  def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
68
81
  logits = self._prepare_output_weights() @ x
69
82
  if self.config.logits_soft_cap is not None:
@@ -85,6 +98,14 @@ class TiedEmbeddingConfig(EmbeddingConfigBase):
85
98
  weights = jax.random.normal(key, (vocab_size, model_dim), dtype=self.precision)
86
99
  return TiedEmbedding(config=self, weights=weights)
87
100
 
101
+ def empty(
102
+ self,
103
+ vocab_size: int,
104
+ model_dim: int,
105
+ ) -> "TiedEmbedding":
106
+ weights = dummy_array((vocab_size, model_dim), dtype=self.precision)
107
+ return TiedEmbedding(config=self, weights=weights)
108
+
88
109
 
89
110
  class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
90
111
  weights: Float[Array, "vocabulary channels"]
@@ -115,8 +136,16 @@ class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
115
136
  def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
116
137
  return self.weights
117
138
 
118
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
119
- return ParameterDict(weights=self.weights)
139
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
140
+ return {"weights": self.weights}
141
+
142
+ def import_weights(
143
+ self,
144
+ weights: ParameterTree[Array],
145
+ weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
146
+ ) -> Self:
147
+ assert isinstance(weights, Mapping)
148
+ return replace(self, weights=weights["weights"])
120
149
 
121
150
 
122
151
  @dataclass(frozen=True)
@@ -139,6 +168,19 @@ class UntiedEmbeddingConfig(EmbeddingConfigBase):
139
168
  output_weights=output_weights,
140
169
  )
141
170
 
171
+ def empty(
172
+ self,
173
+ vocab_size: int,
174
+ model_dim: int,
175
+ ) -> "UntiedEmbedding":
176
+ input_weights = dummy_array((vocab_size, model_dim), dtype=self.precision)
177
+ output_weights = dummy_array((vocab_size, model_dim), dtype=self.precision)
178
+ return UntiedEmbedding(
179
+ config=self,
180
+ input_weights=input_weights,
181
+ output_weights=output_weights,
182
+ )
183
+
142
184
 
143
185
  class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
144
186
  input_weights: Float[Array, "vocabulary channels"]
@@ -186,21 +228,22 @@ class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
186
228
  def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
187
229
  return self.output_weights
188
230
 
189
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
190
- if weight_layout == WeightLayout.AUTO:
191
- weight_layout = self._default_weight_layout()
231
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
232
+ return {
233
+ "input_weights": self.input_weights,
234
+ "output_weights": into_layout(self.output_weights, weight_layout),
235
+ }
192
236
 
193
- match weight_layout:
194
- case WeightLayout.OUTPUT_INPUT:
195
- output_weights = self.output_weights
196
- case WeightLayout.INPUT_OUTPUT:
197
- output_weights = rearrange(self.output_weights, "token_ids channels -> channels token_ids")
198
- case _:
199
- raise ValueError(f"Unsupported weight layout: {weight_layout}")
200
-
201
- return ParameterDict(
202
- input_weights=self.input_weights,
203
- output_weights=output_weights,
237
+ def import_weights(
238
+ self,
239
+ weights: ParameterTree[Array],
240
+ weight_layout: WeightLayout = WeightLayout.AUTO,
241
+ ) -> Self:
242
+ assert isinstance(weights, Mapping)
243
+ return replace(
244
+ self,
245
+ input_weights=weights["input_weights"],
246
+ output_weights=from_layout(weights["output_weights"], weight_layout),
204
247
  )
205
248
 
206
249
 
@@ -225,6 +268,15 @@ class QuantizedTiedEmbeddingConfig(EmbeddingConfigBase):
225
268
  weights = quantize_weights(weights * min_abs_val, self.embedding_quantization_mode)
226
269
  return QuantizedTiedEmbedding(config=self, weights=weights, scales=scales)
227
270
 
271
+ def empty(
272
+ self,
273
+ vocab_size: int,
274
+ model_dim: int,
275
+ ) -> "QuantizedTiedEmbedding":
276
+ scales = dummy_array(vocab_size, dtype=self.activation_precision)
277
+ weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
278
+ return QuantizedTiedEmbedding(config=self, weights=weights, scales=scales)
279
+
228
280
 
229
281
  class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
230
282
  weights: Float[Array, "vocabulary channels"]
@@ -257,7 +309,7 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
257
309
  f" {self.config.activation_precision}"
258
310
  " Quantized layers require parameter dtypes to be equal to the activation precision.",
259
311
  )
260
- weights_vocab_size, weights_model_dim = self.weights.shape
312
+ weights_vocab_size, _ = self.weights.shape
261
313
  (scales_vocab_size,) = self.scales.shape
262
314
  if weights_vocab_size != scales_vocab_size:
263
315
  raise ValueError(
@@ -281,15 +333,29 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
281
333
  def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
282
334
  return self._prepare_weights()
283
335
 
336
+ @eqx.filter_jit
284
337
  def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
285
338
  if self.config.activation_quantization_mode is not None:
286
339
  x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
287
340
  return super().readout(x)
288
341
 
289
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
290
- return ParameterDict(
291
- weights=self.int_weights,
292
- scales=self.scales,
342
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
343
+ return {
344
+ "weights": into_layout(self.int_weights, weight_layout),
345
+ "scales": into_layout(self.scales, weight_layout),
346
+ }
347
+
348
+ def import_weights(
349
+ self,
350
+ weights: ParameterTree[Array],
351
+ weight_layout: WeightLayout = WeightLayout.AUTO,
352
+ ) -> Self:
353
+ assert isinstance(weights, Mapping)
354
+ assert isinstance(weights["weights"], Array)
355
+ return replace(
356
+ self,
357
+ weights=from_layout(weights["weights"].astype(self.weights.dtype), weight_layout),
358
+ scales=from_layout(weights["scales"], weight_layout),
293
359
  )
294
360
 
295
361
 
@@ -7,7 +7,7 @@ from jax.lax import dynamic_update_slice_in_dim
7
7
  from jax.tree_util import register_pytree_node_class
8
8
  from jaxtyping import Array, Bool, DTypeLike, Float, Int
9
9
 
10
- from lalamo.common import ParameterDict
10
+ from lalamo.common import ParameterTree
11
11
 
12
12
  __all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer"]
13
13
 
@@ -43,8 +43,8 @@ class KVCacheLayer(eqx.Module):
43
43
  added_length: Int[Array, ""] | int | None = None,
44
44
  ) -> Self: ...
45
45
 
46
- def export(self) -> ParameterDict:
47
- return ParameterDict(
46
+ def export(self) -> ParameterTree:
47
+ return dict(
48
48
  keys=self.keys,
49
49
  values=self.values,
50
50
  )