lalamo 0.6.3__tar.gz → 0.6.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. {lalamo-0.6.3 → lalamo-0.6.4}/PKG-INFO +1 -1
  2. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/__init__.py +1 -1
  3. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/main.py +18 -4
  4. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/speculator/estimator.py +11 -9
  5. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo.egg-info/PKG-INFO +1 -1
  6. {lalamo-0.6.3 → lalamo-0.6.4}/LICENSE +0 -0
  7. {lalamo-0.6.3 → lalamo-0.6.4}/README.md +0 -0
  8. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/commands.py +0 -0
  9. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/common.py +0 -0
  10. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/data/__init__.py +0 -0
  11. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/data/huggingface_message.py +0 -0
  12. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/data/lalamo_completions.py +0 -0
  13. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/data/utils.py +0 -0
  14. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/message_processor.py +0 -0
  15. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/__init__.py +0 -0
  16. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/common.py +0 -0
  17. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  18. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/common.py +0 -0
  19. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  20. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  21. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  22. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  23. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  24. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  25. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  26. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  27. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  28. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  29. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  30. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  31. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  32. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/huggingface_generation_config.py +0 -0
  33. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  34. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/loaders/__init__.py +0 -0
  35. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/loaders/common.py +0 -0
  36. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/loaders/executorch.py +0 -0
  37. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/loaders/huggingface.py +0 -0
  38. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/loaders/utils.py +0 -0
  39. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/__init__.py +0 -0
  40. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/common.py +0 -0
  41. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/deepseek.py +0 -0
  42. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  43. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/gemma.py +0 -0
  44. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  45. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/huggingface.py +0 -0
  46. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/lfm2.py +0 -0
  47. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/llama.py +0 -0
  48. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/llamba.py +0 -0
  49. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/mirai.py +0 -0
  50. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/mistral.py +0 -0
  51. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/pleias.py +0 -0
  52. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/polaris.py +0 -0
  53. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/qwen.py +0 -0
  54. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/model_import/model_specs/reka.py +0 -0
  55. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/models/__init__.py +0 -0
  56. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/models/classifier.py +0 -0
  57. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/models/common.py +0 -0
  58. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/models/language_model.py +0 -0
  59. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/__init__.py +0 -0
  60. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/activations.py +0 -0
  61. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/classifier.py +0 -0
  62. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/common.py +0 -0
  63. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/decoder.py +0 -0
  64. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/embedding.py +0 -0
  65. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/linear.py +0 -0
  66. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/mlp.py +0 -0
  67. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/mlx_interop.py +0 -0
  68. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/normalization.py +0 -0
  69. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/rope.py +0 -0
  70. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/__init__.py +0 -0
  71. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/attention.py +0 -0
  72. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/common.py +0 -0
  73. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/mamba.py +0 -0
  74. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/short_conv.py +0 -0
  75. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  76. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/common.py +0 -0
  77. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  78. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  79. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  80. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/torch_interop.py +0 -0
  81. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/transformer.py +0 -0
  82. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/transformer_layer.py +0 -0
  83. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/modules/utils.py +0 -0
  84. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/quantization.py +0 -0
  85. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/registry_abc.py +0 -0
  86. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/safetensors.py +0 -0
  87. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/sampling.py +0 -0
  88. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/speculator/__init__.py +0 -0
  89. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/speculator/common.py +0 -0
  90. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/speculator/inference.py +0 -0
  91. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/speculator/ngram.py +0 -0
  92. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/speculator/utils.py +0 -0
  93. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo/utils.py +0 -0
  94. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo.egg-info/SOURCES.txt +0 -0
  95. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo.egg-info/dependency_links.txt +0 -0
  96. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo.egg-info/entry_points.txt +0 -0
  97. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo.egg-info/requires.txt +0 -0
  98. {lalamo-0.6.3 → lalamo-0.6.4}/lalamo.egg-info/top_level.txt +0 -0
  99. {lalamo-0.6.3 → lalamo-0.6.4}/pyproject.toml +0 -0
  100. {lalamo-0.6.3 → lalamo-0.6.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.3
3
+ Version: 0.6.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -32,7 +32,7 @@ from lalamo.speculator import (
32
32
  SpeculatorTrainingEvent,
33
33
  )
34
34
 
35
- __version__ = "0.6.3"
35
+ __version__ = "0.6.4"
36
36
 
37
37
  __all__ = [
38
38
  "AssistantMessage",
@@ -49,7 +49,10 @@ from lalamo.message_processor import UserMessage
49
49
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec
50
50
  from lalamo.model_import.common import FileSpec
51
51
  from lalamo.models import ClassifierModelConfig, LanguageModelConfig
52
- from lalamo.speculator.estimator import get_default_device_memory
52
+ from lalamo.speculator.estimator import (
53
+ get_default_device_bytes,
54
+ get_usable_memory_from_bytes,
55
+ )
53
56
  from lalamo.speculator.ngram import NGramSpeculator
54
57
  from lalamo.speculator.utils import test_speculator
55
58
 
@@ -384,6 +387,7 @@ class CliTraceCallbacks(TraceCallbacks):
384
387
  self.stack.close()
385
388
  console.print(f"💾 Trace saved to [cyan]{self.output_path}[/cyan]")
386
389
 
390
+
387
391
  @app.command(help="Trace a model.")
388
392
  def trace(
389
393
  model_path: Annotated[
@@ -557,14 +561,24 @@ def estimate_batchsize(
557
561
  ] = None,
558
562
  ) -> None:
559
563
  if vram_gb is not None:
560
- mem = vram_gb * 1024 * 1024 * 1024
561
- elif (mem := get_default_device_memory()) is None:
564
+ # note that in practice GPUs use GiB in their docs, e.g. H100 actually has 85GB of memory
565
+ mem_bytes = vram_gb * 1000 * 1000 * 1000
566
+ elif (mem_bytes := get_default_device_bytes()) is None:
562
567
  err_console.print("Cannot get the default device's memory stats, use --vram-gb")
563
568
  raise Exit(1)
564
569
 
570
+ usable_mem = get_usable_memory_from_bytes(mem_bytes)
571
+
565
572
  callbacks_type = CliEstimateBatchsizeCallbacks
566
573
 
567
- _estimate_batchsize(model_path, mem, max_input_length, max_output_length, num_logits_per_token, callbacks_type)
574
+ _estimate_batchsize(
575
+ model_path,
576
+ usable_mem,
577
+ max_input_length,
578
+ max_output_length,
579
+ num_logits_per_token,
580
+ callbacks_type,
581
+ )
568
582
 
569
583
 
570
584
  @dataclass
@@ -10,7 +10,7 @@ import jax.numpy as jnp
10
10
  from lalamo.models import LanguageModel
11
11
 
12
12
 
13
- def get_default_device_memory() -> int | None:
13
+ def get_default_device_bytes() -> int | None:
14
14
  dynamic_allocate = False
15
15
 
16
16
  preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
@@ -26,20 +26,22 @@ def get_default_device_memory() -> int | None:
26
26
  if memory_stats is None or "bytes_limit" not in memory_stats:
27
27
  return None
28
28
 
29
- # discount by 0.98 because if you try to allocate exactly bytes_limit
30
- # jax will _try_ to allocate a bit more, and then throws an OOM
31
- bytes_limit = memory_stats["bytes_limit"] * 0.98
32
-
33
29
  mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
34
30
  try:
35
31
  mem_fraction = float(mem_fraction_raw)
36
32
  except ValueError:
37
33
  mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
38
34
 
39
- # JAX usually can't allocate more than 98%-ish percent memory; idk why
40
- # Besides we use _some_ autotuning during runtime, so we add 0.96 safety margin here
41
- bytes_limit = int(max(bytes_limit, bytes_limit / min(mem_fraction, 1.0)) * 0.96)
42
- return bytes_limit
35
+ # 500mb is seemingly the usually observed overhead; this tries to match the actual capacity of the gpu
36
+ # so it should correspond to something you'd see in nvidia-smi
37
+ memory_limit = memory_stats["bytes_limit"] / min(mem_fraction, 1.0) + (500 * 1000 * 1000)
38
+
39
+ return get_usable_memory_from_bytes(memory_limit)
40
+
41
+
42
+ def get_usable_memory_from_bytes(limit_bytes: int) -> int:
43
+ # JAX allocates a bit more than it needs, so we discount it by some safety factor
44
+ return int(limit_bytes * 0.95)
43
45
 
44
46
 
45
47
  def estimate_memory_from_batchsize(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.3
3
+ Version: 0.6.4
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes