lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +26 -2
- lalamo/commands.py +429 -0
- lalamo/common.py +14 -1
- lalamo/main.py +375 -229
- lalamo/message_processor.py +4 -1
- lalamo/model_import/common.py +8 -17
- lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
- lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
- lalamo/model_import/huggingface_generation_config.py +21 -3
- lalamo/model_import/loaders/executorch.py +2 -2
- lalamo/model_import/loaders/huggingface.py +3 -3
- lalamo/model_import/model_specs/common.py +8 -4
- lalamo/model_import/model_specs/lfm2.py +41 -9
- lalamo/models/common.py +3 -3
- lalamo/models/language_model.py +7 -6
- lalamo/modules/activations.py +1 -1
- lalamo/modules/classifier.py +11 -24
- lalamo/modules/common.py +4 -1
- lalamo/modules/decoder.py +5 -11
- lalamo/modules/embedding.py +25 -62
- lalamo/modules/linear.py +19 -33
- lalamo/modules/mlp.py +9 -19
- lalamo/modules/mlx_interop.py +1 -1
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +1 -1
- lalamo/modules/token_mixers/attention.py +9 -27
- lalamo/modules/token_mixers/mamba.py +9 -24
- lalamo/modules/token_mixers/short_conv.py +5 -12
- lalamo/modules/transformer.py +10 -20
- lalamo/modules/transformer_layer.py +8 -20
- lalamo/registry_abc.py +4 -4
- lalamo/safetensors.py +97 -0
- lalamo/sampling.py +14 -0
- lalamo/speculator/estimator.py +11 -4
- lalamo/speculator/ngram.py +1 -1
- lalamo/utils.py +0 -13
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -1,4 +1,14 @@
|
|
|
1
|
-
from lalamo.
|
|
1
|
+
from lalamo.commands import (
|
|
2
|
+
CollectTracesCallbacks,
|
|
3
|
+
ConversionCallbacks,
|
|
4
|
+
EstimateBatchsizeCallbacks,
|
|
5
|
+
Precision,
|
|
6
|
+
TrainCallbacks,
|
|
7
|
+
collect_traces,
|
|
8
|
+
convert,
|
|
9
|
+
estimate_batchsize,
|
|
10
|
+
train,
|
|
11
|
+
)
|
|
2
12
|
from lalamo.message_processor import (
|
|
3
13
|
AssistantMessage,
|
|
4
14
|
ContentBlock,
|
|
@@ -9,27 +19,41 @@ from lalamo.message_processor import (
|
|
|
9
19
|
UserMessage,
|
|
10
20
|
)
|
|
11
21
|
from lalamo.model_import import ModelSpec, import_model
|
|
22
|
+
from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase, WeightsType
|
|
12
23
|
from lalamo.models import ClassifierModel, LanguageModel
|
|
24
|
+
from lalamo.quantization import QuantizationMode
|
|
13
25
|
from lalamo.speculator import (
|
|
14
26
|
CollectTracesEvent,
|
|
15
27
|
SpeculatorTrainingEvent,
|
|
16
28
|
)
|
|
17
29
|
|
|
18
|
-
__version__ = "0.
|
|
30
|
+
__version__ = "0.6.0"
|
|
19
31
|
|
|
20
32
|
__all__ = [
|
|
21
33
|
"AssistantMessage",
|
|
22
34
|
"ClassifierModel",
|
|
35
|
+
"CollectTracesCallbacks",
|
|
23
36
|
"CollectTracesEvent",
|
|
37
|
+
"ConfigMap",
|
|
24
38
|
"ContentBlock",
|
|
39
|
+
"ConversionCallbacks",
|
|
40
|
+
"EstimateBatchsizeCallbacks",
|
|
41
|
+
"FileSpec",
|
|
25
42
|
"Image",
|
|
43
|
+
"JSONFieldSpec",
|
|
26
44
|
"LanguageModel",
|
|
27
45
|
"Message",
|
|
28
46
|
"ModelSpec",
|
|
47
|
+
"ModelType",
|
|
48
|
+
"Precision",
|
|
49
|
+
"QuantizationMode",
|
|
29
50
|
"SpeculatorTrainingEvent",
|
|
30
51
|
"SystemMessage",
|
|
31
52
|
"ToolSchema",
|
|
53
|
+
"TrainCallbacks",
|
|
54
|
+
"UseCase",
|
|
32
55
|
"UserMessage",
|
|
56
|
+
"WeightsType",
|
|
33
57
|
"collect_traces",
|
|
34
58
|
"convert",
|
|
35
59
|
"estimate_batchsize",
|
lalamo/commands.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from itertools import chain
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from jaxtyping import DTypeLike
|
|
9
|
+
|
|
10
|
+
from lalamo.common import flatten_parameters
|
|
11
|
+
from lalamo.data import import_hf_parquet
|
|
12
|
+
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
13
|
+
from lalamo.message_processor import Message
|
|
14
|
+
from lalamo.model_import import ModelMetadata, ModelSpec, import_model
|
|
15
|
+
from lalamo.model_import.common import (
|
|
16
|
+
DownloadingFileEvent,
|
|
17
|
+
FileSpec,
|
|
18
|
+
FinishedDownloadingFileEvent,
|
|
19
|
+
FinishedInitializingModelEvent,
|
|
20
|
+
InitializingModelEvent,
|
|
21
|
+
StatusEvent,
|
|
22
|
+
)
|
|
23
|
+
from lalamo.models import LanguageModelConfig
|
|
24
|
+
from lalamo.modules import config_converter
|
|
25
|
+
from lalamo.safetensors import safe_write
|
|
26
|
+
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
27
|
+
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
28
|
+
from lalamo.speculator.ngram import NGramSpeculator
|
|
29
|
+
from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Precision(Enum):
|
|
33
|
+
FLOAT32 = "float32"
|
|
34
|
+
FLOAT16 = "float16"
|
|
35
|
+
BFLOAT16 = "bfloat16"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class ConversionCallbacks:
|
|
40
|
+
model_spec: ModelSpec
|
|
41
|
+
output_dir: Path
|
|
42
|
+
precision: Precision | None
|
|
43
|
+
context_length: int | None
|
|
44
|
+
|
|
45
|
+
def started(self) -> None:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def output_dir_exists(self) -> None:
|
|
49
|
+
raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
|
|
50
|
+
|
|
51
|
+
def downloading(self, file_spec: FileSpec) -> None:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def finished_downloading(self, file_spec: FileSpec) -> None:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
def initializing_model(self) -> None:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def finished_initializing_model(self) -> None:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
def saving_model(self) -> None:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
def finished_saving_model(self) -> None:
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def convert(
|
|
71
|
+
model_spec: ModelSpec,
|
|
72
|
+
output_dir: Path,
|
|
73
|
+
precision: Precision | None = None,
|
|
74
|
+
context_length: int | None = None,
|
|
75
|
+
callbacks_type: Callable[
|
|
76
|
+
[
|
|
77
|
+
ModelSpec,
|
|
78
|
+
Path,
|
|
79
|
+
Precision | None,
|
|
80
|
+
int | None,
|
|
81
|
+
],
|
|
82
|
+
ConversionCallbacks,
|
|
83
|
+
] = ConversionCallbacks,
|
|
84
|
+
) -> None:
|
|
85
|
+
callbacks = callbacks_type(
|
|
86
|
+
model_spec,
|
|
87
|
+
output_dir,
|
|
88
|
+
precision,
|
|
89
|
+
context_length,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if precision is not None:
|
|
93
|
+
precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
|
|
94
|
+
else:
|
|
95
|
+
precision_dtype = None
|
|
96
|
+
|
|
97
|
+
if output_dir.exists():
|
|
98
|
+
callbacks.output_dir_exists()
|
|
99
|
+
|
|
100
|
+
callbacks.started()
|
|
101
|
+
|
|
102
|
+
def progress_callback(event: StatusEvent) -> None:
|
|
103
|
+
match event:
|
|
104
|
+
case DownloadingFileEvent(file_spec):
|
|
105
|
+
callbacks.downloading(file_spec)
|
|
106
|
+
case FinishedDownloadingFileEvent(file_spec):
|
|
107
|
+
callbacks.finished_downloading(file_spec)
|
|
108
|
+
case InitializingModelEvent():
|
|
109
|
+
callbacks.initializing_model()
|
|
110
|
+
case FinishedInitializingModelEvent():
|
|
111
|
+
callbacks.finished_initializing_model()
|
|
112
|
+
|
|
113
|
+
model, metadata = import_model(
|
|
114
|
+
model_spec,
|
|
115
|
+
precision=precision_dtype,
|
|
116
|
+
context_length=context_length,
|
|
117
|
+
progress_callback=progress_callback,
|
|
118
|
+
)
|
|
119
|
+
callbacks.saving_model()
|
|
120
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
121
|
+
|
|
122
|
+
model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
|
|
123
|
+
weights = flatten_parameters(model.export_weights())
|
|
124
|
+
del model
|
|
125
|
+
|
|
126
|
+
with Path(output_dir / "model.safetensors").open("wb") as fd:
|
|
127
|
+
safe_write(fd, weights)
|
|
128
|
+
|
|
129
|
+
config_json = config_converter.unstructure(metadata, ModelMetadata)
|
|
130
|
+
with open(output_dir / "config.json", "w") as file:
|
|
131
|
+
json.dump(config_json, file, indent=4)
|
|
132
|
+
|
|
133
|
+
callbacks.finished_saving_model()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class TraceCallbacks:
|
|
138
|
+
model_path: Path
|
|
139
|
+
output_path: Path
|
|
140
|
+
messages: Iterable[Message] | None
|
|
141
|
+
|
|
142
|
+
def output_exists(self) -> None:
|
|
143
|
+
raise RuntimeError(f"{self.output_path=} already exists, refusing to overwrite!")
|
|
144
|
+
|
|
145
|
+
def started(self) -> None:
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
def loading_model(self) -> None:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
def finished_loading_model(self) -> None:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def tracing_model(self) -> None:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
def finished_tracing_model(self) -> None:
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
def saving_trace(self) -> None:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
def finished_saving_trace(self) -> None:
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def trace(
|
|
168
|
+
model_path: Path,
|
|
169
|
+
output_path: Path,
|
|
170
|
+
messages: Iterable[Message] | None = None,
|
|
171
|
+
callbacks_type: Callable[
|
|
172
|
+
[
|
|
173
|
+
Path,
|
|
174
|
+
Path,
|
|
175
|
+
Iterable[Message] | None,
|
|
176
|
+
],
|
|
177
|
+
TraceCallbacks,
|
|
178
|
+
] = TraceCallbacks,
|
|
179
|
+
) -> None:
|
|
180
|
+
callbacks = callbacks_type(model_path, output_path, messages)
|
|
181
|
+
|
|
182
|
+
if output_path.exists():
|
|
183
|
+
callbacks.output_exists()
|
|
184
|
+
|
|
185
|
+
callbacks.started()
|
|
186
|
+
|
|
187
|
+
callbacks.loading_model()
|
|
188
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
189
|
+
callbacks.finished_loading_model()
|
|
190
|
+
|
|
191
|
+
callbacks.tracing_model()
|
|
192
|
+
result = model.record_trace(messages)
|
|
193
|
+
callbacks.finished_tracing_model()
|
|
194
|
+
|
|
195
|
+
callbacks.saving_trace()
|
|
196
|
+
traces = flatten_parameters(result.export())
|
|
197
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
198
|
+
with Path(output_path).open("wb") as fd:
|
|
199
|
+
safe_write(fd, traces)
|
|
200
|
+
callbacks.finished_saving_trace()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@dataclass
|
|
204
|
+
class EstimateBatchsizeCallbacks:
|
|
205
|
+
model_path: Path
|
|
206
|
+
max_input_length: int
|
|
207
|
+
max_output_length: int
|
|
208
|
+
num_logits_per_token: int
|
|
209
|
+
mem: int
|
|
210
|
+
|
|
211
|
+
def loading_model(self) -> None:
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
def finished_loading_model(self) -> None:
|
|
215
|
+
pass
|
|
216
|
+
|
|
217
|
+
def estimating_batchsize(self, lo: int, hi: int | None) -> None:
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
def finished_estimating_batchsize(self, batchsize: int) -> None:
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def estimate_batchsize(
|
|
225
|
+
model_path: Path,
|
|
226
|
+
mem: int,
|
|
227
|
+
max_input_length: int = 1024,
|
|
228
|
+
max_output_length: int = 1024,
|
|
229
|
+
num_logits_per_token: int = 8,
|
|
230
|
+
callbacks_type: Callable[
|
|
231
|
+
[
|
|
232
|
+
Path,
|
|
233
|
+
int,
|
|
234
|
+
int,
|
|
235
|
+
int,
|
|
236
|
+
int,
|
|
237
|
+
],
|
|
238
|
+
EstimateBatchsizeCallbacks,
|
|
239
|
+
] = EstimateBatchsizeCallbacks,
|
|
240
|
+
) -> int:
|
|
241
|
+
callbacks = callbacks_type(model_path, max_input_length, max_output_length, num_logits_per_token, mem)
|
|
242
|
+
|
|
243
|
+
callbacks.loading_model()
|
|
244
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
245
|
+
callbacks.finished_loading_model()
|
|
246
|
+
|
|
247
|
+
def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
|
|
248
|
+
callbacks.estimating_batchsize(event.lo, event.hi)
|
|
249
|
+
|
|
250
|
+
bs = estimate_batchsize_from_memory(
|
|
251
|
+
model,
|
|
252
|
+
max_input_length,
|
|
253
|
+
max_output_length,
|
|
254
|
+
num_logits_per_token,
|
|
255
|
+
mem,
|
|
256
|
+
progress_callback,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
callbacks.finished_estimating_batchsize(bs)
|
|
260
|
+
return bs
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@dataclass
|
|
264
|
+
class CollectTracesCallbacks:
|
|
265
|
+
model_path: Path
|
|
266
|
+
dataset_path: Path
|
|
267
|
+
output_path: Path
|
|
268
|
+
num_logits_per_token: int
|
|
269
|
+
max_input_length: int
|
|
270
|
+
max_output_length: int
|
|
271
|
+
batch_size: int
|
|
272
|
+
num_tokens_to_generate: int | None
|
|
273
|
+
|
|
274
|
+
def loading_model(self) -> None:
|
|
275
|
+
pass
|
|
276
|
+
|
|
277
|
+
def finished_loading_model(self) -> None:
|
|
278
|
+
pass
|
|
279
|
+
|
|
280
|
+
def loading_dataset(self) -> None:
|
|
281
|
+
pass
|
|
282
|
+
|
|
283
|
+
def finished_loading_dataset(self) -> None:
|
|
284
|
+
pass
|
|
285
|
+
|
|
286
|
+
def inference_progress(self, tokens_generated: int) -> None:
|
|
287
|
+
pass
|
|
288
|
+
|
|
289
|
+
def finished_inference(self) -> None:
|
|
290
|
+
pass
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def collect_traces(
|
|
294
|
+
model_path: Path,
|
|
295
|
+
dataset_path: Path,
|
|
296
|
+
output_path: Path,
|
|
297
|
+
num_logits_per_token: int = 8,
|
|
298
|
+
max_input_length: int = 1024,
|
|
299
|
+
max_output_length: int = 1024,
|
|
300
|
+
batch_size: int = 1,
|
|
301
|
+
num_tokens_to_generate: int | None = None,
|
|
302
|
+
callbacks_type: Callable[
|
|
303
|
+
[
|
|
304
|
+
Path,
|
|
305
|
+
Path,
|
|
306
|
+
Path,
|
|
307
|
+
int,
|
|
308
|
+
int,
|
|
309
|
+
int,
|
|
310
|
+
int,
|
|
311
|
+
int | None,
|
|
312
|
+
],
|
|
313
|
+
CollectTracesCallbacks,
|
|
314
|
+
] = CollectTracesCallbacks,
|
|
315
|
+
) -> None:
|
|
316
|
+
callbacks = callbacks_type(
|
|
317
|
+
model_path,
|
|
318
|
+
dataset_path,
|
|
319
|
+
output_path,
|
|
320
|
+
num_logits_per_token,
|
|
321
|
+
max_input_length,
|
|
322
|
+
max_output_length,
|
|
323
|
+
batch_size,
|
|
324
|
+
num_tokens_to_generate,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
callbacks.loading_model()
|
|
328
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
329
|
+
callbacks.finished_loading_model()
|
|
330
|
+
|
|
331
|
+
callbacks.loading_dataset()
|
|
332
|
+
dataset = iter(import_hf_parquet(dataset_path))
|
|
333
|
+
dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
|
|
334
|
+
callbacks.finished_loading_dataset()
|
|
335
|
+
|
|
336
|
+
def progress_callback(event: CollectTracesEvent) -> None:
|
|
337
|
+
callbacks.inference_progress(event.tokens_generated)
|
|
338
|
+
|
|
339
|
+
traces = inference_collect_traces(
|
|
340
|
+
model,
|
|
341
|
+
dataset,
|
|
342
|
+
num_logits_per_token,
|
|
343
|
+
batch_size,
|
|
344
|
+
max_input_length,
|
|
345
|
+
max_output_length,
|
|
346
|
+
num_tokens_to_generate,
|
|
347
|
+
progress_callback,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
351
|
+
with open(output_path, "wb") as output_fd:
|
|
352
|
+
for trace in traces:
|
|
353
|
+
blob = trace.serialize()
|
|
354
|
+
output_fd.write(blob)
|
|
355
|
+
|
|
356
|
+
callbacks.finished_inference()
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@dataclass
|
|
360
|
+
class TrainCallbacks:
|
|
361
|
+
trace_path: Path
|
|
362
|
+
output_path: Path
|
|
363
|
+
hashtable_size: int
|
|
364
|
+
num_logits_per_token: int
|
|
365
|
+
ngram_size: int
|
|
366
|
+
subsample_size: int | None
|
|
367
|
+
|
|
368
|
+
def started(self) -> None:
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
def training_progress(self, trained_tokens: int) -> None:
|
|
372
|
+
pass
|
|
373
|
+
|
|
374
|
+
def finished_training(self) -> None:
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
def saving_speculator(self) -> None:
|
|
378
|
+
pass
|
|
379
|
+
|
|
380
|
+
def finished_saving_speculator(self) -> None:
|
|
381
|
+
pass
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def train(
|
|
385
|
+
trace_path: Path,
|
|
386
|
+
output_path: Path,
|
|
387
|
+
hashtable_size: int = 65536,
|
|
388
|
+
num_logits_per_token: int = 8,
|
|
389
|
+
ngram_size: int = 2,
|
|
390
|
+
subsample_size: int | None = None,
|
|
391
|
+
callbacks_type: Callable[
|
|
392
|
+
[
|
|
393
|
+
Path,
|
|
394
|
+
Path,
|
|
395
|
+
int,
|
|
396
|
+
int,
|
|
397
|
+
int,
|
|
398
|
+
int | None,
|
|
399
|
+
],
|
|
400
|
+
TrainCallbacks,
|
|
401
|
+
] = TrainCallbacks,
|
|
402
|
+
) -> None:
|
|
403
|
+
callbacks = callbacks_type(
|
|
404
|
+
trace_path,
|
|
405
|
+
output_path,
|
|
406
|
+
hashtable_size,
|
|
407
|
+
num_logits_per_token,
|
|
408
|
+
ngram_size,
|
|
409
|
+
subsample_size,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
callbacks.started()
|
|
413
|
+
|
|
414
|
+
with open(trace_path, "rb") as trace_fd:
|
|
415
|
+
traces = LalamoCompletion.deserialize_many(trace_fd)
|
|
416
|
+
speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
|
|
417
|
+
|
|
418
|
+
def progress_callback(event: SpeculatorTrainingEvent) -> None:
|
|
419
|
+
callbacks.training_progress(event.trained_tokens)
|
|
420
|
+
|
|
421
|
+
train_speculator(speculator, traces, subsample_size, progress_callback)
|
|
422
|
+
|
|
423
|
+
callbacks.finished_training()
|
|
424
|
+
|
|
425
|
+
callbacks.saving_speculator()
|
|
426
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
427
|
+
with open(output_path, "wb") as fd:
|
|
428
|
+
fd.write(speculator.serialize())
|
|
429
|
+
callbacks.finished_saving_speculator()
|
lalamo/common.py
CHANGED
|
@@ -15,6 +15,8 @@ __all__ = [
|
|
|
15
15
|
"ParameterTree",
|
|
16
16
|
"dummy_array",
|
|
17
17
|
"flatten_parameters",
|
|
18
|
+
"require_array",
|
|
19
|
+
"require_tree",
|
|
18
20
|
"unflatten_parameters",
|
|
19
21
|
]
|
|
20
22
|
|
|
@@ -29,6 +31,16 @@ type ParameterTree[ArrayType: ArrayLike] = (
|
|
|
29
31
|
)
|
|
30
32
|
|
|
31
33
|
|
|
34
|
+
def require_array[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ArrayType:
|
|
35
|
+
assert not isinstance(value, (Mapping, Sequence))
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def require_tree[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ParameterTree[ArrayType]:
|
|
40
|
+
assert not isinstance(value, (Array, ShapeDtypeStruct))
|
|
41
|
+
return value
|
|
42
|
+
|
|
43
|
+
|
|
32
44
|
def dummy_array(shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
|
|
33
45
|
if isinstance(shape, int):
|
|
34
46
|
shape = (shape,)
|
|
@@ -40,9 +52,10 @@ def flatten_parameters[ArrayType: ArrayLike](nested_parameters: ParameterTree[Ar
|
|
|
40
52
|
if not isinstance(nested_parameters, Mapping):
|
|
41
53
|
nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
|
|
42
54
|
for key, value in nested_parameters.items():
|
|
55
|
+
value = cast("ArrayType | ParameterTree[ArrayType]", value)
|
|
43
56
|
key_path = ParameterPath(key)
|
|
44
57
|
if isinstance(value, (Array, ShapeDtypeStruct)):
|
|
45
|
-
result[key_path] = value
|
|
58
|
+
result[key_path] = cast("ArrayType", value)
|
|
46
59
|
else:
|
|
47
60
|
update: dict[str, ArrayType] = {
|
|
48
61
|
str(key_path / subkey): subvalue for subkey, subvalue in flatten_parameters(value).items()
|