lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. lalamo/__init__.py +26 -2
  2. lalamo/commands.py +429 -0
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +375 -229
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +8 -4
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/common.py +3 -3
  16. lalamo/models/language_model.py +7 -6
  17. lalamo/modules/activations.py +1 -1
  18. lalamo/modules/classifier.py +11 -24
  19. lalamo/modules/common.py +4 -1
  20. lalamo/modules/decoder.py +5 -11
  21. lalamo/modules/embedding.py +25 -62
  22. lalamo/modules/linear.py +19 -33
  23. lalamo/modules/mlp.py +9 -19
  24. lalamo/modules/mlx_interop.py +1 -1
  25. lalamo/modules/rope.py +1 -1
  26. lalamo/modules/token_mixers/__init__.py +1 -1
  27. lalamo/modules/token_mixers/attention.py +9 -27
  28. lalamo/modules/token_mixers/mamba.py +9 -24
  29. lalamo/modules/token_mixers/short_conv.py +5 -12
  30. lalamo/modules/transformer.py +10 -20
  31. lalamo/modules/transformer_layer.py +8 -20
  32. lalamo/registry_abc.py +4 -4
  33. lalamo/safetensors.py +97 -0
  34. lalamo/sampling.py +14 -0
  35. lalamo/speculator/estimator.py +11 -4
  36. lalamo/speculator/ngram.py +1 -1
  37. lalamo/utils.py +0 -13
  38. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
  39. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
  40. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
  41. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
  42. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/main.py CHANGED
@@ -1,20 +1,19 @@
1
- import json
2
1
  import random
3
2
  import re
4
3
  import shutil
5
4
  import sys
6
- from enum import Enum
7
- from itertools import chain, islice
5
+ from contextlib import ExitStack
6
+ from dataclasses import dataclass, field
7
+ from functools import partial
8
+ from itertools import islice
8
9
  from pathlib import Path
9
10
  from typing import Annotated
10
11
 
11
- import jax
12
12
  import jax.profiler
13
13
  import thefuzz.process
14
14
  from click import Context as ClickContext
15
15
  from click import Parameter as ClickParameter
16
16
  from click import ParamType
17
- from jaxtyping import DTypeLike
18
17
  from rich import box
19
18
  from rich.console import Console
20
19
  from rich.live import Live
@@ -23,48 +22,42 @@ from rich.progress import (
23
22
  MofNCompleteColumn,
24
23
  Progress,
25
24
  SpinnerColumn,
25
+ TaskID,
26
26
  TextColumn,
27
27
  TimeElapsedColumn,
28
28
  TimeRemainingColumn,
29
29
  )
30
+ from rich.prompt import Confirm
30
31
  from rich.table import Table
31
- from safetensors.flax import save_file
32
32
  from typer import Argument, Context, Exit, Option, Typer
33
33
 
34
- from lalamo.common import flatten_parameters
35
- from lalamo.data import import_hf_parquet
34
+ from lalamo.commands import (
35
+ CollectTracesCallbacks,
36
+ ConversionCallbacks,
37
+ EstimateBatchsizeCallbacks,
38
+ Precision,
39
+ TraceCallbacks,
40
+ TrainCallbacks,
41
+ )
42
+ from lalamo.commands import collect_traces as _collect_traces
43
+ from lalamo.commands import convert as _convert
44
+ from lalamo.commands import estimate_batchsize as _estimate_batchsize
45
+ from lalamo.commands import trace as _trace
46
+ from lalamo.commands import train as _train
36
47
  from lalamo.data.lalamo_completions import LalamoCompletion
37
48
  from lalamo.message_processor import UserMessage
38
- from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
39
- from lalamo.model_import.common import (
40
- DownloadingFileEvent,
41
- FinishedDownloadingFileEvent,
42
- FinishedInitializingModelEvent,
43
- InitializingModelEvent,
44
- StatusEvent,
45
- )
49
+ from lalamo.model_import import REPO_TO_MODEL, ModelSpec
50
+ from lalamo.model_import.common import FileSpec
46
51
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
47
- from lalamo.modules import config_converter
48
- from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
49
- from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
52
+ from lalamo.speculator.estimator import get_default_device_memory
50
53
  from lalamo.speculator.ngram import NGramSpeculator
51
- from lalamo.speculator.utils import (
52
- SpeculatorTrainingEvent,
53
- test_speculator,
54
- train_speculator,
55
- )
54
+ from lalamo.speculator.utils import test_speculator
56
55
 
57
56
  SCRIPT_NAME = Path(sys.argv[0]).name
58
57
 
59
58
  DEFAULT_OUTPUT_DIR = Path("models")
60
59
 
61
60
 
62
- class Precision(Enum):
63
- FLOAT32 = "float32"
64
- FLOAT16 = "float16"
65
- BFLOAT16 = "bfloat16"
66
-
67
-
68
61
  console = Console()
69
62
  err_console = Console(stderr=True)
70
63
  app = Typer(
@@ -92,7 +85,7 @@ class ModelParser(ParamType):
92
85
  f"\n\nUse the `{SCRIPT_NAME} list-models` command to see the list of currently supported models.",
93
86
  )
94
87
  error_message = "".join(error_message_parts)
95
- self.fail(error_message, param, ctx)
88
+ return self.fail(error_message, param, ctx)
96
89
  return result
97
90
 
98
91
 
@@ -120,10 +113,18 @@ def chat(
120
113
  metavar="MODEL_PATH",
121
114
  ),
122
115
  ],
116
+ message: Annotated[
117
+ str | None,
118
+ Option(
119
+ help="Message for non-interactive mode",
120
+ show_default="None, run interactively",
121
+ ),
122
+ ] = None,
123
123
  ) -> None:
124
124
  with Progress(
125
125
  SpinnerColumn(),
126
126
  TextColumn("[progress.description]{task.description}"),
127
+ console=err_console,
127
128
  transient=True,
128
129
  ) as progress:
129
130
  loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
@@ -132,21 +133,28 @@ def chat(
132
133
  warmup_task = progress.add_task("🔥 Warming up compilation cache...")
133
134
  list(model.stream_reply_text([UserMessage("")], max_output_length=1))
134
135
  progress.remove_task(warmup_task)
135
- console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
136
- messages = []
137
- while True:
138
- user_text = console.input("[cyan]user> [/cyan]")
139
- user_message = UserMessage(user_text)
140
- messages.append(user_message)
141
136
 
142
- console.print("[red]assistant> [/red]", end="")
143
- model_response_tokens = []
144
- for token in model.stream_reply_text(messages):
137
+ if message is None:
138
+ console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
139
+
140
+ messages = []
141
+ while True:
142
+ user_text = console.input("[cyan]user> [/cyan]")
143
+ user_message = UserMessage(user_text)
144
+ messages.append(user_message)
145
+
146
+ console.print("[red]assistant> [/red]", end="")
147
+ model_response_tokens = []
148
+ for token in model.stream_reply_text(messages):
149
+ console.print(token, end="")
150
+ model_response_tokens.append(token)
151
+ console.print()
152
+ model_response_text = "".join(model_response_tokens)
153
+ messages.append(model.message_processor.parse_response(model_response_text))
154
+ else:
155
+ for token in model.stream_reply_text([UserMessage(message)]):
145
156
  console.print(token, end="")
146
- model_response_tokens.append(token)
147
157
  console.print()
148
- model_response_text = "".join(model_response_tokens)
149
- messages.append(model.message_processor.parse_response(model_response_text))
150
158
 
151
159
 
152
160
  @app.command(help="Classify given message with a Classifier type of model.")
@@ -182,6 +190,79 @@ def classify(
182
190
  console.print()
183
191
 
184
192
 
193
+ @dataclass
194
+ class CliConversionCallbacks(ConversionCallbacks):
195
+ overwrite: bool = False
196
+
197
+ stack: ExitStack = field(default_factory=ExitStack)
198
+ progress: Progress | None = None
199
+ downloading_tasks: dict[FileSpec, TaskID] = field(default_factory=dict)
200
+ initializing_task: TaskID | None = None
201
+ saving_task: TaskID | None = None
202
+
203
+ def started(self) -> None:
204
+ conversion_strs = [
205
+ f"🚀 Converting [cyan]{self.model_spec.name}[/cyan] by [cyan]{self.model_spec.vendor}[/cyan]",
206
+ ]
207
+ if self.precision is not None:
208
+ conversion_strs.append(
209
+ f" and converting floating-point weights into [cyan]{self.precision.name.lower()}[/cyan] precision",
210
+ )
211
+ conversion_strs.append(".")
212
+ console.print("".join(conversion_strs))
213
+
214
+ self.progress = self.stack.enter_context(
215
+ Progress(
216
+ SpinnerColumn(),
217
+ TextColumn("[progress.description]{task.description}"),
218
+ transient=True,
219
+ ),
220
+ )
221
+
222
+ def output_dir_exists(self) -> None:
223
+ if not self.overwrite and not Confirm().ask(
224
+ rf"⚠️ Output directory [cyan]{self.output_dir}[/cyan] already exists."
225
+ r" Do you want to overwrite it?",
226
+ ):
227
+ raise Exit
228
+
229
+ shutil.rmtree(self.output_dir)
230
+
231
+ def downloading(self, file_spec: FileSpec) -> None:
232
+ assert self.progress is not None
233
+
234
+ self.downloading_tasks[file_spec] = self.progress.add_task(f"Retrieving {file_spec.filename}...")
235
+
236
+ def finished_downloading(self, file_spec: FileSpec) -> None:
237
+ assert self.progress is not None
238
+
239
+ self.progress.remove_task(self.downloading_tasks[file_spec])
240
+
241
+ def initializing_model(self) -> None:
242
+ assert self.progress is not None
243
+
244
+ self.initializing_task = self.progress.add_task("Initializing model...")
245
+
246
+ def finished_initializing_model(self) -> None:
247
+ assert self.progress is not None
248
+ assert self.initializing_task is not None
249
+
250
+ self.progress.remove_task(self.initializing_task)
251
+
252
+ def saving_model(self) -> None:
253
+ assert self.progress is not None
254
+
255
+ self.saving_task = self.progress.add_task(f"💾 Saving the model to {self.output_dir}")
256
+
257
+ def finished_saving_model(self) -> None:
258
+ assert self.progress is not None
259
+ assert self.saving_task is not None
260
+
261
+ self.progress.remove_task(self.saving_task)
262
+ self.stack.close()
263
+ console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{self.output_dir}`[/cyan]!")
264
+
265
+
185
266
  @app.command(help="Convert the model for use with the Uzu inference engine.")
186
267
  def convert(
187
268
  model_repo: Annotated[
@@ -219,104 +300,130 @@ def convert(
219
300
  show_default="Model's native maximum context length.",
220
301
  ),
221
302
  ] = None,
222
- include_traces: Annotated[
223
- bool,
224
- Option(
225
- help="Export activation traces for debugging purposes.",
226
- ),
227
- ] = False,
228
303
  overwrite: Annotated[
229
304
  bool,
230
305
  Option(
231
306
  help="Overwrite existing model files.",
232
307
  ),
233
308
  ] = False,
234
- message_for_trace: Annotated[
235
- str | None,
236
- Option(
237
- help="Text message to use as prompt when recording trace",
238
- ),
239
- ] = None,
240
309
  ) -> None:
241
- if precision is not None:
242
- precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
243
- else:
244
- precision_dtype = None
245
-
246
310
  if output_dir is None:
247
311
  output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
248
312
 
249
- conversion_strs = [f"🚀 Converting [cyan]{model_repo.name}[/cyan] by [cyan]{model_repo.vendor}[/cyan]"]
250
- if precision is not None:
251
- conversion_strs.append(
252
- f" and converting floating-point weights into [cyan]{precision.name.lower()}[/cyan] precision",
253
- )
254
- conversion_strs.append(".")
255
- console.print("".join(conversion_strs))
313
+ _convert(
314
+ model_repo,
315
+ output_dir,
316
+ precision,
317
+ context_length,
318
+ partial(CliConversionCallbacks, overwrite=overwrite),
319
+ )
256
320
 
257
- if output_dir.exists() and not overwrite:
258
- answer = console.input(
259
- rf"⚠️ Output directory [cyan]{output_dir}[/cyan] already exists."
260
- r" Do you want to overwrite it? [cyan]\[y/n][/cyan]: ",
261
- )
262
- while answer.lower() not in ["y", "n", "yes", "no"]:
263
- answer = console.input("Please enter 'y' or 'n': ")
264
- if answer.lower() in ["y", "yes"]:
265
- shutil.rmtree(output_dir)
266
- else:
267
- console.print("Exiting...")
321
+
322
+ @dataclass
323
+ class CliTraceCallbacks(TraceCallbacks):
324
+ overwrite: bool = False
325
+
326
+ stack: ExitStack = field(default_factory=ExitStack)
327
+ progress: Progress | None = None
328
+ loading_task: TaskID | None = None
329
+ tracing_task: TaskID | None = None
330
+ saving_task: TaskID | None = None
331
+
332
+ def output_exists(self) -> None:
333
+ if not self.overwrite and not Confirm().ask(
334
+ rf"⚠️ Output [cyan]{self.output_path}[/cyan] already exists."
335
+ r" Do you want to overwrite it?",
336
+ ):
268
337
  raise Exit
269
338
 
270
- message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
339
+ self.output_path.unlink()
271
340
 
272
- with Progress(
273
- SpinnerColumn(),
274
- TextColumn("[progress.description]{task.description}"),
275
- transient=True,
276
- ) as progress:
277
- event_to_task = {}
278
-
279
- def progress_callback(event: StatusEvent) -> None:
280
- match event:
281
- case DownloadingFileEvent(file_spec):
282
- event_to_task[event] = progress.add_task(f"Retrieving {file_spec.filename}...")
283
- case FinishedDownloadingFileEvent(file_spec):
284
- progress.remove_task(event_to_task[event])
285
- case InitializingModelEvent():
286
- event_to_task[event] = progress.add_task("Initializing model...")
287
- case FinishedInitializingModelEvent():
288
- progress.remove_task(event_to_task[event])
289
-
290
- main_task = progress.add_task("👨‍🍳 Cooking...")
291
- model, metadata = import_model(
292
- model_repo,
293
- precision=precision_dtype,
294
- context_length=context_length,
295
- progress_callback=progress_callback,
341
+ def started(self) -> None:
342
+ console.print(f"🔍 Tracing [cyan]{self.model_path}[/cyan]")
343
+
344
+ self.progress = self.stack.enter_context(
345
+ Progress(
346
+ SpinnerColumn(),
347
+ TextColumn("[progress.description]{task.description}"),
348
+ transient=True,
349
+ ),
296
350
  )
297
- save_task = progress.add_task(f"💾 Saving the model to {output_dir}")
298
- output_dir.mkdir(parents=True, exist_ok=True)
299
351
 
300
- if include_traces:
301
- trace_task = progress.add_task("🚁 Generating traces...")
302
- result = model.record_trace(message)
303
- traces = flatten_parameters(result.export())
304
- save_file(traces, output_dir / "traces.safetensors")
305
- progress.remove_task(trace_task)
306
- progress.remove_task(main_task)
352
+ def loading_model(self) -> None:
353
+ assert self.progress is not None
354
+
355
+ self.loading_task = self.progress.add_task("🧠 Loading model...")
356
+
357
+ def finished_loading_model(self) -> None:
358
+ assert self.progress is not None
359
+ assert self.loading_task is not None
307
360
 
308
- model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
309
- weights = flatten_parameters(model.export_weights())
310
- del model
361
+ self.progress.remove_task(self.loading_task)
311
362
 
312
- save_file(weights, output_dir / "model.safetensors")
363
+ def tracing_model(self) -> None:
364
+ assert self.progress is not None
313
365
 
314
- config_json = config_converter.unstructure(metadata, ModelMetadata)
315
- with open(output_dir / "config.json", "w") as file:
316
- json.dump(config_json, file, indent=4)
317
- progress.remove_task(save_task)
366
+ self.tracing_task = self.progress.add_task("🔍 Recording trace...")
318
367
 
319
- console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{output_dir}`[/cyan]!")
368
+ def finished_tracing_model(self) -> None:
369
+ assert self.progress is not None
370
+ assert self.tracing_task is not None
371
+
372
+ self.progress.remove_task(self.tracing_task)
373
+
374
+ def saving_trace(self) -> None:
375
+ assert self.progress is not None
376
+
377
+ self.saving_task = self.progress.add_task(f"💾 Saving trace to {self.output_path}")
378
+
379
+ def finished_saving_trace(self) -> None:
380
+ assert self.progress is not None
381
+ assert self.saving_task is not None
382
+
383
+ self.progress.remove_task(self.saving_task)
384
+ self.stack.close()
385
+ console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
386
+
387
+ @app.command(help="Trace a model.")
388
+ def trace(
389
+ model_path: Annotated[
390
+ Path,
391
+ Argument(
392
+ help="Path to the model directory.",
393
+ metavar="MODEL_PATH",
394
+ ),
395
+ ],
396
+ output_path: Annotated[
397
+ Path | None,
398
+ Option(
399
+ help="Path to save the trace to.",
400
+ show_default="${MODEL_PATH}/traces.safetensors",
401
+ ),
402
+ ] = None,
403
+ overwrite: Annotated[
404
+ bool,
405
+ Option(
406
+ help="Overwrite existing trace file.",
407
+ ),
408
+ ] = False,
409
+ message: Annotated[
410
+ str | None,
411
+ Option(
412
+ help="Text message to use as prompt when recording trace",
413
+ ),
414
+ ] = None,
415
+ ) -> None:
416
+ if output_path is None:
417
+ output_path = model_path / "traces.safetensors"
418
+
419
+ messages = None if message is None else [UserMessage(content=message)]
420
+
421
+ _trace(
422
+ model_path,
423
+ output_path,
424
+ messages,
425
+ partial(CliTraceCallbacks, overwrite=overwrite),
426
+ )
320
427
 
321
428
 
322
429
  def _model_size_string_to_int(
@@ -385,6 +492,41 @@ speculator_app = Typer()
385
492
  app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
386
493
 
387
494
 
495
+ @dataclass
496
+ class CliEstimateBatchsizeCallbacks(EstimateBatchsizeCallbacks):
497
+ stack: ExitStack = field(default_factory=ExitStack)
498
+ loading_task: TaskID | None = None
499
+ estimating_task: TaskID | None = None
500
+
501
+ def loading_model(self) -> None:
502
+ self.progress = self.stack.enter_context(
503
+ Progress(
504
+ SpinnerColumn(),
505
+ TextColumn("[progress.description]{task.description}"),
506
+ transient=True,
507
+ ),
508
+ )
509
+ self.loading_task = self.progress.add_task("[cyan]Loading model...[/cyan]")
510
+
511
+ def finished_loading_model(self) -> None:
512
+ assert self.loading_task is not None
513
+ self.progress.remove_task(self.loading_task)
514
+
515
+ def estimating_batchsize(self, lo: int, hi: int | None) -> None:
516
+ hi_str = str(hi) if hi is not None else "?"
517
+ description = f"[cyan]Estimating batch size... ({lo}..{hi_str})[/cyan]"
518
+ if self.estimating_task is None:
519
+ self.estimating_task = self.progress.add_task(description)
520
+ else:
521
+ self.progress.update(self.estimating_task, description=description)
522
+
523
+ def finished_estimating_batchsize(self, batchsize: int) -> None:
524
+ if self.estimating_task is not None:
525
+ self.progress.remove_task(self.estimating_task)
526
+ self.stack.close()
527
+ console.print(f"Found maximum batch size: [cyan]{batchsize}[/cyan]")
528
+
529
+
388
530
  @speculator_app.command(help="Estimate maximum batch size at which a model can be run.")
389
531
  def estimate_batchsize(
390
532
  model_path: Annotated[
@@ -416,44 +558,64 @@ def estimate_batchsize(
416
558
  ) -> None:
417
559
  if vram_gb is not None:
418
560
  mem = vram_gb * 1024 * 1024 * 1024
419
- else:
420
- memory_stats = jax.local_devices()[0].memory_stats()
421
- if memory_stats is None:
422
- err_console.print("Cannot get the default device's memory stats, use --vram-gb")
423
- raise Exit(1)
424
- if "bytes_limit" not in memory_stats:
425
- err_console.print("Cannot get the default device's bytes limit, use --vram-gb")
426
- raise Exit(1)
427
- mem = memory_stats["bytes_limit"]
561
+ elif (mem := get_default_device_memory()) is None:
562
+ err_console.print("Cannot get the default device's memory stats, use --vram-gb")
563
+ raise Exit(1)
428
564
 
429
- with Progress(
430
- SpinnerColumn(),
431
- TextColumn("[progress.description]{task.description}"),
432
- transient=True,
433
- ) as progress:
434
- loading_model_task = progress.add_task("[cyan]Loading model...[/cyan]")
435
- model = LanguageModelConfig.load_model(model_path)
436
- progress.remove_task(loading_model_task)
437
-
438
- estimating_batchsize_task = progress.add_task("[cyan]Estimating batch size...[/cyan]")
439
-
440
- def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
441
- lo = str(event.lo)
442
- hi = str(event.hi) if event.hi is not None else "?"
443
- description = f"[cyan]Estimating batch size... ({lo}..{hi})[/cyan]"
444
- progress.update(estimating_batchsize_task, description=description)
445
-
446
- bs = estimate_batchsize_from_memory(
447
- model,
448
- max_input_length,
449
- max_output_length,
450
- num_logits_per_token,
451
- mem,
452
- progress_callback,
565
+ callbacks_type = CliEstimateBatchsizeCallbacks
566
+
567
+ _estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
568
+
569
+
570
+ @dataclass
571
+ class CliCollectTracesCallbacks(CollectTracesCallbacks):
572
+ stack: ExitStack = field(default_factory=ExitStack)
573
+ live: Live | None = None
574
+ loading_task: TaskID | None = None
575
+ inference_task: TaskID | None = None
576
+
577
+ def loading_model(self) -> None:
578
+ self.live = self.stack.enter_context(Live(refresh_per_second=10))
579
+ self.progress = Progress(
580
+ SpinnerColumn(),
581
+ TextColumn("[progress.description]{task.description}"),
582
+ transient=True,
583
+ )
584
+ self.live.update(self.progress, refresh=True)
585
+ self.loading_task = self.progress.add_task("🧠 [cyan]Loading model...[/cyan]")
586
+
587
+ def finished_loading_model(self) -> None:
588
+ assert self.loading_task is not None
589
+ self.progress.remove_task(self.loading_task)
590
+
591
+ def loading_dataset(self) -> None:
592
+ self.loading_task = self.progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
593
+
594
+ def finished_loading_dataset(self) -> None:
595
+ assert self.loading_task is not None
596
+ assert self.live is not None
597
+ self.progress.remove_task(self.loading_task)
598
+ self.progress = Progress(
599
+ SpinnerColumn(),
600
+ TextColumn("[progress.description]{task.description}"),
601
+ MofNCompleteColumn(),
602
+ TimeElapsedColumn(),
603
+ TimeRemainingColumn(),
453
604
  )
454
- progress.remove_task(estimating_batchsize_task)
605
+ self.live.update(self.progress, refresh=True)
606
+ self.inference_task = self.progress.add_task(
607
+ "🔮 [cyan]Running inference...[/cyan]",
608
+ total=self.num_tokens_to_generate,
609
+ )
610
+
611
+ def inference_progress(self, tokens_generated: int) -> None:
612
+ assert self.inference_task is not None
613
+ self.progress.update(self.inference_task, completed=tokens_generated)
455
614
 
456
- console.print(f"Found maximum batch size: [cyan]{bs}[/cyan]")
615
+ def finished_inference(self) -> None:
616
+ assert self.inference_task is not None
617
+ self.progress.update(self.inference_task, description="✅ Completed")
618
+ self.stack.close()
457
619
 
458
620
 
459
621
  @speculator_app.command(help="Run model inference and collect traces for speculator training")
@@ -503,55 +665,17 @@ def collect_traces(
503
665
  ),
504
666
  ] = None,
505
667
  ) -> None:
506
- with Live(refresh_per_second=10) as live:
507
- with Progress(
508
- SpinnerColumn(),
509
- TextColumn("[progress.description]{task.description}"),
510
- transient=True,
511
- disable=True,
512
- ) as progress:
513
- live.update(progress, refresh=True)
514
- loading_model_task = progress.add_task("🧠 [cyan]Loading model...[/cyan]")
515
- model = LanguageModelConfig.load_model(model_path)
516
- progress.remove_task(loading_model_task)
517
-
518
- loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
519
- dataset = iter(import_hf_parquet(dataset_path))
520
- dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
521
- progress.remove_task(loading_dataset_task)
522
-
523
- with Progress(
524
- SpinnerColumn(),
525
- TextColumn("[progress.description]{task.description}"),
526
- MofNCompleteColumn(),
527
- TimeElapsedColumn(),
528
- TimeRemainingColumn(),
529
- disable=True,
530
- ) as progress:
531
- live.update(progress, refresh=True)
532
- inference_task = progress.add_task("🔮 [cyan]Running inference...[/cyan]", total=num_tokens_to_generate)
533
-
534
- def progress_callback(event: CollectTracesEvent) -> None:
535
- progress.update(inference_task, completed=event.tokens_generated)
536
-
537
- traces = inference_collect_traces(
538
- model,
539
- dataset,
540
- num_logits_per_token,
541
- batch_size,
542
- max_input_length,
543
- max_output_length,
544
- num_tokens_to_generate,
545
- progress_callback,
546
- )
547
-
548
- output_path.parent.mkdir(parents=True, exist_ok=True)
549
- with open(output_path, "wb+") as output_fd:
550
- for trace in traces:
551
- blob = trace.serialize()
552
- output_fd.write(blob)
553
-
554
- progress.update(inference_task, description="✅ Completed")
668
+ _collect_traces(
669
+ model_path,
670
+ dataset_path,
671
+ output_path,
672
+ num_logits_per_token,
673
+ max_input_length,
674
+ max_output_length,
675
+ batch_size,
676
+ num_tokens_to_generate,
677
+ CliCollectTracesCallbacks,
678
+ )
555
679
 
556
680
 
557
681
  @speculator_app.command(help="View model inference traces")
@@ -597,6 +721,43 @@ def view_traces(
597
721
  console.print(table)
598
722
 
599
723
 
724
+ @dataclass
725
+ class CliTrainCallbacks(TrainCallbacks):
726
+ stack: ExitStack = field(default_factory=ExitStack)
727
+ training_task: TaskID | None = None
728
+
729
+ def started(self) -> None:
730
+ self.progress = self.stack.enter_context(
731
+ Progress(
732
+ SpinnerColumn(),
733
+ TextColumn("[progress.description]{task.description}"),
734
+ MofNCompleteColumn(),
735
+ TimeElapsedColumn(),
736
+ TimeRemainingColumn(),
737
+ ),
738
+ )
739
+ self.training_task = self.progress.add_task(
740
+ "🔮 [cyan]Training speculator...[/cyan]",
741
+ total=self.subsample_size,
742
+ )
743
+
744
+ def training_progress(self, trained_tokens: int) -> None:
745
+ assert self.training_task is not None
746
+ self.progress.update(self.training_task, completed=trained_tokens)
747
+
748
+ def finished_training(self) -> None:
749
+ assert self.training_task is not None
750
+ self.progress.update(self.training_task, description="✅ Completed")
751
+ self.progress.remove_task(self.training_task)
752
+ self.stack.close()
753
+
754
+ def saving_speculator(self) -> None:
755
+ pass
756
+
757
+ def finished_saving_speculator(self) -> None:
758
+ console.print(f"💾 Speculator saved to [cyan]{self.output_path}[/cyan]")
759
+
760
+
600
761
  @speculator_app.command(help="Train a speculator from inference traces")
601
762
  def train(
602
763
  trace_path: Annotated[
@@ -633,30 +794,15 @@ def train(
633
794
  ),
634
795
  ] = None,
635
796
  ) -> None:
636
- with open(trace_path, "rb") as trace_fd:
637
- traces = LalamoCompletion.deserialize_many(trace_fd)
638
-
639
- speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
640
-
641
- with Progress(
642
- SpinnerColumn(),
643
- TextColumn("[progress.description]{task.description}"),
644
- MofNCompleteColumn(),
645
- TimeElapsedColumn(),
646
- TimeRemainingColumn(),
647
- ) as progress:
648
- inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
649
-
650
- def progress_callback(event: SpeculatorTrainingEvent) -> None:
651
- progress.update(inference_task, completed=event.trained_tokens)
652
-
653
- train_speculator(speculator, traces, subsample_size, progress_callback)
654
-
655
- progress.update(inference_task, description="✅ Completed")
656
-
657
- output_path.parent.mkdir(parents=True, exist_ok=True)
658
- with open(output_path, "wb+") as fd:
659
- fd.write(speculator.serialize())
797
+ _train(
798
+ trace_path,
799
+ output_path,
800
+ hashtable_size,
801
+ num_logits_per_token,
802
+ ngram_size,
803
+ subsample_size,
804
+ CliTrainCallbacks,
805
+ )
660
806
 
661
807
 
662
808
  @speculator_app.command(help="Run speculator as an autoregressive llm")