lalamo 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lalamo/__init__.py CHANGED
@@ -15,7 +15,7 @@ from lalamo.speculator import (
15
15
  SpeculatorTrainingEvent,
16
16
  )
17
17
 
18
- __version__ = "0.5.8"
18
+ __version__ = "0.5.9"
19
19
 
20
20
  __all__ = [
21
21
  "AssistantMessage",
@@ -17,6 +17,7 @@ from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
17
17
  from lalamo.models import ClassifierModel, ClassifierModelConfig, GenerationConfig, LanguageModel, LanguageModelConfig
18
18
  from lalamo.modules import Classifier, Decoder, LalamoModule
19
19
  from lalamo.quantization import QuantizationMode
20
+ from lalamo.utils import process_chat_template
20
21
 
21
22
  from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
22
23
  from .huggingface_generation_config import HFGenerationConfig
@@ -154,6 +155,7 @@ def import_message_processor(
154
155
  if model_spec.configs.chat_template is not None:
155
156
  raise ValueError("Conflicting chat template specifications.")
156
157
  prompt_template = tokenizer_config.chat_template
158
+ prompt_template = process_chat_template(prompt_template)
157
159
  tokenizer = Tokenizer.from_file(str(tokenizer_file))
158
160
 
159
161
  added_tokens = tokenizer_config.added_tokens()
@@ -10,7 +10,7 @@ from lalamo.modules.activations import GELU
10
10
  from lalamo.modules.linear import FullPrecisionLinearConfig
11
11
  from lalamo.modules.mlp import DenseMLPConfig
12
12
  from lalamo.modules.normalization import NormalizationConfig, UpcastMode
13
- from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
13
+ from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig, YARNRoPEConfig
14
14
  from lalamo.modules.token_mixers.attention import AttentionConfig
15
15
  from lalamo.modules.transformer_layer import TransformerLayerConfig
16
16
 
@@ -19,9 +19,6 @@ from .common import HuggingFaceLMConfig
19
19
  __all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
20
20
 
21
21
 
22
- NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER = 6
23
-
24
-
25
22
  def _round_to_bfloat16(x: float) -> float:
26
23
  return jnp.asarray(x).astype(jnp.bfloat16).item()
27
24
 
@@ -32,6 +29,16 @@ class GemmaRoPEScalingConfig:
32
29
  rope_type: Literal["linear"]
33
30
 
34
31
 
32
+ @dataclass(frozen=True)
33
+ class YarnRopeScalingConfig:
34
+ factor: float
35
+ beta_fast: float
36
+ beta_slow: float
37
+ original_max_position_embeddings: int
38
+ rope_type: Literal["yarn"]
39
+ truncate: bool = False
40
+
41
+
35
42
  @dataclass(frozen=True)
36
43
  class HFGemma3TextConfigRaw:
37
44
  hidden_size: int
@@ -39,6 +46,7 @@ class HFGemma3TextConfigRaw:
39
46
  model_type: Literal["gemma3_text"]
40
47
  num_hidden_layers: int
41
48
  sliding_window: int
49
+ sliding_window_pattern: int
42
50
  rms_norm_eps: float = 1e-06
43
51
  query_pre_attn_scalar: float = 256.0
44
52
  attention_bias: bool = False
@@ -49,7 +57,7 @@ class HFGemma3TextConfigRaw:
49
57
  max_position_embeddings: int = 131072
50
58
  rope_theta: float = 1000000.0
51
59
  rope_local_base_freq: float = 10000.0
52
- rope_scaling: GemmaRoPEScalingConfig | None = None
60
+ rope_scaling: GemmaRoPEScalingConfig | YarnRopeScalingConfig | None = None
53
61
  final_logit_softcapping: float | None = None
54
62
  vocab_size: int = 262208
55
63
 
@@ -57,7 +65,7 @@ class HFGemma3TextConfigRaw:
57
65
  def sliding_window_sizes(self) -> list[int | None]:
58
66
  result = []
59
67
  for i in range(self.num_hidden_layers):
60
- if (i + 1) % NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER == 0:
68
+ if (i + 1) % self.sliding_window_pattern == 0:
61
69
  result.append(None)
62
70
  else:
63
71
  result.append(self.sliding_window)
@@ -74,7 +82,7 @@ class HFGemma3TextConfigRaw:
74
82
  attention_scale = self.query_pre_attn_scalar**-0.5
75
83
  embedding_config = TiedEmbeddingConfig(
76
84
  input_scale=input_scale,
77
- logit_soft_cap=None,
85
+ logit_soft_cap=self.final_logit_softcapping,
78
86
  precision=activation_precision,
79
87
  )
80
88
  rms_norm_config = NormalizationConfig(
@@ -86,19 +94,33 @@ class HFGemma3TextConfigRaw:
86
94
  subtract_mean=False,
87
95
  )
88
96
 
89
- if self.rope_scaling is not None:
97
+ if isinstance(self.rope_scaling, GemmaRoPEScalingConfig):
90
98
  global_rope_config = LinearScalingRoPEConfig(
91
99
  precision=activation_precision,
92
100
  base=self.rope_theta,
93
101
  max_sequence_length=self.max_position_embeddings,
94
102
  scaling_factor=self.rope_scaling.factor,
95
103
  )
96
- else:
104
+ elif isinstance(self.rope_scaling, YarnRopeScalingConfig):
105
+ global_rope_config = YARNRoPEConfig(
106
+ precision=activation_precision,
107
+ base=self.rope_theta,
108
+ scaling_factor=self.rope_scaling.factor,
109
+ max_sequence_length=self.max_position_embeddings,
110
+ original_context_length=self.rope_scaling.original_max_position_embeddings,
111
+ beta_fast=self.rope_scaling.beta_fast,
112
+ beta_slow=self.rope_scaling.beta_slow,
113
+ truncate=self.rope_scaling.truncate,
114
+ )
115
+ elif self.rope_scaling is None:
97
116
  global_rope_config = UnscaledRoPEConfig(
98
117
  precision=activation_precision,
99
118
  base=self.rope_theta,
100
119
  max_sequence_length=context_length or self.max_position_embeddings,
101
120
  )
121
+ else:
122
+ raise ValueError("Invalid rope scaling configuration")
123
+
102
124
  local_rope_config = UnscaledRoPEConfig(
103
125
  precision=activation_precision,
104
126
  base=self.rope_local_base_freq,
@@ -300,7 +300,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
300
300
  down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
301
301
  down_b = weights_dict[experts_path / "down_proj_bias"]
302
302
  if down_b.ndim == 1:
303
- down_b = jnp.broadcast_to(down_b, down_w.shape[:-1] + (down_b.shape[0],))
303
+ down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
304
304
 
305
305
  down_projection = load_parameters(
306
306
  lambda m: (m.weights, m.biases), # type: ignore
@@ -1,5 +1,6 @@
1
1
  from .common import FileSpec, ModelSpec, ModelType, UseCase, build_quantized_models
2
2
  from .deepseek import DEEPSEEK_MODELS
3
+ from .essential_ai import RNJ_MODELS
3
4
  from .gemma import GEMMA_MODELS
4
5
  from .gpt_oss import GPT_OSS_MODELS
5
6
  from .huggingface import HUGGINGFACE_MODELS
@@ -36,6 +37,7 @@ ALL_MODEL_LISTS = [
36
37
  QWEN_MODELS,
37
38
  REKA_MODELS,
38
39
  MIRAI_CLASSIFIER_MODELS,
40
+ RNJ_MODELS,
39
41
  ]
40
42
 
41
43
  ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
@@ -0,0 +1,17 @@
1
+ from lalamo.model_import.decoder_configs.huggingface import HFGemma3TextConfig
2
+
3
+ from .common import ModelSpec
4
+
5
+ __all__ = ["RNJ_MODELS"]
6
+
7
+ RNJ_MODELS = [
8
+ ModelSpec(
9
+ vendor="EssentialAI",
10
+ family="Rnj-1",
11
+ name="Rnj-1-Instruct",
12
+ size="8B",
13
+ quantization=None,
14
+ repo="EssentialAI/rnj-1-instruct",
15
+ config_type=HFGemma3TextConfig,
16
+ ),
17
+ ]
@@ -14,5 +14,5 @@ HUGGINGFACE_MODELS = [
14
14
  repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
15
15
  config_type=HFLlamaConfig,
16
16
  use_cases=tuple(),
17
- )
17
+ ),
18
18
  ]
lalamo/utils.py CHANGED
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "MapSequence",
25
25
  "jax_uint4_to_packed_uint8",
26
26
  "open_safetensors",
27
+ "process_chat_template",
27
28
  ]
28
29
 
29
30
 
@@ -159,3 +160,9 @@ def jax_uint8_to_unpacked_uint4(array: Array) -> Array:
159
160
  )
160
161
 
161
162
  return unpacked.astype(jnp.uint4)
163
+
164
+
165
+ def process_chat_template(template: str) -> str:
166
+ template = template.replace("{% generation %}", "")
167
+ template = template.replace("{%- endgeneration -%}", "")
168
+ return template
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.8
3
+ Version: 0.5.9
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,17 +1,17 @@
1
- lalamo/__init__.py,sha256=ZJ5Cjq4OoGVrjba9zUYIYnFGRKZkCkhBLaakdt4D008,814
1
+ lalamo/__init__.py,sha256=ANgYnkcN0qtWyEPNfJb_rcAmghdwvBrHUKE2WNN0zn4,814
2
2
  lalamo/common.py,sha256=5NUFD26yQgOnEEk3LaQnce8n-VwJxILkEpFesHZhtQU,3820
3
3
  lalamo/main.py,sha256=GgUT7lT48-XQuAEH7qzsDKG8Lx9iBf-sYBIRhZL9q7E,23978
4
4
  lalamo/message_processor.py,sha256=bSUAQg7CemLTnBV4LtPxJBicAalruDCA-JXjkTYPZ8U,5797
5
5
  lalamo/quantization.py,sha256=8o6ryIZLzzDYQuvBTboPfaVVdfijAKGpTxOcg3GKVD8,2752
6
6
  lalamo/registry_abc.py,sha256=ENjXiD_wEH100fNjG-W5Em1L_EQ0Lf0pdRhRGvf3qZk,2197
7
7
  lalamo/sampling.py,sha256=g_dNiJyZrRqoQIiLid4cr6nRT9N5tSz3GtHr8Bt4n-E,3404
8
- lalamo/utils.py,sha256=9kg5P19eaqGrSyAiNSbdfOwrv4s1PJZTHYdiNctlBSY,4368
8
+ lalamo/utils.py,sha256=QwATVXAeHBsQEDyt_31SHgxFphFVZYHpv3ZaklXks9Y,4585
9
9
  lalamo/data/__init__.py,sha256=exfhBLxHrg7BWutM0tAln5QuIWlNQmOhaG2noFYxfPI,189
10
10
  lalamo/data/huggingface_message.py,sha256=-7lN9eIcETQzt1Pnx3d4d8p3_I7WYMNf4mp1P91N7fI,1115
11
11
  lalamo/data/lalamo_completions.py,sha256=U_m3UNSJASUFz3rJq_taZOtL_U4B8Oj-ndkTF-JH-v4,1509
12
12
  lalamo/data/utils.py,sha256=B96gLaULyStKYuR8wjFdTpFc6YIDC8EEvGh1eiMe_Ec,338
13
13
  lalamo/model_import/__init__.py,sha256=Z8pS9rbKKx1QgUy7KZtHxiNWlZhII3mdovT9d37vAxg,168
14
- lalamo/model_import/common.py,sha256=tdZsteRsxL6DVUFwHw_1eeNLckflOdAaIm7Wm9eJzxM,12311
14
+ lalamo/model_import/common.py,sha256=wvyGD-iLut_Pm3HjDMI05upqdtCW3HWeoeB0YmiFeqk,12419
15
15
  lalamo/model_import/huggingface_generation_config.py,sha256=mot6VQ6ezCtEhN6VjhnvaU-nR5P5T2BuBUgpFNnWJxU,1495
16
16
  lalamo/model_import/huggingface_tokenizer_config.py,sha256=xvwdmio7b9nhn2H3uMBVligiYj58JaCFCvHY3-8dBvM,2502
17
17
  lalamo/model_import/decoder_configs/__init__.py,sha256=1ZqMcEHvCJjMIZ9iNyY31XMXOaFxB-NbqIU01BtmcEk,641
@@ -20,7 +20,7 @@ lalamo/model_import/decoder_configs/executorch.py,sha256=fTEG_j-7d8riR3Fu_H5tHDj
20
20
  lalamo/model_import/decoder_configs/huggingface/__init__.py,sha256=3H7GPTFNNahEvI8D1SGg2mGBgPhsIdZ213MglwbGDlE,645
21
21
  lalamo/model_import/decoder_configs/huggingface/common.py,sha256=YYIDEQy8x7lqL2qtxUHrNqfjZEiizBZ_26sTqOzjRtQ,3792
22
22
  lalamo/model_import/decoder_configs/huggingface/gemma2.py,sha256=g8LH_GlSNyL04WWi596zI0rWsD3ahnfNjDk-9zZNcDE,4759
23
- lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=KlhL7y6lW_cUgsT2JjvlQbsuKZggI8DG5wazZZBk0zM,7415
23
+ lalamo/model_import/decoder_configs/huggingface/gemma3.py,sha256=aSZ0TtpgDYA10rHi8eD0C_Jsn48siM_HXqfZ4O7nh94,8372
24
24
  lalamo/model_import/decoder_configs/huggingface/gpt_oss.py,sha256=MBCoPbuWyzbJiBRtHOtpaPHJjQ1UVCAYcVrfIejTnlQ,7446
25
25
  lalamo/model_import/decoder_configs/huggingface/llama.py,sha256=UPeQiz2Dix8YaZYRxn9z44OZJ6c4xBQmcUZcM0Ymvh4,6934
26
26
  lalamo/model_import/decoder_configs/huggingface/llamba.py,sha256=ANB-vQK8U-zVFubZSTDXXt2S70T5SVOGzf7eOVvPzIQ,5773
@@ -31,14 +31,15 @@ lalamo/model_import/decoder_configs/huggingface/qwen3.py,sha256=lySVO-TvusAYUjDn
31
31
  lalamo/model_import/loaders/__init__.py,sha256=3THc1wQ4EPBzQkL_4EaKCa7Ev5Z7oczcvc4AHy9v5EI,228
32
32
  lalamo/model_import/loaders/common.py,sha256=kkugV-bMQlN1zvGHoj3uc7z0FbXKoMtXEBTvyu4KxK4,1844
33
33
  lalamo/model_import/loaders/executorch.py,sha256=t2Ey_mBMNC8bTSTdYWjuGXdPTRoohFlYrqtWyNkBU_8,9219
34
- lalamo/model_import/loaders/huggingface.py,sha256=ITA0Y_kCDFL4Tanuvd1NWUvV77WEn0VEzkcX5Whlwys,29835
34
+ lalamo/model_import/loaders/huggingface.py,sha256=QURyxD3C4Nzwa8k9iHVx32hQHV-aMWjb29W5_U99-WA,29834
35
35
  lalamo/model_import/loaders/utils.py,sha256=eiX3WKFRrAfBY-dugodscNInl5o5w3KmVcgma4atpGY,2456
36
- lalamo/model_import/model_specs/__init__.py,sha256=V7S5Uo3GVBUG7KD0czMtmWZcQ-FJgryTZlxC7Abn_c0,1175
36
+ lalamo/model_import/model_specs/__init__.py,sha256=8RxLEZUxpsBtTwrTUqGIwhQ-8QzOxUdx-EL__cbcTjg,1228
37
37
  lalamo/model_import/model_specs/common.py,sha256=RVPlNWHG_5OvU1W3YcOpqYz59Dh8plDmd7z1xNrqmaY,6585
38
38
  lalamo/model_import/model_specs/deepseek.py,sha256=Umef93_ZBuq93yYsejIRNwj3udoln1gHfrv3SK5jyMo,417
39
+ lalamo/model_import/model_specs/essential_ai.py,sha256=xbHcwRpAWhR9gOgypVzcgunFspoUEk3iNsw-46CVR4o,390
39
40
  lalamo/model_import/model_specs/gemma.py,sha256=irWgylL-pc7y3Gn5DK3fjKoCT9kJWH3B7mTa-1Gmxqc,1306
40
41
  lalamo/model_import/model_specs/gpt_oss.py,sha256=PLo0QGrXKdX61ReTRdyOaP_EH3Dmj5lp3fpJjZRwRVA,542
41
- lalamo/model_import/model_specs/huggingface.py,sha256=eF8ItF5reFrFkjYxwiAJcFwUAlN6CpXfM-aQ8a92ItM,430
42
+ lalamo/model_import/model_specs/huggingface.py,sha256=TEkU8y95_hmUWyF-Q5hn0dE2SvXbApghAsQwhWRu4D0,431
42
43
  lalamo/model_import/model_specs/llama.py,sha256=Ml-xvRGlXBT9NJhmEpwgNo6C84oBSMYgA1_PrCYGcAw,990
43
44
  lalamo/model_import/model_specs/llamba.py,sha256=Ic3sWTv34FLJ4fG6OR_Mc5goGJQR6fa5b2WbVXbn9FA,1471
44
45
  lalamo/model_import/model_specs/mirai.py,sha256=eifYVV5-fABiLH6rr82_DiVFtDyqpW0vbvXCYsQQzto,617
@@ -80,9 +81,9 @@ lalamo/speculator/estimator.py,sha256=4D8dPZCWsrpORb7y8pQ6VsiIg1Cblvvxe6gXCoYtcD
80
81
  lalamo/speculator/inference.py,sha256=5GntUgj0HQLeLn3HIHnVX8EEO0EBzmKeP5-_U7kdFAM,3670
81
82
  lalamo/speculator/ngram.py,sha256=95mdfAWhx4d5XOnOwhyhElnvcy6nlUjYhcbJzqDs414,5875
82
83
  lalamo/speculator/utils.py,sha256=0wZoMMIzzk0Q-3zq5H5f-JBplePNHxywndkrNtOJOyo,1697
83
- lalamo-0.5.8.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
84
- lalamo-0.5.8.dist-info/METADATA,sha256=miYVR0hj7X-d1X09Bwaqf9-zKUqmljZ2qrhkV1rLICQ,3146
85
- lalamo-0.5.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
86
- lalamo-0.5.8.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
87
- lalamo-0.5.8.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
88
- lalamo-0.5.8.dist-info/RECORD,,
84
+ lalamo-0.5.9.dist-info/licenses/LICENSE,sha256=diHRfjSEJHD1nnEeMIfMRCjR3UERf8bT3eseD6b1ayA,1072
85
+ lalamo-0.5.9.dist-info/METADATA,sha256=573oeEuYV14_hFpPmW2CNVZWciVS4_V85597oKOvjpo,3146
86
+ lalamo-0.5.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
87
+ lalamo-0.5.9.dist-info/entry_points.txt,sha256=qli7qTfnBk5WP10rOGXXEckHMtt-atJMDWd8jN89Uks,43
88
+ lalamo-0.5.9.dist-info/top_level.txt,sha256=VHvWL5JN5XRG36NsN_MieJ7EwRihEOrEjyDaTdFJ-aI,7
89
+ lalamo-0.5.9.dist-info/RECORD,,
File without changes