lalamo 0.5.17__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +69 -17
- lalamo/common.py +14 -1
- lalamo/main.py +148 -27
- lalamo/message_processor.py +4 -1
- lalamo/model_import/common.py +8 -17
- lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
- lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
- lalamo/model_import/huggingface_generation_config.py +21 -3
- lalamo/model_import/loaders/executorch.py +2 -2
- lalamo/model_import/loaders/huggingface.py +3 -3
- lalamo/model_import/model_specs/common.py +4 -2
- lalamo/model_import/model_specs/lfm2.py +41 -9
- lalamo/models/language_model.py +7 -6
- lalamo/modules/activations.py +1 -1
- lalamo/modules/classifier.py +11 -24
- lalamo/modules/common.py +4 -1
- lalamo/modules/decoder.py +5 -11
- lalamo/modules/embedding.py +25 -62
- lalamo/modules/linear.py +19 -33
- lalamo/modules/mlp.py +9 -19
- lalamo/modules/mlx_interop.py +1 -1
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +1 -1
- lalamo/modules/token_mixers/attention.py +9 -27
- lalamo/modules/token_mixers/mamba.py +9 -24
- lalamo/modules/token_mixers/short_conv.py +5 -12
- lalamo/modules/transformer.py +10 -20
- lalamo/modules/transformer_layer.py +8 -20
- lalamo/registry_abc.py +4 -4
- lalamo/sampling.py +14 -0
- lalamo/speculator/estimator.py +3 -3
- lalamo/speculator/ngram.py +1 -1
- {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -1
- {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/RECORD +40 -40
- {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
- {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -289,7 +289,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
289
289
|
combined_up_gate_b = jnp.concatenate([up_b + 1.0, gate_b], axis=-1)
|
|
290
290
|
|
|
291
291
|
up_projection = load_parameters(
|
|
292
|
-
lambda m: (m.weights, m.biases),
|
|
292
|
+
lambda m: (m.weights, m.biases),
|
|
293
293
|
module.experts.up_projection,
|
|
294
294
|
(combined_up_gate_w, combined_up_gate_b),
|
|
295
295
|
)
|
|
@@ -309,7 +309,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
309
309
|
down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
|
|
310
310
|
|
|
311
311
|
down_projection = load_parameters(
|
|
312
|
-
lambda m: (m.weights, m.biases),
|
|
312
|
+
lambda m: (m.weights, m.biases),
|
|
313
313
|
module.experts.down_projection,
|
|
314
314
|
(down_w, down_b),
|
|
315
315
|
)
|
|
@@ -807,7 +807,7 @@ def load_huggingface_decoder(
|
|
|
807
807
|
weights_dict,
|
|
808
808
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
809
809
|
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
810
|
-
mixer_key[type(layer.config.mixer_config)],
|
|
810
|
+
mixer_key[type(layer.config.mixer_config)],
|
|
811
811
|
mlp_key,
|
|
812
812
|
pre_mixer_norm_key,
|
|
813
813
|
pre_mlp_norm_key,
|
|
@@ -7,13 +7,14 @@ from contextlib import contextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from enum import Enum, StrEnum
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import ClassVar, cast, get_args, get_origin
|
|
10
|
+
from typing import Any, ClassVar, cast, get_args, get_origin
|
|
11
11
|
|
|
12
12
|
import cattrs
|
|
13
13
|
import jax.numpy as jnp
|
|
14
14
|
from jaxtyping import Array, DTypeLike
|
|
15
15
|
|
|
16
16
|
from lalamo.model_import.decoder_configs import ForeignConfig
|
|
17
|
+
from lalamo.models.language_model import GenerationConfig
|
|
17
18
|
from lalamo.quantization import QuantizationMode
|
|
18
19
|
from lalamo.safetensors import safe_read
|
|
19
20
|
from lalamo.utils import MapDictValues
|
|
@@ -86,7 +87,7 @@ class ConfigMap:
|
|
|
86
87
|
model_config: FileSpec = field(default=FileSpec("config.json"))
|
|
87
88
|
tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
|
|
88
89
|
tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
|
|
89
|
-
generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
|
|
90
|
+
generation_config: FileSpec | GenerationConfig | None = field(default=FileSpec("generation_config.json"))
|
|
90
91
|
chat_template: FileSpec | JSONFieldSpec | str | None = None
|
|
91
92
|
|
|
92
93
|
|
|
@@ -123,6 +124,7 @@ def _structure_chat_template(value: object, _type: object) -> FileSpec | JSONFie
|
|
|
123
124
|
if isinstance(value, str):
|
|
124
125
|
return value
|
|
125
126
|
if isinstance(value, dict):
|
|
127
|
+
value = cast("dict[Any, Any]", value) # ty bug??? Why is just `dict` != `dict[Any, Any]`?
|
|
126
128
|
if "file_spec" in value and "field_name" in value:
|
|
127
129
|
return JSONFieldSpec(
|
|
128
130
|
file_spec=FileSpec(**value["file_spec"]),
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
from itertools import chain, product
|
|
2
|
+
|
|
1
3
|
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
4
|
+
from lalamo.models.language_model import GenerationConfig
|
|
2
5
|
from lalamo.quantization import QuantizationMode
|
|
3
6
|
|
|
4
7
|
from .common import ConfigMap, FileSpec, ModelSpec
|
|
@@ -6,26 +9,55 @@ from .common import ConfigMap, FileSpec, ModelSpec
|
|
|
6
9
|
__all__ = ["LFM2_MODELS"]
|
|
7
10
|
|
|
8
11
|
|
|
9
|
-
def
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
def _lfm_repo(family: str, size: str, variant: str | None, quantization: QuantizationMode | None) -> tuple[str, str]:
|
|
13
|
+
return (
|
|
14
|
+
"LiquidAI" if quantization is None else "mlx-community",
|
|
15
|
+
f"{family}-{size}"
|
|
16
|
+
f"{f'-{variant}' if variant is not None else ''}"
|
|
17
|
+
f"{f'-{quantization.bits}bit' if quantization is not None else ''}",
|
|
18
|
+
)
|
|
13
19
|
|
|
14
20
|
|
|
15
|
-
|
|
21
|
+
_LFM20_MODELS = [
|
|
16
22
|
ModelSpec(
|
|
17
23
|
vendor="LiquidAI",
|
|
18
24
|
family="LFM2",
|
|
19
|
-
name=
|
|
25
|
+
name=_lfm_repo("LFM2", size, variant, quantization)[1],
|
|
20
26
|
size=size,
|
|
21
|
-
repo="/".join(
|
|
27
|
+
repo="/".join(_lfm_repo("LFM2", size, variant, quantization)),
|
|
22
28
|
config_type=HFLFM2Config,
|
|
23
29
|
quantization=quantization,
|
|
24
30
|
configs=ConfigMap(
|
|
31
|
+
generation_config=GenerationConfig(temperature=0.3, min_p=0.15), # , repetition_penalty=1.05
|
|
25
32
|
chat_template=FileSpec("chat_template.jinja"),
|
|
26
33
|
),
|
|
27
34
|
use_cases=tuple(),
|
|
28
35
|
)
|
|
29
|
-
for size
|
|
30
|
-
|
|
36
|
+
for size, variant, quantization in chain(
|
|
37
|
+
product(["350M", "700M", "1.2B"], [None], [None, QuantizationMode.UINT4, QuantizationMode.UINT8]),
|
|
38
|
+
product(["2.6B"], [None, "Exp"], [None]),
|
|
39
|
+
product(["2.6B"], ["Exp"], [QuantizationMode.UINT4, QuantizationMode.UINT8]),
|
|
40
|
+
)
|
|
31
41
|
]
|
|
42
|
+
|
|
43
|
+
_LFM25_MODELS = [
|
|
44
|
+
ModelSpec(
|
|
45
|
+
vendor="LiquidAI",
|
|
46
|
+
family="LFM2.5",
|
|
47
|
+
name=_lfm_repo("LFM2.5", size, variant, quantization)[1],
|
|
48
|
+
size=size,
|
|
49
|
+
repo="/".join(_lfm_repo("LFM2.5", size, variant, quantization)),
|
|
50
|
+
config_type=HFLFM2Config,
|
|
51
|
+
quantization=quantization,
|
|
52
|
+
configs=ConfigMap(
|
|
53
|
+
generation_config=GenerationConfig(temperature=0.1, top_k=50, top_p=0.1), # , repetition_penalty=1.05
|
|
54
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
55
|
+
),
|
|
56
|
+
use_cases=tuple(),
|
|
57
|
+
)
|
|
58
|
+
for size, variant, quantization in chain(
|
|
59
|
+
product(["1.2B"], ["Instruct"], [None]),
|
|
60
|
+
)
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
LFM2_MODELS = _LFM20_MODELS + _LFM25_MODELS
|
lalamo/models/language_model.py
CHANGED
|
@@ -64,14 +64,15 @@ class GenerationResults(NamedTuple):
|
|
|
64
64
|
|
|
65
65
|
@dataclass(frozen=True)
|
|
66
66
|
class GenerationConfig:
|
|
67
|
-
stop_token_ids: tuple[int, ...]
|
|
68
|
-
temperature: float | None
|
|
69
|
-
top_k: int | None
|
|
70
|
-
top_p: float | None
|
|
71
|
-
|
|
67
|
+
stop_token_ids: tuple[int, ...] = tuple()
|
|
68
|
+
temperature: float | None = None
|
|
69
|
+
top_k: int | None = None
|
|
70
|
+
top_p: float | None = None
|
|
71
|
+
min_p: float | None = None
|
|
72
|
+
banned_tokens: tuple[int, ...] | None = None
|
|
72
73
|
|
|
73
74
|
def default_policy(self) -> SamplingPolicy:
|
|
74
|
-
return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
|
|
75
|
+
return make_policy(self.temperature, self.top_k, self.top_p, self.min_p, self.banned_tokens)
|
|
75
76
|
|
|
76
77
|
|
|
77
78
|
@dataclass(frozen=True)
|
lalamo/modules/activations.py
CHANGED
lalamo/modules/classifier.py
CHANGED
|
@@ -9,7 +9,7 @@ from jax import numpy as jnp
|
|
|
9
9
|
from jax import vmap
|
|
10
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree
|
|
12
|
+
from lalamo.common import ParameterTree, require_tree
|
|
13
13
|
from lalamo.modules import Activation
|
|
14
14
|
from lalamo.modules.normalization import NormalizationConfig
|
|
15
15
|
from lalamo.modules.transformer import (
|
|
@@ -67,7 +67,7 @@ class PredictionHeadConfig:
|
|
|
67
67
|
def random_init(self, input_size: int, num_labels: int, key: PRNGKeyArray) -> "PredictionHead":
|
|
68
68
|
dense_key, readout_key = jax.random.split(key)
|
|
69
69
|
dense_layer = self.dense_config.random_init(
|
|
70
|
-
input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key
|
|
70
|
+
input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key,
|
|
71
71
|
)
|
|
72
72
|
norm = self.normalization_config.empty(input_size)
|
|
73
73
|
readout = self.readout_config.random_init(
|
|
@@ -117,19 +117,13 @@ class PredictionHead(LalamoModule[PredictionHeadConfig]):
|
|
|
117
117
|
)
|
|
118
118
|
return result
|
|
119
119
|
|
|
120
|
-
def import_weights(
|
|
121
|
-
self,
|
|
122
|
-
weights: ParameterTree[Array],
|
|
123
|
-
) -> Self:
|
|
120
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
124
121
|
assert isinstance(weights, Mapping)
|
|
125
|
-
assert isinstance(weights["dense"], Mapping)
|
|
126
|
-
assert isinstance(weights["norm"], Mapping)
|
|
127
|
-
assert isinstance(weights["readout"], Mapping)
|
|
128
122
|
return replace(
|
|
129
123
|
self,
|
|
130
|
-
dense=self.dense.import_weights(weights["dense"]),
|
|
131
|
-
norm=self.norm.import_weights(weights["norm"]),
|
|
132
|
-
readout=self.readout.import_weights(weights["readout"]),
|
|
124
|
+
dense=self.dense.import_weights(require_tree(weights["dense"])),
|
|
125
|
+
norm=self.norm.import_weights(require_tree(weights["norm"])),
|
|
126
|
+
readout=self.readout.import_weights(require_tree(weights["readout"])),
|
|
133
127
|
)
|
|
134
128
|
|
|
135
129
|
|
|
@@ -321,19 +315,12 @@ class Classifier(LalamoModule[ClassifierConfig]):
|
|
|
321
315
|
)
|
|
322
316
|
return result
|
|
323
317
|
|
|
324
|
-
def import_weights(
|
|
325
|
-
self,
|
|
326
|
-
weights: ParameterTree[Array],
|
|
327
|
-
) -> Self:
|
|
318
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
328
319
|
assert isinstance(weights, Mapping)
|
|
329
|
-
assert isinstance(weights["embedding"], Mapping)
|
|
330
|
-
assert isinstance(weights["embedding_norm"], Mapping)
|
|
331
|
-
assert isinstance(weights["transformer"], Mapping)
|
|
332
|
-
assert isinstance(weights["prediction_head"], Mapping)
|
|
333
320
|
return replace(
|
|
334
321
|
self,
|
|
335
|
-
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
336
|
-
embedding_norm=self.embedding_norm.import_weights(weights["embedding_norm"]),
|
|
337
|
-
transformer=self.transformer.import_weights(weights["transformer"]),
|
|
338
|
-
prediction_head=self.prediction_head.import_weights(weights["prediction_head"]),
|
|
322
|
+
embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
|
|
323
|
+
embedding_norm=self.embedding_norm.import_weights(require_tree(weights["embedding_norm"])),
|
|
324
|
+
transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
|
|
325
|
+
prediction_head=self.prediction_head.import_weights(require_tree(weights["prediction_head"])),
|
|
339
326
|
)
|
lalamo/modules/common.py
CHANGED
|
@@ -9,15 +9,18 @@ from cattrs import Converter
|
|
|
9
9
|
from jax import numpy as jnp
|
|
10
10
|
from jaxtyping import Array, DTypeLike
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree
|
|
12
|
+
from lalamo.common import ParameterTree, require_array, require_tree
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
15
|
"DummyUnionMember",
|
|
16
16
|
"ForwardPassMode",
|
|
17
17
|
"LalamoModule",
|
|
18
|
+
"ParameterTree",
|
|
18
19
|
"PositionalEmbeddingSelector",
|
|
19
20
|
"config_converter",
|
|
20
21
|
"register_config_union",
|
|
22
|
+
"require_array",
|
|
23
|
+
"require_tree",
|
|
21
24
|
]
|
|
22
25
|
|
|
23
26
|
|
lalamo/modules/decoder.py
CHANGED
|
@@ -7,7 +7,7 @@ import jax
|
|
|
7
7
|
from jax import vmap
|
|
8
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
9
|
|
|
10
|
-
from lalamo.common import ParameterTree
|
|
10
|
+
from lalamo.common import ParameterTree, require_tree
|
|
11
11
|
|
|
12
12
|
from .common import ForwardPassMode, LalamoModule
|
|
13
13
|
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
@@ -126,7 +126,7 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
126
126
|
return self.embedding.activation_precision
|
|
127
127
|
|
|
128
128
|
@eqx.filter_jit
|
|
129
|
-
def __call__(
|
|
129
|
+
def __call__(
|
|
130
130
|
self,
|
|
131
131
|
token_ids: Int[Array, "batch suffix_tokens"],
|
|
132
132
|
token_positions: Int[Array, "batch suffix_tokens"],
|
|
@@ -193,16 +193,10 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
193
193
|
transformer=self.transformer.export_weights(),
|
|
194
194
|
)
|
|
195
195
|
|
|
196
|
-
def import_weights(
|
|
197
|
-
self,
|
|
198
|
-
weights: ParameterTree[Array],
|
|
199
|
-
) -> Self:
|
|
196
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
200
197
|
assert isinstance(weights, Mapping)
|
|
201
|
-
assert isinstance(weights["embedding"], Mapping)
|
|
202
|
-
assert isinstance(weights["transformer"], Mapping)
|
|
203
|
-
|
|
204
198
|
return replace(
|
|
205
199
|
self,
|
|
206
|
-
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
207
|
-
transformer=self.transformer.import_weights(weights["transformer"]),
|
|
200
|
+
embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
|
|
201
|
+
transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
|
|
208
202
|
)
|
lalamo/modules/embedding.py
CHANGED
|
@@ -9,7 +9,7 @@ import jax.numpy as jnp
|
|
|
9
9
|
from einops import rearrange
|
|
10
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree, dummy_array
|
|
12
|
+
from lalamo.common import ParameterTree, dummy_array, require_array
|
|
13
13
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
14
14
|
from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
|
|
15
15
|
|
|
@@ -355,21 +355,15 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
355
355
|
"scales": self.scales,
|
|
356
356
|
}
|
|
357
357
|
|
|
358
|
-
def import_weights(
|
|
359
|
-
self,
|
|
360
|
-
weights: ParameterTree[Array],
|
|
361
|
-
) -> Self:
|
|
358
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
362
359
|
assert isinstance(weights, Mapping)
|
|
363
|
-
|
|
364
|
-
stored_weights = weights["weights"]
|
|
365
|
-
|
|
360
|
+
stored_weights = require_array(weights["weights"])
|
|
366
361
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
367
362
|
stored_weights = jax_uint8_to_unpacked_uint4(stored_weights)
|
|
368
|
-
|
|
369
363
|
return replace(
|
|
370
364
|
self,
|
|
371
365
|
weights=stored_weights.astype(self.weights.dtype),
|
|
372
|
-
scales=weights["scales"],
|
|
366
|
+
scales=require_array(weights["scales"]),
|
|
373
367
|
)
|
|
374
368
|
|
|
375
369
|
|
|
@@ -472,25 +466,16 @@ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
|
|
|
472
466
|
"biases": self.biases,
|
|
473
467
|
}
|
|
474
468
|
|
|
475
|
-
def import_weights(
|
|
476
|
-
self,
|
|
477
|
-
weights: ParameterTree[Array],
|
|
478
|
-
) -> Self:
|
|
469
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
479
470
|
assert isinstance(weights, Mapping)
|
|
480
|
-
|
|
481
|
-
assert isinstance(weights["scales"], Array)
|
|
482
|
-
assert isinstance(weights["biases"], Array)
|
|
483
|
-
|
|
484
|
-
unpacked_weights = weights["weights"]
|
|
485
|
-
|
|
471
|
+
unpacked_weights = require_array(weights["weights"])
|
|
486
472
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
487
|
-
unpacked_weights = jax_uint8_to_unpacked_uint4(
|
|
488
|
-
|
|
473
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
|
|
489
474
|
return replace(
|
|
490
475
|
self,
|
|
491
476
|
weights=unpacked_weights.astype(self.weights.dtype),
|
|
492
|
-
scales=weights["scales"],
|
|
493
|
-
biases=weights["biases"],
|
|
477
|
+
scales=require_array(weights["scales"]),
|
|
478
|
+
biases=require_array(weights["biases"]),
|
|
494
479
|
)
|
|
495
480
|
|
|
496
481
|
|
|
@@ -630,33 +615,21 @@ class MLXQuantizedUntiedEmbedding(EmbeddingBase[MLXQuantizedUntiedEmbeddingConfi
|
|
|
630
615
|
"output_biases": self.output_biases,
|
|
631
616
|
}
|
|
632
617
|
|
|
633
|
-
def import_weights(
|
|
634
|
-
self,
|
|
635
|
-
weights: ParameterTree[Array],
|
|
636
|
-
) -> Self:
|
|
618
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
637
619
|
assert isinstance(weights, Mapping)
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
assert isinstance(weights["input_biases"], Array)
|
|
641
|
-
assert isinstance(weights["output_weights"], Array)
|
|
642
|
-
assert isinstance(weights["output_scales"], Array)
|
|
643
|
-
assert isinstance(weights["output_biases"], Array)
|
|
644
|
-
|
|
645
|
-
unpacked_input_weights = weights["input_weights"]
|
|
646
|
-
unpacked_output_weights = weights["output_weights"]
|
|
647
|
-
|
|
620
|
+
unpacked_input_weights = require_array(weights["input_weights"])
|
|
621
|
+
unpacked_output_weights = require_array(weights["output_weights"])
|
|
648
622
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
649
|
-
unpacked_input_weights = jax_uint8_to_unpacked_uint4(
|
|
650
|
-
unpacked_output_weights = jax_uint8_to_unpacked_uint4(
|
|
651
|
-
|
|
623
|
+
unpacked_input_weights = jax_uint8_to_unpacked_uint4(unpacked_input_weights)
|
|
624
|
+
unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
|
|
652
625
|
return replace(
|
|
653
626
|
self,
|
|
654
627
|
input_weights=unpacked_input_weights.astype(self.input_weights.dtype),
|
|
655
|
-
input_scales=weights["input_scales"],
|
|
656
|
-
input_biases=weights["input_biases"],
|
|
628
|
+
input_scales=require_array(weights["input_scales"]),
|
|
629
|
+
input_biases=require_array(weights["input_biases"]),
|
|
657
630
|
output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
|
|
658
|
-
output_scales=weights["output_scales"],
|
|
659
|
-
output_biases=weights["output_biases"],
|
|
631
|
+
output_scales=require_array(weights["output_scales"]),
|
|
632
|
+
output_biases=require_array(weights["output_biases"]),
|
|
660
633
|
)
|
|
661
634
|
|
|
662
635
|
|
|
@@ -765,27 +738,17 @@ class MLXSemiQuantizedUntiedEmbedding(EmbeddingBase[MLXSemiQuantizedUntiedEmbedd
|
|
|
765
738
|
"output_biases": self.output_biases,
|
|
766
739
|
}
|
|
767
740
|
|
|
768
|
-
def import_weights(
|
|
769
|
-
self,
|
|
770
|
-
weights: ParameterTree[Array],
|
|
771
|
-
) -> Self:
|
|
741
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
772
742
|
assert isinstance(weights, Mapping)
|
|
773
|
-
|
|
774
|
-
assert isinstance(weights["output_weights"], Array)
|
|
775
|
-
assert isinstance(weights["output_scales"], Array)
|
|
776
|
-
assert isinstance(weights["output_biases"], Array)
|
|
777
|
-
|
|
778
|
-
unpacked_output_weights = weights["output_weights"]
|
|
779
|
-
|
|
743
|
+
unpacked_output_weights = require_array(weights["output_weights"])
|
|
780
744
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
781
|
-
unpacked_output_weights = jax_uint8_to_unpacked_uint4(
|
|
782
|
-
|
|
745
|
+
unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
|
|
783
746
|
return replace(
|
|
784
747
|
self,
|
|
785
|
-
input_weights=weights["input_weights"],
|
|
748
|
+
input_weights=require_array(weights["input_weights"]),
|
|
786
749
|
output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
|
|
787
|
-
output_scales=weights["output_scales"],
|
|
788
|
-
output_biases=weights["output_biases"],
|
|
750
|
+
output_scales=require_array(weights["output_scales"]),
|
|
751
|
+
output_biases=require_array(weights["output_biases"]),
|
|
789
752
|
)
|
|
790
753
|
|
|
791
754
|
|
|
@@ -799,4 +762,4 @@ EmbeddingConfig = (
|
|
|
799
762
|
)
|
|
800
763
|
|
|
801
764
|
|
|
802
|
-
register_config_union(EmbeddingConfig)
|
|
765
|
+
register_config_union(EmbeddingConfig)
|
lalamo/modules/linear.py
CHANGED
|
@@ -2,7 +2,7 @@ import math
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections.abc import Mapping, Sequence
|
|
4
4
|
from dataclasses import dataclass, replace
|
|
5
|
-
from typing import Self
|
|
5
|
+
from typing import Self, cast
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
import jax
|
|
@@ -10,7 +10,7 @@ import jax.numpy as jnp
|
|
|
10
10
|
from einops import rearrange
|
|
11
11
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
12
12
|
|
|
13
|
-
from lalamo.common import ParameterTree, dummy_array
|
|
13
|
+
from lalamo.common import ParameterTree, dummy_array, require_array
|
|
14
14
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
15
15
|
from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
|
|
16
16
|
|
|
@@ -464,7 +464,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
|
|
|
464
464
|
|
|
465
465
|
return packed
|
|
466
466
|
|
|
467
|
-
def __post_init__(self) -> None:
|
|
467
|
+
def __post_init__(self) -> None:
|
|
468
468
|
if self.weights.dtype != self.config.activation_precision:
|
|
469
469
|
raise ValueError(
|
|
470
470
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
@@ -572,26 +572,19 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
|
|
|
572
572
|
result["biases"] = self.biases
|
|
573
573
|
return result
|
|
574
574
|
|
|
575
|
-
def import_weights(
|
|
576
|
-
self,
|
|
577
|
-
weights: ParameterTree[Array],
|
|
578
|
-
) -> Self:
|
|
575
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
579
576
|
assert isinstance(weights, Mapping)
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
unpacked_weights = weights["weights"]
|
|
583
|
-
unpacked_zero_points = weights["zero_points"]
|
|
584
|
-
|
|
577
|
+
unpacked_weights = require_array(weights["weights"])
|
|
578
|
+
unpacked_zero_points = require_array(weights["zero_points"])
|
|
585
579
|
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
586
|
-
unpacked_weights = jax_uint8_to_unpacked_uint4(
|
|
587
|
-
unpacked_zero_points = jax_uint8_to_unpacked_uint4(
|
|
588
|
-
|
|
580
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
|
|
581
|
+
unpacked_zero_points = jax_uint8_to_unpacked_uint4(unpacked_zero_points)
|
|
589
582
|
return replace(
|
|
590
583
|
self,
|
|
591
584
|
weights=unpacked_weights.astype(self.weights.dtype),
|
|
592
|
-
scales=weights["scales"],
|
|
585
|
+
scales=require_array(weights["scales"]),
|
|
593
586
|
zero_points=unpacked_zero_points.astype(self.zero_points.dtype),
|
|
594
|
-
biases=weights["biases"] if self.has_biases else None,
|
|
587
|
+
biases=require_array(weights["biases"]) if self.has_biases else None,
|
|
595
588
|
)
|
|
596
589
|
|
|
597
590
|
|
|
@@ -740,7 +733,7 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
|
|
|
740
733
|
|
|
741
734
|
return packed
|
|
742
735
|
|
|
743
|
-
def __post_init__(self) -> None:
|
|
736
|
+
def __post_init__(self) -> None:
|
|
744
737
|
if self.weights.dtype != self.config.activation_precision:
|
|
745
738
|
raise ValueError(
|
|
746
739
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
@@ -847,24 +840,17 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
|
|
|
847
840
|
result["biases"] = self.biases
|
|
848
841
|
return result
|
|
849
842
|
|
|
850
|
-
def import_weights(
|
|
851
|
-
self,
|
|
852
|
-
weights: ParameterTree[Array],
|
|
853
|
-
) -> Self:
|
|
843
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
854
844
|
assert isinstance(weights, Mapping)
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
unpacked_weights = weights["weights"]
|
|
858
|
-
|
|
845
|
+
unpacked_weights = require_array(weights["weights"])
|
|
859
846
|
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
860
|
-
unpacked_weights = jax_uint8_to_unpacked_uint4(
|
|
861
|
-
|
|
847
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
|
|
862
848
|
return replace(
|
|
863
849
|
self,
|
|
864
850
|
weights=unpacked_weights.astype(self.weights.dtype),
|
|
865
|
-
scales=weights["scales"],
|
|
866
|
-
deq_biases=weights["deq_biases"],
|
|
867
|
-
biases=weights["biases"] if self.has_biases else None,
|
|
851
|
+
scales=require_array(weights["scales"]),
|
|
852
|
+
deq_biases=require_array(weights["deq_biases"]),
|
|
853
|
+
biases=require_array(weights["biases"]) if self.has_biases else None,
|
|
868
854
|
)
|
|
869
855
|
|
|
870
856
|
|
|
@@ -1113,7 +1099,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
1113
1099
|
self,
|
|
1114
1100
|
weights: ParameterTree[Array],
|
|
1115
1101
|
) -> Self:
|
|
1116
|
-
base = super().import_weights(weights)
|
|
1102
|
+
base = cast("Self", super().import_weights(weights)) # ty bug
|
|
1117
1103
|
assert isinstance(weights, Mapping)
|
|
1118
1104
|
assert isinstance(weights["up_weights"], Sequence)
|
|
1119
1105
|
return replace(
|
|
@@ -1126,4 +1112,4 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
1126
1112
|
LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | MLXQuantizedLinearConfig | QLoRALinearConfig
|
|
1127
1113
|
|
|
1128
1114
|
|
|
1129
|
-
register_config_union(LinearConfig)
|
|
1115
|
+
register_config_union(LinearConfig)
|
lalamo/modules/mlp.py
CHANGED
|
@@ -12,7 +12,7 @@ from einops import rearrange
|
|
|
12
12
|
from jax import vmap
|
|
13
13
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
14
14
|
|
|
15
|
-
from lalamo.common import ParameterTree
|
|
15
|
+
from lalamo.common import ParameterTree, require_tree
|
|
16
16
|
from lalamo.modules.utils import vmap_twice
|
|
17
17
|
|
|
18
18
|
from .activations import Activation
|
|
@@ -242,17 +242,12 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
|
|
|
242
242
|
"down_projection": self.down_projection.export_weights(),
|
|
243
243
|
}
|
|
244
244
|
|
|
245
|
-
def import_weights(
|
|
246
|
-
self,
|
|
247
|
-
weights: ParameterTree[Array],
|
|
248
|
-
) -> Self:
|
|
245
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
249
246
|
assert isinstance(weights, Mapping)
|
|
250
|
-
assert isinstance(weights["up_projection"], Mapping)
|
|
251
|
-
assert isinstance(weights["down_projection"], Mapping)
|
|
252
247
|
return replace(
|
|
253
248
|
self,
|
|
254
|
-
up_projection=self.up_projection.import_weights(weights["up_projection"]),
|
|
255
|
-
down_projection=self.down_projection.import_weights(weights["down_projection"]),
|
|
249
|
+
up_projection=self.up_projection.import_weights(require_tree(weights["up_projection"])),
|
|
250
|
+
down_projection=self.down_projection.import_weights(require_tree(weights["down_projection"])),
|
|
256
251
|
)
|
|
257
252
|
|
|
258
253
|
|
|
@@ -285,7 +280,7 @@ class SoftmaxRouting(RoutingFunctionBase):
|
|
|
285
280
|
RoutingFunction = SoftmaxRouting | DummyUnionMember
|
|
286
281
|
|
|
287
282
|
|
|
288
|
-
register_config_union(RoutingFunction)
|
|
283
|
+
register_config_union(RoutingFunction)
|
|
289
284
|
|
|
290
285
|
|
|
291
286
|
@dataclass(frozen=True)
|
|
@@ -486,21 +481,16 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
|
|
|
486
481
|
"experts": self.experts.export_weights(),
|
|
487
482
|
}
|
|
488
483
|
|
|
489
|
-
def import_weights(
|
|
490
|
-
self,
|
|
491
|
-
weights: ParameterTree[Array],
|
|
492
|
-
) -> Self:
|
|
484
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
493
485
|
assert isinstance(weights, Mapping)
|
|
494
|
-
assert isinstance(weights["router"], Mapping)
|
|
495
|
-
assert isinstance(weights["experts"], Mapping)
|
|
496
486
|
return replace(
|
|
497
487
|
self,
|
|
498
|
-
router=self.router.import_weights(weights["router"]),
|
|
499
|
-
experts=self.experts.import_weights(weights["experts"]),
|
|
488
|
+
router=self.router.import_weights(require_tree(weights["router"])),
|
|
489
|
+
experts=self.experts.import_weights(require_tree(weights["experts"])),
|
|
500
490
|
)
|
|
501
491
|
|
|
502
492
|
|
|
503
493
|
MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
|
|
504
494
|
|
|
505
495
|
|
|
506
|
-
register_config_union(MLPConfig)
|
|
496
|
+
register_config_union(MLPConfig)
|
lalamo/modules/mlx_interop.py
CHANGED
lalamo/modules/rope.py
CHANGED