lalamo 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/common.py +55 -1
- lalamo/speculator/estimator.py +1 -1
- lalamo/speculator/inference.py +65 -54
- {lalamo-0.6.4.dist-info → lalamo-0.6.5.dist-info}/METADATA +1 -1
- {lalamo-0.6.4.dist-info → lalamo-0.6.5.dist-info}/RECORD +10 -10
- {lalamo-0.6.4.dist-info → lalamo-0.6.5.dist-info}/WHEEL +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.5.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.5.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
lalamo/common.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
from collections import defaultdict
|
|
2
|
-
from collections.abc import Mapping, Sequence
|
|
3
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
3
4
|
from typing import cast
|
|
4
5
|
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
from jax._src.api import ShapeDtypeStruct
|
|
8
|
+
from jax.errors import JaxRuntimeError
|
|
7
9
|
from jaxtyping import Array, DTypeLike
|
|
8
10
|
|
|
9
11
|
from lalamo.utils import MapDictValues, MapSequence
|
|
@@ -11,8 +13,10 @@ from lalamo.utils import MapDictValues, MapSequence
|
|
|
11
13
|
__all__ = [
|
|
12
14
|
"DEFAULT_PRECISION",
|
|
13
15
|
"ArrayLike",
|
|
16
|
+
"LalamoWarning",
|
|
14
17
|
"ParameterPath",
|
|
15
18
|
"ParameterTree",
|
|
19
|
+
"decrease_batchsize_on_oom",
|
|
16
20
|
"dummy_array",
|
|
17
21
|
"flatten_parameters",
|
|
18
22
|
"require_array",
|
|
@@ -23,6 +27,10 @@ __all__ = [
|
|
|
23
27
|
DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
|
|
24
28
|
|
|
25
29
|
|
|
30
|
+
class LalamoWarning(UserWarning):
|
|
31
|
+
"""Custom warning class for Lalamo-specific warnings."""
|
|
32
|
+
|
|
33
|
+
|
|
26
34
|
type ArrayLike = Array | ShapeDtypeStruct
|
|
27
35
|
|
|
28
36
|
|
|
@@ -121,3 +129,49 @@ class ParameterPath(str):
|
|
|
121
129
|
if not self:
|
|
122
130
|
return ParameterPath(str(other))
|
|
123
131
|
return ParameterPath(self + "." + str(other))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def decrease_batchsize_on_oom[T](
|
|
135
|
+
fn: Callable[[int], Iterable[T]],
|
|
136
|
+
starting_batch_size: int,
|
|
137
|
+
) -> Iterable[T]:
|
|
138
|
+
"""
|
|
139
|
+
Execute fn(batch_size) with automatic batch size reduction on OOM.
|
|
140
|
+
Only reduces batch size if OOM happened on the first batch.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
fn: Function that takes batch_size and returns an iterable
|
|
144
|
+
starting_batch_size: Initial batch size to try
|
|
145
|
+
|
|
146
|
+
Yields:
|
|
147
|
+
Results from fn(batch_size)
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
JaxRuntimeError: If OOM occurs after first batch completes or at batch_size=1
|
|
151
|
+
"""
|
|
152
|
+
first_batch_completed = False
|
|
153
|
+
effective_batch_size = starting_batch_size
|
|
154
|
+
|
|
155
|
+
while True:
|
|
156
|
+
try:
|
|
157
|
+
for result in fn(effective_batch_size):
|
|
158
|
+
yield result
|
|
159
|
+
|
|
160
|
+
# as soon as we yielded we are not allowed to retry anymore
|
|
161
|
+
# to make sure we don't ever miss/duplicate outputs
|
|
162
|
+
first_batch_completed = True
|
|
163
|
+
break
|
|
164
|
+
except JaxRuntimeError:
|
|
165
|
+
if first_batch_completed:
|
|
166
|
+
raise
|
|
167
|
+
# because OOM's sometimes generate stuff that won't be garbage collected,
|
|
168
|
+
# we need to be very aggressive with decreasing batchsize here
|
|
169
|
+
new_bs = max(int(0.7 * effective_batch_size - 1), 1)
|
|
170
|
+
if new_bs == 1 and effective_batch_size == 1:
|
|
171
|
+
raise
|
|
172
|
+
warnings.warn(
|
|
173
|
+
f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
|
|
174
|
+
LalamoWarning,
|
|
175
|
+
stacklevel=3,
|
|
176
|
+
)
|
|
177
|
+
effective_batch_size = new_bs
|
lalamo/speculator/estimator.py
CHANGED
|
@@ -41,7 +41,7 @@ def get_default_device_bytes() -> int | None:
|
|
|
41
41
|
|
|
42
42
|
def get_usable_memory_from_bytes(limit_bytes: int) -> int:
|
|
43
43
|
# JAX allocates a bit more than it needs, so we discount it by some safety factor
|
|
44
|
-
return int(limit_bytes * 0.
|
|
44
|
+
return int(limit_bytes * 0.93)
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def estimate_memory_from_batchsize(
|
lalamo/speculator/inference.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
from collections.abc import Callable, Iterable
|
|
3
|
-
from itertools import batched, chain
|
|
3
|
+
from itertools import batched, chain, islice
|
|
4
4
|
from typing import NamedTuple
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
|
+
from jax._src.stages import Compiled
|
|
8
9
|
|
|
10
|
+
from lalamo.common import decrease_batchsize_on_oom
|
|
9
11
|
from lalamo.data.lalamo_completions import LalamoCompletion
|
|
10
12
|
from lalamo.data.utils import get_prefixes_ending_in_user_message
|
|
11
13
|
from lalamo.message_processor import Message
|
|
@@ -27,75 +29,84 @@ def inference_collect_traces(
|
|
|
27
29
|
tokens_to_generate: int | None = None,
|
|
28
30
|
progress_callback: Callable[[CollectTracesEvent], None] | None = None,
|
|
29
31
|
) -> Iterable[LalamoCompletion]:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
def make_generate_tokens_compiled(batch_size: int) -> Compiled:
|
|
33
|
+
return (
|
|
34
|
+
jax.jit(
|
|
35
|
+
functools.partial(
|
|
36
|
+
LanguageModel.generate_tokens,
|
|
37
|
+
max_output_length=max_output_length,
|
|
38
|
+
num_top_logits_to_return=num_top_logits_to_collect,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
.lower(
|
|
42
|
+
model,
|
|
43
|
+
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
44
|
+
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
45
|
+
)
|
|
46
|
+
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
47
|
+
# 0 - no autotune, gpu shouldn't be touched
|
|
48
|
+
# 1 - basic level, gpu should be touched veeery little
|
|
49
|
+
# 2,3 - gpu touched more and more
|
|
50
|
+
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
51
|
+
.compile(compiler_options={"xla_gpu_autotune_level": "0"})
|
|
37
52
|
)
|
|
38
|
-
.lower(
|
|
39
|
-
model,
|
|
40
|
-
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
41
|
-
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
42
|
-
)
|
|
43
|
-
# the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
|
|
44
|
-
# 0 - no autotune, gpu shouldn't be touched
|
|
45
|
-
# 1 - basic level, gpu should be touched veeery little
|
|
46
|
-
# 2,3 - gpu touched more and more
|
|
47
|
-
# 4 (default) - gpu might allocate more memory than the run would require!
|
|
48
|
-
.compile(compiler_options={"xla_gpu_autotune_level": "2"})
|
|
49
|
-
)
|
|
50
53
|
|
|
51
54
|
prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
|
|
52
55
|
|
|
53
56
|
tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
|
|
54
57
|
filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
|
|
55
58
|
|
|
56
|
-
|
|
59
|
+
test_batch = list(islice(filtered_prefixes, batch_size))
|
|
57
60
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
+
def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
|
|
62
|
+
tokens_generated, sequences_processed = 0, 0
|
|
63
|
+
generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
|
|
64
|
+
for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
|
|
65
|
+
batch_padding = batch_size - len(real_batch)
|
|
66
|
+
batch = (*real_batch, *(([0],) * batch_padding))
|
|
61
67
|
|
|
62
|
-
|
|
68
|
+
length_without_padding = jnp.array(list(map(len, batch)))
|
|
63
69
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
70
|
+
padded = jnp.array(
|
|
71
|
+
[
|
|
72
|
+
jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
|
|
73
|
+
for tokens in batch
|
|
74
|
+
],
|
|
75
|
+
)
|
|
67
76
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
77
|
+
generated = generate_tokens_compiled(
|
|
78
|
+
model,
|
|
79
|
+
prompt_token_ids=padded,
|
|
80
|
+
prompt_lengths_without_padding=length_without_padding,
|
|
81
|
+
)
|
|
73
82
|
|
|
74
|
-
|
|
83
|
+
assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
|
|
75
84
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
85
|
+
for conv_idx in range(len(real_batch)):
|
|
86
|
+
token_ids = generated.token_ids[conv_idx].tolist()
|
|
87
|
+
seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
|
|
88
|
+
if tokens_to_generate is not None:
|
|
89
|
+
seqlen = min(seqlen, tokens_to_generate - tokens_generated)
|
|
90
|
+
tokens_generated += seqlen
|
|
91
|
+
sequences_processed += 1
|
|
83
92
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
93
|
+
token_ids = token_ids[:seqlen]
|
|
94
|
+
token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
|
|
95
|
+
token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
|
|
96
|
+
token_logits = [
|
|
97
|
+
dict(zip(keys, values, strict=True))
|
|
98
|
+
for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
|
|
99
|
+
]
|
|
91
100
|
|
|
92
|
-
|
|
101
|
+
yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
|
|
102
|
+
|
|
103
|
+
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
104
|
+
break
|
|
105
|
+
|
|
106
|
+
if progress_callback is not None:
|
|
107
|
+
progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
|
|
93
108
|
|
|
94
109
|
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
95
110
|
break
|
|
96
111
|
|
|
97
|
-
|
|
98
|
-
progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
|
|
99
|
-
|
|
100
|
-
if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
|
|
101
|
-
break
|
|
112
|
+
yield from decrease_batchsize_on_oom(collect_traces_body, batch_size)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
lalamo/__init__.py,sha256=
|
|
1
|
+
lalamo/__init__.py,sha256=RpKc5sKIQHI8tPVwzH7lIJJWE7tJy6FZauEhabEp2Hg,1532
|
|
2
2
|
lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
|
|
3
|
-
lalamo/common.py,sha256=
|
|
3
|
+
lalamo/common.py,sha256=ddGIPlFCgo6Q683v8uP8G2dh8nsCJe9woZL8A_7_Rt4,6124
|
|
4
4
|
lalamo/main.py,sha256=f1zHYQpX_OndAguOE0wqIOkzjzUolUC7w3_1ndtMC4Y,27655
|
|
5
5
|
lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
|
|
6
6
|
lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
|
|
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
|
|
|
83
83
|
lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
|
|
84
84
|
lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
|
|
85
85
|
lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
|
|
86
|
-
lalamo/speculator/estimator.py,sha256=
|
|
87
|
-
lalamo/speculator/inference.py,sha256=
|
|
86
|
+
lalamo/speculator/estimator.py,sha256=WPG3rxKq4iLro8QwcePF766ageexHc17ANiF5rKAlKU,3833
|
|
87
|
+
lalamo/speculator/inference.py,sha256=47TUiLV0Dkk3dbf1-IkdlWbHCICFw6IDwKZ73FYQUQo,4802
|
|
88
88
|
lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
|
|
89
89
|
lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
|
|
90
|
-
lalamo-0.6.
|
|
91
|
-
lalamo-0.6.
|
|
92
|
-
lalamo-0.6.
|
|
93
|
-
lalamo-0.6.
|
|
94
|
-
lalamo-0.6.
|
|
95
|
-
lalamo-0.6.
|
|
90
|
+
lalamo-0.6.5.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
|
|
91
|
+
lalamo-0.6.5.dist-info/METADATA,sha256=EWI8eHaPSj7tXrW7xW9BPNpeDRjboNNvtGbq3hRELzU,3112
|
|
92
|
+
lalamo-0.6.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
93
|
+
lalamo-0.6.5.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
|
|
94
|
+
lalamo-0.6.5.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
|
|
95
|
+
lalamo-0.6.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|