lalamo 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/language_model.py +22 -23
  3. lalamo/main.py +4 -18
  4. lalamo/model_import/common.py +24 -6
  5. lalamo/model_import/decoder_configs/__init__.py +2 -0
  6. lalamo/model_import/decoder_configs/common.py +4 -4
  7. lalamo/model_import/decoder_configs/executorch.py +17 -10
  8. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  9. lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  10. lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  11. lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
  12. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  13. lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  14. lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  15. lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  16. lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  17. lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  18. lalamo/model_import/loaders/executorch.py +5 -4
  19. lalamo/model_import/loaders/huggingface.py +321 -69
  20. lalamo/model_import/model_specs/__init__.py +2 -0
  21. lalamo/model_import/model_specs/common.py +16 -5
  22. lalamo/model_import/model_specs/llamba.py +40 -0
  23. lalamo/model_import/model_specs/qwen.py +29 -1
  24. lalamo/modules/__init__.py +33 -6
  25. lalamo/modules/activations.py +9 -2
  26. lalamo/modules/common.py +10 -5
  27. lalamo/modules/decoder.py +93 -97
  28. lalamo/modules/decoder_layer.py +85 -103
  29. lalamo/modules/embedding.py +279 -5
  30. lalamo/modules/linear.py +335 -30
  31. lalamo/modules/mlp.py +6 -7
  32. lalamo/modules/mlx_interop.py +19 -0
  33. lalamo/modules/rope.py +1 -1
  34. lalamo/modules/token_mixers/__init__.py +30 -0
  35. lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
  36. lalamo/modules/token_mixers/common.py +78 -0
  37. lalamo/modules/token_mixers/mamba.py +553 -0
  38. lalamo/modules/token_mixers/state/__init__.py +12 -0
  39. lalamo/modules/token_mixers/state/common.py +26 -0
  40. lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
  41. lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  42. lalamo/utils.py +24 -2
  43. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
  44. lalamo-0.5.0.dist-info/RECORD +80 -0
  45. lalamo-0.4.0.dist-info/RECORD +0 -71
  46. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
  47. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,11 @@ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
12
  from lalamo.common import ParameterTree
13
13
 
14
- from .attention import Attention, AttentionConfig
15
- from .common import AttentionType, ForwardPassMode, LalamoModule
16
- from .kv_cache import KVCacheLayer, StaticKVCacheLayer
14
+ from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
17
15
  from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
18
16
  from .normalization import RMSNorm, RMSNormConfig
19
17
  from .rope import PositionalEmbeddings
18
+ from .token_mixers import KVCacheLayer, StateLayerBase, StaticKVCacheLayer, TokenMixerBase, TokenMixerConfig
20
19
  from .utils import vmap_twice
21
20
 
22
21
  __all__ = [
@@ -33,31 +32,32 @@ type DecoderLayerForwardPassConfig = MLPForwardPassConfig
33
32
 
34
33
  class DecoderLayerActivationTrace(eqx.Module):
35
34
  inputs: Float[Array, "batch suffix_tokens channels"]
36
- positional_embeddings: PositionalEmbeddings
37
- kv_cache: KVCacheLayer | None
35
+ positional_embeddings: PositionalEmbeddings | None
36
+ state: StateLayerBase | None
38
37
 
39
38
  mlp_inputs: Float[Array, "batch suffix_tokens channels"]
40
- pre_attention_norm: Float[Array, "batch suffix_tokens channels"]
41
- attention: Float[Array, "batch suffix_tokens channels"]
42
- post_attention_norm: Float[Array, "batch suffix_tokens channels"] | None
39
+ pre_mixer_norm: Float[Array, "batch suffix_tokens channels"]
40
+ mixer: Float[Array, "batch suffix_tokens channels"]
41
+ post_mixer_norm: Float[Array, "batch suffix_tokens channels"] | None
43
42
  pre_mlp_norm: Float[Array, "batch suffix_tokens channels"]
44
43
  mlp: Float[Array, "batch suffix_tokens channels"]
45
44
  post_mlp_norm: Float[Array, "batch suffix_tokens channels"] | None
46
45
 
47
46
  def export(self) -> ParameterTree:
48
- result = dict(
47
+ result: dict[str, ParameterTree | Array] = dict(
49
48
  inputs=self.inputs,
50
- positional_embeddings=self.positional_embeddings.export(),
51
49
  mlp_inputs=self.mlp_inputs,
52
- pre_attention_norm=self.pre_attention_norm,
53
- attention=self.attention,
50
+ pre_mixer_norm=self.pre_mixer_norm,
51
+ mixer=self.mixer,
54
52
  pre_mlp_norm=self.pre_mlp_norm,
55
53
  mlp=self.mlp,
56
54
  )
57
- if self.kv_cache is not None:
58
- result["kv_cache"] = self.kv_cache.export()
59
- if self.post_attention_norm is not None:
60
- result["post_attention_norm"] = self.post_attention_norm
55
+ if self.positional_embeddings is not None:
56
+ result["positional_embeddings"] = self.positional_embeddings.export()
57
+ if self.state is not None:
58
+ result["state"] = self.state.export()
59
+ if self.post_mixer_norm is not None:
60
+ result["post_mixer_norm"] = self.post_mixer_norm
61
61
  if self.post_mlp_norm is not None:
62
62
  result["post_mlp_norm"] = self.post_mlp_norm
63
63
  return result
@@ -65,15 +65,15 @@ class DecoderLayerActivationTrace(eqx.Module):
65
65
 
66
66
  class DecoderLayerResult(eqx.Module):
67
67
  outputs: Float[Array, "suffix_tokens channels"]
68
- updated_kv_cache: KVCacheLayer | None
68
+ updated_state: KVCacheLayer | None
69
69
  activation_trace: DecoderLayerActivationTrace | None
70
70
 
71
71
  def export(self) -> ParameterTree:
72
72
  result: dict[str, ParameterTree | Array] = dict(
73
73
  outputs=self.outputs,
74
74
  )
75
- if self.updated_kv_cache is not None:
76
- result["updated_kv_cache"] = self.updated_kv_cache.export()
75
+ if self.updated_state is not None:
76
+ result["updated_state"] = self.updated_state.export()
77
77
  if self.activation_trace is not None:
78
78
  result["activation_trace"] = self.activation_trace.export()
79
79
  return result
@@ -81,39 +81,32 @@ class DecoderLayerResult(eqx.Module):
81
81
 
82
82
  @dataclass(frozen=True)
83
83
  class DecoderLayerConfig:
84
- pre_attention_norm_config: RMSNormConfig
85
- attention_config: AttentionConfig
86
- post_attention_norm_config: RMSNormConfig | None
84
+ pre_mixer_norm_config: RMSNormConfig
85
+ mixer_config: TokenMixerConfig
86
+ post_mixer_norm_config: RMSNormConfig | None
87
87
  pre_mlp_norm_config: RMSNormConfig
88
88
  mlp_config: MLPConfig
89
89
  post_mlp_norm_config: RMSNormConfig | None
90
90
 
91
+ @property
92
+ def rope_dim(self) -> int:
93
+ return self.mixer_config.rope_dim
94
+
91
95
  def random_init(
92
96
  self,
93
97
  model_dim: int,
94
98
  hidden_dim: int,
95
- num_heads: int,
96
- num_groups: int,
97
- head_dim: int,
98
- attention_scale: float | None,
99
- sliding_window_size: int | None,
100
99
  *,
101
100
  key: PRNGKeyArray,
102
101
  ) -> "DecoderLayer":
103
102
  attention_key, mlp_key = jax.random.split(key)
104
- pre_attention_norm = self.pre_attention_norm_config.init(model_dim)
105
- attention = self.attention_config.random_init(
103
+ pre_attention_norm = self.pre_mixer_norm_config.init(model_dim)
104
+ mixer = self.mixer_config.random_init(
106
105
  model_dim=model_dim,
107
- num_heads=num_heads,
108
- num_groups=num_groups,
109
- head_dim=head_dim,
110
- is_causal=True,
111
- scale=attention_scale,
112
- sliding_window_size=sliding_window_size,
113
106
  key=attention_key,
114
107
  )
115
- if self.post_attention_norm_config is not None:
116
- post_attention_norm = self.post_attention_norm_config.init(model_dim)
108
+ if self.post_mixer_norm_config is not None:
109
+ post_attention_norm = self.post_mixer_norm_config.init(model_dim)
117
110
  else:
118
111
  post_attention_norm = None
119
112
  pre_mlp_norm = self.pre_mlp_norm_config.init(model_dim)
@@ -124,9 +117,9 @@ class DecoderLayerConfig:
124
117
  post_mlp_norm = None
125
118
  return DecoderLayer(
126
119
  config=self,
127
- pre_attention_norm=pre_attention_norm,
128
- attention=attention,
129
- post_attention_norm=post_attention_norm,
120
+ pre_mixer_norm=pre_attention_norm,
121
+ mixer=mixer,
122
+ post_mixer_norm=post_attention_norm,
130
123
  pre_mlp_norm=pre_mlp_norm,
131
124
  mlp=mlp,
132
125
  post_mlp_norm=post_mlp_norm,
@@ -136,24 +129,13 @@ class DecoderLayerConfig:
136
129
  self,
137
130
  model_dim: int,
138
131
  hidden_dim: int,
139
- num_heads: int,
140
- num_groups: int,
141
- head_dim: int,
142
- attention_scale: float | None,
143
- sliding_window_size: int | None,
144
132
  ) -> "DecoderLayer":
145
- pre_attention_norm = self.pre_attention_norm_config.empty(model_dim)
146
- attention = self.attention_config.empty(
133
+ pre_attention_norm = self.pre_mixer_norm_config.empty(model_dim)
134
+ attention = self.mixer_config.empty(
147
135
  model_dim=model_dim,
148
- num_heads=num_heads,
149
- num_groups=num_groups,
150
- head_dim=head_dim,
151
- is_causal=True,
152
- scale=attention_scale,
153
- sliding_window_size=sliding_window_size,
154
136
  )
155
- if self.post_attention_norm_config is not None:
156
- post_attention_norm = self.post_attention_norm_config.empty(model_dim)
137
+ if self.post_mixer_norm_config is not None:
138
+ post_attention_norm = self.post_mixer_norm_config.empty(model_dim)
157
139
  else:
158
140
  post_attention_norm = None
159
141
  pre_mlp_norm = self.pre_mlp_norm_config.empty(model_dim)
@@ -164,9 +146,9 @@ class DecoderLayerConfig:
164
146
  post_mlp_norm = None
165
147
  return DecoderLayer(
166
148
  config=self,
167
- pre_attention_norm=pre_attention_norm,
168
- attention=attention,
169
- post_attention_norm=post_attention_norm,
149
+ pre_mixer_norm=pre_attention_norm,
150
+ mixer=attention,
151
+ post_mixer_norm=post_attention_norm,
170
152
  pre_mlp_norm=pre_mlp_norm,
171
153
  mlp=mlp,
172
154
  post_mlp_norm=post_mlp_norm,
@@ -174,31 +156,31 @@ class DecoderLayerConfig:
174
156
 
175
157
 
176
158
  class DecoderLayer(LalamoModule[DecoderLayerConfig]):
177
- pre_attention_norm: RMSNorm
178
- attention: Attention
179
- post_attention_norm: RMSNorm | None
159
+ pre_mixer_norm: RMSNorm
160
+ mixer: TokenMixerBase
161
+ post_mixer_norm: RMSNorm | None
180
162
  pre_mlp_norm: RMSNorm
181
163
  mlp: MLPBase
182
164
  post_mlp_norm: RMSNorm | None
183
165
 
184
166
  @property
185
167
  def activation_precision(self) -> DTypeLike:
186
- return self.attention.activation_precision
168
+ return self.mixer.activation_precision
187
169
 
188
170
  @property
189
- def attention_type(self) -> AttentionType:
190
- return self.attention.attention_type
171
+ def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
172
+ return self.mixer.positional_embedding_selector
191
173
 
192
174
  def __post_init__(self) -> None:
193
- model_dim = self.pre_attention_norm.input_dim
194
- if self.attention.model_dim != model_dim:
175
+ model_dim = self.pre_mixer_norm.input_dim
176
+ if self.mixer.model_dim != model_dim:
195
177
  raise ValueError(
196
- f"Attention model dim {self.attention.model_dim} does not match"
178
+ f"Attention model dim {self.mixer.model_dim} does not match"
197
179
  f" the first normalization layer dim {model_dim}",
198
180
  )
199
- if self.post_attention_norm is not None and self.post_attention_norm.input_dim != model_dim:
181
+ if self.post_mixer_norm is not None and self.post_mixer_norm.input_dim != model_dim:
200
182
  raise ValueError(
201
- f"Post attention normalization dim {self.post_attention_norm.input_dim} does not match"
183
+ f"Post mixer normalization dim {self.post_mixer_norm.input_dim} does not match"
202
184
  f" the first normalization layer dim {model_dim}",
203
185
  )
204
186
  if self.pre_mlp_norm.input_dim != model_dim:
@@ -216,9 +198,9 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
216
198
  def __call__(
217
199
  self,
218
200
  inputs: Float[Array, "batch suffix_tokens channels"],
219
- positional_embeddings: PositionalEmbeddings,
220
- kv_cache: KVCacheLayer | None = None,
221
- return_updated_kv_cache: bool = False,
201
+ positional_embeddings: PositionalEmbeddings | None,
202
+ state: StateLayerBase | None = None,
203
+ return_updated_state: bool = False,
222
204
  return_activation_trace: bool = False,
223
205
  lengths_without_padding: Int[Array, " batch"] | None = None,
224
206
  forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
@@ -229,20 +211,20 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
229
211
  f"Inputs to decoder layers must be a 3D arrays of size (batch_size, sequence_length, hidden_dim),"
230
212
  f" got {inputs.shape}",
231
213
  )
232
- normalized_attention_inputs = vmap_twice(self.pre_attention_norm)(inputs)
233
- batched_attention_fn = vmap(partial(self.attention, return_updated_kv_cache=return_updated_kv_cache))
234
- attention_outputs, updated_kv_cache = batched_attention_fn(
235
- normalized_attention_inputs,
214
+ normalized_mixer_inputs = vmap_twice(self.pre_mixer_norm)(inputs)
215
+ batched_mixer_fn = vmap(partial(self.mixer, return_updated_state=return_updated_state))
216
+ mixer_outputs, updated_state = batched_mixer_fn(
217
+ normalized_mixer_inputs,
236
218
  positional_embeddings,
237
- kv_cache=kv_cache,
219
+ state=state,
238
220
  length_without_padding=lengths_without_padding,
239
221
  )
240
- if self.post_attention_norm is not None:
241
- normalized_attention_outputs = vmap_twice(self.post_attention_norm)(attention_outputs)
242
- mlp_inputs = inputs + normalized_attention_outputs
222
+ if self.post_mixer_norm is not None:
223
+ normalized_mixer_outputs = vmap_twice(self.post_mixer_norm)(mixer_outputs)
224
+ mlp_inputs = inputs + normalized_mixer_outputs
243
225
  else:
244
- normalized_attention_outputs = None
245
- mlp_inputs = inputs + attention_outputs
226
+ normalized_mixer_outputs = None
227
+ mlp_inputs = inputs + mixer_outputs
246
228
 
247
229
  normalized_mlp_inputs = vmap_twice(self.pre_mlp_norm)(mlp_inputs)
248
230
  mlp_outputs = self.mlp(
@@ -261,10 +243,10 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
261
243
  activation_trace = DecoderLayerActivationTrace(
262
244
  inputs=inputs,
263
245
  positional_embeddings=positional_embeddings,
264
- kv_cache=kv_cache,
265
- pre_attention_norm=normalized_attention_inputs,
266
- attention=attention_outputs,
267
- post_attention_norm=normalized_attention_outputs,
246
+ state=state,
247
+ pre_mixer_norm=normalized_mixer_inputs,
248
+ mixer=mixer_outputs,
249
+ post_mixer_norm=normalized_mixer_outputs,
268
250
  mlp_inputs=mlp_inputs,
269
251
  pre_mlp_norm=normalized_mlp_inputs,
270
252
  mlp=mlp_outputs,
@@ -275,25 +257,25 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
275
257
 
276
258
  return DecoderLayerResult(
277
259
  outputs=outputs,
278
- updated_kv_cache=updated_kv_cache,
260
+ updated_state=updated_state,
279
261
  activation_trace=activation_trace,
280
262
  )
281
263
 
282
- def init_static_kv_cache(self, batch_size: int, capacity: int) -> StaticKVCacheLayer:
264
+ def init_static_state(self, batch_size: int, capacity: int) -> StaticKVCacheLayer:
283
265
  return jax.tree.map(
284
266
  lambda array: jnp.repeat(array[None, ...], batch_size, axis=0),
285
- self.attention.init_static_kv_cache(capacity),
267
+ self.mixer.init_static_state(capacity),
286
268
  )
287
269
 
288
270
  def export_weights(self) -> ParameterTree:
289
271
  result = dict(
290
- pre_attention_norm=self.pre_attention_norm.export_weights(),
291
- attention=self.attention.export_weights(),
272
+ pre_mixer_norm=self.pre_mixer_norm.export_weights(),
273
+ mixer=self.mixer.export_weights(),
292
274
  pre_mlp_norm=self.pre_mlp_norm.export_weights(),
293
275
  mlp=self.mlp.export_weights(),
294
276
  )
295
- if self.post_attention_norm is not None:
296
- result["post_attention_norm"] = self.post_attention_norm.export_weights()
277
+ if self.post_mixer_norm is not None:
278
+ result["post_mixer_norm"] = self.post_mixer_norm.export_weights()
297
279
  if self.post_mlp_norm is not None:
298
280
  result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
299
281
  return result
@@ -303,18 +285,18 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
303
285
  weights: ParameterTree[Array],
304
286
  ) -> Self:
305
287
  assert isinstance(weights, Mapping)
306
- assert isinstance(weights["pre_attention_norm"], Mapping)
307
- assert isinstance(weights["attention"], Mapping)
288
+ assert isinstance(weights["pre_mixer_norm"], Mapping)
289
+ assert isinstance(weights["mixer"], Mapping)
308
290
  assert isinstance(weights["mlp"], Mapping)
309
291
  assert isinstance(weights["pre_mlp_norm"], Mapping)
310
292
 
311
- if self.post_attention_norm is not None:
312
- assert isinstance(weights["post_attention_norm"], Mapping)
313
- post_attention_norm = self.post_attention_norm.import_weights(
314
- weights["post_attention_norm"],
293
+ if self.post_mixer_norm is not None:
294
+ assert isinstance(weights["post_mixer_norm"], Mapping)
295
+ post_mixer_norm = self.post_mixer_norm.import_weights(
296
+ weights["post_mixer_norm"],
315
297
  )
316
298
  else:
317
- post_attention_norm = None
299
+ post_mixer_norm = None
318
300
  if self.post_mlp_norm is not None:
319
301
  assert isinstance(weights["post_mlp_norm"], Mapping)
320
302
  post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
@@ -322,9 +304,9 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
322
304
  post_mlp_norm = None
323
305
  return replace(
324
306
  self,
325
- pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"]),
326
- attention=self.attention.import_weights(weights["attention"]),
327
- post_attention_norm=post_attention_norm,
307
+ pre_mixer_norm=self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"]),
308
+ mixer=self.mixer.import_weights(weights["mixer"]),
309
+ post_mixer_norm=post_mixer_norm,
328
310
  pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
329
311
  mlp=self.mlp.import_weights(weights["mlp"]),
330
312
  post_mlp_norm=post_mlp_norm,
@@ -6,10 +6,12 @@ from typing import Self
6
6
  import equinox as eqx
7
7
  import jax
8
8
  import jax.numpy as jnp
9
+ from einops import rearrange
9
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
10
11
 
11
12
  from lalamo.common import ParameterTree, dummy_array
12
13
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
14
+ from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
13
15
 
14
16
  from .common import (
15
17
  LalamoModule,
@@ -20,6 +22,10 @@ from .utils import apply_soft_capping
20
22
  __all__ = [
21
23
  "EmbeddingBase",
22
24
  "EmbeddingConfig",
25
+ "MLXQuantizedTiedEmbedding",
26
+ "MLXQuantizedTiedEmbeddingConfig",
27
+ "MLXSemiQuantizedUntiedEmbedding",
28
+ "MLXSemiQuantizedUntiedEmbeddingConfig",
23
29
  "QuantizedTiedEmbedding",
24
30
  "QuantizedTiedEmbeddingConfig",
25
31
  "TiedEmbedding",
@@ -314,8 +320,15 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
314
320
 
315
321
  @property
316
322
  def int_weights(self) -> Int[Array, "vocabulary channels"]:
317
- result = quantize_weights(self.weights, self.config.embedding_quantization_mode)
318
- return result.astype(self.config.embedding_quantization_mode.dtype)
323
+ quantized = quantize_weights(self.weights, self.config.embedding_quantization_mode)
324
+ casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
325
+
326
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
327
+ packed = jax_uint4_to_packed_uint8(casted)
328
+ else:
329
+ packed = casted
330
+
331
+ return packed
319
332
 
320
333
  def _prepare_weights(self) -> Float[Array, "vocabulary channels"]:
321
334
  quantized_weights = quantize_weights(self.weights, self.config.embedding_quantization_mode)
@@ -346,14 +359,275 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
346
359
  ) -> Self:
347
360
  assert isinstance(weights, Mapping)
348
361
  assert isinstance(weights["weights"], Array)
362
+ stored_weights = weights["weights"]
363
+
364
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
365
+ stored_weights = jax_uint8_to_unpacked_uint4(stored_weights)
366
+
367
+ return replace(
368
+ self,
369
+ weights=stored_weights.astype(self.weights.dtype),
370
+ scales=weights["scales"],
371
+ )
372
+
373
+
374
+ @dataclass(frozen=True)
375
+ class MLXQuantizedTiedEmbeddingConfig(EmbeddingConfigBase):
376
+ group_size: int
377
+ embedding_quantization_mode: QuantizationMode
378
+ activation_quantization_mode: QuantizationMode | None
379
+ activation_precision: DTypeLike
380
+
381
+ def random_init(
382
+ self,
383
+ vocab_size: int,
384
+ model_dim: int,
385
+ *,
386
+ key: PRNGKeyArray,
387
+ ) -> "QuantizedTiedEmbedding":
388
+ raise NotImplementedError
389
+
390
+ def empty(
391
+ self,
392
+ vocab_size: int,
393
+ model_dim: int,
394
+ ) -> "MLXQuantizedTiedEmbedding":
395
+ assert model_dim % self.group_size == 0
396
+ model_groups = model_dim // self.group_size
397
+ weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
398
+ scales = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
399
+ biases = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
400
+ return MLXQuantizedTiedEmbedding(config=self, weights=weights, scales=scales, biases=biases)
401
+
402
+
403
+ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
404
+ weights: Float[Array, "vocabulary channels"]
405
+ scales: Float[Array, "vocabulary groups"]
406
+ biases: Float[Array, "vocabulary groups"]
407
+
408
+ @property
409
+ def activation_precision(self) -> DTypeLike:
410
+ return self.config.activation_precision
411
+
412
+ @property
413
+ def model_dim(self) -> int:
414
+ _, model_dim = self.weights.shape
415
+ return model_dim
416
+
417
+ @property
418
+ def vocab_size(self) -> int:
419
+ vocab_size, _ = self.weights.shape
420
+ return vocab_size
421
+
422
+ @property
423
+ def int_weights(self) -> Int[Array, "vocabulary channels"]:
424
+ quantized = quantize_weights(self.weights, self.config.embedding_quantization_mode)
425
+ casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
426
+
427
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
428
+ packed = jax_uint4_to_packed_uint8(casted)
429
+ else:
430
+ packed = casted
431
+
432
+ return packed
433
+
434
+ def _prepare_weights(self) -> Float[Array, "vocabulary channels"]:
435
+ quantized_weights = quantize_weights(self.weights, self.config.embedding_quantization_mode)
436
+ grouped_weights = rearrange(
437
+ quantized_weights,
438
+ "vocab (groups elements) -> vocab groups elements",
439
+ elements=self.config.group_size,
440
+ )
441
+
442
+ scales = rearrange(self.scales, "vocab groups -> vocab groups 1")
443
+
444
+ biases = rearrange(self.biases, "vocab groups -> vocab groups 1")
445
+
446
+ scaled_grouped_weights = grouped_weights * scales + biases
447
+
448
+ result = rearrange(
449
+ scaled_grouped_weights,
450
+ "vocab groups elements -> vocab (groups elements)",
451
+ )
452
+ return result
453
+
454
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
455
+ return self._prepare_weights()
456
+
457
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
458
+ return self._prepare_weights()
459
+
460
+ @eqx.filter_jit
461
+ def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
462
+ if self.config.activation_quantization_mode is not None:
463
+ x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
464
+ return super().readout(x)
465
+
466
+ def export_weights(self) -> ParameterTree:
467
+ return {
468
+ "weights": self.int_weights,
469
+ "scales": self.scales,
470
+ "biases": self.biases,
471
+ }
472
+
473
+ def import_weights(
474
+ self,
475
+ weights: ParameterTree[Array],
476
+ ) -> Self:
477
+ assert isinstance(weights, Mapping)
478
+ assert isinstance(weights["weights"], Array)
479
+ assert isinstance(weights["scales"], Array)
480
+ assert isinstance(weights["biases"], Array)
481
+
482
+ unpacked_weights = weights["weights"]
483
+
484
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
485
+ unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
486
+
349
487
  return replace(
350
488
  self,
351
- weights=weights["weights"].astype(self.weights.dtype),
489
+ weights=unpacked_weights.astype(self.weights.dtype),
352
490
  scales=weights["scales"],
491
+ biases=weights["biases"],
492
+ )
493
+
494
+
495
+ @dataclass(frozen=True)
496
+ class MLXSemiQuantizedUntiedEmbeddingConfig(EmbeddingConfigBase):
497
+ group_size: int
498
+ embedding_quantization_mode: QuantizationMode
499
+ activation_quantization_mode: QuantizationMode | None
500
+ activation_precision: DTypeLike
501
+
502
+ def random_init(
503
+ self,
504
+ vocab_size: int,
505
+ model_dim: int,
506
+ *,
507
+ key: PRNGKeyArray,
508
+ ) -> "MLXSemiQuantizedUntiedEmbedding":
509
+ raise NotImplementedError
510
+
511
+ def empty(
512
+ self,
513
+ vocab_size: int,
514
+ model_dim: int,
515
+ ) -> "MLXSemiQuantizedUntiedEmbedding":
516
+ assert model_dim % self.group_size == 0
517
+ model_groups = model_dim // self.group_size
518
+ input_weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
519
+ output_weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
520
+ output_scales = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
521
+ output_biases = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
522
+ return MLXSemiQuantizedUntiedEmbedding(
523
+ config=self,
524
+ input_weights=input_weights,
525
+ output_weights=output_weights,
526
+ output_scales=output_scales,
527
+ output_biases=output_biases,
528
+ )
529
+
530
+
531
+ class MLXSemiQuantizedUntiedEmbedding(EmbeddingBase[MLXSemiQuantizedUntiedEmbeddingConfig]):
532
+ input_weights: Float[Array, "vocabulary channels"]
533
+ output_weights: Float[Array, "vocabulary channels"]
534
+ output_scales: Float[Array, "vocabulary groups"]
535
+ output_biases: Float[Array, "vocabulary groups"]
536
+
537
+ @property
538
+ def activation_precision(self) -> DTypeLike:
539
+ return self.config.activation_precision
540
+
541
+ @property
542
+ def model_dim(self) -> int:
543
+ _, model_dim = self.input_weights.shape
544
+ return model_dim
545
+
546
+ @property
547
+ def vocab_size(self) -> int:
548
+ vocab_size, _ = self.input_weights.shape
549
+ return vocab_size
550
+
551
+ @property
552
+ def int_output_weights(self) -> Int[Array, "vocabulary channels"]:
553
+ quantized = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
554
+ casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
555
+
556
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
557
+ packed = jax_uint4_to_packed_uint8(casted)
558
+ else:
559
+ packed = casted
560
+
561
+ return packed
562
+
563
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
564
+ return self.input_weights
565
+
566
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
567
+ quantized_weights = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
568
+ grouped_weights = rearrange(
569
+ quantized_weights,
570
+ "vocab (groups elements) -> vocab groups elements",
571
+ elements=self.config.group_size,
572
+ )
573
+
574
+ scales = rearrange(self.output_scales, "vocab groups -> vocab groups 1")
575
+
576
+ biases = rearrange(self.output_biases, "vocab groups -> vocab groups 1")
577
+
578
+ scaled_grouped_weights = grouped_weights * scales + biases
579
+
580
+ result = rearrange(
581
+ scaled_grouped_weights,
582
+ "vocab groups elements -> vocab (groups elements)",
583
+ )
584
+ return result
585
+
586
+ @eqx.filter_jit
587
+ def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
588
+ if self.config.activation_quantization_mode is not None:
589
+ x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
590
+ return super().readout(x)
591
+
592
+ def export_weights(self) -> ParameterTree:
593
+ return {
594
+ "input_weights": self.input_weights,
595
+ "output_weights": self.int_output_weights,
596
+ "output_scales": self.output_scales,
597
+ "output_biases": self.output_biases,
598
+ }
599
+
600
+ def import_weights(
601
+ self,
602
+ weights: ParameterTree[Array],
603
+ ) -> Self:
604
+ assert isinstance(weights, Mapping)
605
+ assert isinstance(weights["input_weights"], Array)
606
+ assert isinstance(weights["output_weights"], Array)
607
+ assert isinstance(weights["output_scales"], Array)
608
+ assert isinstance(weights["output_biases"], Array)
609
+
610
+ unpacked_output_weights = weights["output_weights"]
611
+
612
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
613
+ unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
614
+
615
+ return replace(
616
+ self,
617
+ input_weights=weights["input_weights"],
618
+ output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
619
+ output_scales=weights["output_scales"],
620
+ output_biases=weights["output_biases"],
353
621
  )
354
622
 
355
623
 
356
- EmbeddingConfig = TiedEmbeddingConfig | UntiedEmbeddingConfig | QuantizedTiedEmbeddingConfig
624
+ EmbeddingConfig = (
625
+ TiedEmbeddingConfig
626
+ | UntiedEmbeddingConfig
627
+ | QuantizedTiedEmbeddingConfig
628
+ | MLXQuantizedTiedEmbeddingConfig
629
+ | MLXSemiQuantizedUntiedEmbeddingConfig
630
+ )
357
631
 
358
632
 
359
- register_config_union(EmbeddingConfig)
633
+ register_config_union(EmbeddingConfig) # type: ignore (pyright bug)