lalamo 0.6.2__tar.gz → 0.6.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. {lalamo-0.6.2 → lalamo-0.6.4}/PKG-INFO +1 -1
  2. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/__init__.py +6 -1
  3. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/main.py +18 -4
  4. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/speculator/estimator.py +32 -4
  5. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/speculator/inference.py +6 -1
  6. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo.egg-info/PKG-INFO +1 -1
  7. {lalamo-0.6.2 → lalamo-0.6.4}/LICENSE +0 -0
  8. {lalamo-0.6.2 → lalamo-0.6.4}/README.md +0 -0
  9. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/commands.py +0 -0
  10. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/common.py +0 -0
  11. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/data/__init__.py +0 -0
  12. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/data/huggingface_message.py +0 -0
  13. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/data/lalamo_completions.py +0 -0
  14. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/data/utils.py +0 -0
  15. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/message_processor.py +0 -0
  16. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/__init__.py +0 -0
  17. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/common.py +0 -0
  18. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  19. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/common.py +0 -0
  20. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  21. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  22. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  23. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  24. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  25. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  26. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  27. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  28. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  29. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  30. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  31. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  32. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  33. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/huggingface_generation_config.py +0 -0
  34. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  35. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/loaders/__init__.py +0 -0
  36. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/loaders/common.py +0 -0
  37. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/loaders/executorch.py +0 -0
  38. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/loaders/huggingface.py +0 -0
  39. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/loaders/utils.py +0 -0
  40. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/__init__.py +0 -0
  41. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/common.py +0 -0
  42. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/deepseek.py +0 -0
  43. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  44. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/gemma.py +0 -0
  45. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  46. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/huggingface.py +0 -0
  47. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/lfm2.py +0 -0
  48. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/llama.py +0 -0
  49. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/llamba.py +0 -0
  50. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/mirai.py +0 -0
  51. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/mistral.py +0 -0
  52. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/pleias.py +0 -0
  53. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/polaris.py +0 -0
  54. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/qwen.py +0 -0
  55. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/model_import/model_specs/reka.py +0 -0
  56. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/models/__init__.py +0 -0
  57. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/models/classifier.py +0 -0
  58. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/models/common.py +0 -0
  59. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/models/language_model.py +0 -0
  60. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/__init__.py +0 -0
  61. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/activations.py +0 -0
  62. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/classifier.py +0 -0
  63. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/common.py +0 -0
  64. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/decoder.py +0 -0
  65. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/embedding.py +0 -0
  66. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/linear.py +0 -0
  67. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/mlp.py +0 -0
  68. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/mlx_interop.py +0 -0
  69. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/normalization.py +0 -0
  70. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/rope.py +0 -0
  71. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/__init__.py +0 -0
  72. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/attention.py +0 -0
  73. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/common.py +0 -0
  74. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/mamba.py +0 -0
  75. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/short_conv.py +0 -0
  76. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  77. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/common.py +0 -0
  78. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  79. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  80. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  81. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/torch_interop.py +0 -0
  82. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/transformer.py +0 -0
  83. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/transformer_layer.py +0 -0
  84. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/modules/utils.py +0 -0
  85. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/quantization.py +0 -0
  86. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/registry_abc.py +0 -0
  87. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/safetensors.py +0 -0
  88. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/sampling.py +0 -0
  89. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/speculator/__init__.py +0 -0
  90. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/speculator/common.py +0 -0
  91. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/speculator/ngram.py +0 -0
  92. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/speculator/utils.py +0 -0
  93. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo/utils.py +0 -0
  94. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo.egg-info/SOURCES.txt +0 -0
  95. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo.egg-info/dependency_links.txt +0 -0
  96. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo.egg-info/entry_points.txt +0 -0
  97. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo.egg-info/requires.txt +0 -0
  98. {lalamo-0.6.2 → lalamo-0.6.4}/lalamo.egg-info/top_level.txt +0 -0
  99. {lalamo-0.6.2 → lalamo-0.6.4}/pyproject.toml +0 -0
  100. {lalamo-0.6.2 → lalamo-0.6.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.2
3
+ Version: 0.6.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,3 +1,8 @@
1
+ import os
2
+
3
+ # Must run before importing jax / tensorflow, this hides the XLA optimization logs
4
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
5
+
1
6
  from lalamo.commands import (
2
7
  CollectTracesCallbacks,
3
8
  ConversionCallbacks,
@@ -27,7 +32,7 @@ from lalamo.speculator import (
27
32
  SpeculatorTrainingEvent,
28
33
  )
29
34
 
30
- __version__ = "0.6.2"
35
+ __version__ = "0.6.4"
31
36
 
32
37
  __all__ = [
33
38
  "AssistantMessage",
@@ -49,7 +49,10 @@ from lalamo.message_processor import UserMessage
49
49
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec
50
50
  from lalamo.model_import.common import FileSpec
51
51
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
52
- from lalamo.speculator.estimator import get_default_device_memory
52
+ from lalamo.speculator.estimator import (
53
+ get_default_device_bytes,
54
+ get_usable_memory_from_bytes,
55
+ )
53
56
  from lalamo.speculator.ngram import NGramSpeculator
54
57
  from lalamo.speculator.utils import test_speculator
55
58
 
@@ -384,6 +387,7 @@ class CliTraceCallbacks(TraceCallbacks):
384
387
  self.stack.close()
385
388
  console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
386
389
 
390
+
387
391
  @app.command(help="Trace a model.")
388
392
  def trace(
389
393
  model_path: Annotated[
@@ -557,14 +561,24 @@ def estimate_batchsize(
557
561
  ] = None,
558
562
  ) -> None:
559
563
  if vram_gb is not None:
560
- mem = vram_gb * 1024 * 1024 * 1024
561
- elif (mem := get_default_device_memory()) is None:
564
+ # note that in practice GPUs use GiB in their docs, e.g. H100 actually has 85GB of memory
565
+ mem_bytes = vram_gb * 1000 * 1000 * 1000
566
+ elif (mem_bytes := get_default_device_bytes()) is None:
562
567
  err_console.print("Cannot get the default device's memory stats, use --vram-gb")
563
568
  raise Exit(1)
564
569
 
570
+ usable_mem = get_usable_memory_from_bytes(mem_bytes)
571
+
565
572
  callbacks_type = CliEstimateBatchsizeCallbacks
566
573
 
567
- _estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
574
+ _estimate_batchsize(
575
+ model_path,
576
+ usable_mem,
577
+ max_input_length,
578
+ max_output_length,
579
+ num_logits_per_token,
580
+ callbacks_type,
581
+ )
568
582
 
569
583
 
570
584
  @dataclass
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import itertools
3
+ import os
3
4
  from collections.abc import Callable
4
5
  from typing import NamedTuple
5
6
 
@@ -9,11 +10,38 @@ import jax.numpy as jnp
9
10
  from lalamo.models import LanguageModel
10
11
 
11
12
 
12
- def get_default_device_memory() -> int | None:
13
+ def get_default_device_bytes() -> int | None:
14
+ dynamic_allocate = False
15
+
16
+ preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
17
+ dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
18
+
19
+ allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
20
+ dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
21
+
22
+ if dynamic_allocate:
23
+ return None
24
+
13
25
  memory_stats = jax.local_devices()[0].memory_stats()
14
26
  if memory_stats is None or "bytes_limit" not in memory_stats:
15
27
  return None
16
- return memory_stats["bytes_limit"]
28
+
29
+ mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
30
+ try:
31
+ mem_fraction = float(mem_fraction_raw)
32
+ except ValueError:
33
+ mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
34
+
35
+ # 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
36
+ # so it should correspond to something you'd see in nvidia-smi
37
+ memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
38
+
39
+ return get_usable_memory_from_bytes(memory_limit)
40
+
41
+
42
+ def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
+ # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
+ return int(limit_bytes * 0.95)
17
45
 
18
46
 
19
47
  def estimate_memory_from_batchsize(
@@ -30,14 +58,14 @@ def estimate_memory_from_batchsize(
30
58
  max_output_length=max_output_length,
31
59
  num_top_logits_to_return=num_logits_per_token,
32
60
  ),
33
- backend="cpu", # cuda backend tries to allocate in .compile() and ooms
34
61
  )
35
62
  .lower(
36
63
  model,
37
64
  prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
38
65
  prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
39
66
  )
40
- .compile()
67
+ # disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
68
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
41
69
  .memory_analysis()
42
70
  )
43
71
 
@@ -40,7 +40,12 @@ def inference_collect_traces(
40
40
  prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
41
  prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
42
  )
43
- .compile()
43
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
+ # 0 - no autotune, gpu shouldn't be touched
45
+ # 1 - basic level, gpu should be touched veeery little
46
+ # 2,3 - gpu touched more and more
47
+ # 4 (default) - gpu might allocate more memory than the run would require!
48
+ .compile(compiler_options={"xla_gpu_autotune_level": "2"})
44
49
  )
45
50
 
46
51
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.2
3
+ Version: 0.6.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes