lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. lalamo/__init__.py +26 -2
  2. lalamo/commands.py +429 -0
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +375 -229
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +8 -4
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/common.py +3 -3
  16. lalamo/models/language_model.py +7 -6
  17. lalamo/modules/activations.py +1 -1
  18. lalamo/modules/classifier.py +11 -24
  19. lalamo/modules/common.py +4 -1
  20. lalamo/modules/decoder.py +5 -11
  21. lalamo/modules/embedding.py +25 -62
  22. lalamo/modules/linear.py +19 -33
  23. lalamo/modules/mlp.py +9 -19
  24. lalamo/modules/mlx_interop.py +1 -1
  25. lalamo/modules/rope.py +1 -1
  26. lalamo/modules/token_mixers/__init__.py +1 -1
  27. lalamo/modules/token_mixers/attention.py +9 -27
  28. lalamo/modules/token_mixers/mamba.py +9 -24
  29. lalamo/modules/token_mixers/short_conv.py +5 -12
  30. lalamo/modules/transformer.py +10 -20
  31. lalamo/modules/transformer_layer.py +8 -20
  32. lalamo/registry_abc.py +4 -4
  33. lalamo/safetensors.py +97 -0
  34. lalamo/sampling.py +14 -0
  35. lalamo/speculator/estimator.py +11 -4
  36. lalamo/speculator/ngram.py +1 -1
  37. lalamo/utils.py +0 -13
  38. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
  39. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
  40. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
  41. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
  42. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/__init__.py CHANGED
@@ -1,4 +1,14 @@
1
- from lalamo.main import collect_traces, convert, estimate_batchsize, train
1
+ from lalamo.commands import (
2
+ CollectTracesCallbacks,
3
+ ConversionCallbacks,
4
+ EstimateBatchsizeCallbacks,
5
+ Precision,
6
+ TrainCallbacks,
7
+ collect_traces,
8
+ convert,
9
+ estimate_batchsize,
10
+ train,
11
+ )
2
12
  from lalamo.message_processor import (
3
13
  AssistantMessage,
4
14
  ContentBlock,
@@ -9,27 +19,41 @@ from lalamo.message_processor import (
9
19
  UserMessage,
10
20
  )
11
21
  from lalamo.model_import import ModelSpec, import_model
22
+ from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase, WeightsType
12
23
  from lalamo.models import ClassifierModel, LanguageModel
24
+ from lalamo.quantization import QuantizationMode
13
25
  from lalamo.speculator import (
14
26
  CollectTracesEvent,
15
27
  SpeculatorTrainingEvent,
16
28
  )
17
29
 
18
- __version__ = "0.5.16"
30
+ __version__ = "0.6.0"
19
31
 
20
32
  __all__ = [
21
33
  "AssistantMessage",
22
34
  "ClassifierModel",
35
+ "CollectTracesCallbacks",
23
36
  "CollectTracesEvent",
37
+ "ConfigMap",
24
38
  "ContentBlock",
39
+ "ConversionCallbacks",
40
+ "EstimateBatchsizeCallbacks",
41
+ "FileSpec",
25
42
  "Image",
43
+ "JSONFieldSpec",
26
44
  "LanguageModel",
27
45
  "Message",
28
46
  "ModelSpec",
47
+ "ModelType",
48
+ "Precision",
49
+ "QuantizationMode",
29
50
  "SpeculatorTrainingEvent",
30
51
  "SystemMessage",
31
52
  "ToolSchema",
53
+ "TrainCallbacks",
54
+ "UseCase",
32
55
  "UserMessage",
56
+ "WeightsType",
33
57
  "collect_traces",
34
58
  "convert",
35
59
  "estimate_batchsize",
lalamo/commands.py ADDED
@@ -0,0 +1,429 @@
1
+ import json
2
+ from collections.abc import Callable, Iterable
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from itertools import chain
6
+ from pathlib import Path
7
+
8
+ from jaxtyping import DTypeLike
9
+
10
+ from lalamo.common import flatten_parameters
11
+ from lalamo.data import import_hf_parquet
12
+ from lalamo.data.lalamo_completions import LalamoCompletion
13
+ from lalamo.message_processor import Message
14
+ from lalamo.model_import import ModelMetadata, ModelSpec, import_model
15
+ from lalamo.model_import.common import (
16
+ DownloadingFileEvent,
17
+ FileSpec,
18
+ FinishedDownloadingFileEvent,
19
+ FinishedInitializingModelEvent,
20
+ InitializingModelEvent,
21
+ StatusEvent,
22
+ )
23
+ from lalamo.models import LanguageModelConfig
24
+ from lalamo.modules import config_converter
25
+ from lalamo.safetensors import safe_write
26
+ from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
27
+ from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
28
+ from lalamo.speculator.ngram import NGramSpeculator
29
+ from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
30
+
31
+
32
+ class Precision(Enum):
33
+ FLOAT32 = "float32"
34
+ FLOAT16 = "float16"
35
+ BFLOAT16 = "bfloat16"
36
+
37
+
38
+ @dataclass
39
+ class ConversionCallbacks:
40
+ model_spec: ModelSpec
41
+ output_dir: Path
42
+ precision: Precision | None
43
+ context_length: int | None
44
+
45
+ def started(self) -> None:
46
+ pass
47
+
48
+ def output_dir_exists(self) -> None:
49
+ raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
50
+
51
+ def downloading(self, file_spec: FileSpec) -> None:
52
+ pass
53
+
54
+ def finished_downloading(self, file_spec: FileSpec) -> None:
55
+ pass
56
+
57
+ def initializing_model(self) -> None:
58
+ pass
59
+
60
+ def finished_initializing_model(self) -> None:
61
+ pass
62
+
63
+ def saving_model(self) -> None:
64
+ pass
65
+
66
+ def finished_saving_model(self) -> None:
67
+ pass
68
+
69
+
70
+ def convert(
71
+ model_spec: ModelSpec,
72
+ output_dir: Path,
73
+ precision: Precision | None = None,
74
+ context_length: int | None = None,
75
+ callbacks_type: Callable[
76
+ [
77
+ ModelSpec,
78
+ Path,
79
+ Precision | None,
80
+ int | None,
81
+ ],
82
+ ConversionCallbacks,
83
+ ] = ConversionCallbacks,
84
+ ) -> None:
85
+ callbacks = callbacks_type(
86
+ model_spec,
87
+ output_dir,
88
+ precision,
89
+ context_length,
90
+ )
91
+
92
+ if precision is not None:
93
+ precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
94
+ else:
95
+ precision_dtype = None
96
+
97
+ if output_dir.exists():
98
+ callbacks.output_dir_exists()
99
+
100
+ callbacks.started()
101
+
102
+ def progress_callback(event: StatusEvent) -> None:
103
+ match event:
104
+ case DownloadingFileEvent(file_spec):
105
+ callbacks.downloading(file_spec)
106
+ case FinishedDownloadingFileEvent(file_spec):
107
+ callbacks.finished_downloading(file_spec)
108
+ case InitializingModelEvent():
109
+ callbacks.initializing_model()
110
+ case FinishedInitializingModelEvent():
111
+ callbacks.finished_initializing_model()
112
+
113
+ model, metadata = import_model(
114
+ model_spec,
115
+ precision=precision_dtype,
116
+ context_length=context_length,
117
+ progress_callback=progress_callback,
118
+ )
119
+ callbacks.saving_model()
120
+ output_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
123
+ weights = flatten_parameters(model.export_weights())
124
+ del model
125
+
126
+ with Path(output_dir / "model.safetensors").open("wb") as fd:
127
+ safe_write(fd, weights)
128
+
129
+ config_json = config_converter.unstructure(metadata, ModelMetadata)
130
+ with open(output_dir / "config.json", "w") as file:
131
+ json.dump(config_json, file, indent=4)
132
+
133
+ callbacks.finished_saving_model()
134
+
135
+
136
+ @dataclass
137
+ class TraceCallbacks:
138
+ model_path: Path
139
+ output_path: Path
140
+ messages: Iterable[Message] | None
141
+
142
+ def output_exists(self) -> None:
143
+ raise RuntimeError(f"{self.output_path=} already exists, refusing to overwrite!")
144
+
145
+ def started(self) -> None:
146
+ pass
147
+
148
+ def loading_model(self) -> None:
149
+ pass
150
+
151
+ def finished_loading_model(self) -> None:
152
+ pass
153
+
154
+ def tracing_model(self) -> None:
155
+ pass
156
+
157
+ def finished_tracing_model(self) -> None:
158
+ pass
159
+
160
+ def saving_trace(self) -> None:
161
+ pass
162
+
163
+ def finished_saving_trace(self) -> None:
164
+ pass
165
+
166
+
167
+ def trace(
168
+ model_path: Path,
169
+ output_path: Path,
170
+ messages: Iterable[Message] | None = None,
171
+ callbacks_type: Callable[
172
+ [
173
+ Path,
174
+ Path,
175
+ Iterable[Message] | None,
176
+ ],
177
+ TraceCallbacks,
178
+ ] = TraceCallbacks,
179
+ ) -> None:
180
+ callbacks = callbacks_type(model_path, output_path, messages)
181
+
182
+ if output_path.exists():
183
+ callbacks.output_exists()
184
+
185
+ callbacks.started()
186
+
187
+ callbacks.loading_model()
188
+ model = LanguageModelConfig.load_model(model_path)
189
+ callbacks.finished_loading_model()
190
+
191
+ callbacks.tracing_model()
192
+ result = model.record_trace(messages)
193
+ callbacks.finished_tracing_model()
194
+
195
+ callbacks.saving_trace()
196
+ traces = flatten_parameters(result.export())
197
+ output_path.parent.mkdir(parents=True, exist_ok=True)
198
+ with Path(output_path).open("wb") as fd:
199
+ safe_write(fd, traces)
200
+ callbacks.finished_saving_trace()
201
+
202
+
203
+ @dataclass
204
+ class EstimateBatchsizeCallbacks:
205
+ model_path: Path
206
+ max_input_length: int
207
+ max_output_length: int
208
+ num_logits_per_token: int
209
+ mem: int
210
+
211
+ def loading_model(self) -> None:
212
+ pass
213
+
214
+ def finished_loading_model(self) -> None:
215
+ pass
216
+
217
+ def estimating_batchsize(self, lo: int, hi: int | None) -> None:
218
+ pass
219
+
220
+ def finished_estimating_batchsize(self, batchsize: int) -> None:
221
+ pass
222
+
223
+
224
+ def estimate_batchsize(
225
+ model_path: Path,
226
+ mem: int,
227
+ max_input_length: int = 1024,
228
+ max_output_length: int = 1024,
229
+ num_logits_per_token: int = 8,
230
+ callbacks_type: Callable[
231
+ [
232
+ Path,
233
+ int,
234
+ int,
235
+ int,
236
+ int,
237
+ ],
238
+ EstimateBatchsizeCallbacks,
239
+ ] = EstimateBatchsizeCallbacks,
240
+ ) -> int:
241
+ callbacks = callbacks_type(model_path, max_input_length, max_output_length, num_logits_per_token, mem)
242
+
243
+ callbacks.loading_model()
244
+ model = LanguageModelConfig.load_model(model_path)
245
+ callbacks.finished_loading_model()
246
+
247
+ def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
248
+ callbacks.estimating_batchsize(event.lo, event.hi)
249
+
250
+ bs = estimate_batchsize_from_memory(
251
+ model,
252
+ max_input_length,
253
+ max_output_length,
254
+ num_logits_per_token,
255
+ mem,
256
+ progress_callback,
257
+ )
258
+
259
+ callbacks.finished_estimating_batchsize(bs)
260
+ return bs
261
+
262
+
263
+ @dataclass
264
+ class CollectTracesCallbacks:
265
+ model_path: Path
266
+ dataset_path: Path
267
+ output_path: Path
268
+ num_logits_per_token: int
269
+ max_input_length: int
270
+ max_output_length: int
271
+ batch_size: int
272
+ num_tokens_to_generate: int | None
273
+
274
+ def loading_model(self) -> None:
275
+ pass
276
+
277
+ def finished_loading_model(self) -> None:
278
+ pass
279
+
280
+ def loading_dataset(self) -> None:
281
+ pass
282
+
283
+ def finished_loading_dataset(self) -> None:
284
+ pass
285
+
286
+ def inference_progress(self, tokens_generated: int) -> None:
287
+ pass
288
+
289
+ def finished_inference(self) -> None:
290
+ pass
291
+
292
+
293
+ def collect_traces(
294
+ model_path: Path,
295
+ dataset_path: Path,
296
+ output_path: Path,
297
+ num_logits_per_token: int = 8,
298
+ max_input_length: int = 1024,
299
+ max_output_length: int = 1024,
300
+ batch_size: int = 1,
301
+ num_tokens_to_generate: int | None = None,
302
+ callbacks_type: Callable[
303
+ [
304
+ Path,
305
+ Path,
306
+ Path,
307
+ int,
308
+ int,
309
+ int,
310
+ int,
311
+ int | None,
312
+ ],
313
+ CollectTracesCallbacks,
314
+ ] = CollectTracesCallbacks,
315
+ ) -> None:
316
+ callbacks = callbacks_type(
317
+ model_path,
318
+ dataset_path,
319
+ output_path,
320
+ num_logits_per_token,
321
+ max_input_length,
322
+ max_output_length,
323
+ batch_size,
324
+ num_tokens_to_generate,
325
+ )
326
+
327
+ callbacks.loading_model()
328
+ model = LanguageModelConfig.load_model(model_path)
329
+ callbacks.finished_loading_model()
330
+
331
+ callbacks.loading_dataset()
332
+ dataset = iter(import_hf_parquet(dataset_path))
333
+ dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
334
+ callbacks.finished_loading_dataset()
335
+
336
+ def progress_callback(event: CollectTracesEvent) -> None:
337
+ callbacks.inference_progress(event.tokens_generated)
338
+
339
+ traces = inference_collect_traces(
340
+ model,
341
+ dataset,
342
+ num_logits_per_token,
343
+ batch_size,
344
+ max_input_length,
345
+ max_output_length,
346
+ num_tokens_to_generate,
347
+ progress_callback,
348
+ )
349
+
350
+ output_path.parent.mkdir(parents=True, exist_ok=True)
351
+ with open(output_path, "wb") as output_fd:
352
+ for trace in traces:
353
+ blob = trace.serialize()
354
+ output_fd.write(blob)
355
+
356
+ callbacks.finished_inference()
357
+
358
+
359
+ @dataclass
360
+ class TrainCallbacks:
361
+ trace_path: Path
362
+ output_path: Path
363
+ hashtable_size: int
364
+ num_logits_per_token: int
365
+ ngram_size: int
366
+ subsample_size: int | None
367
+
368
+ def started(self) -> None:
369
+ pass
370
+
371
+ def training_progress(self, trained_tokens: int) -> None:
372
+ pass
373
+
374
+ def finished_training(self) -> None:
375
+ pass
376
+
377
+ def saving_speculator(self) -> None:
378
+ pass
379
+
380
+ def finished_saving_speculator(self) -> None:
381
+ pass
382
+
383
+
384
+ def train(
385
+ trace_path: Path,
386
+ output_path: Path,
387
+ hashtable_size: int = 65536,
388
+ num_logits_per_token: int = 8,
389
+ ngram_size: int = 2,
390
+ subsample_size: int | None = None,
391
+ callbacks_type: Callable[
392
+ [
393
+ Path,
394
+ Path,
395
+ int,
396
+ int,
397
+ int,
398
+ int | None,
399
+ ],
400
+ TrainCallbacks,
401
+ ] = TrainCallbacks,
402
+ ) -> None:
403
+ callbacks = callbacks_type(
404
+ trace_path,
405
+ output_path,
406
+ hashtable_size,
407
+ num_logits_per_token,
408
+ ngram_size,
409
+ subsample_size,
410
+ )
411
+
412
+ callbacks.started()
413
+
414
+ with open(trace_path, "rb") as trace_fd:
415
+ traces = LalamoCompletion.deserialize_many(trace_fd)
416
+ speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
417
+
418
+ def progress_callback(event: SpeculatorTrainingEvent) -> None:
419
+ callbacks.training_progress(event.trained_tokens)
420
+
421
+ train_speculator(speculator, traces, subsample_size, progress_callback)
422
+
423
+ callbacks.finished_training()
424
+
425
+ callbacks.saving_speculator()
426
+ output_path.parent.mkdir(parents=True, exist_ok=True)
427
+ with open(output_path, "wb") as fd:
428
+ fd.write(speculator.serialize())
429
+ callbacks.finished_saving_speculator()
lalamo/common.py CHANGED
@@ -15,6 +15,8 @@ __all__ = [
15
15
  "ParameterTree",
16
16
  "dummy_array",
17
17
  "flatten_parameters",
18
+ "require_array",
19
+ "require_tree",
18
20
  "unflatten_parameters",
19
21
  ]
20
22
 
@@ -29,6 +31,16 @@ type ParameterTree[ArrayType: ArrayLike] = (
29
31
  )
30
32
 
31
33
 
34
+ def require_array[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ArrayType:
35
+ assert not isinstance(value, (Mapping, Sequence))
36
+ return value
37
+
38
+
39
+ def require_tree[ArrayType: ArrayLike](value: ArrayType | ParameterTree[ArrayType]) -> ParameterTree[ArrayType]:
40
+ assert not isinstance(value, (Array, ShapeDtypeStruct))
41
+ return value
42
+
43
+
32
44
  def dummy_array(shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
33
45
  if isinstance(shape, int):
34
46
  shape = (shape,)
@@ -40,9 +52,10 @@ def flatten_parameters[ArrayType: ArrayLike](nested_parameters: ParameterTree[Ar
40
52
  if not isinstance(nested_parameters, Mapping):
41
53
  nested_parameters = {str(i): value for i, value in enumerate(nested_parameters)}
42
54
  for key, value in nested_parameters.items():
55
+ value = cast("ArrayType | ParameterTree[ArrayType]", value)
43
56
  key_path = ParameterPath(key)
44
57
  if isinstance(value, (Array, ShapeDtypeStruct)):
45
- result[key_path] = value
58
+ result[key_path] = cast("ArrayType", value)
46
59
  else:
47
60
  update: dict[str, ArrayType] = {
48
61
  str(key_path / subkey): subvalue for subkey, subvalue in flatten_parameters(value).items()