lalamo 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +26 -2
- lalamo/commands.py +377 -0
- lalamo/main.py +239 -214
- lalamo/model_import/model_specs/common.py +4 -2
- lalamo/models/common.py +3 -3
- lalamo/safetensors.py +97 -0
- lalamo/speculator/estimator.py +9 -2
- lalamo/utils.py +0 -13
- {lalamo-0.5.16.dist-info → lalamo-0.5.17.dist-info}/METADATA +1 -2
- {lalamo-0.5.16.dist-info → lalamo-0.5.17.dist-info}/RECORD +14 -12
- {lalamo-0.5.16.dist-info → lalamo-0.5.17.dist-info}/WHEEL +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.5.17.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.5.17.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.5.17.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -1,4 +1,14 @@
|
|
|
1
|
-
from lalamo.
|
|
1
|
+
from lalamo.commands import (
|
|
2
|
+
CollectTracesCallbacks,
|
|
3
|
+
ConversionCallbacks,
|
|
4
|
+
EstimateBatchsizeCallbacks,
|
|
5
|
+
Precision,
|
|
6
|
+
TrainCallbacks,
|
|
7
|
+
collect_traces,
|
|
8
|
+
convert,
|
|
9
|
+
estimate_batchsize,
|
|
10
|
+
train,
|
|
11
|
+
)
|
|
2
12
|
from lalamo.message_processor import (
|
|
3
13
|
AssistantMessage,
|
|
4
14
|
ContentBlock,
|
|
@@ -9,27 +19,41 @@ from lalamo.message_processor import (
|
|
|
9
19
|
UserMessage,
|
|
10
20
|
)
|
|
11
21
|
from lalamo.model_import import ModelSpec, import_model
|
|
22
|
+
from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase, WeightsType
|
|
12
23
|
from lalamo.models import ClassifierModel, LanguageModel
|
|
24
|
+
from lalamo.quantization import QuantizationMode
|
|
13
25
|
from lalamo.speculator import (
|
|
14
26
|
CollectTracesEvent,
|
|
15
27
|
SpeculatorTrainingEvent,
|
|
16
28
|
)
|
|
17
29
|
|
|
18
|
-
__version__ = "0.5.
|
|
30
|
+
__version__ = "0.5.17"
|
|
19
31
|
|
|
20
32
|
__all__ = [
|
|
21
33
|
"AssistantMessage",
|
|
22
34
|
"ClassifierModel",
|
|
35
|
+
"CollectTracesCallbacks",
|
|
23
36
|
"CollectTracesEvent",
|
|
37
|
+
"ConfigMap",
|
|
24
38
|
"ContentBlock",
|
|
39
|
+
"ConversionCallbacks",
|
|
40
|
+
"EstimateBatchsizeCallbacks",
|
|
41
|
+
"FileSpec",
|
|
25
42
|
"Image",
|
|
43
|
+
"JSONFieldSpec",
|
|
26
44
|
"LanguageModel",
|
|
27
45
|
"Message",
|
|
28
46
|
"ModelSpec",
|
|
47
|
+
"ModelType",
|
|
48
|
+
"Precision",
|
|
49
|
+
"QuantizationMode",
|
|
29
50
|
"SpeculatorTrainingEvent",
|
|
30
51
|
"SystemMessage",
|
|
31
52
|
"ToolSchema",
|
|
53
|
+
"TrainCallbacks",
|
|
54
|
+
"UseCase",
|
|
32
55
|
"UserMessage",
|
|
56
|
+
"WeightsType",
|
|
33
57
|
"collect_traces",
|
|
34
58
|
"convert",
|
|
35
59
|
"estimate_batchsize",
|
lalamo/commands.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from itertools import chain
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from jaxtyping import DTypeLike
|
|
9
|
+
|
|
10
|
+
from lalamo.common import flatten_parameters
|
|
11
|
+
from lalamo.data import import_hf_parquet
|
|
12
|
+
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
13
|
+
from lalamo.message_processor import UserMessage
|
|
14
|
+
from lalamo.model_import import ModelMetadata, ModelSpec, import_model
|
|
15
|
+
from lalamo.model_import.common import (
|
|
16
|
+
DownloadingFileEvent,
|
|
17
|
+
FileSpec,
|
|
18
|
+
FinishedDownloadingFileEvent,
|
|
19
|
+
FinishedInitializingModelEvent,
|
|
20
|
+
InitializingModelEvent,
|
|
21
|
+
StatusEvent,
|
|
22
|
+
)
|
|
23
|
+
from lalamo.models import LanguageModelConfig
|
|
24
|
+
from lalamo.modules import config_converter
|
|
25
|
+
from lalamo.safetensors import safe_write
|
|
26
|
+
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
27
|
+
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
28
|
+
from lalamo.speculator.ngram import NGramSpeculator
|
|
29
|
+
from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Precision(Enum):
|
|
33
|
+
FLOAT32 = "float32"
|
|
34
|
+
FLOAT16 = "float16"
|
|
35
|
+
BFLOAT16 = "bfloat16"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class ConversionCallbacks:
|
|
40
|
+
model_spec: ModelSpec
|
|
41
|
+
output_dir: Path
|
|
42
|
+
precision: Precision | None
|
|
43
|
+
context_length: int | None
|
|
44
|
+
include_traces: bool
|
|
45
|
+
message_for_trace: str | None
|
|
46
|
+
|
|
47
|
+
def started(self) -> None:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def output_dir_exists(self) -> None:
|
|
51
|
+
raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
|
|
52
|
+
|
|
53
|
+
def downloading(self, file_spec: FileSpec) -> None:
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
def finished_downloading(self, file_spec: FileSpec) -> None:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def initializing_model(self) -> None:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def finished_initializing_model(self) -> None:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def saving_model(self) -> None:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
def finished_saving_model(self) -> None:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def convert(
|
|
73
|
+
model_spec: ModelSpec,
|
|
74
|
+
output_dir: Path,
|
|
75
|
+
precision: Precision | None = None,
|
|
76
|
+
context_length: int | None = None,
|
|
77
|
+
include_traces: bool = False,
|
|
78
|
+
message_for_trace: str | None = None,
|
|
79
|
+
callbacks_type: Callable[
|
|
80
|
+
[
|
|
81
|
+
ModelSpec,
|
|
82
|
+
Path,
|
|
83
|
+
Precision | None,
|
|
84
|
+
int | None,
|
|
85
|
+
bool,
|
|
86
|
+
str | None,
|
|
87
|
+
],
|
|
88
|
+
ConversionCallbacks,
|
|
89
|
+
] = ConversionCallbacks,
|
|
90
|
+
) -> None:
|
|
91
|
+
callbacks = callbacks_type(
|
|
92
|
+
model_spec,
|
|
93
|
+
output_dir,
|
|
94
|
+
precision,
|
|
95
|
+
context_length,
|
|
96
|
+
include_traces,
|
|
97
|
+
message_for_trace,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if precision is not None:
|
|
101
|
+
precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
|
|
102
|
+
else:
|
|
103
|
+
precision_dtype = None
|
|
104
|
+
|
|
105
|
+
if output_dir.exists():
|
|
106
|
+
callbacks.output_dir_exists()
|
|
107
|
+
|
|
108
|
+
callbacks.started()
|
|
109
|
+
|
|
110
|
+
def progress_callback(event: StatusEvent) -> None:
|
|
111
|
+
match event:
|
|
112
|
+
case DownloadingFileEvent(file_spec):
|
|
113
|
+
callbacks.downloading(file_spec)
|
|
114
|
+
case FinishedDownloadingFileEvent(file_spec):
|
|
115
|
+
callbacks.finished_downloading(file_spec)
|
|
116
|
+
case InitializingModelEvent():
|
|
117
|
+
callbacks.initializing_model()
|
|
118
|
+
case FinishedInitializingModelEvent():
|
|
119
|
+
callbacks.finished_initializing_model()
|
|
120
|
+
|
|
121
|
+
model, metadata = import_model(
|
|
122
|
+
model_spec,
|
|
123
|
+
precision=precision_dtype,
|
|
124
|
+
context_length=context_length,
|
|
125
|
+
progress_callback=progress_callback,
|
|
126
|
+
)
|
|
127
|
+
callbacks.saving_model()
|
|
128
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
129
|
+
|
|
130
|
+
if include_traces:
|
|
131
|
+
message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
|
|
132
|
+
result = model.record_trace(message)
|
|
133
|
+
traces = flatten_parameters(result.export())
|
|
134
|
+
with Path(output_dir / "traces.safetensors").open("wb") as fd:
|
|
135
|
+
safe_write(fd, traces)
|
|
136
|
+
|
|
137
|
+
model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
|
|
138
|
+
weights = flatten_parameters(model.export_weights())
|
|
139
|
+
del model
|
|
140
|
+
|
|
141
|
+
with Path(output_dir / "model.safetensors").open("wb") as fd:
|
|
142
|
+
safe_write(fd, weights)
|
|
143
|
+
|
|
144
|
+
config_json = config_converter.unstructure(metadata, ModelMetadata)
|
|
145
|
+
with open(output_dir / "config.json", "w") as file:
|
|
146
|
+
json.dump(config_json, file, indent=4)
|
|
147
|
+
|
|
148
|
+
callbacks.finished_saving_model()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class EstimateBatchsizeCallbacks:
|
|
153
|
+
model_path: Path
|
|
154
|
+
max_input_length: int
|
|
155
|
+
max_output_length: int
|
|
156
|
+
num_logits_per_token: int
|
|
157
|
+
mem: int
|
|
158
|
+
|
|
159
|
+
def loading_model(self) -> None:
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
def finished_loading_model(self) -> None:
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
def estimating_batchsize(self, lo: int, hi: int | None) -> None:
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
def finished_estimating_batchsize(self, batchsize: int) -> None:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def estimate_batchsize(
|
|
173
|
+
model_path: Path,
|
|
174
|
+
mem: int,
|
|
175
|
+
max_input_length: int = 1024,
|
|
176
|
+
max_output_length: int = 1024,
|
|
177
|
+
num_logits_per_token: int = 8,
|
|
178
|
+
callbacks_type: Callable[
|
|
179
|
+
[
|
|
180
|
+
Path,
|
|
181
|
+
int,
|
|
182
|
+
int,
|
|
183
|
+
int,
|
|
184
|
+
int,
|
|
185
|
+
],
|
|
186
|
+
EstimateBatchsizeCallbacks,
|
|
187
|
+
] = EstimateBatchsizeCallbacks,
|
|
188
|
+
) -> int:
|
|
189
|
+
callbacks = callbacks_type(model_path, max_input_length, max_output_length, num_logits_per_token, mem)
|
|
190
|
+
|
|
191
|
+
callbacks.loading_model()
|
|
192
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
193
|
+
callbacks.finished_loading_model()
|
|
194
|
+
|
|
195
|
+
def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
|
|
196
|
+
callbacks.estimating_batchsize(event.lo, event.hi)
|
|
197
|
+
|
|
198
|
+
bs = estimate_batchsize_from_memory(
|
|
199
|
+
model,
|
|
200
|
+
max_input_length,
|
|
201
|
+
max_output_length,
|
|
202
|
+
num_logits_per_token,
|
|
203
|
+
mem,
|
|
204
|
+
progress_callback,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
callbacks.finished_estimating_batchsize(bs)
|
|
208
|
+
return bs
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dataclass
|
|
212
|
+
class CollectTracesCallbacks:
|
|
213
|
+
model_path: Path
|
|
214
|
+
dataset_path: Path
|
|
215
|
+
output_path: Path
|
|
216
|
+
num_logits_per_token: int
|
|
217
|
+
max_input_length: int
|
|
218
|
+
max_output_length: int
|
|
219
|
+
batch_size: int
|
|
220
|
+
num_tokens_to_generate: int | None
|
|
221
|
+
|
|
222
|
+
def loading_model(self) -> None:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
def finished_loading_model(self) -> None:
|
|
226
|
+
pass
|
|
227
|
+
|
|
228
|
+
def loading_dataset(self) -> None:
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
def finished_loading_dataset(self) -> None:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
def inference_progress(self, tokens_generated: int) -> None:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
def finished_inference(self) -> None:
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def collect_traces(
|
|
242
|
+
model_path: Path,
|
|
243
|
+
dataset_path: Path,
|
|
244
|
+
output_path: Path,
|
|
245
|
+
num_logits_per_token: int = 8,
|
|
246
|
+
max_input_length: int = 1024,
|
|
247
|
+
max_output_length: int = 1024,
|
|
248
|
+
batch_size: int = 1,
|
|
249
|
+
num_tokens_to_generate: int | None = None,
|
|
250
|
+
callbacks_type: Callable[
|
|
251
|
+
[
|
|
252
|
+
Path,
|
|
253
|
+
Path,
|
|
254
|
+
Path,
|
|
255
|
+
int,
|
|
256
|
+
int,
|
|
257
|
+
int,
|
|
258
|
+
int,
|
|
259
|
+
int | None,
|
|
260
|
+
],
|
|
261
|
+
CollectTracesCallbacks,
|
|
262
|
+
] = CollectTracesCallbacks,
|
|
263
|
+
) -> None:
|
|
264
|
+
callbacks = callbacks_type(
|
|
265
|
+
model_path,
|
|
266
|
+
dataset_path,
|
|
267
|
+
output_path,
|
|
268
|
+
num_logits_per_token,
|
|
269
|
+
max_input_length,
|
|
270
|
+
max_output_length,
|
|
271
|
+
batch_size,
|
|
272
|
+
num_tokens_to_generate,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
callbacks.loading_model()
|
|
276
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
277
|
+
callbacks.finished_loading_model()
|
|
278
|
+
|
|
279
|
+
callbacks.loading_dataset()
|
|
280
|
+
dataset = iter(import_hf_parquet(dataset_path))
|
|
281
|
+
dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
|
|
282
|
+
callbacks.finished_loading_dataset()
|
|
283
|
+
|
|
284
|
+
def progress_callback(event: CollectTracesEvent) -> None:
|
|
285
|
+
callbacks.inference_progress(event.tokens_generated)
|
|
286
|
+
|
|
287
|
+
traces = inference_collect_traces(
|
|
288
|
+
model,
|
|
289
|
+
dataset,
|
|
290
|
+
num_logits_per_token,
|
|
291
|
+
batch_size,
|
|
292
|
+
max_input_length,
|
|
293
|
+
max_output_length,
|
|
294
|
+
num_tokens_to_generate,
|
|
295
|
+
progress_callback,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
299
|
+
with open(output_path, "wb") as output_fd:
|
|
300
|
+
for trace in traces:
|
|
301
|
+
blob = trace.serialize()
|
|
302
|
+
output_fd.write(blob)
|
|
303
|
+
|
|
304
|
+
callbacks.finished_inference()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@dataclass
|
|
308
|
+
class TrainCallbacks:
|
|
309
|
+
trace_path: Path
|
|
310
|
+
output_path: Path
|
|
311
|
+
hashtable_size: int
|
|
312
|
+
num_logits_per_token: int
|
|
313
|
+
ngram_size: int
|
|
314
|
+
subsample_size: int | None
|
|
315
|
+
|
|
316
|
+
def started(self) -> None:
|
|
317
|
+
pass
|
|
318
|
+
|
|
319
|
+
def training_progress(self, trained_tokens: int) -> None:
|
|
320
|
+
pass
|
|
321
|
+
|
|
322
|
+
def finished_training(self) -> None:
|
|
323
|
+
pass
|
|
324
|
+
|
|
325
|
+
def saving_speculator(self) -> None:
|
|
326
|
+
pass
|
|
327
|
+
|
|
328
|
+
def finished_saving_speculator(self) -> None:
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def train(
|
|
333
|
+
trace_path: Path,
|
|
334
|
+
output_path: Path,
|
|
335
|
+
hashtable_size: int = 65536,
|
|
336
|
+
num_logits_per_token: int = 8,
|
|
337
|
+
ngram_size: int = 2,
|
|
338
|
+
subsample_size: int | None = None,
|
|
339
|
+
callbacks_type: Callable[
|
|
340
|
+
[
|
|
341
|
+
Path,
|
|
342
|
+
Path,
|
|
343
|
+
int,
|
|
344
|
+
int,
|
|
345
|
+
int,
|
|
346
|
+
int | None,
|
|
347
|
+
],
|
|
348
|
+
TrainCallbacks,
|
|
349
|
+
] = TrainCallbacks,
|
|
350
|
+
) -> None:
|
|
351
|
+
callbacks = callbacks_type(
|
|
352
|
+
trace_path,
|
|
353
|
+
output_path,
|
|
354
|
+
hashtable_size,
|
|
355
|
+
num_logits_per_token,
|
|
356
|
+
ngram_size,
|
|
357
|
+
subsample_size,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
callbacks.started()
|
|
361
|
+
|
|
362
|
+
with open(trace_path, "rb") as trace_fd:
|
|
363
|
+
traces = LalamoCompletion.deserialize_many(trace_fd)
|
|
364
|
+
speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
|
|
365
|
+
|
|
366
|
+
def progress_callback(event: SpeculatorTrainingEvent) -> None:
|
|
367
|
+
callbacks.training_progress(event.trained_tokens)
|
|
368
|
+
|
|
369
|
+
train_speculator(speculator, traces, subsample_size, progress_callback)
|
|
370
|
+
|
|
371
|
+
callbacks.finished_training()
|
|
372
|
+
|
|
373
|
+
callbacks.saving_speculator()
|
|
374
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
375
|
+
with open(output_path, "wb") as fd:
|
|
376
|
+
fd.write(speculator.serialize())
|
|
377
|
+
callbacks.finished_saving_speculator()
|
lalamo/main.py
CHANGED
|
@@ -1,20 +1,19 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import random
|
|
3
2
|
import re
|
|
4
3
|
import shutil
|
|
5
4
|
import sys
|
|
6
|
-
from
|
|
7
|
-
from
|
|
5
|
+
from contextlib import ExitStack
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from functools import partial
|
|
8
|
+
from itertools import islice
|
|
8
9
|
from pathlib import Path
|
|
9
10
|
from typing import Annotated
|
|
10
11
|
|
|
11
|
-
import jax
|
|
12
12
|
import jax.profiler
|
|
13
13
|
import thefuzz.process
|
|
14
14
|
from click import Context as ClickContext
|
|
15
15
|
from click import Parameter as ClickParameter
|
|
16
16
|
from click import ParamType
|
|
17
|
-
from jaxtyping import DTypeLike
|
|
18
17
|
from rich import box
|
|
19
18
|
from rich.console import Console
|
|
20
19
|
from rich.live import Live
|
|
@@ -23,48 +22,40 @@ from rich.progress import (
|
|
|
23
22
|
MofNCompleteColumn,
|
|
24
23
|
Progress,
|
|
25
24
|
SpinnerColumn,
|
|
25
|
+
TaskID,
|
|
26
26
|
TextColumn,
|
|
27
27
|
TimeElapsedColumn,
|
|
28
28
|
TimeRemainingColumn,
|
|
29
29
|
)
|
|
30
|
+
from rich.prompt import Confirm
|
|
30
31
|
from rich.table import Table
|
|
31
|
-
from safetensors.flax import save_file
|
|
32
32
|
from typer import Argument, Context, Exit, Option, Typer
|
|
33
33
|
|
|
34
|
-
from lalamo.
|
|
35
|
-
|
|
34
|
+
from lalamo.commands import (
|
|
35
|
+
CollectTracesCallbacks,
|
|
36
|
+
ConversionCallbacks,
|
|
37
|
+
EstimateBatchsizeCallbacks,
|
|
38
|
+
Precision,
|
|
39
|
+
TrainCallbacks,
|
|
40
|
+
)
|
|
41
|
+
from lalamo.commands import collect_traces as _collect_traces
|
|
42
|
+
from lalamo.commands import convert as _convert
|
|
43
|
+
from lalamo.commands import estimate_batchsize as _estimate_batchsize
|
|
44
|
+
from lalamo.commands import train as _train
|
|
36
45
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
37
46
|
from lalamo.message_processor import UserMessage
|
|
38
|
-
from lalamo.model_import import REPO_TO_MODEL,
|
|
39
|
-
from lalamo.model_import.common import
|
|
40
|
-
DownloadingFileEvent,
|
|
41
|
-
FinishedDownloadingFileEvent,
|
|
42
|
-
FinishedInitializingModelEvent,
|
|
43
|
-
InitializingModelEvent,
|
|
44
|
-
StatusEvent,
|
|
45
|
-
)
|
|
47
|
+
from lalamo.model_import import REPO_TO_MODEL, ModelSpec
|
|
48
|
+
from lalamo.model_import.common import FileSpec
|
|
46
49
|
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
47
|
-
from lalamo.
|
|
48
|
-
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
49
|
-
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
50
|
+
from lalamo.speculator.estimator import get_default_device_memory
|
|
50
51
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
51
|
-
from lalamo.speculator.utils import
|
|
52
|
-
SpeculatorTrainingEvent,
|
|
53
|
-
test_speculator,
|
|
54
|
-
train_speculator,
|
|
55
|
-
)
|
|
52
|
+
from lalamo.speculator.utils import test_speculator
|
|
56
53
|
|
|
57
54
|
SCRIPT_NAME = Path(sys.argv[0]).name
|
|
58
55
|
|
|
59
56
|
DEFAULT_OUTPUT_DIR = Path("models")
|
|
60
57
|
|
|
61
58
|
|
|
62
|
-
class Precision(Enum):
|
|
63
|
-
FLOAT32 = "float32"
|
|
64
|
-
FLOAT16 = "float16"
|
|
65
|
-
BFLOAT16 = "bfloat16"
|
|
66
|
-
|
|
67
|
-
|
|
68
59
|
console = Console()
|
|
69
60
|
err_console = Console(stderr=True)
|
|
70
61
|
app = Typer(
|
|
@@ -182,6 +173,68 @@ def classify(
|
|
|
182
173
|
console.print()
|
|
183
174
|
|
|
184
175
|
|
|
176
|
+
@dataclass
|
|
177
|
+
class CliConversionCallbacks(ConversionCallbacks):
|
|
178
|
+
overwrite: bool = False
|
|
179
|
+
|
|
180
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
181
|
+
downloading_tasks: dict[FileSpec, TaskID] = field(default_factory=dict)
|
|
182
|
+
initializing_task: TaskID | None = None
|
|
183
|
+
saving_task: TaskID | None = None
|
|
184
|
+
|
|
185
|
+
def started(self) -> None:
|
|
186
|
+
conversion_strs = [
|
|
187
|
+
f"🚀 Converting [cyan]{self.model_spec.name}[/cyan] by [cyan]{self.model_spec.vendor}[/cyan]",
|
|
188
|
+
]
|
|
189
|
+
if self.precision is not None:
|
|
190
|
+
conversion_strs.append(
|
|
191
|
+
f" and converting floating-point weights into [cyan]{self.precision.name.lower()}[/cyan] precision",
|
|
192
|
+
)
|
|
193
|
+
conversion_strs.append(".")
|
|
194
|
+
console.print("".join(conversion_strs))
|
|
195
|
+
|
|
196
|
+
self.progress = self.stack.enter_context(
|
|
197
|
+
Progress(
|
|
198
|
+
SpinnerColumn(),
|
|
199
|
+
TextColumn("[progress.description]{task.description}"),
|
|
200
|
+
transient=True,
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def output_dir_exists(self) -> None:
|
|
205
|
+
if not self.overwrite and not Confirm().ask(
|
|
206
|
+
rf"⚠️ Output directory [cyan]{self.output_dir}[/cyan] already exists."
|
|
207
|
+
r" Do you want to overwrite it?",
|
|
208
|
+
):
|
|
209
|
+
raise Exit
|
|
210
|
+
|
|
211
|
+
shutil.rmtree(self.output_dir)
|
|
212
|
+
|
|
213
|
+
def downloading(self, file_spec: FileSpec) -> None:
|
|
214
|
+
self.downloading_tasks[file_spec] = self.progress.add_task(f"Retrieving {file_spec.filename}...")
|
|
215
|
+
|
|
216
|
+
def finished_downloading(self, file_spec: FileSpec) -> None:
|
|
217
|
+
self.progress.remove_task(self.downloading_tasks[file_spec])
|
|
218
|
+
|
|
219
|
+
def initializing_model(self) -> None:
|
|
220
|
+
self.initializing_task = self.progress.add_task("Initializing model...")
|
|
221
|
+
|
|
222
|
+
def finished_initializing_model(self) -> None:
|
|
223
|
+
assert self.initializing_task is not None
|
|
224
|
+
|
|
225
|
+
self.progress.remove_task(self.initializing_task)
|
|
226
|
+
|
|
227
|
+
def saving_model(self) -> None:
|
|
228
|
+
self.saving_task = self.progress.add_task(f"💾 Saving the model to {self.output_dir}")
|
|
229
|
+
|
|
230
|
+
def finished_saving_model(self) -> None:
|
|
231
|
+
assert self.saving_task is not None
|
|
232
|
+
|
|
233
|
+
self.progress.remove_task(self.saving_task)
|
|
234
|
+
self.stack.close()
|
|
235
|
+
console.print(f"🧑🍳 Model successfully cooked and saved to [cyan]`{self.output_dir}`[/cyan]!")
|
|
236
|
+
|
|
237
|
+
|
|
185
238
|
@app.command(help="Convert the model for use with the Uzu inference engine.")
|
|
186
239
|
def convert(
|
|
187
240
|
model_repo: Annotated[
|
|
@@ -238,85 +291,18 @@ def convert(
|
|
|
238
291
|
),
|
|
239
292
|
] = None,
|
|
240
293
|
) -> None:
|
|
241
|
-
if precision is not None:
|
|
242
|
-
precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
|
|
243
|
-
else:
|
|
244
|
-
precision_dtype = None
|
|
245
|
-
|
|
246
294
|
if output_dir is None:
|
|
247
295
|
output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
|
|
248
296
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
answer = console.input(
|
|
259
|
-
rf"⚠️ Output directory [cyan]{output_dir}[/cyan] already exists."
|
|
260
|
-
r" Do you want to overwrite it? [cyan]\[y/n][/cyan]: ",
|
|
261
|
-
)
|
|
262
|
-
while answer.lower() not in ["y", "n", "yes", "no"]:
|
|
263
|
-
answer = console.input("Please enter 'y' or 'n': ")
|
|
264
|
-
if answer.lower() in ["y", "yes"]:
|
|
265
|
-
shutil.rmtree(output_dir)
|
|
266
|
-
else:
|
|
267
|
-
console.print("Exiting...")
|
|
268
|
-
raise Exit
|
|
269
|
-
|
|
270
|
-
message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
|
|
271
|
-
|
|
272
|
-
with Progress(
|
|
273
|
-
SpinnerColumn(),
|
|
274
|
-
TextColumn("[progress.description]{task.description}"),
|
|
275
|
-
transient=True,
|
|
276
|
-
) as progress:
|
|
277
|
-
event_to_task = {}
|
|
278
|
-
|
|
279
|
-
def progress_callback(event: StatusEvent) -> None:
|
|
280
|
-
match event:
|
|
281
|
-
case DownloadingFileEvent(file_spec):
|
|
282
|
-
event_to_task[event] = progress.add_task(f"Retrieving {file_spec.filename}...")
|
|
283
|
-
case FinishedDownloadingFileEvent(file_spec):
|
|
284
|
-
progress.remove_task(event_to_task[event])
|
|
285
|
-
case InitializingModelEvent():
|
|
286
|
-
event_to_task[event] = progress.add_task("Initializing model...")
|
|
287
|
-
case FinishedInitializingModelEvent():
|
|
288
|
-
progress.remove_task(event_to_task[event])
|
|
289
|
-
|
|
290
|
-
main_task = progress.add_task("👨🍳 Cooking...")
|
|
291
|
-
model, metadata = import_model(
|
|
292
|
-
model_repo,
|
|
293
|
-
precision=precision_dtype,
|
|
294
|
-
context_length=context_length,
|
|
295
|
-
progress_callback=progress_callback,
|
|
296
|
-
)
|
|
297
|
-
save_task = progress.add_task(f"💾 Saving the model to {output_dir}")
|
|
298
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
299
|
-
|
|
300
|
-
if include_traces:
|
|
301
|
-
trace_task = progress.add_task("🚁 Generating traces...")
|
|
302
|
-
result = model.record_trace(message)
|
|
303
|
-
traces = flatten_parameters(result.export())
|
|
304
|
-
save_file(traces, output_dir / "traces.safetensors")
|
|
305
|
-
progress.remove_task(trace_task)
|
|
306
|
-
progress.remove_task(main_task)
|
|
307
|
-
|
|
308
|
-
model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
|
|
309
|
-
weights = flatten_parameters(model.export_weights())
|
|
310
|
-
del model
|
|
311
|
-
|
|
312
|
-
save_file(weights, output_dir / "model.safetensors")
|
|
313
|
-
|
|
314
|
-
config_json = config_converter.unstructure(metadata, ModelMetadata)
|
|
315
|
-
with open(output_dir / "config.json", "w") as file:
|
|
316
|
-
json.dump(config_json, file, indent=4)
|
|
317
|
-
progress.remove_task(save_task)
|
|
318
|
-
|
|
319
|
-
console.print(f"🧑🍳 Model successfully cooked and saved to [cyan]`{output_dir}`[/cyan]!")
|
|
297
|
+
_convert(
|
|
298
|
+
model_repo,
|
|
299
|
+
output_dir,
|
|
300
|
+
precision,
|
|
301
|
+
context_length,
|
|
302
|
+
include_traces,
|
|
303
|
+
message_for_trace,
|
|
304
|
+
partial(CliConversionCallbacks, overwrite=overwrite),
|
|
305
|
+
)
|
|
320
306
|
|
|
321
307
|
|
|
322
308
|
def _model_size_string_to_int(
|
|
@@ -385,6 +371,41 @@ speculator_app = Typer()
|
|
|
385
371
|
app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
|
|
386
372
|
|
|
387
373
|
|
|
374
|
+
@dataclass
|
|
375
|
+
class CliEstimateBatchsizeCallbacks(EstimateBatchsizeCallbacks):
|
|
376
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
377
|
+
loading_task: TaskID | None = None
|
|
378
|
+
estimating_task: TaskID | None = None
|
|
379
|
+
|
|
380
|
+
def loading_model(self) -> None:
|
|
381
|
+
self.progress = self.stack.enter_context(
|
|
382
|
+
Progress(
|
|
383
|
+
SpinnerColumn(),
|
|
384
|
+
TextColumn("[progress.description]{task.description}"),
|
|
385
|
+
transient=True,
|
|
386
|
+
),
|
|
387
|
+
)
|
|
388
|
+
self.loading_task = self.progress.add_task("[cyan]Loading model...[/cyan]")
|
|
389
|
+
|
|
390
|
+
def finished_loading_model(self) -> None:
|
|
391
|
+
assert self.loading_task is not None
|
|
392
|
+
self.progress.remove_task(self.loading_task)
|
|
393
|
+
|
|
394
|
+
def estimating_batchsize(self, lo: int, hi: int | None) -> None:
|
|
395
|
+
hi_str = str(hi) if hi is not None else "?"
|
|
396
|
+
description = f"[cyan]Estimating batch size... ({lo}..{hi_str})[/cyan]"
|
|
397
|
+
if self.estimating_task is None:
|
|
398
|
+
self.estimating_task = self.progress.add_task(description)
|
|
399
|
+
else:
|
|
400
|
+
self.progress.update(self.estimating_task, description=description)
|
|
401
|
+
|
|
402
|
+
def finished_estimating_batchsize(self, batchsize: int) -> None:
|
|
403
|
+
if self.estimating_task is not None:
|
|
404
|
+
self.progress.remove_task(self.estimating_task)
|
|
405
|
+
self.stack.close()
|
|
406
|
+
console.print(f"Found maximum batch size: [cyan]{batchsize}[/cyan]")
|
|
407
|
+
|
|
408
|
+
|
|
388
409
|
@speculator_app.command(help="Estimate maximum batch size at which a model can be run.")
|
|
389
410
|
def estimate_batchsize(
|
|
390
411
|
model_path: Annotated[
|
|
@@ -416,44 +437,64 @@ def estimate_batchsize(
|
|
|
416
437
|
) -> None:
|
|
417
438
|
if vram_gb is not None:
|
|
418
439
|
mem = vram_gb * 1024 * 1024 * 1024
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
423
|
-
raise Exit(1)
|
|
424
|
-
if "bytes_limit" not in memory_stats:
|
|
425
|
-
err_console.print("Cannot get the default device's bytes limit, use --vram-gb")
|
|
426
|
-
raise Exit(1)
|
|
427
|
-
mem = memory_stats["bytes_limit"]
|
|
440
|
+
elif (mem := get_default_device_memory()) is None:
|
|
441
|
+
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
442
|
+
raise Exit(1)
|
|
428
443
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
444
|
+
callbacks_type = CliEstimateBatchsizeCallbacks
|
|
445
|
+
|
|
446
|
+
_estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
@dataclass
|
|
450
|
+
class CliCollectTracesCallbacks(CollectTracesCallbacks):
|
|
451
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
452
|
+
live: Live | None = None
|
|
453
|
+
loading_task: TaskID | None = None
|
|
454
|
+
inference_task: TaskID | None = None
|
|
455
|
+
|
|
456
|
+
def loading_model(self) -> None:
|
|
457
|
+
self.live = self.stack.enter_context(Live(refresh_per_second=10))
|
|
458
|
+
self.progress = Progress(
|
|
459
|
+
SpinnerColumn(),
|
|
460
|
+
TextColumn("[progress.description]{task.description}"),
|
|
461
|
+
transient=True,
|
|
462
|
+
)
|
|
463
|
+
self.live.update(self.progress, refresh=True)
|
|
464
|
+
self.loading_task = self.progress.add_task("🧠 [cyan]Loading model...[/cyan]")
|
|
465
|
+
|
|
466
|
+
def finished_loading_model(self) -> None:
|
|
467
|
+
assert self.loading_task is not None
|
|
468
|
+
self.progress.remove_task(self.loading_task)
|
|
469
|
+
|
|
470
|
+
def loading_dataset(self) -> None:
|
|
471
|
+
self.loading_task = self.progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
|
|
472
|
+
|
|
473
|
+
def finished_loading_dataset(self) -> None:
|
|
474
|
+
assert self.loading_task is not None
|
|
475
|
+
assert self.live is not None
|
|
476
|
+
self.progress.remove_task(self.loading_task)
|
|
477
|
+
self.progress = Progress(
|
|
478
|
+
SpinnerColumn(),
|
|
479
|
+
TextColumn("[progress.description]{task.description}"),
|
|
480
|
+
MofNCompleteColumn(),
|
|
481
|
+
TimeElapsedColumn(),
|
|
482
|
+
TimeRemainingColumn(),
|
|
453
483
|
)
|
|
454
|
-
|
|
484
|
+
self.live.update(self.progress, refresh=True)
|
|
485
|
+
self.inference_task = self.progress.add_task(
|
|
486
|
+
"🔮 [cyan]Running inference...[/cyan]",
|
|
487
|
+
total=self.num_tokens_to_generate,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
def inference_progress(self, tokens_generated: int) -> None:
|
|
491
|
+
assert self.inference_task is not None
|
|
492
|
+
self.progress.update(self.inference_task, completed=tokens_generated)
|
|
455
493
|
|
|
456
|
-
|
|
494
|
+
def finished_inference(self) -> None:
|
|
495
|
+
assert self.inference_task is not None
|
|
496
|
+
self.progress.update(self.inference_task, description="✅ Completed")
|
|
497
|
+
self.stack.close()
|
|
457
498
|
|
|
458
499
|
|
|
459
500
|
@speculator_app.command(help="Run model inference and collect traces for speculator training")
|
|
@@ -503,55 +544,17 @@ def collect_traces(
|
|
|
503
544
|
),
|
|
504
545
|
] = None,
|
|
505
546
|
) -> None:
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
|
|
519
|
-
dataset = iter(import_hf_parquet(dataset_path))
|
|
520
|
-
dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
|
|
521
|
-
progress.remove_task(loading_dataset_task)
|
|
522
|
-
|
|
523
|
-
with Progress(
|
|
524
|
-
SpinnerColumn(),
|
|
525
|
-
TextColumn("[progress.description]{task.description}"),
|
|
526
|
-
MofNCompleteColumn(),
|
|
527
|
-
TimeElapsedColumn(),
|
|
528
|
-
TimeRemainingColumn(),
|
|
529
|
-
disable=True,
|
|
530
|
-
) as progress:
|
|
531
|
-
live.update(progress, refresh=True)
|
|
532
|
-
inference_task = progress.add_task("🔮 [cyan]Running inference...[/cyan]", total=num_tokens_to_generate)
|
|
533
|
-
|
|
534
|
-
def progress_callback(event: CollectTracesEvent) -> None:
|
|
535
|
-
progress.update(inference_task, completed=event.tokens_generated)
|
|
536
|
-
|
|
537
|
-
traces = inference_collect_traces(
|
|
538
|
-
model,
|
|
539
|
-
dataset,
|
|
540
|
-
num_logits_per_token,
|
|
541
|
-
batch_size,
|
|
542
|
-
max_input_length,
|
|
543
|
-
max_output_length,
|
|
544
|
-
num_tokens_to_generate,
|
|
545
|
-
progress_callback,
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
549
|
-
with open(output_path, "wb+") as output_fd:
|
|
550
|
-
for trace in traces:
|
|
551
|
-
blob = trace.serialize()
|
|
552
|
-
output_fd.write(blob)
|
|
553
|
-
|
|
554
|
-
progress.update(inference_task, description="✅ Completed")
|
|
547
|
+
_collect_traces(
|
|
548
|
+
model_path,
|
|
549
|
+
dataset_path,
|
|
550
|
+
output_path,
|
|
551
|
+
num_logits_per_token,
|
|
552
|
+
max_input_length,
|
|
553
|
+
max_output_length,
|
|
554
|
+
batch_size,
|
|
555
|
+
num_tokens_to_generate,
|
|
556
|
+
CliCollectTracesCallbacks,
|
|
557
|
+
)
|
|
555
558
|
|
|
556
559
|
|
|
557
560
|
@speculator_app.command(help="View model inference traces")
|
|
@@ -597,6 +600,43 @@ def view_traces(
|
|
|
597
600
|
console.print(table)
|
|
598
601
|
|
|
599
602
|
|
|
603
|
+
@dataclass
|
|
604
|
+
class CliTrainCallbacks(TrainCallbacks):
|
|
605
|
+
stack: ExitStack = field(default_factory=ExitStack)
|
|
606
|
+
training_task: TaskID | None = None
|
|
607
|
+
|
|
608
|
+
def started(self) -> None:
|
|
609
|
+
self.progress = self.stack.enter_context(
|
|
610
|
+
Progress(
|
|
611
|
+
SpinnerColumn(),
|
|
612
|
+
TextColumn("[progress.description]{task.description}"),
|
|
613
|
+
MofNCompleteColumn(),
|
|
614
|
+
TimeElapsedColumn(),
|
|
615
|
+
TimeRemainingColumn(),
|
|
616
|
+
),
|
|
617
|
+
)
|
|
618
|
+
self.training_task = self.progress.add_task(
|
|
619
|
+
"🔮 [cyan]Training speculator...[/cyan]",
|
|
620
|
+
total=self.subsample_size,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
def training_progress(self, trained_tokens: int) -> None:
|
|
624
|
+
assert self.training_task is not None
|
|
625
|
+
self.progress.update(self.training_task, completed=trained_tokens)
|
|
626
|
+
|
|
627
|
+
def finished_training(self) -> None:
|
|
628
|
+
assert self.training_task is not None
|
|
629
|
+
self.progress.update(self.training_task, description="✅ Completed")
|
|
630
|
+
self.progress.remove_task(self.training_task)
|
|
631
|
+
self.stack.close()
|
|
632
|
+
|
|
633
|
+
def saving_speculator(self) -> None:
|
|
634
|
+
pass
|
|
635
|
+
|
|
636
|
+
def finished_saving_speculator(self) -> None:
|
|
637
|
+
console.print(f"💾 Speculator saved to [cyan]{self.output_path}[/cyan]")
|
|
638
|
+
|
|
639
|
+
|
|
600
640
|
@speculator_app.command(help="Train a speculator from inference traces")
|
|
601
641
|
def train(
|
|
602
642
|
trace_path: Annotated[
|
|
@@ -633,30 +673,15 @@ def train(
|
|
|
633
673
|
),
|
|
634
674
|
] = None,
|
|
635
675
|
) -> None:
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
TimeElapsedColumn(),
|
|
646
|
-
TimeRemainingColumn(),
|
|
647
|
-
) as progress:
|
|
648
|
-
inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
|
|
649
|
-
|
|
650
|
-
def progress_callback(event: SpeculatorTrainingEvent) -> None:
|
|
651
|
-
progress.update(inference_task, completed=event.trained_tokens)
|
|
652
|
-
|
|
653
|
-
train_speculator(speculator, traces, subsample_size, progress_callback)
|
|
654
|
-
|
|
655
|
-
progress.update(inference_task, description="✅ Completed")
|
|
656
|
-
|
|
657
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
658
|
-
with open(output_path, "wb+") as fd:
|
|
659
|
-
fd.write(speculator.serialize())
|
|
676
|
+
_train(
|
|
677
|
+
trace_path,
|
|
678
|
+
output_path,
|
|
679
|
+
hashtable_size,
|
|
680
|
+
num_logits_per_token,
|
|
681
|
+
ngram_size,
|
|
682
|
+
subsample_size,
|
|
683
|
+
CliTrainCallbacks,
|
|
684
|
+
)
|
|
660
685
|
|
|
661
686
|
|
|
662
687
|
@speculator_app.command(help="Run speculator as an autoregressive llm")
|
|
@@ -15,7 +15,8 @@ from jaxtyping import Array, DTypeLike
|
|
|
15
15
|
|
|
16
16
|
from lalamo.model_import.decoder_configs import ForeignConfig
|
|
17
17
|
from lalamo.quantization import QuantizationMode
|
|
18
|
-
from lalamo.
|
|
18
|
+
from lalamo.safetensors import safe_read
|
|
19
|
+
from lalamo.utils import MapDictValues
|
|
19
20
|
|
|
20
21
|
__all__ = [
|
|
21
22
|
"ConfigMap",
|
|
@@ -52,7 +53,8 @@ class WeightsType(Enum):
|
|
|
52
53
|
float_dtype: DTypeLike,
|
|
53
54
|
) -> Iterator[tuple[Mapping[str, jnp.ndarray], Mapping[str, str]]]:
|
|
54
55
|
if self == WeightsType.SAFETENSORS:
|
|
55
|
-
with
|
|
56
|
+
with Path(filename).open("rb") as fd:
|
|
57
|
+
(metadata_dict, weights_dict) = safe_read(fd)
|
|
56
58
|
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
|
|
57
59
|
else:
|
|
58
60
|
import torch
|
lalamo/models/common.py
CHANGED
|
@@ -15,7 +15,7 @@ from lalamo.message_processor import Message, MessageProcessor, MessageProcessor
|
|
|
15
15
|
from lalamo.modules import Classifier, Decoder, LalamoModule, config_converter
|
|
16
16
|
from lalamo.modules.classifier import ClassifierConfig, ClassifierResult
|
|
17
17
|
from lalamo.modules.decoder import DecoderConfig, DecoderResult
|
|
18
|
-
from lalamo.
|
|
18
|
+
from lalamo.safetensors import safe_read
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"TextModel",
|
|
@@ -42,8 +42,8 @@ class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
|
|
|
42
42
|
with open(path / "config.json") as config_file:
|
|
43
43
|
config_json = json.load(config_file)
|
|
44
44
|
config = config_converter.structure(config_json["model_config"], cls)
|
|
45
|
-
with
|
|
46
|
-
|
|
45
|
+
with Path(path / "model.safetensors").open("rb") as fd:
|
|
46
|
+
_, weights_dict = safe_read(fd)
|
|
47
47
|
weights = unflatten_parameters(weights_dict)
|
|
48
48
|
model = config.model_config.empty().import_weights(weights)
|
|
49
49
|
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
lalamo/safetensors.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import struct
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from io import BufferedReader, BufferedWriter
|
|
6
|
+
from typing import Any, ClassVar, Self
|
|
7
|
+
|
|
8
|
+
import cattrs
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import numpy as np
|
|
12
|
+
from jaxtyping import Array
|
|
13
|
+
|
|
14
|
+
from lalamo.utils import LazyDict
|
|
15
|
+
|
|
16
|
+
SF2J = {
|
|
17
|
+
"BOOL": jnp.dtype(jnp.bool_),
|
|
18
|
+
"U8": jnp.dtype(jnp.uint8),
|
|
19
|
+
"I8": jnp.dtype(jnp.int8),
|
|
20
|
+
"I16": jnp.dtype(jnp.int16),
|
|
21
|
+
"U16": jnp.dtype(jnp.uint16),
|
|
22
|
+
"F16": jnp.dtype(jnp.float16),
|
|
23
|
+
"BF16": jnp.dtype(jnp.bfloat16),
|
|
24
|
+
"I32": jnp.dtype(jnp.int32),
|
|
25
|
+
"U32": jnp.dtype(jnp.uint32),
|
|
26
|
+
"F32": jnp.dtype(jnp.float32),
|
|
27
|
+
"C64": jnp.dtype(jnp.complex64),
|
|
28
|
+
"F64": jnp.dtype(jnp.float64),
|
|
29
|
+
"I64": jnp.dtype(jnp.int64),
|
|
30
|
+
"U64": jnp.dtype(jnp.uint64),
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
J2SF = {v: k for k, v in SF2J.items()}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class SFTensorInfo:
|
|
38
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
39
|
+
_converter.register_structure_hook(jnp.dtype, lambda x, _: SF2J[x])
|
|
40
|
+
_converter.register_unstructure_hook(jnp.dtype, lambda x: J2SF[x])
|
|
41
|
+
|
|
42
|
+
dtype: jnp.dtype
|
|
43
|
+
shape: tuple[int, ...]
|
|
44
|
+
data_offsets: tuple[int, int]
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def start(self) -> int:
|
|
48
|
+
return self.data_offsets[0]
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def end(self) -> int:
|
|
52
|
+
return self.data_offsets[1]
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def size(self) -> int:
|
|
56
|
+
return self.end - self.start
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, obj: dict) -> Self:
|
|
60
|
+
return cls._converter.structure(obj, cls)
|
|
61
|
+
|
|
62
|
+
def to_dict(self) -> dict:
|
|
63
|
+
return self._converter.unstructure(self)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def safe_read(fd: BufferedReader) -> tuple[dict[str, str] | None, LazyDict[str, Array]]:
|
|
67
|
+
header_size = struct.unpack("<Q", fd.read(8))[0]
|
|
68
|
+
header: dict[str, dict[str, Any]] = json.loads(fd.read(header_size))
|
|
69
|
+
metadata: dict[str, str] | None = header.pop("__metadata__", None)
|
|
70
|
+
data_offset = fd.tell()
|
|
71
|
+
|
|
72
|
+
def _load_tensor(key: str) -> Array:
|
|
73
|
+
info = SFTensorInfo.from_dict(header[key])
|
|
74
|
+
fd.seek(data_offset + info.start)
|
|
75
|
+
return jnp.asarray(np.fromfile(fd, info.dtype, info.size // info.dtype.itemsize)).reshape(info.shape)
|
|
76
|
+
|
|
77
|
+
lazy_tensors = LazyDict(set(header.keys()), _load_tensor)
|
|
78
|
+
return (metadata, lazy_tensors)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def safe_write(fd: BufferedWriter, tensors: Mapping[str, Array]) -> None:
|
|
82
|
+
sorted_tensors = dict(sorted(tensors.items(), key=lambda x: (-x[1].dtype.alignment, x[0])))
|
|
83
|
+
|
|
84
|
+
header_dict = {}
|
|
85
|
+
offset = 0
|
|
86
|
+
for key, tensor in sorted_tensors.items():
|
|
87
|
+
assert offset % tensor.dtype.alignment == 0
|
|
88
|
+
header_dict[key] = SFTensorInfo(tensor.dtype, tensor.shape, (offset, offset + tensor.nbytes)).to_dict()
|
|
89
|
+
offset += tensor.nbytes
|
|
90
|
+
|
|
91
|
+
data_alignment = max(8, next((t.dtype.alignment for t in sorted_tensors.values()), 1))
|
|
92
|
+
header = json.dumps(header_dict).encode()
|
|
93
|
+
header += b" " * (-len(header) % data_alignment)
|
|
94
|
+
fd.write(struct.pack("<Q", len(header)) + header)
|
|
95
|
+
|
|
96
|
+
for tensor in sorted_tensors.values():
|
|
97
|
+
jax.device_get(tensor).tofile(fd)
|
lalamo/speculator/estimator.py
CHANGED
|
@@ -9,6 +9,13 @@ import jax.numpy as jnp
|
|
|
9
9
|
from lalamo.models import LanguageModel
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
def get_default_device_memory() -> int | None:
|
|
13
|
+
memory_stats = jax.local_devices()[0].memory_stats()
|
|
14
|
+
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
15
|
+
return None
|
|
16
|
+
return memory_stats["bytes_limit"]
|
|
17
|
+
|
|
18
|
+
|
|
12
19
|
def estimate_memory_from_batchsize(
|
|
13
20
|
model: LanguageModel,
|
|
14
21
|
max_input_length: int,
|
|
@@ -23,7 +30,7 @@ def estimate_memory_from_batchsize(
|
|
|
23
30
|
max_output_length=max_output_length,
|
|
24
31
|
num_top_logits_to_return=num_logits_per_token,
|
|
25
32
|
),
|
|
26
|
-
backend="cpu",
|
|
33
|
+
backend="cpu", # cuda backend tries to allocate in .compile() and ooms
|
|
27
34
|
)
|
|
28
35
|
.lower(
|
|
29
36
|
model,
|
|
@@ -41,7 +48,7 @@ def estimate_memory_from_batchsize(
|
|
|
41
48
|
return (
|
|
42
49
|
memory_analysis.argument_size_in_bytes # type: ignore (pyright bug)
|
|
43
50
|
+ memory_analysis.output_size_in_bytes # type: ignore (pyright bug)
|
|
44
|
-
+ memory_analysis.temp_size_in_bytes
|
|
51
|
+
+ memory_analysis.temp_size_in_bytes # type: ignore (pyright bug)
|
|
45
52
|
)
|
|
46
53
|
|
|
47
54
|
|
lalamo/utils.py
CHANGED
|
@@ -9,21 +9,17 @@ from collections.abc import (
|
|
|
9
9
|
Sequence,
|
|
10
10
|
ValuesView,
|
|
11
11
|
)
|
|
12
|
-
from contextlib import contextmanager
|
|
13
12
|
from dataclasses import dataclass
|
|
14
|
-
from pathlib import Path
|
|
15
13
|
from typing import overload
|
|
16
14
|
|
|
17
15
|
import einops
|
|
18
16
|
import jax.numpy as jnp
|
|
19
17
|
from jaxtyping import Array
|
|
20
|
-
from safetensors import safe_open
|
|
21
18
|
|
|
22
19
|
__all__ = [
|
|
23
20
|
"MapDictValues",
|
|
24
21
|
"MapSequence",
|
|
25
22
|
"jax_uint4_to_packed_uint8",
|
|
26
|
-
"open_safetensors",
|
|
27
23
|
"process_chat_template",
|
|
28
24
|
]
|
|
29
25
|
|
|
@@ -45,15 +41,6 @@ class LazyDict[K, V](Mapping[K, V]):
|
|
|
45
41
|
return len(self.stored_keys)
|
|
46
42
|
|
|
47
43
|
|
|
48
|
-
@contextmanager
|
|
49
|
-
def open_safetensors(filename: Path | str) -> Iterator[tuple[Mapping[str, Array], Mapping[str, str]]]:
|
|
50
|
-
with safe_open(filename, framework="flax") as safetensors_nonsense:
|
|
51
|
-
yield (
|
|
52
|
-
LazyDict(set(safetensors_nonsense.keys()), safetensors_nonsense.get_tensor),
|
|
53
|
-
safetensors_nonsense.metadata(),
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
|
|
57
44
|
@dataclass(frozen=True)
|
|
58
45
|
class MapIterable[OldT, NewT](Iterable[NewT]):
|
|
59
46
|
map_func: Callable[[OldT], NewT]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lalamo
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.17
|
|
4
4
|
Summary: JAX library for optimization and export of models for use with the UZU inference engine.
|
|
5
5
|
Requires-Python: <4,>=3.12
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
@@ -19,7 +19,6 @@ Requires-Dist: rich>=14.0.0
|
|
|
19
19
|
Requires-Dist: thefuzz>=0.22.1
|
|
20
20
|
Requires-Dist: tokenizers>=0.21.2
|
|
21
21
|
Requires-Dist: typer>=0.15.1
|
|
22
|
-
Requires-Dist: safetensors>=0.6.2
|
|
23
22
|
Requires-Dist: polars>=1.33.1
|
|
24
23
|
Requires-Dist: xxhash>=3.5.0
|
|
25
24
|
Provides-Extra: cpu
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=asVMPmQ7BUt7bYlcuNZ7SnOSJDJUiN9QhlU5lRUehSo,1387
|
|
2
|
+
lalamo/commands.py,sha256=rU9T8Mx6s7itpk-dj5ToQ4PUpGPfdmmKlrF02l2kIS0,9967
|
|
2
3
|
lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
|
|
3
|
-
lalamo/main.py,sha256=
|
|
4
|
+
lalamo/main.py,sha256=dE7Us9L6sfz9bp5rUSzGHUkG0Uon4xdju9dGGtXidZI,23888
|
|
4
5
|
lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
|
|
5
6
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
6
7
|
lalamo/registry_abc.py,sha256=ENjXiD_wEH100fNjG-W5Em1L_EQ0Lf0pdRhRGvf3qZk,2197
|
|
8
|
+
lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
|
|
7
9
|
lalamo/sampling.py,sha256=g_dNiJyZrRqoQIiLid4cr6nRT9N5tSz3GtHr8Bt4n-E,3404
|
|
8
|
-
lalamo/utils.py,sha256=
|
|
10
|
+
lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
|
|
9
11
|
lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
|
|
10
12
|
lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
|
|
11
13
|
lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
|
|
@@ -35,7 +37,7 @@ lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFl
|
|
|
35
37
|
lalamo/model_import/loaders/huggingface.py,sha256=qWdzoSvHvb_3prn2kwfxgnYPW2bVB0Q49m_wyRYha8Q,34677
|
|
36
38
|
lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
|
|
37
39
|
lalamo/model_import/model_specs/__init__.py,sha256=JISqwJkloQkGD2jvi1MakNEWapIwlNXXVi5giZyXB74,1275
|
|
38
|
-
lalamo/model_import/model_specs/common.py,sha256=
|
|
40
|
+
lalamo/model_import/model_specs/common.py,sha256=8ALKxHrt8uK4XiqjK25NwZj1CC7DM7jlYcFVZPGkFrw,6643
|
|
39
41
|
lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
|
|
40
42
|
lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
|
|
41
43
|
lalamo/model_import/model_specs/gemma.py,sha256=dwKwOHU1sBJNLFAwtEyydsRUF9QENN3SHtjbfqtOSic,3876
|
|
@@ -52,7 +54,7 @@ lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiE
|
|
|
52
54
|
lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
|
|
53
55
|
lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
|
|
54
56
|
lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
|
|
55
|
-
lalamo/models/common.py,sha256=
|
|
57
|
+
lalamo/models/common.py,sha256=uU6eCHtIqMeC_aRGVo09NdpAtvQ6RKSbm6pumVvL8pc,2943
|
|
56
58
|
lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
|
|
57
59
|
lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
|
|
58
60
|
lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
|
|
@@ -81,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
|
|
|
81
83
|
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
82
84
|
lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
|
|
83
85
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
84
|
-
lalamo/speculator/estimator.py,sha256=
|
|
86
|
+
lalamo/speculator/estimator.py,sha256=j-zmhy3RxYDmQ7W0FMTmDk3i275r_Vg1s4NCaS4c_SQ,2760
|
|
85
87
|
lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
|
|
86
88
|
lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
|
|
87
89
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
88
|
-
lalamo-0.5.
|
|
89
|
-
lalamo-0.5.
|
|
90
|
-
lalamo-0.5.
|
|
91
|
-
lalamo-0.5.
|
|
92
|
-
lalamo-0.5.
|
|
93
|
-
lalamo-0.5.
|
|
90
|
+
lalamo-0.5.17.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
91
|
+
lalamo-0.5.17.dist-info/METADATA,sha256=16-W1J0wiwrmgMTgqiE9r3vxKRmZbGgZ-zS7bNACwTA,3113
|
|
92
|
+
lalamo-0.5.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
93
|
+
lalamo-0.5.17.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
94
|
+
lalamo-0.5.17.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
95
|
+
lalamo-0.5.17.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|