lalamo 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/common.py +79 -29
  3. lalamo/language_model.py +106 -83
  4. lalamo/main.py +91 -18
  5. lalamo/message_processor.py +170 -0
  6. lalamo/model_import/common.py +159 -43
  7. lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
  8. lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
  9. lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
  10. lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
  11. lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
  12. lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
  13. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
  14. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
  15. lalamo/model_import/huggingface_generation_config.py +44 -0
  16. lalamo/model_import/huggingface_tokenizer_config.py +85 -0
  17. lalamo/model_import/loaders/common.py +2 -1
  18. lalamo/model_import/loaders/huggingface.py +12 -10
  19. lalamo/model_import/model_specs/__init__.py +3 -2
  20. lalamo/model_import/model_specs/common.py +32 -34
  21. lalamo/model_import/model_specs/deepseek.py +1 -10
  22. lalamo/model_import/model_specs/gemma.py +2 -25
  23. lalamo/model_import/model_specs/huggingface.py +2 -12
  24. lalamo/model_import/model_specs/llama.py +2 -58
  25. lalamo/model_import/model_specs/mistral.py +9 -19
  26. lalamo/model_import/model_specs/pleias.py +3 -13
  27. lalamo/model_import/model_specs/polaris.py +5 -7
  28. lalamo/model_import/model_specs/qwen.py +12 -111
  29. lalamo/model_import/model_specs/reka.py +4 -13
  30. lalamo/modules/__init__.py +2 -1
  31. lalamo/modules/attention.py +90 -10
  32. lalamo/modules/common.py +51 -4
  33. lalamo/modules/decoder.py +90 -8
  34. lalamo/modules/decoder_layer.py +85 -8
  35. lalamo/modules/embedding.py +95 -29
  36. lalamo/modules/kv_cache.py +3 -3
  37. lalamo/modules/linear.py +170 -130
  38. lalamo/modules/mlp.py +40 -7
  39. lalamo/modules/normalization.py +24 -6
  40. lalamo/modules/rope.py +24 -6
  41. lalamo/sampling.py +99 -0
  42. lalamo/utils.py +86 -1
  43. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/METADATA +6 -6
  44. lalamo-0.3.0.dist-info/RECORD +58 -0
  45. lalamo-0.2.7.dist-info/RECORD +0 -54
  46. /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
  47. /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
  48. /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
  49. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/WHEEL +0 -0
  50. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/entry_points.txt +0 -0
  51. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/licenses/LICENSE +0 -0
  52. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,85 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import ClassVar
5
+
6
+ import cattrs
7
+ from tokenizers import AddedToken
8
+
9
+ __all__ = ["HFAddedToken", "HFTokenizerConfig"]
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class HFAddedToken:
14
+ content: str
15
+ single_word: bool
16
+ normalized: bool
17
+ special: bool
18
+ lstrip: bool
19
+ rstrip: bool
20
+
21
+ def to_added_token(self) -> AddedToken:
22
+ return AddedToken(
23
+ self.content,
24
+ single_word=self.single_word,
25
+ normalized=self.normalized,
26
+ special=self.special,
27
+ lstrip=self.lstrip,
28
+ rstrip=self.rstrip,
29
+ )
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class HFTokenizerConfig:
34
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
35
+ _converter.register_structure_hook(int | list[int], lambda v, _: v)
36
+
37
+ # ---------- core identity ----------
38
+ tokenizer_class: str | None = None
39
+ model_max_length: int | None = None
40
+ padding_side: str | None = None # "left" | "right"
41
+ truncation_side: str | None = None # "left" | "right"
42
+ legacy: bool | None = None
43
+ use_fast: bool | None = None
44
+ clean_up_tokenization_spaces: bool | None = None
45
+
46
+ # ---------- behaviour flags ----------
47
+ add_bos_token: bool | None = None
48
+ add_eos_token: bool | None = None
49
+ add_prefix_space: bool | None = None
50
+ use_default_system_prompt: bool | None = None
51
+ spaces_between_special_tokens: bool | None = None
52
+ do_lower_case: bool | None = None
53
+
54
+ # ---------- special tokens ----------
55
+ bos_token: str | None = None
56
+ eos_token: str | None = None
57
+ unk_token: str | None = None
58
+ pad_token: str | None = None
59
+ sep_token: str | None = None
60
+ cls_token: str | None = None
61
+ mask_token: str | None = None
62
+ added_tokens_decoder: dict[str, HFAddedToken] | None = None
63
+
64
+ # ---------- chat / SentencePiece ----------
65
+ chat_template: str | None = None
66
+ sp_model_kwargs: dict | None = None
67
+
68
+ # ---------- extras ----------
69
+ language: str | None = None
70
+ task: str | None = None
71
+
72
+ def added_tokens(self) -> list[AddedToken]:
73
+ if self.added_tokens_decoder is None:
74
+ return []
75
+ return [
76
+ AddedToken(content=token.content, single_word=token.single_word, normalized=token.normalized)
77
+ for token in self.added_tokens_decoder.values()
78
+ ]
79
+
80
+ @classmethod
81
+ def from_json(cls, json_path: Path | str) -> "HFTokenizerConfig":
82
+ json_path = Path(json_path)
83
+ with open(json_path) as f:
84
+ config = json.load(f)
85
+ return cls._converter.structure(config, cls)
@@ -1,6 +1,7 @@
1
1
  from collections.abc import Callable, Iterable
2
2
 
3
3
  import equinox as eqx
4
+ from jax._src.api import ShapeDtypeStruct
4
5
  from jax.tree import leaves_with_path
5
6
  from jax.tree_util import keystr
6
7
  from jaxtyping import Array, PyTree
@@ -18,7 +19,7 @@ def _get_name(leaf: PyTree, tree: PyTree) -> str:
18
19
 
19
20
 
20
21
  def _check_compatible(old_value: PyTree, new_value: PyTree, module: eqx.Module) -> None:
21
- if isinstance(old_value, Array) and isinstance(new_value, Array):
22
+ if isinstance(old_value, (Array, ShapeDtypeStruct)) and isinstance(new_value, Array):
22
23
  name = _get_name(old_value, module)
23
24
  if old_value.shape != new_value.shape:
24
25
  raise ValueError(f"Expected parameter {name} to have shape {old_value.shape}, got {new_value.shape}")
@@ -1,3 +1,5 @@
1
+ from collections.abc import Mapping
2
+
1
3
  import jax.numpy as jnp
2
4
  from einops import rearrange
3
5
  from jaxtyping import Array
@@ -80,7 +82,7 @@ def _process_quantized_tensors(
80
82
 
81
83
 
82
84
  def _fuse_full_precision_weights(
83
- weights_dict: dict[str, Array],
85
+ weights_dict: Mapping[str, Array],
84
86
  path: ParameterPath,
85
87
  sublayers_to_fuse: list[str] | None,
86
88
  ) -> Array:
@@ -92,7 +94,7 @@ def _fuse_full_precision_weights(
92
94
 
93
95
 
94
96
  def _fuse_quantized_weights(
95
- weights_dict: dict[str, Array],
97
+ weights_dict: Mapping[str, Array],
96
98
  path: ParameterPath,
97
99
  sublayers_to_fuse: list[str] | None,
98
100
  ) -> tuple[Array, Array, Array]:
@@ -117,7 +119,7 @@ def _fuse_quantized_weights(
117
119
 
118
120
  def load_linear(
119
121
  module: LinearBase,
120
- weights_dict: dict[str, Array],
122
+ weights_dict: Mapping[str, Array],
121
123
  path: ParameterPath,
122
124
  sublayers_to_fuse: list[str] | None = None,
123
125
  ) -> LinearBase:
@@ -162,7 +164,7 @@ def load_linear(
162
164
  raise TypeError(f"Unsupported module type for loading: {type(module)}")
163
165
 
164
166
 
165
- def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -> MLP:
167
+ def load_mlp(module: MLP, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLP:
166
168
  up_projection = load_linear(module.up_projection, weights_dict, path, sublayers_to_fuse=["up_proj", "gate_proj"])
167
169
  down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
168
170
  return load_parameters(lambda m: (m.up_projection, m.down_projection), module, (up_projection, down_projection))
@@ -170,7 +172,7 @@ def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -
170
172
 
171
173
  def load_rmsnorm(
172
174
  module: RMSNorm,
173
- weights_dict: dict[str, Array],
175
+ weights_dict: Mapping[str, Array],
174
176
  path: ParameterPath,
175
177
  ) -> RMSNorm:
176
178
  scales = weights_dict[path / "weight"]
@@ -179,7 +181,7 @@ def load_rmsnorm(
179
181
 
180
182
  def load_attention(
181
183
  module: Attention,
182
- weights_dict: dict[str, Array],
184
+ weights_dict: Mapping[str, Array],
183
185
  path: ParameterPath,
184
186
  ) -> Attention:
185
187
  qkv_projection = load_linear(
@@ -209,7 +211,7 @@ def load_attention(
209
211
 
210
212
  def load_decoder_layer(
211
213
  module: DecoderLayer,
212
- weights_dict: dict[str, Array],
214
+ weights_dict: Mapping[str, Array],
213
215
  path: ParameterPath,
214
216
  ) -> DecoderLayer:
215
217
  pre_attention_norm = load_rmsnorm(
@@ -257,7 +259,7 @@ def load_decoder_layer(
257
259
 
258
260
  def load_tied_embedding(
259
261
  module: TiedEmbedding,
260
- weights_dict: dict[str, Array],
262
+ weights_dict: Mapping[str, Array],
261
263
  decoder_path: ParameterPath,
262
264
  ) -> TiedEmbedding:
263
265
  weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
@@ -266,7 +268,7 @@ def load_tied_embedding(
266
268
 
267
269
  def load_untied_embedding(
268
270
  module: UntiedEmbedding,
269
- weights_dict: dict[str, Array],
271
+ weights_dict: Mapping[str, Array],
270
272
  decoder_path: ParameterPath,
271
273
  lm_head_path: ParameterPath,
272
274
  ) -> UntiedEmbedding:
@@ -277,7 +279,7 @@ def load_untied_embedding(
277
279
 
278
280
  def load_huggingface(
279
281
  module: Decoder,
280
- weights_dict: dict[str, Array],
282
+ weights_dict: Mapping[str, Array],
281
283
  ) -> Decoder:
282
284
  if any(key.startswith("language_model.") for key in weights_dict):
283
285
  base_path = ParameterPath("language_model")
@@ -1,4 +1,4 @@
1
- from .common import awq_model_spec, build_quantized_models, ModelSpec, UseCase
1
+ from .common import FileSpec, ModelSpec, UseCase, build_quantized_models
2
2
  from .deepseek import DEEPSEEK_MODELS
3
3
  from .gemma import GEMMA_MODELS
4
4
  from .huggingface import HUGGINGFACE_MODELS
@@ -12,6 +12,7 @@ from .reka import REKA_MODELS
12
12
  __all__ = [
13
13
  "ALL_MODELS",
14
14
  "REPO_TO_MODEL",
15
+ "FileSpec",
15
16
  "ModelSpec",
16
17
  "UseCase",
17
18
  ]
@@ -23,7 +24,7 @@ ALL_MODEL_LISTS = [
23
24
  GEMMA_MODELS,
24
25
  HUGGINGFACE_MODELS,
25
26
  MISTRAL_MODELS,
26
- PLEIAS_MODELS,
27
+ # PLEIAS_MODELS, # TODO(norpadon): Add chat template
27
28
  POLARIS_MODELS,
28
29
  QWEN_MODELS,
29
30
  REKA_MODELS,
@@ -1,4 +1,7 @@
1
- from dataclasses import dataclass
1
+ from collections.abc import (
2
+ Mapping,
3
+ )
4
+ from dataclasses import dataclass, field
2
5
  from enum import Enum
3
6
  from pathlib import Path
4
7
 
@@ -6,18 +9,19 @@ import jax.numpy as jnp
6
9
  from jaxtyping import Array, DTypeLike
7
10
  from safetensors.flax import load_file as load_safetensors
8
11
 
9
- from lalamo.model_import.configs import ForeignConfig
12
+ from lalamo.model_import.decoder_configs import ForeignConfig
13
+ from lalamo.modules.torch_interop import torch_to_jax
10
14
  from lalamo.quantization import QuantizationMode
15
+ from lalamo.utils import MapDictValues
11
16
 
12
17
  __all__ = [
13
- "HUGGINFACE_GENERATION_CONFIG_FILE",
14
- "HUGGINGFACE_TOKENIZER_FILES",
18
+ "ConfigMap",
19
+ "FileSpec",
15
20
  "ModelSpec",
16
- "TokenizerFileSpec",
17
21
  "UseCase",
22
+ "WeightsType",
18
23
  "awq_model_spec",
19
24
  "build_quantized_models",
20
- "huggingface_weight_files",
21
25
  ]
22
26
 
23
27
 
@@ -31,16 +35,14 @@ class WeightsType(Enum):
31
35
  SAFETENSORS = "safetensors"
32
36
  TORCH = "torch"
33
37
 
34
- def load(self, filename: Path | str, float_dtype: DTypeLike) -> dict[str, jnp.ndarray]:
38
+ def load(self, filename: Path | str, float_dtype: DTypeLike) -> Mapping[str, jnp.ndarray]:
35
39
  if self == WeightsType.SAFETENSORS:
36
- return {k: cast_if_float(v, float_dtype) for k, v in load_safetensors(filename).items()}
40
+ return MapDictValues(lambda v: cast_if_float(v, float_dtype), load_safetensors(filename))
37
41
 
38
42
  import torch
39
43
 
40
- from lalamo.modules.torch_interop import torch_to_jax
41
-
42
44
  torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
43
- return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
45
+ return MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
44
46
 
45
47
 
46
48
  class UseCase(Enum):
@@ -48,9 +50,18 @@ class UseCase(Enum):
48
50
 
49
51
 
50
52
  @dataclass(frozen=True)
51
- class TokenizerFileSpec:
52
- repo: str | None
53
+ class FileSpec:
53
54
  filename: str
55
+ repo: str | None = None
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class ConfigMap:
60
+ model_config: FileSpec = field(default=FileSpec("config.json"))
61
+ tokenizer: FileSpec = field(default=FileSpec("tokenizer.json"))
62
+ tokenizer_config: FileSpec = field(default=FileSpec("tokenizer_config.json"))
63
+ generation_config: FileSpec | None = field(default=FileSpec("generation_config.json"))
64
+ chat_template: FileSpec | None = None
54
65
 
55
66
 
56
67
  @dataclass(frozen=True)
@@ -62,19 +73,16 @@ class ModelSpec:
62
73
  quantization: QuantizationMode | None
63
74
  repo: str
64
75
  config_type: type[ForeignConfig]
65
- config_file_name: str
66
- weights_file_names: tuple[str, ...]
67
- weights_type: WeightsType
68
- tokenizer_files: tuple[TokenizerFileSpec, ...] = tuple()
76
+ output_parser_regex: str | None = None
77
+ system_role_name: str = "system"
78
+ user_role_name: str = "user"
79
+ assistant_role_name: str = "assistant"
80
+ tool_role_name: str = "tool"
81
+ weights_type: WeightsType = WeightsType.SAFETENSORS
82
+ configs: ConfigMap = field(default=ConfigMap())
69
83
  use_cases: tuple[UseCase, ...] = tuple()
70
84
 
71
85
 
72
- def huggingface_weight_files(num_shards: int) -> tuple[str, ...]:
73
- if num_shards == 1:
74
- return ("model.safetensors",)
75
- return tuple(f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1))
76
-
77
-
78
86
  def awq_model_spec(
79
87
  model_spec: ModelSpec,
80
88
  repo: str,
@@ -88,10 +96,8 @@ def awq_model_spec(
88
96
  quantization=quantization,
89
97
  repo=repo,
90
98
  config_type=model_spec.config_type,
91
- config_file_name=model_spec.config_file_name,
92
- weights_file_names=huggingface_weight_files(1),
99
+ configs=model_spec.configs,
93
100
  weights_type=model_spec.weights_type,
94
- tokenizer_files=model_spec.tokenizer_files,
95
101
  use_cases=model_spec.use_cases,
96
102
  )
97
103
 
@@ -115,11 +121,3 @@ def build_quantized_models(model_specs: list[ModelSpec]) -> list[ModelSpec]:
115
121
  quantized_model_spec = awq_model_spec(model_spec, quantized_repo)
116
122
  quantized_model_specs.append(quantized_model_spec)
117
123
  return quantized_model_specs
118
-
119
-
120
- HUGGINGFACE_TOKENIZER_FILES = (
121
- TokenizerFileSpec(repo=None, filename="tokenizer.json"),
122
- TokenizerFileSpec(repo=None, filename="tokenizer_config.json"),
123
- )
124
-
125
- HUGGINFACE_GENERATION_CONFIG_FILE = TokenizerFileSpec(repo=None, filename="generation_config.json")
@@ -1,11 +1,7 @@
1
- from lalamo.model_import.configs import HFQwen2Config
1
+ from lalamo.model_import.decoder_configs import HFQwen2Config
2
2
 
3
3
  from .common import (
4
- HUGGINFACE_GENERATION_CONFIG_FILE,
5
- HUGGINGFACE_TOKENIZER_FILES,
6
4
  ModelSpec,
7
- WeightsType,
8
- huggingface_weight_files,
9
5
  )
10
6
 
11
7
  __all__ = ["DEEPSEEK_MODELS"]
@@ -19,10 +15,5 @@ DEEPSEEK_MODELS = [
19
15
  quantization=None,
20
16
  repo="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
21
17
  config_type=HFQwen2Config,
22
- config_file_name="config.json",
23
- weights_file_names=huggingface_weight_files(1),
24
- weights_type=WeightsType.SAFETENSORS,
25
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
- use_cases=tuple(),
27
18
  ),
28
19
  ]
@@ -1,12 +1,6 @@
1
- from lalamo.model_import.configs import HFGemma2Config, HFGemma3Config, HFGemma3TextConfig
1
+ from lalamo.model_import.decoder_configs import HFGemma2Config, HFGemma3Config, HFGemma3TextConfig
2
2
 
3
- from .common import (
4
- HUGGINFACE_GENERATION_CONFIG_FILE,
5
- HUGGINGFACE_TOKENIZER_FILES,
6
- ModelSpec,
7
- WeightsType,
8
- huggingface_weight_files,
9
- )
3
+ from .common import ModelSpec, WeightsType
10
4
 
11
5
  __all__ = ["GEMMA_MODELS"]
12
6
 
@@ -19,11 +13,6 @@ GEMMA2 = [
19
13
  quantization=None,
20
14
  repo="google/gemma-2-2b-it",
21
15
  config_type=HFGemma2Config,
22
- config_file_name="config.json",
23
- weights_file_names=huggingface_weight_files(2),
24
- weights_type=WeightsType.SAFETENSORS,
25
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
- use_cases=tuple(),
27
16
  ),
28
17
  ]
29
18
 
@@ -36,11 +25,7 @@ GEMMA3 = [
36
25
  quantization=None,
37
26
  repo="google/gemma-3-1b-it",
38
27
  config_type=HFGemma3TextConfig,
39
- config_file_name="config.json",
40
- weights_file_names=huggingface_weight_files(1),
41
28
  weights_type=WeightsType.SAFETENSORS,
42
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
43
- use_cases=tuple(),
44
29
  ),
45
30
  ModelSpec(
46
31
  vendor="Google",
@@ -50,11 +35,7 @@ GEMMA3 = [
50
35
  quantization=None,
51
36
  repo="google/gemma-3-4b-it",
52
37
  config_type=HFGemma3Config,
53
- config_file_name="config.json",
54
- weights_file_names=huggingface_weight_files(2),
55
38
  weights_type=WeightsType.SAFETENSORS,
56
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
57
- use_cases=tuple(),
58
39
  ),
59
40
  ModelSpec(
60
41
  vendor="Google",
@@ -64,11 +45,7 @@ GEMMA3 = [
64
45
  quantization=None,
65
46
  repo="google/gemma-3-27b-it",
66
47
  config_type=HFGemma3Config,
67
- config_file_name="config.json",
68
- weights_file_names=huggingface_weight_files(12),
69
48
  weights_type=WeightsType.SAFETENSORS,
70
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
71
- use_cases=tuple(),
72
49
  ),
73
50
  ]
74
51
 
@@ -1,12 +1,6 @@
1
- from lalamo.model_import.configs import HFLlamaConfig
1
+ from lalamo.model_import.decoder_configs import HFLlamaConfig
2
2
 
3
- from .common import (
4
- HUGGINFACE_GENERATION_CONFIG_FILE,
5
- HUGGINGFACE_TOKENIZER_FILES,
6
- ModelSpec,
7
- WeightsType,
8
- huggingface_weight_files,
9
- )
3
+ from .common import ModelSpec
10
4
 
11
5
  __all__ = ["HUGGINGFACE_MODELS"]
12
6
 
@@ -19,10 +13,6 @@ HUGGINGFACE_MODELS = [
19
13
  quantization=None,
20
14
  repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
21
15
  config_type=HFLlamaConfig,
22
- config_file_name="config.json",
23
- weights_file_names=huggingface_weight_files(1),
24
- weights_type=WeightsType.SAFETENSORS,
25
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
16
  use_cases=tuple(),
27
17
  ),
28
18
  ]
@@ -1,15 +1,6 @@
1
- from dataclasses import replace
1
+ from lalamo.model_import.decoder_configs import HFLlamaConfig
2
2
 
3
- from lalamo.model_import.configs import HFLlamaConfig
4
-
5
- from .common import (
6
- HUGGINFACE_GENERATION_CONFIG_FILE,
7
- HUGGINGFACE_TOKENIZER_FILES,
8
- ModelSpec,
9
- TokenizerFileSpec,
10
- WeightsType,
11
- huggingface_weight_files,
12
- )
3
+ from .common import ModelSpec
13
4
 
14
5
  __all__ = ["LLAMA_MODELS"]
15
6
 
@@ -22,23 +13,12 @@ LLAMA31 = [
22
13
  quantization=None,
23
14
  repo="meta-llama/Llama-3.1-8B-Instruct",
24
15
  config_type=HFLlamaConfig,
25
- config_file_name="config.json",
26
- weights_file_names=huggingface_weight_files(4),
27
- weights_type=WeightsType.SAFETENSORS,
28
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
29
16
  use_cases=tuple(),
30
17
  ),
31
18
  ]
32
19
 
33
20
 
34
- def _tokenizer_files_from_another_repo(repo: str) -> tuple[TokenizerFileSpec, ...]:
35
- return tuple(
36
- replace(spec, repo=repo) for spec in (*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE)
37
- )
38
-
39
-
40
21
  LLAMA32 = [
41
- # LLAMA
42
22
  ModelSpec(
43
23
  vendor="Meta",
44
24
  family="Llama-3.2",
@@ -47,26 +27,8 @@ LLAMA32 = [
47
27
  quantization=None,
48
28
  repo="meta-llama/Llama-3.2-1B-Instruct",
49
29
  config_type=HFLlamaConfig,
50
- config_file_name="config.json",
51
- weights_file_names=huggingface_weight_files(1),
52
- weights_type=WeightsType.SAFETENSORS,
53
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
54
30
  use_cases=tuple(),
55
31
  ),
56
- # ModelSpec(
57
- # vendor="Meta",
58
- # family="Llama-3.2",
59
- # name="Llama-3.2-1B-Instruct-QLoRA",
60
- # size="1B",
61
- # quantization=QuantizationMode.UINT4,
62
- # repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
63
- # config_type=ETLlamaConfig,
64
- # config_file_name="params.json",
65
- # weights_file_names=("consolidated.00.pth",),
66
- # weights_type=WeightsType.TORCH,
67
- # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
68
- # use_cases=tuple(),
69
- # ),
70
32
  ModelSpec(
71
33
  vendor="Meta",
72
34
  family="Llama-3.2",
@@ -75,26 +37,8 @@ LLAMA32 = [
75
37
  quantization=None,
76
38
  repo="meta-llama/Llama-3.2-3B-Instruct",
77
39
  config_type=HFLlamaConfig,
78
- config_file_name="config.json",
79
- weights_file_names=huggingface_weight_files(2),
80
- weights_type=WeightsType.SAFETENSORS,
81
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
82
40
  use_cases=tuple(),
83
41
  ),
84
- # ModelSpec(
85
- # vendor="Meta",
86
- # family="Llama-3.2",
87
- # name="Llama-3.2-3B-Instruct-QLoRA",
88
- # size="3B",
89
- # quantization=QuantizationMode.UINT4,
90
- # repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
91
- # config_type=ETLlamaConfig,
92
- # config_file_name="params.json",
93
- # weights_file_names=("consolidated.00.pth",),
94
- # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
95
- # weights_type=WeightsType.TORCH,
96
- # use_cases=tuple(),
97
- # ),
98
42
  ]
99
43
 
100
44
  LLAMA_MODELS = LLAMA31 + LLAMA32
@@ -1,15 +1,11 @@
1
- from dataclasses import replace
2
-
3
- from lalamo.model_import.configs import HFMistralConfig
1
+ from lalamo.model_import.decoder_configs import HFMistralConfig
4
2
 
5
3
  from .common import (
6
- HUGGINFACE_GENERATION_CONFIG_FILE,
7
- HUGGINGFACE_TOKENIZER_FILES,
4
+ ConfigMap,
5
+ FileSpec,
8
6
  ModelSpec,
9
- TokenizerFileSpec,
10
7
  UseCase,
11
8
  WeightsType,
12
- huggingface_weight_files,
13
9
  )
14
10
 
15
11
  __all__ = ["MISTRAL_MODELS"]
@@ -23,20 +19,13 @@ CODESTRAL = [
23
19
  quantization=None,
24
20
  repo="mistral-community/Codestral-22B-v0.1",
25
21
  config_type=HFMistralConfig,
26
- config_file_name="config.json",
27
- weights_file_names=huggingface_weight_files(9),
28
22
  weights_type=WeightsType.SAFETENSORS,
29
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
30
23
  use_cases=(UseCase.CODE,),
31
24
  ),
32
25
  ]
33
26
 
34
27
 
35
- def _tokenizer_files_from_another_repo(repo: str) -> tuple[TokenizerFileSpec, ...]:
36
- return tuple(
37
- replace(spec, repo=repo) for spec in (*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE)
38
- )
39
-
28
+ DEVSTRAL_TOKENIZER_REPO = "mistralai/Mistral-Small-3.1-24B-Base-2503"
40
29
 
41
30
  DEVSTRAL = [
42
31
  ModelSpec(
@@ -47,10 +36,11 @@ DEVSTRAL = [
47
36
  quantization=None,
48
37
  repo="mistralai/Devstral-Small-2505",
49
38
  config_type=HFMistralConfig,
50
- config_file_name="config.json",
51
- weights_file_names=huggingface_weight_files(10),
52
- weights_type=WeightsType.SAFETENSORS,
53
- tokenizer_files=_tokenizer_files_from_another_repo("mistralai/Mistral-Small-3.1-24B-Base-2503"),
39
+ configs=ConfigMap(
40
+ tokenizer=FileSpec(repo=DEVSTRAL_TOKENIZER_REPO, filename="tokenizer.json"),
41
+ tokenizer_config=FileSpec(repo=DEVSTRAL_TOKENIZER_REPO, filename="tokenizer_config.json"),
42
+ generation_config=FileSpec(repo=DEVSTRAL_TOKENIZER_REPO, filename="generation_config.json"),
43
+ ),
54
44
  use_cases=(UseCase.CODE,),
55
45
  ),
56
46
  ]
@@ -1,15 +1,10 @@
1
- from lalamo.model_import.configs import HFLlamaConfig
1
+ from lalamo.model_import.decoder_configs import HFLlamaConfig
2
2
 
3
- from .common import (
4
- HUGGINFACE_GENERATION_CONFIG_FILE,
5
- HUGGINGFACE_TOKENIZER_FILES,
6
- ModelSpec,
7
- WeightsType,
8
- huggingface_weight_files,
9
- )
3
+ from .common import ModelSpec
10
4
 
11
5
  __all__ = ["PLEIAS_MODELS"]
12
6
 
7
+
13
8
  PLEIAS_MODELS = [
14
9
  ModelSpec(
15
10
  vendor="PleIAs",
@@ -19,10 +14,5 @@ PLEIAS_MODELS = [
19
14
  quantization=None,
20
15
  repo="PleIAs/Pleias-RAG-1B",
21
16
  config_type=HFLlamaConfig,
22
- config_file_name="config.json",
23
- weights_file_names=huggingface_weight_files(1),
24
- weights_type=WeightsType.SAFETENSORS,
25
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
- use_cases=tuple(),
27
17
  ),
28
18
  ]
@@ -1,6 +1,6 @@
1
- from lalamo.model_import.configs import HFQwen3Config
1
+ from lalamo.model_import.decoder_configs import HFQwen3Config
2
2
 
3
- from .common import HUGGINGFACE_TOKENIZER_FILES, ModelSpec, TokenizerFileSpec, WeightsType, huggingface_weight_files
3
+ from .common import ConfigMap, FileSpec, ModelSpec
4
4
 
5
5
  __all__ = ["POLARIS_MODELS"]
6
6
 
@@ -13,10 +13,8 @@ POLARIS_MODELS = [
13
13
  quantization=None,
14
14
  repo="POLARIS-Project/Polaris-4B-Preview",
15
15
  config_type=HFQwen3Config,
16
- config_file_name="config.json",
17
- weights_file_names=huggingface_weight_files(2),
18
- weights_type=WeightsType.SAFETENSORS,
19
- tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, TokenizerFileSpec(repo=None, filename="chat_template.jinja")),
20
- use_cases=tuple(),
16
+ configs=ConfigMap(
17
+ chat_template=FileSpec("chat_template.jinja"),
18
+ ),
21
19
  ),
22
20
  ]