lalamo 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -1,3 +1,8 @@
1
+ import os
2
+
3
+ # Must run before importing jax / tensorflow, this hides the XLA optimization logs
4
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
5
+
1
6
  from lalamo.commands import (
2
7
  CollectTracesCallbacks,
3
8
  ConversionCallbacks,
@@ -27,7 +32,7 @@ from lalamo.speculator import (
27
32
  SpeculatorTrainingEvent,
28
33
  )
29
34
 
30
- __version__ = "0.6.2"
35
+ __version__ = "0.6.3"
31
36
 
32
37
  __all__ = [
33
38
  "AssistantMessage",
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import itertools
3
+ import os
3
4
  from collections.abc import Callable
4
5
  from typing import NamedTuple
5
6
 
@@ -10,10 +11,35 @@ from lalamo.models import LanguageModel
10
11
 
11
12
 
12
13
  def get_default_device_memory() -> int | None:
14
+ dynamic_allocate = False
15
+
16
+ preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
17
+ dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
18
+
19
+ allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
20
+ dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
21
+
22
+ if dynamic_allocate:
23
+ return None
24
+
13
25
  memory_stats = jax.local_devices()[0].memory_stats()
14
26
  if memory_stats is None or "bytes_limit" not in memory_stats:
15
27
  return None
16
- return memory_stats["bytes_limit"]
28
+
29
+ # discount by 0.98 because if you try to allocate exactly bytes_limit
30
+ # jax will _try_ to allocate a bit more, and then throws an OOM
31
+ bytes_limit = memory_stats["bytes_limit"] * 0.98
32
+
33
+ mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
34
+ try:
35
+ mem_fraction = float(mem_fraction_raw)
36
+ except ValueError:
37
+ mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
38
+
39
+ # JAX usually can't allocate more than 98%-ish percent memory; idk why
40
+ # Besides we use _some_ autotuning during runtime, so we add 0.96 safety margin here
41
+ bytes_limit = int(max(bytes_limit, bytes_limit / min(mem_fraction, 1.0)) * 0.96)
42
+ return bytes_limit
17
43
 
18
44
 
19
45
  def estimate_memory_from_batchsize(
@@ -30,14 +56,14 @@ def estimate_memory_from_batchsize(
30
56
  max_output_length=max_output_length,
31
57
  num_top_logits_to_return=num_logits_per_token,
32
58
  ),
33
- backend="cpu", # cuda backend tries to allocate in .compile() and ooms
34
59
  )
35
60
  .lower(
36
61
  model,
37
62
  prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
38
63
  prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
39
64
  )
40
- .compile()
65
+ # disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
66
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
41
67
  .memory_analysis()
42
68
  )
43
69
 
@@ -40,7 +40,12 @@ def inference_collect_traces(
40
40
  prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
41
  prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
42
  )
43
- .compile()
43
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
+ # 0 - no autotune, gpu shouldn't be touched
45
+ # 1 - basic level, gpu should be touched veeery little
46
+ # 2,3 - gpu touched more and more
47
+ # 4 (default) - gpu might allocate more memory than the run would require!
48
+ .compile(compiler_options={"xla_gpu_autotune_level": "2"})
44
49
  )
45
50
 
46
51
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.2
3
+ Version: 0.6.3
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,4 +1,4 @@
1
- lalamo/__init__.py,sha256=LlHkzLyEJUp3M_Zvcwqc2CTzLhhyqnHGqk2bs6hV3fY,1386
1
+ lalamo/__init__.py,sha256=MBRNUkooL0qiMz4B9VqvRblslw0VXAODjzg3BtY7uhU,1532
2
2
  lalamo/commands.py,sha256=zXyyrLTHhP9wouwtpX4RUZeEF6No-_9ee-y_GWGhw7k,10972
3
3
  lalamo/common.py,sha256=WaNJx20eUX4CBF50aym9lniGAiX-SzBJzDzO5Jh6zXA,4312
4
4
  lalamo/main.py,sha256=Tez84CtMxUi1ySuRSqQElu4Zr1UWs_Gw6HX1xtCZknQ,27383
@@ -83,13 +83,13 @@ lalamo/modules/token_mixers/state/mamba_state.py,sha256=LHzJvNE6MkB7nrsZSNto6pxb
83
83
  lalamo/modules/token_mixers/state/short_conv_state.py,sha256=osjcDHoeFWQaUoOROzeJe8F1qC8rvqunimGD4CuIDHo,895
84
84
  lalamo/speculator/__init__.py,sha256=9-tmZcbCom_lIGpJYn6xLlnEahFLFidpqmgkafmu--k,456
85
85
  lalamo/speculator/common.py,sha256=PudF_gkpe5_nQ-57sAC-foE1xCy_H2Axh5KwRoA86lo,587
86
- lalamo/speculator/estimator.py,sha256=S_TRwMnjWg5qt9le2AYua_Vmo6QkIT-0Si7TjCfC7xc,2670
87
- lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
86
+ lalamo/speculator/estimator.py,sha256=xkvrNsoIWUbotv__e2ms-QP3NUnj92W_Fcpse8rvQ_o,3797
87
+ lalamo/speculator/inference.py,sha256=uEv33Qqcpa2xqEKdIzmPzkAzRsZOlb8TPeEG6TP6fjo,4071
88
88
  lalamo/speculator/ngram.py,sha256=2eqInIieJPaQHCvIfnCIDtwMa8PGEtiND_NkG7plE34,5899
89
89
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
90
- lalamo-0.6.2.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
- lalamo-0.6.2.dist-info/METADATA,sha256=ZxR_Z-Q90tm45WUk2Wh1e_SpjKT0oW-FvkqmNXAqdvA,3112
92
- lalamo-0.6.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
- lalamo-0.6.2.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
- lalamo-0.6.2.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
- lalamo-0.6.2.dist-info/RECORD,,
90
+ lalamo-0.6.3.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
91
+ lalamo-0.6.3.dist-info/METADATA,sha256=tvl435CXfVSezsey_FhFzZ5YaY64GOjZQBaDY5vEAPU,3112
92
+ lalamo-0.6.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
93
+ lalamo-0.6.3.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
94
+ lalamo-0.6.3.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
95
+ lalamo-0.6.3.dist-info/RECORD,,
File without changes