mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,625 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import copy
|
|
5
|
+
import glob
|
|
6
|
+
import importlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
|
11
|
+
|
|
12
|
+
import mlx.core as mx
|
|
13
|
+
import mlx.nn as nn
|
|
14
|
+
from mlx.utils import tree_flatten, tree_reduce
|
|
15
|
+
from huggingface_hub import snapshot_download
|
|
16
|
+
from transformers import PreTrainedTokenizer
|
|
17
|
+
|
|
18
|
+
# Local imports
|
|
19
|
+
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
|
20
|
+
# Training imports
|
|
21
|
+
from mlx_raclate.tuner.utils import nparams #, load_adapters ### removing adapters for now
|
|
22
|
+
|
|
23
|
+
PIPELINES = [
|
|
24
|
+
"embeddings",
|
|
25
|
+
"masked-lm",
|
|
26
|
+
"text-classification",
|
|
27
|
+
"token-classification",
|
|
28
|
+
"sentence-transformers",
|
|
29
|
+
"zero-shot-classification",
|
|
30
|
+
"sentence-similarity"
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
# Map common string representations to MLX dtypes
|
|
34
|
+
STR_TO_DTYPE = {
|
|
35
|
+
"float32": mx.float32,
|
|
36
|
+
"fp32": mx.float32,
|
|
37
|
+
"float16": mx.float16,
|
|
38
|
+
"fp16": mx.float16,
|
|
39
|
+
"half": mx.float16,
|
|
40
|
+
"bfloat16": mx.bfloat16,
|
|
41
|
+
"bf16": mx.bfloat16,
|
|
42
|
+
# Less common but possible
|
|
43
|
+
"float64": mx.float32, # Map double to single precision (usually sufficient)
|
|
44
|
+
"double": mx.float32,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
HF_ARCH_TO_PIPELINE_MAPPING = {
|
|
48
|
+
"ForSequenceClassification": "text-classification",
|
|
49
|
+
"ForMaskedLM": "masked-lm",
|
|
50
|
+
"ForTokenClassification": "token-classification",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
MODEL_REMAPPING = {
|
|
54
|
+
"mistral": "llama", # mistral is compatible with llama
|
|
55
|
+
"phi-msft": "phixtral"
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
MAX_FILE_SIZE_GB = 5
|
|
59
|
+
|
|
60
|
+
class ModelNotFoundError(Exception):
|
|
61
|
+
def __init__(self, message):
|
|
62
|
+
self.message = message
|
|
63
|
+
super().__init__(self.message)
|
|
64
|
+
|
|
65
|
+
def _determine_model_dtype(config: dict, loaded_weights: dict) -> mx.Dtype:
|
|
66
|
+
"""
|
|
67
|
+
Robustly determine the target dtype for the model.
|
|
68
|
+
1. If 'quantization' is in config -> Default to float16 (Standard MLX format).
|
|
69
|
+
2. Else check config['torch_dtype'].
|
|
70
|
+
3. Else 'auto' -> infer from loaded weights.
|
|
71
|
+
"""
|
|
72
|
+
# MLX Quantized Models
|
|
73
|
+
# If the model is quantized, the non-quantized layers (norms, etc.)
|
|
74
|
+
# should usually be float16. We ignore torch_dtype here because
|
|
75
|
+
# converted configs often retain the original model's 'float32' tag.
|
|
76
|
+
if config.get("quantization", None) is not None:
|
|
77
|
+
return mx.float16
|
|
78
|
+
|
|
79
|
+
# Check Torch Config
|
|
80
|
+
dtype_entry = config.get("torch_dtype", "auto")
|
|
81
|
+
|
|
82
|
+
if isinstance(dtype_entry, str):
|
|
83
|
+
dtype_entry = dtype_entry.lower()
|
|
84
|
+
if dtype_entry in STR_TO_DTYPE:
|
|
85
|
+
return STR_TO_DTYPE[dtype_entry]
|
|
86
|
+
|
|
87
|
+
if dtype_entry == "auto":
|
|
88
|
+
# Infer from the first float-like weight we found
|
|
89
|
+
for v in loaded_weights.values():
|
|
90
|
+
if v.dtype in [mx.float16, mx.bfloat16]:
|
|
91
|
+
return v.dtype
|
|
92
|
+
return mx.float32
|
|
93
|
+
|
|
94
|
+
return mx.float32
|
|
95
|
+
|
|
96
|
+
def _get_pipeline_from_config(arch : str):
|
|
97
|
+
"""
|
|
98
|
+
Retrieve the pipeline type based on the model configuration.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
arch: first item of architectures from the model configuration.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
str: The pipeline type.
|
|
105
|
+
"""
|
|
106
|
+
if arch is not None:
|
|
107
|
+
for k,v in HF_ARCH_TO_PIPELINE_MAPPING.items():
|
|
108
|
+
if k in arch:
|
|
109
|
+
return v
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _get_classes(config: dict, pipeline: Optional[str] = 'masked-lm'):
|
|
114
|
+
"""
|
|
115
|
+
Retrieve the model and model args classes based on the configuration.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
config (dict): The model configuration.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
A tuple containing the Model class and the ModelArgs class.
|
|
122
|
+
"""
|
|
123
|
+
if pipeline not in PIPELINES:
|
|
124
|
+
raise ValueError(f"Pipeline {pipeline} not supported. Supported pipelines: {PIPELINES}")
|
|
125
|
+
|
|
126
|
+
model_type = config["model_type"]
|
|
127
|
+
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
128
|
+
try:
|
|
129
|
+
arch = importlib.import_module(f"mlx_raclate.models.{model_type}")
|
|
130
|
+
except ImportError:
|
|
131
|
+
msg = f"Model type {model_type} not supported."
|
|
132
|
+
logging.error(msg)
|
|
133
|
+
raise ValueError(msg)
|
|
134
|
+
|
|
135
|
+
if pipeline == "masked-lm":
|
|
136
|
+
return arch.ModelForMaskedLM, arch.ModelArgs
|
|
137
|
+
|
|
138
|
+
if pipeline == "text-classification":
|
|
139
|
+
return arch.ModelForSequenceClassification, arch.ModelArgs
|
|
140
|
+
|
|
141
|
+
if pipeline == "token-classification":
|
|
142
|
+
return arch.ModelForTokenClassification, arch.ModelArgs
|
|
143
|
+
|
|
144
|
+
if pipeline == "embeddings":
|
|
145
|
+
return arch.Model, arch.ModelArgs
|
|
146
|
+
|
|
147
|
+
if pipeline == "sentence-transformers":
|
|
148
|
+
return arch.ModelForSentenceTransformers, arch.ModelArgs
|
|
149
|
+
|
|
150
|
+
if pipeline == "zero-shot-classification":
|
|
151
|
+
return arch.ModelForMaskedLM, arch.ModelArgs
|
|
152
|
+
# using the MaskeLM pipeline for now (see models/modernbert.py comment for class ModelForZeroShotClassification)
|
|
153
|
+
# return arch.ModelForZeroShotClassification, arch.ModelArgs
|
|
154
|
+
|
|
155
|
+
if pipeline == "sentence-similarity":
|
|
156
|
+
return arch.ModelForSentenceSimilarity, arch.ModelArgs
|
|
157
|
+
|
|
158
|
+
### should not reach here
|
|
159
|
+
return arch.Model, arch.ModelArgs
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _initialize_head_weights(model: nn.Module, loaded_weights: dict, config: Any, target_dtype: mx.Dtype = mx.float32):
|
|
163
|
+
"""
|
|
164
|
+
If we are in training mode and missing head weights, we generate them
|
|
165
|
+
using the specific distribution required (e.g., Normal 0.02) rather
|
|
166
|
+
than relying on default initialization.
|
|
167
|
+
"""
|
|
168
|
+
# Flattens the model so we know the shape and dtype of every expected parameter
|
|
169
|
+
model_params = dict(tree_flatten(model.parameters()))
|
|
170
|
+
|
|
171
|
+
# Keywords that identify a 'Head' or 'Classifier' layer in your architectures
|
|
172
|
+
head_keywords = ["classifier", "score", "head", "decoder", "dense"]
|
|
173
|
+
|
|
174
|
+
initializer_range = getattr(config, "initializer_range", 0.02)
|
|
175
|
+
|
|
176
|
+
initialized_count = 0
|
|
177
|
+
|
|
178
|
+
for key, param in model_params.items():
|
|
179
|
+
# If the parameter is missing from the loaded checkpoint
|
|
180
|
+
if key not in loaded_weights:
|
|
181
|
+
# And it belongs to a prediction head
|
|
182
|
+
if any(x in key for x in head_keywords):
|
|
183
|
+
|
|
184
|
+
# Initialize Biases to Zero
|
|
185
|
+
if "bias" in key:
|
|
186
|
+
print(f"[INFO] Initializing missing bias {key} to Zeros ({target_dtype})")
|
|
187
|
+
loaded_weights[key] = mx.zeros(param.shape, dtype=target_dtype)
|
|
188
|
+
|
|
189
|
+
# 2. Initialize Weights
|
|
190
|
+
elif "weight" in key:
|
|
191
|
+
# Norm weights (Gamma) should be 1.0
|
|
192
|
+
if "norm" in key:
|
|
193
|
+
print(f"[INFO] Initializing missing normalization weight {key} to Ones ({target_dtype})")
|
|
194
|
+
loaded_weights[key] = mx.ones(param.shape, dtype=target_dtype)
|
|
195
|
+
# Other weights to Normal (std=0.02)
|
|
196
|
+
else:
|
|
197
|
+
print(f"[INFO] Initializing missing weight {key} with Normal(0.0, {initializer_range}) ({target_dtype})")
|
|
198
|
+
loaded_weights[key] = mx.random.normal(
|
|
199
|
+
param.shape,
|
|
200
|
+
scale=initializer_range,
|
|
201
|
+
dtype=target_dtype
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
initialized_count += 1
|
|
205
|
+
|
|
206
|
+
if initialized_count > 0:
|
|
207
|
+
print(f"[INFO] Explicitly initialized {initialized_count} missing parameters for transfer learning.")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _verify_weights(model: nn.Module, loaded_weights: dict, train_mode: bool):
|
|
211
|
+
"""
|
|
212
|
+
Ensures safety.
|
|
213
|
+
- Inference: CRASH if head weights are missing.
|
|
214
|
+
- Training: PASS (we will initialize them next).
|
|
215
|
+
"""
|
|
216
|
+
model_params = dict(tree_flatten(model.parameters()))
|
|
217
|
+
missing_keys = [k for k in model_params.keys() if k not in loaded_weights]
|
|
218
|
+
extra_keys = [k for k in loaded_weights.keys() if k not in model_params]
|
|
219
|
+
|
|
220
|
+
head_keywords = ['classifier', 'score', 'head', 'decoder']
|
|
221
|
+
missing_head_keys = [k for k in missing_keys if any(x in k for x in head_keywords)]
|
|
222
|
+
|
|
223
|
+
if missing_head_keys:
|
|
224
|
+
if not train_mode:
|
|
225
|
+
# CRASH: User wants inference but loaded a base model
|
|
226
|
+
raise ValueError(
|
|
227
|
+
f"Weights missing for pipeline head: {missing_head_keys[:3]}...\n"
|
|
228
|
+
f"You are trying to run Inference using a checkpoint that lacks the "
|
|
229
|
+
f"classifier/decoder layers (likely a base model).\n"
|
|
230
|
+
f"Set `train=True` if you intend to finetune this model."
|
|
231
|
+
f" Extra keys found in loaded weights: {extra_keys[:3]}..."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def compute_bits_per_weight(model):
|
|
236
|
+
model_bytes = tree_reduce(
|
|
237
|
+
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
238
|
+
)
|
|
239
|
+
leaf_modules = tree_flatten(
|
|
240
|
+
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
|
241
|
+
)
|
|
242
|
+
model_params = sum(nparams(m) for _, m in leaf_modules)
|
|
243
|
+
return model_bytes * 8 / model_params
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
|
247
|
+
"""
|
|
248
|
+
Ensures the model is available locally. If the path does not exist locally,
|
|
249
|
+
it is downloaded from the Hugging Face Hub.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
253
|
+
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Path: The path to the model.
|
|
257
|
+
"""
|
|
258
|
+
model_path = Path(path_or_hf_repo)
|
|
259
|
+
if not model_path.exists():
|
|
260
|
+
try:
|
|
261
|
+
model_path = Path(
|
|
262
|
+
snapshot_download(
|
|
263
|
+
repo_id=path_or_hf_repo,
|
|
264
|
+
revision=revision,
|
|
265
|
+
allow_patterns=[
|
|
266
|
+
"*.json",
|
|
267
|
+
"*.safetensors",
|
|
268
|
+
"*.py",
|
|
269
|
+
"tokenizer.model",
|
|
270
|
+
"*.tiktoken",
|
|
271
|
+
"*.txt",
|
|
272
|
+
],
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
except:
|
|
276
|
+
raise ModelNotFoundError(
|
|
277
|
+
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
|
278
|
+
"Please make sure you specified the local path or Hugging Face"
|
|
279
|
+
" repo id correctly.\nIf you are trying to access a private or"
|
|
280
|
+
" gated Hugging Face repo, make sure you are authenticated:\n"
|
|
281
|
+
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
|
282
|
+
) from None
|
|
283
|
+
return model_path
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def load_config(model_path: Path) -> dict:
|
|
287
|
+
try:
|
|
288
|
+
with open(model_path / "config.json", "r") as f:
|
|
289
|
+
config = json.load(f)
|
|
290
|
+
except FileNotFoundError:
|
|
291
|
+
logging.error(f"Config file not found in {model_path}")
|
|
292
|
+
raise
|
|
293
|
+
return config
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def load_model(
|
|
297
|
+
model_path: Path,
|
|
298
|
+
lazy: bool = False,
|
|
299
|
+
model_config: dict = {},
|
|
300
|
+
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
|
301
|
+
pipeline: Optional[str] = None,
|
|
302
|
+
train: bool = False,
|
|
303
|
+
) -> nn.Module:
|
|
304
|
+
"""
|
|
305
|
+
Load and initialize the model from a given path.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
model_path (Path): The path to load the model from.
|
|
309
|
+
lazy (bool): If False eval the model parameters to make sure they are
|
|
310
|
+
loaded in memory before returning, otherwise they will be loaded
|
|
311
|
+
when needed. Default: ``False``
|
|
312
|
+
model_config (dict, optional): Configuration parameters for the model.
|
|
313
|
+
Defaults to an empty dictionary.
|
|
314
|
+
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
|
315
|
+
A function that returns the model class and model args class given a config.
|
|
316
|
+
Defaults to the _get_classes function.
|
|
317
|
+
pipeline (str, optional): The pipeline type. If None, it will be inferred
|
|
318
|
+
from the model configuration. Defaults to None.
|
|
319
|
+
train (bool, optional): Whether the model is being loaded for training.
|
|
320
|
+
In training model, models can be loaded from a different pipeline and
|
|
321
|
+
some weights can be initialized accordingly. Defaults to False.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
nn.Module: The loaded and initialized model.
|
|
325
|
+
|
|
326
|
+
Raises:
|
|
327
|
+
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
328
|
+
ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
# check if model_path/config_sentence_transformers.json exists
|
|
332
|
+
is_sentence_transformer= (model_path / "config_sentence_transformers.json").exists()
|
|
333
|
+
|
|
334
|
+
config = load_config(model_path)
|
|
335
|
+
if 'is_encoder_decoder' in config and config.get('encoder', None):
|
|
336
|
+
model_type = config['model_type']
|
|
337
|
+
print(f"[INFO] Detected {model_type} model, merging encoder config.")
|
|
338
|
+
# merge encoder config for main models
|
|
339
|
+
encoder_config = config.get('encoder', {})
|
|
340
|
+
encoder_config['model_type'] = model_type + '_encoder'
|
|
341
|
+
config.update(encoder_config)
|
|
342
|
+
|
|
343
|
+
config.update(model_config)
|
|
344
|
+
|
|
345
|
+
arch = config.get("architectures", None)
|
|
346
|
+
if arch is not None:
|
|
347
|
+
model_arch = _get_pipeline_from_config(arch[0])
|
|
348
|
+
|
|
349
|
+
if model_arch is not None:
|
|
350
|
+
if pipeline is None:
|
|
351
|
+
pipeline = model_arch
|
|
352
|
+
print(f"[INFO] Using pipeline {pipeline} based on model architecture {model_arch}")
|
|
353
|
+
elif pipeline != model_arch:
|
|
354
|
+
print(
|
|
355
|
+
f"[INFO] Using pipeline {pipeline} based on user input, ignoring model architecture {model_arch}"
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if is_sentence_transformer :
|
|
359
|
+
if pipeline not in ["sentence-transformers", "embeddings", "sentence-similarity"]:
|
|
360
|
+
if not train:
|
|
361
|
+
raise ValueError(
|
|
362
|
+
f"Pipeline '{pipeline}' cannot be used with a Sentence Transformer model in Inference mode. "
|
|
363
|
+
f"These models only support embeddings/similarity."
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
print(f"[INFO] Adaptation: Loading Sentence Transformer base into {pipeline} pipeline for training.")
|
|
367
|
+
else:
|
|
368
|
+
pipeline = "sentence-transformers"
|
|
369
|
+
print(f"[INFO] Using pipeline {pipeline} based on Sentence Transformer config file.")
|
|
370
|
+
|
|
371
|
+
weights = {}
|
|
372
|
+
modules_file = model_path / "modules.json"
|
|
373
|
+
|
|
374
|
+
# Sentence Transformer weights may be loaded from subfolders
|
|
375
|
+
# prefix keys added so sanitize() can identify them
|
|
376
|
+
if is_sentence_transformer and modules_file.exists():
|
|
377
|
+
with open(modules_file, "r") as f:
|
|
378
|
+
modules = json.load(f)
|
|
379
|
+
|
|
380
|
+
for module in modules:
|
|
381
|
+
sub_path = module.get("path", "")
|
|
382
|
+
module_dir = model_path / sub_path
|
|
383
|
+
|
|
384
|
+
module_weights = glob.glob(str(module_dir / "model*.safetensors"))
|
|
385
|
+
if not module_weights:
|
|
386
|
+
# Fallback for older naming conventions
|
|
387
|
+
module_weights = glob.glob(str(module_dir / "weight*.safetensors"))
|
|
388
|
+
|
|
389
|
+
for wf in module_weights:
|
|
390
|
+
sub_weights = mx.load(wf)
|
|
391
|
+
for k, v in sub_weights.items():
|
|
392
|
+
# prefix the key 'linear.weight' -> '1_Dense.linear.weight'
|
|
393
|
+
# This allows the regex in sanitize() (r"\d+_Dense\.linear").
|
|
394
|
+
if sub_path:
|
|
395
|
+
weights[f"{sub_path}.{k}"] = v
|
|
396
|
+
else:
|
|
397
|
+
# Root module (Transformer), load keys as is
|
|
398
|
+
weights[k] = v
|
|
399
|
+
|
|
400
|
+
# Load weights from safetensors at the root of model_path
|
|
401
|
+
# Typically for non-Sentence Transformer models
|
|
402
|
+
if not weights:
|
|
403
|
+
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
|
404
|
+
if not weight_files:
|
|
405
|
+
# Try weight for back-compat
|
|
406
|
+
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
|
407
|
+
|
|
408
|
+
if not weight_files:
|
|
409
|
+
logging.error(f"No safetensors found in {model_path}")
|
|
410
|
+
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
|
411
|
+
|
|
412
|
+
for wf in weight_files:
|
|
413
|
+
weights.update(mx.load(wf))
|
|
414
|
+
|
|
415
|
+
target_dtype = _determine_model_dtype(config, weights)
|
|
416
|
+
print(f"[INFO] Model initialized with precision: {target_dtype}")
|
|
417
|
+
|
|
418
|
+
model_class, model_args_class = get_model_classes(config=config, pipeline=pipeline)
|
|
419
|
+
model_args = model_args_class.from_dict(config)
|
|
420
|
+
|
|
421
|
+
# Instantiate the model (random init)
|
|
422
|
+
model = model_class(model_args)
|
|
423
|
+
# Use set_dtype to update all floating-point parameters recursively.
|
|
424
|
+
# The default predicate ensures we don't accidentally cast integer params.
|
|
425
|
+
model.set_dtype(target_dtype)
|
|
426
|
+
|
|
427
|
+
if hasattr(model, "sanitize"):
|
|
428
|
+
weights = model.sanitize(weights)
|
|
429
|
+
|
|
430
|
+
_verify_weights(model, weights, train_mode=train)
|
|
431
|
+
|
|
432
|
+
if train:
|
|
433
|
+
_initialize_head_weights(model, weights, model_args, target_dtype=target_dtype)
|
|
434
|
+
|
|
435
|
+
model.load_weights(list(weights.items()))
|
|
436
|
+
|
|
437
|
+
if (quantization := config.get("quantization", None)) is not None:
|
|
438
|
+
# Handle legacy models which may not have everything quantized
|
|
439
|
+
def class_predicate(p, m):
|
|
440
|
+
if not hasattr(m, "to_quantized"):
|
|
441
|
+
return False
|
|
442
|
+
return f"{p}.scales" in weights
|
|
443
|
+
|
|
444
|
+
nn.quantize(
|
|
445
|
+
model,
|
|
446
|
+
**quantization,
|
|
447
|
+
class_predicate=class_predicate,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
if not lazy:
|
|
451
|
+
mx.eval(model.parameters())
|
|
452
|
+
|
|
453
|
+
model.eval()
|
|
454
|
+
return model, config
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def load(
|
|
458
|
+
path_or_hf_repo: str,
|
|
459
|
+
tokenizer_config={},
|
|
460
|
+
model_config={},
|
|
461
|
+
adapter_path: Optional[str] = None, ## for now, disabling adapter loading
|
|
462
|
+
lazy: bool = False,
|
|
463
|
+
pipeline: Optional[str] = None,
|
|
464
|
+
train: bool = False
|
|
465
|
+
) -> Tuple[nn.Module, TokenizerWrapper]:
|
|
466
|
+
"""
|
|
467
|
+
Load the model and tokenizer from a given path or a huggingface repository.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
|
471
|
+
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
472
|
+
Defaults to an empty dictionary.
|
|
473
|
+
model_config(dict, optional): Configuration parameters specifically for the model.
|
|
474
|
+
Defaults to an empty dictionary.
|
|
475
|
+
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
476
|
+
to the model. Default: ``None``.
|
|
477
|
+
lazy (bool): If False eval the model parameters to make sure they are
|
|
478
|
+
loaded in memory before returning, otherwise they will be loaded
|
|
479
|
+
when needed. Default: ``False``
|
|
480
|
+
pipeline (str, optional): The pipeline type. If None, it will be inferred
|
|
481
|
+
from the model configuration. Defaults to None.
|
|
482
|
+
train (bool, optional): Whether the model is being loaded for training.
|
|
483
|
+
In training model, models can be loaded from a different pipeline and
|
|
484
|
+
some weights can be initialized accordingly. Defaults to False.
|
|
485
|
+
Returns:
|
|
486
|
+
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
|
487
|
+
|
|
488
|
+
Raises:
|
|
489
|
+
FileNotFoundError: If config file or safetensors are not found.
|
|
490
|
+
ValueError: If model class or args class are not found.
|
|
491
|
+
"""
|
|
492
|
+
model_path = get_model_path(path_or_hf_repo)
|
|
493
|
+
|
|
494
|
+
model, config = load_model(model_path, lazy, model_config, pipeline=pipeline, train=train)
|
|
495
|
+
### disabling adapter for encoders
|
|
496
|
+
# if adapter_path is not None:
|
|
497
|
+
# model = load_adapters(model, adapter_path)
|
|
498
|
+
# model.eval()
|
|
499
|
+
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
|
500
|
+
|
|
501
|
+
return model, tokenizer
|
|
502
|
+
|
|
503
|
+
def fetch_from_hub(
|
|
504
|
+
model_path: Path, lazy: bool = False
|
|
505
|
+
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
|
506
|
+
model, config = load_model(model_path, lazy)
|
|
507
|
+
tokenizer = load_tokenizer(
|
|
508
|
+
model_path, eos_token_ids=config.get("eos_token_id", None)
|
|
509
|
+
)
|
|
510
|
+
return model, config, tokenizer
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def quantize_model(
|
|
514
|
+
model: nn.Module,
|
|
515
|
+
config: dict,
|
|
516
|
+
q_group_size: int = 64,
|
|
517
|
+
q_bits: int = 4,
|
|
518
|
+
quant_predicate: Optional[
|
|
519
|
+
Callable[[str, nn.Module, dict], Union[bool, dict]]
|
|
520
|
+
] = None,
|
|
521
|
+
) -> Tuple:
|
|
522
|
+
"""
|
|
523
|
+
Applies quantization to the model weights.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
model (nn.Module): The model to be quantized.
|
|
527
|
+
config (dict): Model configuration.
|
|
528
|
+
q_group_size (int): Group size for quantization.
|
|
529
|
+
q_bits (int): Bits per weight for quantization.
|
|
530
|
+
quant_predicate (Callable): A callable that decides how
|
|
531
|
+
to quantize each layer based on the path.
|
|
532
|
+
Accepts the layer `path`, the `module` and the model `config`.
|
|
533
|
+
Returns either a bool to signify quantize/no quantize or
|
|
534
|
+
a dict of quantization parameters to pass to `to_quantized`.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
Tuple: Tuple containing quantized weights and config.
|
|
538
|
+
"""
|
|
539
|
+
quantized_config = copy.deepcopy(config)
|
|
540
|
+
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
|
541
|
+
|
|
542
|
+
# Add any custom quantization parameters to the config as we go
|
|
543
|
+
def _class_predicate(p, m):
|
|
544
|
+
bool_or_params = quant_predicate(p, m, config)
|
|
545
|
+
quantized_config["quantization"][p] = bool_or_params
|
|
546
|
+
return bool_or_params
|
|
547
|
+
|
|
548
|
+
nn.quantize(
|
|
549
|
+
model,
|
|
550
|
+
q_group_size,
|
|
551
|
+
q_bits,
|
|
552
|
+
class_predicate=_class_predicate if quant_predicate else None,
|
|
553
|
+
)
|
|
554
|
+
# support hf model tree #957
|
|
555
|
+
quantized_config["quantization_config"] = quantized_config["quantization"]
|
|
556
|
+
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
557
|
+
|
|
558
|
+
bpw = compute_bits_per_weight(model)
|
|
559
|
+
print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
|
|
560
|
+
|
|
561
|
+
return quantized_weights, quantized_config
|
|
562
|
+
|
|
563
|
+
### Conversion should not be needed if we work with safetensors
|
|
564
|
+
### Kept here for reference, and if we need to re-implement it later
|
|
565
|
+
|
|
566
|
+
# def convert(
|
|
567
|
+
# hf_path: str,
|
|
568
|
+
# mlx_path: str = "mlx_model",
|
|
569
|
+
# quantize: bool = False,
|
|
570
|
+
# q_group_size: int = 64,
|
|
571
|
+
# q_bits: int = 4,
|
|
572
|
+
# dtype: str = "float16",
|
|
573
|
+
# upload_repo: str = None,
|
|
574
|
+
# revision: Optional[str] = None,
|
|
575
|
+
# dequantize: bool = False,
|
|
576
|
+
# quant_predicate: Optional[
|
|
577
|
+
# Callable[[str, nn.Module, dict], Union[bool, dict]]
|
|
578
|
+
# ] = None,
|
|
579
|
+
# ):
|
|
580
|
+
# # Check the save path is empty
|
|
581
|
+
# if isinstance(mlx_path, str):
|
|
582
|
+
# mlx_path = Path(mlx_path)
|
|
583
|
+
|
|
584
|
+
# if mlx_path.exists():
|
|
585
|
+
# raise ValueError(
|
|
586
|
+
# f"Cannot save to the path {mlx_path} as it already exists."
|
|
587
|
+
# " Please delete the file/directory or specify a new path to save to."
|
|
588
|
+
# )
|
|
589
|
+
|
|
590
|
+
# print("[INFO] Loading")
|
|
591
|
+
# model_path = get_model_path(hf_path, revision=revision)
|
|
592
|
+
# model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
|
593
|
+
|
|
594
|
+
# weights = dict(tree_flatten(model.parameters()))
|
|
595
|
+
# dtype = getattr(mx, dtype)
|
|
596
|
+
# weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
597
|
+
|
|
598
|
+
# if quantize and dequantize:
|
|
599
|
+
# raise ValueError("Choose either quantize or dequantize, not both.")
|
|
600
|
+
|
|
601
|
+
# if quantize:
|
|
602
|
+
# print("[INFO] Quantizing")
|
|
603
|
+
# model.load_weights(list(weights.items()))
|
|
604
|
+
# weights, config = quantize_model(
|
|
605
|
+
# model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
|
606
|
+
# )
|
|
607
|
+
|
|
608
|
+
# if dequantize:
|
|
609
|
+
# print("[INFO] Dequantizing")
|
|
610
|
+
# model = dequantize_model(model)
|
|
611
|
+
# weights = dict(tree_flatten(model.parameters()))
|
|
612
|
+
|
|
613
|
+
# del model
|
|
614
|
+
# save_weights(mlx_path, weights, donate_weights=True)
|
|
615
|
+
|
|
616
|
+
# py_files = glob.glob(str(model_path / "*.py"))
|
|
617
|
+
# for file in py_files:
|
|
618
|
+
# shutil.copy(file, mlx_path)
|
|
619
|
+
|
|
620
|
+
# tokenizer.save_pretrained(mlx_path)
|
|
621
|
+
|
|
622
|
+
# save_config(config, config_path=mlx_path / "config.json")
|
|
623
|
+
|
|
624
|
+
# if upload_repo is not None:
|
|
625
|
+
# upload_to_hub(mlx_path, upload_repo, hf_path)
|