lalamo 0.5.14__tar.gz → 0.5.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. {lalamo-0.5.14 → lalamo-0.5.16}/PKG-INFO +1 -1
  2. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/__init__.py +1 -1
  3. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/llama.py +32 -21
  4. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/qwen3.py +33 -22
  5. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/loaders/huggingface.py +49 -1
  6. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/qwen.py +14 -0
  7. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/__init__.py +4 -0
  8. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/embedding.py +169 -0
  9. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo.egg-info/PKG-INFO +1 -1
  10. {lalamo-0.5.14 → lalamo-0.5.16}/LICENSE +0 -0
  11. {lalamo-0.5.14 → lalamo-0.5.16}/README.md +0 -0
  12. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/common.py +0 -0
  13. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/data/__init__.py +0 -0
  14. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/data/huggingface_message.py +0 -0
  15. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/data/lalamo_completions.py +0 -0
  16. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/data/utils.py +0 -0
  17. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/main.py +0 -0
  18. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/message_processor.py +0 -0
  19. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/__init__.py +0 -0
  20. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/common.py +0 -0
  21. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/__init__.py +0 -0
  22. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/common.py +0 -0
  23. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/executorch.py +0 -0
  24. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/__init__.py +0 -0
  25. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/common.py +0 -0
  26. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/gemma2.py +0 -0
  27. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/gemma3.py +0 -0
  28. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +0 -0
  29. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/lfm2.py +0 -0
  30. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/llamba.py +0 -0
  31. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/mistral.py +0 -0
  32. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/modern_bert.py +0 -0
  33. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/decoder_configs/huggingface/qwen2.py +0 -0
  34. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/huggingface_generation_config.py +0 -0
  35. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/huggingface_tokenizer_config.py +0 -0
  36. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/loaders/__init__.py +0 -0
  37. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/loaders/common.py +0 -0
  38. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/loaders/executorch.py +0 -0
  39. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/loaders/utils.py +0 -0
  40. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/__init__.py +0 -0
  41. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/common.py +0 -0
  42. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/deepseek.py +0 -0
  43. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/essential_ai.py +0 -0
  44. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/gemma.py +0 -0
  45. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/gpt_oss.py +0 -0
  46. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/huggingface.py +0 -0
  47. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/lfm2.py +0 -0
  48. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/llama.py +0 -0
  49. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/llamba.py +0 -0
  50. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/mirai.py +0 -0
  51. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/mistral.py +0 -0
  52. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/pleias.py +0 -0
  53. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/polaris.py +0 -0
  54. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/model_import/model_specs/reka.py +0 -0
  55. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/models/__init__.py +0 -0
  56. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/models/classifier.py +0 -0
  57. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/models/common.py +0 -0
  58. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/models/language_model.py +0 -0
  59. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/activations.py +0 -0
  60. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/classifier.py +0 -0
  61. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/common.py +0 -0
  62. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/decoder.py +0 -0
  63. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/linear.py +0 -0
  64. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/mlp.py +0 -0
  65. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/mlx_interop.py +0 -0
  66. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/normalization.py +0 -0
  67. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/rope.py +0 -0
  68. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/__init__.py +0 -0
  69. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/attention.py +0 -0
  70. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/common.py +0 -0
  71. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/mamba.py +0 -0
  72. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/short_conv.py +0 -0
  73. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/state/__init__.py +0 -0
  74. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/state/common.py +0 -0
  75. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/state/kv_cache.py +0 -0
  76. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/state/mamba_state.py +0 -0
  77. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/token_mixers/state/short_conv_state.py +0 -0
  78. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/torch_interop.py +0 -0
  79. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/transformer.py +0 -0
  80. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/transformer_layer.py +0 -0
  81. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/modules/utils.py +0 -0
  82. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/quantization.py +0 -0
  83. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/registry_abc.py +0 -0
  84. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/sampling.py +0 -0
  85. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/speculator/__init__.py +0 -0
  86. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/speculator/common.py +0 -0
  87. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/speculator/estimator.py +0 -0
  88. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/speculator/inference.py +0 -0
  89. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/speculator/ngram.py +0 -0
  90. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/speculator/utils.py +0 -0
  91. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo/utils.py +0 -0
  92. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo.egg-info/SOURCES.txt +0 -0
  93. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo.egg-info/dependency_links.txt +0 -0
  94. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo.egg-info/entry_points.txt +0 -0
  95. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo.egg-info/requires.txt +0 -0
  96. {lalamo-0.5.14 → lalamo-0.5.16}/lalamo.egg-info/top_level.txt +0 -0
  97. {lalamo-0.5.14 → lalamo-0.5.16}/pyproject.toml +0 -0
  98. {lalamo-0.5.14 → lalamo-0.5.16}/setup.cfg +0 -0
  99. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_cartesia_mlx_models.py +0 -0
  100. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_chat_template.py +0 -0
  101. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_generation.py +0 -0
  102. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_huggingface_model_conversion.py +0 -0
  103. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_huggingface_models.py +0 -0
  104. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_lfm2_models.py +0 -0
  105. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_mlx_models.py +0 -0
  106. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_model_spec.py +0 -0
  107. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_models.py +0 -0
  108. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_moe.py +0 -0
  109. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_parameter_tree.py +0 -0
  110. {lalamo-0.5.14 → lalamo-0.5.16}/tests/test_registry_abc.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.14
3
+ Version: 0.5.16
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -15,7 +15,7 @@ from lalamo.speculator import (
15
15
  SpeculatorTrainingEvent,
16
16
  )
17
17
 
18
- __version__ = "0.5.14"
18
+ __version__ = "0.5.16"
19
19
 
20
20
  __all__ = [
21
21
  "AssistantMessage",
@@ -13,6 +13,7 @@ from lalamo.modules import (
13
13
  LlamaRoPEConfig,
14
14
  MLXQuantizedLinearConfig,
15
15
  MLXQuantizedTiedEmbeddingConfig,
16
+ MLXQuantizedUntiedEmbeddingConfig,
16
17
  NormalizationConfig,
17
18
  SiLU,
18
19
  TiedEmbeddingConfig,
@@ -89,27 +90,37 @@ class HFLlamaConfig(HuggingFaceLMConfig):
89
90
  ) -> DecoderConfig:
90
91
  quantization = self.quantization or self.quantization_config
91
92
  if isinstance(quantization, MLXQuantizationConfig):
92
- assert self.tie_word_embeddings, "only tied embeddings are supported"
93
- embedding_config = MLXQuantizedTiedEmbeddingConfig(
94
- input_scale=None,
95
- logit_soft_cap=None,
96
- group_size=quantization.group_size,
97
- embedding_quantization_mode=QuantizationMode.from_num_bits(quantization.bits),
98
- activation_quantization_mode=None,
99
- activation_precision=activation_precision,
100
- )
101
- elif self.tie_word_embeddings:
102
- embedding_config = TiedEmbeddingConfig(
103
- input_scale=None,
104
- logit_soft_cap=None,
105
- precision=activation_precision,
106
- )
107
- else:
108
- embedding_config = UntiedEmbeddingConfig(
109
- input_scale=None,
110
- logit_soft_cap=None,
111
- precision=activation_precision,
112
- )
93
+ if self.tie_word_embeddings:
94
+ embedding_config = MLXQuantizedTiedEmbeddingConfig(
95
+ input_scale=None,
96
+ logit_soft_cap=None,
97
+ group_size=quantization.group_size,
98
+ embedding_quantization_mode=QuantizationMode.from_num_bits(quantization.bits),
99
+ activation_quantization_mode=None,
100
+ activation_precision=activation_precision,
101
+ )
102
+ else:
103
+ embedding_config = MLXQuantizedUntiedEmbeddingConfig(
104
+ input_scale=None,
105
+ logit_soft_cap=None,
106
+ group_size=quantization.group_size,
107
+ embedding_quantization_mode=QuantizationMode.from_num_bits(quantization.bits),
108
+ activation_quantization_mode=None,
109
+ activation_precision=activation_precision,
110
+ )
111
+ else: # noqa: PLR5501
112
+ if self.tie_word_embeddings:
113
+ embedding_config = TiedEmbeddingConfig(
114
+ input_scale=None,
115
+ logit_soft_cap=None,
116
+ precision=activation_precision,
117
+ )
118
+ else:
119
+ embedding_config = UntiedEmbeddingConfig(
120
+ input_scale=None,
121
+ logit_soft_cap=None,
122
+ precision=activation_precision,
123
+ )
113
124
  if self.rope_scaling is None:
114
125
  rope_config = UnscaledRoPEConfig(
115
126
  precision=activation_precision,
@@ -10,6 +10,8 @@ from lalamo.modules import (
10
10
  DenseMLPConfig,
11
11
  FullPrecisionLinearConfig,
12
12
  GroupQuantizedLinearConfig,
13
+ MLXQuantizedTiedEmbeddingConfig,
14
+ MLXQuantizedUntiedEmbeddingConfig,
13
15
  NormalizationConfig,
14
16
  TiedEmbeddingConfig,
15
17
  TransformerConfig,
@@ -19,7 +21,6 @@ from lalamo.modules import (
19
21
  UpcastMode,
20
22
  )
21
23
  from lalamo.modules.activations import SiLU
22
- from lalamo.modules.embedding import MLXQuantizedTiedEmbeddingConfig
23
24
  from lalamo.modules.linear import MLXQuantizedLinearConfig
24
25
  from lalamo.quantization import QuantizationMode
25
26
 
@@ -75,27 +76,37 @@ class HFQwen3Config(HuggingFaceLMConfig):
75
76
  metadata_dict: Mapping[str, str], # noqa: ARG002
76
77
  ) -> DecoderConfig:
77
78
  if isinstance(self.quantization_config, MLXQuantizationConfig):
78
- assert self.tie_word_embeddings, "only tied embeddings are supported"
79
- embedding_config = MLXQuantizedTiedEmbeddingConfig(
80
- input_scale=None,
81
- logit_soft_cap=None,
82
- group_size=self.quantization_config.group_size,
83
- embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
84
- activation_quantization_mode=None,
85
- activation_precision=activation_precision,
86
- )
87
- elif self.tie_word_embeddings:
88
- embedding_config = TiedEmbeddingConfig(
89
- input_scale=None,
90
- logit_soft_cap=None,
91
- precision=activation_precision,
92
- )
93
- else:
94
- embedding_config = UntiedEmbeddingConfig(
95
- input_scale=None,
96
- logit_soft_cap=None,
97
- precision=activation_precision,
98
- )
79
+ if self.tie_word_embeddings:
80
+ embedding_config = MLXQuantizedTiedEmbeddingConfig(
81
+ input_scale=None,
82
+ logit_soft_cap=None,
83
+ group_size=self.quantization_config.group_size,
84
+ embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
85
+ activation_quantization_mode=None,
86
+ activation_precision=activation_precision,
87
+ )
88
+ else:
89
+ embedding_config = MLXQuantizedUntiedEmbeddingConfig(
90
+ input_scale=None,
91
+ logit_soft_cap=None,
92
+ group_size=self.quantization_config.group_size,
93
+ embedding_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
94
+ activation_quantization_mode=None,
95
+ activation_precision=activation_precision,
96
+ )
97
+ else: # noqa: PLR5501
98
+ if self.tie_word_embeddings:
99
+ embedding_config = TiedEmbeddingConfig(
100
+ input_scale=None,
101
+ logit_soft_cap=None,
102
+ precision=activation_precision,
103
+ )
104
+ else:
105
+ embedding_config = UntiedEmbeddingConfig(
106
+ input_scale=None,
107
+ logit_soft_cap=None,
108
+ precision=activation_precision,
109
+ )
99
110
  rope_config = UnscaledRoPEConfig(
100
111
  precision=activation_precision,
101
112
  base=self.rope_theta,
@@ -29,6 +29,7 @@ from lalamo.modules import (
29
29
  UntiedEmbedding,
30
30
  )
31
31
  from lalamo.modules.classifier import Classifier
32
+ from lalamo.modules.embedding import MLXQuantizedUntiedEmbedding
32
33
  from lalamo.modules.mlp import MixtureOfExperts, MLPBase
33
34
  from lalamo.quantization import QuantizationMode
34
35
 
@@ -625,6 +626,51 @@ def load_mlx_quantized_tied_embedding(
625
626
  return load_parameters(lambda m: (m.weights, m.scales, m.biases), module, (weights, scales, biases))
626
627
 
627
628
 
629
+ def load_mlx_quantized_untied_embedding(
630
+ module: MLXQuantizedUntiedEmbedding,
631
+ weights_dict: Mapping[str, Array],
632
+ embedding_path: ParameterPath,
633
+ lm_head_path: ParameterPath,
634
+ ) -> MLXQuantizedUntiedEmbedding:
635
+ input_qweights = weights_dict[embedding_path / "weight"]
636
+ input_qscales = weights_dict[embedding_path / "scales"]
637
+ input_qbiases = weights_dict[embedding_path / "biases"]
638
+ output_qweights = weights_dict[lm_head_path / "weight"]
639
+ output_qscales = weights_dict[lm_head_path / "scales"]
640
+ output_qbiases = weights_dict[lm_head_path / "biases"]
641
+
642
+ input_weights = _process_quantized_tensor(
643
+ input_qweights,
644
+ module.config.embedding_quantization_mode,
645
+ module.activation_precision,
646
+ None,
647
+ )
648
+ input_scales = input_qscales.astype(module.activation_precision)
649
+ input_biases = input_qbiases.astype(module.activation_precision)
650
+
651
+ output_weights = _process_quantized_tensor(
652
+ output_qweights,
653
+ module.config.embedding_quantization_mode,
654
+ module.activation_precision,
655
+ None,
656
+ )
657
+ output_scales = output_qscales.astype(module.activation_precision)
658
+ output_biases = output_qbiases.astype(module.activation_precision)
659
+
660
+ return load_parameters(
661
+ lambda m: (
662
+ m.input_weights,
663
+ m.input_scales,
664
+ m.input_biases,
665
+ m.output_weights,
666
+ m.output_scales,
667
+ m.output_biases,
668
+ ),
669
+ module,
670
+ (input_weights, input_scales, input_biases, output_weights, output_scales, output_biases),
671
+ )
672
+
673
+
628
674
  def load_mlx_semi_quantized_untied_embedding(
629
675
  module: MLXSemiQuantizedUntiedEmbedding,
630
676
  weights_dict: Mapping[str, Array],
@@ -741,6 +787,8 @@ def load_huggingface_decoder(
741
787
  embedding = load_tied_embedding(module.embedding, weights_dict, embedding_path)
742
788
  elif isinstance(module.embedding, MLXQuantizedTiedEmbedding):
743
789
  embedding = load_mlx_quantized_tied_embedding(module.embedding, weights_dict, embedding_path)
790
+ elif isinstance(module.embedding, MLXQuantizedUntiedEmbedding):
791
+ embedding = load_mlx_quantized_untied_embedding(module.embedding, weights_dict, embedding_path, lm_head_path)
744
792
  elif isinstance(module.embedding, MLXSemiQuantizedUntiedEmbedding):
745
793
  embedding = load_mlx_semi_quantized_untied_embedding(
746
794
  module.embedding,
@@ -759,7 +807,7 @@ def load_huggingface_decoder(
759
807
  weights_dict,
760
808
  decoder_path / "layers" / ((i * 2) if alternating_layers else i),
761
809
  decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
762
- mixer_key[type(layer.config.mixer_config)], # type: ignore
810
+ mixer_key[type(layer.config.mixer_config)], # type: ignore
763
811
  mlp_key,
764
812
  pre_mixer_norm_key,
765
813
  pre_mlp_norm_key,
@@ -223,6 +223,20 @@ QWEN3 = [
223
223
  repo="Qwen/Qwen3-8B-AWQ",
224
224
  config_type=HFQwen3Config,
225
225
  ),
226
+ ModelSpec(
227
+ vendor="Alibaba",
228
+ family="Qwen3",
229
+ name="Qwen3-8B-MLX-4bit",
230
+ size="8B",
231
+ quantization=QuantizationMode.UINT4,
232
+ repo="Qwen/Qwen3-8B-MLX-4bit",
233
+ config_type=HFQwen3Config,
234
+ configs=ConfigMap(
235
+ tokenizer=FileSpec("tokenizer.json", "Qwen/Qwen3-8B"),
236
+ tokenizer_config=FileSpec("tokenizer_config.json", "Qwen/Qwen3-8B"),
237
+ generation_config=FileSpec("generation_config.json", "Qwen/Qwen3-8B"),
238
+ ),
239
+ ),
226
240
  ModelSpec(
227
241
  vendor="Alibaba",
228
242
  family="Qwen3",
@@ -18,6 +18,8 @@ from .embedding import (
18
18
  EmbeddingConfig,
19
19
  MLXQuantizedTiedEmbedding,
20
20
  MLXQuantizedTiedEmbeddingConfig,
21
+ MLXQuantizedUntiedEmbedding,
22
+ MLXQuantizedUntiedEmbeddingConfig,
21
23
  MLXSemiQuantizedUntiedEmbedding,
22
24
  MLXSemiQuantizedUntiedEmbeddingConfig,
23
25
  QuantizedTiedEmbedding,
@@ -120,6 +122,8 @@ __all__ = [
120
122
  "MLXQuantizedLinearConfig",
121
123
  "MLXQuantizedTiedEmbedding",
122
124
  "MLXQuantizedTiedEmbeddingConfig",
125
+ "MLXQuantizedUntiedEmbedding",
126
+ "MLXQuantizedUntiedEmbeddingConfig",
123
127
  "MLXSemiQuantizedUntiedEmbedding",
124
128
  "MLXSemiQuantizedUntiedEmbeddingConfig",
125
129
  "Mamba2",
@@ -24,6 +24,8 @@ __all__ = [
24
24
  "EmbeddingConfig",
25
25
  "MLXQuantizedTiedEmbedding",
26
26
  "MLXQuantizedTiedEmbeddingConfig",
27
+ "MLXQuantizedUntiedEmbedding",
28
+ "MLXQuantizedUntiedEmbeddingConfig",
27
29
  "MLXSemiQuantizedUntiedEmbedding",
28
30
  "MLXSemiQuantizedUntiedEmbeddingConfig",
29
31
  "QuantizedTiedEmbedding",
@@ -492,6 +494,172 @@ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
492
494
  )
493
495
 
494
496
 
497
+ @dataclass(frozen=True)
498
+ class MLXQuantizedUntiedEmbeddingConfig(EmbeddingConfigBase):
499
+ group_size: int
500
+ embedding_quantization_mode: QuantizationMode
501
+ activation_quantization_mode: QuantizationMode | None
502
+ activation_precision: DTypeLike
503
+
504
+ def random_init(
505
+ self,
506
+ vocab_size: int,
507
+ model_dim: int,
508
+ *,
509
+ key: PRNGKeyArray,
510
+ ) -> "MLXQuantizedUntiedEmbedding":
511
+ raise NotImplementedError
512
+
513
+ def empty(
514
+ self,
515
+ vocab_size: int,
516
+ model_dim: int,
517
+ ) -> "MLXQuantizedUntiedEmbedding":
518
+ assert model_dim % self.group_size == 0
519
+ model_groups = model_dim // self.group_size
520
+ return MLXQuantizedUntiedEmbedding(
521
+ config=self,
522
+ input_weights=dummy_array((vocab_size, model_dim), dtype=self.activation_precision),
523
+ input_scales=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
524
+ input_biases=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
525
+ output_weights=dummy_array((vocab_size, model_dim), dtype=self.activation_precision),
526
+ output_scales=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
527
+ output_biases=dummy_array((vocab_size, model_groups), dtype=self.activation_precision),
528
+ )
529
+
530
+
531
+ class MLXQuantizedUntiedEmbedding(EmbeddingBase[MLXQuantizedUntiedEmbeddingConfig]):
532
+ input_weights: Float[Array, "vocabulary channels"]
533
+ input_scales: Float[Array, "vocabulary groups"]
534
+ input_biases: Float[Array, "vocabulary groups"]
535
+ output_weights: Float[Array, "vocabulary channels"]
536
+ output_scales: Float[Array, "vocabulary groups"]
537
+ output_biases: Float[Array, "vocabulary groups"]
538
+
539
+ @property
540
+ def activation_precision(self) -> DTypeLike:
541
+ return self.config.activation_precision
542
+
543
+ @property
544
+ def model_dim(self) -> int:
545
+ _, model_dim = self.input_weights.shape
546
+ return model_dim
547
+
548
+ @property
549
+ def vocab_size(self) -> int:
550
+ vocab_size, _ = self.input_weights.shape
551
+ return vocab_size
552
+
553
+ @property
554
+ def int_input_weights(self) -> Int[Array, "vocabulary channels"]:
555
+ quantized = quantize_weights(self.input_weights, self.config.embedding_quantization_mode)
556
+ casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
557
+
558
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
559
+ packed = jax_uint4_to_packed_uint8(casted)
560
+ else:
561
+ packed = casted
562
+
563
+ return packed
564
+
565
+ @property
566
+ def int_output_weights(self) -> Int[Array, "vocabulary channels"]:
567
+ quantized = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
568
+ casted = quantized.astype(self.config.embedding_quantization_mode.dtype)
569
+
570
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
571
+ packed = jax_uint4_to_packed_uint8(casted)
572
+ else:
573
+ packed = casted
574
+
575
+ return packed
576
+
577
+ def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
578
+ quantized_weights = quantize_weights(self.input_weights, self.config.embedding_quantization_mode)
579
+ grouped_weights = rearrange(
580
+ quantized_weights,
581
+ "vocab (groups elements) -> vocab groups elements",
582
+ elements=self.config.group_size,
583
+ )
584
+
585
+ scales = rearrange(self.input_scales, "vocab groups -> vocab groups 1")
586
+
587
+ biases = rearrange(self.input_biases, "vocab groups -> vocab groups 1")
588
+
589
+ scaled_grouped_weights = grouped_weights * scales + biases
590
+
591
+ result = rearrange(
592
+ scaled_grouped_weights,
593
+ "vocab groups elements -> vocab (groups elements)",
594
+ )
595
+ return result
596
+
597
+ def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
598
+ quantized_weights = quantize_weights(self.output_weights, self.config.embedding_quantization_mode)
599
+ grouped_weights = rearrange(
600
+ quantized_weights,
601
+ "vocab (groups elements) -> vocab groups elements",
602
+ elements=self.config.group_size,
603
+ )
604
+
605
+ scales = rearrange(self.output_scales, "vocab groups -> vocab groups 1")
606
+
607
+ biases = rearrange(self.output_biases, "vocab groups -> vocab groups 1")
608
+
609
+ scaled_grouped_weights = grouped_weights * scales + biases
610
+
611
+ result = rearrange(
612
+ scaled_grouped_weights,
613
+ "vocab groups elements -> vocab (groups elements)",
614
+ )
615
+ return result
616
+
617
+ @eqx.filter_jit
618
+ def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
619
+ if self.config.activation_quantization_mode is not None:
620
+ x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
621
+ return super().readout(x)
622
+
623
+ def export_weights(self) -> ParameterTree:
624
+ return {
625
+ "input_weights": self.int_input_weights,
626
+ "input_scales": self.input_scales,
627
+ "input_biases": self.input_biases,
628
+ "output_weights": self.int_output_weights,
629
+ "output_scales": self.output_scales,
630
+ "output_biases": self.output_biases,
631
+ }
632
+
633
+ def import_weights(
634
+ self,
635
+ weights: ParameterTree[Array],
636
+ ) -> Self:
637
+ assert isinstance(weights, Mapping)
638
+ assert isinstance(weights["input_weights"], Array)
639
+ assert isinstance(weights["input_scales"], Array)
640
+ assert isinstance(weights["input_biases"], Array)
641
+ assert isinstance(weights["output_weights"], Array)
642
+ assert isinstance(weights["output_scales"], Array)
643
+ assert isinstance(weights["output_biases"], Array)
644
+
645
+ unpacked_input_weights = weights["input_weights"]
646
+ unpacked_output_weights = weights["output_weights"]
647
+
648
+ if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
649
+ unpacked_input_weights = jax_uint8_to_unpacked_uint4(weights["input_weights"])
650
+ unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
651
+
652
+ return replace(
653
+ self,
654
+ input_weights=unpacked_input_weights.astype(self.input_weights.dtype),
655
+ input_scales=weights["input_scales"],
656
+ input_biases=weights["input_biases"],
657
+ output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
658
+ output_scales=weights["output_scales"],
659
+ output_biases=weights["output_biases"],
660
+ )
661
+
662
+
495
663
  @dataclass(frozen=True)
496
664
  class MLXSemiQuantizedUntiedEmbeddingConfig(EmbeddingConfigBase):
497
665
  group_size: int
@@ -626,6 +794,7 @@ EmbeddingConfig = (
626
794
  | UntiedEmbeddingConfig
627
795
  | QuantizedTiedEmbeddingConfig
628
796
  | MLXQuantizedTiedEmbeddingConfig
797
+ | MLXQuantizedUntiedEmbeddingConfig
629
798
  | MLXSemiQuantizedUntiedEmbeddingConfig
630
799
  )
631
800
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.5.14
3
+ Version: 0.5.16
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes