lalamo 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +24 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +100 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/METADATA +1 -1
  44. lalamo-0.2.3.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,240 @@
1
+ from dataclasses import dataclass
2
+
3
+ import equinox as eqx
4
+ import jax
5
+ from jax import vmap
6
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
7
+
8
+ from lalamo.common import ParameterDict
9
+
10
+ from .attention import Attention, AttentionConfig
11
+ from .common import AttentionType, LalamoModule, WeightLayout
12
+ from .kv_cache import KVCacheLayer, StaticKVCacheLayer
13
+ from .mlp import MLP, MLPConfig
14
+ from .normalization import RMSNorm, RMSNormConfig
15
+ from .rope import PositionalEmbeddings
16
+
17
+ __all__ = [
18
+ "DecoderLayer",
19
+ "DecoderLayerActivationTrace",
20
+ "DecoderLayerConfig",
21
+ "DecoderLayerResult",
22
+ ]
23
+
24
+
25
+ class DecoderLayerActivationTrace(eqx.Module):
26
+ inputs: Float[Array, "suffix_tokens channels"]
27
+ positional_embeddings: PositionalEmbeddings
28
+ kv_cache: KVCacheLayer | None
29
+
30
+ mlp_inputs: Float[Array, "suffix_tokens channels"]
31
+ pre_attention_norm: Float[Array, "suffix_tokens channels"]
32
+ attention: Float[Array, "suffix_tokens channels"]
33
+ post_attention_norm: Float[Array, "suffix_tokens channels"] | None
34
+ pre_mlp_norm: Float[Array, "suffix_tokens channels"]
35
+ mlp: Float[Array, "suffix_tokens channels"]
36
+ post_mlp_norm: Float[Array, "suffix_tokens channels"] | None
37
+
38
+ def export(self) -> ParameterDict:
39
+ result = ParameterDict(
40
+ inputs=self.inputs,
41
+ positional_embeddings=self.positional_embeddings.export(),
42
+ mlp_inputs=self.mlp_inputs,
43
+ pre_attention_norm=self.pre_attention_norm,
44
+ attention=self.attention,
45
+ pre_mlp_norm=self.pre_mlp_norm,
46
+ mlp=self.mlp,
47
+ )
48
+ if self.kv_cache is not None:
49
+ result["kv_cache"] = self.kv_cache.export()
50
+ if self.post_attention_norm is not None:
51
+ result["post_attention_norm"] = self.post_attention_norm
52
+ if self.post_mlp_norm is not None:
53
+ result["post_mlp_norm"] = self.post_mlp_norm
54
+ return result
55
+
56
+
57
+ class DecoderLayerResult(eqx.Module):
58
+ outputs: Float[Array, "suffix_tokens channels"]
59
+ updated_kv_cache: KVCacheLayer | None
60
+ activation_trace: DecoderLayerActivationTrace | None
61
+
62
+ def export(self) -> ParameterDict:
63
+ result = ParameterDict(
64
+ outputs=self.outputs,
65
+ )
66
+ if self.updated_kv_cache is not None:
67
+ result["updated_kv_cache"] = self.updated_kv_cache.export()
68
+ if self.activation_trace is not None:
69
+ result["activation_trace"] = self.activation_trace.export()
70
+ return result
71
+
72
+
73
+ @dataclass(frozen=True)
74
+ class DecoderLayerConfig:
75
+ pre_attention_norm_config: RMSNormConfig
76
+ attention_config: AttentionConfig
77
+ post_attention_norm_config: RMSNormConfig | None
78
+ pre_mlp_norm_config: RMSNormConfig
79
+ mlp_config: MLPConfig
80
+ post_mlp_norm_config: RMSNormConfig | None
81
+
82
+ def random_init(
83
+ self,
84
+ model_dim: int,
85
+ hidden_dim: int,
86
+ num_heads: int,
87
+ num_groups: int,
88
+ head_dim: int,
89
+ attention_scale: float | None,
90
+ sliding_window_size: int | None,
91
+ *,
92
+ key: PRNGKeyArray,
93
+ ) -> "DecoderLayer":
94
+ attention_key, mlp_key = jax.random.split(key)
95
+ pre_attention_norm = self.pre_attention_norm_config.init(model_dim)
96
+ attention = self.attention_config.random_init(
97
+ model_dim=model_dim,
98
+ num_heads=num_heads,
99
+ num_groups=num_groups,
100
+ head_dim=head_dim,
101
+ is_causal=True,
102
+ scale=attention_scale,
103
+ sliding_window_size=sliding_window_size,
104
+ key=attention_key,
105
+ )
106
+ if self.post_attention_norm_config is not None:
107
+ post_attention_norm = self.post_attention_norm_config.init(model_dim)
108
+ else:
109
+ post_attention_norm = None
110
+ pre_mlp_norm = self.pre_mlp_norm_config.init(model_dim)
111
+ mlp = self.mlp_config.random_init(model_dim, hidden_dim, key=mlp_key)
112
+ if self.post_mlp_norm_config is not None:
113
+ post_mlp_norm = self.post_mlp_norm_config.init(model_dim)
114
+ else:
115
+ post_mlp_norm = None
116
+ return DecoderLayer(
117
+ config=self,
118
+ pre_attention_norm=pre_attention_norm,
119
+ attention=attention,
120
+ post_attention_norm=post_attention_norm,
121
+ pre_mlp_norm=pre_mlp_norm,
122
+ mlp=mlp,
123
+ post_mlp_norm=post_mlp_norm,
124
+ )
125
+
126
+
127
+ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
128
+ pre_attention_norm: RMSNorm
129
+ attention: Attention
130
+ post_attention_norm: RMSNorm | None
131
+ pre_mlp_norm: RMSNorm
132
+ mlp: MLP
133
+ post_mlp_norm: RMSNorm | None
134
+
135
+ @property
136
+ def activation_precision(self) -> DTypeLike:
137
+ return self.attention.activation_precision
138
+
139
+ @property
140
+ def attention_type(self) -> AttentionType:
141
+ return self.attention.attention_type
142
+
143
+ def __post_init__(self) -> None:
144
+ model_dim = self.pre_attention_norm.input_dim
145
+ if self.attention.model_dim != model_dim:
146
+ raise ValueError(
147
+ f"Attention model dim {self.attention.model_dim} does not match"
148
+ f" the first normalization layer dim {model_dim}",
149
+ )
150
+ if self.post_attention_norm is not None and self.post_attention_norm.input_dim != model_dim:
151
+ raise ValueError(
152
+ f"Post attention normalization dim {self.post_attention_norm.input_dim} does not match"
153
+ f" the first normalization layer dim {model_dim}",
154
+ )
155
+ if self.pre_mlp_norm.input_dim != model_dim:
156
+ raise ValueError(
157
+ f"Pre MLP normalization dim {self.pre_mlp_norm.input_dim} does not match"
158
+ f" the first normalization layer dim {model_dim}",
159
+ )
160
+ if self.mlp.model_dim != model_dim:
161
+ raise ValueError(
162
+ f"MLP up projection dim {self.mlp.up_projection.input_dim} does not match"
163
+ f" the first normalization layer dim {model_dim}",
164
+ )
165
+ if self.mlp.hidden_dim != self.mlp.down_projection.input_dim:
166
+ raise ValueError(
167
+ f"MLP down projection dim {self.mlp.down_projection.input_dim} does not match"
168
+ f" the up projection dim {self.mlp.hidden_dim}",
169
+ )
170
+
171
+ def __call__(
172
+ self,
173
+ inputs: Float[Array, "suffix_tokens channels"],
174
+ positional_embeddings: PositionalEmbeddings,
175
+ kv_cache: KVCacheLayer | None = None,
176
+ return_updated_kv_cache: bool = False,
177
+ return_activation_trace: bool = False,
178
+ length_without_padding: Int[Array, ""] | int | None = None,
179
+ ) -> DecoderLayerResult:
180
+ normalized_attention_inputs = vmap(self.pre_attention_norm, in_axes=0)(inputs)
181
+ attention_outputs, updated_kv_cache = self.attention(
182
+ normalized_attention_inputs,
183
+ positional_embeddings,
184
+ kv_cache=kv_cache,
185
+ return_updated_kv_cache=return_updated_kv_cache,
186
+ length_without_padding=length_without_padding,
187
+ )
188
+ if self.post_attention_norm is not None:
189
+ normalized_attention_outputs = vmap(self.post_attention_norm, in_axes=0)(attention_outputs)
190
+ mlp_inputs = inputs + normalized_attention_outputs
191
+ else:
192
+ normalized_attention_outputs = None
193
+ mlp_inputs = inputs + attention_outputs
194
+
195
+ normalized_mlp_inputs = vmap(self.pre_mlp_norm, in_axes=0)(mlp_inputs)
196
+ mlp_outputs = vmap(self.mlp, in_axes=0)(normalized_mlp_inputs)
197
+ if self.post_mlp_norm is not None:
198
+ normalized_mlp_outputs = vmap(self.post_mlp_norm, in_axes=0)(mlp_outputs)
199
+ outputs = mlp_inputs + normalized_mlp_outputs
200
+ else:
201
+ normalized_mlp_outputs = None
202
+ outputs = mlp_inputs + mlp_outputs
203
+
204
+ if return_activation_trace:
205
+ activation_trace = DecoderLayerActivationTrace(
206
+ inputs=inputs,
207
+ positional_embeddings=positional_embeddings,
208
+ kv_cache=kv_cache,
209
+ pre_attention_norm=normalized_attention_inputs,
210
+ attention=attention_outputs,
211
+ post_attention_norm=normalized_attention_outputs,
212
+ mlp_inputs=mlp_inputs,
213
+ pre_mlp_norm=normalized_mlp_inputs,
214
+ mlp=mlp_outputs,
215
+ post_mlp_norm=normalized_mlp_outputs,
216
+ )
217
+ else:
218
+ activation_trace = None
219
+
220
+ return DecoderLayerResult(
221
+ outputs=outputs,
222
+ updated_kv_cache=updated_kv_cache,
223
+ activation_trace=activation_trace,
224
+ )
225
+
226
+ def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
227
+ return self.attention.init_static_kv_cache(capacity)
228
+
229
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
230
+ result = ParameterDict(
231
+ pre_attention_norm=self.pre_attention_norm.export_weights(weight_layout),
232
+ attention=self.attention.export_weights(weight_layout),
233
+ pre_mlp_norm=self.pre_mlp_norm.export_weights(weight_layout),
234
+ mlp=self.mlp.export_weights(weight_layout),
235
+ )
236
+ if self.post_attention_norm is not None:
237
+ result["post_attention_norm"] = self.post_attention_norm.export_weights(weight_layout)
238
+ if self.post_mlp_norm is not None:
239
+ result["post_mlp_norm"] = self.post_mlp_norm.export_weights(weight_layout)
240
+ return result
@@ -0,0 +1,299 @@
1
+ from abc import abstractmethod
2
+ from dataclasses import dataclass
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from einops import rearrange
7
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
8
+
9
+ from lalamo.common import ParameterDict
10
+ from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
11
+
12
+ from .common import LalamoModule, WeightLayout, register_config_union
13
+ from .utils import apply_soft_capping
14
+
15
+ __all__ = [
16
+ "EmbeddingBase",
17
+ "EmbeddingConfig",
18
+ "QuantizedTiedEmbedding",
19
+ "QuantizedTiedEmbeddingConfig",
20
+ "TiedEmbedding",
21
+ "TiedEmbeddingConfig",
22
+ "UntiedEmbedding",
23
+ "UntiedEmbeddingConfig",
24
+ ]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class EmbeddingConfigBase:
29
+ input_scale: float | None
30
+ logits_soft_cap: float | None
31
+
32
+ @abstractmethod
33
+ def random_init(
34
+ self,
35
+ vocab_size: int,
36
+ model_dim: int,
37
+ *,
38
+ key: PRNGKeyArray,
39
+ ) -> "EmbeddingBase": ...
40
+
41
+
42
+ class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
43
+ @abstractmethod
44
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]: ...
45
+
46
+ @abstractmethod
47
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]: ...
48
+
49
+ @property
50
+ @abstractmethod
51
+ def vocab_size(self) -> int: ...
52
+
53
+ @property
54
+ @abstractmethod
55
+ def model_dim(self) -> int: ...
56
+
57
+ @classmethod
58
+ def _default_weight_layout(cls) -> WeightLayout:
59
+ return WeightLayout.INPUT_OUTPUT
60
+
61
+ def embed(self, x: Int[Array, " tokens"]) -> Float[Array, "tokens channels"]:
62
+ result = self._prepare_input_weights()[x]
63
+ if self.config.input_scale is not None:
64
+ result = result * jnp.array(self.config.input_scale, dtype=result.dtype)
65
+ return result
66
+
67
+ def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
68
+ logits = self._prepare_output_weights() @ x
69
+ if self.config.logits_soft_cap is not None:
70
+ logits = apply_soft_capping(logits, self.config.logits_soft_cap)
71
+ return logits
72
+
73
+
74
+ @dataclass(frozen=True)
75
+ class TiedEmbeddingConfig(EmbeddingConfigBase):
76
+ precision: DTypeLike
77
+
78
+ def random_init(
79
+ self,
80
+ vocab_size: int,
81
+ model_dim: int,
82
+ *,
83
+ key: PRNGKeyArray,
84
+ ) -> "TiedEmbedding":
85
+ weights = jax.random.normal(key, (vocab_size, model_dim), dtype=self.precision)
86
+ return TiedEmbedding(config=self, weights=weights)
87
+
88
+
89
+ class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
90
+ weights: Float[Array, "vocabulary channels"]
91
+
92
+ @property
93
+ def activation_precision(self) -> DTypeLike:
94
+ return self.config.precision
95
+
96
+ def __post_init__(self) -> None:
97
+ if self.config.precision != self.weights.dtype:
98
+ raise ValueError(
99
+ f"Embedding dtype {self.weights.dtype} does not match the specified precision {self.config.precision}",
100
+ )
101
+
102
+ @property
103
+ def model_dim(self) -> int:
104
+ _, model_dim = self.weights.shape
105
+ return model_dim
106
+
107
+ @property
108
+ def vocab_size(self) -> int:
109
+ vocab_size, _ = self.weights.shape
110
+ return vocab_size
111
+
112
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
113
+ return self.weights
114
+
115
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
116
+ return self.weights
117
+
118
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
119
+ return ParameterDict(weights=self.weights)
120
+
121
+
122
+ @dataclass(frozen=True)
123
+ class UntiedEmbeddingConfig(EmbeddingConfigBase):
124
+ precision: DTypeLike
125
+
126
+ def random_init(
127
+ self,
128
+ vocab_size: int,
129
+ model_dim: int,
130
+ *,
131
+ key: PRNGKeyArray,
132
+ ) -> "UntiedEmbedding":
133
+ input_key, output_key = jax.random.split(key)
134
+ input_weights = jax.random.normal(input_key, (vocab_size, model_dim), dtype=self.precision)
135
+ output_weights = jax.random.normal(output_key, (vocab_size, model_dim), dtype=self.precision)
136
+ return UntiedEmbedding(
137
+ config=self,
138
+ input_weights=input_weights,
139
+ output_weights=output_weights,
140
+ )
141
+
142
+
143
+ class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
144
+ input_weights: Float[Array, "vocabulary channels"]
145
+ output_weights: Float[Array, "vocabulary channels"]
146
+
147
+ @property
148
+ def activation_precision(self) -> DTypeLike:
149
+ return self.config.precision
150
+
151
+ @property
152
+ def model_dim(self) -> int:
153
+ _, model_dim = self.input_weights.shape
154
+ return model_dim
155
+
156
+ @property
157
+ def vocab_size(self) -> int:
158
+ vocab_size, _ = self.input_weights.shape
159
+ return vocab_size
160
+
161
+ def __post_init__(self) -> None:
162
+ if self.config.precision != self.input_weights.dtype:
163
+ raise ValueError(
164
+ f"Embedding dtype {self.input_weights.dtype} does not match",
165
+ f" the specified precision {self.config.precision}",
166
+ )
167
+ if self.config.precision != self.output_weights.dtype:
168
+ raise ValueError(
169
+ f"Embedding dtype {self.output_weights.dtype} does not match"
170
+ f" the specified precision {self.config.precision}",
171
+ )
172
+ input_vocab_size, input_model_dim = self.input_weights.shape
173
+ output_vocab_size, output_model_dim = self.output_weights.shape
174
+ if input_vocab_size != output_vocab_size:
175
+ raise ValueError(
176
+ f"Input vocab size {input_vocab_size} does not match the output vocab size {output_vocab_size}",
177
+ )
178
+ if input_model_dim != output_model_dim:
179
+ raise ValueError(
180
+ f"Input model dim {input_model_dim} does not match the output model dim {output_model_dim}",
181
+ )
182
+
183
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
184
+ return self.input_weights
185
+
186
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
187
+ return self.output_weights
188
+
189
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
190
+ if weight_layout == WeightLayout.AUTO:
191
+ weight_layout = self._default_weight_layout()
192
+
193
+ match weight_layout:
194
+ case WeightLayout.OUTPUT_INPUT:
195
+ output_weights = self.output_weights
196
+ case WeightLayout.INPUT_OUTPUT:
197
+ output_weights = rearrange(self.output_weights, "token_ids channels -> channels token_ids")
198
+ case _:
199
+ raise ValueError(f"Unsupported weight layout: {weight_layout}")
200
+
201
+ return ParameterDict(
202
+ input_weights=self.input_weights,
203
+ output_weights=output_weights,
204
+ )
205
+
206
+
207
+ @dataclass(frozen=True)
208
+ class QuantizedTiedEmbeddingConfig(EmbeddingConfigBase):
209
+ embedding_quantization_mode: QuantizationMode
210
+ activation_quantization_mode: QuantizationMode | None
211
+ activation_precision: DTypeLike
212
+
213
+ def random_init(
214
+ self,
215
+ vocab_size: int,
216
+ model_dim: int,
217
+ *,
218
+ key: PRNGKeyArray,
219
+ ) -> "QuantizedTiedEmbedding":
220
+ min_val, max_val = self.embedding_quantization_mode.range
221
+ min_abs_val = min(abs(min_val), abs(max_val))
222
+ scale = 1 / min_abs_val
223
+ scales = scale * jnp.ones(vocab_size, dtype=self.activation_precision)
224
+ weights = jax.random.normal(key, (vocab_size, model_dim), dtype=self.activation_precision)
225
+ weights = quantize_weights(weights * min_abs_val, self.embedding_quantization_mode)
226
+ return QuantizedTiedEmbedding(config=self, weights=weights, scales=scales)
227
+
228
+
229
+ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
230
+ weights: Float[Array, "vocabulary channels"]
231
+ scales: Float[Array, " vocabulary"]
232
+
233
+ @property
234
+ def activation_precision(self) -> DTypeLike:
235
+ return self.config.activation_precision
236
+
237
+ @property
238
+ def model_dim(self) -> int:
239
+ _, model_dim = self.weights.shape
240
+ return model_dim
241
+
242
+ @property
243
+ def vocab_size(self) -> int:
244
+ vocab_size, _ = self.weights.shape
245
+ return vocab_size
246
+
247
+ def __post_init__(self) -> None:
248
+ if self.weights.dtype != self.config.activation_precision:
249
+ raise ValueError(
250
+ f"Embedding dtype ({self.scales.dtype}) is not equal to specified activation precision"
251
+ f" ({self.config.activation_precision})."
252
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
253
+ )
254
+ if self.scales.dtype != self.config.activation_precision:
255
+ raise ValueError(
256
+ f"Scales dtype {self.scales.dtype} does not match the specified activation precision"
257
+ f" {self.config.activation_precision}"
258
+ " Quantized layers require parameter dtypes to be equal to the activation precision.",
259
+ )
260
+ weights_vocab_size, weights_model_dim = self.weights.shape
261
+ (scales_vocab_size,) = self.scales.shape
262
+ if weights_vocab_size != scales_vocab_size:
263
+ raise ValueError(
264
+ f"Embedding vocab size {weights_vocab_size} does not match"
265
+ f" the scales dimension size {scales_vocab_size}",
266
+ )
267
+
268
+ @property
269
+ def int_weights(self) -> Int[Array, "vocabulary channels"]:
270
+ result = quantize_weights(self.weights, self.config.embedding_quantization_mode)
271
+ return result.astype(self.config.embedding_quantization_mode.dtype)
272
+
273
+ def _prepare_weights(self) -> Float[Array, "vocabulary channels"]:
274
+ quantized_weights = quantize_weights(self.weights, self.config.embedding_quantization_mode)
275
+ quantized_weights = quantized_weights * self.scales.reshape(-1, 1)
276
+ return quantized_weights
277
+
278
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
279
+ return self._prepare_weights()
280
+
281
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
282
+ return self._prepare_weights()
283
+
284
+ def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
285
+ if self.config.activation_quantization_mode is not None:
286
+ x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
287
+ return super().readout(x)
288
+
289
+ def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
290
+ return ParameterDict(
291
+ weights=self.int_weights,
292
+ scales=self.scales,
293
+ )
294
+
295
+
296
+ EmbeddingConfig = TiedEmbeddingConfig | UntiedEmbeddingConfig | QuantizedTiedEmbeddingConfig
297
+
298
+
299
+ register_config_union(EmbeddingConfig)