lalamo 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +6 -1
- lalamo/main.py +18 -4
- lalamo/speculator/estimator.py +32 -4
- lalamo/speculator/inference.py +6 -1
- {lalamo-0.6.2.dist-info → lalamo-0.6.4.dist-info}/METADATA +1 -1
- {lalamo-0.6.2.dist-info → lalamo-0.6.4.dist-info}/RECORD +10 -10
- {lalamo-0.6.2.dist-info → lalamo-0.6.4.dist-info}/WHEEL +0 -0
- {lalamo-0.6.2.dist-info → lalamo-0.6.4.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.2.dist-info → lalamo-0.6.4.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.2.dist-info → lalamo-0.6.4.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
# Must run before importing jax / tensorflow, this hides the XLA optimization logs
|
|
4
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
|
|
5
|
+
|
|
1
6
|
from lalamo.commands import (
|
|
2
7
|
CollectTracesCallbacks,
|
|
3
8
|
ConversionCallbacks,
|
|
@@ -27,7 +32,7 @@ from lalamo.speculator import (
|
|
|
27
32
|
SpeculatorTrainingEvent,
|
|
28
33
|
)
|
|
29
34
|
|
|
30
|
-
__version__ = "0.6.
|
|
35
|
+
__version__ = "0.6.4"
|
|
31
36
|
|
|
32
37
|
__all__ = [
|
|
33
38
|
"AssistantMessage",
|
lalamo/main.py
CHANGED
|
@@ -49,7 +49,10 @@ from lalamo.message_processor import UserMessage
|
|
|
49
49
|
from lalamo.model_import import REPO_TO_MODEL, ModelSpec
|
|
50
50
|
from lalamo.model_import.common import FileSpec
|
|
51
51
|
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
|
|
52
|
-
from lalamo.speculator.estimator import
|
|
52
|
+
from lalamo.speculator.estimator import (
|
|
53
|
+
get_default_device_bytes,
|
|
54
|
+
get_usable_memory_from_bytes,
|
|
55
|
+
)
|
|
53
56
|
from lalamo.speculator.ngram import NGramSpeculator
|
|
54
57
|
from lalamo.speculator.utils import test_speculator
|
|
55
58
|
|
|
@@ -384,6 +387,7 @@ class CliTraceCallbacks(TraceCallbacks):
|
|
|
384
387
|
self.stack.close()
|
|
385
388
|
console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
|
|
386
389
|
|
|
390
|
+
|
|
387
391
|
@app.command(help="Trace a model.")
|
|
388
392
|
def trace(
|
|
389
393
|
model_path: Annotated[
|
|
@@ -557,14 +561,24 @@ def estimate_batchsize(
|
|
|
557
561
|
] = None,
|
|
558
562
|
) -> None:
|
|
559
563
|
if vram_gb is not None:
|
|
560
|
-
|
|
561
|
-
|
|
564
|
+
# note that in practice GPUs use GiB in their docs, e.g. H100 actually has 85GB of memory
|
|
565
|
+
mem_bytes = vram_gb * 1000 * 1000 * 1000
|
|
566
|
+
elif (mem_bytes := get_default_device_bytes()) is None:
|
|
562
567
|
err_console.print("Cannot get the default device's memory stats, use --vram-gb")
|
|
563
568
|
raise Exit(1)
|
|
564
569
|
|
|
570
|
+
usable_mem = get_usable_memory_from_bytes(mem_bytes)
|
|
571
|
+
|
|
565
572
|
callbacks_type = CliEstimateBatchsizeCallbacks
|
|
566
573
|
|
|
567
|
-
_estimate_batchsize(
|
|
574
|
+
_estimate_batchsize(
|
|
575
|
+
model_path,
|
|
576
|
+
usable_mem,
|
|
577
|
+
max_input_length,
|
|
578
|
+
max_output_length,
|
|
579
|
+
num_logits_per_token,
|
|
580
|
+
callbacks_type,
|
|
581
|
+
)
|
|
568
582
|
|
|
569
583
|
|
|
570
584
|
@dataclass
|
lalamo/speculator/estimator.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import itertools
|
|
3
|
+
import os
|
|
3
4
|
from collections.abc import Callable
|
|
4
5
|
from typing import NamedTuple
|
|
5
6
|
|
|
@@ -9,11 +10,38 @@ import jax.numpy as jnp
|
|
|
9
10
|
from lalamo.models import LanguageModel
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def
|
|
13
|
+
def get_default_device_bytes() -> int | None:
|
|
14
|
+
dynamic_allocate = False
|
|
15
|
+
|
|
16
|
+
preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
|
|
17
|
+
dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
|
|
18
|
+
|
|
19
|
+
allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
|
|
20
|
+
dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
|
|
21
|
+
|
|
22
|
+
if dynamic_allocate:
|
|
23
|
+
return None
|
|
24
|
+
|
|
13
25
|
memory_stats = jax.local_devices()[0].memory_stats()
|
|
14
26
|
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
15
27
|
return None
|
|
16
|
-
|
|
28
|
+
|
|
29
|
+
mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
|
|
30
|
+
try:
|
|
31
|
+
mem_fraction = float(mem_fraction_raw)
|
|
32
|
+
except ValueError:
|
|
33
|
+
mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
|
|
34
|
+
|
|
35
|
+
# 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
|
|
36
|
+
# so it should correspond to something you'd see in nvidia-smi
|
|
37
|
+
memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
|
|
38
|
+
|
|
39
|
+
return get_usable_memory_from_bytes(memory_limit)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_usable_memory_from_bytes(limit_bytes: int) -> int:
|
|
43
|
+
# JAX allocates a bit more than it needs, so we discount it by some safety factor
|
|
44
|
+
return int(limit_bytes * 0.95)
|
|
17
45
|
|
|
18
46
|
|
|
19
47
|
def estimate_memory_from_batchsize(
|
|
@@ -30,14 +58,14 @@ def estimate_memory_from_batchsize(
|
|
|
30
58
|
max_output_length=max_output_length,
|
|
31
59
|
num_top_logits_to_return=num_logits_per_token,
|
|
32
60
|
),
|
|
33
|
-
backend="cpu", # cuda backend tries to allocate in .compile() and ooms
|
|
34
61
|
)
|
|
35
62
|
.lower(
|
|
36
63
|
model,
|
|
37
64
|
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
38
65
|
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
39
66
|
)
|
|
40
|
-
.
|
|
67
|
+
# disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
|
|
68
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
41
69
|
.memory_analysis()
|
|
42
70
|
)
|
|
43
71
|
|
lalamo/speculator/inference.py
CHANGED
|
@@ -40,7 +40,12 @@ def inference_collect_traces(
|
|
|
40
40
|
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
41
41
|
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
42
42
|
)
|
|
43
|
-
.
|
|
43
|
+
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
44
|
+
# 0 - no autotune, gpu shouldn't be touched
|
|
45
|
+
# 1 - basic level, gpu should be touched veeery little
|
|
46
|
+
# 2,3 - gpu touched more and more
|
|
47
|
+
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
48
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "2"})
|
|
44
49
|
)
|
|
45
50
|
|
|
46
51
|
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=RDkf5Hhglc-fLZ-CmI4R-th6UgJKYmN-1hdbCzTiVx8,1532
|
|
2
2
|
lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
|
|
3
3
|
lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
|
|
4
|
-
lalamo/main.py,sha256=
|
|
4
|
+
lalamo/main.py,sha256=f1zHYQpX_OndAguOE0wqIOkzjzUolUC7w3_1ndtMC4Y,27655
|
|
5
5
|
lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
|
|
6
6
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
7
7
|
lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
|
|
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
|
|
|
83
83
|
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
84
84
|
lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
|
|
85
85
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
86
|
-
lalamo/speculator/estimator.py,sha256=
|
|
87
|
-
lalamo/speculator/inference.py,sha256=
|
|
86
|
+
lalamo/speculator/estimator.py,sha256=6T8NdmDdhvP0BPg7vdkB_pxAkfgpu4WktNpUHtFuyiE,3833
|
|
87
|
+
lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
|
|
88
88
|
lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
|
|
89
89
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
90
|
-
lalamo-0.6.
|
|
91
|
-
lalamo-0.6.
|
|
92
|
-
lalamo-0.6.
|
|
93
|
-
lalamo-0.6.
|
|
94
|
-
lalamo-0.6.
|
|
95
|
-
lalamo-0.6.
|
|
90
|
+
lalamo-0.6.4.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
91
|
+
lalamo-0.6.4.dist-info/METADATA,sha256=oS1EAJBl3jBtvZU0Rd-UcjnL2Trngree7Syn2L16Rx8,3112
|
|
92
|
+
lalamo-0.6.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
93
|
+
lalamo-0.6.4.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
94
|
+
lalamo-0.6.4.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
95
|
+
lalamo-0.6.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|