lalamo 0.6.1__tar.gz → 0.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. {lalamo-0.6.1 → lalamo-0.6.3}/PKG-INFO +1 -1
  2. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/__init__.py +6 -1
  3. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/speculator/estimator.py +29 -3
  4. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/speculator/inference.py +6 -1
  5. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/speculator/ngram.py +1 -1
  6. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo.egg-info/PKG-INFO +1 -1
  7. {lalamo-0.6.1 → lalamo-0.6.3}/LICENSE +0 -0
  8. {lalamo-0.6.1 → lalamo-0.6.3}/README.md +0 -0
  9. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/commands.py +0 -0
  10. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/common.py +0 -0
  11. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/data/__init__.py +0 -0
  12. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/data/huggingface_message.py +0 -0
  13. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/data/lalamo_completions.py +0 -0
  14. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/data/utils.py +0 -0
  15. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/main.py +0 -0
  16. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/message_processor.py +0 -0
  17. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/__init__.py +0 -0
  18. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/common.py +0 -0
  19. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  20. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/common.py +0 -0
  21. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  22. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  23. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  24. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  25. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  26. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  27. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  28. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  29. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  30. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  31. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  32. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  33. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  34. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/huggingface_generation_config.py +0 -0
  35. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  36. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/loaders/__init__.py +0 -0
  37. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/loaders/common.py +0 -0
  38. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/loaders/executorch.py +0 -0
  39. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/loaders/huggingface.py +0 -0
  40. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/loaders/utils.py +0 -0
  41. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/__init__.py +0 -0
  42. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/common.py +0 -0
  43. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/deepseek.py +0 -0
  44. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  45. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/gemma.py +0 -0
  46. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  47. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/huggingface.py +0 -0
  48. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/lfm2.py +0 -0
  49. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/llama.py +0 -0
  50. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/llamba.py +0 -0
  51. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/mirai.py +0 -0
  52. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/mistral.py +0 -0
  53. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/pleias.py +0 -0
  54. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/polaris.py +0 -0
  55. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/qwen.py +0 -0
  56. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/model_import/model_specs/reka.py +0 -0
  57. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/models/__init__.py +0 -0
  58. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/models/classifier.py +0 -0
  59. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/models/common.py +0 -0
  60. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/models/language_model.py +0 -0
  61. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/__init__.py +0 -0
  62. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/activations.py +0 -0
  63. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/classifier.py +0 -0
  64. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/common.py +0 -0
  65. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/decoder.py +0 -0
  66. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/embedding.py +0 -0
  67. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/linear.py +0 -0
  68. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/mlp.py +0 -0
  69. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/mlx_interop.py +0 -0
  70. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/normalization.py +0 -0
  71. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/rope.py +0 -0
  72. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/__init__.py +0 -0
  73. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/attention.py +0 -0
  74. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/common.py +0 -0
  75. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/mamba.py +0 -0
  76. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/short_conv.py +0 -0
  77. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  78. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/state/common.py +0 -0
  79. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  80. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  81. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  82. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/torch_interop.py +0 -0
  83. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/transformer.py +0 -0
  84. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/transformer_layer.py +0 -0
  85. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/modules/utils.py +0 -0
  86. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/quantization.py +0 -0
  87. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/registry_abc.py +0 -0
  88. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/safetensors.py +0 -0
  89. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/sampling.py +0 -0
  90. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/speculator/__init__.py +0 -0
  91. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/speculator/common.py +0 -0
  92. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/speculator/utils.py +0 -0
  93. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo/utils.py +0 -0
  94. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo.egg-info/SOURCES.txt +0 -0
  95. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo.egg-info/dependency_links.txt +0 -0
  96. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo.egg-info/entry_points.txt +0 -0
  97. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo.egg-info/requires.txt +0 -0
  98. {lalamo-0.6.1 → lalamo-0.6.3}/lalamo.egg-info/top_level.txt +0 -0
  99. {lalamo-0.6.1 → lalamo-0.6.3}/pyproject.toml +0 -0
  100. {lalamo-0.6.1 → lalamo-0.6.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.1
3
+ Version: 0.6.3
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,3 +1,8 @@
1
+ import os
2
+
3
+ # Must run before importing jax / tensorflow, this hides the XLA optimization logs
4
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
5
+
1
6
  from lalamo.commands import (
2
7
  CollectTracesCallbacks,
3
8
  ConversionCallbacks,
@@ -27,7 +32,7 @@ from lalamo.speculator import (
27
32
  SpeculatorTrainingEvent,
28
33
  )
29
34
 
30
- __version__ = "0.6.1"
35
+ __version__ = "0.6.3"
31
36
 
32
37
  __all__ = [
33
38
  "AssistantMessage",
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import itertools
3
+ import os
3
4
  from collections.abc import Callable
4
5
  from typing import NamedTuple
5
6
 
@@ -10,10 +11,35 @@ from lalamo.models import LanguageModel
10
11
 
11
12
 
12
13
  def get_default_device_memory() -> int | None:
14
+ dynamic_allocate = False
15
+
16
+ preallocate = os.getenv("XLA_PYTHON_CLIENT_PREALLOCATE", "")
17
+ dynamic_allocate |= preallocate.strip().lower() in {"0", "false", "no", "off"}
18
+
19
+ allocator = os.getenv("XLA_PYTHON_CLIENT_ALLOCATOR", "")
20
+ dynamic_allocate |= allocator.strip().lower() in {"platform", "cuda_malloc_async"}
21
+
22
+ if dynamic_allocate:
23
+ return None
24
+
13
25
  memory_stats = jax.local_devices()[0].memory_stats()
14
26
  if memory_stats is None or "bytes_limit" not in memory_stats:
15
27
  return None
16
- return memory_stats["bytes_limit"]
28
+
29
+ # discount by 0.98 because if you try to allocate exactly bytes_limit
30
+ # jax will _try_ to allocate a bit more, and then throws an OOM
31
+ bytes_limit = memory_stats["bytes_limit"] * 0.98
32
+
33
+ mem_fraction_raw = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "")
34
+ try:
35
+ mem_fraction = float(mem_fraction_raw)
36
+ except ValueError:
37
+ mem_fraction = 0.75 # jax default https://docs.jax.dev/en/latest/gpu_memory_allocation.html
38
+
39
+ # JAX usually can't allocate more than 98%-ish percent memory; idk why
40
+ # Besides we use _some_ autotuning during runtime, so we add 0.96 safety margin here
41
+ bytes_limit = int(max(bytes_limit, bytes_limit / min(mem_fraction, 1.0)) * 0.96)
42
+ return bytes_limit
17
43
 
18
44
 
19
45
  def estimate_memory_from_batchsize(
@@ -30,14 +56,14 @@ def estimate_memory_from_batchsize(
30
56
  max_output_length=max_output_length,
31
57
  num_top_logits_to_return=num_logits_per_token,
32
58
  ),
33
- backend="cpu", # cuda backend tries to allocate in .compile() and ooms
34
59
  )
35
60
  .lower(
36
61
  model,
37
62
  prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
38
63
  prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
39
64
  )
40
- .compile()
65
+ # disables autotune, see https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level
66
+ .compile(compiler_options={"xla_gpu_autotune_level": "0"})
41
67
  .memory_analysis()
42
68
  )
43
69
 
@@ -40,7 +40,12 @@ def inference_collect_traces(
40
40
  prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
41
41
  prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
42
42
  )
43
- .compile()
43
+ # the autotune levels are (according to https://guides.lw1.at/all-xla-options/#--xla_gpu_autotune_level)
44
+ # 0 - no autotune, gpu shouldn't be touched
45
+ # 1 - basic level, gpu should be touched veeery little
46
+ # 2,3 - gpu touched more and more
47
+ # 4 (default) - gpu might allocate more memory than the run would require!
48
+ .compile(compiler_options={"xla_gpu_autotune_level": "2"})
44
49
  )
45
50
 
46
51
  prefixes = chain.from_iterable(map(get_prefixes_ending_in_user_message, conversations))
@@ -129,7 +129,7 @@ class NGramSpeculator(Speculator):
129
129
 
130
130
  return (
131
131
  memoryview(self.ngram_keys)[idx_start:idx_end],
132
- memoryview(self.ngram_values)[idx_start:idx_end].cast("f"), # noop cast to make typechecker happy
132
+ memoryview(self.ngram_values)[idx_start:idx_end].cast("c").cast("f"), # noop cast to make typechecker happy
133
133
  memoryview(self.ngram_counts)[seq_hash : (seq_hash + 1)],
134
134
  )
135
135
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.6.1
3
+ Version: 0.6.3
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes