lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +26 -2
- lalamo/commands.py +429 -0
- lalamo/common.py +14 -1
- lalamo/main.py +375 -229
- lalamo/message_processor.py +4 -1
- lalamo/model_import/common.py +8 -17
- lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
- lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
- lalamo/model_import/huggingface_generation_config.py +21 -3
- lalamo/model_import/loaders/executorch.py +2 -2
- lalamo/model_import/loaders/huggingface.py +3 -3
- lalamo/model_import/model_specs/common.py +8 -4
- lalamo/model_import/model_specs/lfm2.py +41 -9
- lalamo/models/common.py +3 -3
- lalamo/models/language_model.py +7 -6
- lalamo/modules/activations.py +1 -1
- lalamo/modules/classifier.py +11 -24
- lalamo/modules/common.py +4 -1
- lalamo/modules/decoder.py +5 -11
- lalamo/modules/embedding.py +25 -62
- lalamo/modules/linear.py +19 -33
- lalamo/modules/mlp.py +9 -19
- lalamo/modules/mlx_interop.py +1 -1
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +1 -1
- lalamo/modules/token_mixers/attention.py +9 -27
- lalamo/modules/token_mixers/mamba.py +9 -24
- lalamo/modules/token_mixers/short_conv.py +5 -12
- lalamo/modules/transformer.py +10 -20
- lalamo/modules/transformer_layer.py +8 -20
- lalamo/registry_abc.py +4 -4
- lalamo/safetensors.py +97 -0
- lalamo/sampling.py +14 -0
- lalamo/speculator/estimator.py +11 -4
- lalamo/speculator/ngram.py +1 -1
- lalamo/utils.py +0 -13
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/message_processor.py
CHANGED
|
@@ -169,7 +169,10 @@ class MessageProcessor:
|
|
|
169
169
|
def __post_init__(self) -> None:
|
|
170
170
|
if self.output_parser_regex is not None:
|
|
171
171
|
all_fields = AssistantMessage.__dataclass_fields__
|
|
172
|
-
|
|
172
|
+
# NOTE: str type annotations are assumed to be required
|
|
173
|
+
required_fields = {
|
|
174
|
+
k: v for k, v in all_fields.items() if isinstance(v.type, str) or v.type == (v.type | None)
|
|
175
|
+
}
|
|
173
176
|
named_groups = self.output_parser_regex.groupindex
|
|
174
177
|
invalid_groups = set(named_groups) - set(all_fields)
|
|
175
178
|
if invalid_groups:
|
lalamo/model_import/common.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
from collections import ChainMap
|
|
4
4
|
from collections.abc import Callable
|
|
5
5
|
from contextlib import ExitStack
|
|
6
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, replace
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import NamedTuple
|
|
9
9
|
|
|
@@ -20,7 +20,7 @@ from lalamo.quantization import QuantizationMode
|
|
|
20
20
|
from lalamo.utils import process_chat_template
|
|
21
21
|
|
|
22
22
|
from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
|
|
23
|
-
from .huggingface_generation_config import HFGenerationConfig
|
|
23
|
+
from .huggingface_generation_config import HFGenerationConfig, _policy_from_hf_config
|
|
24
24
|
from .huggingface_tokenizer_config import HFTokenizerConfig
|
|
25
25
|
from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, ModelType, UseCase
|
|
26
26
|
from .model_specs.common import JSONFieldSpec
|
|
@@ -34,6 +34,7 @@ __all__ = [
|
|
|
34
34
|
"ModelSpec",
|
|
35
35
|
"ModelType",
|
|
36
36
|
"StatusEvent",
|
|
37
|
+
"download_file",
|
|
37
38
|
"import_model",
|
|
38
39
|
]
|
|
39
40
|
|
|
@@ -239,24 +240,14 @@ def _import_language_model(
|
|
|
239
240
|
|
|
240
241
|
stop_token_ids = tuple(foreign_decoder_config.eos_token_ids)
|
|
241
242
|
|
|
242
|
-
if model_spec.configs.generation_config
|
|
243
|
+
if isinstance(model_spec.configs.generation_config, GenerationConfig):
|
|
244
|
+
generation_config = replace(model_spec.configs.generation_config, stop_token_ids=stop_token_ids)
|
|
245
|
+
elif isinstance(model_spec.configs.generation_config, FileSpec):
|
|
243
246
|
hf_generation_config_file = download_file(model_spec.configs.generation_config, model_spec.repo)
|
|
244
247
|
hf_generation_config = HFGenerationConfig.from_json(hf_generation_config_file)
|
|
245
|
-
generation_config =
|
|
246
|
-
stop_token_ids=stop_token_ids,
|
|
247
|
-
temperature=hf_generation_config.temperature,
|
|
248
|
-
top_p=hf_generation_config.top_p,
|
|
249
|
-
top_k=hf_generation_config.top_k,
|
|
250
|
-
banned_tokens=None,
|
|
251
|
-
)
|
|
248
|
+
generation_config = _policy_from_hf_config(hf_generation_config, stop_token_ids)
|
|
252
249
|
else:
|
|
253
|
-
generation_config = GenerationConfig(
|
|
254
|
-
stop_token_ids=stop_token_ids,
|
|
255
|
-
temperature=None,
|
|
256
|
-
top_p=None,
|
|
257
|
-
top_k=None,
|
|
258
|
-
banned_tokens=None,
|
|
259
|
-
)
|
|
250
|
+
generation_config = GenerationConfig(stop_token_ids)
|
|
260
251
|
|
|
261
252
|
language_model_config = LanguageModelConfig(
|
|
262
253
|
model_config=decoder.config,
|
|
@@ -2,6 +2,7 @@ from collections.abc import Mapping
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Literal
|
|
4
4
|
|
|
5
|
+
import jax.numpy as jnp
|
|
5
6
|
from jaxtyping import DTypeLike
|
|
6
7
|
|
|
7
8
|
from lalamo.modules import (
|
|
@@ -50,7 +51,6 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
50
51
|
conv_L_cache: int # noqa: N815
|
|
51
52
|
conv_bias: bool
|
|
52
53
|
conv_dim: int
|
|
53
|
-
conv_dim_out: int
|
|
54
54
|
conv_use_xavier_init: bool
|
|
55
55
|
eos_token_id: int
|
|
56
56
|
hidden_size: int
|
|
@@ -64,13 +64,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
64
64
|
num_key_value_heads: int
|
|
65
65
|
pad_token_id: int
|
|
66
66
|
rope_theta: float
|
|
67
|
-
torch_dtype: Literal["bfloat16"]
|
|
68
67
|
transformers_version: str
|
|
69
68
|
use_cache: bool
|
|
70
69
|
use_pos_enc: bool
|
|
71
70
|
vocab_size: int
|
|
72
71
|
|
|
72
|
+
dtype: Literal["bfloat16", "float16", "float32"] | None = None
|
|
73
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"] | None = None
|
|
73
74
|
intermediate_size: int | None = None
|
|
75
|
+
conv_dim_out: int | None = None
|
|
74
76
|
layer_types: list[Literal["conv", "full_attention"]] | None = None
|
|
75
77
|
full_attn_idxs: list[int] | None = None
|
|
76
78
|
tie_embedding: bool = True
|
|
@@ -79,6 +81,14 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
79
81
|
quantization: QuantizationConfig | None = None
|
|
80
82
|
quantization_config: QuantizationConfig | None = None
|
|
81
83
|
|
|
84
|
+
@property
|
|
85
|
+
def default_precision(self) -> DTypeLike:
|
|
86
|
+
assert self.dtype is not None or self.torch_dtype is not None, (
|
|
87
|
+
"at least one of dtype or torch_dtype must be specified"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return jnp.dtype(self.dtype or self.torch_dtype)
|
|
91
|
+
|
|
82
92
|
def to_decoder_config(
|
|
83
93
|
self,
|
|
84
94
|
context_length: int | None,
|
|
@@ -200,8 +210,8 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
200
210
|
subtract_mean=False,
|
|
201
211
|
)
|
|
202
212
|
|
|
203
|
-
if self.
|
|
204
|
-
hidden_dim = self.intermediate_size
|
|
213
|
+
if not self.block_auto_adjust_ff_dim:
|
|
214
|
+
hidden_dim = self.intermediate_size or self.block_ff_dim
|
|
205
215
|
else:
|
|
206
216
|
hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
|
|
207
217
|
hidden_dim = int(
|
|
@@ -76,7 +76,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
|
|
|
76
76
|
logit_soft_cap=None,
|
|
77
77
|
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
78
78
|
embedding_quantization_mode=QuantizationMode.from_num_bits(
|
|
79
|
-
int(metadata_dict["quantization_kwargs.bits"])
|
|
79
|
+
int(metadata_dict["quantization_kwargs.bits"]),
|
|
80
80
|
),
|
|
81
81
|
activation_quantization_mode=None,
|
|
82
82
|
activation_precision=activation_precision,
|
|
@@ -107,7 +107,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
|
|
|
107
107
|
linear_config = MLXQuantizedLinearConfig(
|
|
108
108
|
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
109
109
|
weight_quantization_mode=QuantizationMode.from_num_bits(
|
|
110
|
-
int(metadata_dict["quantization_kwargs.bits"])
|
|
110
|
+
int(metadata_dict["quantization_kwargs.bits"]),
|
|
111
111
|
),
|
|
112
112
|
activation_quantization_mode=None,
|
|
113
113
|
activation_precision=activation_precision,
|
|
@@ -41,7 +41,7 @@ def activation_from_str(activation: str) -> type[Activation]:
|
|
|
41
41
|
return supported_activations[activation]
|
|
42
42
|
|
|
43
43
|
raise ValueError(
|
|
44
|
-
f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}"
|
|
44
|
+
f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}",
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
|
|
@@ -97,7 +97,7 @@ class ModernBERTConfig(HuggingFaceClassifierConfig):
|
|
|
97
97
|
result = [None] * num_layers
|
|
98
98
|
for index in range(len(result)):
|
|
99
99
|
if index % global_attn_every_n_layers != 0:
|
|
100
|
-
result[index] = self.local_attention
|
|
100
|
+
result[index] = self.local_attention
|
|
101
101
|
else:
|
|
102
102
|
pass
|
|
103
103
|
return tuple(result)
|
|
@@ -5,7 +5,9 @@ from typing import ClassVar
|
|
|
5
5
|
|
|
6
6
|
import cattrs
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
from lalamo.models import GenerationConfig
|
|
9
|
+
|
|
10
|
+
__all__ = ["HFGenerationConfig", "_policy_from_hf_config"]
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
@dataclass(frozen=True)
|
|
@@ -27,10 +29,11 @@ class HFGenerationConfig:
|
|
|
27
29
|
cache_implementation: str | None = None # “hybrid” for Gemma 3/2
|
|
28
30
|
|
|
29
31
|
# -------- sampling strategy -------------
|
|
30
|
-
do_sample: bool | None =
|
|
32
|
+
do_sample: bool | None = False
|
|
31
33
|
temperature: float | None = None
|
|
34
|
+
min_p: float | None = None
|
|
32
35
|
top_p: float | None = None
|
|
33
|
-
top_k: int | None =
|
|
36
|
+
top_k: int | None = 50
|
|
34
37
|
repetition_penalty: float | None = None
|
|
35
38
|
|
|
36
39
|
# -------- length limits -----------------
|
|
@@ -42,3 +45,18 @@ class HFGenerationConfig:
|
|
|
42
45
|
with open(json_path) as f:
|
|
43
46
|
config = json.load(f)
|
|
44
47
|
return cls._converter.structure(config, cls)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _policy_from_hf_config(
|
|
51
|
+
hf_config: HFGenerationConfig,
|
|
52
|
+
stop_token_ids: tuple[int, ...] = (),
|
|
53
|
+
banned_tokens: tuple[int, ...] | None = None,
|
|
54
|
+
) -> GenerationConfig:
|
|
55
|
+
return GenerationConfig(
|
|
56
|
+
stop_token_ids=stop_token_ids,
|
|
57
|
+
temperature=hf_config.temperature,
|
|
58
|
+
top_k=hf_config.top_k,
|
|
59
|
+
top_p=hf_config.top_p,
|
|
60
|
+
min_p=hf_config.min_p,
|
|
61
|
+
banned_tokens=banned_tokens,
|
|
62
|
+
)
|
|
@@ -97,7 +97,7 @@ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: Paramete
|
|
|
97
97
|
fused_up_gate_params = merge_linear_params([up_proj_params, gate_proj_params])
|
|
98
98
|
|
|
99
99
|
return load_parameters(
|
|
100
|
-
lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)),
|
|
100
|
+
lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)),
|
|
101
101
|
module,
|
|
102
102
|
(*fused_up_gate_params, *down_proj_params),
|
|
103
103
|
)
|
|
@@ -177,7 +177,7 @@ def load_attention(
|
|
|
177
177
|
|
|
178
178
|
qkv_params = merge_linear_params([q_params, k_params, v_params])
|
|
179
179
|
return load_parameters(
|
|
180
|
-
lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)),
|
|
180
|
+
lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)),
|
|
181
181
|
module,
|
|
182
182
|
(*qkv_params, *out_params),
|
|
183
183
|
)
|
|
@@ -289,7 +289,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
289
289
|
combined_up_gate_b = jnp.concatenate([up_b + 1.0, gate_b], axis=-1)
|
|
290
290
|
|
|
291
291
|
up_projection = load_parameters(
|
|
292
|
-
lambda m: (m.weights, m.biases),
|
|
292
|
+
lambda m: (m.weights, m.biases),
|
|
293
293
|
module.experts.up_projection,
|
|
294
294
|
(combined_up_gate_w, combined_up_gate_b),
|
|
295
295
|
)
|
|
@@ -309,7 +309,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
309
309
|
down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
|
|
310
310
|
|
|
311
311
|
down_projection = load_parameters(
|
|
312
|
-
lambda m: (m.weights, m.biases),
|
|
312
|
+
lambda m: (m.weights, m.biases),
|
|
313
313
|
module.experts.down_projection,
|
|
314
314
|
(down_w, down_b),
|
|
315
315
|
)
|
|
@@ -807,7 +807,7 @@ def load_huggingface_decoder(
|
|
|
807
807
|
weights_dict,
|
|
808
808
|
decoder_path / "layers" / ((i * 2) if alternating_layers else i),
|
|
809
809
|
decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
|
|
810
|
-
mixer_key[type(layer.config.mixer_config)],
|
|
810
|
+
mixer_key[type(layer.config.mixer_config)],
|
|
811
811
|
mlp_key,
|
|
812
812
|
pre_mixer_norm_key,
|
|
813
813
|
pre_mlp_norm_key,
|
|
@@ -7,15 +7,17 @@ from contextlib import contextmanager
|
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from enum import Enum, StrEnum
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import ClassVar, cast, get_args, get_origin
|
|
10
|
+
from typing import Any, ClassVar, cast, get_args, get_origin
|
|
11
11
|
|
|
12
12
|
import cattrs
|
|
13
13
|
import jax.numpy as jnp
|
|
14
14
|
from jaxtyping import Array, DTypeLike
|
|
15
15
|
|
|
16
16
|
from lalamo.model_import.decoder_configs import ForeignConfig
|
|
17
|
+
from lalamo.models.language_model import GenerationConfig
|
|
17
18
|
from lalamo.quantization import QuantizationMode
|
|
18
|
-
from lalamo.
|
|
19
|
+
from lalamo.safetensors import safe_read
|
|
20
|
+
from lalamo.utils import MapDictValues
|
|
19
21
|
|
|
20
22
|
__all__ = [
|
|
21
23
|
"ConfigMap",
|
|
@@ -52,7 +54,8 @@ class WeightsType(Enum):
|
|
|
52
54
|
float_dtype: DTypeLike,
|
|
53
55
|
) -> Iterator[tuple[Mapping[str, jnp.ndarray], Mapping[str, str]]]:
|
|
54
56
|
if self == WeightsType.SAFETENSORS:
|
|
55
|
-
with
|
|
57
|
+
with Path(filename).open("rb") as fd:
|
|
58
|
+
(metadata_dict, weights_dict) = safe_read(fd)
|
|
56
59
|
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
57
60
|
else:
|
|
58
61
|
import torch
|
|
@@ -84,7 +87,7 @@ class ConfigMap:
|
|
|
84
87
|
model_config: FileSpec = field(default=FileSpec("config.json"))
|
|
85
88
|
tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
|
|
86
89
|
tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
|
|
87
|
-
generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
|
|
90
|
+
generation_config: FileSpec | GenerationConfig | None = field(default=FileSpec("generation_config.json"))
|
|
88
91
|
chat_template: FileSpec | JSONFieldSpec | str | None = None
|
|
89
92
|
|
|
90
93
|
|
|
@@ -121,6 +124,7 @@ def _structure_chat_template(value: object, _type: object) -> FileSpec | JSONFie
|
|
|
121
124
|
if isinstance(value, str):
|
|
122
125
|
return value
|
|
123
126
|
if isinstance(value, dict):
|
|
127
|
+
value = cast("dict[Any, Any]", value) # ty bug??? Why is just `dict` != `dict[Any, Any]`?
|
|
124
128
|
if "file_spec" in value and "field_name" in value:
|
|
125
129
|
return JSONFieldSpec(
|
|
126
130
|
file_spec=FileSpec(**value["file_spec"]),
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
from itertools import chain, product
|
|
2
|
+
|
|
1
3
|
from lalamo.model_import.decoder_configs import HFLFM2Config
|
|
4
|
+
from lalamo.models.language_model import GenerationConfig
|
|
2
5
|
from lalamo.quantization import QuantizationMode
|
|
3
6
|
|
|
4
7
|
from .common import ConfigMap, FileSpec, ModelSpec
|
|
@@ -6,26 +9,55 @@ from .common import ConfigMap, FileSpec, ModelSpec
|
|
|
6
9
|
__all__ = ["LFM2_MODELS"]
|
|
7
10
|
|
|
8
11
|
|
|
9
|
-
def
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
def _lfm_repo(family: str, size: str, variant: str | None, quantization: QuantizationMode | None) -> tuple[str, str]:
|
|
13
|
+
return (
|
|
14
|
+
"LiquidAI" if quantization is None else "mlx-community",
|
|
15
|
+
f"{family}-{size}"
|
|
16
|
+
f"{f'-{variant}' if variant is not None else ''}"
|
|
17
|
+
f"{f'-{quantization.bits}bit' if quantization is not None else ''}",
|
|
18
|
+
)
|
|
13
19
|
|
|
14
20
|
|
|
15
|
-
|
|
21
|
+
_LFM20_MODELS = [
|
|
16
22
|
ModelSpec(
|
|
17
23
|
vendor="LiquidAI",
|
|
18
24
|
family="LFM2",
|
|
19
|
-
name=
|
|
25
|
+
name=_lfm_repo("LFM2", size, variant, quantization)[1],
|
|
20
26
|
size=size,
|
|
21
|
-
repo="/".join(
|
|
27
|
+
repo="/".join(_lfm_repo("LFM2", size, variant, quantization)),
|
|
22
28
|
config_type=HFLFM2Config,
|
|
23
29
|
quantization=quantization,
|
|
24
30
|
configs=ConfigMap(
|
|
31
|
+
generation_config=GenerationConfig(temperature=0.3, min_p=0.15), # , repetition_penalty=1.05
|
|
25
32
|
chat_template=FileSpec("chat_template.jinja"),
|
|
26
33
|
),
|
|
27
34
|
use_cases=tuple(),
|
|
28
35
|
)
|
|
29
|
-
for size
|
|
30
|
-
|
|
36
|
+
for size, variant, quantization in chain(
|
|
37
|
+
product(["350M", "700M", "1.2B"], [None], [None, QuantizationMode.UINT4, QuantizationMode.UINT8]),
|
|
38
|
+
product(["2.6B"], [None, "Exp"], [None]),
|
|
39
|
+
product(["2.6B"], ["Exp"], [QuantizationMode.UINT4, QuantizationMode.UINT8]),
|
|
40
|
+
)
|
|
31
41
|
]
|
|
42
|
+
|
|
43
|
+
_LFM25_MODELS = [
|
|
44
|
+
ModelSpec(
|
|
45
|
+
vendor="LiquidAI",
|
|
46
|
+
family="LFM2.5",
|
|
47
|
+
name=_lfm_repo("LFM2.5", size, variant, quantization)[1],
|
|
48
|
+
size=size,
|
|
49
|
+
repo="/".join(_lfm_repo("LFM2.5", size, variant, quantization)),
|
|
50
|
+
config_type=HFLFM2Config,
|
|
51
|
+
quantization=quantization,
|
|
52
|
+
configs=ConfigMap(
|
|
53
|
+
generation_config=GenerationConfig(temperature=0.1, top_k=50, top_p=0.1), # , repetition_penalty=1.05
|
|
54
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
55
|
+
),
|
|
56
|
+
use_cases=tuple(),
|
|
57
|
+
)
|
|
58
|
+
for size, variant, quantization in chain(
|
|
59
|
+
product(["1.2B"], ["Instruct"], [None]),
|
|
60
|
+
)
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
LFM2_MODELS = _LFM20_MODELS + _LFM25_MODELS
|
lalamo/models/common.py
CHANGED
|
@@ -15,7 +15,7 @@ from lalamo.message_processor import Message, MessageProcessor, MessageProcessor
|
|
|
15
15
|
from lalamo.modules import Classifier, Decoder, LalamoModule, config_converter
|
|
16
16
|
from lalamo.modules.classifier import ClassifierConfig, ClassifierResult
|
|
17
17
|
from lalamo.modules.decoder import DecoderConfig, DecoderResult
|
|
18
|
-
from lalamo.
|
|
18
|
+
from lalamo.safetensors import safe_read
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"TextModel",
|
|
@@ -42,8 +42,8 @@ class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
|
|
|
42
42
|
with open(path / "config.json") as config_file:
|
|
43
43
|
config_json = json.load(config_file)
|
|
44
44
|
config = config_converter.structure(config_json["model_config"], cls)
|
|
45
|
-
with
|
|
46
|
-
|
|
45
|
+
with Path(path / "model.safetensors").open("rb") as fd:
|
|
46
|
+
_, weights_dict = safe_read(fd)
|
|
47
47
|
weights = unflatten_parameters(weights_dict)
|
|
48
48
|
model = config.model_config.empty().import_weights(weights)
|
|
49
49
|
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
lalamo/models/language_model.py
CHANGED
|
@@ -64,14 +64,15 @@ class GenerationResults(NamedTuple):
|
|
|
64
64
|
|
|
65
65
|
@dataclass(frozen=True)
|
|
66
66
|
class GenerationConfig:
|
|
67
|
-
stop_token_ids: tuple[int, ...]
|
|
68
|
-
temperature: float | None
|
|
69
|
-
top_k: int | None
|
|
70
|
-
top_p: float | None
|
|
71
|
-
|
|
67
|
+
stop_token_ids: tuple[int, ...] = tuple()
|
|
68
|
+
temperature: float | None = None
|
|
69
|
+
top_k: int | None = None
|
|
70
|
+
top_p: float | None = None
|
|
71
|
+
min_p: float | None = None
|
|
72
|
+
banned_tokens: tuple[int, ...] | None = None
|
|
72
73
|
|
|
73
74
|
def default_policy(self) -> SamplingPolicy:
|
|
74
|
-
return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
|
|
75
|
+
return make_policy(self.temperature, self.top_k, self.top_p, self.min_p, self.banned_tokens)
|
|
75
76
|
|
|
76
77
|
|
|
77
78
|
@dataclass(frozen=True)
|
lalamo/modules/activations.py
CHANGED
lalamo/modules/classifier.py
CHANGED
|
@@ -9,7 +9,7 @@ from jax import numpy as jnp
|
|
|
9
9
|
from jax import vmap
|
|
10
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree
|
|
12
|
+
from lalamo.common import ParameterTree, require_tree
|
|
13
13
|
from lalamo.modules import Activation
|
|
14
14
|
from lalamo.modules.normalization import NormalizationConfig
|
|
15
15
|
from lalamo.modules.transformer import (
|
|
@@ -67,7 +67,7 @@ class PredictionHeadConfig:
|
|
|
67
67
|
def random_init(self, input_size: int, num_labels: int, key: PRNGKeyArray) -> "PredictionHead":
|
|
68
68
|
dense_key, readout_key = jax.random.split(key)
|
|
69
69
|
dense_layer = self.dense_config.random_init(
|
|
70
|
-
input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key
|
|
70
|
+
input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key,
|
|
71
71
|
)
|
|
72
72
|
norm = self.normalization_config.empty(input_size)
|
|
73
73
|
readout = self.readout_config.random_init(
|
|
@@ -117,19 +117,13 @@ class PredictionHead(LalamoModule[PredictionHeadConfig]):
|
|
|
117
117
|
)
|
|
118
118
|
return result
|
|
119
119
|
|
|
120
|
-
def import_weights(
|
|
121
|
-
self,
|
|
122
|
-
weights: ParameterTree[Array],
|
|
123
|
-
) -> Self:
|
|
120
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
124
121
|
assert isinstance(weights, Mapping)
|
|
125
|
-
assert isinstance(weights["dense"], Mapping)
|
|
126
|
-
assert isinstance(weights["norm"], Mapping)
|
|
127
|
-
assert isinstance(weights["readout"], Mapping)
|
|
128
122
|
return replace(
|
|
129
123
|
self,
|
|
130
|
-
dense=self.dense.import_weights(weights["dense"]),
|
|
131
|
-
norm=self.norm.import_weights(weights["norm"]),
|
|
132
|
-
readout=self.readout.import_weights(weights["readout"]),
|
|
124
|
+
dense=self.dense.import_weights(require_tree(weights["dense"])),
|
|
125
|
+
norm=self.norm.import_weights(require_tree(weights["norm"])),
|
|
126
|
+
readout=self.readout.import_weights(require_tree(weights["readout"])),
|
|
133
127
|
)
|
|
134
128
|
|
|
135
129
|
|
|
@@ -321,19 +315,12 @@ class Classifier(LalamoModule[ClassifierConfig]):
|
|
|
321
315
|
)
|
|
322
316
|
return result
|
|
323
317
|
|
|
324
|
-
def import_weights(
|
|
325
|
-
self,
|
|
326
|
-
weights: ParameterTree[Array],
|
|
327
|
-
) -> Self:
|
|
318
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
328
319
|
assert isinstance(weights, Mapping)
|
|
329
|
-
assert isinstance(weights["embedding"], Mapping)
|
|
330
|
-
assert isinstance(weights["embedding_norm"], Mapping)
|
|
331
|
-
assert isinstance(weights["transformer"], Mapping)
|
|
332
|
-
assert isinstance(weights["prediction_head"], Mapping)
|
|
333
320
|
return replace(
|
|
334
321
|
self,
|
|
335
|
-
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
336
|
-
embedding_norm=self.embedding_norm.import_weights(weights["embedding_norm"]),
|
|
337
|
-
transformer=self.transformer.import_weights(weights["transformer"]),
|
|
338
|
-
prediction_head=self.prediction_head.import_weights(weights["prediction_head"]),
|
|
322
|
+
embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
|
|
323
|
+
embedding_norm=self.embedding_norm.import_weights(require_tree(weights["embedding_norm"])),
|
|
324
|
+
transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
|
|
325
|
+
prediction_head=self.prediction_head.import_weights(require_tree(weights["prediction_head"])),
|
|
339
326
|
)
|
lalamo/modules/common.py
CHANGED
|
@@ -9,15 +9,18 @@ from cattrs import Converter
|
|
|
9
9
|
from jax import numpy as jnp
|
|
10
10
|
from jaxtyping import Array, DTypeLike
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree
|
|
12
|
+
from lalamo.common import ParameterTree, require_array, require_tree
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
15
|
"DummyUnionMember",
|
|
16
16
|
"ForwardPassMode",
|
|
17
17
|
"LalamoModule",
|
|
18
|
+
"ParameterTree",
|
|
18
19
|
"PositionalEmbeddingSelector",
|
|
19
20
|
"config_converter",
|
|
20
21
|
"register_config_union",
|
|
22
|
+
"require_array",
|
|
23
|
+
"require_tree",
|
|
21
24
|
]
|
|
22
25
|
|
|
23
26
|
|
lalamo/modules/decoder.py
CHANGED
|
@@ -7,7 +7,7 @@ import jax
|
|
|
7
7
|
from jax import vmap
|
|
8
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
9
|
|
|
10
|
-
from lalamo.common import ParameterTree
|
|
10
|
+
from lalamo.common import ParameterTree, require_tree
|
|
11
11
|
|
|
12
12
|
from .common import ForwardPassMode, LalamoModule
|
|
13
13
|
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
@@ -126,7 +126,7 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
126
126
|
return self.embedding.activation_precision
|
|
127
127
|
|
|
128
128
|
@eqx.filter_jit
|
|
129
|
-
def __call__(
|
|
129
|
+
def __call__(
|
|
130
130
|
self,
|
|
131
131
|
token_ids: Int[Array, "batch suffix_tokens"],
|
|
132
132
|
token_positions: Int[Array, "batch suffix_tokens"],
|
|
@@ -193,16 +193,10 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
193
193
|
transformer=self.transformer.export_weights(),
|
|
194
194
|
)
|
|
195
195
|
|
|
196
|
-
def import_weights(
|
|
197
|
-
self,
|
|
198
|
-
weights: ParameterTree[Array],
|
|
199
|
-
) -> Self:
|
|
196
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
200
197
|
assert isinstance(weights, Mapping)
|
|
201
|
-
assert isinstance(weights["embedding"], Mapping)
|
|
202
|
-
assert isinstance(weights["transformer"], Mapping)
|
|
203
|
-
|
|
204
198
|
return replace(
|
|
205
199
|
self,
|
|
206
|
-
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
207
|
-
transformer=self.transformer.import_weights(weights["transformer"]),
|
|
200
|
+
embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
|
|
201
|
+
transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
|
|
208
202
|
)
|