lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +26 -2
- lalamo/commands.py +429 -0
- lalamo/common.py +14 -1
- lalamo/main.py +375 -229
- lalamo/message_processor.py +4 -1
- lalamo/model_import/common.py +8 -17
- lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
- lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
- lalamo/model_import/huggingface_generation_config.py +21 -3
- lalamo/model_import/loaders/executorch.py +2 -2
- lalamo/model_import/loaders/huggingface.py +3 -3
- lalamo/model_import/model_specs/common.py +8 -4
- lalamo/model_import/model_specs/lfm2.py +41 -9
- lalamo/models/common.py +3 -3
- lalamo/models/language_model.py +7 -6
- lalamo/modules/activations.py +1 -1
- lalamo/modules/classifier.py +11 -24
- lalamo/modules/common.py +4 -1
- lalamo/modules/decoder.py +5 -11
- lalamo/modules/embedding.py +25 -62
- lalamo/modules/linear.py +19 -33
- lalamo/modules/mlp.py +9 -19
- lalamo/modules/mlx_interop.py +1 -1
- lalamo/modules/rope.py +1 -1
- lalamo/modules/token_mixers/__init__.py +1 -1
- lalamo/modules/token_mixers/attention.py +9 -27
- lalamo/modules/token_mixers/mamba.py +9 -24
- lalamo/modules/token_mixers/short_conv.py +5 -12
- lalamo/modules/transformer.py +10 -20
- lalamo/modules/transformer_layer.py +8 -20
- lalamo/registry_abc.py +4 -4
- lalamo/safetensors.py +97 -0
- lalamo/sampling.py +14 -0
- lalamo/speculator/estimator.py +11 -4
- lalamo/speculator/ngram.py +1 -1
- lalamo/utils.py +0 -13
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
lalamo/modules/embedding.py
CHANGED
|
@@ -9,7 +9,7 @@ import jax.numpy as jnp
|
|
|
9
9
|
from einops import rearrange
|
|
10
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree, dummy_array
|
|
12
|
+
from lalamo.common import ParameterTree, dummy_array, require_array
|
|
13
13
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
14
14
|
from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
|
|
15
15
|
|
|
@@ -355,21 +355,15 @@ class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
|
355
355
|
"scales": self.scales,
|
|
356
356
|
}
|
|
357
357
|
|
|
358
|
-
def import_weights(
|
|
359
|
-
self,
|
|
360
|
-
weights: ParameterTree[Array],
|
|
361
|
-
) -> Self:
|
|
358
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
362
359
|
assert isinstance(weights, Mapping)
|
|
363
|
-
|
|
364
|
-
stored_weights = weights["weights"]
|
|
365
|
-
|
|
360
|
+
stored_weights = require_array(weights["weights"])
|
|
366
361
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
367
362
|
stored_weights = jax_uint8_to_unpacked_uint4(stored_weights)
|
|
368
|
-
|
|
369
363
|
return replace(
|
|
370
364
|
self,
|
|
371
365
|
weights=stored_weights.astype(self.weights.dtype),
|
|
372
|
-
scales=weights["scales"],
|
|
366
|
+
scales=require_array(weights["scales"]),
|
|
373
367
|
)
|
|
374
368
|
|
|
375
369
|
|
|
@@ -472,25 +466,16 @@ class MLXQuantizedTiedEmbedding(EmbeddingBase[MLXQuantizedTiedEmbeddingConfig]):
|
|
|
472
466
|
"biases": self.biases,
|
|
473
467
|
}
|
|
474
468
|
|
|
475
|
-
def import_weights(
|
|
476
|
-
self,
|
|
477
|
-
weights: ParameterTree[Array],
|
|
478
|
-
) -> Self:
|
|
469
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
479
470
|
assert isinstance(weights, Mapping)
|
|
480
|
-
|
|
481
|
-
assert isinstance(weights["scales"], Array)
|
|
482
|
-
assert isinstance(weights["biases"], Array)
|
|
483
|
-
|
|
484
|
-
unpacked_weights = weights["weights"]
|
|
485
|
-
|
|
471
|
+
unpacked_weights = require_array(weights["weights"])
|
|
486
472
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
487
|
-
unpacked_weights = jax_uint8_to_unpacked_uint4(
|
|
488
|
-
|
|
473
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
|
|
489
474
|
return replace(
|
|
490
475
|
self,
|
|
491
476
|
weights=unpacked_weights.astype(self.weights.dtype),
|
|
492
|
-
scales=weights["scales"],
|
|
493
|
-
biases=weights["biases"],
|
|
477
|
+
scales=require_array(weights["scales"]),
|
|
478
|
+
biases=require_array(weights["biases"]),
|
|
494
479
|
)
|
|
495
480
|
|
|
496
481
|
|
|
@@ -630,33 +615,21 @@ class MLXQuantizedUntiedEmbedding(EmbeddingBase[MLXQuantizedUntiedEmbeddingConfi
|
|
|
630
615
|
"output_biases": self.output_biases,
|
|
631
616
|
}
|
|
632
617
|
|
|
633
|
-
def import_weights(
|
|
634
|
-
self,
|
|
635
|
-
weights: ParameterTree[Array],
|
|
636
|
-
) -> Self:
|
|
618
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
637
619
|
assert isinstance(weights, Mapping)
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
assert isinstance(weights["input_biases"], Array)
|
|
641
|
-
assert isinstance(weights["output_weights"], Array)
|
|
642
|
-
assert isinstance(weights["output_scales"], Array)
|
|
643
|
-
assert isinstance(weights["output_biases"], Array)
|
|
644
|
-
|
|
645
|
-
unpacked_input_weights = weights["input_weights"]
|
|
646
|
-
unpacked_output_weights = weights["output_weights"]
|
|
647
|
-
|
|
620
|
+
unpacked_input_weights = require_array(weights["input_weights"])
|
|
621
|
+
unpacked_output_weights = require_array(weights["output_weights"])
|
|
648
622
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
649
|
-
unpacked_input_weights = jax_uint8_to_unpacked_uint4(
|
|
650
|
-
unpacked_output_weights = jax_uint8_to_unpacked_uint4(
|
|
651
|
-
|
|
623
|
+
unpacked_input_weights = jax_uint8_to_unpacked_uint4(unpacked_input_weights)
|
|
624
|
+
unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
|
|
652
625
|
return replace(
|
|
653
626
|
self,
|
|
654
627
|
input_weights=unpacked_input_weights.astype(self.input_weights.dtype),
|
|
655
|
-
input_scales=weights["input_scales"],
|
|
656
|
-
input_biases=weights["input_biases"],
|
|
628
|
+
input_scales=require_array(weights["input_scales"]),
|
|
629
|
+
input_biases=require_array(weights["input_biases"]),
|
|
657
630
|
output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
|
|
658
|
-
output_scales=weights["output_scales"],
|
|
659
|
-
output_biases=weights["output_biases"],
|
|
631
|
+
output_scales=require_array(weights["output_scales"]),
|
|
632
|
+
output_biases=require_array(weights["output_biases"]),
|
|
660
633
|
)
|
|
661
634
|
|
|
662
635
|
|
|
@@ -765,27 +738,17 @@ class MLXSemiQuantizedUntiedEmbedding(EmbeddingBase[MLXSemiQuantizedUntiedEmbedd
|
|
|
765
738
|
"output_biases": self.output_biases,
|
|
766
739
|
}
|
|
767
740
|
|
|
768
|
-
def import_weights(
|
|
769
|
-
self,
|
|
770
|
-
weights: ParameterTree[Array],
|
|
771
|
-
) -> Self:
|
|
741
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
772
742
|
assert isinstance(weights, Mapping)
|
|
773
|
-
|
|
774
|
-
assert isinstance(weights["output_weights"], Array)
|
|
775
|
-
assert isinstance(weights["output_scales"], Array)
|
|
776
|
-
assert isinstance(weights["output_biases"], Array)
|
|
777
|
-
|
|
778
|
-
unpacked_output_weights = weights["output_weights"]
|
|
779
|
-
|
|
743
|
+
unpacked_output_weights = require_array(weights["output_weights"])
|
|
780
744
|
if self.config.embedding_quantization_mode == QuantizationMode.UINT4:
|
|
781
|
-
unpacked_output_weights = jax_uint8_to_unpacked_uint4(
|
|
782
|
-
|
|
745
|
+
unpacked_output_weights = jax_uint8_to_unpacked_uint4(unpacked_output_weights)
|
|
783
746
|
return replace(
|
|
784
747
|
self,
|
|
785
|
-
input_weights=weights["input_weights"],
|
|
748
|
+
input_weights=require_array(weights["input_weights"]),
|
|
786
749
|
output_weights=unpacked_output_weights.astype(self.output_weights.dtype),
|
|
787
|
-
output_scales=weights["output_scales"],
|
|
788
|
-
output_biases=weights["output_biases"],
|
|
750
|
+
output_scales=require_array(weights["output_scales"]),
|
|
751
|
+
output_biases=require_array(weights["output_biases"]),
|
|
789
752
|
)
|
|
790
753
|
|
|
791
754
|
|
|
@@ -799,4 +762,4 @@ EmbeddingConfig = (
|
|
|
799
762
|
)
|
|
800
763
|
|
|
801
764
|
|
|
802
|
-
register_config_union(EmbeddingConfig)
|
|
765
|
+
register_config_union(EmbeddingConfig)
|
lalamo/modules/linear.py
CHANGED
|
@@ -2,7 +2,7 @@ import math
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections.abc import Mapping, Sequence
|
|
4
4
|
from dataclasses import dataclass, replace
|
|
5
|
-
from typing import Self
|
|
5
|
+
from typing import Self, cast
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
import jax
|
|
@@ -10,7 +10,7 @@ import jax.numpy as jnp
|
|
|
10
10
|
from einops import rearrange
|
|
11
11
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
12
12
|
|
|
13
|
-
from lalamo.common import ParameterTree, dummy_array
|
|
13
|
+
from lalamo.common import ParameterTree, dummy_array, require_array
|
|
14
14
|
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
15
15
|
from lalamo.utils import jax_uint4_to_packed_uint8, jax_uint8_to_unpacked_uint4
|
|
16
16
|
|
|
@@ -464,7 +464,7 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
|
|
|
464
464
|
|
|
465
465
|
return packed
|
|
466
466
|
|
|
467
|
-
def __post_init__(self) -> None:
|
|
467
|
+
def __post_init__(self) -> None:
|
|
468
468
|
if self.weights.dtype != self.config.activation_precision:
|
|
469
469
|
raise ValueError(
|
|
470
470
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
@@ -572,26 +572,19 @@ class GroupQuantizedLinearBase[ConfigT: GroupQuantizedLinearConfig](QuantizedLin
|
|
|
572
572
|
result["biases"] = self.biases
|
|
573
573
|
return result
|
|
574
574
|
|
|
575
|
-
def import_weights(
|
|
576
|
-
self,
|
|
577
|
-
weights: ParameterTree[Array],
|
|
578
|
-
) -> Self:
|
|
575
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
579
576
|
assert isinstance(weights, Mapping)
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
unpacked_weights = weights["weights"]
|
|
583
|
-
unpacked_zero_points = weights["zero_points"]
|
|
584
|
-
|
|
577
|
+
unpacked_weights = require_array(weights["weights"])
|
|
578
|
+
unpacked_zero_points = require_array(weights["zero_points"])
|
|
585
579
|
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
586
|
-
unpacked_weights = jax_uint8_to_unpacked_uint4(
|
|
587
|
-
unpacked_zero_points = jax_uint8_to_unpacked_uint4(
|
|
588
|
-
|
|
580
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
|
|
581
|
+
unpacked_zero_points = jax_uint8_to_unpacked_uint4(unpacked_zero_points)
|
|
589
582
|
return replace(
|
|
590
583
|
self,
|
|
591
584
|
weights=unpacked_weights.astype(self.weights.dtype),
|
|
592
|
-
scales=weights["scales"],
|
|
585
|
+
scales=require_array(weights["scales"]),
|
|
593
586
|
zero_points=unpacked_zero_points.astype(self.zero_points.dtype),
|
|
594
|
-
biases=weights["biases"] if self.has_biases else None,
|
|
587
|
+
biases=require_array(weights["biases"]) if self.has_biases else None,
|
|
595
588
|
)
|
|
596
589
|
|
|
597
590
|
|
|
@@ -740,7 +733,7 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
|
|
|
740
733
|
|
|
741
734
|
return packed
|
|
742
735
|
|
|
743
|
-
def __post_init__(self) -> None:
|
|
736
|
+
def __post_init__(self) -> None:
|
|
744
737
|
if self.weights.dtype != self.config.activation_precision:
|
|
745
738
|
raise ValueError(
|
|
746
739
|
f"Weight dtype ({self.weights.dtype}) is not equal to specified activation precision"
|
|
@@ -847,24 +840,17 @@ class MLXQuantizedLinearBase[ConfigT: MLXQuantizedLinearConfig](QuantizedLinearB
|
|
|
847
840
|
result["biases"] = self.biases
|
|
848
841
|
return result
|
|
849
842
|
|
|
850
|
-
def import_weights(
|
|
851
|
-
self,
|
|
852
|
-
weights: ParameterTree[Array],
|
|
853
|
-
) -> Self:
|
|
843
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
854
844
|
assert isinstance(weights, Mapping)
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
unpacked_weights = weights["weights"]
|
|
858
|
-
|
|
845
|
+
unpacked_weights = require_array(weights["weights"])
|
|
859
846
|
if self.config.weight_quantization_mode == QuantizationMode.UINT4:
|
|
860
|
-
unpacked_weights = jax_uint8_to_unpacked_uint4(
|
|
861
|
-
|
|
847
|
+
unpacked_weights = jax_uint8_to_unpacked_uint4(unpacked_weights)
|
|
862
848
|
return replace(
|
|
863
849
|
self,
|
|
864
850
|
weights=unpacked_weights.astype(self.weights.dtype),
|
|
865
|
-
scales=weights["scales"],
|
|
866
|
-
deq_biases=weights["deq_biases"],
|
|
867
|
-
biases=weights["biases"] if self.has_biases else None,
|
|
851
|
+
scales=require_array(weights["scales"]),
|
|
852
|
+
deq_biases=require_array(weights["deq_biases"]),
|
|
853
|
+
biases=require_array(weights["biases"]) if self.has_biases else None,
|
|
868
854
|
)
|
|
869
855
|
|
|
870
856
|
|
|
@@ -1113,7 +1099,7 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
1113
1099
|
self,
|
|
1114
1100
|
weights: ParameterTree[Array],
|
|
1115
1101
|
) -> Self:
|
|
1116
|
-
base = super().import_weights(weights)
|
|
1102
|
+
base = cast("Self", super().import_weights(weights)) # ty bug
|
|
1117
1103
|
assert isinstance(weights, Mapping)
|
|
1118
1104
|
assert isinstance(weights["up_weights"], Sequence)
|
|
1119
1105
|
return replace(
|
|
@@ -1126,4 +1112,4 @@ class QLoRALinear(GroupQuantizedLinearBase[QLoRALinearConfig]):
|
|
|
1126
1112
|
LinearConfig = FullPrecisionLinearConfig | GroupQuantizedLinearConfig | MLXQuantizedLinearConfig | QLoRALinearConfig
|
|
1127
1113
|
|
|
1128
1114
|
|
|
1129
|
-
register_config_union(LinearConfig)
|
|
1115
|
+
register_config_union(LinearConfig)
|
lalamo/modules/mlp.py
CHANGED
|
@@ -12,7 +12,7 @@ from einops import rearrange
|
|
|
12
12
|
from jax import vmap
|
|
13
13
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
14
14
|
|
|
15
|
-
from lalamo.common import ParameterTree
|
|
15
|
+
from lalamo.common import ParameterTree, require_tree
|
|
16
16
|
from lalamo.modules.utils import vmap_twice
|
|
17
17
|
|
|
18
18
|
from .activations import Activation
|
|
@@ -242,17 +242,12 @@ class DenseMLP(MLPBase[DenseMLPConfig]):
|
|
|
242
242
|
"down_projection": self.down_projection.export_weights(),
|
|
243
243
|
}
|
|
244
244
|
|
|
245
|
-
def import_weights(
|
|
246
|
-
self,
|
|
247
|
-
weights: ParameterTree[Array],
|
|
248
|
-
) -> Self:
|
|
245
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
249
246
|
assert isinstance(weights, Mapping)
|
|
250
|
-
assert isinstance(weights["up_projection"], Mapping)
|
|
251
|
-
assert isinstance(weights["down_projection"], Mapping)
|
|
252
247
|
return replace(
|
|
253
248
|
self,
|
|
254
|
-
up_projection=self.up_projection.import_weights(weights["up_projection"]),
|
|
255
|
-
down_projection=self.down_projection.import_weights(weights["down_projection"]),
|
|
249
|
+
up_projection=self.up_projection.import_weights(require_tree(weights["up_projection"])),
|
|
250
|
+
down_projection=self.down_projection.import_weights(require_tree(weights["down_projection"])),
|
|
256
251
|
)
|
|
257
252
|
|
|
258
253
|
|
|
@@ -285,7 +280,7 @@ class SoftmaxRouting(RoutingFunctionBase):
|
|
|
285
280
|
RoutingFunction = SoftmaxRouting | DummyUnionMember
|
|
286
281
|
|
|
287
282
|
|
|
288
|
-
register_config_union(RoutingFunction)
|
|
283
|
+
register_config_union(RoutingFunction)
|
|
289
284
|
|
|
290
285
|
|
|
291
286
|
@dataclass(frozen=True)
|
|
@@ -486,21 +481,16 @@ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
|
|
|
486
481
|
"experts": self.experts.export_weights(),
|
|
487
482
|
}
|
|
488
483
|
|
|
489
|
-
def import_weights(
|
|
490
|
-
self,
|
|
491
|
-
weights: ParameterTree[Array],
|
|
492
|
-
) -> Self:
|
|
484
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
493
485
|
assert isinstance(weights, Mapping)
|
|
494
|
-
assert isinstance(weights["router"], Mapping)
|
|
495
|
-
assert isinstance(weights["experts"], Mapping)
|
|
496
486
|
return replace(
|
|
497
487
|
self,
|
|
498
|
-
router=self.router.import_weights(weights["router"]),
|
|
499
|
-
experts=self.experts.import_weights(weights["experts"]),
|
|
488
|
+
router=self.router.import_weights(require_tree(weights["router"])),
|
|
489
|
+
experts=self.experts.import_weights(require_tree(weights["experts"])),
|
|
500
490
|
)
|
|
501
491
|
|
|
502
492
|
|
|
503
493
|
MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
|
|
504
494
|
|
|
505
495
|
|
|
506
|
-
register_config_union(MLPConfig)
|
|
496
|
+
register_config_union(MLPConfig)
|
lalamo/modules/mlx_interop.py
CHANGED
lalamo/modules/rope.py
CHANGED
|
@@ -10,7 +10,7 @@ from jax import vmap
|
|
|
10
10
|
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
12
|
from lalamo.common import dummy_array
|
|
13
|
-
from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector
|
|
13
|
+
from lalamo.modules.common import ParameterTree, PositionalEmbeddingSelector, require_array, require_tree
|
|
14
14
|
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
15
15
|
from lalamo.modules.normalization import Normalization, NormalizationConfig
|
|
16
16
|
from lalamo.modules.rope import PositionalEmbeddings
|
|
@@ -433,33 +433,15 @@ class Attention(TokenMixerBase[AttentionConfig, KVCacheLayer]):
|
|
|
433
433
|
result["sinks"] = self.sinks
|
|
434
434
|
return result
|
|
435
435
|
|
|
436
|
-
def import_weights(
|
|
437
|
-
self,
|
|
438
|
-
weights: ParameterTree[Array],
|
|
439
|
-
) -> Self:
|
|
436
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
440
437
|
assert isinstance(weights, Mapping)
|
|
441
|
-
assert isinstance(weights["qkv_projection"], Mapping)
|
|
442
|
-
assert isinstance(weights["out_projection"], Mapping)
|
|
443
|
-
if self.query_norm is not None:
|
|
444
|
-
assert isinstance(weights["query_norm"], Mapping)
|
|
445
|
-
query_norm = self.query_norm.import_weights(weights["query_norm"])
|
|
446
|
-
else:
|
|
447
|
-
query_norm = None
|
|
448
|
-
if self.key_norm is not None:
|
|
449
|
-
assert isinstance(weights["key_norm"], Mapping)
|
|
450
|
-
key_norm = self.key_norm.import_weights(weights["key_norm"])
|
|
451
|
-
else:
|
|
452
|
-
key_norm = None
|
|
453
|
-
if self.sinks is not None:
|
|
454
|
-
assert isinstance(weights["sinks"], Array)
|
|
455
|
-
sinks = weights["sinks"]
|
|
456
|
-
else:
|
|
457
|
-
sinks = None
|
|
458
438
|
return replace(
|
|
459
439
|
self,
|
|
460
|
-
qkv_projection=self.qkv_projection.import_weights(weights["qkv_projection"]),
|
|
461
|
-
out_projection=self.out_projection.import_weights(weights["out_projection"]),
|
|
462
|
-
query_norm=query_norm
|
|
463
|
-
|
|
464
|
-
|
|
440
|
+
qkv_projection=self.qkv_projection.import_weights(require_tree(weights["qkv_projection"])),
|
|
441
|
+
out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
|
|
442
|
+
query_norm=self.query_norm.import_weights(require_tree(weights["query_norm"]))
|
|
443
|
+
if self.query_norm
|
|
444
|
+
else None,
|
|
445
|
+
key_norm=self.key_norm.import_weights(require_tree(weights["key_norm"])) if self.key_norm else None,
|
|
446
|
+
sinks=require_array(weights["sinks"]) if self.sinks is not None else None,
|
|
465
447
|
)
|
|
@@ -10,7 +10,7 @@ from einops import einsum, rearrange
|
|
|
10
10
|
from jax import vmap
|
|
11
11
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
12
12
|
|
|
13
|
-
from lalamo.common import ParameterTree, dummy_array
|
|
13
|
+
from lalamo.common import ParameterTree, dummy_array, require_array, require_tree
|
|
14
14
|
from lalamo.modules.activations import Activation
|
|
15
15
|
from lalamo.modules.common import LalamoModule, PositionalEmbeddingSelector
|
|
16
16
|
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
@@ -149,16 +149,10 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
|
|
|
149
149
|
|
|
150
150
|
def import_weights(self, weights: ParameterTree[Array]) -> "SeparableCausalConv":
|
|
151
151
|
assert isinstance(weights, Mapping)
|
|
152
|
-
assert isinstance(weights["weights"], Array)
|
|
153
|
-
if self.biases is not None:
|
|
154
|
-
assert isinstance(weights["biases"], Array)
|
|
155
|
-
biases = weights["biases"]
|
|
156
|
-
else:
|
|
157
|
-
biases = None
|
|
158
152
|
return replace(
|
|
159
153
|
self,
|
|
160
|
-
weights=weights["weights"],
|
|
161
|
-
biases=biases,
|
|
154
|
+
weights=require_array(weights["weights"]),
|
|
155
|
+
biases=require_array(weights["biases"]) if self.biases is not None else None,
|
|
162
156
|
)
|
|
163
157
|
|
|
164
158
|
|
|
@@ -532,22 +526,13 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
|
|
|
532
526
|
"gate_bias": self.gate_bias,
|
|
533
527
|
}
|
|
534
528
|
|
|
535
|
-
def import_weights(
|
|
536
|
-
self,
|
|
537
|
-
weights: ParameterTree[Array],
|
|
538
|
-
) -> Self:
|
|
529
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
539
530
|
assert isinstance(weights, Mapping)
|
|
540
|
-
assert isinstance(weights["in_projection"], Mapping)
|
|
541
|
-
assert isinstance(weights["out_projection"], Mapping)
|
|
542
|
-
assert isinstance(weights["conv"], Mapping)
|
|
543
|
-
assert isinstance(weights["skip_connection_weight"], Array)
|
|
544
|
-
assert isinstance(weights["gate_bias"], Array)
|
|
545
|
-
|
|
546
531
|
return replace(
|
|
547
532
|
self,
|
|
548
|
-
in_projection=self.in_projection.import_weights(weights["in_projection"]),
|
|
549
|
-
out_projection=self.out_projection.import_weights(weights["out_projection"]),
|
|
550
|
-
conv=self.conv.import_weights(weights["conv"]),
|
|
551
|
-
skip_connection_weight=weights["skip_connection_weight"],
|
|
552
|
-
gate_bias=weights["gate_bias"],
|
|
533
|
+
in_projection=self.in_projection.import_weights(require_tree(weights["in_projection"])),
|
|
534
|
+
out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
|
|
535
|
+
conv=self.conv.import_weights(require_tree(weights["conv"])),
|
|
536
|
+
skip_connection_weight=require_array(weights["skip_connection_weight"]),
|
|
537
|
+
gate_bias=require_array(weights["gate_bias"]),
|
|
553
538
|
)
|
|
@@ -6,7 +6,7 @@ import equinox as eqx
|
|
|
6
6
|
from jax import vmap
|
|
7
7
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
8
8
|
|
|
9
|
-
from lalamo.common import ParameterTree
|
|
9
|
+
from lalamo.common import ParameterTree, require_tree
|
|
10
10
|
from lalamo.modules.common import PositionalEmbeddingSelector
|
|
11
11
|
from lalamo.modules.linear import LinearBase, LinearConfig
|
|
12
12
|
from lalamo.modules.rope import PositionalEmbeddings
|
|
@@ -151,18 +151,11 @@ class ShortConv(TokenMixerBase[ShortConvConfig, ShortConvStateLayer]):
|
|
|
151
151
|
"out_projection": self.out_projection.export_weights(),
|
|
152
152
|
}
|
|
153
153
|
|
|
154
|
-
def import_weights(
|
|
155
|
-
self,
|
|
156
|
-
weights: ParameterTree[Array],
|
|
157
|
-
) -> Self:
|
|
154
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
158
155
|
assert isinstance(weights, Mapping)
|
|
159
|
-
assert isinstance(weights["in_projection"], Mapping)
|
|
160
|
-
assert isinstance(weights["conv"], Mapping)
|
|
161
|
-
assert isinstance(weights["out_projection"], Mapping)
|
|
162
|
-
|
|
163
156
|
return replace(
|
|
164
157
|
self,
|
|
165
|
-
in_projection=self.in_projection.import_weights(weights["in_projection"]),
|
|
166
|
-
conv=self.conv.import_weights(weights["conv"]),
|
|
167
|
-
out_projection=self.out_projection.import_weights(weights["out_projection"]),
|
|
158
|
+
in_projection=self.in_projection.import_weights(require_tree(weights["in_projection"])),
|
|
159
|
+
conv=self.conv.import_weights(require_tree(weights["conv"])),
|
|
160
|
+
out_projection=self.out_projection.import_weights(require_tree(weights["out_projection"])),
|
|
168
161
|
)
|
lalamo/modules/transformer.py
CHANGED
|
@@ -7,7 +7,7 @@ import jax
|
|
|
7
7
|
from jax import vmap
|
|
8
8
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
9
|
|
|
10
|
-
from lalamo.common import ParameterTree
|
|
10
|
+
from lalamo.common import ParameterTree, require_tree
|
|
11
11
|
from lalamo.modules.token_mixers import AttentionConfig
|
|
12
12
|
from lalamo.modules.utils import vmap_twice
|
|
13
13
|
|
|
@@ -182,7 +182,8 @@ class Transformer(LalamoModule[TransformerConfig]):
|
|
|
182
182
|
) -> TransformerResult:
|
|
183
183
|
if inner_features.ndim != 3:
|
|
184
184
|
raise ValueError(
|
|
185
|
-
|
|
185
|
+
"inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim),"
|
|
186
|
+
f" got {inner_features.shape}",
|
|
186
187
|
)
|
|
187
188
|
if token_positions.ndim != 2:
|
|
188
189
|
raise ValueError(
|
|
@@ -251,35 +252,24 @@ class Transformer(LalamoModule[TransformerConfig]):
|
|
|
251
252
|
result["local_rope"] = self.local_rope.export_weights()
|
|
252
253
|
return result
|
|
253
254
|
|
|
254
|
-
def import_weights(
|
|
255
|
-
self,
|
|
256
|
-
weights: ParameterTree[Array],
|
|
257
|
-
) -> Self:
|
|
255
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
258
256
|
assert isinstance(weights, Mapping)
|
|
259
257
|
assert isinstance(weights["layers"], Sequence)
|
|
260
|
-
assert isinstance(weights["output_norm"], Mapping)
|
|
261
|
-
|
|
262
258
|
if self.global_rope:
|
|
263
|
-
|
|
264
|
-
global_rope = self.global_rope.import_weights(weights["global_rope"])
|
|
259
|
+
global_rope = self.global_rope.import_weights(require_tree(weights["global_rope"]))
|
|
265
260
|
else:
|
|
266
261
|
global_rope = None
|
|
267
|
-
|
|
268
262
|
if self.local_rope:
|
|
269
|
-
|
|
270
|
-
local_rope = self.local_rope.import_weights(weights["local_rope"])
|
|
263
|
+
local_rope = self.local_rope.import_weights(require_tree(weights["local_rope"]))
|
|
271
264
|
else:
|
|
272
265
|
local_rope = None
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
assert isinstance(layer_weights, Mapping)
|
|
277
|
-
layers.append(layer.import_weights(layer_weights))
|
|
278
|
-
|
|
266
|
+
layers = [
|
|
267
|
+
layer.import_weights(require_tree(lw)) for layer, lw in zip(self.layers, weights["layers"], strict=True)
|
|
268
|
+
]
|
|
279
269
|
return replace(
|
|
280
270
|
self,
|
|
281
271
|
global_rope=global_rope,
|
|
282
272
|
layers=tuple(layers),
|
|
283
|
-
output_norm=self.output_norm.import_weights(weights["output_norm"]),
|
|
273
|
+
output_norm=self.output_norm.import_weights(require_tree(weights["output_norm"])),
|
|
284
274
|
local_rope=local_rope,
|
|
285
275
|
)
|
|
@@ -9,7 +9,7 @@ import jax.numpy as jnp
|
|
|
9
9
|
from jax import vmap
|
|
10
10
|
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
11
|
|
|
12
|
-
from lalamo.common import ParameterTree
|
|
12
|
+
from lalamo.common import ParameterTree, require_tree
|
|
13
13
|
|
|
14
14
|
from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
|
|
15
15
|
from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
|
|
@@ -293,38 +293,26 @@ class TransformerLayer(LalamoModule[TransformerLayerConfig]):
|
|
|
293
293
|
result["post_mlp_norm"] = self.post_mlp_norm.export_weights()
|
|
294
294
|
return result
|
|
295
295
|
|
|
296
|
-
def import_weights(
|
|
297
|
-
self,
|
|
298
|
-
weights: ParameterTree[Array],
|
|
299
|
-
) -> Self:
|
|
296
|
+
def import_weights(self, weights: ParameterTree[Array]) -> Self:
|
|
300
297
|
assert isinstance(weights, Mapping)
|
|
301
|
-
assert isinstance(weights["mixer"], Mapping)
|
|
302
|
-
assert isinstance(weights["mlp"], Mapping)
|
|
303
|
-
assert isinstance(weights["pre_mlp_norm"], Mapping)
|
|
304
|
-
|
|
305
298
|
if self.post_mixer_norm is not None:
|
|
306
|
-
|
|
307
|
-
post_mixer_norm = self.post_mixer_norm.import_weights(
|
|
308
|
-
weights["post_mixer_norm"],
|
|
309
|
-
)
|
|
299
|
+
post_mixer_norm = self.post_mixer_norm.import_weights(require_tree(weights["post_mixer_norm"]))
|
|
310
300
|
else:
|
|
311
301
|
post_mixer_norm = None
|
|
312
302
|
if self.post_mlp_norm is not None:
|
|
313
|
-
|
|
314
|
-
post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
|
|
303
|
+
post_mlp_norm = self.post_mlp_norm.import_weights(require_tree(weights["post_mlp_norm"]))
|
|
315
304
|
else:
|
|
316
305
|
post_mlp_norm = None
|
|
317
306
|
if self.pre_mixer_norm is not None:
|
|
318
|
-
|
|
319
|
-
pre_mixer_norm = self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"])
|
|
307
|
+
pre_mixer_norm = self.pre_mixer_norm.import_weights(require_tree(weights["pre_mixer_norm"]))
|
|
320
308
|
else:
|
|
321
309
|
pre_mixer_norm = None
|
|
322
310
|
return replace(
|
|
323
311
|
self,
|
|
324
312
|
pre_mixer_norm=pre_mixer_norm,
|
|
325
|
-
mixer=self.mixer.import_weights(weights["mixer"]),
|
|
313
|
+
mixer=self.mixer.import_weights(require_tree(weights["mixer"])),
|
|
326
314
|
post_mixer_norm=post_mixer_norm,
|
|
327
|
-
pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
|
|
328
|
-
mlp=self.mlp.import_weights(weights["mlp"]),
|
|
315
|
+
pre_mlp_norm=self.pre_mlp_norm.import_weights(require_tree(weights["pre_mlp_norm"])),
|
|
316
|
+
mlp=self.mlp.import_weights(require_tree(weights["mlp"])),
|
|
329
317
|
post_mlp_norm=post_mlp_norm,
|
|
330
318
|
)
|
lalamo/registry_abc.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, ABCMeta
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Self
|
|
3
3
|
from weakref import WeakSet
|
|
4
4
|
|
|
5
5
|
__all__ = ["RegistryABC", "RegistryMeta"]
|
|
@@ -29,7 +29,7 @@ class RegistryMeta(ABCMeta):
|
|
|
29
29
|
|
|
30
30
|
# Detect and remember the root exactly once
|
|
31
31
|
if RegistryMeta._ROOT is None and name == "RegistryABC":
|
|
32
|
-
RegistryMeta._ROOT = cls
|
|
32
|
+
RegistryMeta._ROOT = cls
|
|
33
33
|
return
|
|
34
34
|
|
|
35
35
|
root = RegistryMeta._ROOT
|
|
@@ -58,6 +58,6 @@ class RegistryABC(ABC, metaclass=RegistryMeta):
|
|
|
58
58
|
"""
|
|
59
59
|
|
|
60
60
|
@classmethod
|
|
61
|
-
def __descendants__(cls) -> tuple[type, ...]:
|
|
62
|
-
reg: WeakSet[type] = getattr(cls, RegistryMeta._REG_ATTR) # noqa: SLF001
|
|
61
|
+
def __descendants__(cls) -> tuple[type[Self], ...]:
|
|
62
|
+
reg: WeakSet[type[Self]] = getattr(cls, RegistryMeta._REG_ATTR) # noqa: SLF001
|
|
63
63
|
return tuple(reg)
|