lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +247 -14
- lalamo/common.py +33 -0
- lalamo/data/__init__.py +3 -2
- lalamo/data/huggingface_message.py +4 -5
- lalamo/main.py +274 -9
- lalamo/message_processor.py +19 -1
- lalamo/model_import/common.py +17 -1
- lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo/model_import/remote_registry.py +44 -0
- lalamo/models/__init__.py +3 -0
- lalamo/models/common.py +22 -0
- lalamo/models/compile_helpers.py +58 -0
- lalamo/models/language_model.py +342 -56
- lalamo/models/lm_helpers.py +198 -0
- lalamo/modules/decoder.py +4 -0
- lalamo/modules/token_mixers/mamba.py +345 -105
- lalamo/speculator/__init__.py +0 -2
- lalamo/speculator/inference.py +35 -61
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/METADATA +1 -1
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/RECORD +25 -23
- lalamo/speculator/estimator.py +0 -127
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/WHEEL +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import itertools
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable, Iterable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import numpy as np
|
|
10
|
+
from jax.errors import JaxRuntimeError
|
|
11
|
+
from jaxtyping import DTypeLike
|
|
12
|
+
|
|
13
|
+
from lalamo.common import LalamoWarning, get_usable_memory_from_bytes
|
|
14
|
+
from lalamo.models.common import InferenceConfig
|
|
15
|
+
|
|
16
|
+
type TokenSequence = list[int] | np.ndarray | jnp.ndarray
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"BatchSizeEstimatingEvent",
|
|
20
|
+
"decrease_batchsize_on_oom",
|
|
21
|
+
"estimate_batchsize_from_bytes",
|
|
22
|
+
"estimate_batchsizes_from_vram",
|
|
23
|
+
"merge_small_buckets",
|
|
24
|
+
"pad_keys_to_size",
|
|
25
|
+
"pad_sequences",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class BatchSizeEstimatingEvent:
|
|
31
|
+
lo: int
|
|
32
|
+
hi: int | None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _assert_sorted(values: list[int]) -> None:
|
|
36
|
+
assert all(values[i] <= values[i + 1] for i in range(len(values) - 1)), "expected sorted inputs"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def merge_small_buckets[T: TokenSequence](
|
|
40
|
+
buckets: dict[int, list[tuple[int, T]]],
|
|
41
|
+
batch_size_for_length: dict[int, int],
|
|
42
|
+
min_batches: int = 4,
|
|
43
|
+
) -> dict[int, list[tuple[int, T]]]:
|
|
44
|
+
# Merge buckets that are too small into larger buckets.
|
|
45
|
+
# Buckets smaller than min_batches * batch_size are merged into the next larger bucket.
|
|
46
|
+
# The last bucket absorbs all overflow.
|
|
47
|
+
sorted_lengths = sorted(buckets.keys())
|
|
48
|
+
merged: dict[int, list[tuple[int, T]]] = {}
|
|
49
|
+
overflow: list[tuple[int, T]] = []
|
|
50
|
+
|
|
51
|
+
for padded_len in sorted_lengths:
|
|
52
|
+
batch_size = batch_size_for_length.get(padded_len, 1) # note how with i's increment batch_size decreases
|
|
53
|
+
items = overflow + buckets[padded_len]
|
|
54
|
+
|
|
55
|
+
if len(items) < min_batches * batch_size:
|
|
56
|
+
# the bucket is too small, push the items into a bigger one
|
|
57
|
+
overflow = items
|
|
58
|
+
else:
|
|
59
|
+
# the bucket is big enough, keep _all_ the items and move on
|
|
60
|
+
# keeping all the items avoids a funny problem with spill over into _very_ long ctx length
|
|
61
|
+
merged[padded_len] = items
|
|
62
|
+
overflow = []
|
|
63
|
+
|
|
64
|
+
if overflow:
|
|
65
|
+
# any leftover items go into the largest bucket
|
|
66
|
+
largest_len = sorted_lengths[-1]
|
|
67
|
+
merged.setdefault(largest_len, []).extend(overflow)
|
|
68
|
+
|
|
69
|
+
return merged
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def pad_sequences(
|
|
73
|
+
sequences: Iterable[TokenSequence],
|
|
74
|
+
shape: tuple[int, int],
|
|
75
|
+
*,
|
|
76
|
+
dtype: DTypeLike,
|
|
77
|
+
pad_value: int = 0,
|
|
78
|
+
) -> jnp.ndarray:
|
|
79
|
+
batch_size, seq_len = shape
|
|
80
|
+
sequences_list = list(sequences)
|
|
81
|
+
if len(sequences_list) > batch_size:
|
|
82
|
+
raise ValueError(f"Expected at most {batch_size} sequences, got {len(sequences_list)}")
|
|
83
|
+
|
|
84
|
+
if len(sequences_list) < batch_size:
|
|
85
|
+
sequences_list.extend([jnp.array([pad_value], dtype=dtype)] * (batch_size - len(sequences_list)))
|
|
86
|
+
|
|
87
|
+
padded = np.full((batch_size, seq_len), pad_value, dtype=dtype)
|
|
88
|
+
for i, seq in enumerate(sequences_list):
|
|
89
|
+
seq_arr = np.asarray(seq, dtype=dtype)
|
|
90
|
+
if seq_arr.size > seq_len:
|
|
91
|
+
raise ValueError(f"Sequence length {seq_arr.size} exceeds target length {seq_len}")
|
|
92
|
+
padded[i, : seq_arr.size] = seq_arr
|
|
93
|
+
|
|
94
|
+
return jnp.array(padded)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def pad_keys_to_size(keys: Iterable, size: int, *, seed: int = 0) -> jnp.ndarray:
|
|
98
|
+
keys_list = list(keys)
|
|
99
|
+
if len(keys_list) > size:
|
|
100
|
+
raise ValueError(f"Expected at most {size} keys, got {len(keys_list)}")
|
|
101
|
+
if len(keys_list) == size:
|
|
102
|
+
return jnp.array(keys_list)
|
|
103
|
+
dummy_keys = jax.random.split(jax.random.key(seed), size - len(keys_list))
|
|
104
|
+
return jnp.concatenate([jnp.array(keys_list), dummy_keys])
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def estimate_batchsizes_from_vram(
|
|
108
|
+
memory_consumption_callback: Callable[[InferenceConfig], int],
|
|
109
|
+
sorted_lengths: list[int],
|
|
110
|
+
vram_bytes: int,
|
|
111
|
+
inference_config: InferenceConfig,
|
|
112
|
+
) -> dict[int, int]:
|
|
113
|
+
_assert_sorted(sorted_lengths)
|
|
114
|
+
assert len(sorted_lengths) > 0
|
|
115
|
+
usable_memory = get_usable_memory_from_bytes(vram_bytes)
|
|
116
|
+
|
|
117
|
+
def memory_consumption(bs: int, seq_len: int) -> int:
|
|
118
|
+
config = InferenceConfig(
|
|
119
|
+
max_output_length=inference_config.max_output_length,
|
|
120
|
+
padded_length=seq_len,
|
|
121
|
+
num_top_logits_to_return=inference_config.num_top_logits_to_return,
|
|
122
|
+
batch_size=bs,
|
|
123
|
+
)
|
|
124
|
+
return memory_consumption_callback(config)
|
|
125
|
+
|
|
126
|
+
result: dict[int, int] = {}
|
|
127
|
+
|
|
128
|
+
first_seq_len = sorted_lengths[0]
|
|
129
|
+
bs = estimate_batchsize_from_bytes(functools.partial(memory_consumption, seq_len=sorted_lengths[0]), usable_memory)
|
|
130
|
+
result[first_seq_len] = bs
|
|
131
|
+
for seq_len in sorted_lengths[1:]:
|
|
132
|
+
while bs > 1 and memory_consumption(bs, seq_len) > usable_memory:
|
|
133
|
+
bs = max(1, int(bs * 0.8))
|
|
134
|
+
|
|
135
|
+
result[seq_len] = bs
|
|
136
|
+
|
|
137
|
+
return result
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def estimate_batchsize_from_bytes(
|
|
141
|
+
memory_per_batchsize_callback: Callable[[int], int],
|
|
142
|
+
target_mem_bytes: int,
|
|
143
|
+
progress: Callable[[BatchSizeEstimatingEvent], None] | None = None,
|
|
144
|
+
) -> int:
|
|
145
|
+
lo = 0
|
|
146
|
+
hi = 0
|
|
147
|
+
for candidate_exp in itertools.count():
|
|
148
|
+
lo = hi
|
|
149
|
+
hi = 4**candidate_exp
|
|
150
|
+
|
|
151
|
+
if progress is not None:
|
|
152
|
+
progress(BatchSizeEstimatingEvent(lo, None))
|
|
153
|
+
if target_mem_bytes < memory_per_batchsize_callback(hi):
|
|
154
|
+
break
|
|
155
|
+
|
|
156
|
+
while hi - lo > 1:
|
|
157
|
+
mid = (lo + hi) // 2
|
|
158
|
+
|
|
159
|
+
if progress is not None:
|
|
160
|
+
progress(BatchSizeEstimatingEvent(lo, hi))
|
|
161
|
+
if target_mem_bytes < memory_per_batchsize_callback(mid):
|
|
162
|
+
hi = mid
|
|
163
|
+
else:
|
|
164
|
+
lo = mid
|
|
165
|
+
|
|
166
|
+
return lo
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def decrease_batchsize_on_oom[T](
|
|
170
|
+
fn: Callable[[int], Iterable[T]],
|
|
171
|
+
starting_batch_size: int,
|
|
172
|
+
) -> Iterable[T]:
|
|
173
|
+
first_batch_completed = False
|
|
174
|
+
effective_batch_size = starting_batch_size
|
|
175
|
+
|
|
176
|
+
while True:
|
|
177
|
+
try:
|
|
178
|
+
for result in fn(effective_batch_size):
|
|
179
|
+
yield result
|
|
180
|
+
|
|
181
|
+
# as soon as we yielded we are not allowed to retry anymore
|
|
182
|
+
# to make sure we don't ever miss/duplicate outputs
|
|
183
|
+
first_batch_completed = True
|
|
184
|
+
break
|
|
185
|
+
except JaxRuntimeError:
|
|
186
|
+
if first_batch_completed:
|
|
187
|
+
raise
|
|
188
|
+
# because OOM's sometimes generate stuff that won't be garbage collected,
|
|
189
|
+
# we need to be very aggressive with decreasing batchsize here
|
|
190
|
+
new_bs = max(int(0.7 * effective_batch_size - 1), 1)
|
|
191
|
+
if new_bs == 1 and effective_batch_size == 1:
|
|
192
|
+
raise
|
|
193
|
+
warnings.warn(
|
|
194
|
+
f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
|
|
195
|
+
LalamoWarning,
|
|
196
|
+
stacklevel=3,
|
|
197
|
+
)
|
|
198
|
+
effective_batch_size = new_bs
|
lalamo/modules/decoder.py
CHANGED
|
@@ -121,6 +121,10 @@ class Decoder(LalamoModule[DecoderConfig]):
|
|
|
121
121
|
embedding: EmbeddingBase
|
|
122
122
|
transformer: Transformer
|
|
123
123
|
|
|
124
|
+
@property
|
|
125
|
+
def vocab_size(self) -> int:
|
|
126
|
+
return self.embedding.vocab_size
|
|
127
|
+
|
|
124
128
|
@property
|
|
125
129
|
def activation_precision(self) -> DTypeLike:
|
|
126
130
|
return self.embedding.activation_precision
|