lalamo 0.5.17__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/commands.py +69 -17
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +148 -27
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +4 -2
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/language_model.py +7 -6
  16. lalamo/modules/activations.py +1 -1
  17. lalamo/modules/classifier.py +11 -24
  18. lalamo/modules/common.py +4 -1
  19. lalamo/modules/decoder.py +5 -11
  20. lalamo/modules/embedding.py +25 -62
  21. lalamo/modules/linear.py +19 -33
  22. lalamo/modules/mlp.py +9 -19
  23. lalamo/modules/mlx_interop.py +1 -1
  24. lalamo/modules/rope.py +1 -1
  25. lalamo/modules/token_mixers/__init__.py +1 -1
  26. lalamo/modules/token_mixers/attention.py +9 -27
  27. lalamo/modules/token_mixers/mamba.py +9 -24
  28. lalamo/modules/token_mixers/short_conv.py +5 -12
  29. lalamo/modules/transformer.py +10 -20
  30. lalamo/modules/transformer_layer.py +8 -20
  31. lalamo/registry_abc.py +4 -4
  32. lalamo/sampling.py +14 -0
  33. lalamo/speculator/estimator.py +3 -3
  34. lalamo/speculator/ngram.py +1 -1
  35. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -1
  36. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/RECORD +40 -40
  37. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
  38. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
  39. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
  40. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/__init__.py CHANGED
@@ -27,7 +27,7 @@ from lalamo.speculator import (
27
27
  SpeculatorTrainingEvent,
28
28
  )
29
29
 
30
- __version__ = "0.5.17"
30
+ __version__ = "0.6.0"
31
31
 
32
32
  __all__ = [
33
33
  "AssistantMessage",
lalamo/commands.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import json
2
- from collections.abc import Callable
2
+ from collections.abc import Callable, Iterable
3
3
  from dataclasses import dataclass
4
4
  from enum import Enum
5
5
  from itertools import chain
@@ -10,7 +10,7 @@ from jaxtyping import DTypeLike
10
10
  from lalamo.common import flatten_parameters
11
11
  from lalamo.data import import_hf_parquet
12
12
  from lalamo.data.lalamo_completions import LalamoCompletion
13
- from lalamo.message_processor import UserMessage
13
+ from lalamo.message_processor import Message
14
14
  from lalamo.model_import import ModelMetadata, ModelSpec, import_model
15
15
  from lalamo.model_import.common import (
16
16
  DownloadingFileEvent,
@@ -41,8 +41,6 @@ class ConversionCallbacks:
41
41
  output_dir: Path
42
42
  precision: Precision | None
43
43
  context_length: int | None
44
- include_traces: bool
45
- message_for_trace: str | None
46
44
 
47
45
  def started(self) -> None:
48
46
  pass
@@ -74,16 +72,12 @@ def convert(
74
72
  output_dir: Path,
75
73
  precision: Precision | None = None,
76
74
  context_length: int | None = None,
77
- include_traces: bool = False,
78
- message_for_trace: str | None = None,
79
75
  callbacks_type: Callable[
80
76
  [
81
77
  ModelSpec,
82
78
  Path,
83
79
  Precision | None,
84
80
  int | None,
85
- bool,
86
- str | None,
87
81
  ],
88
82
  ConversionCallbacks,
89
83
  ] = ConversionCallbacks,
@@ -93,8 +87,6 @@ def convert(
93
87
  output_dir,
94
88
  precision,
95
89
  context_length,
96
- include_traces,
97
- message_for_trace,
98
90
  )
99
91
 
100
92
  if precision is not None:
@@ -127,13 +119,6 @@ def convert(
127
119
  callbacks.saving_model()
128
120
  output_dir.mkdir(parents=True, exist_ok=True)
129
121
 
130
- if include_traces:
131
- message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
132
- result = model.record_trace(message)
133
- traces = flatten_parameters(result.export())
134
- with Path(output_dir / "traces.safetensors").open("wb") as fd:
135
- safe_write(fd, traces)
136
-
137
122
  model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
138
123
  weights = flatten_parameters(model.export_weights())
139
124
  del model
@@ -148,6 +133,73 @@ def convert(
148
133
  callbacks.finished_saving_model()
149
134
 
150
135
 
136
+ @dataclass
137
+ class TraceCallbacks:
138
+ model_path: Path
139
+ output_path: Path
140
+ messages: Iterable[Message] | None
141
+
142
+ def output_exists(self) -> None:
143
+ raise RuntimeError(f"{self.output_path=} already exists, refusing to overwrite!")
144
+
145
+ def started(self) -> None:
146
+ pass
147
+
148
+ def loading_model(self) -> None:
149
+ pass
150
+
151
+ def finished_loading_model(self) -> None:
152
+ pass
153
+
154
+ def tracing_model(self) -> None:
155
+ pass
156
+
157
+ def finished_tracing_model(self) -> None:
158
+ pass
159
+
160
+ def saving_trace(self) -> None:
161
+ pass
162
+
163
+ def finished_saving_trace(self) -> None:
164
+ pass
165
+
166
+
167
+ def trace(
168
+ model_path: Path,
169
+ output_path: Path,
170
+ messages: Iterable[Message] | None = None,
171
+ callbacks_type: Callable[
172
+ [
173
+ Path,
174
+ Path,
175
+ Iterable[Message] | None,
176
+ ],
177
+ TraceCallbacks,
178
+ ] = TraceCallbacks,
179
+ ) -> None:
180
+ callbacks = callbacks_type(model_path, output_path, messages)
181
+
182
+ if output_path.exists():
183
+ callbacks.output_exists()
184
+
185
+ callbacks.started()
186
+
187
+ callbacks.loading_model()
188
+ model = LanguageModelConfig.load_model(model_path)
189
+ callbacks.finished_loading_model()
190
+
191
+ callbacks.tracing_model()
192
+ result = model.record_trace(messages)
193
+ callbacks.finished_tracing_model()
194
+
195
+ callbacks.saving_trace()
196
+ traces = flatten_parameters(result.export())
197
+ output_path.parent.mkdir(parents=True, exist_ok=True)
198
+ with Path(output_path).open("wb") as fd:
199
+ safe_write(fd, traces)
200
+ callbacks.finished_saving_trace()
201
+
202
+
151
203
  @dataclass
152
204
  class EstimateBatchsizeCallbacks:
153
205
  model_path: Path
lalamo/common.py CHANGED
@@ -15,6 +15,8 @@ __all__ = [
15
15
  "ParameterTree",
16
16
  "dummy_array",
17
17
  "flatten_parameters",
18
+ "require_array",
19
+ "require_tree",
18
20
  "unflatten_parameters",
19
21
  ]
20
22
 
@@ -29,6 +31,16 @@ type ParameterTree[ArrayType: ArrayLike] = (
29
31
  )
30
32
 
31
33
 
34
+ def require_array[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ArrayType:
35
+ assert not isinstance(value, (Mapping, Sequence))
36
+ return value
37
+
38
+
39
+ def require_tree[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ParameterTree[ArrayType]:
40
+ assert not isinstance(value, (Array, ShapeDtypeStruct))
41
+ return value
42
+
43
+
32
44
  def dummy_array(shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
33
45
  if isinstance(shape, int):
34
46
  shape = (shape,)
@@ -40,9 +52,10 @@ def flatten_parameters[ArrayType: ArrayLike](nested_parameters: ParameterTree[Ar
40
52
  if not isinstance(nested_parameters, Mapping):
41
53
  nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
42
54
  for key, value in nested_parameters.items():
55
+ value = cast("ArrayType | ParameterTree[ArrayType]", value)
43
56
  key_path = ParameterPath(key)
44
57
  if isinstance(value, (Array, ShapeDtypeStruct)):
45
- result[key_path] = value
58
+ result[key_path] = cast("ArrayType", value)
46
59
  else:
47
60
  update: dict[str, ArrayType] = {
48
61
  str(key_path / subkey): subvalue for subkey, subvalue in flatten_parameters(value).items()
lalamo/main.py CHANGED
@@ -36,11 +36,13 @@ from lalamo.commands import (
36
36
  ConversionCallbacks,
37
37
  EstimateBatchsizeCallbacks,
38
38
  Precision,
39
+ TraceCallbacks,
39
40
  TrainCallbacks,
40
41
  )
41
42
  from lalamo.commands import collect_traces as _collect_traces
42
43
  from lalamo.commands import convert as _convert
43
44
  from lalamo.commands import estimate_batchsize as _estimate_batchsize
45
+ from lalamo.commands import trace as _trace
44
46
  from lalamo.commands import train as _train
45
47
  from lalamo.data.lalamo_completions import LalamoCompletion
46
48
  from lalamo.message_processor import UserMessage
@@ -83,7 +85,7 @@ class ModelParser(ParamType):
83
85
  f"\n\nUse the `{SCRIPT_NAME} list-models` command to see the list of currently supported models.",
84
86
  )
85
87
  error_message = "".join(error_message_parts)
86
- self.fail(error_message, param, ctx)
88
+ return self.fail(error_message, param, ctx)
87
89
  return result
88
90
 
89
91
 
@@ -111,10 +113,18 @@ def chat(
111
113
  metavar="MODEL_PATH",
112
114
  ),
113
115
  ],
116
+ message: Annotated[
117
+ str | None,
118
+ Option(
119
+ help="Message for non-interactive mode",
120
+ show_default="None, run interactively",
121
+ ),
122
+ ] = None,
114
123
  ) -> None:
115
124
  with Progress(
116
125
  SpinnerColumn(),
117
126
  TextColumn("[progress.description]{task.description}"),
127
+ console=err_console,
118
128
  transient=True,
119
129
  ) as progress:
120
130
  loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
@@ -123,21 +133,28 @@ def chat(
123
133
  warmup_task = progress.add_task("🔥 Warming up compilation cache...")
124
134
  list(model.stream_reply_text([UserMessage("")], max_output_length=1))
125
135
  progress.remove_task(warmup_task)
126
- console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
127
- messages = []
128
- while True:
129
- user_text = console.input("[cyan]user> [/cyan]")
130
- user_message = UserMessage(user_text)
131
- messages.append(user_message)
132
136
 
133
- console.print("[red]assistant> [/red]", end="")
134
- model_response_tokens = []
135
- for token in model.stream_reply_text(messages):
137
+ if message is None:
138
+ console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
139
+
140
+ messages = []
141
+ while True:
142
+ user_text = console.input("[cyan]user> [/cyan]")
143
+ user_message = UserMessage(user_text)
144
+ messages.append(user_message)
145
+
146
+ console.print("[red]assistant> [/red]", end="")
147
+ model_response_tokens = []
148
+ for token in model.stream_reply_text(messages):
149
+ console.print(token, end="")
150
+ model_response_tokens.append(token)
151
+ console.print()
152
+ model_response_text = "".join(model_response_tokens)
153
+ messages.append(model.message_processor.parse_response(model_response_text))
154
+ else:
155
+ for token in model.stream_reply_text([UserMessage(message)]):
136
156
  console.print(token, end="")
137
- model_response_tokens.append(token)
138
157
  console.print()
139
- model_response_text = "".join(model_response_tokens)
140
- messages.append(model.message_processor.parse_response(model_response_text))
141
158
 
142
159
 
143
160
  @app.command(help="Classify given message with a Classifier type of model.")
@@ -178,6 +195,7 @@ class CliConversionCallbacks(ConversionCallbacks):
178
195
  overwrite: bool = False
179
196
 
180
197
  stack: ExitStack = field(default_factory=ExitStack)
198
+ progress: Progress | None = None
181
199
  downloading_tasks: dict[FileSpec, TaskID] = field(default_factory=dict)
182
200
  initializing_task: TaskID | None = None
183
201
  saving_task: TaskID | None = None
@@ -211,23 +229,33 @@ class CliConversionCallbacks(ConversionCallbacks):
211
229
  shutil.rmtree(self.output_dir)
212
230
 
213
231
  def downloading(self, file_spec: FileSpec) -> None:
232
+ assert self.progress is not None
233
+
214
234
  self.downloading_tasks[file_spec] = self.progress.add_task(f"Retrieving {file_spec.filename}...")
215
235
 
216
236
  def finished_downloading(self, file_spec: FileSpec) -> None:
237
+ assert self.progress is not None
238
+
217
239
  self.progress.remove_task(self.downloading_tasks[file_spec])
218
240
 
219
241
  def initializing_model(self) -> None:
242
+ assert self.progress is not None
243
+
220
244
  self.initializing_task = self.progress.add_task("Initializing model...")
221
245
 
222
246
  def finished_initializing_model(self) -> None:
247
+ assert self.progress is not None
223
248
  assert self.initializing_task is not None
224
249
 
225
250
  self.progress.remove_task(self.initializing_task)
226
251
 
227
252
  def saving_model(self) -> None:
253
+ assert self.progress is not None
254
+
228
255
  self.saving_task = self.progress.add_task(f"💾 Saving the model to {self.output_dir}")
229
256
 
230
257
  def finished_saving_model(self) -> None:
258
+ assert self.progress is not None
231
259
  assert self.saving_task is not None
232
260
 
233
261
  self.progress.remove_task(self.saving_task)
@@ -272,24 +300,12 @@ def convert(
272
300
  show_default="Model's native maximum context length.",
273
301
  ),
274
302
  ] = None,
275
- include_traces: Annotated[
276
- bool,
277
- Option(
278
- help="Export activation traces for debugging purposes.",
279
- ),
280
- ] = False,
281
303
  overwrite: Annotated[
282
304
  bool,
283
305
  Option(
284
306
  help="Overwrite existing model files.",
285
307
  ),
286
308
  ] = False,
287
- message_for_trace: Annotated[
288
- str | None,
289
- Option(
290
- help="Text message to use as prompt when recording trace",
291
- ),
292
- ] = None,
293
309
  ) -> None:
294
310
  if output_dir is None:
295
311
  output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
@@ -299,12 +315,117 @@ def convert(
299
315
  output_dir,
300
316
  precision,
301
317
  context_length,
302
- include_traces,
303
- message_for_trace,
304
318
  partial(CliConversionCallbacks, overwrite=overwrite),
305
319
  )
306
320
 
307
321
 
322
+ @dataclass
323
+ class CliTraceCallbacks(TraceCallbacks):
324
+ overwrite: bool = False
325
+
326
+ stack: ExitStack = field(default_factory=ExitStack)
327
+ progress: Progress | None = None
328
+ loading_task: TaskID | None = None
329
+ tracing_task: TaskID | None = None
330
+ saving_task: TaskID | None = None
331
+
332
+ def output_exists(self) -> None:
333
+ if not self.overwrite and not Confirm().ask(
334
+ rf"⚠️ Output [cyan]{self.output_path}[/cyan] already exists."
335
+ r" Do you want to overwrite it?",
336
+ ):
337
+ raise Exit
338
+
339
+ self.output_path.unlink()
340
+
341
+ def started(self) -> None:
342
+ console.print(f"🔍 Tracing [cyan]{self.model_path}[/cyan]")
343
+
344
+ self.progress = self.stack.enter_context(
345
+ Progress(
346
+ SpinnerColumn(),
347
+ TextColumn("[progress.description]{task.description}"),
348
+ transient=True,
349
+ ),
350
+ )
351
+
352
+ def loading_model(self) -> None:
353
+ assert self.progress is not None
354
+
355
+ self.loading_task = self.progress.add_task("🧠 Loading model...")
356
+
357
+ def finished_loading_model(self) -> None:
358
+ assert self.progress is not None
359
+ assert self.loading_task is not None
360
+
361
+ self.progress.remove_task(self.loading_task)
362
+
363
+ def tracing_model(self) -> None:
364
+ assert self.progress is not None
365
+
366
+ self.tracing_task = self.progress.add_task("🔍 Recording trace...")
367
+
368
+ def finished_tracing_model(self) -> None:
369
+ assert self.progress is not None
370
+ assert self.tracing_task is not None
371
+
372
+ self.progress.remove_task(self.tracing_task)
373
+
374
+ def saving_trace(self) -> None:
375
+ assert self.progress is not None
376
+
377
+ self.saving_task = self.progress.add_task(f"💾 Saving trace to {self.output_path}")
378
+
379
+ def finished_saving_trace(self) -> None:
380
+ assert self.progress is not None
381
+ assert self.saving_task is not None
382
+
383
+ self.progress.remove_task(self.saving_task)
384
+ self.stack.close()
385
+ console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
386
+
387
+ @app.command(help="Trace a model.")
388
+ def trace(
389
+ model_path: Annotated[
390
+ Path,
391
+ Argument(
392
+ help="Path to the model directory.",
393
+ metavar="MODEL_PATH",
394
+ ),
395
+ ],
396
+ output_path: Annotated[
397
+ Path | None,
398
+ Option(
399
+ help="Path to save the trace to.",
400
+ show_default="${MODEL_PATH}/traces.safetensors",
401
+ ),
402
+ ] = None,
403
+ overwrite: Annotated[
404
+ bool,
405
+ Option(
406
+ help="Overwrite existing trace file.",
407
+ ),
408
+ ] = False,
409
+ message: Annotated[
410
+ str | None,
411
+ Option(
412
+ help="Text message to use as prompt when recording trace",
413
+ ),
414
+ ] = None,
415
+ ) -> None:
416
+ if output_path is None:
417
+ output_path = model_path / "traces.safetensors"
418
+
419
+ messages = None if message is None else [UserMessage(content=message)]
420
+
421
+ _trace(
422
+ model_path,
423
+ output_path,
424
+ messages,
425
+ partial(CliTraceCallbacks, overwrite=overwrite),
426
+ )
427
+
428
+
308
429
  def _model_size_string_to_int(
309
430
  size_str: str,
310
431
  _regex: re.Pattern = re.compile(r"(?P<number>(\d+)(\.\d*)?)(?P<suffix>[KMBT])"),
@@ -169,7 +169,10 @@ class MessageProcessor:
169
169
  def __post_init__(self) -> None:
170
170
  if self.output_parser_regex is not None:
171
171
  all_fields = AssistantMessage.__dataclass_fields__
172
- required_fields = {k: v for k, v in all_fields.items() if v.type == v.type | None}
172
+ # NOTE: str type annotations are assumed to be required
173
+ required_fields = {
174
+ k: v for k, v in all_fields.items() if isinstance(v.type, str) or v.type == (v.type | None)
175
+ }
173
176
  named_groups = self.output_parser_regex.groupindex
174
177
  invalid_groups = set(named_groups) - set(all_fields)
175
178
  if invalid_groups:
@@ -3,7 +3,7 @@ import json
3
3
  from collections import ChainMap
4
4
  from collections.abc import Callable
5
5
  from contextlib import ExitStack
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass, replace
7
7
  from pathlib import Path
8
8
  from typing import NamedTuple
9
9
 
@@ -20,7 +20,7 @@ from lalamo.quantization import QuantizationMode
20
20
  from lalamo.utils import process_chat_template
21
21
 
22
22
  from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
23
- from .huggingface_generation_config import HFGenerationConfig
23
+ from .huggingface_generation_config import HFGenerationConfig, _policy_from_hf_config
24
24
  from .huggingface_tokenizer_config import HFTokenizerConfig
25
25
  from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, ModelType, UseCase
26
26
  from .model_specs.common import JSONFieldSpec
@@ -34,6 +34,7 @@ __all__ = [
34
34
  "ModelSpec",
35
35
  "ModelType",
36
36
  "StatusEvent",
37
+ "download_file",
37
38
  "import_model",
38
39
  ]
39
40
 
@@ -239,24 +240,14 @@ def _import_language_model(
239
240
 
240
241
  stop_token_ids = tuple(foreign_decoder_config.eos_token_ids)
241
242
 
242
- if model_spec.configs.generation_config is not None:
243
+ if isinstance(model_spec.configs.generation_config, GenerationConfig):
244
+ generation_config = replace(model_spec.configs.generation_config, stop_token_ids=stop_token_ids)
245
+ elif isinstance(model_spec.configs.generation_config, FileSpec):
243
246
  hf_generation_config_file = download_file(model_spec.configs.generation_config, model_spec.repo)
244
247
  hf_generation_config = HFGenerationConfig.from_json(hf_generation_config_file)
245
- generation_config = GenerationConfig(
246
- stop_token_ids=stop_token_ids,
247
- temperature=hf_generation_config.temperature,
248
- top_p=hf_generation_config.top_p,
249
- top_k=hf_generation_config.top_k,
250
- banned_tokens=None,
251
- )
248
+ generation_config = _policy_from_hf_config(hf_generation_config, stop_token_ids)
252
249
  else:
253
- generation_config = GenerationConfig(
254
- stop_token_ids=stop_token_ids,
255
- temperature=None,
256
- top_p=None,
257
- top_k=None,
258
- banned_tokens=None,
259
- )
250
+ generation_config = GenerationConfig(stop_token_ids)
260
251
 
261
252
  language_model_config = LanguageModelConfig(
262
253
  model_config=decoder.config,
@@ -2,6 +2,7 @@ from collections.abc import Mapping
2
2
  from dataclasses import dataclass
3
3
  from typing import Literal
4
4
 
5
+ import jax.numpy as jnp
5
6
  from jaxtyping import DTypeLike
6
7
 
7
8
  from lalamo.modules import (
@@ -50,7 +51,6 @@ class HFLFM2Config(HuggingFaceLMConfig):
50
51
  conv_L_cache: int # noqa: N815
51
52
  conv_bias: bool
52
53
  conv_dim: int
53
- conv_dim_out: int
54
54
  conv_use_xavier_init: bool
55
55
  eos_token_id: int
56
56
  hidden_size: int
@@ -64,13 +64,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
64
64
  num_key_value_heads: int
65
65
  pad_token_id: int
66
66
  rope_theta: float
67
- torch_dtype: Literal["bfloat16"]
68
67
  transformers_version: str
69
68
  use_cache: bool
70
69
  use_pos_enc: bool
71
70
  vocab_size: int
72
71
 
72
+ dtype: Literal["bfloat16", "float16", "float32"] | None = None
73
+ torch_dtype: Literal["bfloat16", "float16", "float32"] | None = None
73
74
  intermediate_size: int | None = None
75
+ conv_dim_out: int | None = None
74
76
  layer_types: list[Literal["conv", "full_attention"]] | None = None
75
77
  full_attn_idxs: list[int] | None = None
76
78
  tie_embedding: bool = True
@@ -79,6 +81,14 @@ class HFLFM2Config(HuggingFaceLMConfig):
79
81
  quantization: QuantizationConfig | None = None
80
82
  quantization_config: QuantizationConfig | None = None
81
83
 
84
+ @property
85
+ def default_precision(self) -> DTypeLike:
86
+ assert self.dtype is not None or self.torch_dtype is not None, (
87
+ "at least one of dtype or torch_dtype must be specified"
88
+ )
89
+
90
+ return jnp.dtype(self.dtype or self.torch_dtype)
91
+
82
92
  def to_decoder_config(
83
93
  self,
84
94
  context_length: int | None,
@@ -200,8 +210,8 @@ class HFLFM2Config(HuggingFaceLMConfig):
200
210
  subtract_mean=False,
201
211
  )
202
212
 
203
- if self.intermediate_size is not None:
204
- hidden_dim = self.intermediate_size
213
+ if not self.block_auto_adjust_ff_dim:
214
+ hidden_dim = self.intermediate_size or self.block_ff_dim
205
215
  else:
206
216
  hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
207
217
  hidden_dim = int(
@@ -76,7 +76,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
76
76
  logit_soft_cap=None,
77
77
  group_size=int(metadata_dict["quantization_kwargs.group_size"]),
78
78
  embedding_quantization_mode=QuantizationMode.from_num_bits(
79
- int(metadata_dict["quantization_kwargs.bits"])
79
+ int(metadata_dict["quantization_kwargs.bits"]),
80
80
  ),
81
81
  activation_quantization_mode=None,
82
82
  activation_precision=activation_precision,
@@ -107,7 +107,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
107
107
  linear_config = MLXQuantizedLinearConfig(
108
108
  group_size=int(metadata_dict["quantization_kwargs.group_size"]),
109
109
  weight_quantization_mode=QuantizationMode.from_num_bits(
110
- int(metadata_dict["quantization_kwargs.bits"])
110
+ int(metadata_dict["quantization_kwargs.bits"]),
111
111
  ),
112
112
  activation_quantization_mode=None,
113
113
  activation_precision=activation_precision,
@@ -41,7 +41,7 @@ def activation_from_str(activation: str) -> type[Activation]:
41
41
  return supported_activations[activation]
42
42
 
43
43
  raise ValueError(
44
- f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}"
44
+ f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}",
45
45
  )
46
46
 
47
47
 
@@ -97,7 +97,7 @@ class ModernBERTConfig(HuggingFaceClassifierConfig):
97
97
  result = [None] * num_layers
98
98
  for index in range(len(result)):
99
99
  if index % global_attn_every_n_layers != 0:
100
- result[index] = self.local_attention # type: ignore
100
+ result[index] = self.local_attention
101
101
  else:
102
102
  pass
103
103
  return tuple(result)
@@ -5,7 +5,9 @@ from typing import ClassVar
5
5
 
6
6
  import cattrs
7
7
 
8
- __all__ = ["HFGenerationConfig"]
8
+ from lalamo.models import GenerationConfig
9
+
10
+ __all__ = ["HFGenerationConfig", "_policy_from_hf_config"]
9
11
 
10
12
 
11
13
  @dataclass(frozen=True)
@@ -27,10 +29,11 @@ class HFGenerationConfig:
27
29
  cache_implementation: str | None = None # “hybrid” for Gemma 3/2
28
30
 
29
31
  # -------- sampling strategy -------------
30
- do_sample: bool | None = None
32
+ do_sample: bool | None = False
31
33
  temperature: float | None = None
34
+ min_p: float | None = None
32
35
  top_p: float | None = None
33
- top_k: int | None = None
36
+ top_k: int | None = 50
34
37
  repetition_penalty: float | None = None
35
38
 
36
39
  # -------- length limits -----------------
@@ -42,3 +45,18 @@ class HFGenerationConfig:
42
45
  with open(json_path) as f:
43
46
  config = json.load(f)
44
47
  return cls._converter.structure(config, cls)
48
+
49
+
50
+ def _policy_from_hf_config(
51
+ hf_config: HFGenerationConfig,
52
+ stop_token_ids: tuple[int, ...] = (),
53
+ banned_tokens: tuple[int, ...] | None = None,
54
+ ) -> GenerationConfig:
55
+ return GenerationConfig(
56
+ stop_token_ids=stop_token_ids,
57
+ temperature=hf_config.temperature,
58
+ top_k=hf_config.top_k,
59
+ top_p=hf_config.top_p,
60
+ min_p=hf_config.min_p,
61
+ banned_tokens=banned_tokens,
62
+ )
@@ -97,7 +97,7 @@ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: Paramete
97
97
  fused_up_gate_params = merge_linear_params([up_proj_params, gate_proj_params])
98
98
 
99
99
  return load_parameters(
100
- lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)), # type: ignore
100
+ lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)),
101
101
  module,
102
102
  (*fused_up_gate_params, *down_proj_params),
103
103
  )
@@ -177,7 +177,7 @@ def load_attention(
177
177
 
178
178
  qkv_params = merge_linear_params([q_params, k_params, v_params])
179
179
  return load_parameters(
180
- lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)), # type: ignore
180
+ lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)),
181
181
  module,
182
182
  (*qkv_params, *out_params),
183
183
  )