lalamo 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -32,7 +32,7 @@ from lalamo.speculator import (
32
32
  SpeculatorTrainingEvent,
33
33
  )
34
34
 
35
- __version__ = "0.6.4"
35
+ __version__ = "0.6.5"
36
36
 
37
37
  __all__ = [
38
38
  "AssistantMessage",
lalamo/common.py CHANGED
@@ -1,9 +1,11 @@
1
+ import warnings
1
2
  from collections import defaultdict
2
- from collections.abc import Mapping, Sequence
3
+ from collections.abc import Callable, Iterable, Mapping, Sequence
3
4
  from typing import cast
4
5
 
5
6
  import jax.numpy as jnp
6
7
  from jax._src.api import ShapeDtypeStruct
8
+ from jax.errors import JaxRuntimeError
7
9
  from jaxtyping import Array, DTypeLike
8
10
 
9
11
  from lalamo.utils import MapDictValues, MapSequence
@@ -11,8 +13,10 @@ from lalamo.utils import MapDictValues, MapSequence
11
13
  __all__ = [
12
14
  "DEFAULT_PRECISION",
13
15
  "ArrayLike",
16
+ "LalamoWarning",
14
17
  "ParameterPath",
15
18
  "ParameterTree",
19
+ "decrease_batchsize_on_oom",
16
20
  "dummy_array",
17
21
  "flatten_parameters",
18
22
  "require_array",
@@ -23,6 +27,10 @@ __all__ = [
23
27
  DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
24
28
 
25
29
 
30
+ class LalamoWarning(UserWarning):
31
+ """Custom warning class for Lalamo-specific warnings."""
32
+
33
+
26
34
  type ArrayLike = Array | ShapeDtypeStruct
27
35
 
28
36
 
@@ -121,3 +129,49 @@ class ParameterPath(str):
121
129
  if not self:
122
130
  return ParameterPath(str(other))
123
131
  return ParameterPath(self + "." + str(other))
132
+
133
+
134
+ def decrease_batchsize_on_oom[T](
135
+ fn: Callable[[int], Iterable[T]],
136
+ starting_batch_size: int,
137
+ ) -> Iterable[T]:
138
+ """
139
+ Execute fn(batch_size) with automatic batch size reduction on OOM.
140
+ Only reduces batch size if OOM happened on the first batch.
141
+
142
+ Args:
143
+ fn: Function that takes batch_size and returns an iterable
144
+ starting_batch_size: Initial batch size to try
145
+
146
+ Yields:
147
+ Results from fn(batch_size)
148
+
149
+ Raises:
150
+ JaxRuntimeError: If OOM occurs after first batch completes or at batch_size=1
151
+ """
152
+ first_batch_completed = False
153
+ effective_batch_size = starting_batch_size
154
+
155
+ while True:
156
+ try:
157
+ for result in fn(effective_batch_size):
158
+ yield result
159
+
160
+ # as soon as we yielded we are not allowed to retry anymore
161
+ # to make sure we don't ever miss/duplicate outputs
162
+ first_batch_completed = True
163
+ break
164
+ except JaxRuntimeError:
165
+ if first_batch_completed:
166
+ raise
167
+ # because OOM's sometimes generate stuff that won't be garbage collected,
168
+ # we need to be very aggressive with decreasing batchsize here
169
+ new_bs = max(int(0.7 * effective_batch_size - 1), 1)
170
+ if new_bs == 1 and effective_batch_size == 1:
171
+ raise
172
+ warnings.warn(
173
+ f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
174
+ LalamoWarning,
175
+ stacklevel=3,
176
+ )
177
+ effective_batch_size = new_bs
@@ -41,7 +41,7 @@ def get_default_device_bytes() -> int | None:
41
41
 
42
42
  def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
43
  # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
- return int(limit_bytes * 0.95)
44
+ return int(limit_bytes * 0.93)
45
45
 
46
46
 
47
47
  def estimate_memory_from_batchsize(
@@ -1,11 +1,13 @@
1
1
  import functools
2
2
  from collections.abc import Callable, Iterable
3
- from itertools import batched, chain
3
+ from itertools import batched, chain, islice
4
4
  from typing import NamedTuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
8
+ from jax._src.stages import Compiled
8
9
 
10
+ from lalamo.common import decrease_batchsize_on_oom
9
11
  from lalamo.data.lalamo_completions import LalamoCompletion
10
12
  from lalamo.data.utils import get_prefixes_ending_in_user_message
11
13
  from lalamo.message_processor import Message
@@ -27,75 +29,84 @@ def inference_collect_traces(
27
29
  tokens_to_generate: int | None = None,
28
30
  progress_callback: Callable[[CollectTracesEvent], None] | None = None,
29
31
  ) -> Iterable[LalamoCompletion]:
30
- generate_tokens_compiled = (
31
- jax.jit(
32
- functools.partial(
33
- LanguageModel.generate_tokens,
34
- max_output_length=max_output_length,
35
- num_top_logits_to_return=num_top_logits_to_collect,
36
- ),
32
+ def make_generate_tokens_compiled(batch_size: int) -> Compiled:
33
+ return (
34
+ jax.jit(
35
+ functools.partial(
36
+ LanguageModel.generate_tokens,
37
+ max_output_length=max_output_length,
38
+ num_top_logits_to_return=num_top_logits_to_collect,
39
+ ),
40
+ )
41
+ .lower(
42
+ model,
43
+ prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
44
+ prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
45
+ )
46
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
47
+ # 0 - no autotune, gpu shouldn't be touched
48
+ # 1 - basic level, gpu should be touched veeery little
49
+ # 2,3 - gpu touched more and more
50
+ # 4 (default) - gpu might allocate more memory than the run would require!
51
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
37
52
  )
38
- .lower(
39
- model,
40
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
- )
43
- # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
- # 0 - no autotune, gpu shouldn't be touched
45
- # 1 - basic level, gpu should be touched veeery little
46
- # 2,3 - gpu touched more and more
47
- # 4 (default) - gpu might allocate more memory than the run would require!
48
- .compile(compiler_options={"xla_gpu_autotune_level": "2"})
49
- )
50
53
 
51
54
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
52
55
 
53
56
  tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
54
57
  filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
55
58
 
56
- tokens_generated, sequences_processed = 0, 0
59
+ test_batch = list(islice(filtered_prefixes, batch_size))
57
60
 
58
- for real_batch in batched(filtered_prefixes, n=batch_size):
59
- batch_padding = batch_size - len(real_batch)
60
- batch = (*real_batch, *(([0],) * batch_padding))
61
+ def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
62
+ tokens_generated, sequences_processed = 0, 0
63
+ generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
64
+ for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
65
+ batch_padding = batch_size - len(real_batch)
66
+ batch = (*real_batch, *(([0],) * batch_padding))
61
67
 
62
- length_without_padding = jnp.array(list(map(len, batch)))
68
+ length_without_padding = jnp.array(list(map(len, batch)))
63
69
 
64
- padded = jnp.array(
65
- [jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0) for tokens in batch],
66
- )
70
+ padded = jnp.array(
71
+ [
72
+ jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
73
+ for tokens in batch
74
+ ],
75
+ )
67
76
 
68
- generated = generate_tokens_compiled(
69
- model,
70
- prompt_token_ids=padded,
71
- prompt_lengths_without_padding=length_without_padding,
72
- )
77
+ generated = generate_tokens_compiled(
78
+ model,
79
+ prompt_token_ids=padded,
80
+ prompt_lengths_without_padding=length_without_padding,
81
+ )
73
82
 
74
- assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
83
+ assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
75
84
 
76
- for conv_idx in range(len(real_batch)):
77
- token_ids = generated.token_ids[conv_idx].tolist()
78
- seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
79
- if tokens_to_generate is not None:
80
- seqlen = min(seqlen, tokens_to_generate - tokens_generated)
81
- tokens_generated += seqlen
82
- sequences_processed += 1
85
+ for conv_idx in range(len(real_batch)):
86
+ token_ids = generated.token_ids[conv_idx].tolist()
87
+ seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
88
+ if tokens_to_generate is not None:
89
+ seqlen = min(seqlen, tokens_to_generate - tokens_generated)
90
+ tokens_generated += seqlen
91
+ sequences_processed += 1
83
92
 
84
- token_ids = token_ids[:seqlen]
85
- token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
86
- token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
87
- token_logits = [
88
- dict(zip(keys, values, strict=True))
89
- for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
90
- ]
93
+ token_ids = token_ids[:seqlen]
94
+ token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
95
+ token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
96
+ token_logits = [
97
+ dict(zip(keys, values, strict=True))
98
+ for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
99
+ ]
91
100
 
92
- yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
101
+ yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
102
+
103
+ if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
104
+ break
105
+
106
+ if progress_callback is not None:
107
+ progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
93
108
 
94
109
  if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
95
110
  break
96
111
 
97
- if progress_callback is not None:
98
- progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
99
-
100
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
101
- break
112
+ yield from decrease_batchsize_on_oom(collect_traces_body, batch_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.4
3
+ Version: 0.6.5
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
- lalamo/__init__.py,sha256=RDkf5Hhglc-fLZ-CmI4R-th6UgJKYmN-1hdbCzTiVx8,1532
1
+ lalamo/__init__.py,sha256=RpKc5sKIQHI8tPVwzH7lIJJWE7tJy6FZauEhabEp2Hg,1532
2
2
  lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
3
- lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
3
+ lalamo/common.py,sha256=ddGIPlFCgo6Q683v8uP8G2dh8nsCJe9woZL8A_7_Rt4,6124
4
4
  lalamo/main.py,sha256=f1zHYQpX_OndAguOE0wqIOkzjzUolUC7w3_1ndtMC4Y,27655
5
5
  lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
6
6
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
83
83
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
84
84
  lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
85
85
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
86
- lalamo/speculator/estimator.py,sha256=6T8NdmDdhvP0BPg7vdkB_pxAkfgpu4WktNpUHtFuyiE,3833
87
- lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
86
+ lalamo/speculator/estimator.py,sha256=WPG3rxKq4iLro8QwcePF766ageexHc17ANiF5rKAlKU,3833
87
+ lalamo/speculator/inference.py,sha256=47TUiLV0Dkk3dbf1-IkdlWbHCICFw6IDwKZ73FYQUQo,4802
88
88
  lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
89
89
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
90
- lalamo-0.6.4.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
- lalamo-0.6.4.dist-info/METADATA,sha256=oS1EAJBl3jBtvZU0Rd-UcjnL2Trngree7Syn2L16Rx8,3112
92
- lalamo-0.6.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
- lalamo-0.6.4.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
- lalamo-0.6.4.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
- lalamo-0.6.4.dist-info/RECORD,,
90
+ lalamo-0.6.5.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
+ lalamo-0.6.5.dist-info/METADATA,sha256=EWI8eHaPSj7tXrW7xW9BPNpeDRjboNNvtGbq3hRELzU,3112
92
+ lalamo-0.6.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
+ lalamo-0.6.5.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
+ lalamo-0.6.5.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
+ lalamo-0.6.5.dist-info/RECORD,,
File without changes