lalamo 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/common.py +79 -29
- lalamo/language_model.py +106 -83
- lalamo/main.py +91 -18
- lalamo/message_processor.py +170 -0
- lalamo/model_import/common.py +159 -43
- lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
- lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
- lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
- lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
- lalamo/model_import/huggingface_generation_config.py +44 -0
- lalamo/model_import/huggingface_tokenizer_config.py +85 -0
- lalamo/model_import/loaders/common.py +2 -1
- lalamo/model_import/loaders/huggingface.py +12 -10
- lalamo/model_import/model_specs/__init__.py +3 -2
- lalamo/model_import/model_specs/common.py +31 -32
- lalamo/model_import/model_specs/deepseek.py +1 -10
- lalamo/model_import/model_specs/gemma.py +2 -25
- lalamo/model_import/model_specs/huggingface.py +2 -12
- lalamo/model_import/model_specs/llama.py +2 -58
- lalamo/model_import/model_specs/mistral.py +9 -19
- lalamo/model_import/model_specs/pleias.py +3 -13
- lalamo/model_import/model_specs/polaris.py +5 -7
- lalamo/model_import/model_specs/qwen.py +12 -111
- lalamo/model_import/model_specs/reka.py +4 -13
- lalamo/modules/__init__.py +2 -1
- lalamo/modules/attention.py +90 -10
- lalamo/modules/common.py +51 -4
- lalamo/modules/decoder.py +90 -8
- lalamo/modules/decoder_layer.py +85 -8
- lalamo/modules/embedding.py +95 -29
- lalamo/modules/kv_cache.py +3 -3
- lalamo/modules/linear.py +170 -130
- lalamo/modules/mlp.py +40 -7
- lalamo/modules/normalization.py +24 -6
- lalamo/modules/rope.py +24 -6
- lalamo/sampling.py +99 -0
- lalamo/utils.py +86 -1
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/METADATA +6 -6
- lalamo-0.3.1.dist-info/RECORD +58 -0
- lalamo-0.2.7.dist-info/RECORD +0 -54
- /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/WHEEL +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/top_level.txt +0 -0
lalamo/modules/decoder.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Mapping, Sequence
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from typing import Self
|
|
2
4
|
|
|
3
5
|
import equinox as eqx
|
|
4
6
|
import jax
|
|
5
7
|
from jax import vmap
|
|
6
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
7
9
|
|
|
8
|
-
from lalamo.common import
|
|
10
|
+
from lalamo.common import ParameterTree
|
|
9
11
|
|
|
10
12
|
from .common import AttentionType, LalamoModule, WeightLayout
|
|
11
13
|
from .decoder_layer import DecoderLayer, DecoderLayerConfig, DecoderLayerResult
|
|
@@ -34,8 +36,8 @@ class DecoderActivationTrace(eqx.Module):
|
|
|
34
36
|
|
|
35
37
|
output_norm: Float[Array, "suffix_tokens channels"]
|
|
36
38
|
|
|
37
|
-
def export(self) ->
|
|
38
|
-
result =
|
|
39
|
+
def export(self) -> ParameterTree:
|
|
40
|
+
result = dict(
|
|
39
41
|
token_ids=self.token_ids,
|
|
40
42
|
token_positions=self.token_positions,
|
|
41
43
|
local_positional_embeddings=self.local_positional_embeddings.export(),
|
|
@@ -53,8 +55,8 @@ class DecoderResult(eqx.Module):
|
|
|
53
55
|
updated_kv_cache: KVCache | None = None
|
|
54
56
|
activation_trace: DecoderActivationTrace | None = None
|
|
55
57
|
|
|
56
|
-
def export(self) ->
|
|
57
|
-
result =
|
|
58
|
+
def export(self) -> ParameterTree:
|
|
59
|
+
result: dict[str, ParameterTree | Array] = dict(
|
|
58
60
|
logits=self.logits,
|
|
59
61
|
)
|
|
60
62
|
if self.updated_kv_cache is not None:
|
|
@@ -152,6 +154,56 @@ class DecoderConfig:
|
|
|
152
154
|
output_norm=output_norm,
|
|
153
155
|
)
|
|
154
156
|
|
|
157
|
+
def empty(
|
|
158
|
+
self,
|
|
159
|
+
) -> "Decoder":
|
|
160
|
+
embedding = self.embedding_config.empty(
|
|
161
|
+
vocab_size=self.vocab_size,
|
|
162
|
+
model_dim=self.model_dim,
|
|
163
|
+
)
|
|
164
|
+
global_rope = self.global_rope_config.init(
|
|
165
|
+
head_dim=self.head_dim,
|
|
166
|
+
num_timesteps=self.context_length,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if self.local_rope_config:
|
|
170
|
+
assert self.sliding_window_sizes is not None
|
|
171
|
+
max_sliding_window_size = max(
|
|
172
|
+
window_size for window_size in self.sliding_window_sizes if window_size is not None
|
|
173
|
+
)
|
|
174
|
+
local_rope = self.local_rope_config.init(
|
|
175
|
+
head_dim=self.head_dim,
|
|
176
|
+
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
local_rope = None
|
|
180
|
+
|
|
181
|
+
if self.sliding_window_sizes is None:
|
|
182
|
+
sliding_window_sizes = [None] * self.num_layers
|
|
183
|
+
else:
|
|
184
|
+
sliding_window_sizes = self.sliding_window_sizes
|
|
185
|
+
layers = tuple(
|
|
186
|
+
self.layer_config.empty(
|
|
187
|
+
model_dim=self.model_dim,
|
|
188
|
+
hidden_dim=self.hidden_dim,
|
|
189
|
+
num_heads=self.num_heads,
|
|
190
|
+
num_groups=self.num_groups,
|
|
191
|
+
head_dim=self.head_dim,
|
|
192
|
+
attention_scale=self.attention_scale,
|
|
193
|
+
sliding_window_size=sliding_window_size,
|
|
194
|
+
)
|
|
195
|
+
for sliding_window_size in sliding_window_sizes
|
|
196
|
+
)
|
|
197
|
+
output_norm = self.output_norm_config.empty(self.model_dim)
|
|
198
|
+
return Decoder(
|
|
199
|
+
self,
|
|
200
|
+
embedding=embedding,
|
|
201
|
+
global_rope=global_rope,
|
|
202
|
+
local_rope=local_rope,
|
|
203
|
+
layers=layers,
|
|
204
|
+
output_norm=output_norm,
|
|
205
|
+
)
|
|
206
|
+
|
|
155
207
|
|
|
156
208
|
class Decoder(LalamoModule[DecoderConfig]):
|
|
157
209
|
embedding: EmbeddingBase
|
|
@@ -164,6 +216,7 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
164
216
|
def activation_precision(self) -> DTypeLike:
|
|
165
217
|
return self.embedding.activation_precision
|
|
166
218
|
|
|
219
|
+
@eqx.filter_jit
|
|
167
220
|
def __call__(
|
|
168
221
|
self,
|
|
169
222
|
token_ids: Int[Array, " suffix_tokens"],
|
|
@@ -232,8 +285,8 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
232
285
|
def init_static_kv_cache(self, capacity: int) -> KVCache:
|
|
233
286
|
return KVCache(layer.init_static_kv_cache(capacity) for layer in self.layers)
|
|
234
287
|
|
|
235
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
236
|
-
result =
|
|
288
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
289
|
+
result = dict(
|
|
237
290
|
embedding=self.embedding.export_weights(weight_layout),
|
|
238
291
|
global_rope=self.global_rope.export_weights(weight_layout),
|
|
239
292
|
layers=[layer.export_weights(weight_layout) for layer in self.layers],
|
|
@@ -242,3 +295,32 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
242
295
|
if self.local_rope:
|
|
243
296
|
result["local_rope"] = self.local_rope.export_weights(weight_layout)
|
|
244
297
|
return result
|
|
298
|
+
|
|
299
|
+
def import_weights(
|
|
300
|
+
self,
|
|
301
|
+
weights: ParameterTree[Array],
|
|
302
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
303
|
+
) -> Self:
|
|
304
|
+
assert isinstance(weights, Mapping)
|
|
305
|
+
assert isinstance(weights["embedding"], Mapping)
|
|
306
|
+
assert isinstance(weights["global_rope"], Mapping)
|
|
307
|
+
assert isinstance(weights["layers"], Sequence)
|
|
308
|
+
assert isinstance(weights["output_norm"], Mapping)
|
|
309
|
+
if self.local_rope:
|
|
310
|
+
assert isinstance(weights["local_rope"], Mapping)
|
|
311
|
+
local_rope = self.local_rope.import_weights(weights["local_rope"], weight_layout)
|
|
312
|
+
else:
|
|
313
|
+
local_rope = None
|
|
314
|
+
|
|
315
|
+
layers = []
|
|
316
|
+
for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
|
|
317
|
+
assert isinstance(layer_weights, Mapping)
|
|
318
|
+
layers.append(layer.import_weights(layer_weights, weight_layout))
|
|
319
|
+
return replace(
|
|
320
|
+
self,
|
|
321
|
+
embedding=self.embedding.import_weights(weights["embedding"], weight_layout),
|
|
322
|
+
global_rope=self.global_rope.import_weights(weights["global_rope"], weight_layout),
|
|
323
|
+
layers=tuple(layers),
|
|
324
|
+
output_norm=self.output_norm.import_weights(weights["output_norm"], weight_layout),
|
|
325
|
+
local_rope=local_rope,
|
|
326
|
+
)
|
lalamo/modules/decoder_layer.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from typing import Self
|
|
2
4
|
|
|
3
5
|
import equinox as eqx
|
|
4
6
|
import jax
|
|
5
7
|
from jax import vmap
|
|
6
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
7
9
|
|
|
8
|
-
from lalamo.common import
|
|
10
|
+
from lalamo.common import ParameterTree
|
|
9
11
|
|
|
10
12
|
from .attention import Attention, AttentionConfig
|
|
11
13
|
from .common import AttentionType, LalamoModule, WeightLayout
|
|
@@ -35,8 +37,8 @@ class DecoderLayerActivationTrace(eqx.Module):
|
|
|
35
37
|
mlp: Float[Array, "suffix_tokens channels"]
|
|
36
38
|
post_mlp_norm: Float[Array, "suffix_tokens channels"] | None
|
|
37
39
|
|
|
38
|
-
def export(self) ->
|
|
39
|
-
result =
|
|
40
|
+
def export(self) -> ParameterTree:
|
|
41
|
+
result = dict(
|
|
40
42
|
inputs=self.inputs,
|
|
41
43
|
positional_embeddings=self.positional_embeddings.export(),
|
|
42
44
|
mlp_inputs=self.mlp_inputs,
|
|
@@ -59,8 +61,8 @@ class DecoderLayerResult(eqx.Module):
|
|
|
59
61
|
updated_kv_cache: KVCacheLayer | None
|
|
60
62
|
activation_trace: DecoderLayerActivationTrace | None
|
|
61
63
|
|
|
62
|
-
def export(self) ->
|
|
63
|
-
result =
|
|
64
|
+
def export(self) -> ParameterTree:
|
|
65
|
+
result: dict[str, ParameterTree | Array] = dict(
|
|
64
66
|
outputs=self.outputs,
|
|
65
67
|
)
|
|
66
68
|
if self.updated_kv_cache is not None:
|
|
@@ -123,6 +125,46 @@ class DecoderLayerConfig:
|
|
|
123
125
|
post_mlp_norm=post_mlp_norm,
|
|
124
126
|
)
|
|
125
127
|
|
|
128
|
+
def empty(
|
|
129
|
+
self,
|
|
130
|
+
model_dim: int,
|
|
131
|
+
hidden_dim: int,
|
|
132
|
+
num_heads: int,
|
|
133
|
+
num_groups: int,
|
|
134
|
+
head_dim: int,
|
|
135
|
+
attention_scale: float | None,
|
|
136
|
+
sliding_window_size: int | None,
|
|
137
|
+
) -> "DecoderLayer":
|
|
138
|
+
pre_attention_norm = self.pre_attention_norm_config.empty(model_dim)
|
|
139
|
+
attention = self.attention_config.empty(
|
|
140
|
+
model_dim=model_dim,
|
|
141
|
+
num_heads=num_heads,
|
|
142
|
+
num_groups=num_groups,
|
|
143
|
+
head_dim=head_dim,
|
|
144
|
+
is_causal=True,
|
|
145
|
+
scale=attention_scale,
|
|
146
|
+
sliding_window_size=sliding_window_size,
|
|
147
|
+
)
|
|
148
|
+
if self.post_attention_norm_config is not None:
|
|
149
|
+
post_attention_norm = self.post_attention_norm_config.empty(model_dim)
|
|
150
|
+
else:
|
|
151
|
+
post_attention_norm = None
|
|
152
|
+
pre_mlp_norm = self.pre_mlp_norm_config.empty(model_dim)
|
|
153
|
+
mlp = self.mlp_config.empty(model_dim, hidden_dim)
|
|
154
|
+
if self.post_mlp_norm_config is not None:
|
|
155
|
+
post_mlp_norm = self.post_mlp_norm_config.empty(model_dim)
|
|
156
|
+
else:
|
|
157
|
+
post_mlp_norm = None
|
|
158
|
+
return DecoderLayer(
|
|
159
|
+
config=self,
|
|
160
|
+
pre_attention_norm=pre_attention_norm,
|
|
161
|
+
attention=attention,
|
|
162
|
+
post_attention_norm=post_attention_norm,
|
|
163
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
164
|
+
mlp=mlp,
|
|
165
|
+
post_mlp_norm=post_mlp_norm,
|
|
166
|
+
)
|
|
167
|
+
|
|
126
168
|
|
|
127
169
|
class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
128
170
|
pre_attention_norm: RMSNorm
|
|
@@ -168,6 +210,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
168
210
|
f" the up projection dim {self.mlp.hidden_dim}",
|
|
169
211
|
)
|
|
170
212
|
|
|
213
|
+
@eqx.filter_jit
|
|
171
214
|
def __call__(
|
|
172
215
|
self,
|
|
173
216
|
inputs: Float[Array, "suffix_tokens channels"],
|
|
@@ -226,8 +269,8 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
226
269
|
def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
|
|
227
270
|
return self.attention.init_static_kv_cache(capacity)
|
|
228
271
|
|
|
229
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
230
|
-
result =
|
|
272
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
273
|
+
result = dict(
|
|
231
274
|
pre_attention_norm=self.pre_attention_norm.export_weights(weight_layout),
|
|
232
275
|
attention=self.attention.export_weights(weight_layout),
|
|
233
276
|
pre_mlp_norm=self.pre_mlp_norm.export_weights(weight_layout),
|
|
@@ -238,3 +281,37 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
238
281
|
if self.post_mlp_norm is not None:
|
|
239
282
|
result["post_mlp_norm"] = self.post_mlp_norm.export_weights(weight_layout)
|
|
240
283
|
return result
|
|
284
|
+
|
|
285
|
+
def import_weights(
|
|
286
|
+
self,
|
|
287
|
+
weights: ParameterTree[Array],
|
|
288
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
289
|
+
) -> Self:
|
|
290
|
+
assert isinstance(weights, Mapping)
|
|
291
|
+
assert isinstance(weights["pre_attention_norm"], Mapping)
|
|
292
|
+
assert isinstance(weights["attention"], Mapping)
|
|
293
|
+
assert isinstance(weights["mlp"], Mapping)
|
|
294
|
+
assert isinstance(weights["pre_mlp_norm"], Mapping)
|
|
295
|
+
|
|
296
|
+
if self.post_attention_norm is not None:
|
|
297
|
+
assert isinstance(weights["post_attention_norm"], Mapping)
|
|
298
|
+
post_attention_norm = self.post_attention_norm.import_weights(
|
|
299
|
+
weights["post_attention_norm"],
|
|
300
|
+
weight_layout,
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
post_attention_norm = None
|
|
304
|
+
if self.post_mlp_norm is not None:
|
|
305
|
+
assert isinstance(weights["post_mlp_norm"], Mapping)
|
|
306
|
+
post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"], weight_layout)
|
|
307
|
+
else:
|
|
308
|
+
post_mlp_norm = None
|
|
309
|
+
return replace(
|
|
310
|
+
self,
|
|
311
|
+
pre_attention_norm=self.pre_attention_norm.import_weights(weights["pre_attention_norm"], weight_layout),
|
|
312
|
+
attention=self.attention.import_weights(weights["attention"], weight_layout),
|
|
313
|
+
post_attention_norm=post_attention_norm,
|
|
314
|
+
pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"], weight_layout),
|
|
315
|
+
mlp=self.mlp.import_weights(weights["mlp"], weight_layout),
|
|
316
|
+
post_mlp_norm=post_mlp_norm,
|
|
317
|
+
)
|
lalamo/modules/embedding.py
CHANGED
|
@@ -1,15 +1,23 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import Self
|
|
3
5
|
|
|
6
|
+
import equinox as eqx
|
|
4
7
|
import jax
|
|
5
8
|
import jax.numpy as jnp
|
|
6
|
-
from einops import rearrange
|
|
7
9
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
8
10
|
|
|
9
|
-
from lalamo.common import
|
|
11
|
+
from lalamo.common import ParameterTree, dummy_array
|
|
10
12
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
11
13
|
|
|
12
|
-
from .common import
|
|
14
|
+
from .common import (
|
|
15
|
+
LalamoModule,
|
|
16
|
+
WeightLayout,
|
|
17
|
+
from_layout,
|
|
18
|
+
into_layout,
|
|
19
|
+
register_config_union,
|
|
20
|
+
)
|
|
13
21
|
from .utils import apply_soft_capping
|
|
14
22
|
|
|
15
23
|
__all__ = [
|
|
@@ -38,6 +46,13 @@ class EmbeddingConfigBase:
|
|
|
38
46
|
key: PRNGKeyArray,
|
|
39
47
|
) -> "EmbeddingBase": ...
|
|
40
48
|
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def empty(
|
|
51
|
+
self,
|
|
52
|
+
vocab_size: int,
|
|
53
|
+
model_dim: int,
|
|
54
|
+
) -> "EmbeddingBase": ...
|
|
55
|
+
|
|
41
56
|
|
|
42
57
|
class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
|
|
43
58
|
@abstractmethod
|
|
@@ -54,16 +69,14 @@ class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
|
|
|
54
69
|
@abstractmethod
|
|
55
70
|
def model_dim(self) -> int: ...
|
|
56
71
|
|
|
57
|
-
@
|
|
58
|
-
def _default_weight_layout(cls) -> WeightLayout:
|
|
59
|
-
return WeightLayout.INPUT_OUTPUT
|
|
60
|
-
|
|
72
|
+
@eqx.filter_jit
|
|
61
73
|
def embed(self, x: Int[Array, " tokens"]) -> Float[Array, "tokens channels"]:
|
|
62
74
|
result = self._prepare_input_weights()[x]
|
|
63
75
|
if self.config.input_scale is not None:
|
|
64
76
|
result = result * jnp.array(self.config.input_scale, dtype=result.dtype)
|
|
65
77
|
return result
|
|
66
78
|
|
|
79
|
+
@eqx.filter_jit
|
|
67
80
|
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
68
81
|
logits = self._prepare_output_weights() @ x
|
|
69
82
|
if self.config.logits_soft_cap is not None:
|
|
@@ -85,6 +98,14 @@ class TiedEmbeddingConfig(EmbeddingConfigBase):
|
|
|
85
98
|
weights = jax.random.normal(key, (vocab_size, model_dim), dtype=self.precision)
|
|
86
99
|
return TiedEmbedding(config=self, weights=weights)
|
|
87
100
|
|
|
101
|
+
def empty(
|
|
102
|
+
self,
|
|
103
|
+
vocab_size: int,
|
|
104
|
+
model_dim: int,
|
|
105
|
+
) -> "TiedEmbedding":
|
|
106
|
+
weights = dummy_array((vocab_size, model_dim), dtype=self.precision)
|
|
107
|
+
return TiedEmbedding(config=self, weights=weights)
|
|
108
|
+
|
|
88
109
|
|
|
89
110
|
class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
|
|
90
111
|
weights: Float[Array, "vocabulary channels"]
|
|
@@ -115,8 +136,16 @@ class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
|
|
|
115
136
|
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
116
137
|
return self.weights
|
|
117
138
|
|
|
118
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
119
|
-
return
|
|
139
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
|
|
140
|
+
return {"weights": self.weights}
|
|
141
|
+
|
|
142
|
+
def import_weights(
|
|
143
|
+
self,
|
|
144
|
+
weights: ParameterTree[Array],
|
|
145
|
+
weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
|
|
146
|
+
) -> Self:
|
|
147
|
+
assert isinstance(weights, Mapping)
|
|
148
|
+
return replace(self, weights=weights["weights"])
|
|
120
149
|
|
|
121
150
|
|
|
122
151
|
@dataclass(frozen=True)
|
|
@@ -139,6 +168,19 @@ class UntiedEmbeddingConfig(EmbeddingConfigBase):
|
|
|
139
168
|
output_weights=output_weights,
|
|
140
169
|
)
|
|
141
170
|
|
|
171
|
+
def empty(
|
|
172
|
+
self,
|
|
173
|
+
vocab_size: int,
|
|
174
|
+
model_dim: int,
|
|
175
|
+
) -> "UntiedEmbedding":
|
|
176
|
+
input_weights = dummy_array((vocab_size, model_dim), dtype=self.precision)
|
|
177
|
+
output_weights = dummy_array((vocab_size, model_dim), dtype=self.precision)
|
|
178
|
+
return UntiedEmbedding(
|
|
179
|
+
config=self,
|
|
180
|
+
input_weights=input_weights,
|
|
181
|
+
output_weights=output_weights,
|
|
182
|
+
)
|
|
183
|
+
|
|
142
184
|
|
|
143
185
|
class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
|
|
144
186
|
input_weights: Float[Array, "vocabulary channels"]
|
|
@@ -186,21 +228,22 @@ class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
|
|
|
186
228
|
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
187
229
|
return self.output_weights
|
|
188
230
|
|
|
189
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
190
|
-
|
|
191
|
-
|
|
231
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
232
|
+
return {
|
|
233
|
+
"input_weights": self.input_weights,
|
|
234
|
+
"output_weights": into_layout(self.output_weights, weight_layout),
|
|
235
|
+
}
|
|
192
236
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
output_weights=output_weights,
|
|
237
|
+
def import_weights(
|
|
238
|
+
self,
|
|
239
|
+
weights: ParameterTree[Array],
|
|
240
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
241
|
+
) -> Self:
|
|
242
|
+
assert isinstance(weights, Mapping)
|
|
243
|
+
return replace(
|
|
244
|
+
self,
|
|
245
|
+
input_weights=weights["input_weights"],
|
|
246
|
+
output_weights=from_layout(weights["output_weights"], weight_layout),
|
|
204
247
|
)
|
|
205
248
|
|
|
206
249
|
|
|
@@ -225,6 +268,15 @@ class QuantizedTiedEmbeddingConfig(EmbeddingConfigBase):
|
|
|
225
268
|
weights = quantize_weights(weights * min_abs_val, self.embedding_quantization_mode)
|
|
226
269
|
return QuantizedTiedEmbedding(config=self, weights=weights, scales=scales)
|
|
227
270
|
|
|
271
|
+
def empty(
|
|
272
|
+
self,
|
|
273
|
+
vocab_size: int,
|
|
274
|
+
model_dim: int,
|
|
275
|
+
) -> "QuantizedTiedEmbedding":
|
|
276
|
+
scales = dummy_array(vocab_size, dtype=self.activation_precision)
|
|
277
|
+
weights = dummy_array((vocab_size, model_dim), dtype=self.activation_precision)
|
|
278
|
+
return QuantizedTiedEmbedding(config=self, weights=weights, scales=scales)
|
|
279
|
+
|
|
228
280
|
|
|
229
281
|
class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
230
282
|
weights: Float[Array, "vocabulary channels"]
|
|
@@ -257,7 +309,7 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
257
309
|
f" {self.config.activation_precision}"
|
|
258
310
|
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
259
311
|
)
|
|
260
|
-
weights_vocab_size,
|
|
312
|
+
weights_vocab_size, _ = self.weights.shape
|
|
261
313
|
(scales_vocab_size,) = self.scales.shape
|
|
262
314
|
if weights_vocab_size != scales_vocab_size:
|
|
263
315
|
raise ValueError(
|
|
@@ -281,15 +333,29 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
281
333
|
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
282
334
|
return self._prepare_weights()
|
|
283
335
|
|
|
336
|
+
@eqx.filter_jit
|
|
284
337
|
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
285
338
|
if self.config.activation_quantization_mode is not None:
|
|
286
339
|
x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
|
|
287
340
|
return super().readout(x)
|
|
288
341
|
|
|
289
|
-
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) ->
|
|
290
|
-
return
|
|
291
|
-
weights
|
|
292
|
-
scales
|
|
342
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
|
|
343
|
+
return {
|
|
344
|
+
"weights": into_layout(self.int_weights, weight_layout),
|
|
345
|
+
"scales": into_layout(self.scales, weight_layout),
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
def import_weights(
|
|
349
|
+
self,
|
|
350
|
+
weights: ParameterTree[Array],
|
|
351
|
+
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
352
|
+
) -> Self:
|
|
353
|
+
assert isinstance(weights, Mapping)
|
|
354
|
+
assert isinstance(weights["weights"], Array)
|
|
355
|
+
return replace(
|
|
356
|
+
self,
|
|
357
|
+
weights=from_layout(weights["weights"].astype(self.weights.dtype), weight_layout),
|
|
358
|
+
scales=from_layout(weights["scales"], weight_layout),
|
|
293
359
|
)
|
|
294
360
|
|
|
295
361
|
|
lalamo/modules/kv_cache.py
CHANGED
|
@@ -7,7 +7,7 @@ from jax.lax import dynamic_update_slice_in_dim
|
|
|
7
7
|
from jax.tree_util import register_pytree_node_class
|
|
8
8
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int
|
|
9
9
|
|
|
10
|
-
from lalamo.common import
|
|
10
|
+
from lalamo.common import ParameterTree
|
|
11
11
|
|
|
12
12
|
__all__ = ["DynamicKVCacheLayer", "KVCache", "KVCacheLayer", "StaticKVCacheLayer"]
|
|
13
13
|
|
|
@@ -43,8 +43,8 @@ class KVCacheLayer(eqx.Module):
|
|
|
43
43
|
added_length: Int[Array, ""] | int | None = None,
|
|
44
44
|
) -> Self: ...
|
|
45
45
|
|
|
46
|
-
def export(self) ->
|
|
47
|
-
return
|
|
46
|
+
def export(self) -> ParameterTree:
|
|
47
|
+
return dict(
|
|
48
48
|
keys=self.keys,
|
|
49
49
|
values=self.values,
|
|
50
50
|
)
|