lalamo 0.5.8__tar.gz → 0.5.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {lalamo-0.5.8 → lalamo-0.5.9}/PKG-INFO +1 -1
  2. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/__init__.py +1 -1
  3. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/common.py +2 -0
  4. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +31 -9
  5. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/loaders/huggingface.py +1 -1
  6. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/__init__.py +2 -0
  7. lalamo-0.5.9/lalamo/model_import/model_specs/essential_ai.py +17 -0
  8. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/huggingface.py +1 -1
  9. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/utils.py +7 -0
  10. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo.egg-info/PKG-INFO +1 -1
  11. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo.egg-info/SOURCES.txt +1 -0
  12. {lalamo-0.5.8 → lalamo-0.5.9}/LICENSE +0 -0
  13. {lalamo-0.5.8 → lalamo-0.5.9}/README.md +0 -0
  14. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/common.py +0 -0
  15. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/data/__init__.py +0 -0
  16. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/data/huggingface_message.py +0 -0
  17. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/data/lalamo_completions.py +0 -0
  18. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/data/utils.py +0 -0
  19. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/main.py +0 -0
  20. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/message_processor.py +0 -0
  21. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/__init__.py +0 -0
  22. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  23. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/common.py +0 -0
  24. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  25. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  26. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  27. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  28. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  29. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/llama.py +0 -0
  30. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  31. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  32. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  33. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  34. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +0 -0
  35. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/huggingface_generation_config.py +0 -0
  36. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  37. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/loaders/__init__.py +0 -0
  38. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/loaders/common.py +0 -0
  39. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/loaders/executorch.py +0 -0
  40. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/loaders/utils.py +0 -0
  41. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/common.py +0 -0
  42. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/deepseek.py +0 -0
  43. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/gemma.py +0 -0
  44. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  45. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/llama.py +0 -0
  46. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/llamba.py +0 -0
  47. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/mirai.py +0 -0
  48. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/mistral.py +0 -0
  49. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/pleias.py +0 -0
  50. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/polaris.py +0 -0
  51. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/qwen.py +0 -0
  52. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/model_import/model_specs/reka.py +0 -0
  53. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/models/__init__.py +0 -0
  54. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/models/classifier.py +0 -0
  55. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/models/common.py +0 -0
  56. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/models/language_model.py +0 -0
  57. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/__init__.py +0 -0
  58. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/activations.py +0 -0
  59. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/classifier.py +0 -0
  60. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/common.py +0 -0
  61. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/decoder.py +0 -0
  62. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/embedding.py +0 -0
  63. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/linear.py +0 -0
  64. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/mlp.py +0 -0
  65. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/mlx_interop.py +0 -0
  66. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/normalization.py +0 -0
  67. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/rope.py +0 -0
  68. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/__init__.py +0 -0
  69. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/attention.py +0 -0
  70. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/common.py +0 -0
  71. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/mamba.py +0 -0
  72. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  73. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/state/common.py +0 -0
  74. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  75. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  76. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/torch_interop.py +0 -0
  77. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/transformer.py +0 -0
  78. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/transformer_layer.py +0 -0
  79. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/modules/utils.py +0 -0
  80. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/quantization.py +0 -0
  81. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/registry_abc.py +0 -0
  82. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/sampling.py +0 -0
  83. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/speculator/__init__.py +0 -0
  84. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/speculator/common.py +0 -0
  85. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/speculator/estimator.py +0 -0
  86. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/speculator/inference.py +0 -0
  87. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/speculator/ngram.py +0 -0
  88. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo/speculator/utils.py +0 -0
  89. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo.egg-info/dependency_links.txt +0 -0
  90. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo.egg-info/entry_points.txt +0 -0
  91. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo.egg-info/requires.txt +0 -0
  92. {lalamo-0.5.8 → lalamo-0.5.9}/lalamo.egg-info/top_level.txt +0 -0
  93. {lalamo-0.5.8 → lalamo-0.5.9}/pyproject.toml +0 -0
  94. {lalamo-0.5.8 → lalamo-0.5.9}/setup.cfg +0 -0
  95. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_cartesia_mlx_models.py +0 -0
  96. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_chat_template.py +0 -0
  97. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_generation.py +0 -0
  98. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_huggingface_model_conversion.py +0 -0
  99. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_huggingface_models.py +0 -0
  100. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_mlx_models.py +0 -0
  101. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_model_spec.py +0 -0
  102. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_models.py +0 -0
  103. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_moe.py +0 -0
  104. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_parameter_tree.py +0 -0
  105. {lalamo-0.5.8 → lalamo-0.5.9}/tests/test_registry_abc.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.8
3
+ Version: 0.5.9
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -15,7 +15,7 @@ from lalamo.speculator import (
15
15
  SpeculatorTrainingEvent,
16
16
  )
17
17
 
18
- __version__ = "0.5.8"
18
+ __version__ = "0.5.9"
19
19
 
20
20
  __all__ = [
21
21
  "AssistantMessage",
@@ -17,6 +17,7 @@ from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
17
17
  from lalamo.models import ClassifierModel, ClassifierModelConfig, GenerationConfig, LanguageModel, LanguageModelConfig
18
18
  from lalamo.modules import Classifier, Decoder, LalamoModule
19
19
  from lalamo.quantization import QuantizationMode
20
+ from lalamo.utils import process_chat_template
20
21
 
21
22
  from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
22
23
  from .huggingface_generation_config import HFGenerationConfig
@@ -154,6 +155,7 @@ def import_message_processor(
154
155
  if model_spec.configs.chat_template is not None:
155
156
  raise ValueError("Conflicting chat template specifications.")
156
157
  prompt_template = tokenizer_config.chat_template
158
+ prompt_template = process_chat_template(prompt_template)
157
159
  tokenizer = Tokenizer.from_file(str(tokenizer_file))
158
160
 
159
161
  added_tokens = tokenizer_config.added_tokens()
@@ -10,7 +10,7 @@ from lalamo.modules.activations import GELU
10
10
  from lalamo.modules.linear import FullPrecisionLinearConfig
11
11
  from lalamo.modules.mlp import DenseMLPConfig
12
12
  from lalamo.modules.normalization import NormalizationConfig, UpcastMode
13
- from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
13
+ from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig, YARNRoPEConfig
14
14
  from lalamo.modules.token_mixers.attention import AttentionConfig
15
15
  from lalamo.modules.transformer_layer import TransformerLayerConfig
16
16
 
@@ -19,9 +19,6 @@ from .common import HuggingFaceLMConfig
19
19
  __all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
20
20
 
21
21
 
22
- NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER = 6
23
-
24
-
25
22
  def _round_to_bfloat16(x: float) -> float:
26
23
  return jnp.asarray(x).astype(jnp.bfloat16).item()
27
24
 
@@ -32,6 +29,16 @@ class GemmaRoPEScalingConfig:
32
29
  rope_type: Literal["linear"]
33
30
 
34
31
 
32
+ @dataclass(frozen=True)
33
+ class YarnRopeScalingConfig:
34
+ factor: float
35
+ beta_fast: float
36
+ beta_slow: float
37
+ original_max_position_embeddings: int
38
+ rope_type: Literal["yarn"]
39
+ truncate: bool = False
40
+
41
+
35
42
  @dataclass(frozen=True)
36
43
  class HFGemma3TextConfigRaw:
37
44
  hidden_size: int
@@ -39,6 +46,7 @@ class HFGemma3TextConfigRaw:
39
46
  model_type: Literal["gemma3_text"]
40
47
  num_hidden_layers: int
41
48
  sliding_window: int
49
+ sliding_window_pattern: int
42
50
  rms_norm_eps: float = 1e-06
43
51
  query_pre_attn_scalar: float = 256.0
44
52
  attention_bias: bool = False
@@ -49,7 +57,7 @@ class HFGemma3TextConfigRaw:
49
57
  max_position_embeddings: int = 131072
50
58
  rope_theta: float = 1000000.0
51
59
  rope_local_base_freq: float = 10000.0
52
- rope_scaling: GemmaRoPEScalingConfig | None = None
60
+ rope_scaling: GemmaRoPEScalingConfig | YarnRopeScalingConfig | None = None
53
61
  final_logit_softcapping: float | None = None
54
62
  vocab_size: int = 262208
55
63
 
@@ -57,7 +65,7 @@ class HFGemma3TextConfigRaw:
57
65
  def sliding_window_sizes(self) -> list[int | None]:
58
66
  result = []
59
67
  for i in range(self.num_hidden_layers):
60
- if (i + 1) % NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER == 0:
68
+ if (i + 1) % self.sliding_window_pattern == 0:
61
69
  result.append(None)
62
70
  else:
63
71
  result.append(self.sliding_window)
@@ -74,7 +82,7 @@ class HFGemma3TextConfigRaw:
74
82
  attention_scale = self.query_pre_attn_scalar**-0.5
75
83
  embedding_config = TiedEmbeddingConfig(
76
84
  input_scale=input_scale,
77
- logit_soft_cap=None,
85
+ logit_soft_cap=self.final_logit_softcapping,
78
86
  precision=activation_precision,
79
87
  )
80
88
  rms_norm_config = NormalizationConfig(
@@ -86,19 +94,33 @@ class HFGemma3TextConfigRaw:
86
94
  subtract_mean=False,
87
95
  )
88
96
 
89
- if self.rope_scaling is not None:
97
+ if isinstance(self.rope_scaling, GemmaRoPEScalingConfig):
90
98
  global_rope_config = LinearScalingRoPEConfig(
91
99
  precision=activation_precision,
92
100
  base=self.rope_theta,
93
101
  max_sequence_length=self.max_position_embeddings,
94
102
  scaling_factor=self.rope_scaling.factor,
95
103
  )
96
- else:
104
+ elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
105
+ global_rope_config = YARNRoPEConfig(
106
+ precision=activation_precision,
107
+ base=self.rope_theta,
108
+ scaling_factor=self.rope_scaling.factor,
109
+ max_sequence_length=self.max_position_embeddings,
110
+ original_context_length=self.rope_scaling.original_max_position_embeddings,
111
+ beta_fast=self.rope_scaling.beta_fast,
112
+ beta_slow=self.rope_scaling.beta_slow,
113
+ truncate=self.rope_scaling.truncate,
114
+ )
115
+ elif self.rope_scaling is None:
97
116
  global_rope_config = UnscaledRoPEConfig(
98
117
  precision=activation_precision,
99
118
  base=self.rope_theta,
100
119
  max_sequence_length=context_length or self.max_position_embeddings,
101
120
  )
121
+ else:
122
+ raise ValueError("Invalid rope scaling configuration")
123
+
102
124
  local_rope_config = UnscaledRoPEConfig(
103
125
  precision=activation_precision,
104
126
  base=self.rope_local_base_freq,
@@ -300,7 +300,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
300
300
  down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
301
301
  down_b = weights_dict[experts_path / "down_proj_bias"]
302
302
  if down_b.ndim == 1:
303
- down_b = jnp.broadcast_to(down_b, down_w.shape[:-1] + (down_b.shape[0],))
303
+ down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
304
304
 
305
305
  down_projection = load_parameters(
306
306
  lambda m: (m.weights, m.biases), # type: ignore
@@ -1,5 +1,6 @@
1
1
  from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
2
2
  from .deepseek import DEEPSEEK_MODELS
3
+ from .essential_ai import RNJ_MODELS
3
4
  from .gemma import GEMMA_MODELS
4
5
  from .gpt_oss import GPT_OSS_MODELS
5
6
  from .huggingface import HUGGINGFACE_MODELS
@@ -36,6 +37,7 @@ ALL_MODEL_LISTS = [
36
37
  QWEN_MODELS,
37
38
  REKA_MODELS,
38
39
  MIRAI_CLASSIFIER_MODELS,
40
+ RNJ_MODELS,
39
41
  ]
40
42
 
41
43
  ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
@@ -0,0 +1,17 @@
1
+ from lalamo.model_import.decoder_configs.huggingface import HFGemma3TextConfig
2
+
3
+ from .common import ModelSpec
4
+
5
+ __all__ = ["RNJ_MODELS"]
6
+
7
+ RNJ_MODELS = [
8
+ ModelSpec(
9
+ vendor="EssentialAI",
10
+ family="Rnj-1",
11
+ name="Rnj-1-Instruct",
12
+ size="8B",
13
+ quantization=None,
14
+ repo="EssentialAI/rnj-1-instruct",
15
+ config_type=HFGemma3TextConfig,
16
+ ),
17
+ ]
@@ -14,5 +14,5 @@ HUGGINGFACE_MODELS = [
14
14
  repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
15
15
  config_type=HFLlamaConfig,
16
16
  use_cases=tuple(),
17
- )
17
+ ),
18
18
  ]
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "MapSequence",
25
25
  "jax_uint4_to_packed_uint8",
26
26
  "open_safetensors",
27
+ "process_chat_template",
27
28
  ]
28
29
 
29
30
 
@@ -159,3 +160,9 @@ def jax_uint8_to_unpacked_uint4(array: Array) -> Array:
159
160
  )
160
161
 
161
162
  return unpacked.astype(jnp.uint4)
163
+
164
+
165
+ def process_chat_template(template: str) -> str:
166
+ template = template.replace("{% generation %}", "")
167
+ template = template.replace("{%- endgeneration -%}", "")
168
+ return template
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.8
3
+ Version: 0.5.9
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -45,6 +45,7 @@ lalamo/model_import/loaders/utils.py
45
45
  lalamo/model_import/model_specs/__init__.py
46
46
  lalamo/model_import/model_specs/common.py
47
47
  lalamo/model_import/model_specs/deepseek.py
48
+ lalamo/model_import/model_specs/essential_ai.py
48
49
  lalamo/model_import/model_specs/gemma.py
49
50
  lalamo/model_import/model_specs/gpt_oss.py
50
51
  lalamo/model_import/model_specs/huggingface.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes