lalamo 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/language_model.py +22 -23
- lalamo/main.py +4 -18
- lalamo/model_import/common.py +24 -6
- lalamo/model_import/decoder_configs/__init__.py +2 -0
- lalamo/model_import/decoder_configs/common.py +4 -4
- lalamo/model_import/decoder_configs/executorch.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +34 -26
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
- lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
- lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
- lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
- lalamo/model_import/loaders/executorch.py +5 -4
- lalamo/model_import/loaders/huggingface.py +321 -69
- lalamo/model_import/model_specs/__init__.py +2 -0
- lalamo/model_import/model_specs/common.py +16 -5
- lalamo/model_import/model_specs/llamba.py +40 -0
- lalamo/model_import/model_specs/qwen.py +29 -1
- lalamo/modules/__init__.py +33 -6
- lalamo/modules/activations.py +9 -2
- lalamo/modules/common.py +10 -5
- lalamo/modules/decoder.py +93 -97
- lalamo/modules/decoder_layer.py +85 -103
- lalamo/modules/embedding.py +279 -5
- lalamo/modules/linear.py +335 -30
- lalamo/modules/mlp.py +6 -7
- lalamo/modules/mlx_interop.py +19 -0
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +30 -0
- lalamo/modules/{attention.py → token_mixers/attention.py} +72 -70
- lalamo/modules/token_mixers/common.py +78 -0
- lalamo/modules/token_mixers/mamba.py +553 -0
- lalamo/modules/token_mixers/state/__init__.py +12 -0
- lalamo/modules/token_mixers/state/common.py +26 -0
- lalamo/modules/{kv_cache.py → token_mixers/state/kv_cache.py} +5 -16
- lalamo/modules/token_mixers/state/mamba_state.py +51 -0
- lalamo/utils.py +24 -2
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/METADATA +3 -2
- lalamo-0.5.0.dist-info/RECORD +80 -0
- lalamo-0.4.0.dist-info/RECORD +0 -71
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/WHEEL +0 -0
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.4.0.dist-info → lalamo-0.5.0.dist-info}/top_level.txt +0 -0
lalamo/modules/decoder_layer.py
CHANGED
|
@@ -11,12 +11,11 @@ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
|
11
11
|
|
|
12
12
|
from lalamo.common import ParameterTree
|
|
13
13
|
|
|
14
|
-
from .
|
|
15
|
-
from .common import AttentionType, ForwardPassMode, LalamoModule
|
|
16
|
-
from .kv_cache import KVCacheLayer, StaticKVCacheLayer
|
|
14
|
+
from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
|
|
17
15
|
from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
|
|
18
16
|
from .normalization import RMSNorm, RMSNormConfig
|
|
19
17
|
from .rope import PositionalEmbeddings
|
|
18
|
+
from .token_mixers import KVCacheLayer, StateLayerBase, StaticKVCacheLayer, TokenMixerBase, TokenMixerConfig
|
|
20
19
|
from .utils import vmap_twice
|
|
21
20
|
|
|
22
21
|
__all__ = [
|
|
@@ -33,31 +32,32 @@ type DecoderLayerForwardPassConfig = MLPForwardPassConfig
|
|
|
33
32
|
|
|
34
33
|
class DecoderLayerActivationTrace(eqx.Module):
|
|
35
34
|
inputs: Float[Array, "batch suffix_tokens channels"]
|
|
36
|
-
positional_embeddings: PositionalEmbeddings
|
|
37
|
-
|
|
35
|
+
positional_embeddings: PositionalEmbeddings | None
|
|
36
|
+
state: StateLayerBase | None
|
|
38
37
|
|
|
39
38
|
mlp_inputs: Float[Array, "batch suffix_tokens channels"]
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
39
|
+
pre_mixer_norm: Float[Array, "batch suffix_tokens channels"]
|
|
40
|
+
mixer: Float[Array, "batch suffix_tokens channels"]
|
|
41
|
+
post_mixer_norm: Float[Array, "batch suffix_tokens channels"] | None
|
|
43
42
|
pre_mlp_norm: Float[Array, "batch suffix_tokens channels"]
|
|
44
43
|
mlp: Float[Array, "batch suffix_tokens channels"]
|
|
45
44
|
post_mlp_norm: Float[Array, "batch suffix_tokens channels"] | None
|
|
46
45
|
|
|
47
46
|
def export(self) -> ParameterTree:
|
|
48
|
-
result = dict(
|
|
47
|
+
result: dict[str, ParameterTree | Array] = dict(
|
|
49
48
|
inputs=self.inputs,
|
|
50
|
-
positional_embeddings=self.positional_embeddings.export(),
|
|
51
49
|
mlp_inputs=self.mlp_inputs,
|
|
52
|
-
|
|
53
|
-
|
|
50
|
+
pre_mixer_norm=self.pre_mixer_norm,
|
|
51
|
+
mixer=self.mixer,
|
|
54
52
|
pre_mlp_norm=self.pre_mlp_norm,
|
|
55
53
|
mlp=self.mlp,
|
|
56
54
|
)
|
|
57
|
-
if self.
|
|
58
|
-
result["
|
|
59
|
-
if self.
|
|
60
|
-
result["
|
|
55
|
+
if self.positional_embeddings is not None:
|
|
56
|
+
result["positional_embeddings"] = self.positional_embeddings.export()
|
|
57
|
+
if self.state is not None:
|
|
58
|
+
result["state"] = self.state.export()
|
|
59
|
+
if self.post_mixer_norm is not None:
|
|
60
|
+
result["post_mixer_norm"] = self.post_mixer_norm
|
|
61
61
|
if self.post_mlp_norm is not None:
|
|
62
62
|
result["post_mlp_norm"] = self.post_mlp_norm
|
|
63
63
|
return result
|
|
@@ -65,15 +65,15 @@ class DecoderLayerActivationTrace(eqx.Module):
|
|
|
65
65
|
|
|
66
66
|
class DecoderLayerResult(eqx.Module):
|
|
67
67
|
outputs: Float[Array, "suffix_tokens channels"]
|
|
68
|
-
|
|
68
|
+
updated_state: KVCacheLayer | None
|
|
69
69
|
activation_trace: DecoderLayerActivationTrace | None
|
|
70
70
|
|
|
71
71
|
def export(self) -> ParameterTree:
|
|
72
72
|
result: dict[str, ParameterTree | Array] = dict(
|
|
73
73
|
outputs=self.outputs,
|
|
74
74
|
)
|
|
75
|
-
if self.
|
|
76
|
-
result["
|
|
75
|
+
if self.updated_state is not None:
|
|
76
|
+
result["updated_state"] = self.updated_state.export()
|
|
77
77
|
if self.activation_trace is not None:
|
|
78
78
|
result["activation_trace"] = self.activation_trace.export()
|
|
79
79
|
return result
|
|
@@ -81,39 +81,32 @@ class DecoderLayerResult(eqx.Module):
|
|
|
81
81
|
|
|
82
82
|
@dataclass(frozen=True)
|
|
83
83
|
class DecoderLayerConfig:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
84
|
+
pre_mixer_norm_config: RMSNormConfig
|
|
85
|
+
mixer_config: TokenMixerConfig
|
|
86
|
+
post_mixer_norm_config: RMSNormConfig | None
|
|
87
87
|
pre_mlp_norm_config: RMSNormConfig
|
|
88
88
|
mlp_config: MLPConfig
|
|
89
89
|
post_mlp_norm_config: RMSNormConfig | None
|
|
90
90
|
|
|
91
|
+
@property
|
|
92
|
+
def rope_dim(self) -> int:
|
|
93
|
+
return self.mixer_config.rope_dim
|
|
94
|
+
|
|
91
95
|
def random_init(
|
|
92
96
|
self,
|
|
93
97
|
model_dim: int,
|
|
94
98
|
hidden_dim: int,
|
|
95
|
-
num_heads: int,
|
|
96
|
-
num_groups: int,
|
|
97
|
-
head_dim: int,
|
|
98
|
-
attention_scale: float | None,
|
|
99
|
-
sliding_window_size: int | None,
|
|
100
99
|
*,
|
|
101
100
|
key: PRNGKeyArray,
|
|
102
101
|
) -> "DecoderLayer":
|
|
103
102
|
attention_key, mlp_key = jax.random.split(key)
|
|
104
|
-
pre_attention_norm = self.
|
|
105
|
-
|
|
103
|
+
pre_attention_norm = self.pre_mixer_norm_config.init(model_dim)
|
|
104
|
+
mixer = self.mixer_config.random_init(
|
|
106
105
|
model_dim=model_dim,
|
|
107
|
-
num_heads=num_heads,
|
|
108
|
-
num_groups=num_groups,
|
|
109
|
-
head_dim=head_dim,
|
|
110
|
-
is_causal=True,
|
|
111
|
-
scale=attention_scale,
|
|
112
|
-
sliding_window_size=sliding_window_size,
|
|
113
106
|
key=attention_key,
|
|
114
107
|
)
|
|
115
|
-
if self.
|
|
116
|
-
post_attention_norm = self.
|
|
108
|
+
if self.post_mixer_norm_config is not None:
|
|
109
|
+
post_attention_norm = self.post_mixer_norm_config.init(model_dim)
|
|
117
110
|
else:
|
|
118
111
|
post_attention_norm = None
|
|
119
112
|
pre_mlp_norm = self.pre_mlp_norm_config.init(model_dim)
|
|
@@ -124,9 +117,9 @@ class DecoderLayerConfig:
|
|
|
124
117
|
post_mlp_norm = None
|
|
125
118
|
return DecoderLayer(
|
|
126
119
|
config=self,
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
120
|
+
pre_mixer_norm=pre_attention_norm,
|
|
121
|
+
mixer=mixer,
|
|
122
|
+
post_mixer_norm=post_attention_norm,
|
|
130
123
|
pre_mlp_norm=pre_mlp_norm,
|
|
131
124
|
mlp=mlp,
|
|
132
125
|
post_mlp_norm=post_mlp_norm,
|
|
@@ -136,24 +129,13 @@ class DecoderLayerConfig:
|
|
|
136
129
|
self,
|
|
137
130
|
model_dim: int,
|
|
138
131
|
hidden_dim: int,
|
|
139
|
-
num_heads: int,
|
|
140
|
-
num_groups: int,
|
|
141
|
-
head_dim: int,
|
|
142
|
-
attention_scale: float | None,
|
|
143
|
-
sliding_window_size: int | None,
|
|
144
132
|
) -> "DecoderLayer":
|
|
145
|
-
pre_attention_norm = self.
|
|
146
|
-
attention = self.
|
|
133
|
+
pre_attention_norm = self.pre_mixer_norm_config.empty(model_dim)
|
|
134
|
+
attention = self.mixer_config.empty(
|
|
147
135
|
model_dim=model_dim,
|
|
148
|
-
num_heads=num_heads,
|
|
149
|
-
num_groups=num_groups,
|
|
150
|
-
head_dim=head_dim,
|
|
151
|
-
is_causal=True,
|
|
152
|
-
scale=attention_scale,
|
|
153
|
-
sliding_window_size=sliding_window_size,
|
|
154
136
|
)
|
|
155
|
-
if self.
|
|
156
|
-
post_attention_norm = self.
|
|
137
|
+
if self.post_mixer_norm_config is not None:
|
|
138
|
+
post_attention_norm = self.post_mixer_norm_config.empty(model_dim)
|
|
157
139
|
else:
|
|
158
140
|
post_attention_norm = None
|
|
159
141
|
pre_mlp_norm = self.pre_mlp_norm_config.empty(model_dim)
|
|
@@ -164,9 +146,9 @@ class DecoderLayerConfig:
|
|
|
164
146
|
post_mlp_norm = None
|
|
165
147
|
return DecoderLayer(
|
|
166
148
|
config=self,
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
149
|
+
pre_mixer_norm=pre_attention_norm,
|
|
150
|
+
mixer=attention,
|
|
151
|
+
post_mixer_norm=post_attention_norm,
|
|
170
152
|
pre_mlp_norm=pre_mlp_norm,
|
|
171
153
|
mlp=mlp,
|
|
172
154
|
post_mlp_norm=post_mlp_norm,
|
|
@@ -174,31 +156,31 @@ class DecoderLayerConfig:
|
|
|
174
156
|
|
|
175
157
|
|
|
176
158
|
class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
159
|
+
pre_mixer_norm: RMSNorm
|
|
160
|
+
mixer: TokenMixerBase
|
|
161
|
+
post_mixer_norm: RMSNorm | None
|
|
180
162
|
pre_mlp_norm: RMSNorm
|
|
181
163
|
mlp: MLPBase
|
|
182
164
|
post_mlp_norm: RMSNorm | None
|
|
183
165
|
|
|
184
166
|
@property
|
|
185
167
|
def activation_precision(self) -> DTypeLike:
|
|
186
|
-
return self.
|
|
168
|
+
return self.mixer.activation_precision
|
|
187
169
|
|
|
188
170
|
@property
|
|
189
|
-
def
|
|
190
|
-
return self.
|
|
171
|
+
def positional_embedding_selector(self) -> PositionalEmbeddingSelector:
|
|
172
|
+
return self.mixer.positional_embedding_selector
|
|
191
173
|
|
|
192
174
|
def __post_init__(self) -> None:
|
|
193
|
-
model_dim = self.
|
|
194
|
-
if self.
|
|
175
|
+
model_dim = self.pre_mixer_norm.input_dim
|
|
176
|
+
if self.mixer.model_dim != model_dim:
|
|
195
177
|
raise ValueError(
|
|
196
|
-
f"Attention model dim {self.
|
|
178
|
+
f"Attention model dim {self.mixer.model_dim} does not match"
|
|
197
179
|
f" the first normalization layer dim {model_dim}",
|
|
198
180
|
)
|
|
199
|
-
if self.
|
|
181
|
+
if self.post_mixer_norm is not None and self.post_mixer_norm.input_dim != model_dim:
|
|
200
182
|
raise ValueError(
|
|
201
|
-
f"Post
|
|
183
|
+
f"Post mixer normalization dim {self.post_mixer_norm.input_dim} does not match"
|
|
202
184
|
f" the first normalization layer dim {model_dim}",
|
|
203
185
|
)
|
|
204
186
|
if self.pre_mlp_norm.input_dim != model_dim:
|
|
@@ -216,9 +198,9 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
216
198
|
def __call__(
|
|
217
199
|
self,
|
|
218
200
|
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
219
|
-
positional_embeddings: PositionalEmbeddings,
|
|
220
|
-
|
|
221
|
-
|
|
201
|
+
positional_embeddings: PositionalEmbeddings | None,
|
|
202
|
+
state: StateLayerBase | None = None,
|
|
203
|
+
return_updated_state: bool = False,
|
|
222
204
|
return_activation_trace: bool = False,
|
|
223
205
|
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
224
206
|
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
@@ -229,20 +211,20 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
229
211
|
f"Inputs to decoder layers must be a 3D arrays of size (batch_size, sequence_length, hidden_dim),"
|
|
230
212
|
f" got {inputs.shape}",
|
|
231
213
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
214
|
+
normalized_mixer_inputs = vmap_twice(self.pre_mixer_norm)(inputs)
|
|
215
|
+
batched_mixer_fn = vmap(partial(self.mixer, return_updated_state=return_updated_state))
|
|
216
|
+
mixer_outputs, updated_state = batched_mixer_fn(
|
|
217
|
+
normalized_mixer_inputs,
|
|
236
218
|
positional_embeddings,
|
|
237
|
-
|
|
219
|
+
state=state,
|
|
238
220
|
length_without_padding=lengths_without_padding,
|
|
239
221
|
)
|
|
240
|
-
if self.
|
|
241
|
-
|
|
242
|
-
mlp_inputs = inputs +
|
|
222
|
+
if self.post_mixer_norm is not None:
|
|
223
|
+
normalized_mixer_outputs = vmap_twice(self.post_mixer_norm)(mixer_outputs)
|
|
224
|
+
mlp_inputs = inputs + normalized_mixer_outputs
|
|
243
225
|
else:
|
|
244
|
-
|
|
245
|
-
mlp_inputs = inputs +
|
|
226
|
+
normalized_mixer_outputs = None
|
|
227
|
+
mlp_inputs = inputs + mixer_outputs
|
|
246
228
|
|
|
247
229
|
normalized_mlp_inputs = vmap_twice(self.pre_mlp_norm)(mlp_inputs)
|
|
248
230
|
mlp_outputs = self.mlp(
|
|
@@ -261,10 +243,10 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
261
243
|
activation_trace = DecoderLayerActivationTrace(
|
|
262
244
|
inputs=inputs,
|
|
263
245
|
positional_embeddings=positional_embeddings,
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
246
|
+
state=state,
|
|
247
|
+
pre_mixer_norm=normalized_mixer_inputs,
|
|
248
|
+
mixer=mixer_outputs,
|
|
249
|
+
post_mixer_norm=normalized_mixer_outputs,
|
|
268
250
|
mlp_inputs=mlp_inputs,
|
|
269
251
|
pre_mlp_norm=normalized_mlp_inputs,
|
|
270
252
|
mlp=mlp_outputs,
|
|
@@ -275,25 +257,25 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
275
257
|
|
|
276
258
|
return DecoderLayerResult(
|
|
277
259
|
outputs=outputs,
|
|
278
|
-
|
|
260
|
+
updated_state=updated_state,
|
|
279
261
|
activation_trace=activation_trace,
|
|
280
262
|
)
|
|
281
263
|
|
|
282
|
-
def
|
|
264
|
+
def init_static_state(self, batch_size: int, capacity: int) -> StaticKVCacheLayer:
|
|
283
265
|
return jax.tree.map(
|
|
284
266
|
lambda array: jnp.repeat(array[None, ...], batch_size, axis=0),
|
|
285
|
-
self.
|
|
267
|
+
self.mixer.init_static_state(capacity),
|
|
286
268
|
)
|
|
287
269
|
|
|
288
270
|
def export_weights(self) -> ParameterTree:
|
|
289
271
|
result = dict(
|
|
290
|
-
|
|
291
|
-
|
|
272
|
+
pre_mixer_norm=self.pre_mixer_norm.export_weights(),
|
|
273
|
+
mixer=self.mixer.export_weights(),
|
|
292
274
|
pre_mlp_norm=self.pre_mlp_norm.export_weights(),
|
|
293
275
|
mlp=self.mlp.export_weights(),
|
|
294
276
|
)
|
|
295
|
-
if self.
|
|
296
|
-
result["
|
|
277
|
+
if self.post_mixer_norm is not None:
|
|
278
|
+
result["post_mixer_norm"] = self.post_mixer_norm.export_weights()
|
|
297
279
|
if self.post_mlp_norm is not None:
|
|
298
280
|
result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
|
|
299
281
|
return result
|
|
@@ -303,18 +285,18 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
303
285
|
weights: ParameterTree[Array],
|
|
304
286
|
) -> Self:
|
|
305
287
|
assert isinstance(weights, Mapping)
|
|
306
|
-
assert isinstance(weights["
|
|
307
|
-
assert isinstance(weights["
|
|
288
|
+
assert isinstance(weights["pre_mixer_norm"], Mapping)
|
|
289
|
+
assert isinstance(weights["mixer"], Mapping)
|
|
308
290
|
assert isinstance(weights["mlp"], Mapping)
|
|
309
291
|
assert isinstance(weights["pre_mlp_norm"], Mapping)
|
|
310
292
|
|
|
311
|
-
if self.
|
|
312
|
-
assert isinstance(weights["
|
|
313
|
-
|
|
314
|
-
weights["
|
|
293
|
+
if self.post_mixer_norm is not None:
|
|
294
|
+
assert isinstance(weights["post_mixer_norm"], Mapping)
|
|
295
|
+
post_mixer_norm = self.post_mixer_norm.import_weights(
|
|
296
|
+
weights["post_mixer_norm"],
|
|
315
297
|
)
|
|
316
298
|
else:
|
|
317
|
-
|
|
299
|
+
post_mixer_norm = None
|
|
318
300
|
if self.post_mlp_norm is not None:
|
|
319
301
|
assert isinstance(weights["post_mlp_norm"], Mapping)
|
|
320
302
|
post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
|
|
@@ -322,9 +304,9 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
322
304
|
post_mlp_norm = None
|
|
323
305
|
return replace(
|
|
324
306
|
self,
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
307
|
+
pre_mixer_norm=self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"]),
|
|
308
|
+
mixer=self.mixer.import_weights(weights["mixer"]),
|
|
309
|
+
post_mixer_norm=post_mixer_norm,
|
|
328
310
|
pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
|
|
329
311
|
mlp=self.mlp.import_weights(weights["mlp"]),
|
|
330
312
|
post_mlp_norm=post_mlp_norm,
|
lalamo/modules/embedding.py
CHANGED
|
@@ -6,10 +6,12 @@ from typing import Self
|
|
|
6
6
|
import equinox as eqx
|
|
7
7
|
import jax
|
|
8
8
|
import jax.numpy as jnp
|
|
9
|
+
from einops import rearrange
|
|
9
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
10
11
|
|
|
11
12
|
from lalamo.common import ParameterTree, dummy_array
|
|
12
13
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
14
|
+
from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
|
|
13
15
|
|
|
14
16
|
from .common import (
|
|
15
17
|
LalamoModule,
|
|
@@ -20,6 +22,10 @@ from .utils import apply_soft_capping
|
|
|
20
22
|
__all__ = [
|
|
21
23
|
"EmbeddingBase",
|
|
22
24
|
"EmbeddingConfig",
|
|
25
|
+
"MLXQuantizedTiedEmbedding",
|
|
26
|
+
"MLXQuantizedTiedEmbeddingConfig",
|
|
27
|
+
"MLXSemiQuantizedUntiedEmbedding",
|
|
28
|
+
"MLXSemiQuantizedUntiedEmbeddingConfig",
|
|
23
29
|
"QuantizedTiedEmbedding",
|
|
24
30
|
"QuantizedTiedEmbeddingConfig",
|
|
25
31
|
"TiedEmbedding",
|
|
@@ -314,8 +320,15 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
314
320
|
|
|
315
321
|
@property
|
|
316
322
|
def int_weights(self) -> Int[Array, "vocabulary channels"]:
|
|
317
|
-
|
|
318
|
-
|
|
323
|
+
quantized = quantize_weights(self.weights, self.config.embedding_quantization_mode)
|
|
324
|
+
casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
|
|
325
|
+
|
|
326
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
327
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
328
|
+
else:
|
|
329
|
+
packed = casted
|
|
330
|
+
|
|
331
|
+
return packed
|
|
319
332
|
|
|
320
333
|
def _prepare_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
321
334
|
quantized_weights = quantize_weights(self.weights, self.config.embedding_quantization_mode)
|
|
@@ -346,14 +359,275 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
346
359
|
) -> Self:
|
|
347
360
|
assert isinstance(weights, Mapping)
|
|
348
361
|
assert isinstance(weights["weights"], Array)
|
|
362
|
+
stored_weights = weights["weights"]
|
|
363
|
+
|
|
364
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
365
|
+
stored_weights = jax_uint8_to_unpacked_uint4(stored_weights)
|
|
366
|
+
|
|
367
|
+
return replace(
|
|
368
|
+
self,
|
|
369
|
+
weights=stored_weights.astype(self.weights.dtype),
|
|
370
|
+
scales=weights["scales"],
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@dataclass(frozen=True)
|
|
375
|
+
class MLXQuantizedTiedEmbeddingConfig(EmbeddingConfigBase):
|
|
376
|
+
group_size: int
|
|
377
|
+
embedding_quantization_mode: QuantizationMode
|
|
378
|
+
activation_quantization_mode: QuantizationMode | None
|
|
379
|
+
activation_precision: DTypeLike
|
|
380
|
+
|
|
381
|
+
def random_init(
|
|
382
|
+
self,
|
|
383
|
+
vocab_size: int,
|
|
384
|
+
model_dim: int,
|
|
385
|
+
*,
|
|
386
|
+
key: PRNGKeyArray,
|
|
387
|
+
) -> "QuantizedTiedEmbedding":
|
|
388
|
+
raise NotImplementedError
|
|
389
|
+
|
|
390
|
+
def empty(
|
|
391
|
+
self,
|
|
392
|
+
vocab_size: int,
|
|
393
|
+
model_dim: int,
|
|
394
|
+
) -> "MLXQuantizedTiedEmbedding":
|
|
395
|
+
assert model_dim % self.group_size == 0
|
|
396
|
+
model_groups = model_dim // self.group_size
|
|
397
|
+
weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
|
|
398
|
+
scales = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
|
|
399
|
+
biases = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
|
|
400
|
+
return MLXQuantizedTiedEmbedding(config=self, weights=weights, scales=scales, biases=biases)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
|
|
404
|
+
weights: Float[Array, "vocabulary channels"]
|
|
405
|
+
scales: Float[Array, "vocabulary groups"]
|
|
406
|
+
biases: Float[Array, "vocabulary groups"]
|
|
407
|
+
|
|
408
|
+
@property
|
|
409
|
+
def activation_precision(self) -> DTypeLike:
|
|
410
|
+
return self.config.activation_precision
|
|
411
|
+
|
|
412
|
+
@property
|
|
413
|
+
def model_dim(self) -> int:
|
|
414
|
+
_, model_dim = self.weights.shape
|
|
415
|
+
return model_dim
|
|
416
|
+
|
|
417
|
+
@property
|
|
418
|
+
def vocab_size(self) -> int:
|
|
419
|
+
vocab_size, _ = self.weights.shape
|
|
420
|
+
return vocab_size
|
|
421
|
+
|
|
422
|
+
@property
|
|
423
|
+
def int_weights(self) -> Int[Array, "vocabulary channels"]:
|
|
424
|
+
quantized = quantize_weights(self.weights, self.config.embedding_quantization_mode)
|
|
425
|
+
casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
|
|
426
|
+
|
|
427
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
428
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
429
|
+
else:
|
|
430
|
+
packed = casted
|
|
431
|
+
|
|
432
|
+
return packed
|
|
433
|
+
|
|
434
|
+
def _prepare_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
435
|
+
quantized_weights = quantize_weights(self.weights, self.config.embedding_quantization_mode)
|
|
436
|
+
grouped_weights = rearrange(
|
|
437
|
+
quantized_weights,
|
|
438
|
+
"vocab (groups elements) -> vocab groups elements",
|
|
439
|
+
elements=self.config.group_size,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
scales = rearrange(self.scales, "vocab groups -> vocab groups 1")
|
|
443
|
+
|
|
444
|
+
biases = rearrange(self.biases, "vocab groups -> vocab groups 1")
|
|
445
|
+
|
|
446
|
+
scaled_grouped_weights = grouped_weights * scales + biases
|
|
447
|
+
|
|
448
|
+
result = rearrange(
|
|
449
|
+
scaled_grouped_weights,
|
|
450
|
+
"vocab groups elements -> vocab (groups elements)",
|
|
451
|
+
)
|
|
452
|
+
return result
|
|
453
|
+
|
|
454
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
455
|
+
return self._prepare_weights()
|
|
456
|
+
|
|
457
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
458
|
+
return self._prepare_weights()
|
|
459
|
+
|
|
460
|
+
@eqx.filter_jit
|
|
461
|
+
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
462
|
+
if self.config.activation_quantization_mode is not None:
|
|
463
|
+
x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
|
|
464
|
+
return super().readout(x)
|
|
465
|
+
|
|
466
|
+
def export_weights(self) -> ParameterTree:
|
|
467
|
+
return {
|
|
468
|
+
"weights": self.int_weights,
|
|
469
|
+
"scales": self.scales,
|
|
470
|
+
"biases": self.biases,
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
def import_weights(
|
|
474
|
+
self,
|
|
475
|
+
weights: ParameterTree[Array],
|
|
476
|
+
) -> Self:
|
|
477
|
+
assert isinstance(weights, Mapping)
|
|
478
|
+
assert isinstance(weights["weights"], Array)
|
|
479
|
+
assert isinstance(weights["scales"], Array)
|
|
480
|
+
assert isinstance(weights["biases"], Array)
|
|
481
|
+
|
|
482
|
+
unpacked_weights = weights["weights"]
|
|
483
|
+
|
|
484
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
485
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
|
|
486
|
+
|
|
349
487
|
return replace(
|
|
350
488
|
self,
|
|
351
|
-
weights=
|
|
489
|
+
weights=unpacked_weights.astype(self.weights.dtype),
|
|
352
490
|
scales=weights["scales"],
|
|
491
|
+
biases=weights["biases"],
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
@dataclass(frozen=True)
|
|
496
|
+
class MLXSemiQuantizedUntiedEmbeddingConfig(EmbeddingConfigBase):
|
|
497
|
+
group_size: int
|
|
498
|
+
embedding_quantization_mode: QuantizationMode
|
|
499
|
+
activation_quantization_mode: QuantizationMode | None
|
|
500
|
+
activation_precision: DTypeLike
|
|
501
|
+
|
|
502
|
+
def random_init(
|
|
503
|
+
self,
|
|
504
|
+
vocab_size: int,
|
|
505
|
+
model_dim: int,
|
|
506
|
+
*,
|
|
507
|
+
key: PRNGKeyArray,
|
|
508
|
+
) -> "MLXSemiQuantizedUntiedEmbedding":
|
|
509
|
+
raise NotImplementedError
|
|
510
|
+
|
|
511
|
+
def empty(
|
|
512
|
+
self,
|
|
513
|
+
vocab_size: int,
|
|
514
|
+
model_dim: int,
|
|
515
|
+
) -> "MLXSemiQuantizedUntiedEmbedding":
|
|
516
|
+
assert model_dim % self.group_size == 0
|
|
517
|
+
model_groups = model_dim // self.group_size
|
|
518
|
+
input_weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
|
|
519
|
+
output_weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
|
|
520
|
+
output_scales = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
|
|
521
|
+
output_biases = dummy_array((vocab_size, model_groups), dtype=self.activation_precision)
|
|
522
|
+
return MLXSemiQuantizedUntiedEmbedding(
|
|
523
|
+
config=self,
|
|
524
|
+
input_weights=input_weights,
|
|
525
|
+
output_weights=output_weights,
|
|
526
|
+
output_scales=output_scales,
|
|
527
|
+
output_biases=output_biases,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class MLXSemiQuantizedUntiedEmbedding(EmbeddingBase[MLXSemiQuantizedUntiedEmbeddingConfig]):
|
|
532
|
+
input_weights: Float[Array, "vocabulary channels"]
|
|
533
|
+
output_weights: Float[Array, "vocabulary channels"]
|
|
534
|
+
output_scales: Float[Array, "vocabulary groups"]
|
|
535
|
+
output_biases: Float[Array, "vocabulary groups"]
|
|
536
|
+
|
|
537
|
+
@property
|
|
538
|
+
def activation_precision(self) -> DTypeLike:
|
|
539
|
+
return self.config.activation_precision
|
|
540
|
+
|
|
541
|
+
@property
|
|
542
|
+
def model_dim(self) -> int:
|
|
543
|
+
_, model_dim = self.input_weights.shape
|
|
544
|
+
return model_dim
|
|
545
|
+
|
|
546
|
+
@property
|
|
547
|
+
def vocab_size(self) -> int:
|
|
548
|
+
vocab_size, _ = self.input_weights.shape
|
|
549
|
+
return vocab_size
|
|
550
|
+
|
|
551
|
+
@property
|
|
552
|
+
def int_output_weights(self) -> Int[Array, "vocabulary channels"]:
|
|
553
|
+
quantized = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
|
|
554
|
+
casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
|
|
555
|
+
|
|
556
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
557
|
+
packed = jax_uint4_to_packed_uint8(casted)
|
|
558
|
+
else:
|
|
559
|
+
packed = casted
|
|
560
|
+
|
|
561
|
+
return packed
|
|
562
|
+
|
|
563
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
564
|
+
return self.input_weights
|
|
565
|
+
|
|
566
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
567
|
+
quantized_weights = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
|
|
568
|
+
grouped_weights = rearrange(
|
|
569
|
+
quantized_weights,
|
|
570
|
+
"vocab (groups elements) -> vocab groups elements",
|
|
571
|
+
elements=self.config.group_size,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
scales = rearrange(self.output_scales, "vocab groups -> vocab groups 1")
|
|
575
|
+
|
|
576
|
+
biases = rearrange(self.output_biases, "vocab groups -> vocab groups 1")
|
|
577
|
+
|
|
578
|
+
scaled_grouped_weights = grouped_weights * scales + biases
|
|
579
|
+
|
|
580
|
+
result = rearrange(
|
|
581
|
+
scaled_grouped_weights,
|
|
582
|
+
"vocab groups elements -> vocab (groups elements)",
|
|
583
|
+
)
|
|
584
|
+
return result
|
|
585
|
+
|
|
586
|
+
@eqx.filter_jit
|
|
587
|
+
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
588
|
+
if self.config.activation_quantization_mode is not None:
|
|
589
|
+
x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
|
|
590
|
+
return super().readout(x)
|
|
591
|
+
|
|
592
|
+
def export_weights(self) -> ParameterTree:
|
|
593
|
+
return {
|
|
594
|
+
"input_weights": self.input_weights,
|
|
595
|
+
"output_weights": self.int_output_weights,
|
|
596
|
+
"output_scales": self.output_scales,
|
|
597
|
+
"output_biases": self.output_biases,
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
def import_weights(
|
|
601
|
+
self,
|
|
602
|
+
weights: ParameterTree[Array],
|
|
603
|
+
) -> Self:
|
|
604
|
+
assert isinstance(weights, Mapping)
|
|
605
|
+
assert isinstance(weights["input_weights"], Array)
|
|
606
|
+
assert isinstance(weights["output_weights"], Array)
|
|
607
|
+
assert isinstance(weights["output_scales"], Array)
|
|
608
|
+
assert isinstance(weights["output_biases"], Array)
|
|
609
|
+
|
|
610
|
+
unpacked_output_weights = weights["output_weights"]
|
|
611
|
+
|
|
612
|
+
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
613
|
+
unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
|
|
614
|
+
|
|
615
|
+
return replace(
|
|
616
|
+
self,
|
|
617
|
+
input_weights=weights["input_weights"],
|
|
618
|
+
output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
|
|
619
|
+
output_scales=weights["output_scales"],
|
|
620
|
+
output_biases=weights["output_biases"],
|
|
353
621
|
)
|
|
354
622
|
|
|
355
623
|
|
|
356
|
-
EmbeddingConfig =
|
|
624
|
+
EmbeddingConfig = (
|
|
625
|
+
TiedEmbeddingConfig
|
|
626
|
+
| UntiedEmbeddingConfig
|
|
627
|
+
| QuantizedTiedEmbeddingConfig
|
|
628
|
+
| MLXQuantizedTiedEmbeddingConfig
|
|
629
|
+
| MLXSemiQuantizedUntiedEmbeddingConfig
|
|
630
|
+
)
|
|
357
631
|
|
|
358
632
|
|
|
359
|
-
register_config_union(EmbeddingConfig)
|
|
633
|
+
register_config_union(EmbeddingConfig) # type: ignore (pyright bug)
|