lalamo 0.5.7__tar.gz → 0.5.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.5.7 → lalamo-0.5.8}/PKG-INFO +1 -1
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/__init__.py +5 -4
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/main.py +3 -3
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/common.py +14 -10
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/__init__.py +2 -2
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/common.py +21 -2
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/mirai.py +3 -3
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/models/__init__.py +3 -3
- lalamo-0.5.7/lalamo/models/router.py → lalamo-0.5.8/lalamo/models/classifier.py +8 -8
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo.egg-info/SOURCES.txt +2 -1
- lalamo-0.5.8/tests/test_chat_template.py +173 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_huggingface_model_conversion.py +3 -3
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_models.py +3 -3
- {lalamo-0.5.7 → lalamo-0.5.8}/LICENSE +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/README.md +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/data/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/data/huggingface_message.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/data/utils.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/message_processor.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/loaders/huggingface.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/mistral.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/models/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/models/language_model.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/activations.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/classifier.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/decoder.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/linear.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/mlp.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/rope.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/attention.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/mamba.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/state/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/transformer.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/transformer_layer.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/modules/utils.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/quantization.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/registry_abc.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/sampling.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/speculator/__init__.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/speculator/common.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/speculator/estimator.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/speculator/inference.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo/utils.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/pyproject.toml +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/setup.cfg +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_cartesia_mlx_models.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_generation.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_huggingface_models.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_mlx_models.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_model_spec.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_moe.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_parameter_tree.py +0 -0
- {lalamo-0.5.7 → lalamo-0.5.8}/tests/test_registry_abc.py +0 -0
|
@@ -8,24 +8,24 @@ from lalamo.message_processor import (
|
|
|
8
8
|
ToolSchema,
|
|
9
9
|
UserMessage,
|
|
10
10
|
)
|
|
11
|
-
from lalamo.model_import import ModelSpec
|
|
12
|
-
from lalamo.models import
|
|
11
|
+
from lalamo.model_import import ModelSpec, import_model
|
|
12
|
+
from lalamo.models import ClassifierModel, LanguageModel
|
|
13
13
|
from lalamo.speculator import (
|
|
14
14
|
CollectTracesEvent,
|
|
15
15
|
SpeculatorTrainingEvent,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
-
__version__ = "0.5.
|
|
18
|
+
__version__ = "0.5.8"
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"AssistantMessage",
|
|
22
|
+
"ClassifierModel",
|
|
22
23
|
"CollectTracesEvent",
|
|
23
24
|
"ContentBlock",
|
|
24
25
|
"Image",
|
|
25
26
|
"LanguageModel",
|
|
26
27
|
"Message",
|
|
27
28
|
"ModelSpec",
|
|
28
|
-
"Router",
|
|
29
29
|
"SpeculatorTrainingEvent",
|
|
30
30
|
"SystemMessage",
|
|
31
31
|
"ToolSchema",
|
|
@@ -33,5 +33,6 @@ __all__ = [
|
|
|
33
33
|
"collect_traces",
|
|
34
34
|
"convert",
|
|
35
35
|
"estimate_batchsize",
|
|
36
|
+
"import_model",
|
|
36
37
|
"train",
|
|
37
38
|
]
|
|
@@ -43,7 +43,7 @@ from lalamo.model_import.common import (
|
|
|
43
43
|
InitializingModelEvent,
|
|
44
44
|
StatusEvent,
|
|
45
45
|
)
|
|
46
|
-
from lalamo.models import
|
|
46
|
+
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
47
47
|
from lalamo.modules import config_converter
|
|
48
48
|
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
49
49
|
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
@@ -149,7 +149,7 @@ def chat(
|
|
|
149
149
|
messages.append(model.message_processor.parse_response(model_response_text))
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
@app.command(help="Classify given message with a
|
|
152
|
+
@app.command(help="Classify given message with a Classifier type of model.")
|
|
153
153
|
def classify(
|
|
154
154
|
model_path: Annotated[
|
|
155
155
|
Path,
|
|
@@ -165,7 +165,7 @@ def classify(
|
|
|
165
165
|
transient=True,
|
|
166
166
|
) as progress:
|
|
167
167
|
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
168
|
-
model =
|
|
168
|
+
model = ClassifierModelConfig.load_model(model_path)
|
|
169
169
|
progress.remove_task(loading_task)
|
|
170
170
|
warmup_task = progress.add_task("🔥 Warming up...")
|
|
171
171
|
model.classify_chat([UserMessage(content="warmup message")])
|
|
@@ -14,7 +14,7 @@ from jaxtyping import DTypeLike
|
|
|
14
14
|
from tokenizers import Tokenizer
|
|
15
15
|
|
|
16
16
|
from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
|
|
17
|
-
from lalamo.models import
|
|
17
|
+
from lalamo.models import ClassifierModel, ClassifierModelConfig, GenerationConfig, LanguageModel, LanguageModelConfig
|
|
18
18
|
from lalamo.modules import Classifier, Decoder, LalamoModule
|
|
19
19
|
from lalamo.quantization import QuantizationMode
|
|
20
20
|
|
|
@@ -72,7 +72,8 @@ class ModelMetadata:
|
|
|
72
72
|
repo: str
|
|
73
73
|
use_cases: tuple[UseCase, ...]
|
|
74
74
|
model_type: ModelType
|
|
75
|
-
model_config: LanguageModelConfig |
|
|
75
|
+
model_config: LanguageModelConfig | ClassifierModelConfig
|
|
76
|
+
grammar_start_tokens: tuple[str, ...]
|
|
76
77
|
|
|
77
78
|
|
|
78
79
|
def download_file(
|
|
@@ -118,7 +119,7 @@ def download_config_file(
|
|
|
118
119
|
|
|
119
120
|
|
|
120
121
|
class ImportResults(NamedTuple):
|
|
121
|
-
model: LanguageModel |
|
|
122
|
+
model: LanguageModel | ClassifierModel
|
|
122
123
|
metadata: ModelMetadata
|
|
123
124
|
|
|
124
125
|
|
|
@@ -145,6 +146,8 @@ def import_message_processor(
|
|
|
145
146
|
case FileSpec(_) as file_spec:
|
|
146
147
|
chat_template_file = download_file(file_spec, model_spec.repo, output_dir)
|
|
147
148
|
prompt_template = chat_template_file.read_text()
|
|
149
|
+
case str() as template_string:
|
|
150
|
+
prompt_template = template_string
|
|
148
151
|
case None:
|
|
149
152
|
raise ValueError("No chat template specified.")
|
|
150
153
|
else:
|
|
@@ -263,14 +266,14 @@ def _import_language_model(
|
|
|
263
266
|
return language_model, language_model_config
|
|
264
267
|
|
|
265
268
|
|
|
266
|
-
def
|
|
269
|
+
def _import_classifier(
|
|
267
270
|
model_spec: ModelSpec,
|
|
268
271
|
*,
|
|
269
272
|
context_length: int | None = None,
|
|
270
273
|
precision: DTypeLike | None = None,
|
|
271
274
|
accumulation_precision: DTypeLike = jnp.float32,
|
|
272
275
|
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
273
|
-
) -> tuple[
|
|
276
|
+
) -> tuple[ClassifierModel, ClassifierModelConfig]:
|
|
274
277
|
foreign_classifier_config_file = download_config_file(model_spec)
|
|
275
278
|
foreign_classifier_config = model_spec.config_type.from_json(foreign_classifier_config_file)
|
|
276
279
|
assert isinstance(foreign_classifier_config, ForeignClassifierConfig)
|
|
@@ -293,12 +296,12 @@ def _import_router(
|
|
|
293
296
|
|
|
294
297
|
message_processor = import_message_processor(model_spec)
|
|
295
298
|
|
|
296
|
-
|
|
299
|
+
classifier_model_config = ClassifierModelConfig(
|
|
297
300
|
model_config=classifier.config,
|
|
298
301
|
message_processor_config=message_processor.config,
|
|
299
302
|
)
|
|
300
|
-
|
|
301
|
-
return
|
|
303
|
+
classifier_model = ClassifierModel(classifier_model_config, classifier, message_processor)
|
|
304
|
+
return classifier_model, classifier_model_config
|
|
302
305
|
|
|
303
306
|
|
|
304
307
|
def import_model(
|
|
@@ -324,8 +327,8 @@ def import_model(
|
|
|
324
327
|
accumulation_precision=accumulation_precision,
|
|
325
328
|
progress_callback=progress_callback,
|
|
326
329
|
)
|
|
327
|
-
case ModelType.
|
|
328
|
-
model, config =
|
|
330
|
+
case ModelType.CLASSIFIER_MODEL:
|
|
331
|
+
model, config = _import_classifier(
|
|
329
332
|
model_spec,
|
|
330
333
|
context_length=context_length,
|
|
331
334
|
precision=precision,
|
|
@@ -344,5 +347,6 @@ def import_model(
|
|
|
344
347
|
use_cases=model_spec.use_cases,
|
|
345
348
|
model_type=model_spec.model_type,
|
|
346
349
|
model_config=config,
|
|
350
|
+
grammar_start_tokens=model_spec.grammar_start_tokens,
|
|
347
351
|
)
|
|
348
352
|
return ImportResults(model, metadata)
|
|
@@ -5,7 +5,7 @@ from .gpt_oss import GPT_OSS_MODELS
|
|
|
5
5
|
from .huggingface import HUGGINGFACE_MODELS
|
|
6
6
|
from .llama import LLAMA_MODELS
|
|
7
7
|
from .llamba import LLAMBA_MODELS
|
|
8
|
-
from .mirai import
|
|
8
|
+
from .mirai import MIRAI_CLASSIFIER_MODELS
|
|
9
9
|
from .mistral import MISTRAL_MODELS
|
|
10
10
|
|
|
11
11
|
# from .pleias import PLEIAS_MODELS
|
|
@@ -35,7 +35,7 @@ ALL_MODEL_LISTS = [
|
|
|
35
35
|
POLARIS_MODELS,
|
|
36
36
|
QWEN_MODELS,
|
|
37
37
|
REKA_MODELS,
|
|
38
|
-
|
|
38
|
+
MIRAI_CLASSIFIER_MODELS,
|
|
39
39
|
]
|
|
40
40
|
|
|
41
41
|
ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
|
|
@@ -32,7 +32,7 @@ __all__ = [
|
|
|
32
32
|
|
|
33
33
|
class ModelType(StrEnum):
|
|
34
34
|
LANGUAGE_MODEL = "language_model"
|
|
35
|
-
|
|
35
|
+
CLASSIFIER_MODEL = "classifier_model"
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def cast_if_float(array: Array, cast_to: DTypeLike) -> Array:
|
|
@@ -84,7 +84,7 @@ class ConfigMap:
|
|
|
84
84
|
tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
|
|
85
85
|
tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
|
|
86
86
|
generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
|
|
87
|
-
chat_template: FileSpec | JSONFieldSpec | None = None
|
|
87
|
+
chat_template: FileSpec | JSONFieldSpec | str | None = None
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
def _is_foreign_config_type(t: object) -> bool:
|
|
@@ -114,12 +114,29 @@ def _unstructure_foreign_config_factory(t: object, c: cattrs.Converter) -> Calla
|
|
|
114
114
|
return _hook
|
|
115
115
|
|
|
116
116
|
|
|
117
|
+
def _structure_chat_template(value: object, _type: object) -> FileSpec | JSONFieldSpec | str | None:
|
|
118
|
+
if value is None:
|
|
119
|
+
return None
|
|
120
|
+
if isinstance(value, str):
|
|
121
|
+
return value
|
|
122
|
+
if isinstance(value, dict):
|
|
123
|
+
if "file_spec" in value and "field_name" in value:
|
|
124
|
+
return JSONFieldSpec(
|
|
125
|
+
file_spec=FileSpec(**value["file_spec"]),
|
|
126
|
+
field_name=value["field_name"],
|
|
127
|
+
)
|
|
128
|
+
if "filename" in value:
|
|
129
|
+
return FileSpec(**value)
|
|
130
|
+
raise ValueError(f"Invalid chat_template value: {value}")
|
|
131
|
+
|
|
132
|
+
|
|
117
133
|
@dataclass(frozen=True)
|
|
118
134
|
class ModelSpec:
|
|
119
135
|
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
120
136
|
|
|
121
137
|
_converter.register_structure_hook_factory(_is_foreign_config_type, _structure_foreign_config_factory)
|
|
122
138
|
_converter.register_unstructure_hook_factory(_is_foreign_config_type, _unstructure_foreign_config_factory)
|
|
139
|
+
_converter.register_structure_hook(FileSpec | JSONFieldSpec | str | None, _structure_chat_template)
|
|
123
140
|
|
|
124
141
|
vendor: str
|
|
125
142
|
family: str
|
|
@@ -137,6 +154,7 @@ class ModelSpec:
|
|
|
137
154
|
model_type: ModelType = ModelType.LANGUAGE_MODEL
|
|
138
155
|
configs: ConfigMap = field(default=ConfigMap())
|
|
139
156
|
use_cases: tuple[UseCase, ...] = tuple()
|
|
157
|
+
grammar_start_tokens: tuple[str, ...] = tuple()
|
|
140
158
|
|
|
141
159
|
@classmethod
|
|
142
160
|
def from_json(cls, json_data: dict) -> "ModelSpec":
|
|
@@ -162,6 +180,7 @@ def awq_model_spec(
|
|
|
162
180
|
configs=model_spec.configs,
|
|
163
181
|
weights_type=model_spec.weights_type,
|
|
164
182
|
use_cases=model_spec.use_cases,
|
|
183
|
+
grammar_start_tokens=model_spec.grammar_start_tokens,
|
|
165
184
|
)
|
|
166
185
|
|
|
167
186
|
|
|
@@ -2,9 +2,9 @@ from lalamo.model_import.decoder_configs.huggingface import ModernBERTConfig
|
|
|
2
2
|
|
|
3
3
|
from .common import ConfigMap, FileSpec, ModelSpec, ModelType
|
|
4
4
|
|
|
5
|
-
__all__ = ["
|
|
5
|
+
__all__ = ["MIRAI_CLASSIFIER_MODELS"]
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
MIRAI_CLASSIFIER_MODELS = [
|
|
8
8
|
ModelSpec(
|
|
9
9
|
vendor="trymirai",
|
|
10
10
|
family="ModernBERT",
|
|
@@ -14,7 +14,7 @@ MIRAI_ROUTER_MODELS = [
|
|
|
14
14
|
repo="trymirai/chat-moderation-router",
|
|
15
15
|
config_type=ModernBERTConfig,
|
|
16
16
|
use_cases=tuple(),
|
|
17
|
-
model_type=ModelType("
|
|
17
|
+
model_type=ModelType("classifier_model"),
|
|
18
18
|
configs=ConfigMap(chat_template=FileSpec("chat_template.jinja")),
|
|
19
19
|
),
|
|
20
20
|
]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
+
from .classifier import ClassifierModel, ClassifierModelConfig
|
|
1
2
|
from .language_model import GenerationConfig, LanguageModel, LanguageModelConfig
|
|
2
|
-
from .router import Router, RouterConfig
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
+
"ClassifierModel",
|
|
6
|
+
"ClassifierModelConfig",
|
|
5
7
|
"GenerationConfig",
|
|
6
8
|
"LanguageModel",
|
|
7
9
|
"LanguageModelConfig",
|
|
8
|
-
"Router",
|
|
9
|
-
"RouterConfig",
|
|
10
10
|
]
|
|
@@ -13,29 +13,29 @@ from lalamo.modules import Classifier, ClassifierConfig, LalamoModule
|
|
|
13
13
|
from .common import TextModel, TextModelConfig
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
17
|
-
"
|
|
16
|
+
"ClassifierModel",
|
|
17
|
+
"ClassifierModelConfig",
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@dataclass(frozen=True)
|
|
22
|
-
class
|
|
22
|
+
class ClassifierModelConfig(TextModelConfig[ClassifierConfig]):
|
|
23
23
|
def init(
|
|
24
24
|
self,
|
|
25
25
|
model: LalamoModule,
|
|
26
26
|
message_processor: MessageProcessor,
|
|
27
|
-
) -> "
|
|
27
|
+
) -> "ClassifierModel":
|
|
28
28
|
assert isinstance(model, Classifier)
|
|
29
|
-
return
|
|
29
|
+
return ClassifierModel(self, model, message_processor)
|
|
30
30
|
|
|
31
31
|
@classmethod
|
|
32
|
-
def load_model(cls, path: Path | str) -> "
|
|
32
|
+
def load_model(cls, path: Path | str) -> "ClassifierModel":
|
|
33
33
|
result = super().load_model(path)
|
|
34
|
-
assert isinstance(result,
|
|
34
|
+
assert isinstance(result, ClassifierModel)
|
|
35
35
|
return result
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
class
|
|
38
|
+
class ClassifierModel(TextModel[ClassifierModelConfig, Classifier]):
|
|
39
39
|
def label_output_logits(self, logits: Float[Array, "batch logits"]) -> dict[str, Float[Array, " batch"]]:
|
|
40
40
|
output_labels = self.model.config.output_labels
|
|
41
41
|
probabilities = jax.nn.sigmoid(logits)
|
|
@@ -57,9 +57,9 @@ lalamo/model_import/model_specs/polaris.py
|
|
|
57
57
|
lalamo/model_import/model_specs/qwen.py
|
|
58
58
|
lalamo/model_import/model_specs/reka.py
|
|
59
59
|
lalamo/models/__init__.py
|
|
60
|
+
lalamo/models/classifier.py
|
|
60
61
|
lalamo/models/common.py
|
|
61
62
|
lalamo/models/language_model.py
|
|
62
|
-
lalamo/models/router.py
|
|
63
63
|
lalamo/modules/__init__.py
|
|
64
64
|
lalamo/modules/activations.py
|
|
65
65
|
lalamo/modules/classifier.py
|
|
@@ -90,6 +90,7 @@ lalamo/speculator/inference.py
|
|
|
90
90
|
lalamo/speculator/ngram.py
|
|
91
91
|
lalamo/speculator/utils.py
|
|
92
92
|
tests/test_cartesia_mlx_models.py
|
|
93
|
+
tests/test_chat_template.py
|
|
93
94
|
tests/test_generation.py
|
|
94
95
|
tests/test_huggingface_model_conversion.py
|
|
95
96
|
tests/test_huggingface_models.py
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from lalamo.model_import.decoder_configs.huggingface.llama import HFLlamaConfig
|
|
4
|
+
from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelSpec
|
|
5
|
+
|
|
6
|
+
DIRECT_TEMPLATE = "{% for message in messages %}{{ message.content }}{% endfor %}"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestConfigMapChatTemplate:
|
|
10
|
+
def test_chat_template_as_string(self) -> None:
|
|
11
|
+
config = ConfigMap(chat_template=DIRECT_TEMPLATE)
|
|
12
|
+
assert config.chat_template == DIRECT_TEMPLATE
|
|
13
|
+
assert isinstance(config.chat_template, str)
|
|
14
|
+
|
|
15
|
+
def test_chat_template_as_file_spec(self) -> None:
|
|
16
|
+
file_spec = FileSpec("chat_template.jinja")
|
|
17
|
+
config = ConfigMap(chat_template=file_spec)
|
|
18
|
+
assert config.chat_template == file_spec
|
|
19
|
+
assert isinstance(config.chat_template, FileSpec)
|
|
20
|
+
|
|
21
|
+
def test_chat_template_as_file_spec_with_repo(self) -> None:
|
|
22
|
+
file_spec = FileSpec("chat_template.jinja", repo="some/repo")
|
|
23
|
+
config = ConfigMap(chat_template=file_spec)
|
|
24
|
+
assert config.chat_template == file_spec
|
|
25
|
+
assert isinstance(config.chat_template, FileSpec)
|
|
26
|
+
assert config.chat_template.repo == "some/repo"
|
|
27
|
+
|
|
28
|
+
def test_chat_template_as_json_field_spec(self) -> None:
|
|
29
|
+
json_spec = JSONFieldSpec(FileSpec("config.json"), "chat_template")
|
|
30
|
+
config = ConfigMap(chat_template=json_spec)
|
|
31
|
+
assert config.chat_template == json_spec
|
|
32
|
+
assert isinstance(config.chat_template, JSONFieldSpec)
|
|
33
|
+
|
|
34
|
+
def test_chat_template_none(self) -> None:
|
|
35
|
+
config = ConfigMap()
|
|
36
|
+
assert config.chat_template is None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TestModelSpecWithChatTemplate:
|
|
40
|
+
def test_model_spec_with_string_chat_template(self) -> None:
|
|
41
|
+
spec = ModelSpec(
|
|
42
|
+
vendor="Test",
|
|
43
|
+
family="Test",
|
|
44
|
+
name="Test",
|
|
45
|
+
size="1B",
|
|
46
|
+
repo="test/test",
|
|
47
|
+
config_type=HFLlamaConfig,
|
|
48
|
+
configs=ConfigMap(chat_template=DIRECT_TEMPLATE),
|
|
49
|
+
)
|
|
50
|
+
assert spec.configs.chat_template == DIRECT_TEMPLATE
|
|
51
|
+
|
|
52
|
+
def test_model_spec_with_file_spec_chat_template(self) -> None:
|
|
53
|
+
spec = ModelSpec(
|
|
54
|
+
vendor="Test",
|
|
55
|
+
family="Test",
|
|
56
|
+
name="Test",
|
|
57
|
+
size="1B",
|
|
58
|
+
repo="test/test",
|
|
59
|
+
config_type=HFLlamaConfig,
|
|
60
|
+
configs=ConfigMap(chat_template=FileSpec("chat_template.jinja")),
|
|
61
|
+
)
|
|
62
|
+
assert isinstance(spec.configs.chat_template, FileSpec)
|
|
63
|
+
assert spec.configs.chat_template.filename == "chat_template.jinja"
|
|
64
|
+
|
|
65
|
+
def test_model_spec_with_json_field_spec_chat_template(self) -> None:
|
|
66
|
+
spec = ModelSpec(
|
|
67
|
+
vendor="Test",
|
|
68
|
+
family="Test",
|
|
69
|
+
name="Test",
|
|
70
|
+
size="1B",
|
|
71
|
+
repo="test/test",
|
|
72
|
+
config_type=HFLlamaConfig,
|
|
73
|
+
configs=ConfigMap(chat_template=JSONFieldSpec(FileSpec("tokenizer_config.json"), "chat_template")),
|
|
74
|
+
)
|
|
75
|
+
assert isinstance(spec.configs.chat_template, JSONFieldSpec)
|
|
76
|
+
assert spec.configs.chat_template.field_name == "chat_template"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TestModelSpecJsonSerialization:
|
|
80
|
+
def test_roundtrip_with_string_chat_template(self) -> None:
|
|
81
|
+
spec = ModelSpec(
|
|
82
|
+
vendor="Test",
|
|
83
|
+
family="Test",
|
|
84
|
+
name="Test",
|
|
85
|
+
size="1B",
|
|
86
|
+
repo="test/test",
|
|
87
|
+
config_type=HFLlamaConfig,
|
|
88
|
+
configs=ConfigMap(chat_template=DIRECT_TEMPLATE),
|
|
89
|
+
)
|
|
90
|
+
json_data = spec.to_json()
|
|
91
|
+
restored = ModelSpec.from_json(json_data)
|
|
92
|
+
assert restored.configs.chat_template == DIRECT_TEMPLATE
|
|
93
|
+
|
|
94
|
+
def test_roundtrip_with_file_spec_chat_template(self) -> None:
|
|
95
|
+
spec = ModelSpec(
|
|
96
|
+
vendor="Test",
|
|
97
|
+
family="Test",
|
|
98
|
+
name="Test",
|
|
99
|
+
size="1B",
|
|
100
|
+
repo="test/test",
|
|
101
|
+
config_type=HFLlamaConfig,
|
|
102
|
+
configs=ConfigMap(chat_template=FileSpec("chat_template.jinja")),
|
|
103
|
+
)
|
|
104
|
+
json_data = spec.to_json()
|
|
105
|
+
restored = ModelSpec.from_json(json_data)
|
|
106
|
+
assert isinstance(restored.configs.chat_template, FileSpec)
|
|
107
|
+
assert restored.configs.chat_template.filename == "chat_template.jinja"
|
|
108
|
+
|
|
109
|
+
def test_roundtrip_with_json_field_spec_chat_template(self) -> None:
|
|
110
|
+
spec = ModelSpec(
|
|
111
|
+
vendor="Test",
|
|
112
|
+
family="Test",
|
|
113
|
+
name="Test",
|
|
114
|
+
size="1B",
|
|
115
|
+
repo="test/test",
|
|
116
|
+
config_type=HFLlamaConfig,
|
|
117
|
+
configs=ConfigMap(chat_template=JSONFieldSpec(FileSpec("config.json"), "chat_template")),
|
|
118
|
+
)
|
|
119
|
+
json_data = spec.to_json()
|
|
120
|
+
restored = ModelSpec.from_json(json_data)
|
|
121
|
+
assert isinstance(restored.configs.chat_template, JSONFieldSpec)
|
|
122
|
+
assert restored.configs.chat_template.field_name == "chat_template"
|
|
123
|
+
assert restored.configs.chat_template.file_spec.filename == "config.json"
|
|
124
|
+
|
|
125
|
+
def test_from_json_with_string_chat_template(self) -> None:
|
|
126
|
+
json_data = {
|
|
127
|
+
"vendor": "Test",
|
|
128
|
+
"family": "Test",
|
|
129
|
+
"name": "Test",
|
|
130
|
+
"size": "1B",
|
|
131
|
+
"repo": "test/test",
|
|
132
|
+
"config_type": "HFLlamaConfig",
|
|
133
|
+
"configs": {
|
|
134
|
+
"chat_template": DIRECT_TEMPLATE,
|
|
135
|
+
},
|
|
136
|
+
}
|
|
137
|
+
spec = ModelSpec.from_json(json_data)
|
|
138
|
+
assert spec.configs.chat_template == DIRECT_TEMPLATE
|
|
139
|
+
|
|
140
|
+
def test_from_json_with_file_spec_chat_template(self) -> None:
|
|
141
|
+
json_data = {
|
|
142
|
+
"vendor": "Test",
|
|
143
|
+
"family": "Test",
|
|
144
|
+
"name": "Test",
|
|
145
|
+
"size": "1B",
|
|
146
|
+
"repo": "test/test",
|
|
147
|
+
"config_type": "HFLlamaConfig",
|
|
148
|
+
"configs": {
|
|
149
|
+
"chat_template": {"filename": "chat_template.jinja"},
|
|
150
|
+
},
|
|
151
|
+
}
|
|
152
|
+
spec = ModelSpec.from_json(json_data)
|
|
153
|
+
assert isinstance(spec.configs.chat_template, FileSpec)
|
|
154
|
+
assert spec.configs.chat_template.filename == "chat_template.jinja"
|
|
155
|
+
|
|
156
|
+
def test_from_json_with_json_field_spec_chat_template(self) -> None:
|
|
157
|
+
json_data = {
|
|
158
|
+
"vendor": "Test",
|
|
159
|
+
"family": "Test",
|
|
160
|
+
"name": "Test",
|
|
161
|
+
"size": "1B",
|
|
162
|
+
"repo": "test/test",
|
|
163
|
+
"config_type": "HFLlamaConfig",
|
|
164
|
+
"configs": {
|
|
165
|
+
"chat_template": {
|
|
166
|
+
"file_spec": {"filename": "config.json"},
|
|
167
|
+
"field_name": "chat_template",
|
|
168
|
+
},
|
|
169
|
+
},
|
|
170
|
+
}
|
|
171
|
+
spec = ModelSpec.from_json(json_data)
|
|
172
|
+
assert isinstance(spec.configs.chat_template, JSONFieldSpec)
|
|
173
|
+
assert spec.configs.chat_template.field_name == "chat_template"
|
|
@@ -14,7 +14,7 @@ from safetensors.flax import save_file
|
|
|
14
14
|
from lalamo.common import flatten_parameters
|
|
15
15
|
from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, import_model
|
|
16
16
|
from lalamo.model_import.model_specs import ModelType
|
|
17
|
-
from lalamo.models import
|
|
17
|
+
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
18
18
|
from lalamo.modules import config_converter
|
|
19
19
|
from tests.test_models import DType, ModelTestSpec
|
|
20
20
|
|
|
@@ -91,7 +91,7 @@ def test_model_conversion(test_spec: ModelTestSpec, tmp_path: pathlib.Path) -> N
|
|
|
91
91
|
match metadata.model_type:
|
|
92
92
|
case ModelType.LANGUAGE_MODEL:
|
|
93
93
|
model = LanguageModelConfig.load_model(tmp_path)
|
|
94
|
-
case ModelType.
|
|
95
|
-
model =
|
|
94
|
+
case ModelType.CLASSIFIER_MODEL:
|
|
95
|
+
model = ClassifierModelConfig.load_model(tmp_path)
|
|
96
96
|
assert model is not None, f"Failed to load model {model_repo_name}"
|
|
97
97
|
del model
|
|
@@ -13,7 +13,7 @@ import torch
|
|
|
13
13
|
from jaxtyping import Array
|
|
14
14
|
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssAttention
|
|
15
15
|
|
|
16
|
-
from lalamo import
|
|
16
|
+
from lalamo import ClassifierModel, LanguageModel, import_model
|
|
17
17
|
from lalamo.model_import.common import ModelType
|
|
18
18
|
from lalamo.modules.classifier import ClassifierActivationTrace, ClassifierResult
|
|
19
19
|
from lalamo.modules.decoder import (
|
|
@@ -477,8 +477,8 @@ def _test_model(test_spec: ModelTestSpec, model_tracer: type[ModelTracer]) -> No
|
|
|
477
477
|
)
|
|
478
478
|
err.throw()
|
|
479
479
|
|
|
480
|
-
case ModelType.
|
|
481
|
-
assert isinstance(model,
|
|
480
|
+
case ModelType.CLASSIFIER_MODEL:
|
|
481
|
+
assert isinstance(model, ClassifierModel)
|
|
482
482
|
err, inference_results = checkify_forward(model.model)(
|
|
483
483
|
token_ids=token_ids,
|
|
484
484
|
token_positions=token_positions,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{lalamo-0.5.7 → lalamo-0.5.8}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|