lalamo 0.5.16__tar.gz → 0.5.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. {lalamo-0.5.16 → lalamo-0.5.17}/PKG-INFO +1 -2
  2. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/__init__.py +26 -2
  3. lalamo-0.5.17/lalamo/commands.py +377 -0
  4. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/main.py +239 -214
  5. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/common.py +4 -2
  6. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/models/common.py +3 -3
  7. lalamo-0.5.17/lalamo/safetensors.py +97 -0
  8. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/speculator/estimator.py +9 -2
  9. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/utils.py +0 -13
  10. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo.egg-info/PKG-INFO +1 -2
  11. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo.egg-info/SOURCES.txt +2 -0
  12. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo.egg-info/requires.txt +0 -1
  13. {lalamo-0.5.16 → lalamo-0.5.17}/pyproject.toml +0 -1
  14. lalamo-0.5.17/tests/test_huggingface_model_conversion.py +109 -0
  15. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_models.py +5 -0
  16. lalamo-0.5.16/tests/test_huggingface_model_conversion.py +0 -101
  17. {lalamo-0.5.16 → lalamo-0.5.17}/LICENSE +0 -0
  18. {lalamo-0.5.16 → lalamo-0.5.17}/README.md +0 -0
  19. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/common.py +0 -0
  20. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/data/__init__.py +0 -0
  21. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/data/huggingface_message.py +0 -0
  22. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/data/lalamo_completions.py +0 -0
  23. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/data/utils.py +0 -0
  24. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/message_processor.py +0 -0
  25. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/__init__.py +0 -0
  26. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/common.py +0 -0
  27. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  28. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/common.py +0 -0
  29. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  30. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  31. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  32. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  33. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  34. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  35. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  36. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  37. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  38. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  39. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  40. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  41. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  42. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/huggingface_generation_config.py +0 -0
  43. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  44. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/loaders/__init__.py +0 -0
  45. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/loaders/common.py +0 -0
  46. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/loaders/executorch.py +0 -0
  47. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/loaders/huggingface.py +0 -0
  48. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/loaders/utils.py +0 -0
  49. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/__init__.py +0 -0
  50. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/deepseek.py +0 -0
  51. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  52. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/gemma.py +0 -0
  53. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  54. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/huggingface.py +0 -0
  55. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/lfm2.py +0 -0
  56. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/llama.py +0 -0
  57. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/llamba.py +0 -0
  58. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/mirai.py +0 -0
  59. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/mistral.py +0 -0
  60. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/pleias.py +0 -0
  61. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/polaris.py +0 -0
  62. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/qwen.py +0 -0
  63. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/model_import/model_specs/reka.py +0 -0
  64. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/models/__init__.py +0 -0
  65. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/models/classifier.py +0 -0
  66. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/models/language_model.py +0 -0
  67. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/__init__.py +0 -0
  68. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/activations.py +0 -0
  69. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/classifier.py +0 -0
  70. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/common.py +0 -0
  71. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/decoder.py +0 -0
  72. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/embedding.py +0 -0
  73. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/linear.py +0 -0
  74. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/mlp.py +0 -0
  75. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/mlx_interop.py +0 -0
  76. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/normalization.py +0 -0
  77. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/rope.py +0 -0
  78. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/__init__.py +0 -0
  79. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/attention.py +0 -0
  80. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/common.py +0 -0
  81. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/mamba.py +0 -0
  82. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/short_conv.py +0 -0
  83. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  84. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/state/common.py +0 -0
  85. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  86. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  87. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  88. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/torch_interop.py +0 -0
  89. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/transformer.py +0 -0
  90. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/transformer_layer.py +0 -0
  91. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/modules/utils.py +0 -0
  92. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/quantization.py +0 -0
  93. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/registry_abc.py +0 -0
  94. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/sampling.py +0 -0
  95. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/speculator/__init__.py +0 -0
  96. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/speculator/common.py +0 -0
  97. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/speculator/inference.py +0 -0
  98. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/speculator/ngram.py +0 -0
  99. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo/speculator/utils.py +0 -0
  100. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo.egg-info/dependency_links.txt +0 -0
  101. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo.egg-info/entry_points.txt +0 -0
  102. {lalamo-0.5.16 → lalamo-0.5.17}/lalamo.egg-info/top_level.txt +0 -0
  103. {lalamo-0.5.16 → lalamo-0.5.17}/setup.cfg +0 -0
  104. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_cartesia_mlx_models.py +0 -0
  105. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_chat_template.py +0 -0
  106. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_generation.py +0 -0
  107. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_huggingface_models.py +0 -0
  108. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_lfm2_models.py +0 -0
  109. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_mlx_models.py +0 -0
  110. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_model_spec.py +0 -0
  111. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_moe.py +0 -0
  112. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_parameter_tree.py +0 -0
  113. {lalamo-0.5.16 → lalamo-0.5.17}/tests/test_registry_abc.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.16
3
+ Version: 0.5.17
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -19,7 +19,6 @@ Requires-Dist: rich>=14.0.0
19
19
  Requires-Dist: thefuzz>=0.22.1
20
20
  Requires-Dist: tokenizers>=0.21.2
21
21
  Requires-Dist: typer>=0.15.1
22
- Requires-Dist: safetensors>=0.6.2
23
22
  Requires-Dist: polars>=1.33.1
24
23
  Requires-Dist: xxhash>=3.5.0
25
24
  Provides-Extra: cpu
@@ -1,4 +1,14 @@
1
- from lalamo.main import collect_traces, convert, estimate_batchsize, train
1
+ from lalamo.commands import (
2
+ CollectTracesCallbacks,
3
+ ConversionCallbacks,
4
+ EstimateBatchsizeCallbacks,
5
+ Precision,
6
+ TrainCallbacks,
7
+ collect_traces,
8
+ convert,
9
+ estimate_batchsize,
10
+ train,
11
+ )
2
12
  from lalamo.message_processor import (
3
13
  AssistantMessage,
4
14
  ContentBlock,
@@ -9,27 +19,41 @@ from lalamo.message_processor import (
9
19
  UserMessage,
10
20
  )
11
21
  from lalamo.model_import import ModelSpec, import_model
22
+ from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase, WeightsType
12
23
  from lalamo.models import ClassifierModel, LanguageModel
24
+ from lalamo.quantization import QuantizationMode
13
25
  from lalamo.speculator import (
14
26
  CollectTracesEvent,
15
27
  SpeculatorTrainingEvent,
16
28
  )
17
29
 
18
- __version__ = "0.5.16"
30
+ __version__ = "0.5.17"
19
31
 
20
32
  __all__ = [
21
33
  "AssistantMessage",
22
34
  "ClassifierModel",
35
+ "CollectTracesCallbacks",
23
36
  "CollectTracesEvent",
37
+ "ConfigMap",
24
38
  "ContentBlock",
39
+ "ConversionCallbacks",
40
+ "EstimateBatchsizeCallbacks",
41
+ "FileSpec",
25
42
  "Image",
43
+ "JSONFieldSpec",
26
44
  "LanguageModel",
27
45
  "Message",
28
46
  "ModelSpec",
47
+ "ModelType",
48
+ "Precision",
49
+ "QuantizationMode",
29
50
  "SpeculatorTrainingEvent",
30
51
  "SystemMessage",
31
52
  "ToolSchema",
53
+ "TrainCallbacks",
54
+ "UseCase",
32
55
  "UserMessage",
56
+ "WeightsType",
33
57
  "collect_traces",
34
58
  "convert",
35
59
  "estimate_batchsize",
@@ -0,0 +1,377 @@
1
+ import json
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from itertools import chain
6
+ from pathlib import Path
7
+
8
+ from jaxtyping import DTypeLike
9
+
10
+ from lalamo.common import flatten_parameters
11
+ from lalamo.data import import_hf_parquet
12
+ from lalamo.data.lalamo_completions import LalamoCompletion
13
+ from lalamo.message_processor import UserMessage
14
+ from lalamo.model_import import ModelMetadata, ModelSpec, import_model
15
+ from lalamo.model_import.common import (
16
+ DownloadingFileEvent,
17
+ FileSpec,
18
+ FinishedDownloadingFileEvent,
19
+ FinishedInitializingModelEvent,
20
+ InitializingModelEvent,
21
+ StatusEvent,
22
+ )
23
+ from lalamo.models import LanguageModelConfig
24
+ from lalamo.modules import config_converter
25
+ from lalamo.safetensors import safe_write
26
+ from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
27
+ from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
28
+ from lalamo.speculator.ngram import NGramSpeculator
29
+ from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
30
+
31
+
32
+ class Precision(Enum):
33
+ FLOAT32 = "float32"
34
+ FLOAT16 = "float16"
35
+ BFLOAT16 = "bfloat16"
36
+
37
+
38
+ @dataclass
39
+ class ConversionCallbacks:
40
+ model_spec: ModelSpec
41
+ output_dir: Path
42
+ precision: Precision | None
43
+ context_length: int | None
44
+ include_traces: bool
45
+ message_for_trace: str | None
46
+
47
+ def started(self) -> None:
48
+ pass
49
+
50
+ def output_dir_exists(self) -> None:
51
+ raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
52
+
53
+ def downloading(self, file_spec: FileSpec) -> None:
54
+ pass
55
+
56
+ def finished_downloading(self, file_spec: FileSpec) -> None:
57
+ pass
58
+
59
+ def initializing_model(self) -> None:
60
+ pass
61
+
62
+ def finished_initializing_model(self) -> None:
63
+ pass
64
+
65
+ def saving_model(self) -> None:
66
+ pass
67
+
68
+ def finished_saving_model(self) -> None:
69
+ pass
70
+
71
+
72
+ def convert(
73
+ model_spec: ModelSpec,
74
+ output_dir: Path,
75
+ precision: Precision | None = None,
76
+ context_length: int | None = None,
77
+ include_traces: bool = False,
78
+ message_for_trace: str | None = None,
79
+ callbacks_type: Callable[
80
+ [
81
+ ModelSpec,
82
+ Path,
83
+ Precision | None,
84
+ int | None,
85
+ bool,
86
+ str | None,
87
+ ],
88
+ ConversionCallbacks,
89
+ ] = ConversionCallbacks,
90
+ ) -> None:
91
+ callbacks = callbacks_type(
92
+ model_spec,
93
+ output_dir,
94
+ precision,
95
+ context_length,
96
+ include_traces,
97
+ message_for_trace,
98
+ )
99
+
100
+ if precision is not None:
101
+ precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
102
+ else:
103
+ precision_dtype = None
104
+
105
+ if output_dir.exists():
106
+ callbacks.output_dir_exists()
107
+
108
+ callbacks.started()
109
+
110
+ def progress_callback(event: StatusEvent) -> None:
111
+ match event:
112
+ case DownloadingFileEvent(file_spec):
113
+ callbacks.downloading(file_spec)
114
+ case FinishedDownloadingFileEvent(file_spec):
115
+ callbacks.finished_downloading(file_spec)
116
+ case InitializingModelEvent():
117
+ callbacks.initializing_model()
118
+ case FinishedInitializingModelEvent():
119
+ callbacks.finished_initializing_model()
120
+
121
+ model, metadata = import_model(
122
+ model_spec,
123
+ precision=precision_dtype,
124
+ context_length=context_length,
125
+ progress_callback=progress_callback,
126
+ )
127
+ callbacks.saving_model()
128
+ output_dir.mkdir(parents=True, exist_ok=True)
129
+
130
+ if include_traces:
131
+ message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
132
+ result = model.record_trace(message)
133
+ traces = flatten_parameters(result.export())
134
+ with Path(output_dir / "traces.safetensors").open("wb") as fd:
135
+ safe_write(fd, traces)
136
+
137
+ model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
138
+ weights = flatten_parameters(model.export_weights())
139
+ del model
140
+
141
+ with Path(output_dir / "model.safetensors").open("wb") as fd:
142
+ safe_write(fd, weights)
143
+
144
+ config_json = config_converter.unstructure(metadata, ModelMetadata)
145
+ with open(output_dir / "config.json", "w") as file:
146
+ json.dump(config_json, file, indent=4)
147
+
148
+ callbacks.finished_saving_model()
149
+
150
+
151
+ @dataclass
152
+ class EstimateBatchsizeCallbacks:
153
+ model_path: Path
154
+ max_input_length: int
155
+ max_output_length: int
156
+ num_logits_per_token: int
157
+ mem: int
158
+
159
+ def loading_model(self) -> None:
160
+ pass
161
+
162
+ def finished_loading_model(self) -> None:
163
+ pass
164
+
165
+ def estimating_batchsize(self, lo: int, hi: int | None) -> None:
166
+ pass
167
+
168
+ def finished_estimating_batchsize(self, batchsize: int) -> None:
169
+ pass
170
+
171
+
172
+ def estimate_batchsize(
173
+ model_path: Path,
174
+ mem: int,
175
+ max_input_length: int = 1024,
176
+ max_output_length: int = 1024,
177
+ num_logits_per_token: int = 8,
178
+ callbacks_type: Callable[
179
+ [
180
+ Path,
181
+ int,
182
+ int,
183
+ int,
184
+ int,
185
+ ],
186
+ EstimateBatchsizeCallbacks,
187
+ ] = EstimateBatchsizeCallbacks,
188
+ ) -> int:
189
+ callbacks = callbacks_type(model_path, max_input_length, max_output_length, num_logits_per_token, mem)
190
+
191
+ callbacks.loading_model()
192
+ model = LanguageModelConfig.load_model(model_path)
193
+ callbacks.finished_loading_model()
194
+
195
+ def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
196
+ callbacks.estimating_batchsize(event.lo, event.hi)
197
+
198
+ bs = estimate_batchsize_from_memory(
199
+ model,
200
+ max_input_length,
201
+ max_output_length,
202
+ num_logits_per_token,
203
+ mem,
204
+ progress_callback,
205
+ )
206
+
207
+ callbacks.finished_estimating_batchsize(bs)
208
+ return bs
209
+
210
+
211
+ @dataclass
212
+ class CollectTracesCallbacks:
213
+ model_path: Path
214
+ dataset_path: Path
215
+ output_path: Path
216
+ num_logits_per_token: int
217
+ max_input_length: int
218
+ max_output_length: int
219
+ batch_size: int
220
+ num_tokens_to_generate: int | None
221
+
222
+ def loading_model(self) -> None:
223
+ pass
224
+
225
+ def finished_loading_model(self) -> None:
226
+ pass
227
+
228
+ def loading_dataset(self) -> None:
229
+ pass
230
+
231
+ def finished_loading_dataset(self) -> None:
232
+ pass
233
+
234
+ def inference_progress(self, tokens_generated: int) -> None:
235
+ pass
236
+
237
+ def finished_inference(self) -> None:
238
+ pass
239
+
240
+
241
+ def collect_traces(
242
+ model_path: Path,
243
+ dataset_path: Path,
244
+ output_path: Path,
245
+ num_logits_per_token: int = 8,
246
+ max_input_length: int = 1024,
247
+ max_output_length: int = 1024,
248
+ batch_size: int = 1,
249
+ num_tokens_to_generate: int | None = None,
250
+ callbacks_type: Callable[
251
+ [
252
+ Path,
253
+ Path,
254
+ Path,
255
+ int,
256
+ int,
257
+ int,
258
+ int,
259
+ int | None,
260
+ ],
261
+ CollectTracesCallbacks,
262
+ ] = CollectTracesCallbacks,
263
+ ) -> None:
264
+ callbacks = callbacks_type(
265
+ model_path,
266
+ dataset_path,
267
+ output_path,
268
+ num_logits_per_token,
269
+ max_input_length,
270
+ max_output_length,
271
+ batch_size,
272
+ num_tokens_to_generate,
273
+ )
274
+
275
+ callbacks.loading_model()
276
+ model = LanguageModelConfig.load_model(model_path)
277
+ callbacks.finished_loading_model()
278
+
279
+ callbacks.loading_dataset()
280
+ dataset = iter(import_hf_parquet(dataset_path))
281
+ dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
282
+ callbacks.finished_loading_dataset()
283
+
284
+ def progress_callback(event: CollectTracesEvent) -> None:
285
+ callbacks.inference_progress(event.tokens_generated)
286
+
287
+ traces = inference_collect_traces(
288
+ model,
289
+ dataset,
290
+ num_logits_per_token,
291
+ batch_size,
292
+ max_input_length,
293
+ max_output_length,
294
+ num_tokens_to_generate,
295
+ progress_callback,
296
+ )
297
+
298
+ output_path.parent.mkdir(parents=True, exist_ok=True)
299
+ with open(output_path, "wb") as output_fd:
300
+ for trace in traces:
301
+ blob = trace.serialize()
302
+ output_fd.write(blob)
303
+
304
+ callbacks.finished_inference()
305
+
306
+
307
+ @dataclass
308
+ class TrainCallbacks:
309
+ trace_path: Path
310
+ output_path: Path
311
+ hashtable_size: int
312
+ num_logits_per_token: int
313
+ ngram_size: int
314
+ subsample_size: int | None
315
+
316
+ def started(self) -> None:
317
+ pass
318
+
319
+ def training_progress(self, trained_tokens: int) -> None:
320
+ pass
321
+
322
+ def finished_training(self) -> None:
323
+ pass
324
+
325
+ def saving_speculator(self) -> None:
326
+ pass
327
+
328
+ def finished_saving_speculator(self) -> None:
329
+ pass
330
+
331
+
332
+ def train(
333
+ trace_path: Path,
334
+ output_path: Path,
335
+ hashtable_size: int = 65536,
336
+ num_logits_per_token: int = 8,
337
+ ngram_size: int = 2,
338
+ subsample_size: int | None = None,
339
+ callbacks_type: Callable[
340
+ [
341
+ Path,
342
+ Path,
343
+ int,
344
+ int,
345
+ int,
346
+ int | None,
347
+ ],
348
+ TrainCallbacks,
349
+ ] = TrainCallbacks,
350
+ ) -> None:
351
+ callbacks = callbacks_type(
352
+ trace_path,
353
+ output_path,
354
+ hashtable_size,
355
+ num_logits_per_token,
356
+ ngram_size,
357
+ subsample_size,
358
+ )
359
+
360
+ callbacks.started()
361
+
362
+ with open(trace_path, "rb") as trace_fd:
363
+ traces = LalamoCompletion.deserialize_many(trace_fd)
364
+ speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
365
+
366
+ def progress_callback(event: SpeculatorTrainingEvent) -> None:
367
+ callbacks.training_progress(event.trained_tokens)
368
+
369
+ train_speculator(speculator, traces, subsample_size, progress_callback)
370
+
371
+ callbacks.finished_training()
372
+
373
+ callbacks.saving_speculator()
374
+ output_path.parent.mkdir(parents=True, exist_ok=True)
375
+ with open(output_path, "wb") as fd:
376
+ fd.write(speculator.serialize())
377
+ callbacks.finished_saving_speculator()