mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,625 @@
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ import contextlib
4
+ import copy
5
+ import glob
6
+ import importlib
7
+ import json
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
11
+
12
+ import mlx.core as mx
13
+ import mlx.nn as nn
14
+ from mlx.utils import tree_flatten, tree_reduce
15
+ from huggingface_hub import snapshot_download
16
+ from transformers import PreTrainedTokenizer
17
+
18
+ # Local imports
19
+ from .tokenizer_utils import TokenizerWrapper, load_tokenizer
20
+ # Training imports
21
+ from mlx_raclate.tuner.utils import nparams #, load_adapters ### removing adapters for now
22
+
23
+ PIPELINES = [
24
+ "embeddings",
25
+ "masked-lm",
26
+ "text-classification",
27
+ "token-classification",
28
+ "sentence-transformers",
29
+ "zero-shot-classification",
30
+ "sentence-similarity"
31
+ ]
32
+
33
+ # Map common string representations to MLX dtypes
34
+ STR_TO_DTYPE = {
35
+ "float32": mx.float32,
36
+ "fp32": mx.float32,
37
+ "float16": mx.float16,
38
+ "fp16": mx.float16,
39
+ "half": mx.float16,
40
+ "bfloat16": mx.bfloat16,
41
+ "bf16": mx.bfloat16,
42
+ # Less common but possible
43
+ "float64": mx.float32, # Map double to single precision (usually sufficient)
44
+ "double": mx.float32,
45
+ }
46
+
47
+ HF_ARCH_TO_PIPELINE_MAPPING = {
48
+ "ForSequenceClassification": "text-classification",
49
+ "ForMaskedLM": "masked-lm",
50
+ "ForTokenClassification": "token-classification",
51
+ }
52
+
53
+ MODEL_REMAPPING = {
54
+ "mistral": "llama", # mistral is compatible with llama
55
+ "phi-msft": "phixtral"
56
+ }
57
+
58
+ MAX_FILE_SIZE_GB = 5
59
+
60
+ class ModelNotFoundError(Exception):
61
+ def __init__(self, message):
62
+ self.message = message
63
+ super().__init__(self.message)
64
+
65
+ def _determine_model_dtype(config: dict, loaded_weights: dict) -> mx.Dtype:
66
+ """
67
+ Robustly determine the target dtype for the model.
68
+ 1. If 'quantization' is in config -> Default to float16 (Standard MLX format).
69
+ 2. Else check config['torch_dtype'].
70
+ 3. Else 'auto' -> infer from loaded weights.
71
+ """
72
+ # MLX Quantized Models
73
+ # If the model is quantized, the non-quantized layers (norms, etc.)
74
+ # should usually be float16. We ignore torch_dtype here because
75
+ # converted configs often retain the original model's 'float32' tag.
76
+ if config.get("quantization", None) is not None:
77
+ return mx.float16
78
+
79
+ # Check Torch Config
80
+ dtype_entry = config.get("torch_dtype", "auto")
81
+
82
+ if isinstance(dtype_entry, str):
83
+ dtype_entry = dtype_entry.lower()
84
+ if dtype_entry in STR_TO_DTYPE:
85
+ return STR_TO_DTYPE[dtype_entry]
86
+
87
+ if dtype_entry == "auto":
88
+ # Infer from the first float-like weight we found
89
+ for v in loaded_weights.values():
90
+ if v.dtype in [mx.float16, mx.bfloat16]:
91
+ return v.dtype
92
+ return mx.float32
93
+
94
+ return mx.float32
95
+
96
+ def _get_pipeline_from_config(arch : str):
97
+ """
98
+ Retrieve the pipeline type based on the model configuration.
99
+
100
+ Args:
101
+ arch: first item of architectures from the model configuration.
102
+
103
+ Returns:
104
+ str: The pipeline type.
105
+ """
106
+ if arch is not None:
107
+ for k,v in HF_ARCH_TO_PIPELINE_MAPPING.items():
108
+ if k in arch:
109
+ return v
110
+ return None
111
+
112
+
113
+ def _get_classes(config: dict, pipeline: Optional[str] = 'masked-lm'):
114
+ """
115
+ Retrieve the model and model args classes based on the configuration.
116
+
117
+ Args:
118
+ config (dict): The model configuration.
119
+
120
+ Returns:
121
+ A tuple containing the Model class and the ModelArgs class.
122
+ """
123
+ if pipeline not in PIPELINES:
124
+ raise ValueError(f"Pipeline {pipeline} not supported. Supported pipelines: {PIPELINES}")
125
+
126
+ model_type = config["model_type"]
127
+ model_type = MODEL_REMAPPING.get(model_type, model_type)
128
+ try:
129
+ arch = importlib.import_module(f"mlx_raclate.models.{model_type}")
130
+ except ImportError:
131
+ msg = f"Model type {model_type} not supported."
132
+ logging.error(msg)
133
+ raise ValueError(msg)
134
+
135
+ if pipeline == "masked-lm":
136
+ return arch.ModelForMaskedLM, arch.ModelArgs
137
+
138
+ if pipeline == "text-classification":
139
+ return arch.ModelForSequenceClassification, arch.ModelArgs
140
+
141
+ if pipeline == "token-classification":
142
+ return arch.ModelForTokenClassification, arch.ModelArgs
143
+
144
+ if pipeline == "embeddings":
145
+ return arch.Model, arch.ModelArgs
146
+
147
+ if pipeline == "sentence-transformers":
148
+ return arch.ModelForSentenceTransformers, arch.ModelArgs
149
+
150
+ if pipeline == "zero-shot-classification":
151
+ return arch.ModelForMaskedLM, arch.ModelArgs
152
+ # using the MaskeLM pipeline for now (see models/modernbert.py comment for class ModelForZeroShotClassification)
153
+ # return arch.ModelForZeroShotClassification, arch.ModelArgs
154
+
155
+ if pipeline == "sentence-similarity":
156
+ return arch.ModelForSentenceSimilarity, arch.ModelArgs
157
+
158
+ ### should not reach here
159
+ return arch.Model, arch.ModelArgs
160
+
161
+
162
+ def _initialize_head_weights(model: nn.Module, loaded_weights: dict, config: Any, target_dtype: mx.Dtype = mx.float32):
163
+ """
164
+ If we are in training mode and missing head weights, we generate them
165
+ using the specific distribution required (e.g., Normal 0.02) rather
166
+ than relying on default initialization.
167
+ """
168
+ # Flattens the model so we know the shape and dtype of every expected parameter
169
+ model_params = dict(tree_flatten(model.parameters()))
170
+
171
+ # Keywords that identify a 'Head' or 'Classifier' layer in your architectures
172
+ head_keywords = ["classifier", "score", "head", "decoder", "dense"]
173
+
174
+ initializer_range = getattr(config, "initializer_range", 0.02)
175
+
176
+ initialized_count = 0
177
+
178
+ for key, param in model_params.items():
179
+ # If the parameter is missing from the loaded checkpoint
180
+ if key not in loaded_weights:
181
+ # And it belongs to a prediction head
182
+ if any(x in key for x in head_keywords):
183
+
184
+ # Initialize Biases to Zero
185
+ if "bias" in key:
186
+ print(f"[INFO] Initializing missing bias {key} to Zeros ({target_dtype})")
187
+ loaded_weights[key] = mx.zeros(param.shape, dtype=target_dtype)
188
+
189
+ # 2. Initialize Weights
190
+ elif "weight" in key:
191
+ # Norm weights (Gamma) should be 1.0
192
+ if "norm" in key:
193
+ print(f"[INFO] Initializing missing normalization weight {key} to Ones ({target_dtype})")
194
+ loaded_weights[key] = mx.ones(param.shape, dtype=target_dtype)
195
+ # Other weights to Normal (std=0.02)
196
+ else:
197
+ print(f"[INFO] Initializing missing weight {key} with Normal(0.0, {initializer_range}) ({target_dtype})")
198
+ loaded_weights[key] = mx.random.normal(
199
+ param.shape,
200
+ scale=initializer_range,
201
+ dtype=target_dtype
202
+ )
203
+
204
+ initialized_count += 1
205
+
206
+ if initialized_count > 0:
207
+ print(f"[INFO] Explicitly initialized {initialized_count} missing parameters for transfer learning.")
208
+
209
+
210
+ def _verify_weights(model: nn.Module, loaded_weights: dict, train_mode: bool):
211
+ """
212
+ Ensures safety.
213
+ - Inference: CRASH if head weights are missing.
214
+ - Training: PASS (we will initialize them next).
215
+ """
216
+ model_params = dict(tree_flatten(model.parameters()))
217
+ missing_keys = [k for k in model_params.keys() if k not in loaded_weights]
218
+ extra_keys = [k for k in loaded_weights.keys() if k not in model_params]
219
+
220
+ head_keywords = ['classifier', 'score', 'head', 'decoder']
221
+ missing_head_keys = [k for k in missing_keys if any(x in k for x in head_keywords)]
222
+
223
+ if missing_head_keys:
224
+ if not train_mode:
225
+ # CRASH: User wants inference but loaded a base model
226
+ raise ValueError(
227
+ f"Weights missing for pipeline head: {missing_head_keys[:3]}...\n"
228
+ f"You are trying to run Inference using a checkpoint that lacks the "
229
+ f"classifier/decoder layers (likely a base model).\n"
230
+ f"Set `train=True` if you intend to finetune this model."
231
+ f" Extra keys found in loaded weights: {extra_keys[:3]}..."
232
+ )
233
+
234
+
235
+ def compute_bits_per_weight(model):
236
+ model_bytes = tree_reduce(
237
+ lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
238
+ )
239
+ leaf_modules = tree_flatten(
240
+ model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
241
+ )
242
+ model_params = sum(nparams(m) for _, m in leaf_modules)
243
+ return model_bytes * 8 / model_params
244
+
245
+
246
+ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
247
+ """
248
+ Ensures the model is available locally. If the path does not exist locally,
249
+ it is downloaded from the Hugging Face Hub.
250
+
251
+ Args:
252
+ path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
253
+ revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
254
+
255
+ Returns:
256
+ Path: The path to the model.
257
+ """
258
+ model_path = Path(path_or_hf_repo)
259
+ if not model_path.exists():
260
+ try:
261
+ model_path = Path(
262
+ snapshot_download(
263
+ repo_id=path_or_hf_repo,
264
+ revision=revision,
265
+ allow_patterns=[
266
+ "*.json",
267
+ "*.safetensors",
268
+ "*.py",
269
+ "tokenizer.model",
270
+ "*.tiktoken",
271
+ "*.txt",
272
+ ],
273
+ )
274
+ )
275
+ except:
276
+ raise ModelNotFoundError(
277
+ f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
278
+ "Please make sure you specified the local path or Hugging Face"
279
+ " repo id correctly.\nIf you are trying to access a private or"
280
+ " gated Hugging Face repo, make sure you are authenticated:\n"
281
+ "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
282
+ ) from None
283
+ return model_path
284
+
285
+
286
+ def load_config(model_path: Path) -> dict:
287
+ try:
288
+ with open(model_path / "config.json", "r") as f:
289
+ config = json.load(f)
290
+ except FileNotFoundError:
291
+ logging.error(f"Config file not found in {model_path}")
292
+ raise
293
+ return config
294
+
295
+
296
+ def load_model(
297
+ model_path: Path,
298
+ lazy: bool = False,
299
+ model_config: dict = {},
300
+ get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
301
+ pipeline: Optional[str] = None,
302
+ train: bool = False,
303
+ ) -> nn.Module:
304
+ """
305
+ Load and initialize the model from a given path.
306
+
307
+ Args:
308
+ model_path (Path): The path to load the model from.
309
+ lazy (bool): If False eval the model parameters to make sure they are
310
+ loaded in memory before returning, otherwise they will be loaded
311
+ when needed. Default: ``False``
312
+ model_config (dict, optional): Configuration parameters for the model.
313
+ Defaults to an empty dictionary.
314
+ get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
315
+ A function that returns the model class and model args class given a config.
316
+ Defaults to the _get_classes function.
317
+ pipeline (str, optional): The pipeline type. If None, it will be inferred
318
+ from the model configuration. Defaults to None.
319
+ train (bool, optional): Whether the model is being loaded for training.
320
+ In training model, models can be loaded from a different pipeline and
321
+ some weights can be initialized accordingly. Defaults to False.
322
+
323
+ Returns:
324
+ nn.Module: The loaded and initialized model.
325
+
326
+ Raises:
327
+ FileNotFoundError: If the weight files (.safetensors) are not found.
328
+ ValueError: If the model class or args class are not found or cannot be instantiated.
329
+ """
330
+
331
+ # check if model_path/config_sentence_transformers.json exists
332
+ is_sentence_transformer= (model_path / "config_sentence_transformers.json").exists()
333
+
334
+ config = load_config(model_path)
335
+ if 'is_encoder_decoder' in config and config.get('encoder', None):
336
+ model_type = config['model_type']
337
+ print(f"[INFO] Detected {model_type} model, merging encoder config.")
338
+ # merge encoder config for main models
339
+ encoder_config = config.get('encoder', {})
340
+ encoder_config['model_type'] = model_type + '_encoder'
341
+ config.update(encoder_config)
342
+
343
+ config.update(model_config)
344
+
345
+ arch = config.get("architectures", None)
346
+ if arch is not None:
347
+ model_arch = _get_pipeline_from_config(arch[0])
348
+
349
+ if model_arch is not None:
350
+ if pipeline is None:
351
+ pipeline = model_arch
352
+ print(f"[INFO] Using pipeline {pipeline} based on model architecture {model_arch}")
353
+ elif pipeline != model_arch:
354
+ print(
355
+ f"[INFO] Using pipeline {pipeline} based on user input, ignoring model architecture {model_arch}"
356
+ )
357
+
358
+ if is_sentence_transformer :
359
+ if pipeline not in ["sentence-transformers", "embeddings", "sentence-similarity"]:
360
+ if not train:
361
+ raise ValueError(
362
+ f"Pipeline '{pipeline}' cannot be used with a Sentence Transformer model in Inference mode. "
363
+ f"These models only support embeddings/similarity."
364
+ )
365
+ else:
366
+ print(f"[INFO] Adaptation: Loading Sentence Transformer base into {pipeline} pipeline for training.")
367
+ else:
368
+ pipeline = "sentence-transformers"
369
+ print(f"[INFO] Using pipeline {pipeline} based on Sentence Transformer config file.")
370
+
371
+ weights = {}
372
+ modules_file = model_path / "modules.json"
373
+
374
+ # Sentence Transformer weights may be loaded from subfolders
375
+ # prefix keys added so sanitize() can identify them
376
+ if is_sentence_transformer and modules_file.exists():
377
+ with open(modules_file, "r") as f:
378
+ modules = json.load(f)
379
+
380
+ for module in modules:
381
+ sub_path = module.get("path", "")
382
+ module_dir = model_path / sub_path
383
+
384
+ module_weights = glob.glob(str(module_dir / "model*.safetensors"))
385
+ if not module_weights:
386
+ # Fallback for older naming conventions
387
+ module_weights = glob.glob(str(module_dir / "weight*.safetensors"))
388
+
389
+ for wf in module_weights:
390
+ sub_weights = mx.load(wf)
391
+ for k, v in sub_weights.items():
392
+ # prefix the key 'linear.weight' -> '1_Dense.linear.weight'
393
+ # This allows the regex in sanitize() (r"\d+_Dense\.linear").
394
+ if sub_path:
395
+ weights[f"{sub_path}.{k}"] = v
396
+ else:
397
+ # Root module (Transformer), load keys as is
398
+ weights[k] = v
399
+
400
+ # Load weights from safetensors at the root of model_path
401
+ # Typically for non-Sentence Transformer models
402
+ if not weights:
403
+ weight_files = glob.glob(str(model_path / "model*.safetensors"))
404
+ if not weight_files:
405
+ # Try weight for back-compat
406
+ weight_files = glob.glob(str(model_path / "weight*.safetensors"))
407
+
408
+ if not weight_files:
409
+ logging.error(f"No safetensors found in {model_path}")
410
+ raise FileNotFoundError(f"No safetensors found in {model_path}")
411
+
412
+ for wf in weight_files:
413
+ weights.update(mx.load(wf))
414
+
415
+ target_dtype = _determine_model_dtype(config, weights)
416
+ print(f"[INFO] Model initialized with precision: {target_dtype}")
417
+
418
+ model_class, model_args_class = get_model_classes(config=config, pipeline=pipeline)
419
+ model_args = model_args_class.from_dict(config)
420
+
421
+ # Instantiate the model (random init)
422
+ model = model_class(model_args)
423
+ # Use set_dtype to update all floating-point parameters recursively.
424
+ # The default predicate ensures we don't accidentally cast integer params.
425
+ model.set_dtype(target_dtype)
426
+
427
+ if hasattr(model, "sanitize"):
428
+ weights = model.sanitize(weights)
429
+
430
+ _verify_weights(model, weights, train_mode=train)
431
+
432
+ if train:
433
+ _initialize_head_weights(model, weights, model_args, target_dtype=target_dtype)
434
+
435
+ model.load_weights(list(weights.items()))
436
+
437
+ if (quantization := config.get("quantization", None)) is not None:
438
+ # Handle legacy models which may not have everything quantized
439
+ def class_predicate(p, m):
440
+ if not hasattr(m, "to_quantized"):
441
+ return False
442
+ return f"{p}.scales" in weights
443
+
444
+ nn.quantize(
445
+ model,
446
+ **quantization,
447
+ class_predicate=class_predicate,
448
+ )
449
+
450
+ if not lazy:
451
+ mx.eval(model.parameters())
452
+
453
+ model.eval()
454
+ return model, config
455
+
456
+
457
+ def load(
458
+ path_or_hf_repo: str,
459
+ tokenizer_config={},
460
+ model_config={},
461
+ adapter_path: Optional[str] = None, ## for now, disabling adapter loading
462
+ lazy: bool = False,
463
+ pipeline: Optional[str] = None,
464
+ train: bool = False
465
+ ) -> Tuple[nn.Module, TokenizerWrapper]:
466
+ """
467
+ Load the model and tokenizer from a given path or a huggingface repository.
468
+
469
+ Args:
470
+ path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
471
+ tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
472
+ Defaults to an empty dictionary.
473
+ model_config(dict, optional): Configuration parameters specifically for the model.
474
+ Defaults to an empty dictionary.
475
+ adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
476
+ to the model. Default: ``None``.
477
+ lazy (bool): If False eval the model parameters to make sure they are
478
+ loaded in memory before returning, otherwise they will be loaded
479
+ when needed. Default: ``False``
480
+ pipeline (str, optional): The pipeline type. If None, it will be inferred
481
+ from the model configuration. Defaults to None.
482
+ train (bool, optional): Whether the model is being loaded for training.
483
+ In training model, models can be loaded from a different pipeline and
484
+ some weights can be initialized accordingly. Defaults to False.
485
+ Returns:
486
+ Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
487
+
488
+ Raises:
489
+ FileNotFoundError: If config file or safetensors are not found.
490
+ ValueError: If model class or args class are not found.
491
+ """
492
+ model_path = get_model_path(path_or_hf_repo)
493
+
494
+ model, config = load_model(model_path, lazy, model_config, pipeline=pipeline, train=train)
495
+ ### disabling adapter for encoders
496
+ # if adapter_path is not None:
497
+ # model = load_adapters(model, adapter_path)
498
+ # model.eval()
499
+ tokenizer = load_tokenizer(model_path, tokenizer_config)
500
+
501
+ return model, tokenizer
502
+
503
+ def fetch_from_hub(
504
+ model_path: Path, lazy: bool = False
505
+ ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
506
+ model, config = load_model(model_path, lazy)
507
+ tokenizer = load_tokenizer(
508
+ model_path, eos_token_ids=config.get("eos_token_id", None)
509
+ )
510
+ return model, config, tokenizer
511
+
512
+
513
+ def quantize_model(
514
+ model: nn.Module,
515
+ config: dict,
516
+ q_group_size: int = 64,
517
+ q_bits: int = 4,
518
+ quant_predicate: Optional[
519
+ Callable[[str, nn.Module, dict], Union[bool, dict]]
520
+ ] = None,
521
+ ) -> Tuple:
522
+ """
523
+ Applies quantization to the model weights.
524
+
525
+ Args:
526
+ model (nn.Module): The model to be quantized.
527
+ config (dict): Model configuration.
528
+ q_group_size (int): Group size for quantization.
529
+ q_bits (int): Bits per weight for quantization.
530
+ quant_predicate (Callable): A callable that decides how
531
+ to quantize each layer based on the path.
532
+ Accepts the layer `path`, the `module` and the model `config`.
533
+ Returns either a bool to signify quantize/no quantize or
534
+ a dict of quantization parameters to pass to `to_quantized`.
535
+
536
+ Returns:
537
+ Tuple: Tuple containing quantized weights and config.
538
+ """
539
+ quantized_config = copy.deepcopy(config)
540
+ quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
541
+
542
+ # Add any custom quantization parameters to the config as we go
543
+ def _class_predicate(p, m):
544
+ bool_or_params = quant_predicate(p, m, config)
545
+ quantized_config["quantization"][p] = bool_or_params
546
+ return bool_or_params
547
+
548
+ nn.quantize(
549
+ model,
550
+ q_group_size,
551
+ q_bits,
552
+ class_predicate=_class_predicate if quant_predicate else None,
553
+ )
554
+ # support hf model tree #957
555
+ quantized_config["quantization_config"] = quantized_config["quantization"]
556
+ quantized_weights = dict(tree_flatten(model.parameters()))
557
+
558
+ bpw = compute_bits_per_weight(model)
559
+ print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.")
560
+
561
+ return quantized_weights, quantized_config
562
+
563
+ ### Conversion should not be needed if we work with safetensors
564
+ ### Kept here for reference, and if we need to re-implement it later
565
+
566
+ # def convert(
567
+ # hf_path: str,
568
+ # mlx_path: str = "mlx_model",
569
+ # quantize: bool = False,
570
+ # q_group_size: int = 64,
571
+ # q_bits: int = 4,
572
+ # dtype: str = "float16",
573
+ # upload_repo: str = None,
574
+ # revision: Optional[str] = None,
575
+ # dequantize: bool = False,
576
+ # quant_predicate: Optional[
577
+ # Callable[[str, nn.Module, dict], Union[bool, dict]]
578
+ # ] = None,
579
+ # ):
580
+ # # Check the save path is empty
581
+ # if isinstance(mlx_path, str):
582
+ # mlx_path = Path(mlx_path)
583
+
584
+ # if mlx_path.exists():
585
+ # raise ValueError(
586
+ # f"Cannot save to the path {mlx_path} as it already exists."
587
+ # " Please delete the file/directory or specify a new path to save to."
588
+ # )
589
+
590
+ # print("[INFO] Loading")
591
+ # model_path = get_model_path(hf_path, revision=revision)
592
+ # model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
593
+
594
+ # weights = dict(tree_flatten(model.parameters()))
595
+ # dtype = getattr(mx, dtype)
596
+ # weights = {k: v.astype(dtype) for k, v in weights.items()}
597
+
598
+ # if quantize and dequantize:
599
+ # raise ValueError("Choose either quantize or dequantize, not both.")
600
+
601
+ # if quantize:
602
+ # print("[INFO] Quantizing")
603
+ # model.load_weights(list(weights.items()))
604
+ # weights, config = quantize_model(
605
+ # model, config, q_group_size, q_bits, quant_predicate=quant_predicate
606
+ # )
607
+
608
+ # if dequantize:
609
+ # print("[INFO] Dequantizing")
610
+ # model = dequantize_model(model)
611
+ # weights = dict(tree_flatten(model.parameters()))
612
+
613
+ # del model
614
+ # save_weights(mlx_path, weights, donate_weights=True)
615
+
616
+ # py_files = glob.glob(str(model_path / "*.py"))
617
+ # for file in py_files:
618
+ # shutil.copy(file, mlx_path)
619
+
620
+ # tokenizer.save_pretrained(mlx_path)
621
+
622
+ # save_config(config, config_path=mlx_path / "config.json")
623
+
624
+ # if upload_repo is not None:
625
+ # upload_to_hub(mlx_path, upload_repo, hf_path)