lalamo 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +5 -4
- lalamo/main.py +3 -3
- lalamo/model_import/common.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +31 -9
- lalamo/model_import/loaders/huggingface.py +1 -1
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +21 -2
- lalamo/model_import/model_specs/essential_ai.py +17 -0
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +3 -3
- lalamo/models/__init__.py +3 -3
- lalamo/models/{router.py → classifier.py} +8 -8
- lalamo/utils.py +7 -0
- {lalamo-0.5.7.dist-info → lalamo-0.5.9.dist-info}/METADATA +1 -1
- {lalamo-0.5.7.dist-info → lalamo-0.5.9.dist-info}/RECORD +19 -18
- {lalamo-0.5.7.dist-info → lalamo-0.5.9.dist-info}/WHEEL +0 -0
- {lalamo-0.5.7.dist-info → lalamo-0.5.9.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.7.dist-info → lalamo-0.5.9.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.7.dist-info → lalamo-0.5.9.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -8,24 +8,24 @@ from lalamo.message_processor import (
|
|
|
8
8
|
ToolSchema,
|
|
9
9
|
UserMessage,
|
|
10
10
|
)
|
|
11
|
-
from lalamo.model_import import ModelSpec
|
|
12
|
-
from lalamo.models import
|
|
11
|
+
from lalamo.model_import import ModelSpec, import_model
|
|
12
|
+
from lalamo.models import ClassifierModel, LanguageModel
|
|
13
13
|
from lalamo.speculator import (
|
|
14
14
|
CollectTracesEvent,
|
|
15
15
|
SpeculatorTrainingEvent,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
-
__version__ = "0.5.
|
|
18
|
+
__version__ = "0.5.9"
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"AssistantMessage",
|
|
22
|
+
"ClassifierModel",
|
|
22
23
|
"CollectTracesEvent",
|
|
23
24
|
"ContentBlock",
|
|
24
25
|
"Image",
|
|
25
26
|
"LanguageModel",
|
|
26
27
|
"Message",
|
|
27
28
|
"ModelSpec",
|
|
28
|
-
"Router",
|
|
29
29
|
"SpeculatorTrainingEvent",
|
|
30
30
|
"SystemMessage",
|
|
31
31
|
"ToolSchema",
|
|
@@ -33,5 +33,6 @@ __all__ = [
|
|
|
33
33
|
"collect_traces",
|
|
34
34
|
"convert",
|
|
35
35
|
"estimate_batchsize",
|
|
36
|
+
"import_model",
|
|
36
37
|
"train",
|
|
37
38
|
]
|
lalamo/main.py
CHANGED
|
@@ -43,7 +43,7 @@ from lalamo.model_import.common import (
|
|
|
43
43
|
InitializingModelEvent,
|
|
44
44
|
StatusEvent,
|
|
45
45
|
)
|
|
46
|
-
from lalamo.models import
|
|
46
|
+
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
47
47
|
from lalamo.modules import config_converter
|
|
48
48
|
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
49
49
|
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
@@ -149,7 +149,7 @@ def chat(
|
|
|
149
149
|
messages.append(model.message_processor.parse_response(model_response_text))
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
@app.command(help="Classify given message with a
|
|
152
|
+
@app.command(help="Classify given message with a Classifier type of model.")
|
|
153
153
|
def classify(
|
|
154
154
|
model_path: Annotated[
|
|
155
155
|
Path,
|
|
@@ -165,7 +165,7 @@ def classify(
|
|
|
165
165
|
transient=True,
|
|
166
166
|
) as progress:
|
|
167
167
|
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
168
|
-
model =
|
|
168
|
+
model = ClassifierModelConfig.load_model(model_path)
|
|
169
169
|
progress.remove_task(loading_task)
|
|
170
170
|
warmup_task = progress.add_task("🔥 Warming up...")
|
|
171
171
|
model.classify_chat([UserMessage(content="warmup message")])
|
lalamo/model_import/common.py
CHANGED
|
@@ -14,9 +14,10 @@ from jaxtyping import DTypeLike
|
|
|
14
14
|
from tokenizers import Tokenizer
|
|
15
15
|
|
|
16
16
|
from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
|
|
17
|
-
from lalamo.models import
|
|
17
|
+
from lalamo.models import ClassifierModel, ClassifierModelConfig, GenerationConfig, LanguageModel, LanguageModelConfig
|
|
18
18
|
from lalamo.modules import Classifier, Decoder, LalamoModule
|
|
19
19
|
from lalamo.quantization import QuantizationMode
|
|
20
|
+
from lalamo.utils import process_chat_template
|
|
20
21
|
|
|
21
22
|
from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
|
|
22
23
|
from .huggingface_generation_config import HFGenerationConfig
|
|
@@ -72,7 +73,8 @@ class ModelMetadata:
|
|
|
72
73
|
repo: str
|
|
73
74
|
use_cases: tuple[UseCase, ...]
|
|
74
75
|
model_type: ModelType
|
|
75
|
-
model_config: LanguageModelConfig |
|
|
76
|
+
model_config: LanguageModelConfig | ClassifierModelConfig
|
|
77
|
+
grammar_start_tokens: tuple[str, ...]
|
|
76
78
|
|
|
77
79
|
|
|
78
80
|
def download_file(
|
|
@@ -118,7 +120,7 @@ def download_config_file(
|
|
|
118
120
|
|
|
119
121
|
|
|
120
122
|
class ImportResults(NamedTuple):
|
|
121
|
-
model: LanguageModel |
|
|
123
|
+
model: LanguageModel | ClassifierModel
|
|
122
124
|
metadata: ModelMetadata
|
|
123
125
|
|
|
124
126
|
|
|
@@ -145,12 +147,15 @@ def import_message_processor(
|
|
|
145
147
|
case FileSpec(_) as file_spec:
|
|
146
148
|
chat_template_file = download_file(file_spec, model_spec.repo, output_dir)
|
|
147
149
|
prompt_template = chat_template_file.read_text()
|
|
150
|
+
case str() as template_string:
|
|
151
|
+
prompt_template = template_string
|
|
148
152
|
case None:
|
|
149
153
|
raise ValueError("No chat template specified.")
|
|
150
154
|
else:
|
|
151
155
|
if model_spec.configs.chat_template is not None:
|
|
152
156
|
raise ValueError("Conflicting chat template specifications.")
|
|
153
157
|
prompt_template = tokenizer_config.chat_template
|
|
158
|
+
prompt_template = process_chat_template(prompt_template)
|
|
154
159
|
tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
|
155
160
|
|
|
156
161
|
added_tokens = tokenizer_config.added_tokens()
|
|
@@ -263,14 +268,14 @@ def _import_language_model(
|
|
|
263
268
|
return language_model, language_model_config
|
|
264
269
|
|
|
265
270
|
|
|
266
|
-
def
|
|
271
|
+
def _import_classifier(
|
|
267
272
|
model_spec: ModelSpec,
|
|
268
273
|
*,
|
|
269
274
|
context_length: int | None = None,
|
|
270
275
|
precision: DTypeLike | None = None,
|
|
271
276
|
accumulation_precision: DTypeLike = jnp.float32,
|
|
272
277
|
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
273
|
-
) -> tuple[
|
|
278
|
+
) -> tuple[ClassifierModel, ClassifierModelConfig]:
|
|
274
279
|
foreign_classifier_config_file = download_config_file(model_spec)
|
|
275
280
|
foreign_classifier_config = model_spec.config_type.from_json(foreign_classifier_config_file)
|
|
276
281
|
assert isinstance(foreign_classifier_config, ForeignClassifierConfig)
|
|
@@ -293,12 +298,12 @@ def _import_router(
|
|
|
293
298
|
|
|
294
299
|
message_processor = import_message_processor(model_spec)
|
|
295
300
|
|
|
296
|
-
|
|
301
|
+
classifier_model_config = ClassifierModelConfig(
|
|
297
302
|
model_config=classifier.config,
|
|
298
303
|
message_processor_config=message_processor.config,
|
|
299
304
|
)
|
|
300
|
-
|
|
301
|
-
return
|
|
305
|
+
classifier_model = ClassifierModel(classifier_model_config, classifier, message_processor)
|
|
306
|
+
return classifier_model, classifier_model_config
|
|
302
307
|
|
|
303
308
|
|
|
304
309
|
def import_model(
|
|
@@ -324,8 +329,8 @@ def import_model(
|
|
|
324
329
|
accumulation_precision=accumulation_precision,
|
|
325
330
|
progress_callback=progress_callback,
|
|
326
331
|
)
|
|
327
|
-
case ModelType.
|
|
328
|
-
model, config =
|
|
332
|
+
case ModelType.CLASSIFIER_MODEL:
|
|
333
|
+
model, config = _import_classifier(
|
|
329
334
|
model_spec,
|
|
330
335
|
context_length=context_length,
|
|
331
336
|
precision=precision,
|
|
@@ -344,5 +349,6 @@ def import_model(
|
|
|
344
349
|
use_cases=model_spec.use_cases,
|
|
345
350
|
model_type=model_spec.model_type,
|
|
346
351
|
model_config=config,
|
|
352
|
+
grammar_start_tokens=model_spec.grammar_start_tokens,
|
|
347
353
|
)
|
|
348
354
|
return ImportResults(model, metadata)
|
|
@@ -10,7 +10,7 @@ from lalamo.modules.activations import GELU
|
|
|
10
10
|
from lalamo.modules.linear import FullPrecisionLinearConfig
|
|
11
11
|
from lalamo.modules.mlp import DenseMLPConfig
|
|
12
12
|
from lalamo.modules.normalization import NormalizationConfig, UpcastMode
|
|
13
|
-
from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
|
|
13
|
+
from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig, YARNRoPEConfig
|
|
14
14
|
from lalamo.modules.token_mixers.attention import AttentionConfig
|
|
15
15
|
from lalamo.modules.transformer_layer import TransformerLayerConfig
|
|
16
16
|
|
|
@@ -19,9 +19,6 @@ from .common import HuggingFaceLMConfig
|
|
|
19
19
|
__all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER = 6
|
|
23
|
-
|
|
24
|
-
|
|
25
22
|
def _round_to_bfloat16(x: float) -> float:
|
|
26
23
|
return jnp.asarray(x).astype(jnp.bfloat16).item()
|
|
27
24
|
|
|
@@ -32,6 +29,16 @@ class GemmaRoPEScalingConfig:
|
|
|
32
29
|
rope_type: Literal["linear"]
|
|
33
30
|
|
|
34
31
|
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class YarnRopeScalingConfig:
|
|
34
|
+
factor: float
|
|
35
|
+
beta_fast: float
|
|
36
|
+
beta_slow: float
|
|
37
|
+
original_max_position_embeddings: int
|
|
38
|
+
rope_type: Literal["yarn"]
|
|
39
|
+
truncate: bool = False
|
|
40
|
+
|
|
41
|
+
|
|
35
42
|
@dataclass(frozen=True)
|
|
36
43
|
class HFGemma3TextConfigRaw:
|
|
37
44
|
hidden_size: int
|
|
@@ -39,6 +46,7 @@ class HFGemma3TextConfigRaw:
|
|
|
39
46
|
model_type: Literal["gemma3_text"]
|
|
40
47
|
num_hidden_layers: int
|
|
41
48
|
sliding_window: int
|
|
49
|
+
sliding_window_pattern: int
|
|
42
50
|
rms_norm_eps: float = 1e-06
|
|
43
51
|
query_pre_attn_scalar: float = 256.0
|
|
44
52
|
attention_bias: bool = False
|
|
@@ -49,7 +57,7 @@ class HFGemma3TextConfigRaw:
|
|
|
49
57
|
max_position_embeddings: int = 131072
|
|
50
58
|
rope_theta: float = 1000000.0
|
|
51
59
|
rope_local_base_freq: float = 10000.0
|
|
52
|
-
rope_scaling: GemmaRoPEScalingConfig | None = None
|
|
60
|
+
rope_scaling: GemmaRoPEScalingConfig | YarnRopeScalingConfig | None = None
|
|
53
61
|
final_logit_softcapping: float | None = None
|
|
54
62
|
vocab_size: int = 262208
|
|
55
63
|
|
|
@@ -57,7 +65,7 @@ class HFGemma3TextConfigRaw:
|
|
|
57
65
|
def sliding_window_sizes(self) -> list[int | None]:
|
|
58
66
|
result = []
|
|
59
67
|
for i in range(self.num_hidden_layers):
|
|
60
|
-
if (i + 1) %
|
|
68
|
+
if (i + 1) % self.sliding_window_pattern == 0:
|
|
61
69
|
result.append(None)
|
|
62
70
|
else:
|
|
63
71
|
result.append(self.sliding_window)
|
|
@@ -74,7 +82,7 @@ class HFGemma3TextConfigRaw:
|
|
|
74
82
|
attention_scale = self.query_pre_attn_scalar**-0.5
|
|
75
83
|
embedding_config = TiedEmbeddingConfig(
|
|
76
84
|
input_scale=input_scale,
|
|
77
|
-
logit_soft_cap=
|
|
85
|
+
logit_soft_cap=self.final_logit_softcapping,
|
|
78
86
|
precision=activation_precision,
|
|
79
87
|
)
|
|
80
88
|
rms_norm_config = NormalizationConfig(
|
|
@@ -86,19 +94,33 @@ class HFGemma3TextConfigRaw:
|
|
|
86
94
|
subtract_mean=False,
|
|
87
95
|
)
|
|
88
96
|
|
|
89
|
-
if self.rope_scaling
|
|
97
|
+
if isinstance(self.rope_scaling, GemmaRoPEScalingConfig):
|
|
90
98
|
global_rope_config = LinearScalingRoPEConfig(
|
|
91
99
|
precision=activation_precision,
|
|
92
100
|
base=self.rope_theta,
|
|
93
101
|
max_sequence_length=self.max_position_embeddings,
|
|
94
102
|
scaling_factor=self.rope_scaling.factor,
|
|
95
103
|
)
|
|
96
|
-
|
|
104
|
+
elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
|
|
105
|
+
global_rope_config = YARNRoPEConfig(
|
|
106
|
+
precision=activation_precision,
|
|
107
|
+
base=self.rope_theta,
|
|
108
|
+
scaling_factor=self.rope_scaling.factor,
|
|
109
|
+
max_sequence_length=self.max_position_embeddings,
|
|
110
|
+
original_context_length=self.rope_scaling.original_max_position_embeddings,
|
|
111
|
+
beta_fast=self.rope_scaling.beta_fast,
|
|
112
|
+
beta_slow=self.rope_scaling.beta_slow,
|
|
113
|
+
truncate=self.rope_scaling.truncate,
|
|
114
|
+
)
|
|
115
|
+
elif self.rope_scaling is None:
|
|
97
116
|
global_rope_config = UnscaledRoPEConfig(
|
|
98
117
|
precision=activation_precision,
|
|
99
118
|
base=self.rope_theta,
|
|
100
119
|
max_sequence_length=context_length or self.max_position_embeddings,
|
|
101
120
|
)
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError("Invalid rope scaling configuration")
|
|
123
|
+
|
|
102
124
|
local_rope_config = UnscaledRoPEConfig(
|
|
103
125
|
precision=activation_precision,
|
|
104
126
|
base=self.rope_local_base_freq,
|
|
@@ -300,7 +300,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
|
|
|
300
300
|
down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
|
|
301
301
|
down_b = weights_dict[experts_path / "down_proj_bias"]
|
|
302
302
|
if down_b.ndim == 1:
|
|
303
|
-
down_b = jnp.broadcast_to(down_b, down_w.shape[:-1]
|
|
303
|
+
down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
|
|
304
304
|
|
|
305
305
|
down_projection = load_parameters(
|
|
306
306
|
lambda m: (m.weights, m.biases), # type: ignore
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
|
|
2
2
|
from .deepseek import DEEPSEEK_MODELS
|
|
3
|
+
from .essential_ai import RNJ_MODELS
|
|
3
4
|
from .gemma import GEMMA_MODELS
|
|
4
5
|
from .gpt_oss import GPT_OSS_MODELS
|
|
5
6
|
from .huggingface import HUGGINGFACE_MODELS
|
|
6
7
|
from .llama import LLAMA_MODELS
|
|
7
8
|
from .llamba import LLAMBA_MODELS
|
|
8
|
-
from .mirai import
|
|
9
|
+
from .mirai import MIRAI_CLASSIFIER_MODELS
|
|
9
10
|
from .mistral import MISTRAL_MODELS
|
|
10
11
|
|
|
11
12
|
# from .pleias import PLEIAS_MODELS
|
|
@@ -35,7 +36,8 @@ ALL_MODEL_LISTS = [
|
|
|
35
36
|
POLARIS_MODELS,
|
|
36
37
|
QWEN_MODELS,
|
|
37
38
|
REKA_MODELS,
|
|
38
|
-
|
|
39
|
+
MIRAI_CLASSIFIER_MODELS,
|
|
40
|
+
RNJ_MODELS,
|
|
39
41
|
]
|
|
40
42
|
|
|
41
43
|
ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
|
|
@@ -32,7 +32,7 @@ __all__ = [
|
|
|
32
32
|
|
|
33
33
|
class ModelType(StrEnum):
|
|
34
34
|
LANGUAGE_MODEL = "language_model"
|
|
35
|
-
|
|
35
|
+
CLASSIFIER_MODEL = "classifier_model"
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def cast_if_float(array: Array, cast_to: DTypeLike) -> Array:
|
|
@@ -84,7 +84,7 @@ class ConfigMap:
|
|
|
84
84
|
tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
|
|
85
85
|
tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
|
|
86
86
|
generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
|
|
87
|
-
chat_template: FileSpec | JSONFieldSpec | None = None
|
|
87
|
+
chat_template: FileSpec | JSONFieldSpec | str | None = None
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
def _is_foreign_config_type(t: object) -> bool:
|
|
@@ -114,12 +114,29 @@ def _unstructure_foreign_config_factory(t: object, c: cattrs.Converter) -> Calla
|
|
|
114
114
|
return _hook
|
|
115
115
|
|
|
116
116
|
|
|
117
|
+
def _structure_chat_template(value: object, _type: object) -> FileSpec | JSONFieldSpec | str | None:
|
|
118
|
+
if value is None:
|
|
119
|
+
return None
|
|
120
|
+
if isinstance(value, str):
|
|
121
|
+
return value
|
|
122
|
+
if isinstance(value, dict):
|
|
123
|
+
if "file_spec" in value and "field_name" in value:
|
|
124
|
+
return JSONFieldSpec(
|
|
125
|
+
file_spec=FileSpec(**value["file_spec"]),
|
|
126
|
+
field_name=value["field_name"],
|
|
127
|
+
)
|
|
128
|
+
if "filename" in value:
|
|
129
|
+
return FileSpec(**value)
|
|
130
|
+
raise ValueError(f"Invalid chat_template value: {value}")
|
|
131
|
+
|
|
132
|
+
|
|
117
133
|
@dataclass(frozen=True)
|
|
118
134
|
class ModelSpec:
|
|
119
135
|
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
120
136
|
|
|
121
137
|
_converter.register_structure_hook_factory(_is_foreign_config_type, _structure_foreign_config_factory)
|
|
122
138
|
_converter.register_unstructure_hook_factory(_is_foreign_config_type, _unstructure_foreign_config_factory)
|
|
139
|
+
_converter.register_structure_hook(FileSpec | JSONFieldSpec | str | None, _structure_chat_template)
|
|
123
140
|
|
|
124
141
|
vendor: str
|
|
125
142
|
family: str
|
|
@@ -137,6 +154,7 @@ class ModelSpec:
|
|
|
137
154
|
model_type: ModelType = ModelType.LANGUAGE_MODEL
|
|
138
155
|
configs: ConfigMap = field(default=ConfigMap())
|
|
139
156
|
use_cases: tuple[UseCase, ...] = tuple()
|
|
157
|
+
grammar_start_tokens: tuple[str, ...] = tuple()
|
|
140
158
|
|
|
141
159
|
@classmethod
|
|
142
160
|
def from_json(cls, json_data: dict) -> "ModelSpec":
|
|
@@ -162,6 +180,7 @@ def awq_model_spec(
|
|
|
162
180
|
configs=model_spec.configs,
|
|
163
181
|
weights_type=model_spec.weights_type,
|
|
164
182
|
use_cases=model_spec.use_cases,
|
|
183
|
+
grammar_start_tokens=model_spec.grammar_start_tokens,
|
|
165
184
|
)
|
|
166
185
|
|
|
167
186
|
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs.huggingface import HFGemma3TextConfig
|
|
2
|
+
|
|
3
|
+
from .common import ModelSpec
|
|
4
|
+
|
|
5
|
+
__all__ = ["RNJ_MODELS"]
|
|
6
|
+
|
|
7
|
+
RNJ_MODELS = [
|
|
8
|
+
ModelSpec(
|
|
9
|
+
vendor="EssentialAI",
|
|
10
|
+
family="Rnj-1",
|
|
11
|
+
name="Rnj-1-Instruct",
|
|
12
|
+
size="8B",
|
|
13
|
+
quantization=None,
|
|
14
|
+
repo="EssentialAI/rnj-1-instruct",
|
|
15
|
+
config_type=HFGemma3TextConfig,
|
|
16
|
+
),
|
|
17
|
+
]
|
|
@@ -2,9 +2,9 @@ from lalamo.model_import.decoder_configs.huggingface import ModernBERTConfig
|
|
|
2
2
|
|
|
3
3
|
from .common import ConfigMap, FileSpec, ModelSpec, ModelType
|
|
4
4
|
|
|
5
|
-
__all__ = ["
|
|
5
|
+
__all__ = ["MIRAI_CLASSIFIER_MODELS"]
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
MIRAI_CLASSIFIER_MODELS = [
|
|
8
8
|
ModelSpec(
|
|
9
9
|
vendor="trymirai",
|
|
10
10
|
family="ModernBERT",
|
|
@@ -14,7 +14,7 @@ MIRAI_ROUTER_MODELS = [
|
|
|
14
14
|
repo="trymirai/chat-moderation-router",
|
|
15
15
|
config_type=ModernBERTConfig,
|
|
16
16
|
use_cases=tuple(),
|
|
17
|
-
model_type=ModelType("
|
|
17
|
+
model_type=ModelType("classifier_model"),
|
|
18
18
|
configs=ConfigMap(chat_template=FileSpec("chat_template.jinja")),
|
|
19
19
|
),
|
|
20
20
|
]
|
lalamo/models/__init__.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
+
from .classifier import ClassifierModel, ClassifierModelConfig
|
|
1
2
|
from .language_model import GenerationConfig, LanguageModel, LanguageModelConfig
|
|
2
|
-
from .router import Router, RouterConfig
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
+
"ClassifierModel",
|
|
6
|
+
"ClassifierModelConfig",
|
|
5
7
|
"GenerationConfig",
|
|
6
8
|
"LanguageModel",
|
|
7
9
|
"LanguageModelConfig",
|
|
8
|
-
"Router",
|
|
9
|
-
"RouterConfig",
|
|
10
10
|
]
|
|
@@ -13,29 +13,29 @@ from lalamo.modules import Classifier, ClassifierConfig, LalamoModule
|
|
|
13
13
|
from .common import TextModel, TextModelConfig
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
17
|
-
"
|
|
16
|
+
"ClassifierModel",
|
|
17
|
+
"ClassifierModelConfig",
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@dataclass(frozen=True)
|
|
22
|
-
class
|
|
22
|
+
class ClassifierModelConfig(TextModelConfig[ClassifierConfig]):
|
|
23
23
|
def init(
|
|
24
24
|
self,
|
|
25
25
|
model: LalamoModule,
|
|
26
26
|
message_processor: MessageProcessor,
|
|
27
|
-
) -> "
|
|
27
|
+
) -> "ClassifierModel":
|
|
28
28
|
assert isinstance(model, Classifier)
|
|
29
|
-
return
|
|
29
|
+
return ClassifierModel(self, model, message_processor)
|
|
30
30
|
|
|
31
31
|
@classmethod
|
|
32
|
-
def load_model(cls, path: Path | str) -> "
|
|
32
|
+
def load_model(cls, path: Path | str) -> "ClassifierModel":
|
|
33
33
|
result = super().load_model(path)
|
|
34
|
-
assert isinstance(result,
|
|
34
|
+
assert isinstance(result, ClassifierModel)
|
|
35
35
|
return result
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
class
|
|
38
|
+
class ClassifierModel(TextModel[ClassifierModelConfig, Classifier]):
|
|
39
39
|
def label_output_logits(self, logits: Float[Array, "batch logits"]) -> dict[str, Float[Array, " batch"]]:
|
|
40
40
|
output_labels = self.model.config.output_labels
|
|
41
41
|
probabilities = jax.nn.sigmoid(logits)
|
lalamo/utils.py
CHANGED
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
|
24
24
|
"MapSequence",
|
|
25
25
|
"jax_uint4_to_packed_uint8",
|
|
26
26
|
"open_safetensors",
|
|
27
|
+
"process_chat_template",
|
|
27
28
|
]
|
|
28
29
|
|
|
29
30
|
|
|
@@ -159,3 +160,9 @@ def jax_uint8_to_unpacked_uint4(array: Array) -> Array:
|
|
|
159
160
|
)
|
|
160
161
|
|
|
161
162
|
return unpacked.astype(jnp.uint4)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def process_chat_template(template: str) -> str:
|
|
166
|
+
template = template.replace("{% generation %}", "")
|
|
167
|
+
template = template.replace("{%- endgeneration -%}", "")
|
|
168
|
+
return template
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=ANgYnkcN0qtWyEPNfJb_rcAmghdwvBrHUKE2WNN0zn4,814
|
|
2
2
|
lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
|
|
3
|
-
lalamo/main.py,sha256=
|
|
3
|
+
lalamo/main.py,sha256=GgUT7lT48-XQuAEH7qzsDKG8Lx9iBf-sYBIRhZL9q7E,23978
|
|
4
4
|
lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
|
|
5
5
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
6
6
|
lalamo/registry_abc.py,sha256=ENjXiD_wEH100fNjG-W5Em1L_EQ0Lf0pdRhRGvf3qZk,2197
|
|
7
7
|
lalamo/sampling.py,sha256=g_dNiJyZrRqoQIiLid4cr6nRT9N5tSz3GtHr8Bt4n-E,3404
|
|
8
|
-
lalamo/utils.py,sha256=
|
|
8
|
+
lalamo/utils.py,sha256=QwATVXAeHBsQEDyt_31SHgxFphFVZYHpv3ZaklXks9Y,4585
|
|
9
9
|
lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
|
|
10
10
|
lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
|
|
11
11
|
lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
|
|
12
12
|
lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
|
|
13
13
|
lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
|
|
14
|
-
lalamo/model_import/common.py,sha256=
|
|
14
|
+
lalamo/model_import/common.py,sha256=wvyGD-iLut_Pm3HjDMI05upqdtCW3HWeoeB0YmiFeqk,12419
|
|
15
15
|
lalamo/model_import/huggingface_generation_config.py,sha256=mot6VQ6ezCtEhN6VjhnvaU-nR5P5T2BuBUgpFNnWJxU,1495
|
|
16
16
|
lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
|
|
17
17
|
lalamo/model_import/decoder_configs/__init__.py,sha256=1ZqMcEHvCJjMIZ9iNyY31XMXOaFxB-NbqIU01BtmcEk,641
|
|
@@ -20,7 +20,7 @@ lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDj
|
|
|
20
20
|
lalamo/model_import/decoder_configs/huggingface/__init__.py,sha256=3H7GPTFNNahEvI8D1SGg2mGBgPhsIdZ213MglwbGDlE,645
|
|
21
21
|
lalamo/model_import/decoder_configs/huggingface/common.py,sha256=YYIDEQy8x7lqL2qtxUHrNqfjZEiizBZ_26sTqOzjRtQ,3792
|
|
22
22
|
lalamo/model_import/decoder_configs/huggingface/gemma2.py,sha256=g8LH_GlSNyL04WWi596zI0rWsD3ahnfNjDk-9zZNcDE,4759
|
|
23
|
-
lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=
|
|
23
|
+
lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=aSZ0TtpgDYA10rHi8eD0C_Jsn48siM_HXqfZ4O7nh94,8372
|
|
24
24
|
lalamo/model_import/decoder_configs/huggingface/gpt_oss.py,sha256=MBCoPbuWyzbJiBRtHOtpaPHJjQ1UVCAYcVrfIejTnlQ,7446
|
|
25
25
|
lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=UPeQiz2Dix8YaZYRxn9z44OZJ6c4xBQmcUZcM0Ymvh4,6934
|
|
26
26
|
lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=ANB-vQK8U-zVFubZSTDXXt2S70T5SVOGzf7eOVvPzIQ,5773
|
|
@@ -31,26 +31,27 @@ lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=lySVO-TvusAYUjDn
|
|
|
31
31
|
lalamo/model_import/loaders/__init__.py,sha256=3THc1wQ4EPBzQkL_4EaKCa7Ev5Z7oczcvc4AHy9v5EI,228
|
|
32
32
|
lalamo/model_import/loaders/common.py,sha256=kkugV-bMQlN1zvGHoj3uc7z0FbXKoMtXEBTvyu4KxK4,1844
|
|
33
33
|
lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFlYrqtWyNkBU_8,9219
|
|
34
|
-
lalamo/model_import/loaders/huggingface.py,sha256=
|
|
34
|
+
lalamo/model_import/loaders/huggingface.py,sha256=QURyxD3C4Nzwa8k9iHVx32hQHV-aMWjb29W5_U99-WA,29834
|
|
35
35
|
lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
|
|
36
|
-
lalamo/model_import/model_specs/__init__.py,sha256=
|
|
37
|
-
lalamo/model_import/model_specs/common.py,sha256=
|
|
36
|
+
lalamo/model_import/model_specs/__init__.py,sha256=8RxLEZUxpsBtTwrTUqGIwhQ-8QzOxUdx-EL__cbcTjg,1228
|
|
37
|
+
lalamo/model_import/model_specs/common.py,sha256=RVPlNWHG_5OvU1W3YcOpqYz59Dh8plDmd7z1xNrqmaY,6585
|
|
38
38
|
lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
|
|
39
|
+
lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
|
|
39
40
|
lalamo/model_import/model_specs/gemma.py,sha256=irWgylL-pc7y3Gn5DK3fjKoCT9kJWH3B7mTa-1Gmxqc,1306
|
|
40
41
|
lalamo/model_import/model_specs/gpt_oss.py,sha256=PLo0QGrXKdX61ReTRdyOaP_EH3Dmj5lp3fpJjZRwRVA,542
|
|
41
|
-
lalamo/model_import/model_specs/huggingface.py,sha256=
|
|
42
|
+
lalamo/model_import/model_specs/huggingface.py,sha256=TEkU8y95_hmUWyF-Q5hn0dE2SvXbApghAsQwhWRu4D0,431
|
|
42
43
|
lalamo/model_import/model_specs/llama.py,sha256=Ml-xvRGlXBT9NJhmEpwgNo6C84oBSMYgA1_PrCYGcAw,990
|
|
43
44
|
lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
|
|
44
|
-
lalamo/model_import/model_specs/mirai.py,sha256=
|
|
45
|
+
lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
|
|
45
46
|
lalamo/model_import/model_specs/mistral.py,sha256=HAojorjOqsJn2DoMBzYRw8A70qCslhFEsE9AF5xumlg,1278
|
|
46
47
|
lalamo/model_import/model_specs/pleias.py,sha256=5sRpZGYwLdsav6bLiW-459y1Cs9iJKgKkBIuGsOxtsQ,368
|
|
47
48
|
lalamo/model_import/model_specs/polaris.py,sha256=Mw1-6bByjDmPIKlIUIV46CsmV5xUp_laI5Qquo5DmAQ,520
|
|
48
49
|
lalamo/model_import/model_specs/qwen.py,sha256=qzLmTveATmnwNFQSFJlffcXw7syFnrCmKf9ggkkkw1Y,7050
|
|
49
50
|
lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
|
|
50
|
-
lalamo/models/__init__.py,sha256=
|
|
51
|
+
lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
|
|
52
|
+
lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
|
|
51
53
|
lalamo/models/common.py,sha256=PDteofGxjSBWYw_mPxbN1DTUba70aOURrAIjl13SSHc,2954
|
|
52
54
|
lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
|
|
53
|
-
lalamo/models/router.py,sha256=7KZqHVhr2TA7Qh76KfwrvyfztfZnV-P-Ss11O8dzbRg,2013
|
|
54
55
|
lalamo/modules/__init__.py,sha256=xWJ4OPAF4gKd0evYwXIK5kTnbH6nI55oLAePcoDDHQ0,3730
|
|
55
56
|
lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
|
|
56
57
|
lalamo/modules/classifier.py,sha256=_jtJ3INEq1dJP5HpUmcDk9YYzpRYlQ04zvFGaWBV6Lg,12101
|
|
@@ -80,9 +81,9 @@ lalamo/speculator/estimator.py,sha256=4D8dPZCWsrpORb7y8pQ6VsiIg1Cblvvxe6gXCoYtcD
|
|
|
80
81
|
lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
|
|
81
82
|
lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
|
|
82
83
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
83
|
-
lalamo-0.5.
|
|
84
|
-
lalamo-0.5.
|
|
85
|
-
lalamo-0.5.
|
|
86
|
-
lalamo-0.5.
|
|
87
|
-
lalamo-0.5.
|
|
88
|
-
lalamo-0.5.
|
|
84
|
+
lalamo-0.5.9.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
85
|
+
lalamo-0.5.9.dist-info/METADATA,sha256=573oeEuYV14_hFpPmW2CNVZWciVS4_V85597oKOvjpo,3146
|
|
86
|
+
lalamo-0.5.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
87
|
+
lalamo-0.5.9.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
88
|
+
lalamo-0.5.9.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
89
|
+
lalamo-0.5.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|