fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/convert.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import glob
|
|
3
|
+
import shutil
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Callable, Optional, Union
|
|
6
|
+
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
import mlx.nn as nn
|
|
9
|
+
from mlx.utils import tree_map_with_path
|
|
10
|
+
from mlx_lm.utils import dequantize_model, quantize_model
|
|
11
|
+
|
|
12
|
+
from .utils import (
|
|
13
|
+
MODEL_CONVERSION_DTYPES,
|
|
14
|
+
fetch_from_hub,
|
|
15
|
+
get_model_path,
|
|
16
|
+
save_config,
|
|
17
|
+
save_weights,
|
|
18
|
+
skip_multimodal_module,
|
|
19
|
+
upload_to_hub,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
QUANT_RECIPES = [
|
|
23
|
+
"mixed_2_6",
|
|
24
|
+
"mixed_3_4",
|
|
25
|
+
"mixed_3_5",
|
|
26
|
+
"mixed_3_6",
|
|
27
|
+
"mixed_3_8",
|
|
28
|
+
"mixed_4_6",
|
|
29
|
+
"mixed_4_8",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def mixed_quant_predicate_builder(
|
|
34
|
+
recipe: str, model: nn.Module
|
|
35
|
+
) -> Callable[[str, nn.Module], Union[bool, dict]]:
|
|
36
|
+
group_size = 64
|
|
37
|
+
|
|
38
|
+
recipe_config = {
|
|
39
|
+
"mixed_2_6": (2, 6),
|
|
40
|
+
"mixed_3_4": (3, 4),
|
|
41
|
+
"mixed_3_5": (3, 5),
|
|
42
|
+
"mixed_3_6": (3, 6),
|
|
43
|
+
"mixed_3_8": (3, 8),
|
|
44
|
+
"mixed_4_6": (4, 6),
|
|
45
|
+
"mixed_4_8": (4, 8),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
if recipe not in recipe_config:
|
|
49
|
+
raise ValueError(f"Invalid quant recipe {recipe}")
|
|
50
|
+
|
|
51
|
+
low_bits, high_bits = recipe_config[recipe]
|
|
52
|
+
|
|
53
|
+
down_keys = [k for k, _ in model.named_modules() if "down_proj" in k]
|
|
54
|
+
if len(down_keys) == 0:
|
|
55
|
+
raise ValueError("Model does not have expected keys for mixed quant.")
|
|
56
|
+
|
|
57
|
+
# Look for the layer index location in the path:
|
|
58
|
+
for layer_location, k in enumerate(down_keys[0].split(".")):
|
|
59
|
+
if k.isdigit():
|
|
60
|
+
break
|
|
61
|
+
num_layers = len(model.layers)
|
|
62
|
+
|
|
63
|
+
def mixed_quant_predicate(
|
|
64
|
+
path: str,
|
|
65
|
+
module: nn.Module,
|
|
66
|
+
) -> Union[bool, dict]:
|
|
67
|
+
"""Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M.
|
|
68
|
+
Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265
|
|
69
|
+
By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
if skip_multimodal_module(path):
|
|
73
|
+
return False
|
|
74
|
+
if not hasattr(module, "to_quantized"):
|
|
75
|
+
return False
|
|
76
|
+
if module.weight.shape[1] % group_size != 0:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
path_parts = path.split(".")
|
|
80
|
+
index = 0
|
|
81
|
+
|
|
82
|
+
if len(path_parts) > layer_location:
|
|
83
|
+
element = path_parts[layer_location]
|
|
84
|
+
if element.isdigit():
|
|
85
|
+
index = int(element)
|
|
86
|
+
|
|
87
|
+
use_more_bits = (
|
|
88
|
+
index < num_layers // 8
|
|
89
|
+
or index >= 7 * num_layers // 8
|
|
90
|
+
or (index - num_layers // 8) % 3 == 2
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if use_more_bits and ("v_proj" in path or "down_proj" in path):
|
|
94
|
+
return {"group_size": group_size, "bits": high_bits}
|
|
95
|
+
|
|
96
|
+
if "lm_head" in path or "embed_tokens" in path:
|
|
97
|
+
return {"group_size": group_size, "bits": high_bits}
|
|
98
|
+
|
|
99
|
+
return {"group_size": group_size, "bits": low_bits}
|
|
100
|
+
|
|
101
|
+
return mixed_quant_predicate
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def convert(
|
|
105
|
+
hf_path: str,
|
|
106
|
+
mlx_path: str = "mlx_model",
|
|
107
|
+
quantize: bool = False,
|
|
108
|
+
q_group_size: int = 64,
|
|
109
|
+
q_bits: int = 4,
|
|
110
|
+
q_mode: str = "affine",
|
|
111
|
+
dtype: Optional[str] = None,
|
|
112
|
+
upload_repo: str = None,
|
|
113
|
+
revision: Optional[str] = None,
|
|
114
|
+
dequantize: bool = False,
|
|
115
|
+
trust_remote_code: bool = True,
|
|
116
|
+
quant_predicate: Optional[str] = None,
|
|
117
|
+
):
|
|
118
|
+
print("[INFO] Loading")
|
|
119
|
+
model_path = get_model_path(hf_path, revision=revision)
|
|
120
|
+
model, config, processor = fetch_from_hub(
|
|
121
|
+
model_path, lazy=True, trust_remote_code=trust_remote_code
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def base_quant_predicate(path, module):
|
|
125
|
+
if skip_multimodal_module(path):
|
|
126
|
+
return False
|
|
127
|
+
return True
|
|
128
|
+
|
|
129
|
+
if isinstance(quant_predicate, str):
|
|
130
|
+
quant_predicate = mixed_quant_predicate_builder(quant_predicate, model)
|
|
131
|
+
|
|
132
|
+
quant_predicate = quant_predicate or base_quant_predicate
|
|
133
|
+
|
|
134
|
+
if dtype is None:
|
|
135
|
+
dtype = config.get("torch_dtype", None)
|
|
136
|
+
if dtype in MODEL_CONVERSION_DTYPES:
|
|
137
|
+
print("[INFO] Using dtype:", dtype)
|
|
138
|
+
dtype = getattr(mx, dtype)
|
|
139
|
+
cast_predicate = getattr(model, "cast_predicate", lambda _: True)
|
|
140
|
+
|
|
141
|
+
def set_dtype(k, v):
|
|
142
|
+
if cast_predicate(k) and mx.issubdtype(v.dtype, mx.floating):
|
|
143
|
+
return v.astype(dtype)
|
|
144
|
+
else:
|
|
145
|
+
return v
|
|
146
|
+
|
|
147
|
+
model.update(tree_map_with_path(set_dtype, model.parameters()))
|
|
148
|
+
|
|
149
|
+
if quantize and dequantize:
|
|
150
|
+
raise ValueError("Choose either quantize or dequantize, not both.")
|
|
151
|
+
|
|
152
|
+
if quantize:
|
|
153
|
+
print("[INFO] Quantizing")
|
|
154
|
+
config.setdefault("vision_config", {})
|
|
155
|
+
model, config = quantize_model(
|
|
156
|
+
model,
|
|
157
|
+
config,
|
|
158
|
+
q_group_size,
|
|
159
|
+
q_bits,
|
|
160
|
+
mode=q_mode,
|
|
161
|
+
quant_predicate=quant_predicate,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if dequantize:
|
|
165
|
+
print("[INFO] Dequantizing")
|
|
166
|
+
model = dequantize_model(model)
|
|
167
|
+
|
|
168
|
+
if isinstance(mlx_path, str):
|
|
169
|
+
mlx_path = Path(mlx_path)
|
|
170
|
+
|
|
171
|
+
save_weights(mlx_path, model, donate_weights=True)
|
|
172
|
+
|
|
173
|
+
# Copy Python and JSON files from the model path to the MLX path
|
|
174
|
+
for pattern in ["*.py", "*.json"]:
|
|
175
|
+
files = glob.glob(str(model_path / pattern))
|
|
176
|
+
for file in files:
|
|
177
|
+
# Skip the index file - save_weights() already generated the correct one
|
|
178
|
+
if Path(file).name == "model.safetensors.index.json":
|
|
179
|
+
continue
|
|
180
|
+
shutil.copy(file, mlx_path)
|
|
181
|
+
|
|
182
|
+
processor.save_pretrained(mlx_path)
|
|
183
|
+
|
|
184
|
+
save_config(config, config_path=mlx_path / "config.json")
|
|
185
|
+
|
|
186
|
+
if upload_repo is not None:
|
|
187
|
+
upload_to_hub(mlx_path, upload_repo, hf_path)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def configure_parser() -> argparse.ArgumentParser:
|
|
191
|
+
"""
|
|
192
|
+
Configures and returns the argument parser for the script.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
argparse.ArgumentParser: Configured argument parser.
|
|
196
|
+
"""
|
|
197
|
+
parser = argparse.ArgumentParser(
|
|
198
|
+
description="Convert Hugging Face model to MLX format"
|
|
199
|
+
)
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--hf-path",
|
|
202
|
+
"--model",
|
|
203
|
+
type=str,
|
|
204
|
+
help="Path to the model. This can be a local path or a Hugging Face Hub model identifier.",
|
|
205
|
+
)
|
|
206
|
+
parser.add_argument(
|
|
207
|
+
"--revision",
|
|
208
|
+
type=str,
|
|
209
|
+
help="Hugging Face revision (branch), when converting a model from the Hub.",
|
|
210
|
+
default=None,
|
|
211
|
+
)
|
|
212
|
+
parser.add_argument(
|
|
213
|
+
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
|
214
|
+
)
|
|
215
|
+
parser.add_argument(
|
|
216
|
+
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
|
217
|
+
)
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
"--q-group-size",
|
|
220
|
+
help="Group size for quantization.",
|
|
221
|
+
type=int,
|
|
222
|
+
default=None,
|
|
223
|
+
)
|
|
224
|
+
parser.add_argument(
|
|
225
|
+
"--q-bits",
|
|
226
|
+
help="Bits per weight for quantization.",
|
|
227
|
+
type=int,
|
|
228
|
+
default=None,
|
|
229
|
+
)
|
|
230
|
+
parser.add_argument(
|
|
231
|
+
"--q-mode",
|
|
232
|
+
help="The quantization mode.",
|
|
233
|
+
type=str,
|
|
234
|
+
choices=["affine", "mxfp4", "nvfp4", "mxfp8"],
|
|
235
|
+
default="affine",
|
|
236
|
+
)
|
|
237
|
+
parser.add_argument(
|
|
238
|
+
"--dtype",
|
|
239
|
+
help="Type to save the parameter. Defaults to config.json's `torch_dtype` or the current model weights dtype",
|
|
240
|
+
type=str,
|
|
241
|
+
choices=MODEL_CONVERSION_DTYPES,
|
|
242
|
+
default=None,
|
|
243
|
+
)
|
|
244
|
+
parser.add_argument(
|
|
245
|
+
"--quant-predicate",
|
|
246
|
+
help=f"Mixed-bit quantization recipe.",
|
|
247
|
+
choices=QUANT_RECIPES,
|
|
248
|
+
type=str,
|
|
249
|
+
required=False,
|
|
250
|
+
)
|
|
251
|
+
parser.add_argument(
|
|
252
|
+
"--upload-repo",
|
|
253
|
+
help="The Hugging Face repo to upload the model to.",
|
|
254
|
+
type=str,
|
|
255
|
+
default=None,
|
|
256
|
+
)
|
|
257
|
+
parser.add_argument(
|
|
258
|
+
"-d",
|
|
259
|
+
"--dequantize",
|
|
260
|
+
help="Dequantize a quantized model.",
|
|
261
|
+
action="store_true",
|
|
262
|
+
default=False,
|
|
263
|
+
)
|
|
264
|
+
parser.add_argument(
|
|
265
|
+
"--trust-remote-code",
|
|
266
|
+
help="Trust remote code.",
|
|
267
|
+
action="store_true",
|
|
268
|
+
default=False,
|
|
269
|
+
)
|
|
270
|
+
return parser
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def main():
|
|
274
|
+
parser = configure_parser()
|
|
275
|
+
args = parser.parse_args()
|
|
276
|
+
convert(**vars(args))
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
if __name__ == "__main__":
|
|
280
|
+
print(
|
|
281
|
+
"Calling `python -m mlx_vlm.convert ...` directly is deprecated."
|
|
282
|
+
" Use `mlx_vlm.convert ...` or `python -m mlx_vlm convert ...` instead."
|
|
283
|
+
)
|
|
284
|
+
main()
|
mlx_vlm/deprecation.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Callable, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def deprecate(
|
|
7
|
+
remove_version: str,
|
|
8
|
+
message: str,
|
|
9
|
+
instead: Optional[str] = None,
|
|
10
|
+
since: Optional[str] = None,
|
|
11
|
+
) -> Callable:
|
|
12
|
+
"""
|
|
13
|
+
Mark a function or method as deprecated.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
remove_version: Version when this will be removed
|
|
17
|
+
message: Deprecation message
|
|
18
|
+
instead: What to use instead
|
|
19
|
+
since: Version when this was deprecated
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
@deprecate(
|
|
23
|
+
remove_version="2.0.0",
|
|
24
|
+
message="Legacy API function",
|
|
25
|
+
instead="new_api()",
|
|
26
|
+
since="1.0.0"
|
|
27
|
+
)
|
|
28
|
+
def old_function():
|
|
29
|
+
pass
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def decorator(func: Callable) -> Callable:
|
|
33
|
+
@functools.wraps(func)
|
|
34
|
+
def wrapper(*args, **kwargs):
|
|
35
|
+
msg = f"`{func.__name__}` is deprecated"
|
|
36
|
+
|
|
37
|
+
if since:
|
|
38
|
+
msg += f" since v{since}"
|
|
39
|
+
|
|
40
|
+
msg += f". {message}"
|
|
41
|
+
|
|
42
|
+
if instead:
|
|
43
|
+
msg += f" Use `{instead}` instead."
|
|
44
|
+
|
|
45
|
+
msg += f" Will be removed in v{remove_version}."
|
|
46
|
+
|
|
47
|
+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
|
48
|
+
return func(*args, **kwargs)
|
|
49
|
+
|
|
50
|
+
return wrapper
|
|
51
|
+
|
|
52
|
+
return decorator
|
|
File without changes
|