lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/__init__.py +8 -0
- lalamo/model_import/common.py +111 -0
- lalamo/model_import/configs/__init__.py +23 -0
- lalamo/model_import/configs/common.py +62 -0
- lalamo/model_import/configs/executorch.py +166 -0
- lalamo/model_import/configs/huggingface/__init__.py +18 -0
- lalamo/model_import/configs/huggingface/common.py +72 -0
- lalamo/model_import/configs/huggingface/gemma2.py +122 -0
- lalamo/model_import/configs/huggingface/gemma3.py +187 -0
- lalamo/model_import/configs/huggingface/llama.py +155 -0
- lalamo/model_import/configs/huggingface/mistral.py +132 -0
- lalamo/model_import/configs/huggingface/qwen2.py +144 -0
- lalamo/model_import/configs/huggingface/qwen3.py +142 -0
- lalamo/model_import/loaders/__init__.py +7 -0
- lalamo/model_import/loaders/common.py +45 -0
- lalamo/model_import/loaders/executorch.py +223 -0
- lalamo/model_import/loaders/huggingface.py +304 -0
- lalamo/model_import/model_specs/__init__.py +38 -0
- lalamo/model_import/model_specs/common.py +118 -0
- lalamo/model_import/model_specs/deepseek.py +28 -0
- lalamo/model_import/model_specs/gemma.py +76 -0
- lalamo/model_import/model_specs/huggingface.py +28 -0
- lalamo/model_import/model_specs/llama.py +101 -0
- lalamo/model_import/model_specs/mistral.py +59 -0
- lalamo/model_import/model_specs/pleias.py +28 -0
- lalamo/model_import/model_specs/polaris.py +22 -0
- lalamo/model_import/model_specs/qwen.py +336 -0
- lalamo/model_import/model_specs/reka.py +28 -0
- lalamo/modules/__init__.py +85 -0
- lalamo/modules/activations.py +30 -0
- lalamo/modules/attention.py +326 -0
- lalamo/modules/common.py +133 -0
- lalamo/modules/decoder.py +244 -0
- lalamo/modules/decoder_layer.py +240 -0
- lalamo/modules/embedding.py +299 -0
- lalamo/modules/kv_cache.py +196 -0
- lalamo/modules/linear.py +603 -0
- lalamo/modules/mlp.py +79 -0
- lalamo/modules/normalization.py +77 -0
- lalamo/modules/rope.py +255 -0
- lalamo/modules/utils.py +13 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
- lalamo-0.2.2.dist-info/RECORD +53 -0
- lalamo-0.2.1.dist-info/RECORD +0 -12
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from einops import rearrange
|
|
3
|
+
from jaxtyping import Array
|
|
4
|
+
|
|
5
|
+
from lalamo.common import ParameterPath
|
|
6
|
+
from lalamo.modules import (
|
|
7
|
+
MLP,
|
|
8
|
+
Attention,
|
|
9
|
+
Decoder,
|
|
10
|
+
DecoderLayer,
|
|
11
|
+
FullPrecisionLinear,
|
|
12
|
+
GroupQuantizedLinear,
|
|
13
|
+
LinearBase,
|
|
14
|
+
RMSNorm,
|
|
15
|
+
TiedEmbedding,
|
|
16
|
+
UntiedEmbedding,
|
|
17
|
+
)
|
|
18
|
+
from lalamo.quantization import QuantizationMode
|
|
19
|
+
|
|
20
|
+
from .common import load_parameters
|
|
21
|
+
|
|
22
|
+
__all__ = ["load_huggingface"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
AWQ_REVERSE_ORDER = jnp.array([0, 4, 1, 5, 2, 6, 3, 7], dtype=jnp.int32)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _reverse_uint4_awq_order(array: Array) -> Array:
|
|
29
|
+
"""Reverses the AWQ packing order to get the logical order of channels for INT4."""
|
|
30
|
+
pack_factor = 32 // 4
|
|
31
|
+
*_, last_dim = array.shape
|
|
32
|
+
if last_dim % pack_factor != 0:
|
|
33
|
+
return array
|
|
34
|
+
|
|
35
|
+
array_reshaped = rearrange(array, "... (group pack_factor) -> ... group pack_factor", pack_factor=pack_factor)
|
|
36
|
+
array_reordered = array_reshaped[..., AWQ_REVERSE_ORDER]
|
|
37
|
+
return rearrange(array_reordered, "... group pack_factor -> ... (group pack_factor)")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def unpack_int32(packed_weights: Array, mode: QuantizationMode) -> Array:
|
|
41
|
+
assert packed_weights.dtype == jnp.int32, (
|
|
42
|
+
f"Expected packed_weights to be of dtype jnp.int32, got {packed_weights.dtype}"
|
|
43
|
+
)
|
|
44
|
+
assert 32 % mode.bits == 0
|
|
45
|
+
|
|
46
|
+
shifts = jnp.arange(0, 32, mode.bits)
|
|
47
|
+
mask = (2**mode.bits) - 1
|
|
48
|
+
unpacked = jnp.bitwise_and(jnp.right_shift(packed_weights[:, :, None], shifts[None, None, :]), mask)
|
|
49
|
+
unpacked = rearrange(
|
|
50
|
+
unpacked,
|
|
51
|
+
"out_channels packed_groups packed_values -> out_channels (packed_groups packed_values)",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return unpacked
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _process_quantized_tensors(
|
|
58
|
+
qweights: Array,
|
|
59
|
+
qzeros: Array,
|
|
60
|
+
scales: Array,
|
|
61
|
+
module: GroupQuantizedLinear,
|
|
62
|
+
) -> tuple[Array, Array, Array]:
|
|
63
|
+
"""Unpacks, recenters, transposes, and casts quantized tensors to the correct dtype."""
|
|
64
|
+
mode = module.config.weight_quantization_mode
|
|
65
|
+
assert qweights.dtype == jnp.int32
|
|
66
|
+
unpacked_weights = unpack_int32(qweights, mode)
|
|
67
|
+
if mode == QuantizationMode.UINT4:
|
|
68
|
+
unpacked_weights = _reverse_uint4_awq_order(unpacked_weights)
|
|
69
|
+
|
|
70
|
+
assert qzeros.dtype == jnp.int32
|
|
71
|
+
unpacked_zero_points = unpack_int32(qzeros, mode)
|
|
72
|
+
if mode == QuantizationMode.UINT4:
|
|
73
|
+
unpacked_zero_points = _reverse_uint4_awq_order(unpacked_zero_points)
|
|
74
|
+
|
|
75
|
+
weights = unpacked_weights.astype(module.config.activation_precision)
|
|
76
|
+
zero_points = unpacked_zero_points.astype(module.config.activation_precision)
|
|
77
|
+
processed_scales = scales.astype(module.config.activation_precision)
|
|
78
|
+
|
|
79
|
+
return weights.transpose(), zero_points.transpose(), processed_scales.transpose()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _fuse_full_precision_weights(
|
|
83
|
+
weights_dict: dict[str, Array],
|
|
84
|
+
path: ParameterPath,
|
|
85
|
+
sublayers_to_fuse: list[str] | None,
|
|
86
|
+
) -> Array:
|
|
87
|
+
if sublayers_to_fuse is None:
|
|
88
|
+
return weights_dict[path / "weight"]
|
|
89
|
+
|
|
90
|
+
weights = [weights_dict[path / layer_name / "weight"] for layer_name in sublayers_to_fuse]
|
|
91
|
+
return jnp.concatenate(weights, axis=0)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _fuse_quantized_weights(
|
|
95
|
+
weights_dict: dict[str, Array],
|
|
96
|
+
path: ParameterPath,
|
|
97
|
+
sublayers_to_fuse: list[str] | None,
|
|
98
|
+
) -> tuple[Array, Array, Array]:
|
|
99
|
+
# Note that AWQ quantized weights are stored transposed relative to full-precision weights
|
|
100
|
+
|
|
101
|
+
if sublayers_to_fuse is None:
|
|
102
|
+
qweights = weights_dict[path / "qweight"]
|
|
103
|
+
qzeros = weights_dict[path / "qzeros"]
|
|
104
|
+
scales = weights_dict[path / "scales"]
|
|
105
|
+
return qweights, qzeros, scales
|
|
106
|
+
|
|
107
|
+
qweights = [weights_dict[path / layer_name / "qweight"] for layer_name in sublayers_to_fuse]
|
|
108
|
+
qzeros = [weights_dict[path / layer_name / "qzeros"] for layer_name in sublayers_to_fuse]
|
|
109
|
+
scales = [weights_dict[path / layer_name / "scales"] for layer_name in sublayers_to_fuse]
|
|
110
|
+
|
|
111
|
+
fused_qweights = jnp.concatenate(qweights, axis=1)
|
|
112
|
+
fused_qzeros = jnp.concatenate(qzeros, axis=1)
|
|
113
|
+
fused_scales = jnp.concatenate(scales, axis=1)
|
|
114
|
+
|
|
115
|
+
return fused_qweights, fused_qzeros, fused_scales
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def load_linear(
|
|
119
|
+
module: LinearBase,
|
|
120
|
+
weights_dict: dict[str, Array],
|
|
121
|
+
path: ParameterPath,
|
|
122
|
+
sublayers_to_fuse: list[str] | None = None,
|
|
123
|
+
) -> LinearBase:
|
|
124
|
+
"""Loads a linear layer, optionally fusing weights from sublayers."""
|
|
125
|
+
if not module.has_biases:
|
|
126
|
+
if sublayers_to_fuse:
|
|
127
|
+
paths_to_check = [path / proj / "bias" for proj in sublayers_to_fuse]
|
|
128
|
+
else:
|
|
129
|
+
paths_to_check = path / "bias"
|
|
130
|
+
for p in paths_to_check:
|
|
131
|
+
if p in weights_dict:
|
|
132
|
+
raise ValueError(f"Bias tensor found at {p} but module does not support it.")
|
|
133
|
+
bias = None
|
|
134
|
+
elif sublayers_to_fuse is None:
|
|
135
|
+
bias = weights_dict[path / "bias"]
|
|
136
|
+
else:
|
|
137
|
+
bias = jnp.concatenate(
|
|
138
|
+
[weights_dict[path / proj_name / "bias"] for proj_name in sublayers_to_fuse],
|
|
139
|
+
axis=0,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if isinstance(module, FullPrecisionLinear):
|
|
143
|
+
weights = _fuse_full_precision_weights(weights_dict, path, sublayers_to_fuse)
|
|
144
|
+
return load_parameters(lambda m: (m.weights, m.biases), module, (weights, bias))
|
|
145
|
+
|
|
146
|
+
if isinstance(module, GroupQuantizedLinear):
|
|
147
|
+
qweights, qzeros, scales = _fuse_quantized_weights(weights_dict, path, sublayers_to_fuse)
|
|
148
|
+
|
|
149
|
+
weights, zero_points, scales = _process_quantized_tensors(
|
|
150
|
+
qweights,
|
|
151
|
+
qzeros,
|
|
152
|
+
scales,
|
|
153
|
+
module,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return load_parameters(
|
|
157
|
+
lambda m: (m.weights, m.scales, m.zero_points, m.biases),
|
|
158
|
+
module,
|
|
159
|
+
(weights, scales, zero_points, bias),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
raise TypeError(f"Unsupported module type for loading: {type(module)}")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def load_mlp(module: MLP, weights_dict: dict[str, Array], path: ParameterPath) -> MLP:
|
|
166
|
+
up_projection = load_linear(module.up_projection, weights_dict, path, sublayers_to_fuse=["up_proj", "gate_proj"])
|
|
167
|
+
down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
|
|
168
|
+
return load_parameters(lambda m: (m.up_projection, m.down_projection), module, (up_projection, down_projection))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def load_rmsnorm(
|
|
172
|
+
module: RMSNorm,
|
|
173
|
+
weights_dict: dict[str, Array],
|
|
174
|
+
path: ParameterPath,
|
|
175
|
+
) -> RMSNorm:
|
|
176
|
+
scales = weights_dict[path / "weight"]
|
|
177
|
+
return load_parameters(lambda m: (m.scales,), module, (scales,))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def load_attention(
|
|
181
|
+
module: Attention,
|
|
182
|
+
weights_dict: dict[str, Array],
|
|
183
|
+
path: ParameterPath,
|
|
184
|
+
) -> Attention:
|
|
185
|
+
qkv_projection = load_linear(
|
|
186
|
+
module.qkv_projection,
|
|
187
|
+
weights_dict,
|
|
188
|
+
path,
|
|
189
|
+
sublayers_to_fuse=["q_proj", "k_proj", "v_proj"],
|
|
190
|
+
)
|
|
191
|
+
out_projection = load_linear(module.out_projection, weights_dict, path / "o_proj")
|
|
192
|
+
|
|
193
|
+
if module.query_norm is not None:
|
|
194
|
+
query_norm = load_rmsnorm(module.query_norm, weights_dict, path / "q_norm")
|
|
195
|
+
else:
|
|
196
|
+
query_norm = None
|
|
197
|
+
|
|
198
|
+
if module.key_norm is not None:
|
|
199
|
+
key_norm = load_rmsnorm(module.key_norm, weights_dict, path / "k_norm")
|
|
200
|
+
else:
|
|
201
|
+
key_norm = None
|
|
202
|
+
|
|
203
|
+
return load_parameters(
|
|
204
|
+
lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm),
|
|
205
|
+
module,
|
|
206
|
+
(qkv_projection, out_projection, query_norm, key_norm),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def load_decoder_layer(
|
|
211
|
+
module: DecoderLayer,
|
|
212
|
+
weights_dict: dict[str, Array],
|
|
213
|
+
path: ParameterPath,
|
|
214
|
+
) -> DecoderLayer:
|
|
215
|
+
pre_attention_norm = load_rmsnorm(
|
|
216
|
+
module.pre_attention_norm,
|
|
217
|
+
weights_dict,
|
|
218
|
+
path / "input_layernorm",
|
|
219
|
+
)
|
|
220
|
+
attention = load_attention(module.attention, weights_dict, path / "self_attn")
|
|
221
|
+
if module.post_attention_norm is not None:
|
|
222
|
+
post_attention_norm = load_rmsnorm(
|
|
223
|
+
module.post_attention_norm,
|
|
224
|
+
weights_dict,
|
|
225
|
+
path / "post_attention_layernorm",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
pre_mlp_norm = load_rmsnorm(
|
|
229
|
+
module.pre_mlp_norm,
|
|
230
|
+
weights_dict,
|
|
231
|
+
path / "pre_feedforward_layernorm",
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
post_attention_norm = None
|
|
235
|
+
|
|
236
|
+
pre_mlp_norm = load_rmsnorm(
|
|
237
|
+
module.pre_mlp_norm,
|
|
238
|
+
weights_dict,
|
|
239
|
+
path / "post_attention_layernorm",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
mlp = load_mlp(module.mlp, weights_dict, path / "mlp")
|
|
243
|
+
if module.post_mlp_norm is not None:
|
|
244
|
+
post_mlp_norm = load_rmsnorm(
|
|
245
|
+
module.post_mlp_norm,
|
|
246
|
+
weights_dict,
|
|
247
|
+
path / "post_feedforward_layernorm",
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
post_mlp_norm = None
|
|
251
|
+
return load_parameters(
|
|
252
|
+
lambda m: (m.pre_attention_norm, m.attention, m.post_attention_norm, m.pre_mlp_norm, m.mlp, m.post_mlp_norm),
|
|
253
|
+
module,
|
|
254
|
+
(pre_attention_norm, attention, post_attention_norm, pre_mlp_norm, mlp, post_mlp_norm),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def load_tied_embedding(
|
|
259
|
+
module: TiedEmbedding,
|
|
260
|
+
weights_dict: dict[str, Array],
|
|
261
|
+
decoder_path: ParameterPath,
|
|
262
|
+
) -> TiedEmbedding:
|
|
263
|
+
weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
|
|
264
|
+
return load_parameters(lambda m: (m.weights,), module, (weights,))
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def load_untied_embedding(
|
|
268
|
+
module: UntiedEmbedding,
|
|
269
|
+
weights_dict: dict[str, Array],
|
|
270
|
+
decoder_path: ParameterPath,
|
|
271
|
+
lm_head_path: ParameterPath,
|
|
272
|
+
) -> UntiedEmbedding:
|
|
273
|
+
input_weights = weights_dict[decoder_path / "embed_tokens" / "weight"]
|
|
274
|
+
output_weights = weights_dict[lm_head_path / "weight"]
|
|
275
|
+
return load_parameters(lambda m: (m.input_weights, m.output_weights), module, (input_weights, output_weights))
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def load_huggingface(
|
|
279
|
+
module: Decoder,
|
|
280
|
+
weights_dict: dict[str, Array],
|
|
281
|
+
) -> Decoder:
|
|
282
|
+
if any(key.startswith("language_model.") for key in weights_dict):
|
|
283
|
+
base_path = ParameterPath("language_model")
|
|
284
|
+
else:
|
|
285
|
+
base_path = ParameterPath()
|
|
286
|
+
|
|
287
|
+
decoder_path = base_path / "model"
|
|
288
|
+
lm_head_path = base_path / "lm_head"
|
|
289
|
+
|
|
290
|
+
if isinstance(module.embedding, TiedEmbedding):
|
|
291
|
+
embedding = load_tied_embedding(module.embedding, weights_dict, decoder_path)
|
|
292
|
+
elif isinstance(module.embedding, UntiedEmbedding):
|
|
293
|
+
embedding = load_untied_embedding(module.embedding, weights_dict, decoder_path, lm_head_path)
|
|
294
|
+
else:
|
|
295
|
+
raise TypeError(f"Unsupported embedding type: {type(module.embedding)}")
|
|
296
|
+
decoder_layers = tuple(
|
|
297
|
+
load_decoder_layer(layer, weights_dict, decoder_path / "layers" / i) for i, layer in enumerate(module.layers)
|
|
298
|
+
)
|
|
299
|
+
output_norm = load_rmsnorm(module.output_norm, weights_dict, decoder_path / "norm")
|
|
300
|
+
return load_parameters(
|
|
301
|
+
lambda m: (m.embedding, m.layers, m.output_norm),
|
|
302
|
+
module,
|
|
303
|
+
(embedding, decoder_layers, output_norm),
|
|
304
|
+
)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from .common import awq_model_spec, build_quantized_models, ModelSpec, UseCase
|
|
2
|
+
from .deepseek import DEEPSEEK_MODELS
|
|
3
|
+
from .gemma import GEMMA_MODELS
|
|
4
|
+
from .huggingface import HUGGINGFACE_MODELS
|
|
5
|
+
from .llama import LLAMA_MODELS
|
|
6
|
+
from .mistral import MISTRAL_MODELS
|
|
7
|
+
from .pleias import PLEIAS_MODELS
|
|
8
|
+
from .polaris import POLARIS_MODELS
|
|
9
|
+
from .qwen import QWEN_MODELS
|
|
10
|
+
from .reka import REKA_MODELS
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ALL_MODELS",
|
|
14
|
+
"REPO_TO_MODEL",
|
|
15
|
+
"ModelSpec",
|
|
16
|
+
"UseCase",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
ALL_MODEL_LISTS = [
|
|
21
|
+
LLAMA_MODELS,
|
|
22
|
+
DEEPSEEK_MODELS,
|
|
23
|
+
GEMMA_MODELS,
|
|
24
|
+
HUGGINGFACE_MODELS,
|
|
25
|
+
MISTRAL_MODELS,
|
|
26
|
+
PLEIAS_MODELS,
|
|
27
|
+
POLARIS_MODELS,
|
|
28
|
+
QWEN_MODELS,
|
|
29
|
+
REKA_MODELS,
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
ALL_MODELS = [model for model_list in ALL_MODEL_LISTS for model in model_list]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
QUANTIZED_MODELS = build_quantized_models(ALL_MODELS)
|
|
37
|
+
ALL_MODELS = ALL_MODELS + QUANTIZED_MODELS
|
|
38
|
+
REPO_TO_MODEL = {model.repo: model for model in ALL_MODELS}
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
from jaxtyping import Array, DTypeLike
|
|
8
|
+
from safetensors.flax import load_file as load_safetensors
|
|
9
|
+
|
|
10
|
+
from lalamo.model_import.configs import ForeignConfig
|
|
11
|
+
from lalamo.quantization import QuantizationMode
|
|
12
|
+
from lalamo.utils import torch_to_jax
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"HUGGINFACE_GENERATION_CONFIG_FILE",
|
|
16
|
+
"HUGGINGFACE_TOKENIZER_FILES",
|
|
17
|
+
"ModelSpec",
|
|
18
|
+
"TokenizerFileSpec",
|
|
19
|
+
"UseCase",
|
|
20
|
+
"huggingface_weight_files",
|
|
21
|
+
"awq_model_spec",
|
|
22
|
+
"build_quantized_models",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def cast_if_float(array: Array, cast_to: DTypeLike) -> Array:
|
|
27
|
+
if array.dtype in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
|
|
28
|
+
return array.astype(cast_to)
|
|
29
|
+
return array
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class WeightsType(Enum):
|
|
33
|
+
SAFETENSORS = "safetensors"
|
|
34
|
+
TORCH = "torch"
|
|
35
|
+
|
|
36
|
+
def load(self, filename: Path | str, float_dtype: DTypeLike) -> dict[str, jnp.ndarray]:
|
|
37
|
+
if self == WeightsType.SAFETENSORS:
|
|
38
|
+
return {k: cast_if_float(v, float_dtype) for k, v in load_safetensors(filename).items()}
|
|
39
|
+
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
40
|
+
return {k: cast_if_float(torch_to_jax(v), float_dtype) for k, v in torch_weights.items()}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class UseCase(Enum):
|
|
44
|
+
CODE = "code"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class TokenizerFileSpec:
|
|
49
|
+
repo: str | None
|
|
50
|
+
filename: str
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(frozen=True)
|
|
54
|
+
class ModelSpec:
|
|
55
|
+
vendor: str
|
|
56
|
+
family: str
|
|
57
|
+
name: str
|
|
58
|
+
size: str
|
|
59
|
+
quantization: QuantizationMode | None
|
|
60
|
+
repo: str
|
|
61
|
+
config_type: type[ForeignConfig]
|
|
62
|
+
config_file_name: str
|
|
63
|
+
weights_file_names: tuple[str, ...]
|
|
64
|
+
weights_type: WeightsType
|
|
65
|
+
tokenizer_files: tuple[TokenizerFileSpec, ...] = tuple()
|
|
66
|
+
use_cases: tuple[UseCase, ...] = tuple()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def huggingface_weight_files(num_shards: int) -> tuple[str, ...]:
|
|
70
|
+
if num_shards == 1:
|
|
71
|
+
return ("model.safetensors",)
|
|
72
|
+
return tuple(f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def awq_model_spec(model_spec: ModelSpec, repo: str, quantization: QuantizationMode = QuantizationMode.UINT4) -> ModelSpec:
|
|
76
|
+
return ModelSpec(
|
|
77
|
+
vendor=model_spec.vendor,
|
|
78
|
+
family=model_spec.family,
|
|
79
|
+
name="{}-AWQ".format(model_spec.name),
|
|
80
|
+
size=model_spec.size,
|
|
81
|
+
quantization=quantization,
|
|
82
|
+
repo=repo,
|
|
83
|
+
config_type=model_spec.config_type,
|
|
84
|
+
config_file_name=model_spec.config_file_name,
|
|
85
|
+
weights_file_names=huggingface_weight_files(1),
|
|
86
|
+
weights_type=model_spec.weights_type,
|
|
87
|
+
tokenizer_files=model_spec.tokenizer_files,
|
|
88
|
+
use_cases=model_spec.use_cases,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def build_quantized_models(model_specs: list[ModelSpec]):
|
|
93
|
+
quantization_compatible_repos: list[str] = [
|
|
94
|
+
"Qwen/Qwen2.5-3B-Instruct",
|
|
95
|
+
"Qwen/Qwen2.5-7B-Instruct",
|
|
96
|
+
"Qwen/Qwen2.5-Coder-3B-Instruct",
|
|
97
|
+
"Qwen/Qwen2.5-Coder-7B-Instruct",
|
|
98
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
|
99
|
+
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
|
100
|
+
"meta-llama/Llama-3.2-3B-Instruct",
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
quantized_model_specs: list[ModelSpec] = []
|
|
104
|
+
for model_spec in model_specs:
|
|
105
|
+
if model_spec.repo not in quantization_compatible_repos:
|
|
106
|
+
continue
|
|
107
|
+
quantized_repo = "trymirai/{}-AWQ".format(model_spec.repo.split("/")[-1])
|
|
108
|
+
quantized_model_spec = awq_model_spec(model_spec, quantized_repo)
|
|
109
|
+
quantized_model_specs.append(quantized_model_spec)
|
|
110
|
+
return quantized_model_specs
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
HUGGINGFACE_TOKENIZER_FILES = (
|
|
114
|
+
TokenizerFileSpec(repo=None, filename="tokenizer.json"),
|
|
115
|
+
TokenizerFileSpec(repo=None, filename="tokenizer_config.json"),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
HUGGINFACE_GENERATION_CONFIG_FILE = TokenizerFileSpec(repo=None, filename="generation_config.json")
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from lalamo.model_import.configs import HFQwen2Config
|
|
2
|
+
|
|
3
|
+
from .common import (
|
|
4
|
+
HUGGINFACE_GENERATION_CONFIG_FILE,
|
|
5
|
+
HUGGINGFACE_TOKENIZER_FILES,
|
|
6
|
+
ModelSpec,
|
|
7
|
+
WeightsType,
|
|
8
|
+
huggingface_weight_files,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = ["DEEPSEEK_MODELS"]
|
|
12
|
+
|
|
13
|
+
DEEPSEEK_MODELS = [
|
|
14
|
+
ModelSpec(
|
|
15
|
+
vendor="DeepSeek",
|
|
16
|
+
family="R1-Distill-Qwen",
|
|
17
|
+
name="R1-Distill-Qwen-1.5B",
|
|
18
|
+
size="1.5B",
|
|
19
|
+
quantization=None,
|
|
20
|
+
repo="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
|
21
|
+
config_type=HFQwen2Config,
|
|
22
|
+
config_file_name="config.json",
|
|
23
|
+
weights_file_names=huggingface_weight_files(1),
|
|
24
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
25
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
26
|
+
use_cases=tuple(),
|
|
27
|
+
),
|
|
28
|
+
]
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from lalamo.model_import.configs import HFGemma2Config, HFGemma3Config, HFGemma3TextConfig
|
|
2
|
+
|
|
3
|
+
from .common import (
|
|
4
|
+
HUGGINFACE_GENERATION_CONFIG_FILE,
|
|
5
|
+
HUGGINGFACE_TOKENIZER_FILES,
|
|
6
|
+
ModelSpec,
|
|
7
|
+
WeightsType,
|
|
8
|
+
huggingface_weight_files,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = ["GEMMA_MODELS"]
|
|
12
|
+
|
|
13
|
+
GEMMA2 = [
|
|
14
|
+
ModelSpec(
|
|
15
|
+
vendor="Google",
|
|
16
|
+
family="Gemma-2",
|
|
17
|
+
name="Gemma-2-2B-Instruct",
|
|
18
|
+
size="2B",
|
|
19
|
+
quantization=None,
|
|
20
|
+
repo="google/gemma-2-2b-it",
|
|
21
|
+
config_type=HFGemma2Config,
|
|
22
|
+
config_file_name="config.json",
|
|
23
|
+
weights_file_names=huggingface_weight_files(2),
|
|
24
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
25
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
26
|
+
use_cases=tuple(),
|
|
27
|
+
),
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
GEMMA3 = [
|
|
31
|
+
ModelSpec(
|
|
32
|
+
vendor="Google",
|
|
33
|
+
family="Gemma-3",
|
|
34
|
+
name="Gemma-3-1B-Instruct",
|
|
35
|
+
size="1B",
|
|
36
|
+
quantization=None,
|
|
37
|
+
repo="google/gemma-3-1b-it",
|
|
38
|
+
config_type=HFGemma3TextConfig,
|
|
39
|
+
config_file_name="config.json",
|
|
40
|
+
weights_file_names=huggingface_weight_files(1),
|
|
41
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
42
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
43
|
+
use_cases=tuple(),
|
|
44
|
+
),
|
|
45
|
+
ModelSpec(
|
|
46
|
+
vendor="Google",
|
|
47
|
+
family="Gemma-3",
|
|
48
|
+
name="Gemma-3-4B-Instruct",
|
|
49
|
+
size="4B",
|
|
50
|
+
quantization=None,
|
|
51
|
+
repo="google/gemma-3-4b-it",
|
|
52
|
+
config_type=HFGemma3Config,
|
|
53
|
+
config_file_name="config.json",
|
|
54
|
+
weights_file_names=huggingface_weight_files(2),
|
|
55
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
56
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
57
|
+
use_cases=tuple(),
|
|
58
|
+
),
|
|
59
|
+
ModelSpec(
|
|
60
|
+
vendor="Google",
|
|
61
|
+
family="Gemma-3",
|
|
62
|
+
name="Gemma-3-27B-Instruct",
|
|
63
|
+
size="27B",
|
|
64
|
+
quantization=None,
|
|
65
|
+
repo="google/gemma-3-27b-it",
|
|
66
|
+
config_type=HFGemma3Config,
|
|
67
|
+
config_file_name="config.json",
|
|
68
|
+
weights_file_names=huggingface_weight_files(12),
|
|
69
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
70
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
71
|
+
use_cases=tuple(),
|
|
72
|
+
),
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
GEMMA_MODELS = GEMMA2 + GEMMA3
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from lalamo.model_import.configs import HFLlamaConfig
|
|
2
|
+
|
|
3
|
+
from .common import (
|
|
4
|
+
HUGGINFACE_GENERATION_CONFIG_FILE,
|
|
5
|
+
HUGGINGFACE_TOKENIZER_FILES,
|
|
6
|
+
ModelSpec,
|
|
7
|
+
WeightsType,
|
|
8
|
+
huggingface_weight_files,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = ["HUGGINGFACE_MODELS"]
|
|
12
|
+
|
|
13
|
+
HUGGINGFACE_MODELS = [
|
|
14
|
+
ModelSpec(
|
|
15
|
+
vendor="HuggingFace",
|
|
16
|
+
family="SmolLM2",
|
|
17
|
+
name="SmolLM2-1.7B-Instruct",
|
|
18
|
+
size="1.7B",
|
|
19
|
+
quantization=None,
|
|
20
|
+
repo="HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
|
21
|
+
config_type=HFLlamaConfig,
|
|
22
|
+
config_file_name="config.json",
|
|
23
|
+
weights_file_names=huggingface_weight_files(1),
|
|
24
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
25
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
26
|
+
use_cases=tuple(),
|
|
27
|
+
),
|
|
28
|
+
]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from dataclasses import replace
|
|
2
|
+
|
|
3
|
+
from lalamo.model_import.configs import ETLlamaConfig, HFLlamaConfig
|
|
4
|
+
from lalamo.quantization import QuantizationMode
|
|
5
|
+
|
|
6
|
+
from .common import (
|
|
7
|
+
HUGGINFACE_GENERATION_CONFIG_FILE,
|
|
8
|
+
HUGGINGFACE_TOKENIZER_FILES,
|
|
9
|
+
ModelSpec,
|
|
10
|
+
TokenizerFileSpec,
|
|
11
|
+
WeightsType,
|
|
12
|
+
huggingface_weight_files,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = ["LLAMA_MODELS"]
|
|
16
|
+
|
|
17
|
+
LLAMA31 = [
|
|
18
|
+
ModelSpec(
|
|
19
|
+
vendor="Meta",
|
|
20
|
+
family="Llama-3.1",
|
|
21
|
+
name="Llama-3.1-8B-Instruct",
|
|
22
|
+
size="8B",
|
|
23
|
+
quantization=None,
|
|
24
|
+
repo="meta-llama/Llama-3.1-8B-Instruct",
|
|
25
|
+
config_type=HFLlamaConfig,
|
|
26
|
+
config_file_name="config.json",
|
|
27
|
+
weights_file_names=huggingface_weight_files(4),
|
|
28
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
29
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
30
|
+
use_cases=tuple(),
|
|
31
|
+
),
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _tokenizer_files_from_another_repo(repo: str) -> tuple[TokenizerFileSpec, ...]:
|
|
36
|
+
return tuple(
|
|
37
|
+
replace(spec, repo=repo) for spec in (*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
LLAMA32 = [
|
|
42
|
+
# LLAMA
|
|
43
|
+
ModelSpec(
|
|
44
|
+
vendor="Meta",
|
|
45
|
+
family="Llama-3.2",
|
|
46
|
+
name="Llama-3.2-1B-Instruct",
|
|
47
|
+
size="1B",
|
|
48
|
+
quantization=None,
|
|
49
|
+
repo="meta-llama/Llama-3.2-1B-Instruct",
|
|
50
|
+
config_type=HFLlamaConfig,
|
|
51
|
+
config_file_name="config.json",
|
|
52
|
+
weights_file_names=huggingface_weight_files(1),
|
|
53
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
54
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
55
|
+
use_cases=tuple(),
|
|
56
|
+
),
|
|
57
|
+
ModelSpec(
|
|
58
|
+
vendor="Meta",
|
|
59
|
+
family="Llama-3.2",
|
|
60
|
+
name="Llama-3.2-1B-Instruct-QLoRA",
|
|
61
|
+
size="1B",
|
|
62
|
+
quantization=QuantizationMode.UINT4,
|
|
63
|
+
repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
|
|
64
|
+
config_type=ETLlamaConfig,
|
|
65
|
+
config_file_name="params.json",
|
|
66
|
+
weights_file_names=("consolidated.00.pth",),
|
|
67
|
+
weights_type=WeightsType.TORCH,
|
|
68
|
+
tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-1B-Instruct"),
|
|
69
|
+
use_cases=tuple(),
|
|
70
|
+
),
|
|
71
|
+
ModelSpec(
|
|
72
|
+
vendor="Meta",
|
|
73
|
+
family="Llama-3.2",
|
|
74
|
+
name="Llama-3.2-3B-Instruct",
|
|
75
|
+
size="3B",
|
|
76
|
+
quantization=None,
|
|
77
|
+
repo="meta-llama/Llama-3.2-3B-Instruct",
|
|
78
|
+
config_type=HFLlamaConfig,
|
|
79
|
+
config_file_name="config.json",
|
|
80
|
+
weights_file_names=huggingface_weight_files(2),
|
|
81
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
82
|
+
tokenizer_files=(*HUGGINGFACE_TOKENIZER_FILES, HUGGINFACE_GENERATION_CONFIG_FILE),
|
|
83
|
+
use_cases=tuple(),
|
|
84
|
+
),
|
|
85
|
+
ModelSpec(
|
|
86
|
+
vendor="Meta",
|
|
87
|
+
family="Llama-3.2",
|
|
88
|
+
name="Llama-3.2-3B-Instruct-QLoRA",
|
|
89
|
+
size="3B",
|
|
90
|
+
quantization=QuantizationMode.UINT4,
|
|
91
|
+
repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
|
|
92
|
+
config_type=ETLlamaConfig,
|
|
93
|
+
config_file_name="params.json",
|
|
94
|
+
weights_file_names=("consolidated.00.pth",),
|
|
95
|
+
tokenizer_files=_tokenizer_files_from_another_repo("meta-llama/Llama-3.2-3B-Instruct"),
|
|
96
|
+
weights_type=WeightsType.TORCH,
|
|
97
|
+
use_cases=tuple(),
|
|
98
|
+
),
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
LLAMA_MODELS = LLAMA31 + LLAMA32
|