lalamo 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -1,4 +1,14 @@
1
- from lalamo.main import collect_traces, convert, estimate_batchsize, train
1
+ from lalamo.commands import (
2
+ CollectTracesCallbacks,
3
+ ConversionCallbacks,
4
+ EstimateBatchsizeCallbacks,
5
+ Precision,
6
+ TrainCallbacks,
7
+ collect_traces,
8
+ convert,
9
+ estimate_batchsize,
10
+ train,
11
+ )
2
12
  from lalamo.message_processor import (
3
13
  AssistantMessage,
4
14
  ContentBlock,
@@ -9,27 +19,41 @@ from lalamo.message_processor import (
9
19
  UserMessage,
10
20
  )
11
21
  from lalamo.model_import import ModelSpec, import_model
22
+ from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase, WeightsType
12
23
  from lalamo.models import ClassifierModel, LanguageModel
24
+ from lalamo.quantization import QuantizationMode
13
25
  from lalamo.speculator import (
14
26
  CollectTracesEvent,
15
27
  SpeculatorTrainingEvent,
16
28
  )
17
29
 
18
- __version__ = "0.5.16"
30
+ __version__ = "0.5.17"
19
31
 
20
32
  __all__ = [
21
33
  "AssistantMessage",
22
34
  "ClassifierModel",
35
+ "CollectTracesCallbacks",
23
36
  "CollectTracesEvent",
37
+ "ConfigMap",
24
38
  "ContentBlock",
39
+ "ConversionCallbacks",
40
+ "EstimateBatchsizeCallbacks",
41
+ "FileSpec",
25
42
  "Image",
43
+ "JSONFieldSpec",
26
44
  "LanguageModel",
27
45
  "Message",
28
46
  "ModelSpec",
47
+ "ModelType",
48
+ "Precision",
49
+ "QuantizationMode",
29
50
  "SpeculatorTrainingEvent",
30
51
  "SystemMessage",
31
52
  "ToolSchema",
53
+ "TrainCallbacks",
54
+ "UseCase",
32
55
  "UserMessage",
56
+ "WeightsType",
33
57
  "collect_traces",
34
58
  "convert",
35
59
  "estimate_batchsize",
lalamo/commands.py ADDED
@@ -0,0 +1,377 @@
1
+ import json
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from itertools import chain
6
+ from pathlib import Path
7
+
8
+ from jaxtyping import DTypeLike
9
+
10
+ from lalamo.common import flatten_parameters
11
+ from lalamo.data import import_hf_parquet
12
+ from lalamo.data.lalamo_completions import LalamoCompletion
13
+ from lalamo.message_processor import UserMessage
14
+ from lalamo.model_import import ModelMetadata, ModelSpec, import_model
15
+ from lalamo.model_import.common import (
16
+ DownloadingFileEvent,
17
+ FileSpec,
18
+ FinishedDownloadingFileEvent,
19
+ FinishedInitializingModelEvent,
20
+ InitializingModelEvent,
21
+ StatusEvent,
22
+ )
23
+ from lalamo.models import LanguageModelConfig
24
+ from lalamo.modules import config_converter
25
+ from lalamo.safetensors import safe_write
26
+ from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
27
+ from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
28
+ from lalamo.speculator.ngram import NGramSpeculator
29
+ from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
30
+
31
+
32
+ class Precision(Enum):
33
+ FLOAT32 = "float32"
34
+ FLOAT16 = "float16"
35
+ BFLOAT16 = "bfloat16"
36
+
37
+
38
+ @dataclass
39
+ class ConversionCallbacks:
40
+ model_spec: ModelSpec
41
+ output_dir: Path
42
+ precision: Precision | None
43
+ context_length: int | None
44
+ include_traces: bool
45
+ message_for_trace: str | None
46
+
47
+ def started(self) -> None:
48
+ pass
49
+
50
+ def output_dir_exists(self) -> None:
51
+ raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
52
+
53
+ def downloading(self, file_spec: FileSpec) -> None:
54
+ pass
55
+
56
+ def finished_downloading(self, file_spec: FileSpec) -> None:
57
+ pass
58
+
59
+ def initializing_model(self) -> None:
60
+ pass
61
+
62
+ def finished_initializing_model(self) -> None:
63
+ pass
64
+
65
+ def saving_model(self) -> None:
66
+ pass
67
+
68
+ def finished_saving_model(self) -> None:
69
+ pass
70
+
71
+
72
+ def convert(
73
+ model_spec: ModelSpec,
74
+ output_dir: Path,
75
+ precision: Precision | None = None,
76
+ context_length: int | None = None,
77
+ include_traces: bool = False,
78
+ message_for_trace: str | None = None,
79
+ callbacks_type: Callable[
80
+ [
81
+ ModelSpec,
82
+ Path,
83
+ Precision | None,
84
+ int | None,
85
+ bool,
86
+ str | None,
87
+ ],
88
+ ConversionCallbacks,
89
+ ] = ConversionCallbacks,
90
+ ) -> None:
91
+ callbacks = callbacks_type(
92
+ model_spec,
93
+ output_dir,
94
+ precision,
95
+ context_length,
96
+ include_traces,
97
+ message_for_trace,
98
+ )
99
+
100
+ if precision is not None:
101
+ precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
102
+ else:
103
+ precision_dtype = None
104
+
105
+ if output_dir.exists():
106
+ callbacks.output_dir_exists()
107
+
108
+ callbacks.started()
109
+
110
+ def progress_callback(event: StatusEvent) -> None:
111
+ match event:
112
+ case DownloadingFileEvent(file_spec):
113
+ callbacks.downloading(file_spec)
114
+ case FinishedDownloadingFileEvent(file_spec):
115
+ callbacks.finished_downloading(file_spec)
116
+ case InitializingModelEvent():
117
+ callbacks.initializing_model()
118
+ case FinishedInitializingModelEvent():
119
+ callbacks.finished_initializing_model()
120
+
121
+ model, metadata = import_model(
122
+ model_spec,
123
+ precision=precision_dtype,
124
+ context_length=context_length,
125
+ progress_callback=progress_callback,
126
+ )
127
+ callbacks.saving_model()
128
+ output_dir.mkdir(parents=True, exist_ok=True)
129
+
130
+ if include_traces:
131
+ message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
132
+ result = model.record_trace(message)
133
+ traces = flatten_parameters(result.export())
134
+ with Path(output_dir / "traces.safetensors").open("wb") as fd:
135
+ safe_write(fd, traces)
136
+
137
+ model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
138
+ weights = flatten_parameters(model.export_weights())
139
+ del model
140
+
141
+ with Path(output_dir / "model.safetensors").open("wb") as fd:
142
+ safe_write(fd, weights)
143
+
144
+ config_json = config_converter.unstructure(metadata, ModelMetadata)
145
+ with open(output_dir / "config.json", "w") as file:
146
+ json.dump(config_json, file, indent=4)
147
+
148
+ callbacks.finished_saving_model()
149
+
150
+
151
+ @dataclass
152
+ class EstimateBatchsizeCallbacks:
153
+ model_path: Path
154
+ max_input_length: int
155
+ max_output_length: int
156
+ num_logits_per_token: int
157
+ mem: int
158
+
159
+ def loading_model(self) -> None:
160
+ pass
161
+
162
+ def finished_loading_model(self) -> None:
163
+ pass
164
+
165
+ def estimating_batchsize(self, lo: int, hi: int | None) -> None:
166
+ pass
167
+
168
+ def finished_estimating_batchsize(self, batchsize: int) -> None:
169
+ pass
170
+
171
+
172
+ def estimate_batchsize(
173
+ model_path: Path,
174
+ mem: int,
175
+ max_input_length: int = 1024,
176
+ max_output_length: int = 1024,
177
+ num_logits_per_token: int = 8,
178
+ callbacks_type: Callable[
179
+ [
180
+ Path,
181
+ int,
182
+ int,
183
+ int,
184
+ int,
185
+ ],
186
+ EstimateBatchsizeCallbacks,
187
+ ] = EstimateBatchsizeCallbacks,
188
+ ) -> int:
189
+ callbacks = callbacks_type(model_path, max_input_length, max_output_length, num_logits_per_token, mem)
190
+
191
+ callbacks.loading_model()
192
+ model = LanguageModelConfig.load_model(model_path)
193
+ callbacks.finished_loading_model()
194
+
195
+ def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
196
+ callbacks.estimating_batchsize(event.lo, event.hi)
197
+
198
+ bs = estimate_batchsize_from_memory(
199
+ model,
200
+ max_input_length,
201
+ max_output_length,
202
+ num_logits_per_token,
203
+ mem,
204
+ progress_callback,
205
+ )
206
+
207
+ callbacks.finished_estimating_batchsize(bs)
208
+ return bs
209
+
210
+
211
+ @dataclass
212
+ class CollectTracesCallbacks:
213
+ model_path: Path
214
+ dataset_path: Path
215
+ output_path: Path
216
+ num_logits_per_token: int
217
+ max_input_length: int
218
+ max_output_length: int
219
+ batch_size: int
220
+ num_tokens_to_generate: int | None
221
+
222
+ def loading_model(self) -> None:
223
+ pass
224
+
225
+ def finished_loading_model(self) -> None:
226
+ pass
227
+
228
+ def loading_dataset(self) -> None:
229
+ pass
230
+
231
+ def finished_loading_dataset(self) -> None:
232
+ pass
233
+
234
+ def inference_progress(self, tokens_generated: int) -> None:
235
+ pass
236
+
237
+ def finished_inference(self) -> None:
238
+ pass
239
+
240
+
241
+ def collect_traces(
242
+ model_path: Path,
243
+ dataset_path: Path,
244
+ output_path: Path,
245
+ num_logits_per_token: int = 8,
246
+ max_input_length: int = 1024,
247
+ max_output_length: int = 1024,
248
+ batch_size: int = 1,
249
+ num_tokens_to_generate: int | None = None,
250
+ callbacks_type: Callable[
251
+ [
252
+ Path,
253
+ Path,
254
+ Path,
255
+ int,
256
+ int,
257
+ int,
258
+ int,
259
+ int | None,
260
+ ],
261
+ CollectTracesCallbacks,
262
+ ] = CollectTracesCallbacks,
263
+ ) -> None:
264
+ callbacks = callbacks_type(
265
+ model_path,
266
+ dataset_path,
267
+ output_path,
268
+ num_logits_per_token,
269
+ max_input_length,
270
+ max_output_length,
271
+ batch_size,
272
+ num_tokens_to_generate,
273
+ )
274
+
275
+ callbacks.loading_model()
276
+ model = LanguageModelConfig.load_model(model_path)
277
+ callbacks.finished_loading_model()
278
+
279
+ callbacks.loading_dataset()
280
+ dataset = iter(import_hf_parquet(dataset_path))
281
+ dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
282
+ callbacks.finished_loading_dataset()
283
+
284
+ def progress_callback(event: CollectTracesEvent) -> None:
285
+ callbacks.inference_progress(event.tokens_generated)
286
+
287
+ traces = inference_collect_traces(
288
+ model,
289
+ dataset,
290
+ num_logits_per_token,
291
+ batch_size,
292
+ max_input_length,
293
+ max_output_length,
294
+ num_tokens_to_generate,
295
+ progress_callback,
296
+ )
297
+
298
+ output_path.parent.mkdir(parents=True, exist_ok=True)
299
+ with open(output_path, "wb") as output_fd:
300
+ for trace in traces:
301
+ blob = trace.serialize()
302
+ output_fd.write(blob)
303
+
304
+ callbacks.finished_inference()
305
+
306
+
307
+ @dataclass
308
+ class TrainCallbacks:
309
+ trace_path: Path
310
+ output_path: Path
311
+ hashtable_size: int
312
+ num_logits_per_token: int
313
+ ngram_size: int
314
+ subsample_size: int | None
315
+
316
+ def started(self) -> None:
317
+ pass
318
+
319
+ def training_progress(self, trained_tokens: int) -> None:
320
+ pass
321
+
322
+ def finished_training(self) -> None:
323
+ pass
324
+
325
+ def saving_speculator(self) -> None:
326
+ pass
327
+
328
+ def finished_saving_speculator(self) -> None:
329
+ pass
330
+
331
+
332
+ def train(
333
+ trace_path: Path,
334
+ output_path: Path,
335
+ hashtable_size: int = 65536,
336
+ num_logits_per_token: int = 8,
337
+ ngram_size: int = 2,
338
+ subsample_size: int | None = None,
339
+ callbacks_type: Callable[
340
+ [
341
+ Path,
342
+ Path,
343
+ int,
344
+ int,
345
+ int,
346
+ int | None,
347
+ ],
348
+ TrainCallbacks,
349
+ ] = TrainCallbacks,
350
+ ) -> None:
351
+ callbacks = callbacks_type(
352
+ trace_path,
353
+ output_path,
354
+ hashtable_size,
355
+ num_logits_per_token,
356
+ ngram_size,
357
+ subsample_size,
358
+ )
359
+
360
+ callbacks.started()
361
+
362
+ with open(trace_path, "rb") as trace_fd:
363
+ traces = LalamoCompletion.deserialize_many(trace_fd)
364
+ speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
365
+
366
+ def progress_callback(event: SpeculatorTrainingEvent) -> None:
367
+ callbacks.training_progress(event.trained_tokens)
368
+
369
+ train_speculator(speculator, traces, subsample_size, progress_callback)
370
+
371
+ callbacks.finished_training()
372
+
373
+ callbacks.saving_speculator()
374
+ output_path.parent.mkdir(parents=True, exist_ok=True)
375
+ with open(output_path, "wb") as fd:
376
+ fd.write(speculator.serialize())
377
+ callbacks.finished_saving_speculator()
lalamo/main.py CHANGED
@@ -1,20 +1,19 @@
1
- import json
2
1
  import random
3
2
  import re
4
3
  import shutil
5
4
  import sys
6
- from enum import Enum
7
- from itertools import chain, islice
5
+ from contextlib import ExitStack
6
+ from dataclasses import dataclass, field
7
+ from functools import partial
8
+ from itertools import islice
8
9
  from pathlib import Path
9
10
  from typing import Annotated
10
11
 
11
- import jax
12
12
  import jax.profiler
13
13
  import thefuzz.process
14
14
  from click import Context as ClickContext
15
15
  from click import Parameter as ClickParameter
16
16
  from click import ParamType
17
- from jaxtyping import DTypeLike
18
17
  from rich import box
19
18
  from rich.console import Console
20
19
  from rich.live import Live
@@ -23,48 +22,40 @@ from rich.progress import (
23
22
  MofNCompleteColumn,
24
23
  Progress,
25
24
  SpinnerColumn,
25
+ TaskID,
26
26
  TextColumn,
27
27
  TimeElapsedColumn,
28
28
  TimeRemainingColumn,
29
29
  )
30
+ from rich.prompt import Confirm
30
31
  from rich.table import Table
31
- from safetensors.flax import save_file
32
32
  from typer import Argument, Context, Exit, Option, Typer
33
33
 
34
- from lalamo.common import flatten_parameters
35
- from lalamo.data import import_hf_parquet
34
+ from lalamo.commands import (
35
+ CollectTracesCallbacks,
36
+ ConversionCallbacks,
37
+ EstimateBatchsizeCallbacks,
38
+ Precision,
39
+ TrainCallbacks,
40
+ )
41
+ from lalamo.commands import collect_traces as _collect_traces
42
+ from lalamo.commands import convert as _convert
43
+ from lalamo.commands import estimate_batchsize as _estimate_batchsize
44
+ from lalamo.commands import train as _train
36
45
  from lalamo.data.lalamo_completions import LalamoCompletion
37
46
  from lalamo.message_processor import UserMessage
38
- from lalamo.model_import import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
39
- from lalamo.model_import.common import (
40
- DownloadingFileEvent,
41
- FinishedDownloadingFileEvent,
42
- FinishedInitializingModelEvent,
43
- InitializingModelEvent,
44
- StatusEvent,
45
- )
47
+ from lalamo.model_import import REPO_TO_MODEL, ModelSpec
48
+ from lalamo.model_import.common import FileSpec
46
49
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
47
- from lalamo.modules import config_converter
48
- from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
49
- from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
50
+ from lalamo.speculator.estimator import get_default_device_memory
50
51
  from lalamo.speculator.ngram import NGramSpeculator
51
- from lalamo.speculator.utils import (
52
- SpeculatorTrainingEvent,
53
- test_speculator,
54
- train_speculator,
55
- )
52
+ from lalamo.speculator.utils import test_speculator
56
53
 
57
54
  SCRIPT_NAME = Path(sys.argv[0]).name
58
55
 
59
56
  DEFAULT_OUTPUT_DIR = Path("models")
60
57
 
61
58
 
62
- class Precision(Enum):
63
- FLOAT32 = "float32"
64
- FLOAT16 = "float16"
65
- BFLOAT16 = "bfloat16"
66
-
67
-
68
59
  console = Console()
69
60
  err_console = Console(stderr=True)
70
61
  app = Typer(
@@ -182,6 +173,68 @@ def classify(
182
173
  console.print()
183
174
 
184
175
 
176
+ @dataclass
177
+ class CliConversionCallbacks(ConversionCallbacks):
178
+ overwrite: bool = False
179
+
180
+ stack: ExitStack = field(default_factory=ExitStack)
181
+ downloading_tasks: dict[FileSpec, TaskID] = field(default_factory=dict)
182
+ initializing_task: TaskID | None = None
183
+ saving_task: TaskID | None = None
184
+
185
+ def started(self) -> None:
186
+ conversion_strs = [
187
+ f"🚀 Converting [cyan]{self.model_spec.name}[/cyan] by [cyan]{self.model_spec.vendor}[/cyan]",
188
+ ]
189
+ if self.precision is not None:
190
+ conversion_strs.append(
191
+ f" and converting floating-point weights into [cyan]{self.precision.name.lower()}[/cyan] precision",
192
+ )
193
+ conversion_strs.append(".")
194
+ console.print("".join(conversion_strs))
195
+
196
+ self.progress = self.stack.enter_context(
197
+ Progress(
198
+ SpinnerColumn(),
199
+ TextColumn("[progress.description]{task.description}"),
200
+ transient=True,
201
+ ),
202
+ )
203
+
204
+ def output_dir_exists(self) -> None:
205
+ if not self.overwrite and not Confirm().ask(
206
+ rf"⚠️ Output directory [cyan]{self.output_dir}[/cyan] already exists."
207
+ r" Do you want to overwrite it?",
208
+ ):
209
+ raise Exit
210
+
211
+ shutil.rmtree(self.output_dir)
212
+
213
+ def downloading(self, file_spec: FileSpec) -> None:
214
+ self.downloading_tasks[file_spec] = self.progress.add_task(f"Retrieving {file_spec.filename}...")
215
+
216
+ def finished_downloading(self, file_spec: FileSpec) -> None:
217
+ self.progress.remove_task(self.downloading_tasks[file_spec])
218
+
219
+ def initializing_model(self) -> None:
220
+ self.initializing_task = self.progress.add_task("Initializing model...")
221
+
222
+ def finished_initializing_model(self) -> None:
223
+ assert self.initializing_task is not None
224
+
225
+ self.progress.remove_task(self.initializing_task)
226
+
227
+ def saving_model(self) -> None:
228
+ self.saving_task = self.progress.add_task(f"💾 Saving the model to {self.output_dir}")
229
+
230
+ def finished_saving_model(self) -> None:
231
+ assert self.saving_task is not None
232
+
233
+ self.progress.remove_task(self.saving_task)
234
+ self.stack.close()
235
+ console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{self.output_dir}`[/cyan]!")
236
+
237
+
185
238
  @app.command(help="Convert the model for use with the Uzu inference engine.")
186
239
  def convert(
187
240
  model_repo: Annotated[
@@ -238,85 +291,18 @@ def convert(
238
291
  ),
239
292
  ] = None,
240
293
  ) -> None:
241
- if precision is not None:
242
- precision_dtype = config_converter.structure(precision.value, DTypeLike) # type: ignore
243
- else:
244
- precision_dtype = None
245
-
246
294
  if output_dir is None:
247
295
  output_dir = DEFAULT_OUTPUT_DIR / model_repo.name
248
296
 
249
- conversion_strs = [f"🚀 Converting [cyan]{model_repo.name}[/cyan] by [cyan]{model_repo.vendor}[/cyan]"]
250
- if precision is not None:
251
- conversion_strs.append(
252
- f" and converting floating-point weights into [cyan]{precision.name.lower()}[/cyan] precision",
253
- )
254
- conversion_strs.append(".")
255
- console.print("".join(conversion_strs))
256
-
257
- if output_dir.exists() and not overwrite:
258
- answer = console.input(
259
- rf"⚠️ Output directory [cyan]{output_dir}[/cyan] already exists."
260
- r" Do you want to overwrite it? [cyan]\[y/n][/cyan]: ",
261
- )
262
- while answer.lower() not in ["y", "n", "yes", "no"]:
263
- answer = console.input("Please enter 'y' or 'n': ")
264
- if answer.lower() in ["y", "yes"]:
265
- shutil.rmtree(output_dir)
266
- else:
267
- console.print("Exiting...")
268
- raise Exit
269
-
270
- message = None if message_for_trace is None else [UserMessage(content=message_for_trace)]
271
-
272
- with Progress(
273
- SpinnerColumn(),
274
- TextColumn("[progress.description]{task.description}"),
275
- transient=True,
276
- ) as progress:
277
- event_to_task = {}
278
-
279
- def progress_callback(event: StatusEvent) -> None:
280
- match event:
281
- case DownloadingFileEvent(file_spec):
282
- event_to_task[event] = progress.add_task(f"Retrieving {file_spec.filename}...")
283
- case FinishedDownloadingFileEvent(file_spec):
284
- progress.remove_task(event_to_task[event])
285
- case InitializingModelEvent():
286
- event_to_task[event] = progress.add_task("Initializing model...")
287
- case FinishedInitializingModelEvent():
288
- progress.remove_task(event_to_task[event])
289
-
290
- main_task = progress.add_task("👨‍🍳 Cooking...")
291
- model, metadata = import_model(
292
- model_repo,
293
- precision=precision_dtype,
294
- context_length=context_length,
295
- progress_callback=progress_callback,
296
- )
297
- save_task = progress.add_task(f"💾 Saving the model to {output_dir}")
298
- output_dir.mkdir(parents=True, exist_ok=True)
299
-
300
- if include_traces:
301
- trace_task = progress.add_task("🚁 Generating traces...")
302
- result = model.record_trace(message)
303
- traces = flatten_parameters(result.export())
304
- save_file(traces, output_dir / "traces.safetensors")
305
- progress.remove_task(trace_task)
306
- progress.remove_task(main_task)
307
-
308
- model.message_processor.tokenizer.save(str(output_dir / "tokenizer.json"))
309
- weights = flatten_parameters(model.export_weights())
310
- del model
311
-
312
- save_file(weights, output_dir / "model.safetensors")
313
-
314
- config_json = config_converter.unstructure(metadata, ModelMetadata)
315
- with open(output_dir / "config.json", "w") as file:
316
- json.dump(config_json, file, indent=4)
317
- progress.remove_task(save_task)
318
-
319
- console.print(f"🧑‍🍳 Model successfully cooked and saved to [cyan]`{output_dir}`[/cyan]!")
297
+ _convert(
298
+ model_repo,
299
+ output_dir,
300
+ precision,
301
+ context_length,
302
+ include_traces,
303
+ message_for_trace,
304
+ partial(CliConversionCallbacks, overwrite=overwrite),
305
+ )
320
306
 
321
307
 
322
308
  def _model_size_string_to_int(
@@ -385,6 +371,41 @@ speculator_app = Typer()
385
371
  app.add_typer(speculator_app, name="speculator", help="Train a speculator for a model.")
386
372
 
387
373
 
374
+ @dataclass
375
+ class CliEstimateBatchsizeCallbacks(EstimateBatchsizeCallbacks):
376
+ stack: ExitStack = field(default_factory=ExitStack)
377
+ loading_task: TaskID | None = None
378
+ estimating_task: TaskID | None = None
379
+
380
+ def loading_model(self) -> None:
381
+ self.progress = self.stack.enter_context(
382
+ Progress(
383
+ SpinnerColumn(),
384
+ TextColumn("[progress.description]{task.description}"),
385
+ transient=True,
386
+ ),
387
+ )
388
+ self.loading_task = self.progress.add_task("[cyan]Loading model...[/cyan]")
389
+
390
+ def finished_loading_model(self) -> None:
391
+ assert self.loading_task is not None
392
+ self.progress.remove_task(self.loading_task)
393
+
394
+ def estimating_batchsize(self, lo: int, hi: int | None) -> None:
395
+ hi_str = str(hi) if hi is not None else "?"
396
+ description = f"[cyan]Estimating batch size... ({lo}..{hi_str})[/cyan]"
397
+ if self.estimating_task is None:
398
+ self.estimating_task = self.progress.add_task(description)
399
+ else:
400
+ self.progress.update(self.estimating_task, description=description)
401
+
402
+ def finished_estimating_batchsize(self, batchsize: int) -> None:
403
+ if self.estimating_task is not None:
404
+ self.progress.remove_task(self.estimating_task)
405
+ self.stack.close()
406
+ console.print(f"Found maximum batch size: [cyan]{batchsize}[/cyan]")
407
+
408
+
388
409
  @speculator_app.command(help="Estimate maximum batch size at which a model can be run.")
389
410
  def estimate_batchsize(
390
411
  model_path: Annotated[
@@ -416,44 +437,64 @@ def estimate_batchsize(
416
437
  ) -> None:
417
438
  if vram_gb is not None:
418
439
  mem = vram_gb * 1024 * 1024 * 1024
419
- else:
420
- memory_stats = jax.local_devices()[0].memory_stats()
421
- if memory_stats is None:
422
- err_console.print("Cannot get the default device's memory stats, use --vram-gb")
423
- raise Exit(1)
424
- if "bytes_limit" not in memory_stats:
425
- err_console.print("Cannot get the default device's bytes limit, use --vram-gb")
426
- raise Exit(1)
427
- mem = memory_stats["bytes_limit"]
440
+ elif (mem := get_default_device_memory()) is None:
441
+ err_console.print("Cannot get the default device's memory stats, use --vram-gb")
442
+ raise Exit(1)
428
443
 
429
- with Progress(
430
- SpinnerColumn(),
431
- TextColumn("[progress.description]{task.description}"),
432
- transient=True,
433
- ) as progress:
434
- loading_model_task = progress.add_task("[cyan]Loading model...[/cyan]")
435
- model = LanguageModelConfig.load_model(model_path)
436
- progress.remove_task(loading_model_task)
437
-
438
- estimating_batchsize_task = progress.add_task("[cyan]Estimating batch size...[/cyan]")
439
-
440
- def progress_callback(event: EstimateBatchsizeFromMemoryEvent) -> None:
441
- lo = str(event.lo)
442
- hi = str(event.hi) if event.hi is not None else "?"
443
- description = f"[cyan]Estimating batch size... ({lo}..{hi})[/cyan]"
444
- progress.update(estimating_batchsize_task, description=description)
445
-
446
- bs = estimate_batchsize_from_memory(
447
- model,
448
- max_input_length,
449
- max_output_length,
450
- num_logits_per_token,
451
- mem,
452
- progress_callback,
444
+ callbacks_type = CliEstimateBatchsizeCallbacks
445
+
446
+ _estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
447
+
448
+
449
+ @dataclass
450
+ class CliCollectTracesCallbacks(CollectTracesCallbacks):
451
+ stack: ExitStack = field(default_factory=ExitStack)
452
+ live: Live | None = None
453
+ loading_task: TaskID | None = None
454
+ inference_task: TaskID | None = None
455
+
456
+ def loading_model(self) -> None:
457
+ self.live = self.stack.enter_context(Live(refresh_per_second=10))
458
+ self.progress = Progress(
459
+ SpinnerColumn(),
460
+ TextColumn("[progress.description]{task.description}"),
461
+ transient=True,
462
+ )
463
+ self.live.update(self.progress, refresh=True)
464
+ self.loading_task = self.progress.add_task("🧠 [cyan]Loading model...[/cyan]")
465
+
466
+ def finished_loading_model(self) -> None:
467
+ assert self.loading_task is not None
468
+ self.progress.remove_task(self.loading_task)
469
+
470
+ def loading_dataset(self) -> None:
471
+ self.loading_task = self.progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
472
+
473
+ def finished_loading_dataset(self) -> None:
474
+ assert self.loading_task is not None
475
+ assert self.live is not None
476
+ self.progress.remove_task(self.loading_task)
477
+ self.progress = Progress(
478
+ SpinnerColumn(),
479
+ TextColumn("[progress.description]{task.description}"),
480
+ MofNCompleteColumn(),
481
+ TimeElapsedColumn(),
482
+ TimeRemainingColumn(),
453
483
  )
454
- progress.remove_task(estimating_batchsize_task)
484
+ self.live.update(self.progress, refresh=True)
485
+ self.inference_task = self.progress.add_task(
486
+ "🔮 [cyan]Running inference...[/cyan]",
487
+ total=self.num_tokens_to_generate,
488
+ )
489
+
490
+ def inference_progress(self, tokens_generated: int) -> None:
491
+ assert self.inference_task is not None
492
+ self.progress.update(self.inference_task, completed=tokens_generated)
455
493
 
456
- console.print(f"Found maximum batch size: [cyan]{bs}[/cyan]")
494
+ def finished_inference(self) -> None:
495
+ assert self.inference_task is not None
496
+ self.progress.update(self.inference_task, description="✅ Completed")
497
+ self.stack.close()
457
498
 
458
499
 
459
500
  @speculator_app.command(help="Run model inference and collect traces for speculator training")
@@ -503,55 +544,17 @@ def collect_traces(
503
544
  ),
504
545
  ] = None,
505
546
  ) -> None:
506
- with Live(refresh_per_second=10) as live:
507
- with Progress(
508
- SpinnerColumn(),
509
- TextColumn("[progress.description]{task.description}"),
510
- transient=True,
511
- disable=True,
512
- ) as progress:
513
- live.update(progress, refresh=True)
514
- loading_model_task = progress.add_task("🧠 [cyan]Loading model...[/cyan]")
515
- model = LanguageModelConfig.load_model(model_path)
516
- progress.remove_task(loading_model_task)
517
-
518
- loading_dataset_task = progress.add_task("🗂️ [cyan]Loading dataset...[/cyan]")
519
- dataset = iter(import_hf_parquet(dataset_path))
520
- dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
521
- progress.remove_task(loading_dataset_task)
522
-
523
- with Progress(
524
- SpinnerColumn(),
525
- TextColumn("[progress.description]{task.description}"),
526
- MofNCompleteColumn(),
527
- TimeElapsedColumn(),
528
- TimeRemainingColumn(),
529
- disable=True,
530
- ) as progress:
531
- live.update(progress, refresh=True)
532
- inference_task = progress.add_task("🔮 [cyan]Running inference...[/cyan]", total=num_tokens_to_generate)
533
-
534
- def progress_callback(event: CollectTracesEvent) -> None:
535
- progress.update(inference_task, completed=event.tokens_generated)
536
-
537
- traces = inference_collect_traces(
538
- model,
539
- dataset,
540
- num_logits_per_token,
541
- batch_size,
542
- max_input_length,
543
- max_output_length,
544
- num_tokens_to_generate,
545
- progress_callback,
546
- )
547
-
548
- output_path.parent.mkdir(parents=True, exist_ok=True)
549
- with open(output_path, "wb+") as output_fd:
550
- for trace in traces:
551
- blob = trace.serialize()
552
- output_fd.write(blob)
553
-
554
- progress.update(inference_task, description="✅ Completed")
547
+ _collect_traces(
548
+ model_path,
549
+ dataset_path,
550
+ output_path,
551
+ num_logits_per_token,
552
+ max_input_length,
553
+ max_output_length,
554
+ batch_size,
555
+ num_tokens_to_generate,
556
+ CliCollectTracesCallbacks,
557
+ )
555
558
 
556
559
 
557
560
  @speculator_app.command(help="View model inference traces")
@@ -597,6 +600,43 @@ def view_traces(
597
600
  console.print(table)
598
601
 
599
602
 
603
+ @dataclass
604
+ class CliTrainCallbacks(TrainCallbacks):
605
+ stack: ExitStack = field(default_factory=ExitStack)
606
+ training_task: TaskID | None = None
607
+
608
+ def started(self) -> None:
609
+ self.progress = self.stack.enter_context(
610
+ Progress(
611
+ SpinnerColumn(),
612
+ TextColumn("[progress.description]{task.description}"),
613
+ MofNCompleteColumn(),
614
+ TimeElapsedColumn(),
615
+ TimeRemainingColumn(),
616
+ ),
617
+ )
618
+ self.training_task = self.progress.add_task(
619
+ "🔮 [cyan]Training speculator...[/cyan]",
620
+ total=self.subsample_size,
621
+ )
622
+
623
+ def training_progress(self, trained_tokens: int) -> None:
624
+ assert self.training_task is not None
625
+ self.progress.update(self.training_task, completed=trained_tokens)
626
+
627
+ def finished_training(self) -> None:
628
+ assert self.training_task is not None
629
+ self.progress.update(self.training_task, description="✅ Completed")
630
+ self.progress.remove_task(self.training_task)
631
+ self.stack.close()
632
+
633
+ def saving_speculator(self) -> None:
634
+ pass
635
+
636
+ def finished_saving_speculator(self) -> None:
637
+ console.print(f"💾 Speculator saved to [cyan]{self.output_path}[/cyan]")
638
+
639
+
600
640
  @speculator_app.command(help="Train a speculator from inference traces")
601
641
  def train(
602
642
  trace_path: Annotated[
@@ -633,30 +673,15 @@ def train(
633
673
  ),
634
674
  ] = None,
635
675
  ) -> None:
636
- with open(trace_path, "rb") as trace_fd:
637
- traces = LalamoCompletion.deserialize_many(trace_fd)
638
-
639
- speculator = NGramSpeculator.new(hashtable_size, num_logits_per_token, ngram_size)
640
-
641
- with Progress(
642
- SpinnerColumn(),
643
- TextColumn("[progress.description]{task.description}"),
644
- MofNCompleteColumn(),
645
- TimeElapsedColumn(),
646
- TimeRemainingColumn(),
647
- ) as progress:
648
- inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
649
-
650
- def progress_callback(event: SpeculatorTrainingEvent) -> None:
651
- progress.update(inference_task, completed=event.trained_tokens)
652
-
653
- train_speculator(speculator, traces, subsample_size, progress_callback)
654
-
655
- progress.update(inference_task, description="✅ Completed")
656
-
657
- output_path.parent.mkdir(parents=True, exist_ok=True)
658
- with open(output_path, "wb+") as fd:
659
- fd.write(speculator.serialize())
676
+ _train(
677
+ trace_path,
678
+ output_path,
679
+ hashtable_size,
680
+ num_logits_per_token,
681
+ ngram_size,
682
+ subsample_size,
683
+ CliTrainCallbacks,
684
+ )
660
685
 
661
686
 
662
687
  @speculator_app.command(help="Run speculator as an autoregressive llm")
@@ -15,7 +15,8 @@ from jaxtyping import Array, DTypeLike
15
15
 
16
16
  from lalamo.model_import.decoder_configs import ForeignConfig
17
17
  from lalamo.quantization import QuantizationMode
18
- from lalamo.utils import MapDictValues, open_safetensors
18
+ from lalamo.safetensors import safe_read
19
+ from lalamo.utils import MapDictValues
19
20
 
20
21
  __all__ = [
21
22
  "ConfigMap",
@@ -52,7 +53,8 @@ class WeightsType(Enum):
52
53
  float_dtype: DTypeLike,
53
54
  ) -> Iterator[tuple[Mapping[str, jnp.ndarray], Mapping[str, str]]]:
54
55
  if self == WeightsType.SAFETENSORS:
55
- with open_safetensors(filename) as (weights_dict, metadata_dict):
56
+ with Path(filename).open("rb") as fd:
57
+ (metadata_dict, weights_dict) = safe_read(fd)
56
58
  yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
57
59
  else:
58
60
  import torch
lalamo/models/common.py CHANGED
@@ -15,7 +15,7 @@ from lalamo.message_processor import Message, MessageProcessor, MessageProcessor
15
15
  from lalamo.modules import Classifier, Decoder, LalamoModule, config_converter
16
16
  from lalamo.modules.classifier import ClassifierConfig, ClassifierResult
17
17
  from lalamo.modules.decoder import DecoderConfig, DecoderResult
18
- from lalamo.utils import open_safetensors
18
+ from lalamo.safetensors import safe_read
19
19
 
20
20
  __all__ = [
21
21
  "TextModel",
@@ -42,8 +42,8 @@ class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
42
42
  with open(path / "config.json") as config_file:
43
43
  config_json = json.load(config_file)
44
44
  config = config_converter.structure(config_json["model_config"], cls)
45
- with open_safetensors(path / "model.safetensors") as open_results:
46
- weights_dict, _ = open_results
45
+ with Path(path / "model.safetensors").open("rb") as fd:
46
+ _, weights_dict = safe_read(fd)
47
47
  weights = unflatten_parameters(weights_dict)
48
48
  model = config.model_config.empty().import_weights(weights)
49
49
  tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
lalamo/safetensors.py ADDED
@@ -0,0 +1,97 @@
1
+ import json
2
+ import struct
3
+ from collections.abc import Mapping
4
+ from dataclasses import dataclass
5
+ from io import BufferedReader, BufferedWriter
6
+ from typing import Any, ClassVar, Self
7
+
8
+ import cattrs
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as np
12
+ from jaxtyping import Array
13
+
14
+ from lalamo.utils import LazyDict
15
+
16
+ SF2J = {
17
+ "BOOL": jnp.dtype(jnp.bool_),
18
+ "U8": jnp.dtype(jnp.uint8),
19
+ "I8": jnp.dtype(jnp.int8),
20
+ "I16": jnp.dtype(jnp.int16),
21
+ "U16": jnp.dtype(jnp.uint16),
22
+ "F16": jnp.dtype(jnp.float16),
23
+ "BF16": jnp.dtype(jnp.bfloat16),
24
+ "I32": jnp.dtype(jnp.int32),
25
+ "U32": jnp.dtype(jnp.uint32),
26
+ "F32": jnp.dtype(jnp.float32),
27
+ "C64": jnp.dtype(jnp.complex64),
28
+ "F64": jnp.dtype(jnp.float64),
29
+ "I64": jnp.dtype(jnp.int64),
30
+ "U64": jnp.dtype(jnp.uint64),
31
+ }
32
+
33
+ J2SF = {v: k for k, v in SF2J.items()}
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class SFTensorInfo:
38
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
39
+ _converter.register_structure_hook(jnp.dtype, lambda x, _: SF2J[x])
40
+ _converter.register_unstructure_hook(jnp.dtype, lambda x: J2SF[x])
41
+
42
+ dtype: jnp.dtype
43
+ shape: tuple[int, ...]
44
+ data_offsets: tuple[int, int]
45
+
46
+ @property
47
+ def start(self) -> int:
48
+ return self.data_offsets[0]
49
+
50
+ @property
51
+ def end(self) -> int:
52
+ return self.data_offsets[1]
53
+
54
+ @property
55
+ def size(self) -> int:
56
+ return self.end - self.start
57
+
58
+ @classmethod
59
+ def from_dict(cls, obj: dict) -> Self:
60
+ return cls._converter.structure(obj, cls)
61
+
62
+ def to_dict(self) -> dict:
63
+ return self._converter.unstructure(self)
64
+
65
+
66
+ def safe_read(fd: BufferedReader) -> tuple[dict[str, str] | None, LazyDict[str, Array]]:
67
+ header_size = struct.unpack("<Q", fd.read(8))[0]
68
+ header: dict[str, dict[str, Any]] = json.loads(fd.read(header_size))
69
+ metadata: dict[str, str] | None = header.pop("__metadata__", None)
70
+ data_offset = fd.tell()
71
+
72
+ def _load_tensor(key: str) -> Array:
73
+ info = SFTensorInfo.from_dict(header[key])
74
+ fd.seek(data_offset + info.start)
75
+ return jnp.asarray(np.fromfile(fd, info.dtype, info.size // info.dtype.itemsize)).reshape(info.shape)
76
+
77
+ lazy_tensors = LazyDict(set(header.keys()), _load_tensor)
78
+ return (metadata, lazy_tensors)
79
+
80
+
81
+ def safe_write(fd: BufferedWriter, tensors: Mapping[str, Array]) -> None:
82
+ sorted_tensors = dict(sorted(tensors.items(), key=lambda x: (-x[1].dtype.alignment, x[0])))
83
+
84
+ header_dict = {}
85
+ offset = 0
86
+ for key, tensor in sorted_tensors.items():
87
+ assert offset % tensor.dtype.alignment == 0
88
+ header_dict[key] = SFTensorInfo(tensor.dtype, tensor.shape, (offset, offset + tensor.nbytes)).to_dict()
89
+ offset += tensor.nbytes
90
+
91
+ data_alignment = max(8, next((t.dtype.alignment for t in sorted_tensors.values()), 1))
92
+ header = json.dumps(header_dict).encode()
93
+ header += b" " * (-len(header) % data_alignment)
94
+ fd.write(struct.pack("<Q", len(header)) + header)
95
+
96
+ for tensor in sorted_tensors.values():
97
+ jax.device_get(tensor).tofile(fd)
@@ -9,6 +9,13 @@ import jax.numpy as jnp
9
9
  from lalamo.models import LanguageModel
10
10
 
11
11
 
12
+ def get_default_device_memory() -> int | None:
13
+ memory_stats = jax.local_devices()[0].memory_stats()
14
+ if memory_stats is None or "bytes_limit" not in memory_stats:
15
+ return None
16
+ return memory_stats["bytes_limit"]
17
+
18
+
12
19
  def estimate_memory_from_batchsize(
13
20
  model: LanguageModel,
14
21
  max_input_length: int,
@@ -23,7 +30,7 @@ def estimate_memory_from_batchsize(
23
30
  max_output_length=max_output_length,
24
31
  num_top_logits_to_return=num_logits_per_token,
25
32
  ),
26
- backend="cpu", # cuda backend tries to allocate in .compile() and ooms
33
+ backend="cpu", # cuda backend tries to allocate in .compile() and ooms
27
34
  )
28
35
  .lower(
29
36
  model,
@@ -41,7 +48,7 @@ def estimate_memory_from_batchsize(
41
48
  return (
42
49
  memory_analysis.argument_size_in_bytes # type: ignore (pyright bug)
43
50
  + memory_analysis.output_size_in_bytes # type: ignore (pyright bug)
44
- + memory_analysis.temp_size_in_bytes # type: ignore (pyright bug)
51
+ + memory_analysis.temp_size_in_bytes # type: ignore (pyright bug)
45
52
  )
46
53
 
47
54
 
lalamo/utils.py CHANGED
@@ -9,21 +9,17 @@ from collections.abc import (
9
9
  Sequence,
10
10
  ValuesView,
11
11
  )
12
- from contextlib import contextmanager
13
12
  from dataclasses import dataclass
14
- from pathlib import Path
15
13
  from typing import overload
16
14
 
17
15
  import einops
18
16
  import jax.numpy as jnp
19
17
  from jaxtyping import Array
20
- from safetensors import safe_open
21
18
 
22
19
  __all__ = [
23
20
  "MapDictValues",
24
21
  "MapSequence",
25
22
  "jax_uint4_to_packed_uint8",
26
- "open_safetensors",
27
23
  "process_chat_template",
28
24
  ]
29
25
 
@@ -45,15 +41,6 @@ class LazyDict[K, V](Mapping[K, V]):
45
41
  return len(self.stored_keys)
46
42
 
47
43
 
48
- @contextmanager
49
- def open_safetensors(filename: Path | str) -> Iterator[tuple[Mapping[str, Array], Mapping[str, str]]]:
50
- with safe_open(filename, framework="flax") as safetensors_nonsense:
51
- yield (
52
- LazyDict(set(safetensors_nonsense.keys()), safetensors_nonsense.get_tensor),
53
- safetensors_nonsense.metadata(),
54
- )
55
-
56
-
57
44
  @dataclass(frozen=True)
58
45
  class MapIterable[OldT, NewT](Iterable[NewT]):
59
46
  map_func: Callable[[OldT], NewT]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.16
3
+ Version: 0.5.17
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -19,7 +19,6 @@ Requires-Dist: rich>=14.0.0
19
19
  Requires-Dist: thefuzz>=0.22.1
20
20
  Requires-Dist: tokenizers>=0.21.2
21
21
  Requires-Dist: typer>=0.15.1
22
- Requires-Dist: safetensors>=0.6.2
23
22
  Requires-Dist: polars>=1.33.1
24
23
  Requires-Dist: xxhash>=3.5.0
25
24
  Provides-Extra: cpu
@@ -1,11 +1,13 @@
1
- lalamo/__init__.py,sha256=FjfGsBVSl14mNsDoFJEwXMRUq1-Kg_lessRzlJNG3KM,815
1
+ lalamo/__init__.py,sha256=asVMPmQ7BUt7bYlcuNZ7SnOSJDJUiN9QhlU5lRUehSo,1387
2
+ lalamo/commands.py,sha256=rU9T8Mx6s7itpk-dj5ToQ4PUpGPfdmmKlrF02l2kIS0,9967
2
3
  lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
3
- lalamo/main.py,sha256=GgUT7lT48-XQuAEH7qzsDKG8Lx9iBf-sYBIRhZL9q7E,23978
4
+ lalamo/main.py,sha256=dE7Us9L6sfz9bp5rUSzGHUkG0Uon4xdju9dGGtXidZI,23888
4
5
  lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
5
6
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
7
  lalamo/registry_abc.py,sha256=ENjXiD_wEH100fNjG-W5Em1L_EQ0Lf0pdRhRGvf3qZk,2197
8
+ lalamo/safetensors.py,sha256=kUiTSgx2zhfD1hxV_AA1DOLaKAKzjRd_vOYZCFf0em0,3048
7
9
  lalamo/sampling.py,sha256=g_dNiJyZrRqoQIiLid4cr6nRT9N5tSz3GtHr8Bt4n-E,3404
8
- lalamo/utils.py,sha256=QwATVXAeHBsQEDyt_31SHgxFphFVZYHpv3ZaklXks9Y,4585
10
+ lalamo/utils.py,sha256=c88IP110gHZJ6hYDq7p36A9u-vLRM_YdavFom56gsNQ,4111
9
11
  lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
10
12
  lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
11
13
  lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
@@ -35,7 +37,7 @@ lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFl
35
37
  lalamo/model_import/loaders/huggingface.py,sha256=qWdzoSvHvb_3prn2kwfxgnYPW2bVB0Q49m_wyRYha8Q,34677
36
38
  lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
37
39
  lalamo/model_import/model_specs/__init__.py,sha256=JISqwJkloQkGD2jvi1MakNEWapIwlNXXVi5giZyXB74,1275
38
- lalamo/model_import/model_specs/common.py,sha256=RLySCIkmGiA1IVZgLeemssMBMo4hMYMpmBjV0cRwBb4,6586
40
+ lalamo/model_import/model_specs/common.py,sha256=8ALKxHrt8uK4XiqjK25NwZj1CC7DM7jlYcFVZPGkFrw,6643
39
41
  lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
40
42
  lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
41
43
  lalamo/model_import/model_specs/gemma.py,sha256=dwKwOHU1sBJNLFAwtEyydsRUF9QENN3SHtjbfqtOSic,3876
@@ -52,7 +54,7 @@ lalamo/model_import/model_specs/qwen.py,sha256=HvN080ILpOwkqJbRLMqCa8Z8ImlLfTwiE
52
54
  lalamo/model_import/model_specs/reka.py,sha256=dOUYbEMMvovQdzQuBO_DCsjGI39syhoKCvnxLkNEDCw,423
53
55
  lalamo/models/__init__.py,sha256=Vn5PcvSqKppIchkSZwQVTn_GpRvOOzZVxo5PUeDl6N8,283
54
56
  lalamo/models/classifier.py,sha256=LvL54crCVi4HVSIXuoaSLB_5jtcx74GL7kgdy2Y16Zc,2094
55
- lalamo/models/common.py,sha256=PDteofGxjSBWYw_mPxbN1DTUba70aOURrAIjl13SSHc,2954
57
+ lalamo/models/common.py,sha256=uU6eCHtIqMeC_aRGVo09NdpAtvQ6RKSbm6pumVvL8pc,2943
56
58
  lalamo/models/language_model.py,sha256=QPeVEyhutSze7fSNhvOvwSoYt24QMk-dtTJkos38amY,13465
57
59
  lalamo/modules/__init__.py,sha256=OHIQn08jx2c3L2KIQA-7SJ4yVb2E5m6T6FqTHFJTDdM,4006
58
60
  lalamo/modules/activations.py,sha256=U3qTQtZawPAUcoqbkIJnmTYcaNiQuSPMLcBeJ398GhI,1022
@@ -81,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
81
83
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
82
84
  lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
83
85
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
84
- lalamo/speculator/estimator.py,sha256=4D8dPZCWsrpORb7y8pQ6VsiIg1Cblvvxe6gXCoYtcD4,2530
86
+ lalamo/speculator/estimator.py,sha256=j-zmhy3RxYDmQ7W0FMTmDk3i275r_Vg1s4NCaS4c_SQ,2760
85
87
  lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
86
88
  lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
87
89
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
88
- lalamo-0.5.16.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
89
- lalamo-0.5.16.dist-info/METADATA,sha256=dcs0vT9RULTxt4cxJJmfjP-4UJi7ZkrifXAaSMAgKeU,3147
90
- lalamo-0.5.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
91
- lalamo-0.5.16.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
92
- lalamo-0.5.16.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
93
- lalamo-0.5.16.dist-info/RECORD,,
90
+ lalamo-0.5.17.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
+ lalamo-0.5.17.dist-info/METADATA,sha256=16-W1J0wiwrmgMTgqiE9r3vxKRmZbGgZ-zS7bNACwTA,3113
92
+ lalamo-0.5.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
93
+ lalamo-0.5.17.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
+ lalamo-0.5.17.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
+ lalamo-0.5.17.dist-info/RECORD,,