lalamo 0.5.17__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.5.17 → lalamo-0.6.0}/PKG-INFO +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/__init__.py +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/commands.py +69 -17
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/common.py +14 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/main.py +148 -27
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/message_processor.py +4 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/common.py +8 -17
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/huggingface_generation_config.py +21 -3
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/loaders/executorch.py +2 -2
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/loaders/huggingface.py +3 -3
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/common.py +4 -2
- lalamo-0.6.0/lalamo/model_import/model_specs/lfm2.py +63 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/models/language_model.py +7 -6
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/activations.py +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/classifier.py +11 -24
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/common.py +4 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/decoder.py +5 -11
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/embedding.py +25 -62
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/linear.py +19 -33
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/mlp.py +9 -19
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/mlx_interop.py +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/rope.py +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/__init__.py +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/attention.py +9 -27
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/mamba.py +9 -24
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/short_conv.py +5 -12
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/transformer.py +10 -20
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/transformer_layer.py +8 -20
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/registry_abc.py +4 -4
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/sampling.py +14 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/speculator/estimator.py +3 -3
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/speculator/ngram.py +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo.egg-info/SOURCES.txt +1 -13
- {lalamo-0.5.17 → lalamo-0.6.0}/pyproject.toml +8 -0
- lalamo-0.5.17/lalamo/model_import/model_specs/lfm2.py +0 -31
- lalamo-0.5.17/tests/test_cartesia_mlx_models.py +0 -22
- lalamo-0.5.17/tests/test_chat_template.py +0 -173
- lalamo-0.5.17/tests/test_generation.py +0 -198
- lalamo-0.5.17/tests/test_huggingface_model_conversion.py +0 -109
- lalamo-0.5.17/tests/test_huggingface_models.py +0 -41
- lalamo-0.5.17/tests/test_lfm2_models.py +0 -13
- lalamo-0.5.17/tests/test_mlx_models.py +0 -20
- lalamo-0.5.17/tests/test_model_spec.py +0 -52
- lalamo-0.5.17/tests/test_models.py +0 -497
- lalamo-0.5.17/tests/test_moe.py +0 -58
- lalamo-0.5.17/tests/test_parameter_tree.py +0 -103
- lalamo-0.5.17/tests/test_registry_abc.py +0 -147
- {lalamo-0.5.17 → lalamo-0.6.0}/LICENSE +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/README.md +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/data/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/data/utils.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/essential_ai.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/mirai.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/models/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/models/classifier.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/models/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/state/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/modules/utils.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/quantization.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/safetensors.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/speculator/common.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/speculator/inference.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo/utils.py +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.5.17 → lalamo-0.6.0}/setup.cfg +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from collections.abc import Callable
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from enum import Enum
|
|
5
5
|
from itertools import chain
|
|
@@ -10,7 +10,7 @@ from jaxtyping import DTypeLike
|
|
|
10
10
|
from lalamo.common import flatten_parameters
|
|
11
11
|
from lalamo.data import import_hf_parquet
|
|
12
12
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
13
|
-
from lalamo.message_processor import
|
|
13
|
+
from lalamo.message_processor import Message
|
|
14
14
|
from lalamo.model_import import ModelMetadata, ModelSpec, import_model
|
|
15
15
|
from lalamo.model_import.common import (
|
|
16
16
|
DownloadingFileEvent,
|
|
@@ -41,8 +41,6 @@ class ConversionCallbacks:
|
|
|
41
41
|
output_dir: Path
|
|
42
42
|
precision: Precision | None
|
|
43
43
|
context_length: int | None
|
|
44
|
-
include_traces: bool
|
|
45
|
-
message_for_trace: str | None
|
|
46
44
|
|
|
47
45
|
def started(self) -> None:
|
|
48
46
|
pass
|
|
@@ -74,16 +72,12 @@ def convert(
|
|
|
74
72
|
output_dir: Path,
|
|
75
73
|
precision: Precision | None = None,
|
|
76
74
|
context_length: int | None = None,
|
|
77
|
-
include_traces: bool = False,
|
|
78
|
-
message_for_trace: str | None = None,
|
|
79
75
|
callbacks_type: Callable[
|
|
80
76
|
[
|
|
81
77
|
ModelSpec,
|
|
82
78
|
Path,
|
|
83
79
|
Precision | None,
|
|
84
80
|
int | None,
|
|
85
|
-
bool,
|
|
86
|
-
str | None,
|
|
87
81
|
],
|
|
88
82
|
ConversionCallbacks,
|
|
89
83
|
] = ConversionCallbacks,
|
|
@@ -93,8 +87,6 @@ def convert(
|
|
|
93
87
|
output_dir,
|
|
94
88
|
precision,
|
|
95
89
|
context_length,
|
|
96
|
-
include_traces,
|
|
97
|
-
message_for_trace,
|
|
98
90
|
)
|
|
99
91
|
|
|
100
92
|
if precision is not None:
|
|
@@ -127,13 +119,6 @@ def convert(
|
|
|
127
119
|
callbacks.saving_model()
|
|
128
120
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
129
121
|
|
|
130
|
-
if include_traces:
|
|
131
|
-
message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
|
|
132
|
-
result = model.record_trace(message)
|
|
133
|
-
traces = flatten_parameters(result.export())
|
|
134
|
-
with Path(output_dir / "traces.safetensors").open("wb") as fd:
|
|
135
|
-
safe_write(fd, traces)
|
|
136
|
-
|
|
137
122
|
model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
|
|
138
123
|
weights = flatten_parameters(model.export_weights())
|
|
139
124
|
del model
|
|
@@ -148,6 +133,73 @@ def convert(
|
|
|
148
133
|
callbacks.finished_saving_model()
|
|
149
134
|
|
|
150
135
|
|
|
136
|
+
@dataclass
|
|
137
|
+
class TraceCallbacks:
|
|
138
|
+
model_path: Path
|
|
139
|
+
output_path: Path
|
|
140
|
+
messages: Iterable[Message] | None
|
|
141
|
+
|
|
142
|
+
def output_exists(self) -> None:
|
|
143
|
+
raise RuntimeError(f"{self.output_path=} already exists, refusing to overwrite!")
|
|
144
|
+
|
|
145
|
+
def started(self) -> None:
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
def loading_model(self) -> None:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
def finished_loading_model(self) -> None:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def tracing_model(self) -> None:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
def finished_tracing_model(self) -> None:
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
def saving_trace(self) -> None:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
def finished_saving_trace(self) -> None:
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def trace(
|
|
168
|
+
model_path: Path,
|
|
169
|
+
output_path: Path,
|
|
170
|
+
messages: Iterable[Message] | None = None,
|
|
171
|
+
callbacks_type: Callable[
|
|
172
|
+
[
|
|
173
|
+
Path,
|
|
174
|
+
Path,
|
|
175
|
+
Iterable[Message] | None,
|
|
176
|
+
],
|
|
177
|
+
TraceCallbacks,
|
|
178
|
+
] = TraceCallbacks,
|
|
179
|
+
) -> None:
|
|
180
|
+
callbacks = callbacks_type(model_path, output_path, messages)
|
|
181
|
+
|
|
182
|
+
if output_path.exists():
|
|
183
|
+
callbacks.output_exists()
|
|
184
|
+
|
|
185
|
+
callbacks.started()
|
|
186
|
+
|
|
187
|
+
callbacks.loading_model()
|
|
188
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
189
|
+
callbacks.finished_loading_model()
|
|
190
|
+
|
|
191
|
+
callbacks.tracing_model()
|
|
192
|
+
result = model.record_trace(messages)
|
|
193
|
+
callbacks.finished_tracing_model()
|
|
194
|
+
|
|
195
|
+
callbacks.saving_trace()
|
|
196
|
+
traces = flatten_parameters(result.export())
|
|
197
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
198
|
+
with Path(output_path).open("wb") as fd:
|
|
199
|
+
safe_write(fd, traces)
|
|
200
|
+
callbacks.finished_saving_trace()
|
|
201
|
+
|
|
202
|
+
|
|
151
203
|
@dataclass
|
|
152
204
|
class EstimateBatchsizeCallbacks:
|
|
153
205
|
model_path: Path
|
|
@@ -15,6 +15,8 @@ __all__ = [
|
|
|
15
15
|
"ParameterTree",
|
|
16
16
|
"dummy_array",
|
|
17
17
|
"flatten_parameters",
|
|
18
|
+
"require_array",
|
|
19
|
+
"require_tree",
|
|
18
20
|
"unflatten_parameters",
|
|
19
21
|
]
|
|
20
22
|
|
|
@@ -29,6 +31,16 @@ type ParameterTree[ArrayType: ArrayLike] = (
|
|
|
29
31
|
)
|
|
30
32
|
|
|
31
33
|
|
|
34
|
+
def require_array[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ArrayType:
|
|
35
|
+
assert not isinstance(value, (Mapping, Sequence))
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def require_tree[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ParameterTree[ArrayType]:
|
|
40
|
+
assert not isinstance(value, (Array, ShapeDtypeStruct))
|
|
41
|
+
return value
|
|
42
|
+
|
|
43
|
+
|
|
32
44
|
def dummy_array(shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
|
|
33
45
|
if isinstance(shape, int):
|
|
34
46
|
shape = (shape,)
|
|
@@ -40,9 +52,10 @@ def flatten_parameters[ArrayType: ArrayLike](nested_parameters: ParameterTree[Ar
|
|
|
40
52
|
if not isinstance(nested_parameters, Mapping):
|
|
41
53
|
nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
|
|
42
54
|
for key, value in nested_parameters.items():
|
|
55
|
+
value = cast("ArrayType | ParameterTree[ArrayType]", value)
|
|
43
56
|
key_path = ParameterPath(key)
|
|
44
57
|
if isinstance(value, (Array, ShapeDtypeStruct)):
|
|
45
|
-
result[key_path] = value
|
|
58
|
+
result[key_path] = cast("ArrayType", value)
|
|
46
59
|
else:
|
|
47
60
|
update: dict[str, ArrayType] = {
|
|
48
61
|
str(key_path / subkey): subvalue for subkey, subvalue in flatten_parameters(value).items()
|
|
@@ -36,11 +36,13 @@ from lalamo.commands import (
|
|
|
36
36
|
ConversionCallbacks,
|
|
37
37
|
EstimateBatchsizeCallbacks,
|
|
38
38
|
Precision,
|
|
39
|
+
TraceCallbacks,
|
|
39
40
|
TrainCallbacks,
|
|
40
41
|
)
|
|
41
42
|
from lalamo.commands import collect_traces as _collect_traces
|
|
42
43
|
from lalamo.commands import convert as _convert
|
|
43
44
|
from lalamo.commands import estimate_batchsize as _estimate_batchsize
|
|
45
|
+
from lalamo.commands import trace as _trace
|
|
44
46
|
from lalamo.commands import train as _train
|
|
45
47
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
46
48
|
from lalamo.message_processor import UserMessage
|
|
@@ -83,7 +85,7 @@ class ModelParser(ParamType):
|
|
|
83
85
|
f"\n\nUse the `{SCRIPT_NAME} list-models` command to see the list of currently supported models.",
|
|
84
86
|
)
|
|
85
87
|
error_message = "".join(error_message_parts)
|
|
86
|
-
self.fail(error_message, param, ctx)
|
|
88
|
+
return self.fail(error_message, param, ctx)
|
|
87
89
|
return result
|
|
88
90
|
|
|
89
91
|
|
|
@@ -111,10 +113,18 @@ def chat(
|
|
|
111
113
|
metavar="MODEL_PATH",
|
|
112
114
|
),
|
|
113
115
|
],
|
|
116
|
+
message: Annotated[
|
|
117
|
+
str | None,
|
|
118
|
+
Option(
|
|
119
|
+
help="Message for non-interactive mode",
|
|
120
|
+
show_default="None, run interactively",
|
|
121
|
+
),
|
|
122
|
+
] = None,
|
|
114
123
|
) -> None:
|
|
115
124
|
with Progress(
|
|
116
125
|
SpinnerColumn(),
|
|
117
126
|
TextColumn("[progress.description]{task.description}"),
|
|
127
|
+
console=err_console,
|
|
118
128
|
transient=True,
|
|
119
129
|
) as progress:
|
|
120
130
|
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
@@ -123,21 +133,28 @@ def chat(
|
|
|
123
133
|
warmup_task = progress.add_task("🔥 Warming up compilation cache...")
|
|
124
134
|
list(model.stream_reply_text([UserMessage("")], max_output_length=1))
|
|
125
135
|
progress.remove_task(warmup_task)
|
|
126
|
-
console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
|
|
127
|
-
messages = []
|
|
128
|
-
while True:
|
|
129
|
-
user_text = console.input("[cyan]user> [/cyan]")
|
|
130
|
-
user_message = UserMessage(user_text)
|
|
131
|
-
messages.append(user_message)
|
|
132
136
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
137
|
+
if message is None:
|
|
138
|
+
console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
|
|
139
|
+
|
|
140
|
+
messages = []
|
|
141
|
+
while True:
|
|
142
|
+
user_text = console.input("[cyan]user> [/cyan]")
|
|
143
|
+
user_message = UserMessage(user_text)
|
|
144
|
+
messages.append(user_message)
|
|
145
|
+
|
|
146
|
+
console.print("[red]assistant> [/red]", end="")
|
|
147
|
+
model_response_tokens = []
|
|
148
|
+
for token in model.stream_reply_text(messages):
|
|
149
|
+
console.print(token, end="")
|
|
150
|
+
model_response_tokens.append(token)
|
|
151
|
+
console.print()
|
|
152
|
+
model_response_text = "".join(model_response_tokens)
|
|
153
|
+
messages.append(model.message_processor.parse_response(model_response_text))
|
|
154
|
+
else:
|
|
155
|
+
for token in model.stream_reply_text([UserMessage(message)]):
|
|
136
156
|
console.print(token, end="")
|
|
137
|
-
model_response_tokens.append(token)
|
|
138
157
|
console.print()
|
|
139
|
-
model_response_text = "".join(model_response_tokens)
|
|
140
|
-
messages.append(model.message_processor.parse_response(model_response_text))
|
|
141
158
|
|
|
142
159
|
|
|
143
160
|
@app.command(help="Classify given message with a Classifier type of model.")
|
|
@@ -178,6 +195,7 @@ class CliConversionCallbacks(ConversionCallbacks):
|
|
|
178
195
|
overwrite: bool = False
|
|
179
196
|
|
|
180
197
|
stack: ExitStack = field(default_factory=ExitStack)
|
|
198
|
+
progress: Progress | None = None
|
|
181
199
|
downloading_tasks: dict[FileSpec, TaskID] = field(default_factory=dict)
|
|
182
200
|
initializing_task: TaskID | None = None
|
|
183
201
|
saving_task: TaskID | None = None
|
|
@@ -211,23 +229,33 @@ class CliConversionCallbacks(ConversionCallbacks):
|
|
|
211
229
|
shutil.rmtree(self.output_dir)
|
|
212
230
|
|
|
213
231
|
def downloading(self, file_spec: FileSpec) -> None:
|
|
232
|
+
assert self.progress is not None
|
|
233
|
+
|
|
214
234
|
self.downloading_tasks[file_spec] = self.progress.add_task(f"Retrieving {file_spec.filename}...")
|
|
215
235
|
|
|
216
236
|
def finished_downloading(self, file_spec: FileSpec) -> None:
|
|
237
|
+
assert self.progress is not None
|
|
238
|
+
|
|
217
239
|
self.progress.remove_task(self.downloading_tasks[file_spec])
|
|
218
240
|
|
|
219
241
|
def initializing_model(self) -> None:
|
|
242
|
+
assert self.progress is not None
|
|
243
|
+
|
|
220
244
|
self.initializing_task = self.progress.add_task("Initializing model...")
|
|
221
245
|
|
|
222
246
|
def finished_initializing_model(self) -> None:
|
|
247
|
+
assert self.progress is not None
|
|
223
248
|
assert self.initializing_task is not None
|
|
224
249
|
|
|
225
250
|
self.progress.remove_task(self.initializing_task)
|
|
226
251
|
|
|
227
252
|
def saving_model(self) -> None:
|
|
253
|
+
assert self.progress is not None
|
|
254
|
+
|
|
228
255
|
self.saving_task = self.progress.add_task(f"💾 Saving the model to {self.output_dir}")
|
|
229
256
|
|
|
230
257
|
def finished_saving_model(self) -> None:
|
|
258
|
+
assert self.progress is not None
|
|
231
259
|
assert self.saving_task is not None
|
|
232
260
|
|
|
233
261
|
self.progress.remove_task(self.saving_task)
|
|
@@ -272,24 +300,12 @@ def convert(
|
|
|
272
300
|
show_default="Model's native maximum context length.",
|
|
273
301
|
),
|
|
274
302
|
] = None,
|
|
275
|
-
include_traces: Annotated[
|
|
276
|
-
bool,
|
|
277
|
-
Option(
|
|
278
|
-
help="Export activation traces for debugging purposes.",
|
|
279
|
-
),
|
|
280
|
-
] = False,
|
|
281
303
|
overwrite: Annotated[
|
|
282
304
|
bool,
|
|
283
305
|
Option(
|
|
284
306
|
help="Overwrite existing model files.",
|
|
285
307
|
),
|
|
286
308
|
] = False,
|
|
287
|
-
message_for_trace: Annotated[
|
|
288
|
-
str | None,
|
|
289
|
-
Option(
|
|
290
|
-
help="Text message to use as prompt when recording trace",
|
|
291
|
-
),
|
|
292
|
-
] = None,
|
|
293
309
|
) -> None:
|
|
294
310
|
if output_dir is None:
|
|
295
311
|
output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
|
|
@@ -299,12 +315,117 @@ def convert(
|
|
|
299
315
|
output_dir,
|
|
300
316
|
precision,
|
|
301
317
|
context_length,
|
|
302
|
-
include_traces,
|
|
303
|
-
message_for_trace,
|
|
304
318
|
partial(CliConversionCallbacks, overwrite=overwrite),
|
|
305
319
|
)
|
|
306
320
|
|
|
307
321
|
|
|
322
|
+
@dataclass
|
|
323
|
+
class CliTraceCallbacks(TraceCallbacks):
|
|
324
|
+
overwrite: bool = False
|
|
325
|
+
|
|
326
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
327
|
+
progress: Progress | None = None
|
|
328
|
+
loading_task: TaskID | None = None
|
|
329
|
+
tracing_task: TaskID | None = None
|
|
330
|
+
saving_task: TaskID | None = None
|
|
331
|
+
|
|
332
|
+
def output_exists(self) -> None:
|
|
333
|
+
if not self.overwrite and not Confirm().ask(
|
|
334
|
+
rf"⚠️ Output [cyan]{self.output_path}[/cyan] already exists."
|
|
335
|
+
r" Do you want to overwrite it?",
|
|
336
|
+
):
|
|
337
|
+
raise Exit
|
|
338
|
+
|
|
339
|
+
self.output_path.unlink()
|
|
340
|
+
|
|
341
|
+
def started(self) -> None:
|
|
342
|
+
console.print(f"🔍 Tracing [cyan]{self.model_path}[/cyan]")
|
|
343
|
+
|
|
344
|
+
self.progress = self.stack.enter_context(
|
|
345
|
+
Progress(
|
|
346
|
+
SpinnerColumn(),
|
|
347
|
+
TextColumn("[progress.description]{task.description}"),
|
|
348
|
+
transient=True,
|
|
349
|
+
),
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def loading_model(self) -> None:
|
|
353
|
+
assert self.progress is not None
|
|
354
|
+
|
|
355
|
+
self.loading_task = self.progress.add_task("🧠 Loading model...")
|
|
356
|
+
|
|
357
|
+
def finished_loading_model(self) -> None:
|
|
358
|
+
assert self.progress is not None
|
|
359
|
+
assert self.loading_task is not None
|
|
360
|
+
|
|
361
|
+
self.progress.remove_task(self.loading_task)
|
|
362
|
+
|
|
363
|
+
def tracing_model(self) -> None:
|
|
364
|
+
assert self.progress is not None
|
|
365
|
+
|
|
366
|
+
self.tracing_task = self.progress.add_task("🔍 Recording trace...")
|
|
367
|
+
|
|
368
|
+
def finished_tracing_model(self) -> None:
|
|
369
|
+
assert self.progress is not None
|
|
370
|
+
assert self.tracing_task is not None
|
|
371
|
+
|
|
372
|
+
self.progress.remove_task(self.tracing_task)
|
|
373
|
+
|
|
374
|
+
def saving_trace(self) -> None:
|
|
375
|
+
assert self.progress is not None
|
|
376
|
+
|
|
377
|
+
self.saving_task = self.progress.add_task(f"💾 Saving trace to {self.output_path}")
|
|
378
|
+
|
|
379
|
+
def finished_saving_trace(self) -> None:
|
|
380
|
+
assert self.progress is not None
|
|
381
|
+
assert self.saving_task is not None
|
|
382
|
+
|
|
383
|
+
self.progress.remove_task(self.saving_task)
|
|
384
|
+
self.stack.close()
|
|
385
|
+
console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
|
|
386
|
+
|
|
387
|
+
@app.command(help="Trace a model.")
|
|
388
|
+
def trace(
|
|
389
|
+
model_path: Annotated[
|
|
390
|
+
Path,
|
|
391
|
+
Argument(
|
|
392
|
+
help="Path to the model directory.",
|
|
393
|
+
metavar="MODEL_PATH",
|
|
394
|
+
),
|
|
395
|
+
],
|
|
396
|
+
output_path: Annotated[
|
|
397
|
+
Path | None,
|
|
398
|
+
Option(
|
|
399
|
+
help="Path to save the trace to.",
|
|
400
|
+
show_default="${MODEL_PATH}/traces.safetensors",
|
|
401
|
+
),
|
|
402
|
+
] = None,
|
|
403
|
+
overwrite: Annotated[
|
|
404
|
+
bool,
|
|
405
|
+
Option(
|
|
406
|
+
help="Overwrite existing trace file.",
|
|
407
|
+
),
|
|
408
|
+
] = False,
|
|
409
|
+
message: Annotated[
|
|
410
|
+
str | None,
|
|
411
|
+
Option(
|
|
412
|
+
help="Text message to use as prompt when recording trace",
|
|
413
|
+
),
|
|
414
|
+
] = None,
|
|
415
|
+
) -> None:
|
|
416
|
+
if output_path is None:
|
|
417
|
+
output_path = model_path / "traces.safetensors"
|
|
418
|
+
|
|
419
|
+
messages = None if message is None else [UserMessage(content=message)]
|
|
420
|
+
|
|
421
|
+
_trace(
|
|
422
|
+
model_path,
|
|
423
|
+
output_path,
|
|
424
|
+
messages,
|
|
425
|
+
partial(CliTraceCallbacks, overwrite=overwrite),
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
|
|
308
429
|
def _model_size_string_to_int(
|
|
309
430
|
size_str: str,
|
|
310
431
|
_regex: re.Pattern = re.compile(r"(?P<number>(\d+)(\.\d*)?)(?P<suffix>[KMBT])"),
|
|
@@ -169,7 +169,10 @@ class MessageProcessor:
|
|
|
169
169
|
def __post_init__(self) -> None:
|
|
170
170
|
if self.output_parser_regex is not None:
|
|
171
171
|
all_fields = AssistantMessage.__dataclass_fields__
|
|
172
|
-
|
|
172
|
+
# NOTE: str type annotations are assumed to be required
|
|
173
|
+
required_fields = {
|
|
174
|
+
k: v for k, v in all_fields.items() if isinstance(v.type, str) or v.type == (v.type | None)
|
|
175
|
+
}
|
|
173
176
|
named_groups = self.output_parser_regex.groupindex
|
|
174
177
|
invalid_groups = set(named_groups) - set(all_fields)
|
|
175
178
|
if invalid_groups:
|
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
from collections import ChainMap
|
|
4
4
|
from collections.abc import Callable
|
|
5
5
|
from contextlib import ExitStack
|
|
6
|
-
from dataclasses import dataclass
|
|
6
|
+
from dataclasses import dataclass, replace
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import NamedTuple
|
|
9
9
|
|
|
@@ -20,7 +20,7 @@ from lalamo.quantization import QuantizationMode
|
|
|
20
20
|
from lalamo.utils import process_chat_template
|
|
21
21
|
|
|
22
22
|
from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
|
|
23
|
-
from .huggingface_generation_config import HFGenerationConfig
|
|
23
|
+
from .huggingface_generation_config import HFGenerationConfig, _policy_from_hf_config
|
|
24
24
|
from .huggingface_tokenizer_config import HFTokenizerConfig
|
|
25
25
|
from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, ModelType, UseCase
|
|
26
26
|
from .model_specs.common import JSONFieldSpec
|
|
@@ -34,6 +34,7 @@ __all__ = [
|
|
|
34
34
|
"ModelSpec",
|
|
35
35
|
"ModelType",
|
|
36
36
|
"StatusEvent",
|
|
37
|
+
"download_file",
|
|
37
38
|
"import_model",
|
|
38
39
|
]
|
|
39
40
|
|
|
@@ -239,24 +240,14 @@ def _import_language_model(
|
|
|
239
240
|
|
|
240
241
|
stop_token_ids = tuple(foreign_decoder_config.eos_token_ids)
|
|
241
242
|
|
|
242
|
-
if model_spec.configs.generation_config
|
|
243
|
+
if isinstance(model_spec.configs.generation_config, GenerationConfig):
|
|
244
|
+
generation_config = replace(model_spec.configs.generation_config, stop_token_ids=stop_token_ids)
|
|
245
|
+
elif isinstance(model_spec.configs.generation_config, FileSpec):
|
|
243
246
|
hf_generation_config_file = download_file(model_spec.configs.generation_config, model_spec.repo)
|
|
244
247
|
hf_generation_config = HFGenerationConfig.from_json(hf_generation_config_file)
|
|
245
|
-
generation_config =
|
|
246
|
-
stop_token_ids=stop_token_ids,
|
|
247
|
-
temperature=hf_generation_config.temperature,
|
|
248
|
-
top_p=hf_generation_config.top_p,
|
|
249
|
-
top_k=hf_generation_config.top_k,
|
|
250
|
-
banned_tokens=None,
|
|
251
|
-
)
|
|
248
|
+
generation_config = _policy_from_hf_config(hf_generation_config, stop_token_ids)
|
|
252
249
|
else:
|
|
253
|
-
generation_config = GenerationConfig(
|
|
254
|
-
stop_token_ids=stop_token_ids,
|
|
255
|
-
temperature=None,
|
|
256
|
-
top_p=None,
|
|
257
|
-
top_k=None,
|
|
258
|
-
banned_tokens=None,
|
|
259
|
-
)
|
|
250
|
+
generation_config = GenerationConfig(stop_token_ids)
|
|
260
251
|
|
|
261
252
|
language_model_config = LanguageModelConfig(
|
|
262
253
|
model_config=decoder.config,
|
|
@@ -2,6 +2,7 @@ from collections.abc import Mapping
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Literal
|
|
4
4
|
|
|
5
|
+
import jax.numpy as jnp
|
|
5
6
|
from jaxtyping import DTypeLike
|
|
6
7
|
|
|
7
8
|
from lalamo.modules import (
|
|
@@ -50,7 +51,6 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
50
51
|
conv_L_cache: int # noqa: N815
|
|
51
52
|
conv_bias: bool
|
|
52
53
|
conv_dim: int
|
|
53
|
-
conv_dim_out: int
|
|
54
54
|
conv_use_xavier_init: bool
|
|
55
55
|
eos_token_id: int
|
|
56
56
|
hidden_size: int
|
|
@@ -64,13 +64,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
64
64
|
num_key_value_heads: int
|
|
65
65
|
pad_token_id: int
|
|
66
66
|
rope_theta: float
|
|
67
|
-
torch_dtype: Literal["bfloat16"]
|
|
68
67
|
transformers_version: str
|
|
69
68
|
use_cache: bool
|
|
70
69
|
use_pos_enc: bool
|
|
71
70
|
vocab_size: int
|
|
72
71
|
|
|
72
|
+
dtype: Literal["bfloat16", "float16", "float32"] | None = None
|
|
73
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"] | None = None
|
|
73
74
|
intermediate_size: int | None = None
|
|
75
|
+
conv_dim_out: int | None = None
|
|
74
76
|
layer_types: list[Literal["conv", "full_attention"]] | None = None
|
|
75
77
|
full_attn_idxs: list[int] | None = None
|
|
76
78
|
tie_embedding: bool = True
|
|
@@ -79,6 +81,14 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
79
81
|
quantization: QuantizationConfig | None = None
|
|
80
82
|
quantization_config: QuantizationConfig | None = None
|
|
81
83
|
|
|
84
|
+
@property
|
|
85
|
+
def default_precision(self) -> DTypeLike:
|
|
86
|
+
assert self.dtype is not None or self.torch_dtype is not None, (
|
|
87
|
+
"at least one of dtype or torch_dtype must be specified"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return jnp.dtype(self.dtype or self.torch_dtype)
|
|
91
|
+
|
|
82
92
|
def to_decoder_config(
|
|
83
93
|
self,
|
|
84
94
|
context_length: int | None,
|
|
@@ -200,8 +210,8 @@ class HFLFM2Config(HuggingFaceLMConfig):
|
|
|
200
210
|
subtract_mean=False,
|
|
201
211
|
)
|
|
202
212
|
|
|
203
|
-
if self.
|
|
204
|
-
hidden_dim = self.intermediate_size
|
|
213
|
+
if not self.block_auto_adjust_ff_dim:
|
|
214
|
+
hidden_dim = self.intermediate_size or self.block_ff_dim
|
|
205
215
|
else:
|
|
206
216
|
hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
|
|
207
217
|
hidden_dim = int(
|
|
@@ -76,7 +76,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
|
|
|
76
76
|
logit_soft_cap=None,
|
|
77
77
|
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
78
78
|
embedding_quantization_mode=QuantizationMode.from_num_bits(
|
|
79
|
-
int(metadata_dict["quantization_kwargs.bits"])
|
|
79
|
+
int(metadata_dict["quantization_kwargs.bits"]),
|
|
80
80
|
),
|
|
81
81
|
activation_quantization_mode=None,
|
|
82
82
|
activation_precision=activation_precision,
|
|
@@ -107,7 +107,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
|
|
|
107
107
|
linear_config = MLXQuantizedLinearConfig(
|
|
108
108
|
group_size=int(metadata_dict["quantization_kwargs.group_size"]),
|
|
109
109
|
weight_quantization_mode=QuantizationMode.from_num_bits(
|
|
110
|
-
int(metadata_dict["quantization_kwargs.bits"])
|
|
110
|
+
int(metadata_dict["quantization_kwargs.bits"]),
|
|
111
111
|
),
|
|
112
112
|
activation_quantization_mode=None,
|
|
113
113
|
activation_precision=activation_precision,
|
{lalamo-0.5.17 → lalamo-0.6.0}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py
RENAMED
|
@@ -41,7 +41,7 @@ def activation_from_str(activation: str) -> type[Activation]:
|
|
|
41
41
|
return supported_activations[activation]
|
|
42
42
|
|
|
43
43
|
raise ValueError(
|
|
44
|
-
f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}"
|
|
44
|
+
f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}",
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
|
|
@@ -97,7 +97,7 @@ class ModernBERTConfig(HuggingFaceClassifierConfig):
|
|
|
97
97
|
result = [None] * num_layers
|
|
98
98
|
for index in range(len(result)):
|
|
99
99
|
if index % global_attn_every_n_layers != 0:
|
|
100
|
-
result[index] = self.local_attention
|
|
100
|
+
result[index] = self.local_attention
|
|
101
101
|
else:
|
|
102
102
|
pass
|
|
103
103
|
return tuple(result)
|
|
@@ -5,7 +5,9 @@ from typing import ClassVar
|
|
|
5
5
|
|
|
6
6
|
import cattrs
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
from lalamo.models import GenerationConfig
|
|
9
|
+
|
|
10
|
+
__all__ = ["HFGenerationConfig", "_policy_from_hf_config"]
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
@dataclass(frozen=True)
|
|
@@ -27,10 +29,11 @@ class HFGenerationConfig:
|
|
|
27
29
|
cache_implementation: str | None = None # “hybrid” for Gemma 3/2
|
|
28
30
|
|
|
29
31
|
# -------- sampling strategy -------------
|
|
30
|
-
do_sample: bool | None =
|
|
32
|
+
do_sample: bool | None = False
|
|
31
33
|
temperature: float | None = None
|
|
34
|
+
min_p: float | None = None
|
|
32
35
|
top_p: float | None = None
|
|
33
|
-
top_k: int | None =
|
|
36
|
+
top_k: int | None = 50
|
|
34
37
|
repetition_penalty: float | None = None
|
|
35
38
|
|
|
36
39
|
# -------- length limits -----------------
|
|
@@ -42,3 +45,18 @@ class HFGenerationConfig:
|
|
|
42
45
|
with open(json_path) as f:
|
|
43
46
|
config = json.load(f)
|
|
44
47
|
return cls._converter.structure(config, cls)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _policy_from_hf_config(
|
|
51
|
+
hf_config: HFGenerationConfig,
|
|
52
|
+
stop_token_ids: tuple[int, ...] = (),
|
|
53
|
+
banned_tokens: tuple[int, ...] | None = None,
|
|
54
|
+
) -> GenerationConfig:
|
|
55
|
+
return GenerationConfig(
|
|
56
|
+
stop_token_ids=stop_token_ids,
|
|
57
|
+
temperature=hf_config.temperature,
|
|
58
|
+
top_k=hf_config.top_k,
|
|
59
|
+
top_p=hf_config.top_p,
|
|
60
|
+
min_p=hf_config.min_p,
|
|
61
|
+
banned_tokens=banned_tokens,
|
|
62
|
+
)
|
|
@@ -97,7 +97,7 @@ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: Paramete
|
|
|
97
97
|
fused_up_gate_params = merge_linear_params([up_proj_params, gate_proj_params])
|
|
98
98
|
|
|
99
99
|
return load_parameters(
|
|
100
|
-
lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)),
|
|
100
|
+
lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)),
|
|
101
101
|
module,
|
|
102
102
|
(*fused_up_gate_params, *down_proj_params),
|
|
103
103
|
)
|
|
@@ -177,7 +177,7 @@ def load_attention(
|
|
|
177
177
|
|
|
178
178
|
qkv_params = merge_linear_params([q_params, k_params, v_params])
|
|
179
179
|
return load_parameters(
|
|
180
|
-
lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)),
|
|
180
|
+
lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)),
|
|
181
181
|
module,
|
|
182
182
|
(*qkv_params, *out_params),
|
|
183
183
|
)
|