lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ import functools
2
+ import itertools
3
+ import warnings
4
+ from collections.abc import Callable, Iterable
5
+ from dataclasses import dataclass
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from jax.errors import JaxRuntimeError
11
+ from jaxtyping import DTypeLike
12
+
13
+ from lalamo.common import LalamoWarning, get_usable_memory_from_bytes
14
+ from lalamo.models.common import InferenceConfig
15
+
16
+ type TokenSequence = list[int] | np.ndarray | jnp.ndarray
17
+
18
+ __all__ = [
19
+ "BatchSizeEstimatingEvent",
20
+ "decrease_batchsize_on_oom",
21
+ "estimate_batchsize_from_bytes",
22
+ "estimate_batchsizes_from_vram",
23
+ "merge_small_buckets",
24
+ "pad_keys_to_size",
25
+ "pad_sequences",
26
+ ]
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class BatchSizeEstimatingEvent:
31
+ lo: int
32
+ hi: int | None
33
+
34
+
35
+ def _assert_sorted(values: list[int]) -> None:
36
+ assert all(values[i] <= values[i + 1] for i in range(len(values) - 1)), "expected sorted inputs"
37
+
38
+
39
+ def merge_small_buckets[T: TokenSequence](
40
+ buckets: dict[int, list[tuple[int, T]]],
41
+ batch_size_for_length: dict[int, int],
42
+ min_batches: int = 4,
43
+ ) -> dict[int, list[tuple[int, T]]]:
44
+ # Merge buckets that are too small into larger buckets.
45
+ # Buckets smaller than min_batches * batch_size are merged into the next larger bucket.
46
+ # The last bucket absorbs all overflow.
47
+ sorted_lengths = sorted(buckets.keys())
48
+ merged: dict[int, list[tuple[int, T]]] = {}
49
+ overflow: list[tuple[int, T]] = []
50
+
51
+ for padded_len in sorted_lengths:
52
+ batch_size = batch_size_for_length.get(padded_len, 1) # note how with i's increment batch_size decreases
53
+ items = overflow + buckets[padded_len]
54
+
55
+ if len(items) < min_batches * batch_size:
56
+ # the bucket is too small, push the items into a bigger one
57
+ overflow = items
58
+ else:
59
+ # the bucket is big enough, keep _all_ the items and move on
60
+ # keeping all the items avoids a funny problem with spill over into _very_ long ctx length
61
+ merged[padded_len] = items
62
+ overflow = []
63
+
64
+ if overflow:
65
+ # any leftover items go into the largest bucket
66
+ largest_len = sorted_lengths[-1]
67
+ merged.setdefault(largest_len, []).extend(overflow)
68
+
69
+ return merged
70
+
71
+
72
+ def pad_sequences(
73
+ sequences: Iterable[TokenSequence],
74
+ shape: tuple[int, int],
75
+ *,
76
+ dtype: DTypeLike,
77
+ pad_value: int = 0,
78
+ ) -> jnp.ndarray:
79
+ batch_size, seq_len = shape
80
+ sequences_list = list(sequences)
81
+ if len(sequences_list) > batch_size:
82
+ raise ValueError(f"Expected at most {batch_size} sequences, got {len(sequences_list)}")
83
+
84
+ if len(sequences_list) < batch_size:
85
+ sequences_list.extend([jnp.array([pad_value], dtype=dtype)] * (batch_size - len(sequences_list)))
86
+
87
+ padded = np.full((batch_size, seq_len), pad_value, dtype=dtype)
88
+ for i, seq in enumerate(sequences_list):
89
+ seq_arr = np.asarray(seq, dtype=dtype)
90
+ if seq_arr.size > seq_len:
91
+ raise ValueError(f"Sequence length {seq_arr.size} exceeds target length {seq_len}")
92
+ padded[i, : seq_arr.size] = seq_arr
93
+
94
+ return jnp.array(padded)
95
+
96
+
97
+ def pad_keys_to_size(keys: Iterable, size: int, *, seed: int = 0) -> jnp.ndarray:
98
+ keys_list = list(keys)
99
+ if len(keys_list) > size:
100
+ raise ValueError(f"Expected at most {size} keys, got {len(keys_list)}")
101
+ if len(keys_list) == size:
102
+ return jnp.array(keys_list)
103
+ dummy_keys = jax.random.split(jax.random.key(seed), size - len(keys_list))
104
+ return jnp.concatenate([jnp.array(keys_list), dummy_keys])
105
+
106
+
107
+ def estimate_batchsizes_from_vram(
108
+ memory_consumption_callback: Callable[[InferenceConfig], int],
109
+ sorted_lengths: list[int],
110
+ vram_bytes: int,
111
+ inference_config: InferenceConfig,
112
+ ) -> dict[int, int]:
113
+ _assert_sorted(sorted_lengths)
114
+ assert len(sorted_lengths) > 0
115
+ usable_memory = get_usable_memory_from_bytes(vram_bytes)
116
+
117
+ def memory_consumption(bs: int, seq_len: int) -> int:
118
+ config = InferenceConfig(
119
+ max_output_length=inference_config.max_output_length,
120
+ padded_length=seq_len,
121
+ num_top_logits_to_return=inference_config.num_top_logits_to_return,
122
+ batch_size=bs,
123
+ )
124
+ return memory_consumption_callback(config)
125
+
126
+ result: dict[int, int] = {}
127
+
128
+ first_seq_len = sorted_lengths[0]
129
+ bs = estimate_batchsize_from_bytes(functools.partial(memory_consumption, seq_len=sorted_lengths[0]), usable_memory)
130
+ result[first_seq_len] = bs
131
+ for seq_len in sorted_lengths[1:]:
132
+ while bs > 1 and memory_consumption(bs, seq_len) > usable_memory:
133
+ bs = max(1, int(bs * 0.8))
134
+
135
+ result[seq_len] = bs
136
+
137
+ return result
138
+
139
+
140
+ def estimate_batchsize_from_bytes(
141
+ memory_per_batchsize_callback: Callable[[int], int],
142
+ target_mem_bytes: int,
143
+ progress: Callable[[BatchSizeEstimatingEvent], None] | None = None,
144
+ ) -> int:
145
+ lo = 0
146
+ hi = 0
147
+ for candidate_exp in itertools.count():
148
+ lo = hi
149
+ hi = 4**candidate_exp
150
+
151
+ if progress is not None:
152
+ progress(BatchSizeEstimatingEvent(lo, None))
153
+ if target_mem_bytes < memory_per_batchsize_callback(hi):
154
+ break
155
+
156
+ while hi - lo > 1:
157
+ mid = (lo + hi) // 2
158
+
159
+ if progress is not None:
160
+ progress(BatchSizeEstimatingEvent(lo, hi))
161
+ if target_mem_bytes < memory_per_batchsize_callback(mid):
162
+ hi = mid
163
+ else:
164
+ lo = mid
165
+
166
+ return lo
167
+
168
+
169
+ def decrease_batchsize_on_oom[T](
170
+ fn: Callable[[int], Iterable[T]],
171
+ starting_batch_size: int,
172
+ ) -> Iterable[T]:
173
+ first_batch_completed = False
174
+ effective_batch_size = starting_batch_size
175
+
176
+ while True:
177
+ try:
178
+ for result in fn(effective_batch_size):
179
+ yield result
180
+
181
+ # as soon as we yielded we are not allowed to retry anymore
182
+ # to make sure we don't ever miss/duplicate outputs
183
+ first_batch_completed = True
184
+ break
185
+ except JaxRuntimeError:
186
+ if first_batch_completed:
187
+ raise
188
+ # because OOM's sometimes generate stuff that won't be garbage collected,
189
+ # we need to be very aggressive with decreasing batchsize here
190
+ new_bs = max(int(0.7 * effective_batch_size - 1), 1)
191
+ if new_bs == 1 and effective_batch_size == 1:
192
+ raise
193
+ warnings.warn(
194
+ f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
195
+ LalamoWarning,
196
+ stacklevel=3,
197
+ )
198
+ effective_batch_size = new_bs
lalamo/modules/decoder.py CHANGED
@@ -121,6 +121,10 @@ class Decoder(LalamoModule[DecoderConfig]):
121
121
  embedding: EmbeddingBase
122
122
  transformer: Transformer
123
123
 
124
+ @property
125
+ def vocab_size(self) -> int:
126
+ return self.embedding.vocab_size
127
+
124
128
  @property
125
129
  def activation_precision(self) -> DTypeLike:
126
130
  return self.embedding.activation_precision