lalamo 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +24 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +100 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/METADATA +1 -1
  44. lalamo-0.2.3.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ import jax.numpy as jnp
2
+ from einops import rearrange
3
+ from jaxtyping import Array
4
+
5
+ from lalamo.common import ParameterPath
6
+ from lalamo.modules import (
7
+ MLP,
8
+ Attention,
9
+ Decoder,
10
+ DecoderLayer,
11
+ FullPrecisionLinear,
12
+ GroupQuantizedLinear,
13
+ LinearBase,
14
+ RMSNorm,
15
+ TiedEmbedding,
16
+ UntiedEmbedding,
17
+ )
18
+ from lalamo.quantization import QuantizationMode
19
+
20
+ from .common import load_parameters
21
+
22
+ __all__ = ["load_huggingface"]
23
+
24
+
25
+ AWQ_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
26
+
27
+
28
+ def _reverse_uint4_awq_order(array: Array) -> Array:
29
+ """Reverses the AWQ packing order to get the logical order of channels for INT4."""
30
+ pack_factor = 32 // 4
31
+ *_, last_dim = array.shape
32
+ if last_dim % pack_factor != 0:
33
+ return array
34
+
35
+ array_reshaped = rearrange(array, "... (group pack_factor) -> ... group pack_factor", pack_factor=pack_factor)
36
+ array_reordered = array_reshaped[..., AWQ_REVERSE_ORDER]
37
+ return rearrange(array_reordered, "... group pack_factor -> ... (group pack_factor)")
38
+
39
+
40
+ def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
41
+ assert packed_weights.dtype == jnp.int32, (
42
+ f"Expected packed_weights to be of dtype jnp.int32, got {packed_weights.dtype}"
43
+ )
44
+ assert 32 % mode.bits == 0
45
+
46
+ shifts = jnp.arange(0, 32, mode.bits)
47
+ mask = (2**mode.bits) - 1
48
+ unpacked = jnp.bitwise_and(jnp.right_shift(packed_weights[:, :, None], shifts[None, None, :]), mask)
49
+ unpacked = rearrange(
50
+ unpacked,
51
+ "out_channels packed_groups packed_values -> out_channels (packed_groups packed_values)",
52
+ )
53
+
54
+ return unpacked
55
+
56
+
57
+ def _process_quantized_tensors(
58
+ qweights: Array,
59
+ qzeros: Array,
60
+ scales: Array,
61
+ module: GroupQuantizedLinear,
62
+ ) -> tuple[Array, Array, Array]:
63
+ """Unpacks, recenters, transposes, and casts quantized tensors to the correct dtype."""
64
+ mode = module.config.weight_quantization_mode
65
+ assert qweights.dtype == jnp.int32
66
+ unpacked_weights = unpack_int32(qweights, mode)
67
+ if mode == QuantizationMode.UINT4:
68
+ unpacked_weights = _reverse_uint4_awq_order(unpacked_weights)
69
+
70
+ assert qzeros.dtype == jnp.int32
71
+ unpacked_zero_points = unpack_int32(qzeros, mode)
72
+ if mode == QuantizationMode.UINT4:
73
+ unpacked_zero_points = _reverse_uint4_awq_order(unpacked_zero_points)
74
+
75
+ weights = unpacked_weights.astype(module.config.activation_precision)
76
+ zero_points = unpacked_zero_points.astype(module.config.activation_precision)
77
+ processed_scales = scales.astype(module.config.activation_precision)
78
+
79
+ return weights.transpose(), zero_points.transpose(), processed_scales.transpose()
80
+
81
+
82
+ def _fuse_full_precision_weights(
83
+ weights_dict: dict[str, Array],
84
+ path: ParameterPath,
85
+ sublayers_to_fuse: list[str] | None,
86
+ ) -> Array:
87
+ if sublayers_to_fuse is None:
88
+ return weights_dict[path / "weight"]
89
+
90
+ weights = [weights_dict[path / layer_name / "weight"] for layer_name in sublayers_to_fuse]
91
+ return jnp.concatenate(weights, axis=0)
92
+
93
+
94
+ def _fuse_quantized_weights(
95
+ weights_dict: dict[str, Array],
96
+ path: ParameterPath,
97
+ sublayers_to_fuse: list[str] | None,
98
+ ) -> tuple[Array, Array, Array]:
99
+ # Note that AWQ quantized weights are stored transposed relative to full-precision weights
100
+
101
+ if sublayers_to_fuse is None:
102
+ qweights = weights_dict[path / "qweight"]
103
+ qzeros = weights_dict[path / "qzeros"]
104
+ scales = weights_dict[path / "scales"]
105
+ return qweights, qzeros, scales
106
+
107
+ qweights = [weights_dict[path / layer_name / "qweight"] for layer_name in sublayers_to_fuse]
108
+ qzeros = [weights_dict[path / layer_name / "qzeros"] for layer_name in sublayers_to_fuse]
109
+ scales = [weights_dict[path / layer_name / "scales"] for layer_name in sublayers_to_fuse]
110
+
111
+ fused_qweights = jnp.concatenate(qweights, axis=1)
112
+ fused_qzeros = jnp.concatenate(qzeros, axis=1)
113
+ fused_scales = jnp.concatenate(scales, axis=1)
114
+
115
+ return fused_qweights, fused_qzeros, fused_scales
116
+
117
+
118
+ def load_linear(
119
+ module: LinearBase,
120
+ weights_dict: dict[str, Array],
121
+ path: ParameterPath,
122
+ sublayers_to_fuse: list[str] | None = None,
123
+ ) -> LinearBase:
124
+ """Loads a linear layer, optionally fusing weights from sublayers."""
125
+ if not module.has_biases:
126
+ if sublayers_to_fuse:
127
+ paths_to_check = [path / proj / "bias" for proj in sublayers_to_fuse]
128
+ else:
129
+ paths_to_check = path / "bias"
130
+ for p in paths_to_check:
131
+ if p in weights_dict:
132
+ raise ValueError(f"Bias tensor found at {p} but module does not support it.")
133
+ bias = None
134
+ elif sublayers_to_fuse is None:
135
+ bias = weights_dict[path / "bias"]
136
+ else:
137
+ bias = jnp.concatenate(
138
+ [weights_dict[path / proj_name / "bias"] for proj_name in sublayers_to_fuse],
139
+ axis=0,
140
+ )
141
+
142
+ if isinstance(module, FullPrecisionLinear):
143
+ weights = _fuse_full_precision_weights(weights_dict, path, sublayers_to_fuse)
144
+ return load_parameters(lambda m: (m.weights, m.biases), module, (weights, bias))
145
+
146
+ if isinstance(module, GroupQuantizedLinear):
147
+ qweights, qzeros, scales = _fuse_quantized_weights(weights_dict, path, sublayers_to_fuse)
148
+
149
+ weights, zero_points, scales = _process_quantized_tensors(
150
+ qweights,
151
+ qzeros,
152
+ scales,
153
+ module,
154
+ )
155
+
156
+ return load_parameters(
157
+ lambda m: (m.weights, m.scales, m.zero_points, m.biases),
158
+ module,
159
+ (weights, scales, zero_points, bias),
160
+ )
161
+
162
+ raise TypeError(f"Unsupported module type for loading: {type(module)}")
163
+
164
+
165
+ def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -> MLP:
166
+ up_projection = load_linear(module.up_projection, weights_dict, path, sublayers_to_fuse=["up_proj", "gate_proj"])
167
+ down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
168
+ return load_parameters(lambda m: (m.up_projection, m.down_projection), module, (up_projection, down_projection))
169
+
170
+
171
+ def load_rmsnorm(
172
+ module: RMSNorm,
173
+ weights_dict: dict[str, Array],
174
+ path: ParameterPath,
175
+ ) -> RMSNorm:
176
+ scales = weights_dict[path / "weight"]
177
+ return load_parameters(lambda m: (m.scales,), module, (scales,))
178
+
179
+
180
+ def load_attention(
181
+ module: Attention,
182
+ weights_dict: dict[str, Array],
183
+ path: ParameterPath,
184
+ ) -> Attention:
185
+ qkv_projection = load_linear(
186
+ module.qkv_projection,
187
+ weights_dict,
188
+ path,
189
+ sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
190
+ )
191
+ out_projection = load_linear(module.out_projection, weights_dict, path / "o_proj")
192
+
193
+ if module.query_norm is not None:
194
+ query_norm = load_rmsnorm(module.query_norm, weights_dict, path / "q_norm")
195
+ else:
196
+ query_norm = None
197
+
198
+ if module.key_norm is not None:
199
+ key_norm = load_rmsnorm(module.key_norm, weights_dict, path / "k_norm")
200
+ else:
201
+ key_norm = None
202
+
203
+ return load_parameters(
204
+ lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm),
205
+ module,
206
+ (qkv_projection, out_projection, query_norm, key_norm),
207
+ )
208
+
209
+
210
+ def load_decoder_layer(
211
+ module: DecoderLayer,
212
+ weights_dict: dict[str, Array],
213
+ path: ParameterPath,
214
+ ) -> DecoderLayer:
215
+ pre_attention_norm = load_rmsnorm(
216
+ module.pre_attention_norm,
217
+ weights_dict,
218
+ path / "input_layernorm",
219
+ )
220
+ attention = load_attention(module.attention, weights_dict, path / "self_attn")
221
+ if module.post_attention_norm is not None:
222
+ post_attention_norm = load_rmsnorm(
223
+ module.post_attention_norm,
224
+ weights_dict,
225
+ path / "post_attention_layernorm",
226
+ )
227
+
228
+ pre_mlp_norm = load_rmsnorm(
229
+ module.pre_mlp_norm,
230
+ weights_dict,
231
+ path / "pre_feedforward_layernorm",
232
+ )
233
+ else:
234
+ post_attention_norm = None
235
+
236
+ pre_mlp_norm = load_rmsnorm(
237
+ module.pre_mlp_norm,
238
+ weights_dict,
239
+ path / "post_attention_layernorm",
240
+ )
241
+
242
+ mlp = load_mlp(module.mlp, weights_dict, path / "mlp")
243
+ if module.post_mlp_norm is not None:
244
+ post_mlp_norm = load_rmsnorm(
245
+ module.post_mlp_norm,
246
+ weights_dict,
247
+ path / "post_feedforward_layernorm",
248
+ )
249
+ else:
250
+ post_mlp_norm = None
251
+ return load_parameters(
252
+ lambda m: (m.pre_attention_norm, m.attention, m.post_attention_norm, m.pre_mlp_norm, m.mlp, m.post_mlp_norm),
253
+ module,
254
+ (pre_attention_norm, attention, post_attention_norm, pre_mlp_norm, mlp, post_mlp_norm),
255
+ )
256
+
257
+
258
+ def load_tied_embedding(
259
+ module: TiedEmbedding,
260
+ weights_dict: dict[str, Array],
261
+ decoder_path: ParameterPath,
262
+ ) -> TiedEmbedding:
263
+ weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
264
+ return load_parameters(lambda m: (m.weights,), module, (weights,))
265
+
266
+
267
+ def load_untied_embedding(
268
+ module: UntiedEmbedding,
269
+ weights_dict: dict[str, Array],
270
+ decoder_path: ParameterPath,
271
+ lm_head_path: ParameterPath,
272
+ ) -> UntiedEmbedding:
273
+ input_weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
274
+ output_weights = weights_dict[lm_head_path / "weight"]
275
+ return load_parameters(lambda m: (m.input_weights, m.output_weights), module, (input_weights, output_weights))
276
+
277
+
278
+ def load_huggingface(
279
+ module: Decoder,
280
+ weights_dict: dict[str, Array],
281
+ ) -> Decoder:
282
+ if any(key.startswith("language_model.") for key in weights_dict):
283
+ base_path = ParameterPath("language_model")
284
+ else:
285
+ base_path = ParameterPath()
286
+
287
+ decoder_path = base_path / "model"
288
+ lm_head_path = base_path / "lm_head"
289
+
290
+ if isinstance(module.embedding, TiedEmbedding):
291
+ embedding = load_tied_embedding(module.embedding, weights_dict, decoder_path)
292
+ elif isinstance(module.embedding, UntiedEmbedding):
293
+ embedding = load_untied_embedding(module.embedding, weights_dict, decoder_path, lm_head_path)
294
+ else:
295
+ raise TypeError(f"Unsupported embedding type: {type(module.embedding)}")
296
+ decoder_layers = tuple(
297
+ load_decoder_layer(layer, weights_dict, decoder_path / "layers" / i) for i, layer in enumerate(module.layers)
298
+ )
299
+ output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / "norm")
300
+ return load_parameters(
301
+ lambda m: (m.embedding, m.layers, m.output_norm),
302
+ module,
303
+ (embedding, decoder_layers, output_norm),
304
+ )
@@ -0,0 +1,38 @@
1
+ from .common import awq_model_spec, build_quantized_models, ModelSpec, UseCase
2
+ from .deepseek import DEEPSEEK_MODELS
3
+ from .gemma import GEMMA_MODELS
4
+ from .huggingface import HUGGINGFACE_MODELS
5
+ from .llama import LLAMA_MODELS
6
+ from .mistral import MISTRAL_MODELS
7
+ from .pleias import PLEIAS_MODELS
8
+ from .polaris import POLARIS_MODELS
9
+ from .qwen import QWEN_MODELS
10
+ from .reka import REKA_MODELS
11
+
12
+ __all__ = [
13
+ "ALL_MODELS",
14
+ "REPO_TO_MODEL",
15
+ "ModelSpec",
16
+ "UseCase",
17
+ ]
18
+
19
+
20
+ ALL_MODEL_LISTS = [
21
+ LLAMA_MODELS,
22
+ DEEPSEEK_MODELS,
23
+ GEMMA_MODELS,
24
+ HUGGINGFACE_MODELS,
25
+ MISTRAL_MODELS,
26
+ PLEIAS_MODELS,
27
+ POLARIS_MODELS,
28
+ QWEN_MODELS,
29
+ REKA_MODELS,
30
+ ]
31
+
32
+
33
+ ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
34
+
35
+
36
+ QUANTIZED_MODELS = build_quantized_models(ALL_MODELS)
37
+ ALL_MODELS = ALL_MODELS + QUANTIZED_MODELS
38
+ REPO_TO_MODEL = {model.repo: model for model in ALL_MODELS}
@@ -0,0 +1,118 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from pathlib import Path
4
+
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from jaxtyping import Array, DTypeLike
8
+ from safetensors.flax import load_file as load_safetensors
9
+
10
+ from lalamo.model_import.configs import ForeignConfig
11
+ from lalamo.quantization import QuantizationMode
12
+ from lalamo.utils import torch_to_jax
13
+
14
+ __all__ = [
15
+ "HUGGINFACE_GENERATION_CONFIG_FILE",
16
+ "HUGGINGFACE_TOKENIZER_FILES",
17
+ "ModelSpec",
18
+ "TokenizerFileSpec",
19
+ "UseCase",
20
+ "huggingface_weight_files",
21
+ "awq_model_spec",
22
+ "build_quantized_models",
23
+ ]
24
+
25
+
26
+ def cast_if_float(array: Array, cast_to: DTypeLike) -> Array:
27
+ if array.dtype in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
28
+ return array.astype(cast_to)
29
+ return array
30
+
31
+
32
+ class WeightsType(Enum):
33
+ SAFETENSORS = "safetensors"
34
+ TORCH = "torch"
35
+
36
+ def load(self, filename: Path | str, float_dtype: DTypeLike) -> dict[str, jnp.ndarray]:
37
+ if self == WeightsType.SAFETENSORS:
38
+ return {k: cast_if_float(v, float_dtype) for k, v in load_safetensors(filename).items()}
39
+ torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
40
+ return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
41
+
42
+
43
+ class UseCase(Enum):
44
+ CODE = "code"
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class TokenizerFileSpec:
49
+ repo: str | None
50
+ filename: str
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class ModelSpec:
55
+ vendor: str
56
+ family: str
57
+ name: str
58
+ size: str
59
+ quantization: QuantizationMode | None
60
+ repo: str
61
+ config_type: type[ForeignConfig]
62
+ config_file_name: str
63
+ weights_file_names: tuple[str, ...]
64
+ weights_type: WeightsType
65
+ tokenizer_files: tuple[TokenizerFileSpec, ...] = tuple()
66
+ use_cases: tuple[UseCase, ...] = tuple()
67
+
68
+
69
+ def huggingface_weight_files(num_shards: int) -> tuple[str, ...]:
70
+ if num_shards == 1:
71
+ return ("model.safetensors",)
72
+ return tuple(f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1))
73
+
74
+
75
+ def awq_model_spec(model_spec: ModelSpec, repo: str, quantization: QuantizationMode = QuantizationMode.UINT4) -> ModelSpec:
76
+ return ModelSpec(
77
+ vendor=model_spec.vendor,
78
+ family=model_spec.family,
79
+ name="{}-AWQ".format(model_spec.name),
80
+ size=model_spec.size,
81
+ quantization=quantization,
82
+ repo=repo,
83
+ config_type=model_spec.config_type,
84
+ config_file_name=model_spec.config_file_name,
85
+ weights_file_names=huggingface_weight_files(1),
86
+ weights_type=model_spec.weights_type,
87
+ tokenizer_files=model_spec.tokenizer_files,
88
+ use_cases=model_spec.use_cases,
89
+ )
90
+
91
+
92
+ def build_quantized_models(model_specs: list[ModelSpec]):
93
+ quantization_compatible_repos: list[str] = [
94
+ "Qwen/Qwen2.5-3B-Instruct",
95
+ "Qwen/Qwen2.5-7B-Instruct",
96
+ "Qwen/Qwen2.5-Coder-3B-Instruct",
97
+ "Qwen/Qwen2.5-Coder-7B-Instruct",
98
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
99
+ "HuggingFaceTB/SmolLM2-1.7B-Instruct",
100
+ "meta-llama/Llama-3.2-3B-Instruct",
101
+ ]
102
+
103
+ quantized_model_specs: list[ModelSpec] = []
104
+ for model_spec in model_specs:
105
+ if model_spec.repo not in quantization_compatible_repos:
106
+ continue
107
+ quantized_repo = "trymirai/{}-AWQ".format(model_spec.repo.split("/")[-1])
108
+ quantized_model_spec = awq_model_spec(model_spec, quantized_repo)
109
+ quantized_model_specs.append(quantized_model_spec)
110
+ return quantized_model_specs
111
+
112
+
113
+ HUGGINGFACE_TOKENIZER_FILES = (
114
+ TokenizerFileSpec(repo=None, filename="tokenizer.json"),
115
+ TokenizerFileSpec(repo=None, filename="tokenizer_config.json"),
116
+ )
117
+
118
+ HUGGINFACE_GENERATION_CONFIG_FILE = TokenizerFileSpec(repo=None, filename="generation_config.json")
@@ -0,0 +1,28 @@
1
+ from lalamo.model_import.configs import HFQwen2Config
2
+
3
+ from .common import (
4
+ HUGGINFACE_GENERATION_CONFIG_FILE,
5
+ HUGGINGFACE_TOKENIZER_FILES,
6
+ ModelSpec,
7
+ WeightsType,
8
+ huggingface_weight_files,
9
+ )
10
+
11
+ __all__ = ["DEEPSEEK_MODELS"]
12
+
13
+ DEEPSEEK_MODELS = [
14
+ ModelSpec(
15
+ vendor="DeepSeek",
16
+ family="R1-Distill-Qwen",
17
+ name="R1-Distill-Qwen-1.5B",
18
+ size="1.5B",
19
+ quantization=None,
20
+ repo="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
21
+ config_type=HFQwen2Config,
22
+ config_file_name="config.json",
23
+ weights_file_names=huggingface_weight_files(1),
24
+ weights_type=WeightsType.SAFETENSORS,
25
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
+ use_cases=tuple(),
27
+ ),
28
+ ]
@@ -0,0 +1,76 @@
1
+ from lalamo.model_import.configs import HFGemma2Config, HFGemma3Config, HFGemma3TextConfig
2
+
3
+ from .common import (
4
+ HUGGINFACE_GENERATION_CONFIG_FILE,
5
+ HUGGINGFACE_TOKENIZER_FILES,
6
+ ModelSpec,
7
+ WeightsType,
8
+ huggingface_weight_files,
9
+ )
10
+
11
+ __all__ = ["GEMMA_MODELS"]
12
+
13
+ GEMMA2 = [
14
+ ModelSpec(
15
+ vendor="Google",
16
+ family="Gemma-2",
17
+ name="Gemma-2-2B-Instruct",
18
+ size="2B",
19
+ quantization=None,
20
+ repo="google/gemma-2-2b-it",
21
+ config_type=HFGemma2Config,
22
+ config_file_name="config.json",
23
+ weights_file_names=huggingface_weight_files(2),
24
+ weights_type=WeightsType.SAFETENSORS,
25
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
+ use_cases=tuple(),
27
+ ),
28
+ ]
29
+
30
+ GEMMA3 = [
31
+ ModelSpec(
32
+ vendor="Google",
33
+ family="Gemma-3",
34
+ name="Gemma-3-1B-Instruct",
35
+ size="1B",
36
+ quantization=None,
37
+ repo="google/gemma-3-1b-it",
38
+ config_type=HFGemma3TextConfig,
39
+ config_file_name="config.json",
40
+ weights_file_names=huggingface_weight_files(1),
41
+ weights_type=WeightsType.SAFETENSORS,
42
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
43
+ use_cases=tuple(),
44
+ ),
45
+ ModelSpec(
46
+ vendor="Google",
47
+ family="Gemma-3",
48
+ name="Gemma-3-4B-Instruct",
49
+ size="4B",
50
+ quantization=None,
51
+ repo="google/gemma-3-4b-it",
52
+ config_type=HFGemma3Config,
53
+ config_file_name="config.json",
54
+ weights_file_names=huggingface_weight_files(2),
55
+ weights_type=WeightsType.SAFETENSORS,
56
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
57
+ use_cases=tuple(),
58
+ ),
59
+ ModelSpec(
60
+ vendor="Google",
61
+ family="Gemma-3",
62
+ name="Gemma-3-27B-Instruct",
63
+ size="27B",
64
+ quantization=None,
65
+ repo="google/gemma-3-27b-it",
66
+ config_type=HFGemma3Config,
67
+ config_file_name="config.json",
68
+ weights_file_names=huggingface_weight_files(12),
69
+ weights_type=WeightsType.SAFETENSORS,
70
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
71
+ use_cases=tuple(),
72
+ ),
73
+ ]
74
+
75
+
76
+ GEMMA_MODELS = GEMMA2 + GEMMA3
@@ -0,0 +1,28 @@
1
+ from lalamo.model_import.configs import HFLlamaConfig
2
+
3
+ from .common import (
4
+ HUGGINFACE_GENERATION_CONFIG_FILE,
5
+ HUGGINGFACE_TOKENIZER_FILES,
6
+ ModelSpec,
7
+ WeightsType,
8
+ huggingface_weight_files,
9
+ )
10
+
11
+ __all__ = ["HUGGINGFACE_MODELS"]
12
+
13
+ HUGGINGFACE_MODELS = [
14
+ ModelSpec(
15
+ vendor="HuggingFace",
16
+ family="SmolLM2",
17
+ name="SmolLM2-1.7B-Instruct",
18
+ size="1.7B",
19
+ quantization=None,
20
+ repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
21
+ config_type=HFLlamaConfig,
22
+ config_file_name="config.json",
23
+ weights_file_names=huggingface_weight_files(1),
24
+ weights_type=WeightsType.SAFETENSORS,
25
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
26
+ use_cases=tuple(),
27
+ ),
28
+ ]
@@ -0,0 +1,100 @@
1
+ from dataclasses import replace
2
+
3
+ from lalamo.model_import.configs import HFLlamaConfig
4
+
5
+ from .common import (
6
+ HUGGINFACE_GENERATION_CONFIG_FILE,
7
+ HUGGINGFACE_TOKENIZER_FILES,
8
+ ModelSpec,
9
+ TokenizerFileSpec,
10
+ WeightsType,
11
+ huggingface_weight_files,
12
+ )
13
+
14
+ __all__ = ["LLAMA_MODELS"]
15
+
16
+ LLAMA31 = [
17
+ ModelSpec(
18
+ vendor="Meta",
19
+ family="Llama-3.1",
20
+ name="Llama-3.1-8B-Instruct",
21
+ size="8B",
22
+ quantization=None,
23
+ repo="meta-llama/Llama-3.1-8B-Instruct",
24
+ config_type=HFLlamaConfig,
25
+ config_file_name="config.json",
26
+ weights_file_names=huggingface_weight_files(4),
27
+ weights_type=WeightsType.SAFETENSORS,
28
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
29
+ use_cases=tuple(),
30
+ ),
31
+ ]
32
+
33
+
34
+ def _tokenizer_files_from_another_repo(repo: str) -> tuple[TokenizerFileSpec, ...]:
35
+ return tuple(
36
+ replace(spec, repo=repo) for spec in (*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE)
37
+ )
38
+
39
+
40
+ LLAMA32 = [
41
+ # LLAMA
42
+ ModelSpec(
43
+ vendor="Meta",
44
+ family="Llama-3.2",
45
+ name="Llama-3.2-1B-Instruct",
46
+ size="1B",
47
+ quantization=None,
48
+ repo="meta-llama/Llama-3.2-1B-Instruct",
49
+ config_type=HFLlamaConfig,
50
+ config_file_name="config.json",
51
+ weights_file_names=huggingface_weight_files(1),
52
+ weights_type=WeightsType.SAFETENSORS,
53
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
54
+ use_cases=tuple(),
55
+ ),
56
+ # ModelSpec(
57
+ # vendor="Meta",
58
+ # family="Llama-3.2",
59
+ # name="Llama-3.2-1B-Instruct-QLoRA",
60
+ # size="1B",
61
+ # quantization=QuantizationMode.UINT4,
62
+ # repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
63
+ # config_type=ETLlamaConfig,
64
+ # config_file_name="params.json",
65
+ # weights_file_names=("consolidated.00.pth",),
66
+ # weights_type=WeightsType.TORCH,
67
+ # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
68
+ # use_cases=tuple(),
69
+ # ),
70
+ ModelSpec(
71
+ vendor="Meta",
72
+ family="Llama-3.2",
73
+ name="Llama-3.2-3B-Instruct",
74
+ size="3B",
75
+ quantization=None,
76
+ repo="meta-llama/Llama-3.2-3B-Instruct",
77
+ config_type=HFLlamaConfig,
78
+ config_file_name="config.json",
79
+ weights_file_names=huggingface_weight_files(2),
80
+ weights_type=WeightsType.SAFETENSORS,
81
+ tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
82
+ use_cases=tuple(),
83
+ ),
84
+ # ModelSpec(
85
+ # vendor="Meta",
86
+ # family="Llama-3.2",
87
+ # name="Llama-3.2-3B-Instruct-QLoRA",
88
+ # size="3B",
89
+ # quantization=QuantizationMode.UINT4,
90
+ # repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
91
+ # config_type=ETLlamaConfig,
92
+ # config_file_name="params.json",
93
+ # weights_file_names=("consolidated.00.pth",),
94
+ # tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
95
+ # weights_type=WeightsType.TORCH,
96
+ # use_cases=tuple(),
97
+ # ),
98
+ ]
99
+
100
+ LLAMA_MODELS = LLAMA31 + LLAMA32