lalamo 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -32,7 +32,7 @@ from lalamo.speculator import (
32
32
  SpeculatorTrainingEvent,
33
33
  )
34
34
 
35
- __version__ = "0.6.3"
35
+ __version__ = "0.6.5"
36
36
 
37
37
  __all__ = [
38
38
  "AssistantMessage",
lalamo/common.py CHANGED
@@ -1,9 +1,11 @@
1
+ import warnings
1
2
  from collections import defaultdict
2
- from collections.abc import Mapping, Sequence
3
+ from collections.abc import Callable, Iterable, Mapping, Sequence
3
4
  from typing import cast
4
5
 
5
6
  import jax.numpy as jnp
6
7
  from jax._src.api import ShapeDtypeStruct
8
+ from jax.errors import JaxRuntimeError
7
9
  from jaxtyping import Array, DTypeLike
8
10
 
9
11
  from lalamo.utils import MapDictValues, MapSequence
@@ -11,8 +13,10 @@ from lalamo.utils import MapDictValues, MapSequence
11
13
  __all__ = [
12
14
  "DEFAULT_PRECISION",
13
15
  "ArrayLike",
16
+ "LalamoWarning",
14
17
  "ParameterPath",
15
18
  "ParameterTree",
19
+ "decrease_batchsize_on_oom",
16
20
  "dummy_array",
17
21
  "flatten_parameters",
18
22
  "require_array",
@@ -23,6 +27,10 @@ __all__ = [
23
27
  DEFAULT_PRECISION: DTypeLike = jnp.bfloat16
24
28
 
25
29
 
30
+ class LalamoWarning(UserWarning):
31
+ """Custom warning class for Lalamo-specific warnings."""
32
+
33
+
26
34
  type ArrayLike = Array | ShapeDtypeStruct
27
35
 
28
36
 
@@ -121,3 +129,49 @@ class ParameterPath(str):
121
129
  if not self:
122
130
  return ParameterPath(str(other))
123
131
  return ParameterPath(self + "." + str(other))
132
+
133
+
134
+ def decrease_batchsize_on_oom[T](
135
+ fn: Callable[[int], Iterable[T]],
136
+ starting_batch_size: int,
137
+ ) -> Iterable[T]:
138
+ """
139
+ Execute fn(batch_size) with automatic batch size reduction on OOM.
140
+ Only reduces batch size if OOM happened on the first batch.
141
+
142
+ Args:
143
+ fn: Function that takes batch_size and returns an iterable
144
+ starting_batch_size: Initial batch size to try
145
+
146
+ Yields:
147
+ Results from fn(batch_size)
148
+
149
+ Raises:
150
+ JaxRuntimeError: If OOM occurs after first batch completes or at batch_size=1
151
+ """
152
+ first_batch_completed = False
153
+ effective_batch_size = starting_batch_size
154
+
155
+ while True:
156
+ try:
157
+ for result in fn(effective_batch_size):
158
+ yield result
159
+
160
+ # as soon as we yielded we are not allowed to retry anymore
161
+ # to make sure we don't ever miss/duplicate outputs
162
+ first_batch_completed = True
163
+ break
164
+ except JaxRuntimeError:
165
+ if first_batch_completed:
166
+ raise
167
+ # because OOM's sometimes generate stuff that won't be garbage collected,
168
+ # we need to be very aggressive with decreasing batchsize here
169
+ new_bs = max(int(0.7 * effective_batch_size - 1), 1)
170
+ if new_bs == 1 and effective_batch_size == 1:
171
+ raise
172
+ warnings.warn(
173
+ f"OOM detected. Reducing batch size {effective_batch_size} -> {new_bs}.",
174
+ LalamoWarning,
175
+ stacklevel=3,
176
+ )
177
+ effective_batch_size = new_bs
lalamo/main.py CHANGED
@@ -49,7 +49,10 @@ from lalamo.message_processor import UserMessage
49
49
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec
50
50
  from lalamo.model_import.common import FileSpec
51
51
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
52
- from lalamo.speculator.estimator import get_default_device_memory
52
+ from lalamo.speculator.estimator import (
53
+ get_default_device_bytes,
54
+ get_usable_memory_from_bytes,
55
+ )
53
56
  from lalamo.speculator.ngram import NGramSpeculator
54
57
  from lalamo.speculator.utils import test_speculator
55
58
 
@@ -384,6 +387,7 @@ class CliTraceCallbacks(TraceCallbacks):
384
387
  self.stack.close()
385
388
  console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
386
389
 
390
+
387
391
  @app.command(help="Trace a model.")
388
392
  def trace(
389
393
  model_path: Annotated[
@@ -557,14 +561,24 @@ def estimate_batchsize(
557
561
  ] = None,
558
562
  ) -> None:
559
563
  if vram_gb is not None:
560
- mem = vram_gb * 1024 * 1024 * 1024
561
- elif (mem := get_default_device_memory()) is None:
564
+ # note that in practice GPUs use GiB in their docs, e.g. H100 actually has 85GB of memory
565
+ mem_bytes = vram_gb * 1000 * 1000 * 1000
566
+ elif (mem_bytes := get_default_device_bytes()) is None:
562
567
  err_console.print("Cannot get the default device's memory stats, use --vram-gb")
563
568
  raise Exit(1)
564
569
 
570
+ usable_mem = get_usable_memory_from_bytes(mem_bytes)
571
+
565
572
  callbacks_type = CliEstimateBatchsizeCallbacks
566
573
 
567
- _estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
574
+ _estimate_batchsize(
575
+ model_path,
576
+ usable_mem,
577
+ max_input_length,
578
+ max_output_length,
579
+ num_logits_per_token,
580
+ callbacks_type,
581
+ )
568
582
 
569
583
 
570
584
  @dataclass
@@ -10,7 +10,7 @@ import jax.numpy as jnp
10
10
  from lalamo.models import LanguageModel
11
11
 
12
12
 
13
- def get_default_device_memory() -> int | None:
13
+ def get_default_device_bytes() -> int | None:
14
14
  dynamic_allocate = False
15
15
 
16
16
  preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
@@ -26,20 +26,22 @@ def get_default_device_memory() -> int | None:
26
26
  if memory_stats is None or "bytes_limit" not in memory_stats:
27
27
  return None
28
28
 
29
- # discount by 0.98 because if you try to allocate exactly bytes_limit
30
- # jax will _try_ to allocate a bit more, and then throws an OOM
31
- bytes_limit = memory_stats["bytes_limit"] * 0.98
32
-
33
29
  mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
34
30
  try:
35
31
  mem_fraction = float(mem_fraction_raw)
36
32
  except ValueError:
37
33
  mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
38
34
 
39
- # JAX usually can't allocate more than 98%-ish percent memory; idk why
40
- # Besides we use _some_ autotuning during runtime, so we add 0.96 safety margin here
41
- bytes_limit = int(max(bytes_limit, bytes_limit / min(mem_fraction, 1.0)) * 0.96)
42
- return bytes_limit
35
+ # 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
36
+ # so it should correspond to something you'd see in nvidia-smi
37
+ memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
38
+
39
+ return get_usable_memory_from_bytes(memory_limit)
40
+
41
+
42
+ def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
+ # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
+ return int(limit_bytes * 0.93)
43
45
 
44
46
 
45
47
  def estimate_memory_from_batchsize(
@@ -1,11 +1,13 @@
1
1
  import functools
2
2
  from collections.abc import Callable, Iterable
3
- from itertools import batched, chain
3
+ from itertools import batched, chain, islice
4
4
  from typing import NamedTuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
8
+ from jax._src.stages import Compiled
8
9
 
10
+ from lalamo.common import decrease_batchsize_on_oom
9
11
  from lalamo.data.lalamo_completions import LalamoCompletion
10
12
  from lalamo.data.utils import get_prefixes_ending_in_user_message
11
13
  from lalamo.message_processor import Message
@@ -27,75 +29,84 @@ def inference_collect_traces(
27
29
  tokens_to_generate: int | None = None,
28
30
  progress_callback: Callable[[CollectTracesEvent], None] | None = None,
29
31
  ) -> Iterable[LalamoCompletion]:
30
- generate_tokens_compiled = (
31
- jax.jit(
32
- functools.partial(
33
- LanguageModel.generate_tokens,
34
- max_output_length=max_output_length,
35
- num_top_logits_to_return=num_top_logits_to_collect,
36
- ),
32
+ def make_generate_tokens_compiled(batch_size: int) -> Compiled:
33
+ return (
34
+ jax.jit(
35
+ functools.partial(
36
+ LanguageModel.generate_tokens,
37
+ max_output_length=max_output_length,
38
+ num_top_logits_to_return=num_top_logits_to_collect,
39
+ ),
40
+ )
41
+ .lower(
42
+ model,
43
+ prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
44
+ prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
45
+ )
46
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
47
+ # 0 - no autotune, gpu shouldn't be touched
48
+ # 1 - basic level, gpu should be touched veeery little
49
+ # 2,3 - gpu touched more and more
50
+ # 4 (default) - gpu might allocate more memory than the run would require!
51
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
37
52
  )
38
- .lower(
39
- model,
40
- prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
- prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
- )
43
- # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
- # 0 - no autotune, gpu shouldn't be touched
45
- # 1 - basic level, gpu should be touched veeery little
46
- # 2,3 - gpu touched more and more
47
- # 4 (default) - gpu might allocate more memory than the run would require!
48
- .compile(compiler_options={"xla_gpu_autotune_level": "2"})
49
- )
50
53
 
51
54
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
52
55
 
53
56
  tokenized_prefixes = map(model.message_processor.tokenize_request, prefixes)
54
57
  filtered_prefixes = filter(lambda conv: len(conv) <= max_input_length, tokenized_prefixes)
55
58
 
56
- tokens_generated, sequences_processed = 0, 0
59
+ test_batch = list(islice(filtered_prefixes, batch_size))
57
60
 
58
- for real_batch in batched(filtered_prefixes, n=batch_size):
59
- batch_padding = batch_size - len(real_batch)
60
- batch = (*real_batch, *(([0],) * batch_padding))
61
+ def collect_traces_body(batch_size: int) -> Iterable[LalamoCompletion]:
62
+ tokens_generated, sequences_processed = 0, 0
63
+ generate_tokens_compiled = make_generate_tokens_compiled(batch_size)
64
+ for real_batch in batched(chain(test_batch, filtered_prefixes), n=batch_size):
65
+ batch_padding = batch_size - len(real_batch)
66
+ batch = (*real_batch, *(([0],) * batch_padding))
61
67
 
62
- length_without_padding = jnp.array(list(map(len, batch)))
68
+ length_without_padding = jnp.array(list(map(len, batch)))
63
69
 
64
- padded = jnp.array(
65
- [jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0) for tokens in batch],
66
- )
70
+ padded = jnp.array(
71
+ [
72
+ jnp.pad(jnp.array(tokens), (0, max_input_length - len(tokens)), constant_values=0)
73
+ for tokens in batch
74
+ ],
75
+ )
67
76
 
68
- generated = generate_tokens_compiled(
69
- model,
70
- prompt_token_ids=padded,
71
- prompt_lengths_without_padding=length_without_padding,
72
- )
77
+ generated = generate_tokens_compiled(
78
+ model,
79
+ prompt_token_ids=padded,
80
+ prompt_lengths_without_padding=length_without_padding,
81
+ )
73
82
 
74
- assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
83
+ assert generated.top_k_token_ids is not None and generated.top_k_token_logits is not None
75
84
 
76
- for conv_idx in range(len(real_batch)):
77
- token_ids = generated.token_ids[conv_idx].tolist()
78
- seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
79
- if tokens_to_generate is not None:
80
- seqlen = min(seqlen, tokens_to_generate - tokens_generated)
81
- tokens_generated += seqlen
82
- sequences_processed += 1
85
+ for conv_idx in range(len(real_batch)):
86
+ token_ids = generated.token_ids[conv_idx].tolist()
87
+ seqlen = next((i + 1 for i, t in enumerate(token_ids) if t in model.stop_token_ids), len(token_ids))
88
+ if tokens_to_generate is not None:
89
+ seqlen = min(seqlen, tokens_to_generate - tokens_generated)
90
+ tokens_generated += seqlen
91
+ sequences_processed += 1
83
92
 
84
- token_ids = token_ids[:seqlen]
85
- token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
86
- token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
87
- token_logits = [
88
- dict(zip(keys, values, strict=True))
89
- for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
90
- ]
93
+ token_ids = token_ids[:seqlen]
94
+ token_logits_ids = generated.top_k_token_ids[conv_idx, : len(token_ids)].tolist()
95
+ token_logits_values = generated.top_k_token_logits[conv_idx, : len(token_ids)].tolist()
96
+ token_logits = [
97
+ dict(zip(keys, values, strict=True))
98
+ for keys, values in zip(token_logits_ids, token_logits_values, strict=True)
99
+ ]
91
100
 
92
- yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
101
+ yield LalamoCompletion(batch[conv_idx], token_ids, token_logits)
102
+
103
+ if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
104
+ break
105
+
106
+ if progress_callback is not None:
107
+ progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
93
108
 
94
109
  if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
95
110
  break
96
111
 
97
- if progress_callback is not None:
98
- progress_callback(CollectTracesEvent(sequences_processed, tokens_generated))
99
-
100
- if tokens_to_generate is not None and tokens_generated >= tokens_to_generate:
101
- break
112
+ yield from decrease_batchsize_on_oom(collect_traces_body, batch_size)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.3
3
+ Version: 0.6.5
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,7 +1,7 @@
1
- lalamo/__init__.py,sha256=MBRNUkooL0qiMz4B9VqvRblslw0VXAODjzg3BtY7uhU,1532
1
+ lalamo/__init__.py,sha256=RpKc5sKIQHI8tPVwzH7lIJJWE7tJy6FZauEhabEp2Hg,1532
2
2
  lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
3
- lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
4
- lalamo/main.py,sha256=Tez84CtMxUi1ySuRSqQElu4Zr1UWs_Gw6HX1xtCZknQ,27383
3
+ lalamo/common.py,sha256=ddGIPlFCgo6Q683v8uP8G2dh8nsCJe9woZL8A_7_Rt4,6124
4
+ lalamo/main.py,sha256=f1zHYQpX_OndAguOE0wqIOkzjzUolUC7w3_1ndtMC4Y,27655
5
5
  lalamo/message_processor.py,sha256=PMKte9YijT3h9N7DjTNp8H4V45A_qlDqJaubqFevLX8,5924
6
6
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
7
7
  lalamo/registry_abc.py,sha256=qTikqviqqeseNzkjqoyQvL4dEWJYWzN0rI05T-JNTmo,2187
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
83
83
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
84
84
  lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
85
85
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
86
- lalamo/speculator/estimator.py,sha256=xkvrNsoIWUbotv__e2ms-QP3NUnj92W_Fcpse8rvQ_o,3797
87
- lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
86
+ lalamo/speculator/estimator.py,sha256=WPG3rxKq4iLro8QwcePF766ageexHc17ANiF5rKAlKU,3833
87
+ lalamo/speculator/inference.py,sha256=47TUiLV0Dkk3dbf1-IkdlWbHCICFw6IDwKZ73FYQUQo,4802
88
88
  lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
89
89
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
90
- lalamo-0.6.3.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
- lalamo-0.6.3.dist-info/METADATA,sha256=tvl435CXfVSezsey_FhFzZ5YaY64GOjZQBaDY5vEAPU,3112
92
- lalamo-0.6.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
- lalamo-0.6.3.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
- lalamo-0.6.3.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
- lalamo-0.6.3.dist-info/RECORD,,
90
+ lalamo-0.6.5.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
+ lalamo-0.6.5.dist-info/METADATA,sha256=EWI8eHaPSj7tXrW7xW9BPNpeDRjboNNvtGbq3hRELzU,3112
92
+ lalamo-0.6.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
+ lalamo-0.6.5.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
+ lalamo-0.6.5.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
+ lalamo-0.6.5.dist-info/RECORD,,
File without changes