lalamo 0.6.5__tar.gz → 0.6.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lalamo-0.6.5 → lalamo-0.6.6}/PKG-INFO +1 -1
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/__init__.py +1 -1
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/commands.py +247 -14
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/common.py +27 -48
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/data/__init__.py +3 -2
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/data/huggingface_message.py +4 -5
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/main.py +274 -9
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/message_processor.py +19 -1
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/common.py +17 -1
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo-0.6.6/lalamo/model_import/remote_registry.py +44 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/models/__init__.py +3 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/models/common.py +22 -0
- lalamo-0.6.6/lalamo/models/compile_helpers.py +58 -0
- lalamo-0.6.6/lalamo/models/language_model.py +638 -0
- lalamo-0.6.6/lalamo/models/lm_helpers.py +198 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/decoder.py +4 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/mamba.py +345 -105
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/speculator/__init__.py +0 -2
- lalamo-0.6.6/lalamo/speculator/inference.py +75 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo.egg-info/PKG-INFO +1 -1
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo.egg-info/SOURCES.txt +3 -1
- {lalamo-0.6.5 → lalamo-0.6.6}/pyproject.toml +3 -3
- lalamo-0.6.5/lalamo/models/language_model.py +0 -352
- lalamo-0.6.5/lalamo/speculator/estimator.py +0 -127
- lalamo-0.6.5/lalamo/speculator/inference.py +0 -112
- {lalamo-0.6.5 → lalamo-0.6.6}/LICENSE +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/README.md +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/data/lalamo_completions.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/data/utils.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/executorch.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/huggingface_generation_config.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/loaders/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/loaders/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/loaders/executorch.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/loaders/huggingface.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/loaders/utils.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/deepseek.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/essential_ai.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/gemma.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/huggingface.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/lfm2.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/llama.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/llamba.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/mirai.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/pleias.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/polaris.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/qwen.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/model_import/model_specs/reka.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/models/classifier.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/activations.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/classifier.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/embedding.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/linear.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/mlp.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/mlx_interop.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/normalization.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/rope.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/attention.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/short_conv.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/__init__.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/torch_interop.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/transformer.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/transformer_layer.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/modules/utils.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/quantization.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/registry_abc.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/safetensors.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/sampling.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/speculator/common.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/speculator/ngram.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/speculator/utils.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo/utils.py +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo.egg-info/dependency_links.txt +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo.egg-info/entry_points.txt +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo.egg-info/requires.txt +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/lalamo.egg-info/top_level.txt +0 -0
- {lalamo-0.6.5 → lalamo-0.6.6}/setup.cfg +0 -0
|
@@ -1,16 +1,22 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import shutil
|
|
3
|
+
import tempfile
|
|
2
4
|
from collections.abc import Callable, Iterable
|
|
3
5
|
from dataclasses import dataclass
|
|
4
6
|
from enum import Enum
|
|
5
7
|
from itertools import chain
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
|
|
10
|
+
import polars as pl
|
|
11
|
+
import requests
|
|
12
|
+
import thefuzz.process
|
|
8
13
|
from jaxtyping import DTypeLike
|
|
9
14
|
|
|
10
|
-
from lalamo.common import flatten_parameters
|
|
11
|
-
from lalamo.data import
|
|
15
|
+
from lalamo.common import flatten_parameters, get_default_device_bytes
|
|
16
|
+
from lalamo.data import load_hf_parquet, shuffle_dataset
|
|
17
|
+
from lalamo.data.huggingface_message import HFMessage
|
|
12
18
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
13
|
-
from lalamo.message_processor import Message
|
|
19
|
+
from lalamo.message_processor import AssistantMessage, Message
|
|
14
20
|
from lalamo.model_import import ModelMetadata, ModelSpec, import_model
|
|
15
21
|
from lalamo.model_import.common import (
|
|
16
22
|
DownloadingFileEvent,
|
|
@@ -20,15 +26,107 @@ from lalamo.model_import.common import (
|
|
|
20
26
|
InitializingModelEvent,
|
|
21
27
|
StatusEvent,
|
|
22
28
|
)
|
|
29
|
+
from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile
|
|
23
30
|
from lalamo.models import LanguageModelConfig
|
|
31
|
+
from lalamo.models.common import BatchSizesComputedEvent, InferenceConfig
|
|
32
|
+
from lalamo.models.lm_helpers import estimate_batchsize_from_bytes
|
|
24
33
|
from lalamo.modules import config_converter
|
|
25
34
|
from lalamo.safetensors import safe_write
|
|
26
|
-
from lalamo.speculator.estimator import EstimateBatchsizeFromMemoryEvent, estimate_batchsize_from_memory
|
|
27
35
|
from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
|
|
28
36
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
29
37
|
from lalamo.speculator.utils import SpeculatorTrainingEvent, train_speculator
|
|
30
38
|
|
|
31
39
|
|
|
40
|
+
@dataclass
|
|
41
|
+
class PullCallbacks:
|
|
42
|
+
model_spec: RegistryModel
|
|
43
|
+
output_dir: Path
|
|
44
|
+
overwrite: bool
|
|
45
|
+
|
|
46
|
+
def started(self) -> None:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
def output_dir_exists(self) -> None:
|
|
50
|
+
raise RuntimeError(f"{self.output_dir=} already exists, refusing to overwrite!")
|
|
51
|
+
|
|
52
|
+
def downloading(self, file_spec: RegistryModelFile) -> None:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def finished_downloading(self, file_spec: RegistryModelFile) -> None:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def finished(self) -> None:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _download_file(url: str, dest_path: Path) -> None:
|
|
63
|
+
response = requests.get(url, stream=True, timeout=60)
|
|
64
|
+
response.raise_for_status()
|
|
65
|
+
|
|
66
|
+
with open(dest_path, "wb") as f:
|
|
67
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
68
|
+
if chunk:
|
|
69
|
+
f.write(chunk)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _suggest_similar_models(query: str, available_models: list[RegistryModel], limit: int = 3) -> list[str]:
|
|
73
|
+
repo_ids = [m.repo_id for m in available_models]
|
|
74
|
+
matches = thefuzz.process.extract(query, repo_ids, limit=limit)
|
|
75
|
+
return [match[0] for match in matches if match[1] >= 50]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def pull(
|
|
79
|
+
model_spec: RegistryModel,
|
|
80
|
+
output_dir: Path,
|
|
81
|
+
callbacks_type: Callable[
|
|
82
|
+
[
|
|
83
|
+
RegistryModel,
|
|
84
|
+
Path,
|
|
85
|
+
bool,
|
|
86
|
+
],
|
|
87
|
+
PullCallbacks,
|
|
88
|
+
] = PullCallbacks,
|
|
89
|
+
overwrite: bool = False,
|
|
90
|
+
) -> None:
|
|
91
|
+
callbacks = callbacks_type(model_spec, output_dir, overwrite)
|
|
92
|
+
|
|
93
|
+
if output_dir.exists():
|
|
94
|
+
callbacks.output_dir_exists()
|
|
95
|
+
|
|
96
|
+
callbacks.started()
|
|
97
|
+
|
|
98
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
99
|
+
temp_path = Path(temp_dir)
|
|
100
|
+
|
|
101
|
+
for file_spec in model_spec.files:
|
|
102
|
+
callbacks.downloading(file_spec)
|
|
103
|
+
|
|
104
|
+
# Security: validate filename to prevent path traversal attacks
|
|
105
|
+
safe_name = Path(file_spec.name).name
|
|
106
|
+
if not safe_name or safe_name != file_spec.name:
|
|
107
|
+
raise RuntimeError(
|
|
108
|
+
f"Invalid filename from registry: {file_spec.name!r}. "
|
|
109
|
+
f"Filenames must not contain path separators or traversal sequences.",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
file_path = temp_path / safe_name
|
|
113
|
+
try:
|
|
114
|
+
_download_file(file_spec.url, file_path)
|
|
115
|
+
except requests.RequestException as e:
|
|
116
|
+
raise RuntimeError(f"Failed to download {safe_name}: {e}") from e
|
|
117
|
+
|
|
118
|
+
callbacks.finished_downloading(file_spec)
|
|
119
|
+
|
|
120
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
121
|
+
for file_spec in model_spec.files:
|
|
122
|
+
safe_name = Path(file_spec.name).name
|
|
123
|
+
src = temp_path / safe_name
|
|
124
|
+
dst = output_dir / safe_name
|
|
125
|
+
shutil.move(str(src), str(dst))
|
|
126
|
+
|
|
127
|
+
callbacks.finished()
|
|
128
|
+
|
|
129
|
+
|
|
32
130
|
class Precision(Enum):
|
|
33
131
|
FLOAT32 = "float32"
|
|
34
132
|
FLOAT16 = "float16"
|
|
@@ -244,16 +342,19 @@ def estimate_batchsize(
|
|
|
244
342
|
model = LanguageModelConfig.load_model(model_path)
|
|
245
343
|
callbacks.finished_loading_model()
|
|
246
344
|
|
|
247
|
-
def
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
345
|
+
def memory_per_batchsize(batch_size: int) -> int:
|
|
346
|
+
inference_config = InferenceConfig(
|
|
347
|
+
max_output_length=max_output_length,
|
|
348
|
+
padded_length=max_input_length,
|
|
349
|
+
num_top_logits_to_return=num_logits_per_token,
|
|
350
|
+
batch_size=batch_size,
|
|
351
|
+
)
|
|
352
|
+
return model.estimate_memory_consumption(inference_config=inference_config)
|
|
353
|
+
|
|
354
|
+
bs = estimate_batchsize_from_bytes(
|
|
355
|
+
memory_per_batchsize,
|
|
255
356
|
mem,
|
|
256
|
-
|
|
357
|
+
lambda event: callbacks.estimating_batchsize(event.lo, event.hi),
|
|
257
358
|
)
|
|
258
359
|
|
|
259
360
|
callbacks.finished_estimating_batchsize(bs)
|
|
@@ -329,7 +430,11 @@ def collect_traces(
|
|
|
329
430
|
callbacks.finished_loading_model()
|
|
330
431
|
|
|
331
432
|
callbacks.loading_dataset()
|
|
332
|
-
|
|
433
|
+
dataframe = shuffle_dataset(load_hf_parquet(dataset_path))
|
|
434
|
+
conversations = dataframe.get_column("conversation")
|
|
435
|
+
dataset = iter(
|
|
436
|
+
[HFMessage.from_dict(message).as_message() for message in conversation] for conversation in conversations
|
|
437
|
+
)
|
|
333
438
|
dataset = chain([next(dataset)], dataset) # iterator is lazy, force it to actually open the file
|
|
334
439
|
callbacks.finished_loading_dataset()
|
|
335
440
|
|
|
@@ -427,3 +532,131 @@ def train(
|
|
|
427
532
|
with open(output_path, "wb") as fd:
|
|
428
533
|
fd.write(speculator.serialize())
|
|
429
534
|
callbacks.finished_saving_speculator()
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@dataclass
|
|
538
|
+
class GenerateRepliesCallbacks:
|
|
539
|
+
model_path: Path
|
|
540
|
+
dataset_path: Path
|
|
541
|
+
output_path: Path
|
|
542
|
+
max_vram: int | None
|
|
543
|
+
batch_size: int | None
|
|
544
|
+
total_rows: int
|
|
545
|
+
|
|
546
|
+
def loading_model(self) -> None:
|
|
547
|
+
pass
|
|
548
|
+
|
|
549
|
+
def finished_loading_model(self) -> None:
|
|
550
|
+
pass
|
|
551
|
+
|
|
552
|
+
def loading_dataset(self) -> None:
|
|
553
|
+
pass
|
|
554
|
+
|
|
555
|
+
def finished_loading_dataset(self) -> None:
|
|
556
|
+
pass
|
|
557
|
+
|
|
558
|
+
def estimating_batchsize(self, sequence_length: int, lo: int, hi: int | None) -> None:
|
|
559
|
+
pass
|
|
560
|
+
|
|
561
|
+
def batch_sizes_estimated(self) -> None:
|
|
562
|
+
pass
|
|
563
|
+
|
|
564
|
+
def batch_sizes_computed(self, event: BatchSizesComputedEvent) -> None:
|
|
565
|
+
pass
|
|
566
|
+
|
|
567
|
+
def generation_progress(self, rows_processed: int) -> None:
|
|
568
|
+
pass
|
|
569
|
+
|
|
570
|
+
def finished_generation(self) -> None:
|
|
571
|
+
pass
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def generate_replies(
|
|
575
|
+
model_path: Path,
|
|
576
|
+
dataset_path: Path,
|
|
577
|
+
output_path: Path,
|
|
578
|
+
max_vram: int | None,
|
|
579
|
+
max_output_length: int = 8192,
|
|
580
|
+
batch_size: int | None = None,
|
|
581
|
+
callbacks_type: Callable[
|
|
582
|
+
[
|
|
583
|
+
Path,
|
|
584
|
+
Path,
|
|
585
|
+
Path,
|
|
586
|
+
int | None,
|
|
587
|
+
int | None,
|
|
588
|
+
int,
|
|
589
|
+
],
|
|
590
|
+
GenerateRepliesCallbacks,
|
|
591
|
+
] = GenerateRepliesCallbacks,
|
|
592
|
+
) -> None:
|
|
593
|
+
# figure out max_vram if neither batch_size nor max_vram is set
|
|
594
|
+
if max_vram is None and batch_size is None:
|
|
595
|
+
max_vram = get_default_device_bytes()
|
|
596
|
+
if max_vram is None:
|
|
597
|
+
raise ValueError(
|
|
598
|
+
"Unable to determine default defice memory capacity; please specify either --vram-gb or --batch-size",
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# Count rows without loading full dataset
|
|
602
|
+
total_rows = pl.scan_parquet(dataset_path).select(pl.len()).collect().item()
|
|
603
|
+
|
|
604
|
+
callbacks = callbacks_type(
|
|
605
|
+
model_path,
|
|
606
|
+
dataset_path,
|
|
607
|
+
output_path,
|
|
608
|
+
max_vram,
|
|
609
|
+
batch_size,
|
|
610
|
+
total_rows,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
callbacks.loading_model()
|
|
614
|
+
model = LanguageModelConfig.load_model(model_path)
|
|
615
|
+
callbacks.finished_loading_model()
|
|
616
|
+
|
|
617
|
+
callbacks.loading_dataset()
|
|
618
|
+
dataframe = load_hf_parquet(dataset_path).collect()
|
|
619
|
+
conversations = dataframe.get_column("conversation")
|
|
620
|
+
dataset = iter(
|
|
621
|
+
[HFMessage.from_dict(message).as_message() for message in conversation] for conversation in conversations
|
|
622
|
+
)
|
|
623
|
+
try:
|
|
624
|
+
first_row = next(dataset)
|
|
625
|
+
except StopIteration:
|
|
626
|
+
callbacks.finished_loading_dataset()
|
|
627
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
628
|
+
pl.DataFrame({"response": [], "chain_of_thought": []}).write_parquet(output_path)
|
|
629
|
+
return
|
|
630
|
+
dataset = chain([first_row], dataset) # iterator is lazy, force it to actually open the file
|
|
631
|
+
callbacks.finished_loading_dataset()
|
|
632
|
+
|
|
633
|
+
inference_config = InferenceConfig(max_output_length=max_output_length, batch_size=batch_size)
|
|
634
|
+
|
|
635
|
+
callbacks.batch_sizes_estimated()
|
|
636
|
+
|
|
637
|
+
replies: list[tuple[int, AssistantMessage]] = []
|
|
638
|
+
for rows_processed, (idx, reply) in enumerate(
|
|
639
|
+
model.reply_many(
|
|
640
|
+
dataset,
|
|
641
|
+
inference_config=inference_config,
|
|
642
|
+
vram_bytes=max_vram,
|
|
643
|
+
batch_sizes_callback=callbacks.batch_sizes_computed,
|
|
644
|
+
),
|
|
645
|
+
):
|
|
646
|
+
replies.append((idx, reply))
|
|
647
|
+
callbacks.generation_progress(rows_processed)
|
|
648
|
+
|
|
649
|
+
# Sort by original index to restore input order
|
|
650
|
+
replies.sort(key=lambda x: x[0])
|
|
651
|
+
|
|
652
|
+
df = pl.DataFrame(
|
|
653
|
+
{
|
|
654
|
+
"response": [reply.response for _, reply in replies],
|
|
655
|
+
"chain_of_thought": [reply.chain_of_thought for _, reply in replies],
|
|
656
|
+
},
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
660
|
+
df.write_parquet(output_path)
|
|
661
|
+
|
|
662
|
+
callbacks.finished_generation()
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
import
|
|
1
|
+
import os
|
|
2
2
|
from collections import defaultdict
|
|
3
|
-
from collections.abc import
|
|
3
|
+
from collections.abc import Mapping, Sequence
|
|
4
4
|
from typing import cast
|
|
5
5
|
|
|
6
|
+
import jax
|
|
6
7
|
import jax.numpy as jnp
|
|
7
8
|
from jax._src.api import ShapeDtypeStruct
|
|
8
|
-
from jax.errors import JaxRuntimeError
|
|
9
9
|
from jaxtyping import Array, DTypeLike
|
|
10
10
|
|
|
11
11
|
from lalamo.utils import MapDictValues, MapSequence
|
|
@@ -16,7 +16,6 @@ __all__ = [
|
|
|
16
16
|
"LalamoWarning",
|
|
17
17
|
"ParameterPath",
|
|
18
18
|
"ParameterTree",
|
|
19
|
-
"decrease_batchsize_on_oom",
|
|
20
19
|
"dummy_array",
|
|
21
20
|
"flatten_parameters",
|
|
22
21
|
"require_array",
|
|
@@ -131,47 +130,27 @@ class ParameterPath(str):
|
|
|
131
130
|
return ParameterPath(self + "." + str(other))
|
|
132
131
|
|
|
133
132
|
|
|
134
|
-
def
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
yield result
|
|
159
|
-
|
|
160
|
-
# as soon as we yielded we are not allowed to retry anymore
|
|
161
|
-
# to make sure we don't ever miss/duplicate outputs
|
|
162
|
-
first_batch_completed = True
|
|
163
|
-
break
|
|
164
|
-
except JaxRuntimeError:
|
|
165
|
-
if first_batch_completed:
|
|
166
|
-
raise
|
|
167
|
-
# because OOM's sometimes generate stuff that won't be garbage collected,
|
|
168
|
-
# we need to be very aggressive with decreasing batchsize here
|
|
169
|
-
new_bs = max(int(0.7 * effective_batch_size - 1), 1)
|
|
170
|
-
if new_bs == 1 and effective_batch_size == 1:
|
|
171
|
-
raise
|
|
172
|
-
warnings.warn(
|
|
173
|
-
f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
|
|
174
|
-
LalamoWarning,
|
|
175
|
-
stacklevel=3,
|
|
176
|
-
)
|
|
177
|
-
effective_batch_size = new_bs
|
|
133
|
+
def get_default_device_bytes() -> int | None:
|
|
134
|
+
dynamic_allocate = False
|
|
135
|
+
|
|
136
|
+
preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
|
|
137
|
+
dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
|
|
138
|
+
|
|
139
|
+
allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
|
|
140
|
+
dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
|
|
141
|
+
|
|
142
|
+
if dynamic_allocate:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
memory_stats = jax.local_devices()[0].memory_stats()
|
|
146
|
+
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
# 500mb is seemingly the usually observed overhead
|
|
150
|
+
memory_limit = memory_stats["bytes_limit"] - (500 * 1000 * 1000)
|
|
151
|
+
|
|
152
|
+
return memory_limit
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def get_usable_memory_from_bytes(limit_bytes: int) -> int:
|
|
156
|
+
return int(limit_bytes * 0.95)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
from .huggingface_message import
|
|
1
|
+
from .huggingface_message import load_hf_parquet, shuffle_dataset
|
|
2
2
|
from .utils import get_prefixes_ending_in_user_message
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
"get_prefixes_ending_in_user_message",
|
|
6
|
-
"
|
|
6
|
+
"load_hf_parquet",
|
|
7
|
+
"shuffle_dataset",
|
|
7
8
|
]
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
1
|
from dataclasses import dataclass
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
from typing import ClassVar, Self
|
|
@@ -30,10 +29,10 @@ class HFMessage:
|
|
|
30
29
|
raise ValueError(f"Cannot convert {other} message")
|
|
31
30
|
|
|
32
31
|
|
|
33
|
-
def
|
|
32
|
+
def load_hf_parquet(path: Path | str) -> pl.LazyFrame:
|
|
34
33
|
path = Path(path)
|
|
34
|
+
return pl.scan_parquet(path)
|
|
35
35
|
|
|
36
|
-
dataframe = pl.scan_parquet(path).collect()
|
|
37
36
|
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def shuffle_dataset(frame: pl.LazyFrame, seed: int = 1337) -> pl.DataFrame:
|
|
38
|
+
return frame.collect().sample(fraction=1.0, shuffle=True, seed=seed)
|