lalamo 0.5.17__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/commands.py +69 -17
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +148 -27
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +4 -2
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/language_model.py +7 -6
  16. lalamo/modules/activations.py +1 -1
  17. lalamo/modules/classifier.py +11 -24
  18. lalamo/modules/common.py +4 -1
  19. lalamo/modules/decoder.py +5 -11
  20. lalamo/modules/embedding.py +25 -62
  21. lalamo/modules/linear.py +19 -33
  22. lalamo/modules/mlp.py +9 -19
  23. lalamo/modules/mlx_interop.py +1 -1
  24. lalamo/modules/rope.py +1 -1
  25. lalamo/modules/token_mixers/__init__.py +1 -1
  26. lalamo/modules/token_mixers/attention.py +9 -27
  27. lalamo/modules/token_mixers/mamba.py +9 -24
  28. lalamo/modules/token_mixers/short_conv.py +5 -12
  29. lalamo/modules/transformer.py +10 -20
  30. lalamo/modules/transformer_layer.py +8 -20
  31. lalamo/registry_abc.py +4 -4
  32. lalamo/sampling.py +14 -0
  33. lalamo/speculator/estimator.py +3 -3
  34. lalamo/speculator/ngram.py +1 -1
  35. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -1
  36. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/RECORD +40 -40
  37. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
  38. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
  39. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
  40. {lalamo-0.5.17.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
@@ -289,7 +289,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
289
289
  combined_up_gate_b = jnp.concatenate([up_b + 1.0, gate_b], axis=-1)
290
290
 
291
291
  up_projection = load_parameters(
292
- lambda m: (m.weights, m.biases), # type: ignore
292
+ lambda m: (m.weights, m.biases),
293
293
  module.experts.up_projection,
294
294
  (combined_up_gate_w, combined_up_gate_b),
295
295
  )
@@ -309,7 +309,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
309
309
  down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
310
310
 
311
311
  down_projection = load_parameters(
312
- lambda m: (m.weights, m.biases), # type: ignore
312
+ lambda m: (m.weights, m.biases),
313
313
  module.experts.down_projection,
314
314
  (down_w, down_b),
315
315
  )
@@ -807,7 +807,7 @@ def load_huggingface_decoder(
807
807
  weights_dict,
808
808
  decoder_path / "layers" / ((i * 2) if alternating_layers else i),
809
809
  decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
810
- mixer_key[type(layer.config.mixer_config)], # type: ignore
810
+ mixer_key[type(layer.config.mixer_config)],
811
811
  mlp_key,
812
812
  pre_mixer_norm_key,
813
813
  pre_mlp_norm_key,
@@ -7,13 +7,14 @@ from contextlib import contextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from enum import Enum, StrEnum
9
9
  from pathlib import Path
10
- from typing import ClassVar, cast, get_args, get_origin
10
+ from typing import Any, ClassVar, cast, get_args, get_origin
11
11
 
12
12
  import cattrs
13
13
  import jax.numpy as jnp
14
14
  from jaxtyping import Array, DTypeLike
15
15
 
16
16
  from lalamo.model_import.decoder_configs import ForeignConfig
17
+ from lalamo.models.language_model import GenerationConfig
17
18
  from lalamo.quantization import QuantizationMode
18
19
  from lalamo.safetensors import safe_read
19
20
  from lalamo.utils import MapDictValues
@@ -86,7 +87,7 @@ class ConfigMap:
86
87
  model_config: FileSpec = field(default=FileSpec("config.json"))
87
88
  tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
88
89
  tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
89
- generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
90
+ generation_config: FileSpec | GenerationConfig | None = field(default=FileSpec("generation_config.json"))
90
91
  chat_template: FileSpec | JSONFieldSpec | str | None = None
91
92
 
92
93
 
@@ -123,6 +124,7 @@ def _structure_chat_template(value: object, _type: object) -> FileSpec | JSONFie
123
124
  if isinstance(value, str):
124
125
  return value
125
126
  if isinstance(value, dict):
127
+ value = cast("dict[Any, Any]", value) # ty bug??? Why is just `dict` != `dict[Any, Any]`?
126
128
  if "file_spec" in value and "field_name" in value:
127
129
  return JSONFieldSpec(
128
130
  file_spec=FileSpec(**value["file_spec"]),
@@ -1,4 +1,7 @@
1
+ from itertools import chain, product
2
+
1
3
  from lalamo.model_import.decoder_configs import HFLFM2Config
4
+ from lalamo.models.language_model import GenerationConfig
2
5
  from lalamo.quantization import QuantizationMode
3
6
 
4
7
  from .common import ConfigMap, FileSpec, ModelSpec
@@ -6,26 +9,55 @@ from .common import ConfigMap, FileSpec, ModelSpec
6
9
  __all__ = ["LFM2_MODELS"]
7
10
 
8
11
 
9
- def _lfm2_repo(size: str, quantization: QuantizationMode | None) -> tuple[str, str]:
10
- organization = "LiquidAI" if quantization is None else "mlx-community"
11
- name = f"LFM2-{size}{f'-{quantization.bits}bit' if quantization is not None else ''}"
12
- return (organization, name)
12
+ def _lfm_repo(family: str, size: str, variant: str | None, quantization: QuantizationMode | None) -> tuple[str, str]:
13
+ return (
14
+ "LiquidAI" if quantization is None else "mlx-community",
15
+ f"{family}-{size}"
16
+ f"{f'-{variant}' if variant is not None else ''}"
17
+ f"{f'-{quantization.bits}bit' if quantization is not None else ''}",
18
+ )
13
19
 
14
20
 
15
- LFM2_MODELS = [
21
+ _LFM20_MODELS = [
16
22
  ModelSpec(
17
23
  vendor="LiquidAI",
18
24
  family="LFM2",
19
- name=_lfm2_repo(size, quantization)[1],
25
+ name=_lfm_repo("LFM2", size, variant, quantization)[1],
20
26
  size=size,
21
- repo="/".join(_lfm2_repo(size, quantization)),
27
+ repo="/".join(_lfm_repo("LFM2", size, variant, quantization)),
22
28
  config_type=HFLFM2Config,
23
29
  quantization=quantization,
24
30
  configs=ConfigMap(
31
+ generation_config=GenerationConfig(temperature=0.3, min_p=0.15), # , repetition_penalty=1.05
25
32
  chat_template=FileSpec("chat_template.jinja"),
26
33
  ),
27
34
  use_cases=tuple(),
28
35
  )
29
- for size in ["350M", "700M", "1.2B", "2.6B"]
30
- for quantization in [None, *([QuantizationMode.UINT4, QuantizationMode.UINT8] if size != "2.6B" else [])]
36
+ for size, variant, quantization in chain(
37
+ product(["350M", "700M", "1.2B"], [None], [None, QuantizationMode.UINT4, QuantizationMode.UINT8]),
38
+ product(["2.6B"], [None, "Exp"], [None]),
39
+ product(["2.6B"], ["Exp"], [QuantizationMode.UINT4, QuantizationMode.UINT8]),
40
+ )
31
41
  ]
42
+
43
+ _LFM25_MODELS = [
44
+ ModelSpec(
45
+ vendor="LiquidAI",
46
+ family="LFM2.5",
47
+ name=_lfm_repo("LFM2.5", size, variant, quantization)[1],
48
+ size=size,
49
+ repo="/".join(_lfm_repo("LFM2.5", size, variant, quantization)),
50
+ config_type=HFLFM2Config,
51
+ quantization=quantization,
52
+ configs=ConfigMap(
53
+ generation_config=GenerationConfig(temperature=0.1, top_k=50, top_p=0.1), # , repetition_penalty=1.05
54
+ chat_template=FileSpec("chat_template.jinja"),
55
+ ),
56
+ use_cases=tuple(),
57
+ )
58
+ for size, variant, quantization in chain(
59
+ product(["1.2B"], ["Instruct"], [None]),
60
+ )
61
+ ]
62
+
63
+ LFM2_MODELS = _LFM20_MODELS + _LFM25_MODELS
@@ -64,14 +64,15 @@ class GenerationResults(NamedTuple):
64
64
 
65
65
  @dataclass(frozen=True)
66
66
  class GenerationConfig:
67
- stop_token_ids: tuple[int, ...]
68
- temperature: float | None
69
- top_k: int | None
70
- top_p: float | None
71
- banned_tokens: tuple[int, ...] | None
67
+ stop_token_ids: tuple[int, ...] = tuple()
68
+ temperature: float | None = None
69
+ top_k: int | None = None
70
+ top_p: float | None = None
71
+ min_p: float | None = None
72
+ banned_tokens: tuple[int, ...] | None = None
72
73
 
73
74
  def default_policy(self) -> SamplingPolicy:
74
- return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
75
+ return make_policy(self.temperature, self.top_k, self.top_p, self.min_p, self.banned_tokens)
75
76
 
76
77
 
77
78
  @dataclass(frozen=True)
@@ -44,4 +44,4 @@ class Identity(ActivationBase):
44
44
  Activation = SiLU | GELU | Identity
45
45
 
46
46
 
47
- register_config_union(Activation) # type: ignore (pyright bug)
47
+ register_config_union(Activation)
@@ -9,7 +9,7 @@ from jax import numpy as jnp
9
9
  from jax import vmap
10
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
- from lalamo.common import ParameterTree
12
+ from lalamo.common import ParameterTree, require_tree
13
13
  from lalamo.modules import Activation
14
14
  from lalamo.modules.normalization import NormalizationConfig
15
15
  from lalamo.modules.transformer import (
@@ -67,7 +67,7 @@ class PredictionHeadConfig:
67
67
  def random_init(self, input_size: int, num_labels: int, key: PRNGKeyArray) -> "PredictionHead":
68
68
  dense_key, readout_key = jax.random.split(key)
69
69
  dense_layer = self.dense_config.random_init(
70
- input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key
70
+ input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key,
71
71
  )
72
72
  norm = self.normalization_config.empty(input_size)
73
73
  readout = self.readout_config.random_init(
@@ -117,19 +117,13 @@ class PredictionHead(LalamoModule[PredictionHeadConfig]):
117
117
  )
118
118
  return result
119
119
 
120
- def import_weights(
121
- self,
122
- weights: ParameterTree[Array],
123
- ) -> Self:
120
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
124
121
  assert isinstance(weights, Mapping)
125
- assert isinstance(weights["dense"], Mapping)
126
- assert isinstance(weights["norm"], Mapping)
127
- assert isinstance(weights["readout"], Mapping)
128
122
  return replace(
129
123
  self,
130
- dense=self.dense.import_weights(weights["dense"]),
131
- norm=self.norm.import_weights(weights["norm"]),
132
- readout=self.readout.import_weights(weights["readout"]),
124
+ dense=self.dense.import_weights(require_tree(weights["dense"])),
125
+ norm=self.norm.import_weights(require_tree(weights["norm"])),
126
+ readout=self.readout.import_weights(require_tree(weights["readout"])),
133
127
  )
134
128
 
135
129
 
@@ -321,19 +315,12 @@ class Classifier(LalamoModule[ClassifierConfig]):
321
315
  )
322
316
  return result
323
317
 
324
- def import_weights(
325
- self,
326
- weights: ParameterTree[Array],
327
- ) -> Self:
318
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
328
319
  assert isinstance(weights, Mapping)
329
- assert isinstance(weights["embedding"], Mapping)
330
- assert isinstance(weights["embedding_norm"], Mapping)
331
- assert isinstance(weights["transformer"], Mapping)
332
- assert isinstance(weights["prediction_head"], Mapping)
333
320
  return replace(
334
321
  self,
335
- embedding=self.embedding.import_weights(weights["embedding"]),
336
- embedding_norm=self.embedding_norm.import_weights(weights["embedding_norm"]),
337
- transformer=self.transformer.import_weights(weights["transformer"]),
338
- prediction_head=self.prediction_head.import_weights(weights["prediction_head"]),
322
+ embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
323
+ embedding_norm=self.embedding_norm.import_weights(require_tree(weights["embedding_norm"])),
324
+ transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
325
+ prediction_head=self.prediction_head.import_weights(require_tree(weights["prediction_head"])),
339
326
  )
lalamo/modules/common.py CHANGED
@@ -9,15 +9,18 @@ from cattrs import Converter
9
9
  from jax import numpy as jnp
10
10
  from jaxtyping import Array, DTypeLike
11
11
 
12
- from lalamo.common import ParameterTree
12
+ from lalamo.common import ParameterTree, require_array, require_tree
13
13
 
14
14
  __all__ = [
15
15
  "DummyUnionMember",
16
16
  "ForwardPassMode",
17
17
  "LalamoModule",
18
+ "ParameterTree",
18
19
  "PositionalEmbeddingSelector",
19
20
  "config_converter",
20
21
  "register_config_union",
22
+ "require_array",
23
+ "require_tree",
21
24
  ]
22
25
 
23
26
 
lalamo/modules/decoder.py CHANGED
@@ -7,7 +7,7 @@ import jax
7
7
  from jax import vmap
8
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
- from lalamo.common import ParameterTree
10
+ from lalamo.common import ParameterTree, require_tree
11
11
 
12
12
  from .common import ForwardPassMode, LalamoModule
13
13
  from .embedding import EmbeddingBase, EmbeddingConfig
@@ -126,7 +126,7 @@ class Decoder(LalamoModule[DecoderConfig]):
126
126
  return self.embedding.activation_precision
127
127
 
128
128
  @eqx.filter_jit
129
- def __call__( # noqa: PLR0912
129
+ def __call__(
130
130
  self,
131
131
  token_ids: Int[Array, "batch suffix_tokens"],
132
132
  token_positions: Int[Array, "batch suffix_tokens"],
@@ -193,16 +193,10 @@ class Decoder(LalamoModule[DecoderConfig]):
193
193
  transformer=self.transformer.export_weights(),
194
194
  )
195
195
 
196
- def import_weights(
197
- self,
198
- weights: ParameterTree[Array],
199
- ) -> Self:
196
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
200
197
  assert isinstance(weights, Mapping)
201
- assert isinstance(weights["embedding"], Mapping)
202
- assert isinstance(weights["transformer"], Mapping)
203
-
204
198
  return replace(
205
199
  self,
206
- embedding=self.embedding.import_weights(weights["embedding"]),
207
- transformer=self.transformer.import_weights(weights["transformer"]),
200
+ embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
201
+ transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
208
202
  )
@@ -9,7 +9,7 @@ import jax.numpy as jnp
9
9
  from einops import rearrange
10
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
- from lalamo.common import ParameterTree, dummy_array
12
+ from lalamo.common import ParameterTree, dummy_array, require_array
13
13
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
14
14
  from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
15
15
 
@@ -355,21 +355,15 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
355
355
  "scales": self.scales,
356
356
  }
357
357
 
358
- def import_weights(
359
- self,
360
- weights: ParameterTree[Array],
361
- ) -> Self:
358
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
362
359
  assert isinstance(weights, Mapping)
363
- assert isinstance(weights["weights"], Array)
364
- stored_weights = weights["weights"]
365
-
360
+ stored_weights = require_array(weights["weights"])
366
361
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
367
362
  stored_weights = jax_uint8_to_unpacked_uint4(stored_weights)
368
-
369
363
  return replace(
370
364
  self,
371
365
  weights=stored_weights.astype(self.weights.dtype),
372
- scales=weights["scales"],
366
+ scales=require_array(weights["scales"]),
373
367
  )
374
368
 
375
369
 
@@ -472,25 +466,16 @@ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
472
466
  "biases": self.biases,
473
467
  }
474
468
 
475
- def import_weights(
476
- self,
477
- weights: ParameterTree[Array],
478
- ) -> Self:
469
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
479
470
  assert isinstance(weights, Mapping)
480
- assert isinstance(weights["weights"], Array)
481
- assert isinstance(weights["scales"], Array)
482
- assert isinstance(weights["biases"], Array)
483
-
484
- unpacked_weights = weights["weights"]
485
-
471
+ unpacked_weights = require_array(weights["weights"])
486
472
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
487
- unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
488
-
473
+ unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
489
474
  return replace(
490
475
  self,
491
476
  weights=unpacked_weights.astype(self.weights.dtype),
492
- scales=weights["scales"],
493
- biases=weights["biases"],
477
+ scales=require_array(weights["scales"]),
478
+ biases=require_array(weights["biases"]),
494
479
  )
495
480
 
496
481
 
@@ -630,33 +615,21 @@ class MLXQuantizedUntiedEmbedding(EmbeddingBase[MLXQuantizedUntiedEmbeddingConfi
630
615
  "output_biases": self.output_biases,
631
616
  }
632
617
 
633
- def import_weights(
634
- self,
635
- weights: ParameterTree[Array],
636
- ) -> Self:
618
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
637
619
  assert isinstance(weights, Mapping)
638
- assert isinstance(weights["input_weights"], Array)
639
- assert isinstance(weights["input_scales"], Array)
640
- assert isinstance(weights["input_biases"], Array)
641
- assert isinstance(weights["output_weights"], Array)
642
- assert isinstance(weights["output_scales"], Array)
643
- assert isinstance(weights["output_biases"], Array)
644
-
645
- unpacked_input_weights = weights["input_weights"]
646
- unpacked_output_weights = weights["output_weights"]
647
-
620
+ unpacked_input_weights = require_array(weights["input_weights"])
621
+ unpacked_output_weights = require_array(weights["output_weights"])
648
622
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
649
- unpacked_input_weights = jax_uint8_to_unpacked_uint4(weights["input_weights"])
650
- unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
651
-
623
+ unpacked_input_weights = jax_uint8_to_unpacked_uint4(unpacked_input_weights)
624
+ unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
652
625
  return replace(
653
626
  self,
654
627
  input_weights=unpacked_input_weights.astype(self.input_weights.dtype),
655
- input_scales=weights["input_scales"],
656
- input_biases=weights["input_biases"],
628
+ input_scales=require_array(weights["input_scales"]),
629
+ input_biases=require_array(weights["input_biases"]),
657
630
  output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
658
- output_scales=weights["output_scales"],
659
- output_biases=weights["output_biases"],
631
+ output_scales=require_array(weights["output_scales"]),
632
+ output_biases=require_array(weights["output_biases"]),
660
633
  )
661
634
 
662
635
 
@@ -765,27 +738,17 @@ class MLXSemiQuantizedUntiedEmbedding(EmbeddingBase[MLXSemiQuantizedUntiedEmbedd
765
738
  "output_biases": self.output_biases,
766
739
  }
767
740
 
768
- def import_weights(
769
- self,
770
- weights: ParameterTree[Array],
771
- ) -> Self:
741
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
772
742
  assert isinstance(weights, Mapping)
773
- assert isinstance(weights["input_weights"], Array)
774
- assert isinstance(weights["output_weights"], Array)
775
- assert isinstance(weights["output_scales"], Array)
776
- assert isinstance(weights["output_biases"], Array)
777
-
778
- unpacked_output_weights = weights["output_weights"]
779
-
743
+ unpacked_output_weights = require_array(weights["output_weights"])
780
744
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
781
- unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
782
-
745
+ unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
783
746
  return replace(
784
747
  self,
785
- input_weights=weights["input_weights"],
748
+ input_weights=require_array(weights["input_weights"]),
786
749
  output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
787
- output_scales=weights["output_scales"],
788
- output_biases=weights["output_biases"],
750
+ output_scales=require_array(weights["output_scales"]),
751
+ output_biases=require_array(weights["output_biases"]),
789
752
  )
790
753
 
791
754
 
@@ -799,4 +762,4 @@ EmbeddingConfig = (
799
762
  )
800
763
 
801
764
 
802
- register_config_union(EmbeddingConfig) # type: ignore (pyright bug)
765
+ register_config_union(EmbeddingConfig)
lalamo/modules/linear.py CHANGED
@@ -2,7 +2,7 @@ import math
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping, Sequence
4
4
  from dataclasses import dataclass, replace
5
- from typing import Self
5
+ from typing import Self, cast
6
6
 
7
7
  import equinox as eqx
8
8
  import jax
@@ -10,7 +10,7 @@ import jax.numpy as jnp
10
10
  from einops import rearrange
11
11
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
12
 
13
- from lalamo.common import ParameterTree, dummy_array
13
+ from lalamo.common import ParameterTree, dummy_array, require_array
14
14
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
15
15
  from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
16
16
 
@@ -464,7 +464,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
464
464
 
465
465
  return packed
466
466
 
467
- def __post_init__(self) -> None: # noqa: PLR0912
467
+ def __post_init__(self) -> None:
468
468
  if self.weights.dtype != self.config.activation_precision:
469
469
  raise ValueError(
470
470
  f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
@@ -572,26 +572,19 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
572
572
  result["biases"] = self.biases
573
573
  return result
574
574
 
575
- def import_weights(
576
- self,
577
- weights: ParameterTree[Array],
578
- ) -> Self:
575
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
579
576
  assert isinstance(weights, Mapping)
580
- assert isinstance(weights["weights"], Array)
581
- assert isinstance(weights["zero_points"], Array)
582
- unpacked_weights = weights["weights"]
583
- unpacked_zero_points = weights["zero_points"]
584
-
577
+ unpacked_weights = require_array(weights["weights"])
578
+ unpacked_zero_points = require_array(weights["zero_points"])
585
579
  if self.config.weight_quantization_mode == QuantizationMode.UINT4:
586
- unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
587
- unpacked_zero_points = jax_uint8_to_unpacked_uint4(weights["zero_points"])
588
-
580
+ unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
581
+ unpacked_zero_points = jax_uint8_to_unpacked_uint4(unpacked_zero_points)
589
582
  return replace(
590
583
  self,
591
584
  weights=unpacked_weights.astype(self.weights.dtype),
592
- scales=weights["scales"],
585
+ scales=require_array(weights["scales"]),
593
586
  zero_points=unpacked_zero_points.astype(self.zero_points.dtype),
594
- biases=weights["biases"] if self.has_biases else None,
587
+ biases=require_array(weights["biases"]) if self.has_biases else None,
595
588
  )
596
589
 
597
590
 
@@ -740,7 +733,7 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
740
733
 
741
734
  return packed
742
735
 
743
- def __post_init__(self) -> None: # noqa: PLR0912
736
+ def __post_init__(self) -> None:
744
737
  if self.weights.dtype != self.config.activation_precision:
745
738
  raise ValueError(
746
739
  f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
@@ -847,24 +840,17 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
847
840
  result["biases"] = self.biases
848
841
  return result
849
842
 
850
- def import_weights(
851
- self,
852
- weights: ParameterTree[Array],
853
- ) -> Self:
843
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
854
844
  assert isinstance(weights, Mapping)
855
- assert isinstance(weights["weights"], Array)
856
-
857
- unpacked_weights = weights["weights"]
858
-
845
+ unpacked_weights = require_array(weights["weights"])
859
846
  if self.config.weight_quantization_mode == QuantizationMode.UINT4:
860
- unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
861
-
847
+ unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
862
848
  return replace(
863
849
  self,
864
850
  weights=unpacked_weights.astype(self.weights.dtype),
865
- scales=weights["scales"],
866
- deq_biases=weights["deq_biases"],
867
- biases=weights["biases"] if self.has_biases else None,
851
+ scales=require_array(weights["scales"]),
852
+ deq_biases=require_array(weights["deq_biases"]),
853
+ biases=require_array(weights["biases"]) if self.has_biases else None,
868
854
  )
869
855
 
870
856
 
@@ -1113,7 +1099,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
1113
1099
  self,
1114
1100
  weights: ParameterTree[Array],
1115
1101
  ) -> Self:
1116
- base = super().import_weights(weights)
1102
+ base = cast("Self", super().import_weights(weights)) # ty bug
1117
1103
  assert isinstance(weights, Mapping)
1118
1104
  assert isinstance(weights["up_weights"], Sequence)
1119
1105
  return replace(
@@ -1126,4 +1112,4 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
1126
1112
  LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | MLXQuantizedLinearConfig | QLoRALinearConfig
1127
1113
 
1128
1114
 
1129
- register_config_union(LinearConfig) # type: ignore (pyright bug)
1115
+ register_config_union(LinearConfig)
lalamo/modules/mlp.py CHANGED
@@ -12,7 +12,7 @@ from einops import rearrange
12
12
  from jax import vmap
13
13
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
14
14
 
15
- from lalamo.common import ParameterTree
15
+ from lalamo.common import ParameterTree, require_tree
16
16
  from lalamo.modules.utils import vmap_twice
17
17
 
18
18
  from .activations import Activation
@@ -242,17 +242,12 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
242
242
  "down_projection": self.down_projection.export_weights(),
243
243
  }
244
244
 
245
- def import_weights(
246
- self,
247
- weights: ParameterTree[Array],
248
- ) -> Self:
245
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
249
246
  assert isinstance(weights, Mapping)
250
- assert isinstance(weights["up_projection"], Mapping)
251
- assert isinstance(weights["down_projection"], Mapping)
252
247
  return replace(
253
248
  self,
254
- up_projection=self.up_projection.import_weights(weights["up_projection"]),
255
- down_projection=self.down_projection.import_weights(weights["down_projection"]),
249
+ up_projection=self.up_projection.import_weights(require_tree(weights["up_projection"])),
250
+ down_projection=self.down_projection.import_weights(require_tree(weights["down_projection"])),
256
251
  )
257
252
 
258
253
 
@@ -285,7 +280,7 @@ class SoftmaxRouting(RoutingFunctionBase):
285
280
  RoutingFunction = SoftmaxRouting | DummyUnionMember
286
281
 
287
282
 
288
- register_config_union(RoutingFunction) # type: ignore (pyright bug)
283
+ register_config_union(RoutingFunction)
289
284
 
290
285
 
291
286
  @dataclass(frozen=True)
@@ -486,21 +481,16 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
486
481
  "experts": self.experts.export_weights(),
487
482
  }
488
483
 
489
- def import_weights(
490
- self,
491
- weights: ParameterTree[Array],
492
- ) -> Self:
484
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
493
485
  assert isinstance(weights, Mapping)
494
- assert isinstance(weights["router"], Mapping)
495
- assert isinstance(weights["experts"], Mapping)
496
486
  return replace(
497
487
  self,
498
- router=self.router.import_weights(weights["router"]),
499
- experts=self.experts.import_weights(weights["experts"]),
488
+ router=self.router.import_weights(require_tree(weights["router"])),
489
+ experts=self.experts.import_weights(require_tree(weights["experts"])),
500
490
  )
501
491
 
502
492
 
503
493
  MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
504
494
 
505
495
 
506
- register_config_union(MLPConfig) # type: ignore (pyright bug)
496
+ register_config_union(MLPConfig)
@@ -1,5 +1,5 @@
1
1
  import jax.numpy as jnp
2
- import mlx.core as mx
2
+ import mlx.core as mx # type: ignore
3
3
  from jaxtyping import Array
4
4
 
5
5
  __all__ = ["jax_to_mlx", "mlx_to_jax"]
lalamo/modules/rope.py CHANGED
@@ -281,4 +281,4 @@ class LinearScalingRoPEConfig(RoPEConfigBase):
281
281
 
282
282
  RoPEConfig = UnscaledRoPEConfig | LlamaRoPEConfig | YARNRoPEConfig | LinearScalingRoPEConfig
283
283
 
284
- register_config_union(RoPEConfig) # type: ignore (pyright bug)
284
+ register_config_union(RoPEConfig)
@@ -16,7 +16,7 @@ from .state import (
16
16
 
17
17
  TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
18
18
 
19
- register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
19
+ register_config_union(TokenMixerConfig)
20
20
 
21
21
  __all__ = [
22
22
  "Attention",