lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +26 -2
- lalamo/commands.py +429 -0
- lalamo/common.py +14 -1
- lalamo/main.py +375 -229
- lalamo/message_processor.py +4 -1
- lalamo/model_import/common.py +8 -17
- lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
- lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
- lalamo/model_import/huggingface_generation_config.py +21 -3
- lalamo/model_import/loaders/executorch.py +2 -2
- lalamo/model_import/loaders/huggingface.py +3 -3
- lalamo/model_import/model_specs/common.py +8 -4
- lalamo/model_import/model_specs/lfm2.py +41 -9
- lalamo/models/common.py +3 -3
- lalamo/models/language_model.py +7 -6
- lalamo/modules/activations.py +1 -1
- lalamo/modules/classifier.py +11 -24
- lalamo/modules/common.py +4 -1
- lalamo/modules/decoder.py +5 -11
- lalamo/modules/embedding.py +25 -62
- lalamo/modules/linear.py +19 -33
- lalamo/modules/mlp.py +9 -19
- lalamo/modules/mlx_interop.py +1 -1
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +1 -1
- lalamo/modules/token_mixers/attention.py +9 -27
- lalamo/modules/token_mixers/mamba.py +9 -24
- lalamo/modules/token_mixers/short_conv.py +5 -12
- lalamo/modules/transformer.py +10 -20
- lalamo/modules/transformer_layer.py +8 -20
- lalamo/registry_abc.py +4 -4
- lalamo/safetensors.py +97 -0
- lalamo/sampling.py +14 -0
- lalamo/speculator/estimator.py +11 -4
- lalamo/speculator/ngram.py +1 -1
- lalamo/utils.py +0 -13
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/main.py
CHANGED
|
@@ -1,20 +1,19 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import random
|
|
3
2
|
import re
|
|
4
3
|
import shutil
|
|
5
4
|
import sys
|
|
6
|
-
from
|
|
7
|
-
from
|
|
5
|
+
from contextlib import ExitStack
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from functools import partial
|
|
8
|
+
from itertools import islice
|
|
8
9
|
from pathlib import Path
|
|
9
10
|
from typing import Annotated
|
|
10
11
|
|
|
11
|
-
import jax
|
|
12
12
|
import jax.profiler
|
|
13
13
|
import thefuzz.process
|
|
14
14
|
from click import Context as ClickContext
|
|
15
15
|
from click import Parameter as ClickParameter
|
|
16
16
|
from click import ParamType
|
|
17
|
-
from jaxtyping import DTypeLike
|
|
18
17
|
from rich import box
|
|
19
18
|
from rich.console import Console
|
|
20
19
|
from rich.live import Live
|
|
@@ -23,48 +22,42 @@ from rich.progress import (
|
|
|
23
22
|
MofNCompleteColumn,
|
|
24
23
|
Progress,
|
|
25
24
|
SpinnerColumn,
|
|
25
|
+
TaskID,
|
|
26
26
|
TextColumn,
|
|
27
27
|
TimeElapsedColumn,
|
|
28
28
|
TimeRemainingColumn,
|
|
29
29
|
)
|
|
30
|
+
from rich.prompt import Confirm
|
|
30
31
|
from rich.table import Table
|
|
31
|
-
from safetensors.flax import save_file
|
|
32
32
|
from typer import Argument, Context, Exit, Option, Typer
|
|
33
33
|
|
|
34
|
-
from lalamo.
|
|
35
|
-
|
|
34
|
+
from lalamo.commands import (
|
|
35
|
+
CollectTracesCallbacks,
|
|
36
|
+
ConversionCallbacks,
|
|
37
|
+
EstimateBatchsizeCallbacks,
|
|
38
|
+
Precision,
|
|
39
|
+
TraceCallbacks,
|
|
40
|
+
TrainCallbacks,
|
|
41
|
+
)
|
|
42
|
+
from lalamo.commands import collect_traces as _collect_traces
|
|
43
|
+
from lalamo.commands import convert as _convert
|
|
44
|
+
from lalamo.commands import estimate_batchsize as _estimate_batchsize
|
|
45
|
+
from lalamo.commands import trace as _trace
|
|
46
|
+
from lalamo.commands import train as _train
|
|
36
47
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
37
48
|
from lalamo.message_processor import UserMessage
|
|
38
|
-
from lalamo.model_import import REPO_TO_MODEL,
|
|
39
|
-
from lalamo.model_import.common import
|
|
40
|
-
DownloadingFileEvent,
|
|
41
|
-
FinishedDownloadingFileEvent,
|
|
42
|
-
FinishedInitializingModelEvent,
|
|
43
|
-
InitializingModelEvent,
|
|
44
|
-
StatusEvent,
|
|
45
|
-
)
|
|
49
|
+
from lalamo.model_import import REPO_TO_MODEL, ModelSpec
|
|
50
|
+
from lalamo.model_import.common import FileSpec
|
|
46
51
|
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
47
|
-
from lalamo.
|
|
48
|
-
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
49
|
-
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
52
|
+
from lalamo.speculator.estimator import get_default_device_memory
|
|
50
53
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
51
|
-
from lalamo.speculator.utils import
|
|
52
|
-
SpeculatorTrainingEvent,
|
|
53
|
-
test_speculator,
|
|
54
|
-
train_speculator,
|
|
55
|
-
)
|
|
54
|
+
from lalamo.speculator.utils import test_speculator
|
|
56
55
|
|
|
57
56
|
SCRIPT_NAME = Path(sys.argv[0]).name
|
|
58
57
|
|
|
59
58
|
DEFAULT_OUTPUT_DIR = Path("models")
|
|
60
59
|
|
|
61
60
|
|
|
62
|
-
class Precision(Enum):
|
|
63
|
-
FLOAT32 = "float32"
|
|
64
|
-
FLOAT16 = "float16"
|
|
65
|
-
BFLOAT16 = "bfloat16"
|
|
66
|
-
|
|
67
|
-
|
|
68
61
|
console = Console()
|
|
69
62
|
err_console = Console(stderr=True)
|
|
70
63
|
app = Typer(
|
|
@@ -92,7 +85,7 @@ class ModelParser(ParamType):
|
|
|
92
85
|
f"\n\nUse the `{SCRIPT_NAME} list-models` command to see the list of currently supported models.",
|
|
93
86
|
)
|
|
94
87
|
error_message = "".join(error_message_parts)
|
|
95
|
-
self.fail(error_message, param, ctx)
|
|
88
|
+
return self.fail(error_message, param, ctx)
|
|
96
89
|
return result
|
|
97
90
|
|
|
98
91
|
|
|
@@ -120,10 +113,18 @@ def chat(
|
|
|
120
113
|
metavar="MODEL_PATH",
|
|
121
114
|
),
|
|
122
115
|
],
|
|
116
|
+
message: Annotated[
|
|
117
|
+
str | None,
|
|
118
|
+
Option(
|
|
119
|
+
help="Message for non-interactive mode",
|
|
120
|
+
show_default="None, run interactively",
|
|
121
|
+
),
|
|
122
|
+
] = None,
|
|
123
123
|
) -> None:
|
|
124
124
|
with Progress(
|
|
125
125
|
SpinnerColumn(),
|
|
126
126
|
TextColumn("[progress.description]{task.description}"),
|
|
127
|
+
console=err_console,
|
|
127
128
|
transient=True,
|
|
128
129
|
) as progress:
|
|
129
130
|
loading_task = progress.add_task("🚀 [cyan]Loading model...[/cyan]")
|
|
@@ -132,21 +133,28 @@ def chat(
|
|
|
132
133
|
warmup_task = progress.add_task("🔥 Warming up compilation cache...")
|
|
133
134
|
list(model.stream_reply_text([UserMessage("")], max_output_length=1))
|
|
134
135
|
progress.remove_task(warmup_task)
|
|
135
|
-
console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
|
|
136
|
-
messages = []
|
|
137
|
-
while True:
|
|
138
|
-
user_text = console.input("[cyan]user> [/cyan]")
|
|
139
|
-
user_message = UserMessage(user_text)
|
|
140
|
-
messages.append(user_message)
|
|
141
136
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
137
|
+
if message is None:
|
|
138
|
+
console.print(f"🤖 Chatting with [blue]{model_path}[/blue]:")
|
|
139
|
+
|
|
140
|
+
messages = []
|
|
141
|
+
while True:
|
|
142
|
+
user_text = console.input("[cyan]user> [/cyan]")
|
|
143
|
+
user_message = UserMessage(user_text)
|
|
144
|
+
messages.append(user_message)
|
|
145
|
+
|
|
146
|
+
console.print("[red]assistant> [/red]", end="")
|
|
147
|
+
model_response_tokens = []
|
|
148
|
+
for token in model.stream_reply_text(messages):
|
|
149
|
+
console.print(token, end="")
|
|
150
|
+
model_response_tokens.append(token)
|
|
151
|
+
console.print()
|
|
152
|
+
model_response_text = "".join(model_response_tokens)
|
|
153
|
+
messages.append(model.message_processor.parse_response(model_response_text))
|
|
154
|
+
else:
|
|
155
|
+
for token in model.stream_reply_text([UserMessage(message)]):
|
|
145
156
|
console.print(token, end="")
|
|
146
|
-
model_response_tokens.append(token)
|
|
147
157
|
console.print()
|
|
148
|
-
model_response_text = "".join(model_response_tokens)
|
|
149
|
-
messages.append(model.message_processor.parse_response(model_response_text))
|
|
150
158
|
|
|
151
159
|
|
|
152
160
|
@app.command(help="Classify given message with a Classifier type of model.")
|
|
@@ -182,6 +190,79 @@ def classify(
|
|
|
182
190
|
console.print()
|
|
183
191
|
|
|
184
192
|
|
|
193
|
+
@dataclass
|
|
194
|
+
class CliConversionCallbacks(ConversionCallbacks):
|
|
195
|
+
overwrite: bool = False
|
|
196
|
+
|
|
197
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
198
|
+
progress: Progress | None = None
|
|
199
|
+
downloading_tasks: dict[FileSpec, TaskID] = field(default_factory=dict)
|
|
200
|
+
initializing_task: TaskID | None = None
|
|
201
|
+
saving_task: TaskID | None = None
|
|
202
|
+
|
|
203
|
+
def started(self) -> None:
|
|
204
|
+
conversion_strs = [
|
|
205
|
+
f"🚀 Converting [cyan]{self.model_spec.name}[/cyan] by [cyan]{self.model_spec.vendor}[/cyan]",
|
|
206
|
+
]
|
|
207
|
+
if self.precision is not None:
|
|
208
|
+
conversion_strs.append(
|
|
209
|
+
f" and converting floating-point weights into [cyan]{self.precision.name.lower()}[/cyan] precision",
|
|
210
|
+
)
|
|
211
|
+
conversion_strs.append(".")
|
|
212
|
+
console.print("".join(conversion_strs))
|
|
213
|
+
|
|
214
|
+
self.progress = self.stack.enter_context(
|
|
215
|
+
Progress(
|
|
216
|
+
SpinnerColumn(),
|
|
217
|
+
TextColumn("[progress.description]{task.description}"),
|
|
218
|
+
transient=True,
|
|
219
|
+
),
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def output_dir_exists(self) -> None:
|
|
223
|
+
if not self.overwrite and not Confirm().ask(
|
|
224
|
+
rf"⚠️ Output directory [cyan]{self.output_dir}[/cyan] already exists."
|
|
225
|
+
r" Do you want to overwrite it?",
|
|
226
|
+
):
|
|
227
|
+
raise Exit
|
|
228
|
+
|
|
229
|
+
shutil.rmtree(self.output_dir)
|
|
230
|
+
|
|
231
|
+
def downloading(self, file_spec: FileSpec) -> None:
|
|
232
|
+
assert self.progress is not None
|
|
233
|
+
|
|
234
|
+
self.downloading_tasks[file_spec] = self.progress.add_task(f"Retrieving {file_spec.filename}...")
|
|
235
|
+
|
|
236
|
+
def finished_downloading(self, file_spec: FileSpec) -> None:
|
|
237
|
+
assert self.progress is not None
|
|
238
|
+
|
|
239
|
+
self.progress.remove_task(self.downloading_tasks[file_spec])
|
|
240
|
+
|
|
241
|
+
def initializing_model(self) -> None:
|
|
242
|
+
assert self.progress is not None
|
|
243
|
+
|
|
244
|
+
self.initializing_task = self.progress.add_task("Initializing model...")
|
|
245
|
+
|
|
246
|
+
def finished_initializing_model(self) -> None:
|
|
247
|
+
assert self.progress is not None
|
|
248
|
+
assert self.initializing_task is not None
|
|
249
|
+
|
|
250
|
+
self.progress.remove_task(self.initializing_task)
|
|
251
|
+
|
|
252
|
+
def saving_model(self) -> None:
|
|
253
|
+
assert self.progress is not None
|
|
254
|
+
|
|
255
|
+
self.saving_task = self.progress.add_task(f"💾 Saving the model to {self.output_dir}")
|
|
256
|
+
|
|
257
|
+
def finished_saving_model(self) -> None:
|
|
258
|
+
assert self.progress is not None
|
|
259
|
+
assert self.saving_task is not None
|
|
260
|
+
|
|
261
|
+
self.progress.remove_task(self.saving_task)
|
|
262
|
+
self.stack.close()
|
|
263
|
+
console.print(f"🧑🍳 Model successfully cooked and saved to [cyan]`{self.output_dir}`[/cyan]!")
|
|
264
|
+
|
|
265
|
+
|
|
185
266
|
@app.command(help="Convert the model for use with the Uzu inference engine.")
|
|
186
267
|
def convert(
|
|
187
268
|
model_repo: Annotated[
|
|
@@ -219,104 +300,130 @@ def convert(
|
|
|
219
300
|
show_default="Model's native maximum context length.",
|
|
220
301
|
),
|
|
221
302
|
] = None,
|
|
222
|
-
include_traces: Annotated[
|
|
223
|
-
bool,
|
|
224
|
-
Option(
|
|
225
|
-
help="Export activation traces for debugging purposes.",
|
|
226
|
-
),
|
|
227
|
-
] = False,
|
|
228
303
|
overwrite: Annotated[
|
|
229
304
|
bool,
|
|
230
305
|
Option(
|
|
231
306
|
help="Overwrite existing model files.",
|
|
232
307
|
),
|
|
233
308
|
] = False,
|
|
234
|
-
message_for_trace: Annotated[
|
|
235
|
-
str | None,
|
|
236
|
-
Option(
|
|
237
|
-
help="Text message to use as prompt when recording trace",
|
|
238
|
-
),
|
|
239
|
-
] = None,
|
|
240
309
|
) -> None:
|
|
241
|
-
if precision is not None:
|
|
242
|
-
precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
|
|
243
|
-
else:
|
|
244
|
-
precision_dtype = None
|
|
245
|
-
|
|
246
310
|
if output_dir is None:
|
|
247
311
|
output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
|
|
248
312
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
313
|
+
_convert(
|
|
314
|
+
model_repo,
|
|
315
|
+
output_dir,
|
|
316
|
+
precision,
|
|
317
|
+
context_length,
|
|
318
|
+
partial(CliConversionCallbacks, overwrite=overwrite),
|
|
319
|
+
)
|
|
256
320
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
321
|
+
|
|
322
|
+
@dataclass
|
|
323
|
+
class CliTraceCallbacks(TraceCallbacks):
|
|
324
|
+
overwrite: bool = False
|
|
325
|
+
|
|
326
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
327
|
+
progress: Progress | None = None
|
|
328
|
+
loading_task: TaskID | None = None
|
|
329
|
+
tracing_task: TaskID | None = None
|
|
330
|
+
saving_task: TaskID | None = None
|
|
331
|
+
|
|
332
|
+
def output_exists(self) -> None:
|
|
333
|
+
if not self.overwrite and not Confirm().ask(
|
|
334
|
+
rf"⚠️ Output [cyan]{self.output_path}[/cyan] already exists."
|
|
335
|
+
r" Do you want to overwrite it?",
|
|
336
|
+
):
|
|
268
337
|
raise Exit
|
|
269
338
|
|
|
270
|
-
|
|
339
|
+
self.output_path.unlink()
|
|
271
340
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
case DownloadingFileEvent(file_spec):
|
|
282
|
-
event_to_task[event] = progress.add_task(f"Retrieving {file_spec.filename}...")
|
|
283
|
-
case FinishedDownloadingFileEvent(file_spec):
|
|
284
|
-
progress.remove_task(event_to_task[event])
|
|
285
|
-
case InitializingModelEvent():
|
|
286
|
-
event_to_task[event] = progress.add_task("Initializing model...")
|
|
287
|
-
case FinishedInitializingModelEvent():
|
|
288
|
-
progress.remove_task(event_to_task[event])
|
|
289
|
-
|
|
290
|
-
main_task = progress.add_task("👨🍳 Cooking...")
|
|
291
|
-
model, metadata = import_model(
|
|
292
|
-
model_repo,
|
|
293
|
-
precision=precision_dtype,
|
|
294
|
-
context_length=context_length,
|
|
295
|
-
progress_callback=progress_callback,
|
|
341
|
+
def started(self) -> None:
|
|
342
|
+
console.print(f"🔍 Tracing [cyan]{self.model_path}[/cyan]")
|
|
343
|
+
|
|
344
|
+
self.progress = self.stack.enter_context(
|
|
345
|
+
Progress(
|
|
346
|
+
SpinnerColumn(),
|
|
347
|
+
TextColumn("[progress.description]{task.description}"),
|
|
348
|
+
transient=True,
|
|
349
|
+
),
|
|
296
350
|
)
|
|
297
|
-
save_task = progress.add_task(f"💾 Saving the model to {output_dir}")
|
|
298
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
299
351
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
progress
|
|
352
|
+
def loading_model(self) -> None:
|
|
353
|
+
assert self.progress is not None
|
|
354
|
+
|
|
355
|
+
self.loading_task = self.progress.add_task("🧠 Loading model...")
|
|
356
|
+
|
|
357
|
+
def finished_loading_model(self) -> None:
|
|
358
|
+
assert self.progress is not None
|
|
359
|
+
assert self.loading_task is not None
|
|
307
360
|
|
|
308
|
-
|
|
309
|
-
weights = flatten_parameters(model.export_weights())
|
|
310
|
-
del model
|
|
361
|
+
self.progress.remove_task(self.loading_task)
|
|
311
362
|
|
|
312
|
-
|
|
363
|
+
def tracing_model(self) -> None:
|
|
364
|
+
assert self.progress is not None
|
|
313
365
|
|
|
314
|
-
|
|
315
|
-
with open(output_dir / "config.json", "w") as file:
|
|
316
|
-
json.dump(config_json, file, indent=4)
|
|
317
|
-
progress.remove_task(save_task)
|
|
366
|
+
self.tracing_task = self.progress.add_task("🔍 Recording trace...")
|
|
318
367
|
|
|
319
|
-
|
|
368
|
+
def finished_tracing_model(self) -> None:
|
|
369
|
+
assert self.progress is not None
|
|
370
|
+
assert self.tracing_task is not None
|
|
371
|
+
|
|
372
|
+
self.progress.remove_task(self.tracing_task)
|
|
373
|
+
|
|
374
|
+
def saving_trace(self) -> None:
|
|
375
|
+
assert self.progress is not None
|
|
376
|
+
|
|
377
|
+
self.saving_task = self.progress.add_task(f"💾 Saving trace to {self.output_path}")
|
|
378
|
+
|
|
379
|
+
def finished_saving_trace(self) -> None:
|
|
380
|
+
assert self.progress is not None
|
|
381
|
+
assert self.saving_task is not None
|
|
382
|
+
|
|
383
|
+
self.progress.remove_task(self.saving_task)
|
|
384
|
+
self.stack.close()
|
|
385
|
+
console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
|
|
386
|
+
|
|
387
|
+
@app.command(help="Trace a model.")
|
|
388
|
+
def trace(
|
|
389
|
+
model_path: Annotated[
|
|
390
|
+
Path,
|
|
391
|
+
Argument(
|
|
392
|
+
help="Path to the model directory.",
|
|
393
|
+
metavar="MODEL_PATH",
|
|
394
|
+
),
|
|
395
|
+
],
|
|
396
|
+
output_path: Annotated[
|
|
397
|
+
Path | None,
|
|
398
|
+
Option(
|
|
399
|
+
help="Path to save the trace to.",
|
|
400
|
+
show_default="${MODEL_PATH}/traces.safetensors",
|
|
401
|
+
),
|
|
402
|
+
] = None,
|
|
403
|
+
overwrite: Annotated[
|
|
404
|
+
bool,
|
|
405
|
+
Option(
|
|
406
|
+
help="Overwrite existing trace file.",
|
|
407
|
+
),
|
|
408
|
+
] = False,
|
|
409
|
+
message: Annotated[
|
|
410
|
+
str | None,
|
|
411
|
+
Option(
|
|
412
|
+
help="Text message to use as prompt when recording trace",
|
|
413
|
+
),
|
|
414
|
+
] = None,
|
|
415
|
+
) -> None:
|
|
416
|
+
if output_path is None:
|
|
417
|
+
output_path = model_path / "traces.safetensors"
|
|
418
|
+
|
|
419
|
+
messages = None if message is None else [UserMessage(content=message)]
|
|
420
|
+
|
|
421
|
+
_trace(
|
|
422
|
+
model_path,
|
|
423
|
+
output_path,
|
|
424
|
+
messages,
|
|
425
|
+
partial(CliTraceCallbacks, overwrite=overwrite),
|
|
426
|
+
)
|
|
320
427
|
|
|
321
428
|
|
|
322
429
|
def _model_size_string_to_int(
|
|
@@ -385,6 +492,41 @@ speculator_app = Typer()
|
|
|
385
492
|
app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
|
|
386
493
|
|
|
387
494
|
|
|
495
|
+
@dataclass
|
|
496
|
+
class CliEstimateBatchsizeCallbacks(EstimateBatchsizeCallbacks):
|
|
497
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
498
|
+
loading_task: TaskID | None = None
|
|
499
|
+
estimating_task: TaskID | None = None
|
|
500
|
+
|
|
501
|
+
def loading_model(self) -> None:
|
|
502
|
+
self.progress = self.stack.enter_context(
|
|
503
|
+
Progress(
|
|
504
|
+
SpinnerColumn(),
|
|
505
|
+
TextColumn("[progress.description]{task.description}"),
|
|
506
|
+
transient=True,
|
|
507
|
+
),
|
|
508
|
+
)
|
|
509
|
+
self.loading_task = self.progress.add_task("[cyan]Loading model...[/cyan]")
|
|
510
|
+
|
|
511
|
+
def finished_loading_model(self) -> None:
|
|
512
|
+
assert self.loading_task is not None
|
|
513
|
+
self.progress.remove_task(self.loading_task)
|
|
514
|
+
|
|
515
|
+
def estimating_batchsize(self, lo: int, hi: int | None) -> None:
|
|
516
|
+
hi_str = str(hi) if hi is not None else "?"
|
|
517
|
+
description = f"[cyan]Estimating batch size... ({lo}..{hi_str})[/cyan]"
|
|
518
|
+
if self.estimating_task is None:
|
|
519
|
+
self.estimating_task = self.progress.add_task(description)
|
|
520
|
+
else:
|
|
521
|
+
self.progress.update(self.estimating_task, description=description)
|
|
522
|
+
|
|
523
|
+
def finished_estimating_batchsize(self, batchsize: int) -> None:
|
|
524
|
+
if self.estimating_task is not None:
|
|
525
|
+
self.progress.remove_task(self.estimating_task)
|
|
526
|
+
self.stack.close()
|
|
527
|
+
console.print(f"Found maximum batch size: [cyan]{batchsize}[/cyan]")
|
|
528
|
+
|
|
529
|
+
|
|
388
530
|
@speculator_app.command(help="Estimate maximum batch size at which a model can be run.")
|
|
389
531
|
def estimate_batchsize(
|
|
390
532
|
model_path: Annotated[
|
|
@@ -416,44 +558,64 @@ def estimate_batchsize(
|
|
|
416
558
|
) -> None:
|
|
417
559
|
if vram_gb is not None:
|
|
418
560
|
mem = vram_gb * 1024 * 1024 * 1024
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
423
|
-
raise Exit(1)
|
|
424
|
-
if "bytes_limit" not in memory_stats:
|
|
425
|
-
err_console.print("Cannot get the default device's bytes limit, use --vram-gb")
|
|
426
|
-
raise Exit(1)
|
|
427
|
-
mem = memory_stats["bytes_limit"]
|
|
561
|
+
elif (mem := get_default_device_memory()) is None:
|
|
562
|
+
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
563
|
+
raise Exit(1)
|
|
428
564
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
565
|
+
callbacks_type = CliEstimateBatchsizeCallbacks
|
|
566
|
+
|
|
567
|
+
_estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
@dataclass
|
|
571
|
+
class CliCollectTracesCallbacks(CollectTracesCallbacks):
|
|
572
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
573
|
+
live: Live | None = None
|
|
574
|
+
loading_task: TaskID | None = None
|
|
575
|
+
inference_task: TaskID | None = None
|
|
576
|
+
|
|
577
|
+
def loading_model(self) -> None:
|
|
578
|
+
self.live = self.stack.enter_context(Live(refresh_per_second=10))
|
|
579
|
+
self.progress = Progress(
|
|
580
|
+
SpinnerColumn(),
|
|
581
|
+
TextColumn("[progress.description]{task.description}"),
|
|
582
|
+
transient=True,
|
|
583
|
+
)
|
|
584
|
+
self.live.update(self.progress, refresh=True)
|
|
585
|
+
self.loading_task = self.progress.add_task("🧠 [cyan]Loading model...[/cyan]")
|
|
586
|
+
|
|
587
|
+
def finished_loading_model(self) -> None:
|
|
588
|
+
assert self.loading_task is not None
|
|
589
|
+
self.progress.remove_task(self.loading_task)
|
|
590
|
+
|
|
591
|
+
def loading_dataset(self) -> None:
|
|
592
|
+
self.loading_task = self.progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
|
|
593
|
+
|
|
594
|
+
def finished_loading_dataset(self) -> None:
|
|
595
|
+
assert self.loading_task is not None
|
|
596
|
+
assert self.live is not None
|
|
597
|
+
self.progress.remove_task(self.loading_task)
|
|
598
|
+
self.progress = Progress(
|
|
599
|
+
SpinnerColumn(),
|
|
600
|
+
TextColumn("[progress.description]{task.description}"),
|
|
601
|
+
MofNCompleteColumn(),
|
|
602
|
+
TimeElapsedColumn(),
|
|
603
|
+
TimeRemainingColumn(),
|
|
453
604
|
)
|
|
454
|
-
|
|
605
|
+
self.live.update(self.progress, refresh=True)
|
|
606
|
+
self.inference_task = self.progress.add_task(
|
|
607
|
+
"🔮 [cyan]Running inference...[/cyan]",
|
|
608
|
+
total=self.num_tokens_to_generate,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
def inference_progress(self, tokens_generated: int) -> None:
|
|
612
|
+
assert self.inference_task is not None
|
|
613
|
+
self.progress.update(self.inference_task, completed=tokens_generated)
|
|
455
614
|
|
|
456
|
-
|
|
615
|
+
def finished_inference(self) -> None:
|
|
616
|
+
assert self.inference_task is not None
|
|
617
|
+
self.progress.update(self.inference_task, description="✅ Completed")
|
|
618
|
+
self.stack.close()
|
|
457
619
|
|
|
458
620
|
|
|
459
621
|
@speculator_app.command(help="Run model inference and collect traces for speculator training")
|
|
@@ -503,55 +665,17 @@ def collect_traces(
|
|
|
503
665
|
),
|
|
504
666
|
] = None,
|
|
505
667
|
) -> None:
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
|
|
519
|
-
dataset = iter(import_hf_parquet(dataset_path))
|
|
520
|
-
dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
|
|
521
|
-
progress.remove_task(loading_dataset_task)
|
|
522
|
-
|
|
523
|
-
with Progress(
|
|
524
|
-
SpinnerColumn(),
|
|
525
|
-
TextColumn("[progress.description]{task.description}"),
|
|
526
|
-
MofNCompleteColumn(),
|
|
527
|
-
TimeElapsedColumn(),
|
|
528
|
-
TimeRemainingColumn(),
|
|
529
|
-
disable=True,
|
|
530
|
-
) as progress:
|
|
531
|
-
live.update(progress, refresh=True)
|
|
532
|
-
inference_task = progress.add_task("🔮 [cyan]Running inference...[/cyan]", total=num_tokens_to_generate)
|
|
533
|
-
|
|
534
|
-
def progress_callback(event: CollectTracesEvent) -> None:
|
|
535
|
-
progress.update(inference_task, completed=event.tokens_generated)
|
|
536
|
-
|
|
537
|
-
traces = inference_collect_traces(
|
|
538
|
-
model,
|
|
539
|
-
dataset,
|
|
540
|
-
num_logits_per_token,
|
|
541
|
-
batch_size,
|
|
542
|
-
max_input_length,
|
|
543
|
-
max_output_length,
|
|
544
|
-
num_tokens_to_generate,
|
|
545
|
-
progress_callback,
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
549
|
-
with open(output_path, "wb+") as output_fd:
|
|
550
|
-
for trace in traces:
|
|
551
|
-
blob = trace.serialize()
|
|
552
|
-
output_fd.write(blob)
|
|
553
|
-
|
|
554
|
-
progress.update(inference_task, description="✅ Completed")
|
|
668
|
+
_collect_traces(
|
|
669
|
+
model_path,
|
|
670
|
+
dataset_path,
|
|
671
|
+
output_path,
|
|
672
|
+
num_logits_per_token,
|
|
673
|
+
max_input_length,
|
|
674
|
+
max_output_length,
|
|
675
|
+
batch_size,
|
|
676
|
+
num_tokens_to_generate,
|
|
677
|
+
CliCollectTracesCallbacks,
|
|
678
|
+
)
|
|
555
679
|
|
|
556
680
|
|
|
557
681
|
@speculator_app.command(help="View model inference traces")
|
|
@@ -597,6 +721,43 @@ def view_traces(
|
|
|
597
721
|
console.print(table)
|
|
598
722
|
|
|
599
723
|
|
|
724
|
+
@dataclass
|
|
725
|
+
class CliTrainCallbacks(TrainCallbacks):
|
|
726
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
727
|
+
training_task: TaskID | None = None
|
|
728
|
+
|
|
729
|
+
def started(self) -> None:
|
|
730
|
+
self.progress = self.stack.enter_context(
|
|
731
|
+
Progress(
|
|
732
|
+
SpinnerColumn(),
|
|
733
|
+
TextColumn("[progress.description]{task.description}"),
|
|
734
|
+
MofNCompleteColumn(),
|
|
735
|
+
TimeElapsedColumn(),
|
|
736
|
+
TimeRemainingColumn(),
|
|
737
|
+
),
|
|
738
|
+
)
|
|
739
|
+
self.training_task = self.progress.add_task(
|
|
740
|
+
"🔮 [cyan]Training speculator...[/cyan]",
|
|
741
|
+
total=self.subsample_size,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
def training_progress(self, trained_tokens: int) -> None:
|
|
745
|
+
assert self.training_task is not None
|
|
746
|
+
self.progress.update(self.training_task, completed=trained_tokens)
|
|
747
|
+
|
|
748
|
+
def finished_training(self) -> None:
|
|
749
|
+
assert self.training_task is not None
|
|
750
|
+
self.progress.update(self.training_task, description="✅ Completed")
|
|
751
|
+
self.progress.remove_task(self.training_task)
|
|
752
|
+
self.stack.close()
|
|
753
|
+
|
|
754
|
+
def saving_speculator(self) -> None:
|
|
755
|
+
pass
|
|
756
|
+
|
|
757
|
+
def finished_saving_speculator(self) -> None:
|
|
758
|
+
console.print(f"💾 Speculator saved to [cyan]{self.output_path}[/cyan]")
|
|
759
|
+
|
|
760
|
+
|
|
600
761
|
@speculator_app.command(help="Train a speculator from inference traces")
|
|
601
762
|
def train(
|
|
602
763
|
trace_path: Annotated[
|
|
@@ -633,30 +794,15 @@ def train(
|
|
|
633
794
|
),
|
|
634
795
|
] = None,
|
|
635
796
|
) -> None:
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
TimeElapsedColumn(),
|
|
646
|
-
TimeRemainingColumn(),
|
|
647
|
-
) as progress:
|
|
648
|
-
inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
|
|
649
|
-
|
|
650
|
-
def progress_callback(event: SpeculatorTrainingEvent) -> None:
|
|
651
|
-
progress.update(inference_task, completed=event.trained_tokens)
|
|
652
|
-
|
|
653
|
-
train_speculator(speculator, traces, subsample_size, progress_callback)
|
|
654
|
-
|
|
655
|
-
progress.update(inference_task, description="✅ Completed")
|
|
656
|
-
|
|
657
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
658
|
-
with open(output_path, "wb+") as fd:
|
|
659
|
-
fd.write(speculator.serialize())
|
|
797
|
+
_train(
|
|
798
|
+
trace_path,
|
|
799
|
+
output_path,
|
|
800
|
+
hashtable_size,
|
|
801
|
+
num_logits_per_token,
|
|
802
|
+
ngram_size,
|
|
803
|
+
subsample_size,
|
|
804
|
+
CliTrainCallbacks,
|
|
805
|
+
)
|
|
660
806
|
|
|
661
807
|
|
|
662
808
|
@speculator_app.command(help="Run speculator as an autoregressive llm")
|