lalamo 0.4.1__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {lalamo-0.4.1 → lalamo-0.5.1}/PKG-INFO +3 -2
  2. {lalamo-0.4.1 → lalamo-0.5.1}/README.md +2 -1
  3. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/__init__.py +1 -1
  4. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/language_model.py +22 -23
  5. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/main.py +2 -16
  6. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/common.py +24 -6
  7. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/__init__.py +2 -0
  8. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/common.py +4 -4
  9. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/executorch.py +17 -10
  10. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  11. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/common.py +37 -2
  12. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +33 -28
  13. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +33 -26
  14. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +36 -29
  15. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/llama.py +14 -12
  16. lalamo-0.5.1/lalamo/model_import/decoder_configs/huggingface/llamba.py +170 -0
  17. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/mistral.py +31 -30
  18. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +33 -25
  19. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +55 -28
  20. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/loaders/executorch.py +5 -4
  21. lalamo-0.5.1/lalamo/model_import/loaders/huggingface.py +653 -0
  22. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/__init__.py +2 -0
  23. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/common.py +16 -5
  24. lalamo-0.5.1/lalamo/model_import/model_specs/llamba.py +40 -0
  25. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/qwen.py +29 -1
  26. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/__init__.py +33 -6
  27. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/activations.py +9 -2
  28. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/common.py +10 -5
  29. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/decoder.py +93 -97
  30. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/decoder_layer.py +85 -103
  31. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/embedding.py +279 -5
  32. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/linear.py +335 -30
  33. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/mlp.py +6 -7
  34. lalamo-0.5.1/lalamo/modules/mlx_interop.py +19 -0
  35. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/rope.py +1 -1
  36. lalamo-0.5.1/lalamo/modules/token_mixers/__init__.py +30 -0
  37. {lalamo-0.4.1/lalamo/modules → lalamo-0.5.1/lalamo/modules/token_mixers}/attention.py +72 -70
  38. lalamo-0.5.1/lalamo/modules/token_mixers/common.py +78 -0
  39. lalamo-0.5.1/lalamo/modules/token_mixers/mamba.py +553 -0
  40. lalamo-0.5.1/lalamo/modules/token_mixers/state/__init__.py +12 -0
  41. lalamo-0.5.1/lalamo/modules/token_mixers/state/common.py +26 -0
  42. {lalamo-0.4.1/lalamo/modules → lalamo-0.5.1/lalamo/modules/token_mixers/state}/kv_cache.py +5 -16
  43. lalamo-0.5.1/lalamo/modules/token_mixers/state/mamba_state.py +51 -0
  44. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/utils.py +24 -2
  45. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo.egg-info/PKG-INFO +3 -2
  46. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo.egg-info/SOURCES.txt +13 -2
  47. {lalamo-0.4.1 → lalamo-0.5.1}/pyproject.toml +1 -1
  48. {lalamo-0.4.1 → lalamo-0.5.1}/tests/test_generation.py +4 -4
  49. lalamo-0.5.1/tests/test_huggingface_models.py +24 -0
  50. lalamo-0.5.1/tests/test_mlx_models.py +20 -0
  51. lalamo-0.5.1/tests/test_models.py +456 -0
  52. lalamo-0.4.1/lalamo/model_import/loaders/huggingface.py +0 -401
  53. lalamo-0.4.1/tests/test_huggingface_models.py +0 -87
  54. {lalamo-0.4.1 → lalamo-0.5.1}/LICENSE +0 -0
  55. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/common.py +0 -0
  56. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/data/__init__.py +0 -0
  57. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/data/huggingface_message.py +0 -0
  58. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/data/lalamo_completions.py +0 -0
  59. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/data/utils.py +0 -0
  60. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/message_processor.py +0 -0
  61. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/__init__.py +0 -0
  62. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/huggingface_generation_config.py +0 -0
  63. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  64. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/loaders/__init__.py +0 -0
  65. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/loaders/common.py +0 -0
  66. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/loaders/utils.py +0 -0
  67. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/deepseek.py +0 -0
  68. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/gemma.py +0 -0
  69. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  70. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/huggingface.py +0 -0
  71. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/llama.py +0 -0
  72. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/mistral.py +0 -0
  73. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/pleias.py +0 -0
  74. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/polaris.py +0 -0
  75. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/model_import/model_specs/reka.py +0 -0
  76. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/normalization.py +0 -0
  77. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/torch_interop.py +0 -0
  78. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/modules/utils.py +0 -0
  79. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/quantization.py +0 -0
  80. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/registry_abc.py +0 -0
  81. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/sampling.py +0 -0
  82. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/speculator/__init__.py +0 -0
  83. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/speculator/common.py +0 -0
  84. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/speculator/inference.py +0 -0
  85. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/speculator/ngram.py +0 -0
  86. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo/speculator/utils.py +0 -0
  87. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo.egg-info/dependency_links.txt +0 -0
  88. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo.egg-info/entry_points.txt +0 -0
  89. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo.egg-info/requires.txt +0 -0
  90. {lalamo-0.4.1 → lalamo-0.5.1}/lalamo.egg-info/top_level.txt +0 -0
  91. {lalamo-0.4.1 → lalamo-0.5.1}/setup.cfg +0 -0
  92. {lalamo-0.4.1 → lalamo-0.5.1}/tests/test_model_spec.py +0 -0
  93. {lalamo-0.4.1 → lalamo-0.5.1}/tests/test_moe.py +0 -0
  94. {lalamo-0.4.1 → lalamo-0.5.1}/tests/test_parameter_tree.py +0 -0
  95. {lalamo-0.4.1 → lalamo-0.5.1}/tests/test_registry_abc.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.4.1
3
+ Version: 0.5.1
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -38,7 +38,8 @@ Dynamic: license-file
38
38
 
39
39
  <a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
40
40
  <a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
41
- <a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
41
+ <a href="https://discord.com/invite/trymirai"><img src="https://img.shields.io/discord/1377764166764462120?label=Discord" alt="Discord"></a>
42
+ <a href="mailto:contact@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
42
43
  <a href="https://docs.trymirai.com/overview/lalamo"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
43
44
  [![License](https://img.shields.io/badge/License-MIT-blue)](LICENSE)
44
45
 
@@ -6,7 +6,8 @@
6
6
 
7
7
  <a href="https://artifacts.trymirai.com/social/about_us.mp3"><img src="https://img.shields.io/badge/Listen-Podcast-red" alt="Listen to our podcast"></a>
8
8
  <a href="https://docsend.com/v/76bpr/mirai2025"><img src="https://img.shields.io/badge/View-Deck-red" alt="View our deck"></a>
9
- <a href="mailto:alexey@getmirai.co,dima@getmirai.co,aleksei@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
9
+ <a href="https://discord.com/invite/trymirai"><img src="https://img.shields.io/discord/1377764166764462120?label=Discord" alt="Discord"></a>
10
+ <a href="mailto:contact@getmirai.co?subject=Interested%20in%20Mirai"><img src="https://img.shields.io/badge/Send-Email-green" alt="Contact us"></a>
10
11
  <a href="https://docs.trymirai.com/overview/lalamo"><img src="https://img.shields.io/badge/Read-Docs-blue" alt="Read docs"></a>
11
12
  [![License](https://img.shields.io/badge/License-MIT-blue)](LICENSE)
12
13
 
@@ -10,7 +10,7 @@ from lalamo.message_processor import (
10
10
  )
11
11
  from lalamo.model_import import ModelSpec, import_model
12
12
 
13
- __version__ = "0.4.1"
13
+ __version__ = "0.5.1"
14
14
 
15
15
  __all__ = [
16
16
  "AssistantMessage",
@@ -14,8 +14,7 @@ from tokenizers import Tokenizer
14
14
 
15
15
  from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
16
16
  from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
17
- from lalamo.modules import Decoder, DecoderConfig, KVCache, LalamoModule, config_converter
18
- from lalamo.modules.common import ForwardPassMode
17
+ from lalamo.modules import Decoder, DecoderConfig, ForwardPassMode, LalamoModule, State, config_converter
19
18
  from lalamo.modules.decoder import DecoderForwardPassConfig
20
19
  from lalamo.sampling import SamplingPolicy, make_policy
21
20
  from lalamo.utils import open_safetensors
@@ -37,13 +36,13 @@ type ForwardPassConfig = DecoderForwardPassConfig
37
36
  class PrefillResults(NamedTuple):
38
37
  last_token_logits: Float[Array, "batch vocabulary"]
39
38
  last_token_indices: Int[Array, " batch"]
40
- kv_cache: KVCache
39
+ state: State
41
40
 
42
41
 
43
42
  class DecodingState(NamedTuple):
44
43
  last_token_logits: Float[Array, "batch vocabulary"]
45
44
  last_token_indices: Int[Array, " batch"]
46
- kv_cache: KVCache
45
+ state: State
47
46
  stop_flags: Bool[Array, " batch"]
48
47
 
49
48
 
@@ -89,7 +88,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
89
88
  with open(path / "config.json") as config_file:
90
89
  config_json = json.load(config_file)
91
90
  config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
92
- with open_safetensors(path / "model.safetensors") as weights_dict:
91
+ with open_safetensors(path / "model.safetensors") as (weights_dict, _):
93
92
  weights = unflatten_parameters(weights_dict)
94
93
  decoder = config.decoder_config.empty().import_weights(weights)
95
94
  tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
@@ -124,21 +123,21 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
124
123
  self,
125
124
  token_ids: Int[Array, "batch tokens"],
126
125
  lengths_without_padding: Int[Array, " batch"] | None = None,
127
- kv_cache_capacity: int | None = None,
126
+ state_capacity: int | None = None,
128
127
  forward_pass_config: ForwardPassConfig | None = None,
129
128
  ) -> PrefillResults:
130
129
  batch_size, sequence_length = token_ids.shape
131
130
  token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
132
- if kv_cache_capacity is not None:
133
- kv_cache = self.decoder.init_static_kv_cache(batch_size, kv_cache_capacity)
131
+ if state_capacity is not None:
132
+ state = self.decoder.init_static_state(batch_size, state_capacity)
134
133
  else:
135
- kv_cache = None
134
+ state = None
136
135
 
137
136
  decoder_outputs = self.decoder(
138
137
  token_ids,
139
138
  token_positions,
140
- kv_cache,
141
- return_updated_kv_cache=True,
139
+ state,
140
+ return_updated_state=True,
142
141
  lengths_without_padding=lengths_without_padding,
143
142
  forward_pass_mode=ForwardPassMode.MULTI_TOKEN,
144
143
  forward_pass_config=forward_pass_config,
@@ -151,11 +150,11 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
151
150
 
152
151
  last_token_logits = vmap(lambda logits, index: logits[index])(decoder_outputs.logits, last_logits_indices)
153
152
 
154
- assert decoder_outputs.updated_kv_cache is not None
153
+ assert decoder_outputs.updated_state is not None
155
154
  return PrefillResults(
156
155
  last_token_logits=last_token_logits,
157
156
  last_token_indices=last_logits_indices,
158
- kv_cache=decoder_outputs.updated_kv_cache,
157
+ state=decoder_outputs.updated_state,
159
158
  )
160
159
 
161
160
  @eqx.filter_jit
@@ -187,7 +186,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
187
186
  initial_state = DecodingState(
188
187
  prefill_results.last_token_logits,
189
188
  prefill_results.last_token_indices,
190
- prefill_results.kv_cache,
189
+ prefill_results.state,
191
190
  jnp.zeros(batch_size, dtype=jnp.bool),
192
191
  )
193
192
 
@@ -224,16 +223,16 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
224
223
  decoder_outputs = self.decoder(
225
224
  next_token_ids[:, None],
226
225
  next_token_indices[:, None],
227
- state.kv_cache,
228
- return_updated_kv_cache=True,
226
+ state.state,
227
+ return_updated_state=True,
229
228
  forward_pass_mode=forward_pass_mode,
230
229
  forward_pass_config=forward_pass_config,
231
230
  )
232
- assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
231
+ assert decoder_outputs.updated_state is not None, "updated_state should not be None"
233
232
  new_state = DecodingState(
234
233
  decoder_outputs.logits.squeeze(1),
235
234
  next_token_indices,
236
- decoder_outputs.updated_kv_cache,
235
+ decoder_outputs.updated_state,
237
236
  stop_flags,
238
237
  )
239
238
  return new_state, GenerationStepResults(next_token_ids, next_top_k_token_ids, next_top_k_token_logits)
@@ -338,7 +337,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
338
337
  state = DecodingState(
339
338
  prefill_results.last_token_logits,
340
339
  prefill_results.last_token_indices,
341
- prefill_results.kv_cache,
340
+ prefill_results.state,
342
341
  jnp.array([0], dtype=jnp.bool),
343
342
  )
344
343
 
@@ -356,14 +355,14 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
356
355
  decoder_outputs = self.decoder(
357
356
  next_token_id.reshape(1, 1),
358
357
  next_token_indices.reshape(1, 1),
359
- state.kv_cache,
360
- return_updated_kv_cache=True,
358
+ state.state,
359
+ return_updated_state=True,
361
360
  forward_pass_config=forward_pass_config,
362
361
  )
363
- assert decoder_outputs.updated_kv_cache is not None, "updated_kv_cache should not be None"
362
+ assert decoder_outputs.updated_state is not None, "updated_state should not be None"
364
363
  state = DecodingState(
365
364
  decoder_outputs.logits.squeeze(1),
366
365
  next_token_indices,
367
- decoder_outputs.updated_kv_cache,
366
+ decoder_outputs.updated_state,
368
367
  state.stop_flags,
369
368
  )
@@ -27,7 +27,6 @@ from rich.progress import (
27
27
  TextColumn,
28
28
  TimeElapsedColumn,
29
29
  TimeRemainingColumn,
30
- track,
31
30
  )
32
31
  from rich.table import Table
33
32
  from safetensors.flax import save_file
@@ -50,7 +49,6 @@ from lalamo.modules import config_converter
50
49
  from lalamo.speculator.inference import CollectTracesEvent, inference_collect_traces
51
50
  from lalamo.speculator.ngram import NGramSpeculator
52
51
  from lalamo.speculator.utils import SpeculatorTrainingEvent, test_speculator, train_speculator
53
- from lalamo.utils import jax_uint4_to_packed_uint8
54
52
 
55
53
  SCRIPT_NAME = Path(sys.argv[0]).name
56
54
 
@@ -109,16 +107,6 @@ def _error(message: str) -> None:
109
107
  raise Exit(1)
110
108
 
111
109
 
112
- def _pack_uint4_weights(weights: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]:
113
- packed_weights = {}
114
- for key, value in weights.items():
115
- if value.dtype == jnp.uint4:
116
- packed_weights[key] = jax_uint4_to_packed_uint8(value)
117
- else:
118
- packed_weights[key] = value
119
- return packed_weights
120
-
121
-
122
110
  @app.command(help="Chat with a converted model.")
123
111
  def chat(
124
112
  model_path: Annotated[
@@ -274,7 +262,7 @@ def convert(
274
262
  result = model.decoder(
275
263
  token_ids,
276
264
  token_positions,
277
- return_updated_kv_cache=True,
265
+ return_updated_state=True,
278
266
  return_activation_trace=True,
279
267
  )
280
268
  traces = flatten_parameters(result.export())
@@ -286,8 +274,7 @@ def convert(
286
274
  weights = flatten_parameters(model.export_weights())
287
275
  del model
288
276
 
289
- packed_weights = _pack_uint4_weights(weights)
290
- save_file(packed_weights, output_dir / "model.safetensors")
277
+ save_file(weights, output_dir / "model.safetensors")
291
278
 
292
279
  config_json = config_converter.unstructure(metadata, ModelMetadata)
293
280
  with open(output_dir / "config.json", "w") as file:
@@ -511,7 +498,6 @@ def train(
511
498
  ) as progress:
512
499
  inference_task = progress.add_task("🔮 [cyan]Training speculator...[/cyan]", total=subsample_size)
513
500
 
514
-
515
501
  def progress_callback(event: SpeculatorTrainingEvent) -> None:
516
502
  progress.update(inference_task, completed=event.trained_tokens)
517
503
 
@@ -1,4 +1,5 @@
1
1
  import importlib.metadata
2
+ import json
2
3
  from collections import ChainMap
3
4
  from collections.abc import Callable
4
5
  from contextlib import ExitStack
@@ -14,6 +15,7 @@ from tokenizers import Tokenizer
14
15
 
15
16
  from lalamo.language_model import GenerationConfig, LanguageModel, LanguageModelConfig
16
17
  from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
18
+ from lalamo.model_import.model_specs.common import JSONFieldSpec
17
19
  from lalamo.quantization import QuantizationMode
18
20
 
19
21
  from .huggingface_generation_config import HFGenerationConfig
@@ -130,10 +132,17 @@ def import_message_processor(
130
132
  )
131
133
  tokenizer_config = HFTokenizerConfig.from_json(tokenizer_config_file)
132
134
  if tokenizer_config.chat_template is None:
133
- if model_spec.configs.chat_template is None:
134
- raise ValueError("Missiing chat template.")
135
- chat_template_file = download_file(model_spec.configs.chat_template, model_spec.repo, output_dir)
136
- prompt_template = chat_template_file.read_text()
135
+ match model_spec.configs.chat_template:
136
+ case JSONFieldSpec(file_spec, field_name):
137
+ json_file = download_file(file_spec, model_spec.repo, output_dir)
138
+ with open(json_file) as file:
139
+ json_dict = json.load(file)
140
+ prompt_template = json_dict[field_name]
141
+ case FileSpec(_) as file_spec:
142
+ chat_template_file = download_file(file_spec, model_spec.repo, output_dir)
143
+ prompt_template = chat_template_file.read_text()
144
+ case None:
145
+ raise ValueError("No chat template specified.")
137
146
  else:
138
147
  if model_spec.configs.chat_template is not None:
139
148
  raise ValueError("Conflicting chat template specifications.")
@@ -180,15 +189,24 @@ def import_model(
180
189
  weights_paths = download_weights(model_spec, progress_callback=progress_callback)
181
190
  with ExitStack() as stack:
182
191
  weights_shards = []
192
+ metadata_shards = []
183
193
  for weights_path in weights_paths:
184
- weights_shard = stack.enter_context(model_spec.weights_type.load(weights_path, precision))
194
+ weights_shard, metadata_shard = stack.enter_context(model_spec.weights_type.load(weights_path, precision))
185
195
  weights_shards.append(weights_shard)
196
+ metadata_shards.append(metadata_shard)
186
197
  weights_dict: ChainMap[str, Array] = ChainMap(*weights_shards)
198
+ metadata_dict: ChainMap[str, str] = ChainMap(*metadata_shards)
187
199
 
188
200
  if progress_callback is not None:
189
201
  progress_callback(InitializingModelEvent())
190
202
 
191
- decoder = foreign_decoder_config.load_decoder(context_length, precision, accumulation_precision, weights_dict)
203
+ decoder = foreign_decoder_config.load_decoder(
204
+ context_length,
205
+ precision,
206
+ accumulation_precision,
207
+ weights_dict,
208
+ metadata_dict,
209
+ )
192
210
 
193
211
  if progress_callback is not None:
194
212
  progress_callback(FinishedInitializingModelEvent())
@@ -7,6 +7,7 @@ from .huggingface import (
7
7
  HFGemma3TextConfig,
8
8
  HFGPTOssConfig,
9
9
  HFLlamaConfig,
10
+ HFLlambaConfig,
10
11
  HFMistralConfig,
11
12
  HFQwen2Config,
12
13
  HFQwen3Config,
@@ -20,6 +21,7 @@ __all__ = [
20
21
  "HFGemma3Config",
21
22
  "HFGemma3TextConfig",
22
23
  "HFLlamaConfig",
24
+ "HFLlambaConfig",
23
25
  "HFMistralConfig",
24
26
  "HFQwen2Config",
25
27
  "HFQwen3Config",
@@ -19,11 +19,9 @@ class ForeignConfig(RegistryABC):
19
19
  _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
20
20
  _converter.register_structure_hook(int | list[int], lambda v, _: v)
21
21
 
22
- eos_token_id: int | list[int]
23
-
24
22
  @property
25
23
  def eos_token_ids(self) -> list[int]:
26
- return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
24
+ raise NotImplementedError
27
25
 
28
26
  @property
29
27
  @abstractmethod
@@ -41,6 +39,7 @@ class ForeignConfig(RegistryABC):
41
39
  context_length: int | None,
42
40
  activation_precision: DTypeLike,
43
41
  accumulation_precision: DTypeLike,
42
+ metadata_dict: Mapping[str, str],
44
43
  ) -> DecoderConfig:
45
44
  raise NotImplementedError
46
45
 
@@ -58,7 +57,8 @@ class ForeignConfig(RegistryABC):
58
57
  activation_precision: DTypeLike,
59
58
  accumulation_precision: DTypeLike,
60
59
  weights_dict: Mapping[str, Array],
60
+ metadata_dict: Mapping[str, str],
61
61
  ) -> Decoder:
62
- config = self.to_decoder_config(context_length, activation_precision, accumulation_precision)
62
+ config = self.to_decoder_config(context_length, activation_precision, accumulation_precision, metadata_dict)
63
63
  model = config.empty()
64
64
  return self._load_weights(model, weights_dict)
@@ -51,6 +51,12 @@ class LoraConfig:
51
51
 
52
52
  @dataclass(frozen=True)
53
53
  class ExecutorchConfig(ForeignConfig):
54
+ eos_token_id: int | list[int]
55
+
56
+ @property
57
+ def eos_token_ids(self) -> list[int]:
58
+ return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
59
+
54
60
  @property
55
61
  def default_precision(self) -> DTypeLike:
56
62
  return jnp.bfloat16
@@ -89,6 +95,7 @@ class ETLlamaConfig(ExecutorchConfig):
89
95
  context_length: int | None,
90
96
  activation_precision: DTypeLike,
91
97
  accumulation_precision: DTypeLike,
98
+ metadata_dict: Mapping[str, str], # noqa: ARG002
92
99
  ) -> DecoderConfig:
93
100
  if self.lora_args is None:
94
101
  raise ValueError("We only support QLoRA models for now.")
@@ -136,6 +143,12 @@ class ETLlamaConfig(ExecutorchConfig):
136
143
  has_sinks=False,
137
144
  has_qkv_biases=False,
138
145
  has_out_biases=False,
146
+ num_heads=self.n_heads,
147
+ num_groups=self.n_kv_heads,
148
+ head_dim=self.dim // self.n_heads,
149
+ is_causal=True,
150
+ scale=None,
151
+ sliding_window_size=None,
139
152
  )
140
153
  mlp_config = DenseMLPConfig(
141
154
  linear_config=linear_config,
@@ -146,9 +159,9 @@ class ETLlamaConfig(ExecutorchConfig):
146
159
  gate_clipping=None,
147
160
  )
148
161
  decoder_layer_config = DecoderLayerConfig(
149
- pre_attention_norm_config=rmsnorm_config,
150
- attention_config=attention_config,
151
- post_attention_norm_config=None,
162
+ pre_mixer_norm_config=rmsnorm_config,
163
+ mixer_config=attention_config,
164
+ post_mixer_norm_config=None,
152
165
  pre_mlp_norm_config=rmsnorm_config,
153
166
  mlp_config=mlp_config,
154
167
  post_mlp_norm_config=None,
@@ -157,16 +170,10 @@ class ETLlamaConfig(ExecutorchConfig):
157
170
  embedding_config=embedding_config,
158
171
  global_rope_config=rope_config,
159
172
  local_rope_config=None,
160
- layer_config=decoder_layer_config,
173
+ layer_configs=(decoder_layer_config,) * self.n_layers,
161
174
  output_norm_config=rmsnorm_config,
162
175
  vocab_size=self.vocab_size,
163
176
  model_dim=self.dim,
164
177
  hidden_dim=self._find_hidden_size(),
165
- num_heads=self.n_heads,
166
- num_groups=self.n_kv_heads,
167
- head_dim=self.dim // self.n_heads,
168
- attention_scale=None,
169
- num_layers=self.n_layers,
170
- sliding_window_sizes=None,
171
178
  context_length=context_length or MAX_SEQUENCE_LENGTH,
172
179
  )
@@ -3,6 +3,7 @@ from .gemma2 import HFGemma2Config
3
3
  from .gemma3 import HFGemma3Config, HFGemma3TextConfig
4
4
  from .gpt_oss import HFGPTOssConfig
5
5
  from .llama import HFLlamaConfig
6
+ from .llamba import HFLlambaConfig
6
7
  from .mistral import HFMistralConfig
7
8
  from .qwen2 import HFQwen2Config
8
9
  from .qwen3 import HFQwen3Config
@@ -13,6 +14,7 @@ __all__ = [
13
14
  "HFGemma3Config",
14
15
  "HFGemma3TextConfig",
15
16
  "HFLlamaConfig",
17
+ "HFLlambaConfig",
16
18
  "HFMistralConfig",
17
19
  "HFQwen2Config",
18
20
  "HFQwen3Config",
@@ -1,7 +1,8 @@
1
1
  from collections.abc import Mapping
2
2
  from dataclasses import dataclass
3
- from typing import Literal
3
+ from typing import ClassVar, Literal
4
4
 
5
+ import cattrs
5
6
  import jax.numpy as jnp
6
7
  from jaxtyping import Array, DTypeLike
7
8
 
@@ -56,11 +57,45 @@ class GPTQQuantizationConfig:
56
57
  sym: bool
57
58
 
58
59
 
60
+ @dataclass(frozen=True)
61
+ class MLXQuantizationConfig:
62
+ group_size: int
63
+ bits: int
64
+
65
+
66
+ QuantizationConfigType = AWQQuantizationConfig | GPTQQuantizationConfig | MLXQuantizationConfig | None
67
+
68
+
69
+ def _structure_quantization_config(v: object, _: object) -> QuantizationConfigType:
70
+ match v:
71
+ case None:
72
+ return None
73
+
74
+ case {"quant_method": "awq", **_other}:
75
+ return cattrs.structure(v, AWQQuantizationConfig)
76
+
77
+ case {"quant_method": "gptq", **_other}:
78
+ return cattrs.structure(v, GPTQQuantizationConfig)
79
+
80
+ case {**_other}:
81
+ return cattrs.structure(v, MLXQuantizationConfig)
82
+
83
+ case _:
84
+ raise RuntimeError(f"Cannot structure {v}field")
85
+
86
+
59
87
  @dataclass(frozen=True)
60
88
  class HuggingFaceConfig(ForeignConfig):
89
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
90
+ _converter.register_structure_hook(int | list[int], lambda v, _: v)
91
+ _converter.register_structure_hook(QuantizationConfigType, _structure_quantization_config)
92
+
61
93
  @property
62
94
  def eos_token_ids(self) -> list[int]:
63
- return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
95
+ if not hasattr(self, "eos_token_id"):
96
+ raise RuntimeError("model doesn't havve eos_token_id, override eos_token_ids in model config")
97
+
98
+ return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id # type: ignore (This is a bug in pyright)
64
99
 
65
100
  @property
66
101
  def default_precision(self) -> DTypeLike:
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
@@ -57,10 +58,8 @@ class HFGemma2Config(HuggingFaceConfig):
57
58
  context_length: int | None,
58
59
  activation_precision: DTypeLike,
59
60
  accumulation_precision: DTypeLike,
61
+ metadata_dict: Mapping[str, str], # noqa: ARG002
60
62
  ) -> DecoderConfig:
61
- sliding_window_sizes = tuple(
62
- self.sliding_window if not bool(i % 2) else None for i in range(self.num_hidden_layers)
63
- )
64
63
  embedding_input_scale = self.hidden_size**0.5
65
64
  attention_scale = self.query_pre_attn_scalar**-0.5
66
65
  embedding_config = TiedEmbeddingConfig(
@@ -83,16 +82,6 @@ class HFGemma2Config(HuggingFaceConfig):
83
82
  linear_config = FullPrecisionLinearConfig(
84
83
  precision=activation_precision,
85
84
  )
86
- attention_config = AttentionConfig(
87
- qkv_projection_config=linear_config,
88
- out_projection_config=linear_config,
89
- query_norm_config=None,
90
- key_norm_config=None,
91
- logit_soft_cap=self.attn_logit_softcapping,
92
- has_sinks=False,
93
- has_qkv_biases=self.attention_bias,
94
- has_out_biases=False,
95
- )
96
85
  mlp_config = DenseMLPConfig(
97
86
  linear_config=linear_config,
98
87
  activation=GELU(),
@@ -101,28 +90,44 @@ class HFGemma2Config(HuggingFaceConfig):
101
90
  up_clipping=None,
102
91
  gate_clipping=None,
103
92
  )
104
- decoder_layer_config = DecoderLayerConfig(
105
- pre_attention_norm_config=rmsnorm_config,
106
- attention_config=attention_config,
107
- post_attention_norm_config=rmsnorm_config,
108
- pre_mlp_norm_config=rmsnorm_config,
109
- mlp_config=mlp_config,
110
- post_mlp_norm_config=rmsnorm_config,
111
- )
93
+
94
+ layer_configs = []
95
+ for i in range(self.num_hidden_layers):
96
+ sliding_window_size = self.sliding_window if not bool(i % 2) else None
97
+ attention_config = AttentionConfig(
98
+ qkv_projection_config=linear_config,
99
+ out_projection_config=linear_config,
100
+ query_norm_config=None,
101
+ key_norm_config=None,
102
+ logit_soft_cap=self.attn_logit_softcapping,
103
+ has_sinks=False,
104
+ has_qkv_biases=self.attention_bias,
105
+ has_out_biases=False,
106
+ num_heads=self.num_attention_heads,
107
+ num_groups=self.num_key_value_heads,
108
+ head_dim=self.head_dim,
109
+ is_causal=True,
110
+ scale=attention_scale,
111
+ sliding_window_size=sliding_window_size,
112
+ )
113
+ decoder_layer_config = DecoderLayerConfig(
114
+ pre_mixer_norm_config=rmsnorm_config,
115
+ mixer_config=attention_config,
116
+ post_mixer_norm_config=rmsnorm_config,
117
+ pre_mlp_norm_config=rmsnorm_config,
118
+ mlp_config=mlp_config,
119
+ post_mlp_norm_config=rmsnorm_config,
120
+ )
121
+ layer_configs.append(decoder_layer_config)
122
+
112
123
  return DecoderConfig(
113
124
  embedding_config=embedding_config,
114
125
  global_rope_config=rope_config,
115
126
  local_rope_config=None,
116
- layer_config=decoder_layer_config,
127
+ layer_configs=tuple(layer_configs),
117
128
  output_norm_config=rmsnorm_config,
118
129
  vocab_size=self.vocab_size,
119
130
  model_dim=self.hidden_size,
120
131
  hidden_dim=self.intermediate_size,
121
- num_heads=self.num_attention_heads,
122
- num_groups=self.num_key_value_heads,
123
- head_dim=self.head_dim,
124
- attention_scale=attention_scale,
125
- num_layers=self.num_hidden_layers,
126
- sliding_window_sizes=sliding_window_sizes,
127
132
  context_length=context_length or self.max_position_embeddings,
128
133
  )