lalamo 0.5.16__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. lalamo/__init__.py +26 -2
  2. lalamo/commands.py +429 -0
  3. lalamo/common.py +14 -1
  4. lalamo/main.py +375 -229
  5. lalamo/message_processor.py +4 -1
  6. lalamo/model_import/common.py +8 -17
  7. lalamo/model_import/decoder_configs/huggingface/lfm2.py +14 -4
  8. lalamo/model_import/decoder_configs/huggingface/llamba.py +2 -2
  9. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +2 -2
  10. lalamo/model_import/huggingface_generation_config.py +21 -3
  11. lalamo/model_import/loaders/executorch.py +2 -2
  12. lalamo/model_import/loaders/huggingface.py +3 -3
  13. lalamo/model_import/model_specs/common.py +8 -4
  14. lalamo/model_import/model_specs/lfm2.py +41 -9
  15. lalamo/models/common.py +3 -3
  16. lalamo/models/language_model.py +7 -6
  17. lalamo/modules/activations.py +1 -1
  18. lalamo/modules/classifier.py +11 -24
  19. lalamo/modules/common.py +4 -1
  20. lalamo/modules/decoder.py +5 -11
  21. lalamo/modules/embedding.py +25 -62
  22. lalamo/modules/linear.py +19 -33
  23. lalamo/modules/mlp.py +9 -19
  24. lalamo/modules/mlx_interop.py +1 -1
  25. lalamo/modules/rope.py +1 -1
  26. lalamo/modules/token_mixers/__init__.py +1 -1
  27. lalamo/modules/token_mixers/attention.py +9 -27
  28. lalamo/modules/token_mixers/mamba.py +9 -24
  29. lalamo/modules/token_mixers/short_conv.py +5 -12
  30. lalamo/modules/transformer.py +10 -20
  31. lalamo/modules/transformer_layer.py +8 -20
  32. lalamo/registry_abc.py +4 -4
  33. lalamo/safetensors.py +97 -0
  34. lalamo/sampling.py +14 -0
  35. lalamo/speculator/estimator.py +11 -4
  36. lalamo/speculator/ngram.py +1 -1
  37. lalamo/utils.py +0 -13
  38. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/METADATA +1 -2
  39. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/RECORD +43 -41
  40. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/WHEEL +0 -0
  41. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/entry_points.txt +0 -0
  42. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/licenses/LICENSE +0 -0
  43. {lalamo-0.5.16.dist-info → lalamo-0.6.0.dist-info}/top_level.txt +0 -0
@@ -169,7 +169,10 @@ class MessageProcessor:
169
169
  def __post_init__(self) -> None:
170
170
  if self.output_parser_regex is not None:
171
171
  all_fields = AssistantMessage.__dataclass_fields__
172
- required_fields = {k: v for k, v in all_fields.items() if v.type == v.type | None}
172
+ # NOTE: str type annotations are assumed to be required
173
+ required_fields = {
174
+ k: v for k, v in all_fields.items() if isinstance(v.type, str) or v.type == (v.type | None)
175
+ }
173
176
  named_groups = self.output_parser_regex.groupindex
174
177
  invalid_groups = set(named_groups) - set(all_fields)
175
178
  if invalid_groups:
@@ -3,7 +3,7 @@ import json
3
3
  from collections import ChainMap
4
4
  from collections.abc import Callable
5
5
  from contextlib import ExitStack
6
- from dataclasses import dataclass
6
+ from dataclasses import dataclass, replace
7
7
  from pathlib import Path
8
8
  from typing import NamedTuple
9
9
 
@@ -20,7 +20,7 @@ from lalamo.quantization import QuantizationMode
20
20
  from lalamo.utils import process_chat_template
21
21
 
22
22
  from .decoder_configs import ForeignClassifierConfig, ForeignConfig, ForeignLMConfig
23
- from .huggingface_generation_config import HFGenerationConfig
23
+ from .huggingface_generation_config import HFGenerationConfig, _policy_from_hf_config
24
24
  from .huggingface_tokenizer_config import HFTokenizerConfig
25
25
  from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, ModelType, UseCase
26
26
  from .model_specs.common import JSONFieldSpec
@@ -34,6 +34,7 @@ __all__ = [
34
34
  "ModelSpec",
35
35
  "ModelType",
36
36
  "StatusEvent",
37
+ "download_file",
37
38
  "import_model",
38
39
  ]
39
40
 
@@ -239,24 +240,14 @@ def _import_language_model(
239
240
 
240
241
  stop_token_ids = tuple(foreign_decoder_config.eos_token_ids)
241
242
 
242
- if model_spec.configs.generation_config is not None:
243
+ if isinstance(model_spec.configs.generation_config, GenerationConfig):
244
+ generation_config = replace(model_spec.configs.generation_config, stop_token_ids=stop_token_ids)
245
+ elif isinstance(model_spec.configs.generation_config, FileSpec):
243
246
  hf_generation_config_file = download_file(model_spec.configs.generation_config, model_spec.repo)
244
247
  hf_generation_config = HFGenerationConfig.from_json(hf_generation_config_file)
245
- generation_config = GenerationConfig(
246
- stop_token_ids=stop_token_ids,
247
- temperature=hf_generation_config.temperature,
248
- top_p=hf_generation_config.top_p,
249
- top_k=hf_generation_config.top_k,
250
- banned_tokens=None,
251
- )
248
+ generation_config = _policy_from_hf_config(hf_generation_config, stop_token_ids)
252
249
  else:
253
- generation_config = GenerationConfig(
254
- stop_token_ids=stop_token_ids,
255
- temperature=None,
256
- top_p=None,
257
- top_k=None,
258
- banned_tokens=None,
259
- )
250
+ generation_config = GenerationConfig(stop_token_ids)
260
251
 
261
252
  language_model_config = LanguageModelConfig(
262
253
  model_config=decoder.config,
@@ -2,6 +2,7 @@ from collections.abc import Mapping
2
2
  from dataclasses import dataclass
3
3
  from typing import Literal
4
4
 
5
+ import jax.numpy as jnp
5
6
  from jaxtyping import DTypeLike
6
7
 
7
8
  from lalamo.modules import (
@@ -50,7 +51,6 @@ class HFLFM2Config(HuggingFaceLMConfig):
50
51
  conv_L_cache: int # noqa: N815
51
52
  conv_bias: bool
52
53
  conv_dim: int
53
- conv_dim_out: int
54
54
  conv_use_xavier_init: bool
55
55
  eos_token_id: int
56
56
  hidden_size: int
@@ -64,13 +64,15 @@ class HFLFM2Config(HuggingFaceLMConfig):
64
64
  num_key_value_heads: int
65
65
  pad_token_id: int
66
66
  rope_theta: float
67
- torch_dtype: Literal["bfloat16"]
68
67
  transformers_version: str
69
68
  use_cache: bool
70
69
  use_pos_enc: bool
71
70
  vocab_size: int
72
71
 
72
+ dtype: Literal["bfloat16", "float16", "float32"] | None = None
73
+ torch_dtype: Literal["bfloat16", "float16", "float32"] | None = None
73
74
  intermediate_size: int | None = None
75
+ conv_dim_out: int | None = None
74
76
  layer_types: list[Literal["conv", "full_attention"]] | None = None
75
77
  full_attn_idxs: list[int] | None = None
76
78
  tie_embedding: bool = True
@@ -79,6 +81,14 @@ class HFLFM2Config(HuggingFaceLMConfig):
79
81
  quantization: QuantizationConfig | None = None
80
82
  quantization_config: QuantizationConfig | None = None
81
83
 
84
+ @property
85
+ def default_precision(self) -> DTypeLike:
86
+ assert self.dtype is not None or self.torch_dtype is not None, (
87
+ "at least one of dtype or torch_dtype must be specified"
88
+ )
89
+
90
+ return jnp.dtype(self.dtype or self.torch_dtype)
91
+
82
92
  def to_decoder_config(
83
93
  self,
84
94
  context_length: int | None,
@@ -200,8 +210,8 @@ class HFLFM2Config(HuggingFaceLMConfig):
200
210
  subtract_mean=False,
201
211
  )
202
212
 
203
- if self.intermediate_size is not None:
204
- hidden_dim = self.intermediate_size
213
+ if not self.block_auto_adjust_ff_dim:
214
+ hidden_dim = self.intermediate_size or self.block_ff_dim
205
215
  else:
206
216
  hidden_dim_adjusted = self.block_ff_dim * self.block_ffn_dim_multiplier * (2 / 3)
207
217
  hidden_dim = int(
@@ -76,7 +76,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
76
76
  logit_soft_cap=None,
77
77
  group_size=int(metadata_dict["quantization_kwargs.group_size"]),
78
78
  embedding_quantization_mode=QuantizationMode.from_num_bits(
79
- int(metadata_dict["quantization_kwargs.bits"])
79
+ int(metadata_dict["quantization_kwargs.bits"]),
80
80
  ),
81
81
  activation_quantization_mode=None,
82
82
  activation_precision=activation_precision,
@@ -107,7 +107,7 @@ class HFLlambaConfig(HuggingFaceLMConfig):
107
107
  linear_config = MLXQuantizedLinearConfig(
108
108
  group_size=int(metadata_dict["quantization_kwargs.group_size"]),
109
109
  weight_quantization_mode=QuantizationMode.from_num_bits(
110
- int(metadata_dict["quantization_kwargs.bits"])
110
+ int(metadata_dict["quantization_kwargs.bits"]),
111
111
  ),
112
112
  activation_quantization_mode=None,
113
113
  activation_precision=activation_precision,
@@ -41,7 +41,7 @@ def activation_from_str(activation: str) -> type[Activation]:
41
41
  return supported_activations[activation]
42
42
 
43
43
  raise ValueError(
44
- f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}"
44
+ f"Only activations from the following list are supported by Classifier: {supported_activations.keys()}",
45
45
  )
46
46
 
47
47
 
@@ -97,7 +97,7 @@ class ModernBERTConfig(HuggingFaceClassifierConfig):
97
97
  result = [None] * num_layers
98
98
  for index in range(len(result)):
99
99
  if index % global_attn_every_n_layers != 0:
100
- result[index] = self.local_attention # type: ignore
100
+ result[index] = self.local_attention
101
101
  else:
102
102
  pass
103
103
  return tuple(result)
@@ -5,7 +5,9 @@ from typing import ClassVar
5
5
 
6
6
  import cattrs
7
7
 
8
- __all__ = ["HFGenerationConfig"]
8
+ from lalamo.models import GenerationConfig
9
+
10
+ __all__ = ["HFGenerationConfig", "_policy_from_hf_config"]
9
11
 
10
12
 
11
13
  @dataclass(frozen=True)
@@ -27,10 +29,11 @@ class HFGenerationConfig:
27
29
  cache_implementation: str | None = None # “hybrid” for Gemma 3/2
28
30
 
29
31
  # -------- sampling strategy -------------
30
- do_sample: bool | None = None
32
+ do_sample: bool | None = False
31
33
  temperature: float | None = None
34
+ min_p: float | None = None
32
35
  top_p: float | None = None
33
- top_k: int | None = None
36
+ top_k: int | None = 50
34
37
  repetition_penalty: float | None = None
35
38
 
36
39
  # -------- length limits -----------------
@@ -42,3 +45,18 @@ class HFGenerationConfig:
42
45
  with open(json_path) as f:
43
46
  config = json.load(f)
44
47
  return cls._converter.structure(config, cls)
48
+
49
+
50
+ def _policy_from_hf_config(
51
+ hf_config: HFGenerationConfig,
52
+ stop_token_ids: tuple[int, ...] = (),
53
+ banned_tokens: tuple[int, ...] | None = None,
54
+ ) -> GenerationConfig:
55
+ return GenerationConfig(
56
+ stop_token_ids=stop_token_ids,
57
+ temperature=hf_config.temperature,
58
+ top_k=hf_config.top_k,
59
+ top_p=hf_config.top_p,
60
+ min_p=hf_config.min_p,
61
+ banned_tokens=banned_tokens,
62
+ )
@@ -97,7 +97,7 @@ def load_mlp(module: DenseMLP, weights_dict: Mapping[str, Array], path: Paramete
97
97
  fused_up_gate_params = merge_linear_params([up_proj_params, gate_proj_params])
98
98
 
99
99
  return load_parameters(
100
- lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)), # type: ignore
100
+ lambda m: (*params_selector(m.up_projection), *params_selector(m.down_projection)),
101
101
  module,
102
102
  (*fused_up_gate_params, *down_proj_params),
103
103
  )
@@ -177,7 +177,7 @@ def load_attention(
177
177
 
178
178
  qkv_params = merge_linear_params([q_params, k_params, v_params])
179
179
  return load_parameters(
180
- lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)), # type: ignore
180
+ lambda m: (*params_selector(m.qkv_projection), *params_selector(m.out_projection)),
181
181
  module,
182
182
  (*qkv_params, *out_params),
183
183
  )
@@ -289,7 +289,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
289
289
  combined_up_gate_b = jnp.concatenate([up_b + 1.0, gate_b], axis=-1)
290
290
 
291
291
  up_projection = load_parameters(
292
- lambda m: (m.weights, m.biases), # type: ignore
292
+ lambda m: (m.weights, m.biases),
293
293
  module.experts.up_projection,
294
294
  (combined_up_gate_w, combined_up_gate_b),
295
295
  )
@@ -309,7 +309,7 @@ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path:
309
309
  down_b = jnp.broadcast_to(down_b, (*down_w.shape[:-1], down_b.shape[0]))
310
310
 
311
311
  down_projection = load_parameters(
312
- lambda m: (m.weights, m.biases), # type: ignore
312
+ lambda m: (m.weights, m.biases),
313
313
  module.experts.down_projection,
314
314
  (down_w, down_b),
315
315
  )
@@ -807,7 +807,7 @@ def load_huggingface_decoder(
807
807
  weights_dict,
808
808
  decoder_path / "layers" / ((i * 2) if alternating_layers else i),
809
809
  decoder_path / "layers" / ((i * 2 + 1) if alternating_layers else i),
810
- mixer_key[type(layer.config.mixer_config)], # type: ignore
810
+ mixer_key[type(layer.config.mixer_config)],
811
811
  mlp_key,
812
812
  pre_mixer_norm_key,
813
813
  pre_mlp_norm_key,
@@ -7,15 +7,17 @@ from contextlib import contextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from enum import Enum, StrEnum
9
9
  from pathlib import Path
10
- from typing import ClassVar, cast, get_args, get_origin
10
+ from typing import Any, ClassVar, cast, get_args, get_origin
11
11
 
12
12
  import cattrs
13
13
  import jax.numpy as jnp
14
14
  from jaxtyping import Array, DTypeLike
15
15
 
16
16
  from lalamo.model_import.decoder_configs import ForeignConfig
17
+ from lalamo.models.language_model import GenerationConfig
17
18
  from lalamo.quantization import QuantizationMode
18
- from lalamo.utils import MapDictValues, open_safetensors
19
+ from lalamo.safetensors import safe_read
20
+ from lalamo.utils import MapDictValues
19
21
 
20
22
  __all__ = [
21
23
  "ConfigMap",
@@ -52,7 +54,8 @@ class WeightsType(Enum):
52
54
  float_dtype: DTypeLike,
53
55
  ) -> Iterator[tuple[Mapping[str, jnp.ndarray], Mapping[str, str]]]:
54
56
  if self == WeightsType.SAFETENSORS:
55
- with open_safetensors(filename) as (weights_dict, metadata_dict):
57
+ with Path(filename).open("rb") as fd:
58
+ (metadata_dict, weights_dict) = safe_read(fd)
56
59
  yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict), metadata_dict or {}
57
60
  else:
58
61
  import torch
@@ -84,7 +87,7 @@ class ConfigMap:
84
87
  model_config: FileSpec = field(default=FileSpec("config.json"))
85
88
  tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
86
89
  tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
87
- generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
90
+ generation_config: FileSpec | GenerationConfig | None = field(default=FileSpec("generation_config.json"))
88
91
  chat_template: FileSpec | JSONFieldSpec | str | None = None
89
92
 
90
93
 
@@ -121,6 +124,7 @@ def _structure_chat_template(value: object, _type: object) -> FileSpec | JSONFie
121
124
  if isinstance(value, str):
122
125
  return value
123
126
  if isinstance(value, dict):
127
+ value = cast("dict[Any, Any]", value) # ty bug??? Why is just `dict` != `dict[Any, Any]`?
124
128
  if "file_spec" in value and "field_name" in value:
125
129
  return JSONFieldSpec(
126
130
  file_spec=FileSpec(**value["file_spec"]),
@@ -1,4 +1,7 @@
1
+ from itertools import chain, product
2
+
1
3
  from lalamo.model_import.decoder_configs import HFLFM2Config
4
+ from lalamo.models.language_model import GenerationConfig
2
5
  from lalamo.quantization import QuantizationMode
3
6
 
4
7
  from .common import ConfigMap, FileSpec, ModelSpec
@@ -6,26 +9,55 @@ from .common import ConfigMap, FileSpec, ModelSpec
6
9
  __all__ = ["LFM2_MODELS"]
7
10
 
8
11
 
9
- def _lfm2_repo(size: str, quantization: QuantizationMode | None) -> tuple[str, str]:
10
- organization = "LiquidAI" if quantization is None else "mlx-community"
11
- name = f"LFM2-{size}{f'-{quantization.bits}bit' if quantization is not None else ''}"
12
- return (organization, name)
12
+ def _lfm_repo(family: str, size: str, variant: str | None, quantization: QuantizationMode | None) -> tuple[str, str]:
13
+ return (
14
+ "LiquidAI" if quantization is None else "mlx-community",
15
+ f"{family}-{size}"
16
+ f"{f'-{variant}' if variant is not None else ''}"
17
+ f"{f'-{quantization.bits}bit' if quantization is not None else ''}",
18
+ )
13
19
 
14
20
 
15
- LFM2_MODELS = [
21
+ _LFM20_MODELS = [
16
22
  ModelSpec(
17
23
  vendor="LiquidAI",
18
24
  family="LFM2",
19
- name=_lfm2_repo(size, quantization)[1],
25
+ name=_lfm_repo("LFM2", size, variant, quantization)[1],
20
26
  size=size,
21
- repo="/".join(_lfm2_repo(size, quantization)),
27
+ repo="/".join(_lfm_repo("LFM2", size, variant, quantization)),
22
28
  config_type=HFLFM2Config,
23
29
  quantization=quantization,
24
30
  configs=ConfigMap(
31
+ generation_config=GenerationConfig(temperature=0.3, min_p=0.15), # , repetition_penalty=1.05
25
32
  chat_template=FileSpec("chat_template.jinja"),
26
33
  ),
27
34
  use_cases=tuple(),
28
35
  )
29
- for size in ["350M", "700M", "1.2B", "2.6B"]
30
- for quantization in [None, *([QuantizationMode.UINT4, QuantizationMode.UINT8] if size != "2.6B" else [])]
36
+ for size, variant, quantization in chain(
37
+ product(["350M", "700M", "1.2B"], [None], [None, QuantizationMode.UINT4, QuantizationMode.UINT8]),
38
+ product(["2.6B"], [None, "Exp"], [None]),
39
+ product(["2.6B"], ["Exp"], [QuantizationMode.UINT4, QuantizationMode.UINT8]),
40
+ )
31
41
  ]
42
+
43
+ _LFM25_MODELS = [
44
+ ModelSpec(
45
+ vendor="LiquidAI",
46
+ family="LFM2.5",
47
+ name=_lfm_repo("LFM2.5", size, variant, quantization)[1],
48
+ size=size,
49
+ repo="/".join(_lfm_repo("LFM2.5", size, variant, quantization)),
50
+ config_type=HFLFM2Config,
51
+ quantization=quantization,
52
+ configs=ConfigMap(
53
+ generation_config=GenerationConfig(temperature=0.1, top_k=50, top_p=0.1), # , repetition_penalty=1.05
54
+ chat_template=FileSpec("chat_template.jinja"),
55
+ ),
56
+ use_cases=tuple(),
57
+ )
58
+ for size, variant, quantization in chain(
59
+ product(["1.2B"], ["Instruct"], [None]),
60
+ )
61
+ ]
62
+
63
+ LFM2_MODELS = _LFM20_MODELS + _LFM25_MODELS
lalamo/models/common.py CHANGED
@@ -15,7 +15,7 @@ from lalamo.message_processor import Message, MessageProcessor, MessageProcessor
15
15
  from lalamo.modules import Classifier, Decoder, LalamoModule, config_converter
16
16
  from lalamo.modules.classifier import ClassifierConfig, ClassifierResult
17
17
  from lalamo.modules.decoder import DecoderConfig, DecoderResult
18
- from lalamo.utils import open_safetensors
18
+ from lalamo.safetensors import safe_read
19
19
 
20
20
  __all__ = [
21
21
  "TextModel",
@@ -42,8 +42,8 @@ class TextModelConfig[ConfigT: ClassifierConfig | DecoderConfig](ABC):
42
42
  with open(path / "config.json") as config_file:
43
43
  config_json = json.load(config_file)
44
44
  config = config_converter.structure(config_json["model_config"], cls)
45
- with open_safetensors(path / "model.safetensors") as open_results:
46
- weights_dict, _ = open_results
45
+ with Path(path / "model.safetensors").open("rb") as fd:
46
+ _, weights_dict = safe_read(fd)
47
47
  weights = unflatten_parameters(weights_dict)
48
48
  model = config.model_config.empty().import_weights(weights)
49
49
  tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
@@ -64,14 +64,15 @@ class GenerationResults(NamedTuple):
64
64
 
65
65
  @dataclass(frozen=True)
66
66
  class GenerationConfig:
67
- stop_token_ids: tuple[int, ...]
68
- temperature: float | None
69
- top_k: int | None
70
- top_p: float | None
71
- banned_tokens: tuple[int, ...] | None
67
+ stop_token_ids: tuple[int, ...] = tuple()
68
+ temperature: float | None = None
69
+ top_k: int | None = None
70
+ top_p: float | None = None
71
+ min_p: float | None = None
72
+ banned_tokens: tuple[int, ...] | None = None
72
73
 
73
74
  def default_policy(self) -> SamplingPolicy:
74
- return make_policy(self.temperature, self.top_k, self.top_p, self.banned_tokens)
75
+ return make_policy(self.temperature, self.top_k, self.top_p, self.min_p, self.banned_tokens)
75
76
 
76
77
 
77
78
  @dataclass(frozen=True)
@@ -44,4 +44,4 @@ class Identity(ActivationBase):
44
44
  Activation = SiLU | GELU | Identity
45
45
 
46
46
 
47
- register_config_union(Activation) # type: ignore (pyright bug)
47
+ register_config_union(Activation)
@@ -9,7 +9,7 @@ from jax import numpy as jnp
9
9
  from jax import vmap
10
10
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
11
 
12
- from lalamo.common import ParameterTree
12
+ from lalamo.common import ParameterTree, require_tree
13
13
  from lalamo.modules import Activation
14
14
  from lalamo.modules.normalization import NormalizationConfig
15
15
  from lalamo.modules.transformer import (
@@ -67,7 +67,7 @@ class PredictionHeadConfig:
67
67
  def random_init(self, input_size: int, num_labels: int, key: PRNGKeyArray) -> "PredictionHead":
68
68
  dense_key, readout_key = jax.random.split(key)
69
69
  dense_layer = self.dense_config.random_init(
70
- input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key
70
+ input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key,
71
71
  )
72
72
  norm = self.normalization_config.empty(input_size)
73
73
  readout = self.readout_config.random_init(
@@ -117,19 +117,13 @@ class PredictionHead(LalamoModule[PredictionHeadConfig]):
117
117
  )
118
118
  return result
119
119
 
120
- def import_weights(
121
- self,
122
- weights: ParameterTree[Array],
123
- ) -> Self:
120
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
124
121
  assert isinstance(weights, Mapping)
125
- assert isinstance(weights["dense"], Mapping)
126
- assert isinstance(weights["norm"], Mapping)
127
- assert isinstance(weights["readout"], Mapping)
128
122
  return replace(
129
123
  self,
130
- dense=self.dense.import_weights(weights["dense"]),
131
- norm=self.norm.import_weights(weights["norm"]),
132
- readout=self.readout.import_weights(weights["readout"]),
124
+ dense=self.dense.import_weights(require_tree(weights["dense"])),
125
+ norm=self.norm.import_weights(require_tree(weights["norm"])),
126
+ readout=self.readout.import_weights(require_tree(weights["readout"])),
133
127
  )
134
128
 
135
129
 
@@ -321,19 +315,12 @@ class Classifier(LalamoModule[ClassifierConfig]):
321
315
  )
322
316
  return result
323
317
 
324
- def import_weights(
325
- self,
326
- weights: ParameterTree[Array],
327
- ) -> Self:
318
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
328
319
  assert isinstance(weights, Mapping)
329
- assert isinstance(weights["embedding"], Mapping)
330
- assert isinstance(weights["embedding_norm"], Mapping)
331
- assert isinstance(weights["transformer"], Mapping)
332
- assert isinstance(weights["prediction_head"], Mapping)
333
320
  return replace(
334
321
  self,
335
- embedding=self.embedding.import_weights(weights["embedding"]),
336
- embedding_norm=self.embedding_norm.import_weights(weights["embedding_norm"]),
337
- transformer=self.transformer.import_weights(weights["transformer"]),
338
- prediction_head=self.prediction_head.import_weights(weights["prediction_head"]),
322
+ embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
323
+ embedding_norm=self.embedding_norm.import_weights(require_tree(weights["embedding_norm"])),
324
+ transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
325
+ prediction_head=self.prediction_head.import_weights(require_tree(weights["prediction_head"])),
339
326
  )
lalamo/modules/common.py CHANGED
@@ -9,15 +9,18 @@ from cattrs import Converter
9
9
  from jax import numpy as jnp
10
10
  from jaxtyping import Array, DTypeLike
11
11
 
12
- from lalamo.common import ParameterTree
12
+ from lalamo.common import ParameterTree, require_array, require_tree
13
13
 
14
14
  __all__ = [
15
15
  "DummyUnionMember",
16
16
  "ForwardPassMode",
17
17
  "LalamoModule",
18
+ "ParameterTree",
18
19
  "PositionalEmbeddingSelector",
19
20
  "config_converter",
20
21
  "register_config_union",
22
+ "require_array",
23
+ "require_tree",
21
24
  ]
22
25
 
23
26
 
lalamo/modules/decoder.py CHANGED
@@ -7,7 +7,7 @@ import jax
7
7
  from jax import vmap
8
8
  from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
9
 
10
- from lalamo.common import ParameterTree
10
+ from lalamo.common import ParameterTree, require_tree
11
11
 
12
12
  from .common import ForwardPassMode, LalamoModule
13
13
  from .embedding import EmbeddingBase, EmbeddingConfig
@@ -126,7 +126,7 @@ class Decoder(LalamoModule[DecoderConfig]):
126
126
  return self.embedding.activation_precision
127
127
 
128
128
  @eqx.filter_jit
129
- def __call__( # noqa: PLR0912
129
+ def __call__(
130
130
  self,
131
131
  token_ids: Int[Array, "batch suffix_tokens"],
132
132
  token_positions: Int[Array, "batch suffix_tokens"],
@@ -193,16 +193,10 @@ class Decoder(LalamoModule[DecoderConfig]):
193
193
  transformer=self.transformer.export_weights(),
194
194
  )
195
195
 
196
- def import_weights(
197
- self,
198
- weights: ParameterTree[Array],
199
- ) -> Self:
196
+ def import_weights(self, weights: ParameterTree[Array]) -> Self:
200
197
  assert isinstance(weights, Mapping)
201
- assert isinstance(weights["embedding"], Mapping)
202
- assert isinstance(weights["transformer"], Mapping)
203
-
204
198
  return replace(
205
199
  self,
206
- embedding=self.embedding.import_weights(weights["embedding"]),
207
- transformer=self.transformer.import_weights(weights["transformer"]),
200
+ embedding=self.embedding.import_weights(require_tree(weights["embedding"])),
201
+ transformer=self.transformer.import_weights(require_tree(weights["transformer"])),
208
202
  )