lalamo 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +6 -1
- lalamo/speculator/estimator.py +29 -3
- lalamo/speculator/inference.py +6 -1
- {lalamo-0.6.2.dist-info → lalamo-0.6.3.dist-info}/METADATA +1 -1
- {lalamo-0.6.2.dist-info → lalamo-0.6.3.dist-info}/RECORD +9 -9
- {lalamo-0.6.2.dist-info → lalamo-0.6.3.dist-info}/WHEEL +0 -0
- {lalamo-0.6.2.dist-info → lalamo-0.6.3.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.2.dist-info → lalamo-0.6.3.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.2.dist-info → lalamo-0.6.3.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
# Must run before importing jax / tensorflow, this hides the XLA optimization logs
|
|
4
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
|
|
5
|
+
|
|
1
6
|
from lalamo.commands import (
|
|
2
7
|
CollectTracesCallbacks,
|
|
3
8
|
ConversionCallbacks,
|
|
@@ -27,7 +32,7 @@ from lalamo.speculator import (
|
|
|
27
32
|
SpeculatorTrainingEvent,
|
|
28
33
|
)
|
|
29
34
|
|
|
30
|
-
__version__ = "0.6.
|
|
35
|
+
__version__ = "0.6.3"
|
|
31
36
|
|
|
32
37
|
__all__ = [
|
|
33
38
|
"AssistantMessage",
|
lalamo/speculator/estimator.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import itertools
|
|
3
|
+
import os
|
|
3
4
|
from collections.abc import Callable
|
|
4
5
|
from typing import NamedTuple
|
|
5
6
|
|
|
@@ -10,10 +11,35 @@ from lalamo.models import LanguageModel
|
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def get_default_device_memory() -> int | None:
|
|
14
|
+
dynamic_allocate = False
|
|
15
|
+
|
|
16
|
+
preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
|
|
17
|
+
dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
|
|
18
|
+
|
|
19
|
+
allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
|
|
20
|
+
dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
|
|
21
|
+
|
|
22
|
+
if dynamic_allocate:
|
|
23
|
+
return None
|
|
24
|
+
|
|
13
25
|
memory_stats = jax.local_devices()[0].memory_stats()
|
|
14
26
|
if memory_stats is None or "bytes_limit" not in memory_stats:
|
|
15
27
|
return None
|
|
16
|
-
|
|
28
|
+
|
|
29
|
+
# discount by 0.98 because if you try to allocate exactly bytes_limit
|
|
30
|
+
# jax will _try_ to allocate a bit more, and then throws an OOM
|
|
31
|
+
bytes_limit = memory_stats["bytes_limit"] * 0.98
|
|
32
|
+
|
|
33
|
+
mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
|
|
34
|
+
try:
|
|
35
|
+
mem_fraction = float(mem_fraction_raw)
|
|
36
|
+
except ValueError:
|
|
37
|
+
mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
|
|
38
|
+
|
|
39
|
+
# JAX usually can't allocate more than 98%-ish percent memory; idk why
|
|
40
|
+
# Besides we use _some_ autotuning during runtime, so we add 0.96 safety margin here
|
|
41
|
+
bytes_limit = int(max(bytes_limit, bytes_limit / min(mem_fraction, 1.0)) * 0.96)
|
|
42
|
+
return bytes_limit
|
|
17
43
|
|
|
18
44
|
|
|
19
45
|
def estimate_memory_from_batchsize(
|
|
@@ -30,14 +56,14 @@ def estimate_memory_from_batchsize(
|
|
|
30
56
|
max_output_length=max_output_length,
|
|
31
57
|
num_top_logits_to_return=num_logits_per_token,
|
|
32
58
|
),
|
|
33
|
-
backend="cpu", # cuda backend tries to allocate in .compile() and ooms
|
|
34
59
|
)
|
|
35
60
|
.lower(
|
|
36
61
|
model,
|
|
37
62
|
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
38
63
|
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
39
64
|
)
|
|
40
|
-
.
|
|
65
|
+
# disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
|
|
66
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
41
67
|
.memory_analysis()
|
|
42
68
|
)
|
|
43
69
|
|
lalamo/speculator/inference.py
CHANGED
|
@@ -40,7 +40,12 @@ def inference_collect_traces(
|
|
|
40
40
|
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
41
41
|
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
42
42
|
)
|
|
43
|
-
.
|
|
43
|
+
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
44
|
+
# 0 - no autotune, gpu shouldn't be touched
|
|
45
|
+
# 1 - basic level, gpu should be touched veeery little
|
|
46
|
+
# 2,3 - gpu touched more and more
|
|
47
|
+
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
48
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "2"})
|
|
44
49
|
)
|
|
45
50
|
|
|
46
51
|
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=MBRNUkooL0qiMz4B9VqvRblslw0VXAODjzg3BtY7uhU,1532
|
|
2
2
|
lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
|
|
3
3
|
lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
|
|
4
4
|
lalamo/main.py,sha256=Tez84CtMxUi1ySuRSqQElu4Zr1UWs_Gw6HX1xtCZknQ,27383
|
|
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
|
|
|
83
83
|
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
84
84
|
lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
|
|
85
85
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
86
|
-
lalamo/speculator/estimator.py,sha256=
|
|
87
|
-
lalamo/speculator/inference.py,sha256=
|
|
86
|
+
lalamo/speculator/estimator.py,sha256=xkvrNsoIWUbotv__e2ms-QP3NUnj92W_Fcpse8rvQ_o,3797
|
|
87
|
+
lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
|
|
88
88
|
lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
|
|
89
89
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
90
|
-
lalamo-0.6.
|
|
91
|
-
lalamo-0.6.
|
|
92
|
-
lalamo-0.6.
|
|
93
|
-
lalamo-0.6.
|
|
94
|
-
lalamo-0.6.
|
|
95
|
-
lalamo-0.6.
|
|
90
|
+
lalamo-0.6.3.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
91
|
+
lalamo-0.6.3.dist-info/METADATA,sha256=tvl435CXfVSezsey_FhFzZ5YaY64GOjZQBaDY5vEAPU,3112
|
|
92
|
+
lalamo-0.6.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
93
|
+
lalamo-0.6.3.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
94
|
+
lalamo-0.6.3.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
95
|
+
lalamo-0.6.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|