lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. lalamo/__init__.py +26 -2
  2. lalamo/commands.py +429 -0
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +375 -229
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +8 -4
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/common.py +3 -3
  16. lalamo/models/language_model.py +7 -6
  17. lalamo/modules/activations.py +1 -1
  18. lalamo/modules/classifier.py +11 -24
  19. lalamo/modules/common.py +4 -1
  20. lalamo/modules/decoder.py +5 -11
  21. lalamo/modules/embedding.py +25 -62
  22. lalamo/modules/linear.py +19 -33
  23. lalamo/modules/mlp.py +9 -19
  24. lalamo/modules/mlx_interop.py +1 -1
  25. lalamo/modules/rope.py +1 -1
  26. lalamo/modules/token_mixers/__init__.py +1 -1
  27. lalamo/modules/token_mixers/attention.py +9 -27
  28. lalamo/modules/token_mixers/mamba.py +9 -24
  29. lalamo/modules/token_mixers/short_conv.py +5 -12
  30. lalamo/modules/transformer.py +10 -20
  31. lalamo/modules/transformer_layer.py +8 -20
  32. lalamo/registry_abc.py +4 -4
  33. lalamo/safetensors.py +97 -0
  34. lalamo/sampling.py +14 -0
  35. lalamo/speculator/estimator.py +11 -4
  36. lalamo/speculator/ngram.py +1 -1
  37. lalamo/utils.py +0 -13
  38. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
  39. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
  40. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
  41. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
  42. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import jax.numpy as jnp
9
9
  from einops import rearrange
10
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
- from lalamo.common import ParameterTree, dummy_array
12
+ from lalamo.common import ParameterTree, dummy_array, require_array
13
13
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
14
14
  from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
15
15
 
@@ -355,21 +355,15 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
355
355
  "scales": self.scales,
356
356
  }
357
357
 
358
- def import_weights(
359
- self,
360
- weights: ParameterTree[Array],
361
- ) -> Self:
358
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
362
359
  assert isinstance(weights, Mapping)
363
- assert isinstance(weights["weights"], Array)
364
- stored_weights = weights["weights"]
365
-
360
+ stored_weights = require_array(weights["weights"])
366
361
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
367
362
  stored_weights = jax_uint8_to_unpacked_uint4(stored_weights)
368
-
369
363
  return replace(
370
364
  self,
371
365
  weights=stored_weights.astype(self.weights.dtype),
372
- scales=weights["scales"],
366
+ scales=require_array(weights["scales"]),
373
367
  )
374
368
 
375
369
 
@@ -472,25 +466,16 @@ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
472
466
  "biases": self.biases,
473
467
  }
474
468
 
475
- def import_weights(
476
- self,
477
- weights: ParameterTree[Array],
478
- ) -> Self:
469
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
479
470
  assert isinstance(weights, Mapping)
480
- assert isinstance(weights["weights"], Array)
481
- assert isinstance(weights["scales"], Array)
482
- assert isinstance(weights["biases"], Array)
483
-
484
- unpacked_weights = weights["weights"]
485
-
471
+ unpacked_weights = require_array(weights["weights"])
486
472
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
487
- unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
488
-
473
+ unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
489
474
  return replace(
490
475
  self,
491
476
  weights=unpacked_weights.astype(self.weights.dtype),
492
- scales=weights["scales"],
493
- biases=weights["biases"],
477
+ scales=require_array(weights["scales"]),
478
+ biases=require_array(weights["biases"]),
494
479
  )
495
480
 
496
481
 
@@ -630,33 +615,21 @@ class MLXQuantizedUntiedEmbedding(EmbeddingBase[MLXQuantizedUntiedEmbeddingConfi
630
615
  "output_biases": self.output_biases,
631
616
  }
632
617
 
633
- def import_weights(
634
- self,
635
- weights: ParameterTree[Array],
636
- ) -> Self:
618
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
637
619
  assert isinstance(weights, Mapping)
638
- assert isinstance(weights["input_weights"], Array)
639
- assert isinstance(weights["input_scales"], Array)
640
- assert isinstance(weights["input_biases"], Array)
641
- assert isinstance(weights["output_weights"], Array)
642
- assert isinstance(weights["output_scales"], Array)
643
- assert isinstance(weights["output_biases"], Array)
644
-
645
- unpacked_input_weights = weights["input_weights"]
646
- unpacked_output_weights = weights["output_weights"]
647
-
620
+ unpacked_input_weights = require_array(weights["input_weights"])
621
+ unpacked_output_weights = require_array(weights["output_weights"])
648
622
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
649
- unpacked_input_weights = jax_uint8_to_unpacked_uint4(weights["input_weights"])
650
- unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
651
-
623
+ unpacked_input_weights = jax_uint8_to_unpacked_uint4(unpacked_input_weights)
624
+ unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
652
625
  return replace(
653
626
  self,
654
627
  input_weights=unpacked_input_weights.astype(self.input_weights.dtype),
655
- input_scales=weights["input_scales"],
656
- input_biases=weights["input_biases"],
628
+ input_scales=require_array(weights["input_scales"]),
629
+ input_biases=require_array(weights["input_biases"]),
657
630
  output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
658
- output_scales=weights["output_scales"],
659
- output_biases=weights["output_biases"],
631
+ output_scales=require_array(weights["output_scales"]),
632
+ output_biases=require_array(weights["output_biases"]),
660
633
  )
661
634
 
662
635
 
@@ -765,27 +738,17 @@ class MLXSemiQuantizedUntiedEmbedding(EmbeddingBase[MLXSemiQuantizedUntiedEmbedd
765
738
  "output_biases": self.output_biases,
766
739
  }
767
740
 
768
- def import_weights(
769
- self,
770
- weights: ParameterTree[Array],
771
- ) -> Self:
741
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
772
742
  assert isinstance(weights, Mapping)
773
- assert isinstance(weights["input_weights"], Array)
774
- assert isinstance(weights["output_weights"], Array)
775
- assert isinstance(weights["output_scales"], Array)
776
- assert isinstance(weights["output_biases"], Array)
777
-
778
- unpacked_output_weights = weights["output_weights"]
779
-
743
+ unpacked_output_weights = require_array(weights["output_weights"])
780
744
  if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
781
- unpacked_output_weights = jax_uint8_to_unpacked_uint4(weights["output_weights"])
782
-
745
+ unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
783
746
  return replace(
784
747
  self,
785
- input_weights=weights["input_weights"],
748
+ input_weights=require_array(weights["input_weights"]),
786
749
  output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
787
- output_scales=weights["output_scales"],
788
- output_biases=weights["output_biases"],
750
+ output_scales=require_array(weights["output_scales"]),
751
+ output_biases=require_array(weights["output_biases"]),
789
752
  )
790
753
 
791
754
 
@@ -799,4 +762,4 @@ EmbeddingConfig = (
799
762
  )
800
763
 
801
764
 
802
- register_config_union(EmbeddingConfig) # type: ignore (pyright bug)
765
+ register_config_union(EmbeddingConfig)
lalamo/modules/linear.py CHANGED
@@ -2,7 +2,7 @@ import math
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping, Sequence
4
4
  from dataclasses import dataclass, replace
5
- from typing import Self
5
+ from typing import Self, cast
6
6
 
7
7
  import equinox as eqx
8
8
  import jax
@@ -10,7 +10,7 @@ import jax.numpy as jnp
10
10
  from einops import rearrange
11
11
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
12
 
13
- from lalamo.common import ParameterTree, dummy_array
13
+ from lalamo.common import ParameterTree, dummy_array, require_array
14
14
  from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
15
15
  from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
16
16
 
@@ -464,7 +464,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
464
464
 
465
465
  return packed
466
466
 
467
- def __post_init__(self) -> None: # noqa: PLR0912
467
+ def __post_init__(self) -> None:
468
468
  if self.weights.dtype != self.config.activation_precision:
469
469
  raise ValueError(
470
470
  f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
@@ -572,26 +572,19 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
572
572
  result["biases"] = self.biases
573
573
  return result
574
574
 
575
- def import_weights(
576
- self,
577
- weights: ParameterTree[Array],
578
- ) -> Self:
575
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
579
576
  assert isinstance(weights, Mapping)
580
- assert isinstance(weights["weights"], Array)
581
- assert isinstance(weights["zero_points"], Array)
582
- unpacked_weights = weights["weights"]
583
- unpacked_zero_points = weights["zero_points"]
584
-
577
+ unpacked_weights = require_array(weights["weights"])
578
+ unpacked_zero_points = require_array(weights["zero_points"])
585
579
  if self.config.weight_quantization_mode == QuantizationMode.UINT4:
586
- unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
587
- unpacked_zero_points = jax_uint8_to_unpacked_uint4(weights["zero_points"])
588
-
580
+ unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
581
+ unpacked_zero_points = jax_uint8_to_unpacked_uint4(unpacked_zero_points)
589
582
  return replace(
590
583
  self,
591
584
  weights=unpacked_weights.astype(self.weights.dtype),
592
- scales=weights["scales"],
585
+ scales=require_array(weights["scales"]),
593
586
  zero_points=unpacked_zero_points.astype(self.zero_points.dtype),
594
- biases=weights["biases"] if self.has_biases else None,
587
+ biases=require_array(weights["biases"]) if self.has_biases else None,
595
588
  )
596
589
 
597
590
 
@@ -740,7 +733,7 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
740
733
 
741
734
  return packed
742
735
 
743
- def __post_init__(self) -> None: # noqa: PLR0912
736
+ def __post_init__(self) -> None:
744
737
  if self.weights.dtype != self.config.activation_precision:
745
738
  raise ValueError(
746
739
  f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
@@ -847,24 +840,17 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
847
840
  result["biases"] = self.biases
848
841
  return result
849
842
 
850
- def import_weights(
851
- self,
852
- weights: ParameterTree[Array],
853
- ) -> Self:
843
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
854
844
  assert isinstance(weights, Mapping)
855
- assert isinstance(weights["weights"], Array)
856
-
857
- unpacked_weights = weights["weights"]
858
-
845
+ unpacked_weights = require_array(weights["weights"])
859
846
  if self.config.weight_quantization_mode == QuantizationMode.UINT4:
860
- unpacked_weights = jax_uint8_to_unpacked_uint4(weights["weights"])
861
-
847
+ unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
862
848
  return replace(
863
849
  self,
864
850
  weights=unpacked_weights.astype(self.weights.dtype),
865
- scales=weights["scales"],
866
- deq_biases=weights["deq_biases"],
867
- biases=weights["biases"] if self.has_biases else None,
851
+ scales=require_array(weights["scales"]),
852
+ deq_biases=require_array(weights["deq_biases"]),
853
+ biases=require_array(weights["biases"]) if self.has_biases else None,
868
854
  )
869
855
 
870
856
 
@@ -1113,7 +1099,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
1113
1099
  self,
1114
1100
  weights: ParameterTree[Array],
1115
1101
  ) -> Self:
1116
- base = super().import_weights(weights)
1102
+ base = cast("Self", super().import_weights(weights)) # ty bug
1117
1103
  assert isinstance(weights, Mapping)
1118
1104
  assert isinstance(weights["up_weights"], Sequence)
1119
1105
  return replace(
@@ -1126,4 +1112,4 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
1126
1112
  LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | MLXQuantizedLinearConfig | QLoRALinearConfig
1127
1113
 
1128
1114
 
1129
- register_config_union(LinearConfig) # type: ignore (pyright bug)
1115
+ register_config_union(LinearConfig)
lalamo/modules/mlp.py CHANGED
@@ -12,7 +12,7 @@ from einops import rearrange
12
12
  from jax import vmap
13
13
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
14
14
 
15
- from lalamo.common import ParameterTree
15
+ from lalamo.common import ParameterTree, require_tree
16
16
  from lalamo.modules.utils import vmap_twice
17
17
 
18
18
  from .activations import Activation
@@ -242,17 +242,12 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
242
242
  "down_projection": self.down_projection.export_weights(),
243
243
  }
244
244
 
245
- def import_weights(
246
- self,
247
- weights: ParameterTree[Array],
248
- ) -> Self:
245
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
249
246
  assert isinstance(weights, Mapping)
250
- assert isinstance(weights["up_projection"], Mapping)
251
- assert isinstance(weights["down_projection"], Mapping)
252
247
  return replace(
253
248
  self,
254
- up_projection=self.up_projection.import_weights(weights["up_projection"]),
255
- down_projection=self.down_projection.import_weights(weights["down_projection"]),
249
+ up_projection=self.up_projection.import_weights(require_tree(weights["up_projection"])),
250
+ down_projection=self.down_projection.import_weights(require_tree(weights["down_projection"])),
256
251
  )
257
252
 
258
253
 
@@ -285,7 +280,7 @@ class SoftmaxRouting(RoutingFunctionBase):
285
280
  RoutingFunction = SoftmaxRouting | DummyUnionMember
286
281
 
287
282
 
288
- register_config_union(RoutingFunction) # type: ignore (pyright bug)
283
+ register_config_union(RoutingFunction)
289
284
 
290
285
 
291
286
  @dataclass(frozen=True)
@@ -486,21 +481,16 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
486
481
  "experts": self.experts.export_weights(),
487
482
  }
488
483
 
489
- def import_weights(
490
- self,
491
- weights: ParameterTree[Array],
492
- ) -> Self:
484
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
493
485
  assert isinstance(weights, Mapping)
494
- assert isinstance(weights["router"], Mapping)
495
- assert isinstance(weights["experts"], Mapping)
496
486
  return replace(
497
487
  self,
498
- router=self.router.import_weights(weights["router"]),
499
- experts=self.experts.import_weights(weights["experts"]),
488
+ router=self.router.import_weights(require_tree(weights["router"])),
489
+ experts=self.experts.import_weights(require_tree(weights["experts"])),
500
490
  )
501
491
 
502
492
 
503
493
  MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
504
494
 
505
495
 
506
- register_config_union(MLPConfig) # type: ignore (pyright bug)
496
+ register_config_union(MLPConfig)
@@ -1,5 +1,5 @@
1
1
  import jax.numpy as jnp
2
- import mlx.core as mx
2
+ import mlx.core as mx # type: ignore
3
3
  from jaxtyping import Array
4
4
 
5
5
  __all__ = ["jax_to_mlx", "mlx_to_jax"]
lalamo/modules/rope.py CHANGED
@@ -281,4 +281,4 @@ class LinearScalingRoPEConfig(RoPEConfigBase):
281
281
 
282
282
  RoPEConfig = UnscaledRoPEConfig | LlamaRoPEConfig | YARNRoPEConfig | LinearScalingRoPEConfig
283
283
 
284
- register_config_union(RoPEConfig) # type: ignore (pyright bug)
284
+ register_config_union(RoPEConfig)
@@ -16,7 +16,7 @@ from .state import (
16
16
 
17
17
  TokenMixerConfig = AttentionConfig | Mamba2Config | ShortConvConfig
18
18
 
19
- register_config_union(TokenMixerConfig) # type: ignore (pyright bug)
19
+ register_config_union(TokenMixerConfig)
20
20
 
21
21
  __all__ = [
22
22
  "Attention",
@@ -10,7 +10,7 @@ from jax import vmap
10
10
  from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
12
  from lalamo.common import dummy_array
13
- from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
13
+ from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector, require_array, require_tree
14
14
  from lalamo.modules.linear import LinearBase, LinearConfig
15
15
  from lalamo.modules.normalization import Normalization, NormalizationConfig
16
16
  from lalamo.modules.rope import PositionalEmbeddings
@@ -433,33 +433,15 @@ class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
433
433
  result["sinks"] = self.sinks
434
434
  return result
435
435
 
436
- def import_weights(
437
- self,
438
- weights: ParameterTree[Array],
439
- ) -> Self:
436
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
440
437
  assert isinstance(weights, Mapping)
441
- assert isinstance(weights["qkv_projection"], Mapping)
442
- assert isinstance(weights["out_projection"], Mapping)
443
- if self.query_norm is not None:
444
- assert isinstance(weights["query_norm"], Mapping)
445
- query_norm = self.query_norm.import_weights(weights["query_norm"])
446
- else:
447
- query_norm = None
448
- if self.key_norm is not None:
449
- assert isinstance(weights["key_norm"], Mapping)
450
- key_norm = self.key_norm.import_weights(weights["key_norm"])
451
- else:
452
- key_norm = None
453
- if self.sinks is not None:
454
- assert isinstance(weights["sinks"], Array)
455
- sinks = weights["sinks"]
456
- else:
457
- sinks = None
458
438
  return replace(
459
439
  self,
460
- qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"]),
461
- out_projection=self.out_projection.import_weights(weights["out_projection"]),
462
- query_norm=query_norm,
463
- key_norm=key_norm,
464
- sinks=sinks,
440
+ qkv_projection=self.qkv_projection.import_weights(require_tree(weights["qkv_projection"])),
441
+ out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
442
+ query_norm=self.query_norm.import_weights(require_tree(weights["query_norm"]))
443
+ if self.query_norm
444
+ else None,
445
+ key_norm=self.key_norm.import_weights(require_tree(weights["key_norm"])) if self.key_norm else None,
446
+ sinks=require_array(weights["sinks"]) if self.sinks is not None else None,
465
447
  )
@@ -10,7 +10,7 @@ from einops import einsum, rearrange
10
10
  from jax import vmap
11
11
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
12
12
 
13
- from lalamo.common import ParameterTree, dummy_array
13
+ from lalamo.common import ParameterTree, dummy_array, require_array, require_tree
14
14
  from lalamo.modules.activations import Activation
15
15
  from lalamo.modules.common import LalamoModule, PositionalEmbeddingSelector
16
16
  from lalamo.modules.linear import LinearBase, LinearConfig
@@ -149,16 +149,10 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
149
149
 
150
150
  def import_weights(self, weights: ParameterTree[Array]) -> "SeparableCausalConv":
151
151
  assert isinstance(weights, Mapping)
152
- assert isinstance(weights["weights"], Array)
153
- if self.biases is not None:
154
- assert isinstance(weights["biases"], Array)
155
- biases = weights["biases"]
156
- else:
157
- biases = None
158
152
  return replace(
159
153
  self,
160
- weights=weights["weights"],
161
- biases=biases,
154
+ weights=require_array(weights["weights"]),
155
+ biases=require_array(weights["biases"]) if self.biases is not None else None,
162
156
  )
163
157
 
164
158
 
@@ -532,22 +526,13 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
532
526
  "gate_bias": self.gate_bias,
533
527
  }
534
528
 
535
- def import_weights(
536
- self,
537
- weights: ParameterTree[Array],
538
- ) -> Self:
529
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
539
530
  assert isinstance(weights, Mapping)
540
- assert isinstance(weights["in_projection"], Mapping)
541
- assert isinstance(weights["out_projection"], Mapping)
542
- assert isinstance(weights["conv"], Mapping)
543
- assert isinstance(weights["skip_connection_weight"], Array)
544
- assert isinstance(weights["gate_bias"], Array)
545
-
546
531
  return replace(
547
532
  self,
548
- in_projection=self.in_projection.import_weights(weights["in_projection"]),
549
- out_projection=self.out_projection.import_weights(weights["out_projection"]),
550
- conv=self.conv.import_weights(weights["conv"]),
551
- skip_connection_weight=weights["skip_connection_weight"],
552
- gate_bias=weights["gate_bias"],
533
+ in_projection=self.in_projection.import_weights(require_tree(weights["in_projection"])),
534
+ out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
535
+ conv=self.conv.import_weights(require_tree(weights["conv"])),
536
+ skip_connection_weight=require_array(weights["skip_connection_weight"]),
537
+ gate_bias=require_array(weights["gate_bias"]),
553
538
  )
@@ -6,7 +6,7 @@ import equinox as eqx
6
6
  from jax import vmap
7
7
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
8
8
 
9
- from lalamo.common import ParameterTree
9
+ from lalamo.common import ParameterTree, require_tree
10
10
  from lalamo.modules.common import PositionalEmbeddingSelector
11
11
  from lalamo.modules.linear import LinearBase, LinearConfig
12
12
  from lalamo.modules.rope import PositionalEmbeddings
@@ -151,18 +151,11 @@ class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
151
151
  "out_projection": self.out_projection.export_weights(),
152
152
  }
153
153
 
154
- def import_weights(
155
- self,
156
- weights: ParameterTree[Array],
157
- ) -> Self:
154
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
158
155
  assert isinstance(weights, Mapping)
159
- assert isinstance(weights["in_projection"], Mapping)
160
- assert isinstance(weights["conv"], Mapping)
161
- assert isinstance(weights["out_projection"], Mapping)
162
-
163
156
  return replace(
164
157
  self,
165
- in_projection=self.in_projection.import_weights(weights["in_projection"]),
166
- conv=self.conv.import_weights(weights["conv"]),
167
- out_projection=self.out_projection.import_weights(weights["out_projection"]),
158
+ in_projection=self.in_projection.import_weights(require_tree(weights["in_projection"])),
159
+ conv=self.conv.import_weights(require_tree(weights["conv"])),
160
+ out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
168
161
  )
@@ -7,7 +7,7 @@ import jax
7
7
  from jax import vmap
8
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
- from lalamo.common import ParameterTree
10
+ from lalamo.common import ParameterTree, require_tree
11
11
  from lalamo.modules.token_mixers import AttentionConfig
12
12
  from lalamo.modules.utils import vmap_twice
13
13
 
@@ -182,7 +182,8 @@ class Transformer(LalamoModule[TransformerConfig]):
182
182
  ) -> TransformerResult:
183
183
  if inner_features.ndim != 3:
184
184
  raise ValueError(
185
- f"inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim), got {inner_features.shape}",
185
+ "inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim),"
186
+ f" got {inner_features.shape}",
186
187
  )
187
188
  if token_positions.ndim != 2:
188
189
  raise ValueError(
@@ -251,35 +252,24 @@ class Transformer(LalamoModule[TransformerConfig]):
251
252
  result["local_rope"] = self.local_rope.export_weights()
252
253
  return result
253
254
 
254
- def import_weights(
255
- self,
256
- weights: ParameterTree[Array],
257
- ) -> Self:
255
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
258
256
  assert isinstance(weights, Mapping)
259
257
  assert isinstance(weights["layers"], Sequence)
260
- assert isinstance(weights["output_norm"], Mapping)
261
-
262
258
  if self.global_rope:
263
- assert isinstance(weights["global_rope"], Mapping)
264
- global_rope = self.global_rope.import_weights(weights["global_rope"])
259
+ global_rope = self.global_rope.import_weights(require_tree(weights["global_rope"]))
265
260
  else:
266
261
  global_rope = None
267
-
268
262
  if self.local_rope:
269
- assert isinstance(weights["local_rope"], Mapping)
270
- local_rope = self.local_rope.import_weights(weights["local_rope"])
263
+ local_rope = self.local_rope.import_weights(require_tree(weights["local_rope"]))
271
264
  else:
272
265
  local_rope = None
273
-
274
- layers = []
275
- for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
276
- assert isinstance(layer_weights, Mapping)
277
- layers.append(layer.import_weights(layer_weights))
278
-
266
+ layers = [
267
+ layer.import_weights(require_tree(lw)) for layer, lw in zip(self.layers, weights["layers"], strict=True)
268
+ ]
279
269
  return replace(
280
270
  self,
281
271
  global_rope=global_rope,
282
272
  layers=tuple(layers),
283
- output_norm=self.output_norm.import_weights(weights["output_norm"]),
273
+ output_norm=self.output_norm.import_weights(require_tree(weights["output_norm"])),
284
274
  local_rope=local_rope,
285
275
  )
@@ -9,7 +9,7 @@ import jax.numpy as jnp
9
9
  from jax import vmap
10
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
- from lalamo.common import ParameterTree
12
+ from lalamo.common import ParameterTree, require_tree
13
13
 
14
14
  from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
15
15
  from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
@@ -293,38 +293,26 @@ class TransformerLayer(LalamoModule[TransformerLayerConfig]):
293
293
  result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
294
294
  return result
295
295
 
296
- def import_weights(
297
- self,
298
- weights: ParameterTree[Array],
299
- ) -> Self:
296
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
300
297
  assert isinstance(weights, Mapping)
301
- assert isinstance(weights["mixer"], Mapping)
302
- assert isinstance(weights["mlp"], Mapping)
303
- assert isinstance(weights["pre_mlp_norm"], Mapping)
304
-
305
298
  if self.post_mixer_norm is not None:
306
- assert isinstance(weights["post_mixer_norm"], Mapping)
307
- post_mixer_norm = self.post_mixer_norm.import_weights(
308
- weights["post_mixer_norm"],
309
- )
299
+ post_mixer_norm = self.post_mixer_norm.import_weights(require_tree(weights["post_mixer_norm"]))
310
300
  else:
311
301
  post_mixer_norm = None
312
302
  if self.post_mlp_norm is not None:
313
- assert isinstance(weights["post_mlp_norm"], Mapping)
314
- post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
303
+ post_mlp_norm = self.post_mlp_norm.import_weights(require_tree(weights["post_mlp_norm"]))
315
304
  else:
316
305
  post_mlp_norm = None
317
306
  if self.pre_mixer_norm is not None:
318
- assert isinstance(weights["pre_mixer_norm"], Mapping)
319
- pre_mixer_norm = self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"])
307
+ pre_mixer_norm = self.pre_mixer_norm.import_weights(require_tree(weights["pre_mixer_norm"]))
320
308
  else:
321
309
  pre_mixer_norm = None
322
310
  return replace(
323
311
  self,
324
312
  pre_mixer_norm=pre_mixer_norm,
325
- mixer=self.mixer.import_weights(weights["mixer"]),
313
+ mixer=self.mixer.import_weights(require_tree(weights["mixer"])),
326
314
  post_mixer_norm=post_mixer_norm,
327
- pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
328
- mlp=self.mlp.import_weights(weights["mlp"]),
315
+ pre_mlp_norm=self.pre_mlp_norm.import_weights(require_tree(weights["pre_mlp_norm"])),
316
+ mlp=self.mlp.import_weights(require_tree(weights["mlp"])),
329
317
  post_mlp_norm=post_mlp_norm,
330
318
  )
lalamo/registry_abc.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, ABCMeta
2
- from typing import Any
2
+ from typing import Any, Self
3
3
  from weakref import WeakSet
4
4
 
5
5
  __all__ = ["RegistryABC", "RegistryMeta"]
@@ -29,7 +29,7 @@ class RegistryMeta(ABCMeta):
29
29
 
30
30
  # Detect and remember the root exactly once
31
31
  if RegistryMeta._ROOT is None and name == "RegistryABC":
32
- RegistryMeta._ROOT = cls # type: ignore[assignment]
32
+ RegistryMeta._ROOT = cls
33
33
  return
34
34
 
35
35
  root = RegistryMeta._ROOT
@@ -58,6 +58,6 @@ class RegistryABC(ABC, metaclass=RegistryMeta):
58
58
  """
59
59
 
60
60
  @classmethod
61
- def __descendants__(cls) -> tuple[type, ...]:
62
- reg: WeakSet[type] = getattr(cls, RegistryMeta._REG_ATTR) # noqa: SLF001
61
+ def __descendants__(cls) -> tuple[type[Self], ...]:
62
+ reg: WeakSet[type[Self]] = getattr(cls, RegistryMeta._REG_ATTR) # noqa: SLF001
63
63
  return tuple(reg)