fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/utils.py ADDED
@@ -0,0 +1,1339 @@
1
+ import glob
2
+ import importlib
3
+ import inspect
4
+ import json
5
+ import logging
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from textwrap import dedent
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import mlx.core as mx
12
+ import mlx.nn as nn
13
+ import numpy as np
14
+ import requests
15
+ import soundfile as sf
16
+ from huggingface_hub import snapshot_download
17
+ from mlx.utils import tree_flatten, tree_map
18
+ from PIL import Image, ImageOps
19
+ from transformers import AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast
20
+
21
+ from .models.base import BaseImageProcessor
22
+ from .tokenizer_utils import load_tokenizer
23
+ from .trainer import apply_lora_layers
24
+
25
+ # Constants
26
+ MODEL_REMAPPING = {
27
+ "llava_qwen2": "fastvlm", # Apple's FastVLM, note it's different to the one below
28
+ "llava-qwen2": "llava_bunny",
29
+ "bunny-llama": "llava_bunny",
30
+ "lfm2-vl": "lfm2_vl",
31
+ "cohere2_vision": "aya_vision",
32
+ "jvlm": "jina_vlm",
33
+ "moondream1": "moondream2",
34
+ }
35
+
36
+ MAX_FILE_SIZE_GB = 5
37
+
38
+ MODEL_CONVERSION_DTYPES = ["float16", "bfloat16", "float32"]
39
+
40
+
41
+ def skip_multimodal_module(path: str) -> bool:
42
+ """
43
+ Check if a multimodal module (vision/audio) should skip quantization.
44
+
45
+ Args:
46
+ path: The module path to check
47
+
48
+ Returns:
49
+ bool: True if the module is multimodal and should skip quantization, False otherwise
50
+ """
51
+ return (
52
+ "vision_model" in path
53
+ or "vision_tower" in path
54
+ or "vl_connector" in path
55
+ or "sam_model" in path
56
+ or "audio_model" in path
57
+ or "audio_tower" in path
58
+ or "code_predictor" in path
59
+ )
60
+
61
+
62
+ def get_model_and_args(config: dict):
63
+ """
64
+ Retrieve the model object based on the configuration.
65
+
66
+ Args:
67
+ config (dict): The model configuration.
68
+
69
+ Returns:
70
+ A tuple containing the Model class and the ModelArgs class.
71
+ """
72
+ model_type = config["model_type"].lower()
73
+
74
+ model_type = MODEL_REMAPPING.get(model_type, model_type)
75
+
76
+ try:
77
+ arch = importlib.import_module(f"mlx_vlm.models.{model_type}")
78
+ except ImportError as e:
79
+ msg = f"Model type {model_type} not supported. Error: {e}"
80
+ logging.error(msg)
81
+ raise ValueError(msg)
82
+
83
+ return arch, model_type
84
+
85
+
86
+ def get_model_path(
87
+ path_or_hf_repo: str, revision: Optional[str] = None, force_download: bool = False
88
+ ) -> Path:
89
+ """
90
+ Ensures the model is available locally. If the path does not exist locally,
91
+ it is downloaded from the Hugging Face Hub.
92
+
93
+ Args:
94
+ path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
95
+ revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
96
+
97
+ Returns:
98
+ Path: The path to the model.
99
+ """
100
+ model_path = Path(path_or_hf_repo)
101
+ if not model_path.exists():
102
+ model_path = Path(
103
+ snapshot_download(
104
+ repo_id=path_or_hf_repo,
105
+ revision=revision,
106
+ allow_patterns=[
107
+ "*.json",
108
+ "*.safetensors",
109
+ "*.py",
110
+ "*.model",
111
+ "*.tiktoken",
112
+ "*.txt",
113
+ "*.jinja",
114
+ ],
115
+ force_download=force_download,
116
+ )
117
+ )
118
+ return model_path
119
+
120
+
121
+ def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module:
122
+ """
123
+ Load and initialize the model from a given path.
124
+
125
+ Args:
126
+ model_path (Path): The path to load the model from.
127
+ lazy (bool): If False eval the model parameters to make sure they are
128
+ loaded in memory before returning, otherwise they will be loaded
129
+ when needed. Default: ``False``
130
+ revision (str, optional): A revision id which can be a branch name,
131
+ a tag, or a commit hash. Default: ``None``.
132
+
133
+ Returns:
134
+ nn.Module: The loaded and initialized model.
135
+
136
+ Raises:
137
+ FileNotFoundError: If the weight files (.safetensors) are not found.
138
+ ValueError: If the model class or args class are not found or cannot be instantiated.
139
+ """
140
+ config = load_config(model_path, **kwargs)
141
+ quantization = config.get("quantization", None)
142
+
143
+ # Find all .safetensors files in the model_path, excluding consolidated model weights
144
+ weight_files = [
145
+ wf
146
+ for wf in glob.glob(str(model_path / "*.safetensors"))
147
+ if not wf.endswith("consolidated.safetensors")
148
+ ]
149
+
150
+ if not weight_files:
151
+ logging.error(f"No safetensors found in {model_path}")
152
+ message = f"""
153
+ No safetensors found in {model_path}
154
+ Create safetensors using the following code:
155
+ ```
156
+ from transformers import AutoModelForCausalLM, AutoProcessor
157
+
158
+ model_id= "<huggingface_model_id>"
159
+ model = AutoModelForCausalLM.from_pretrained(model_id)
160
+ processor = AutoProcessor.from_pretrained(model_id)
161
+
162
+ model.save_pretrained("<local_dir>")
163
+ processor.save_pretrained("<local_dir>")
164
+ ```
165
+ Then use the <local_dir> as the --hf-path in the convert script.
166
+ ```
167
+ python -m mlx_vlm.convert --hf-path <local_dir> --mlx-path <mlx_dir>
168
+ ```
169
+ """
170
+ raise FileNotFoundError(message)
171
+
172
+ weights = {}
173
+ for wf in weight_files:
174
+ weights.update(mx.load(wf))
175
+
176
+ import safetensors
177
+
178
+ with safetensors.safe_open(weight_files[0], framework="np") as f:
179
+ is_mlx_format = f.metadata() and f.metadata().get("format") == "mlx"
180
+
181
+ model_class, _ = get_model_and_args(config=config)
182
+
183
+ # Initialize text and vision configs if not present
184
+ config.setdefault("text_config", {})
185
+ config.setdefault("vision_config", {})
186
+ config.setdefault("audio_config", {})
187
+
188
+ # Initialize model config and update it with module configs
189
+ model_config = model_class.ModelConfig.from_dict(config)
190
+ modules = ["text", "vision", "perceiver", "projector", "audio"]
191
+ model_config = update_module_configs(model_config, model_class, config, modules)
192
+
193
+ model = model_class.Model(model_config)
194
+
195
+ # #region agent log
196
+ import json
197
+ log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
198
+ def log_weights(location, message, data, hypothesis_id):
199
+ try:
200
+ with open(log_file, "a") as f:
201
+ f.write(json.dumps({"sessionId": "debug-session", "runId": "load_model", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
202
+ except: pass
203
+
204
+ # Get all model parameter keys
205
+ model_params = dict(tree_flatten(model.parameters()))
206
+ model_param_keys = sorted(model_params.keys())
207
+ log_weights("utils.py:load_model_params", "Model parameter keys before load_weights", {"n_params": len(model_param_keys), "sample_keys": model_param_keys[:10], "all_keys": model_param_keys}, "H3")
208
+ # #endregion
209
+
210
+ if not is_mlx_format:
211
+ # #region agent log
212
+ pre_sanitize_keys = sorted(weights.keys())
213
+ log_weights("utils.py:pre_sanitize", "Weights before sanitize", {"n_weights": len(weights), "sample_keys": pre_sanitize_keys[:10], "all_keys": pre_sanitize_keys}, "H1")
214
+ # #endregion
215
+
216
+ # Sanitize weights
217
+ weights = sanitize_weights(model, weights)
218
+
219
+ if hasattr(model, "thinker") and hasattr(model.thinker, "sanitize"):
220
+ weights = sanitize_weights(model.thinker, weights)
221
+ weights = sanitize_weights(model.thinker.vision_tower, weights)
222
+ weights = sanitize_weights(model.thinker.audio_tower, weights)
223
+ weights = sanitize_weights(model.thinker.language_model, weights)
224
+ weights = sanitize_weights(model.code2wav, weights)
225
+ weights = sanitize_weights(model.talker, weights)
226
+ else:
227
+ weights = sanitize_weights(
228
+ model_class.VisionModel, weights, model_config.vision_config
229
+ )
230
+ weights = sanitize_weights(
231
+ model_class.LanguageModel, weights, model_config.text_config
232
+ )
233
+ if hasattr(model_class, "AudioModel"):
234
+ weights = sanitize_weights(
235
+ model_class.AudioModel, weights, model_config.audio_config
236
+ )
237
+
238
+ # #region agent log
239
+ post_sanitize_keys = sorted(weights.keys())
240
+ log_weights("utils.py:post_sanitize", "Weights after sanitize", {"n_weights": len(weights), "sample_keys": post_sanitize_keys[:10], "all_keys": post_sanitize_keys}, "H1")
241
+ # #endregion
242
+
243
+ if (quantization := config.get("quantization", None)) is not None:
244
+ # Handle legacy models which may or may not have vision quantized
245
+ # TODO: Re-upload the models with the new quantization config and remove this
246
+ skip_vision = config.get("vision_config", {}).get("skip_vision", False)
247
+
248
+ def get_class_predicate(p, m):
249
+ # Always skip vision and audio models
250
+ if skip_multimodal_module(p) and skip_vision:
251
+ return False
252
+ # Handle custom per layer quantizations
253
+ if p in config["quantization"]:
254
+ return config["quantization"][p]
255
+ if not hasattr(m, "to_quantized"):
256
+ return False
257
+ # Skip layers not divisible by 64
258
+ if hasattr(m, "weight") and m.weight.size % 64 != 0:
259
+ return False
260
+ # Handle legacy models which may not have everything quantized
261
+ return f"{p}.scales" in weights
262
+
263
+ nn.quantize(
264
+ model,
265
+ group_size=quantization["group_size"],
266
+ bits=quantization["bits"],
267
+ mode=quantization.get("mode", "affine"),
268
+ class_predicate=get_class_predicate,
269
+ )
270
+
271
+ # #region agent log
272
+ weights_to_load = sorted([k for k, v in weights.items()])
273
+ log_weights("utils.py:before_load_weights", "Weights being passed to load_weights", {"n_weights": len(weights_to_load), "sample_keys": weights_to_load[:10], "all_keys": weights_to_load}, "H2")
274
+ # #endregion
275
+
276
+ model.load_weights(list(weights.items()))
277
+
278
+ # #region agent log
279
+ # Get model parameters after load_weights to see what was actually loaded
280
+ loaded_params = dict(tree_flatten(model.parameters()))
281
+ loaded_param_keys = sorted(loaded_params.keys())
282
+
283
+ # Find which weights from sanitize matched model params
284
+ matched_keys = [k for k in weights_to_load if k in loaded_param_keys]
285
+ unmatched_weights = [k for k in weights_to_load if k not in loaded_param_keys]
286
+ unmatched_params = [k for k in loaded_param_keys if k not in weights_to_load]
287
+
288
+ # Categorize by subsystem
289
+ vision_weights = [k for k in matched_keys if k.startswith("vision_")]
290
+ projection_weights = [k for k in matched_keys if k.startswith("vision_projection")]
291
+ language_weights = [k for k in matched_keys if k.startswith("language_model")]
292
+
293
+ log_weights("utils.py:after_load_weights", "Weight loading results", {
294
+ "total_model_params": len(loaded_param_keys),
295
+ "total_sanitized_weights": len(weights_to_load),
296
+ "matched_weights": len(matched_keys),
297
+ "unmatched_weights_from_sanitize": len(unmatched_weights),
298
+ "unmatched_params_in_model": len(unmatched_params),
299
+ "vision_weights_loaded": len(vision_weights),
300
+ "projection_weights_loaded": len(projection_weights),
301
+ "language_weights_loaded": len(language_weights),
302
+ "matched_keys": matched_keys,
303
+ "unmatched_weights": unmatched_weights,
304
+ "unmatched_params": unmatched_params
305
+ }, "H2,H3,H4")
306
+ # #endregion
307
+
308
+ if not lazy:
309
+ mx.eval(model.parameters())
310
+
311
+ model.eval()
312
+ return model
313
+
314
+
315
+ def sanitize_weights(model_obj, weights, config=None):
316
+ """Helper function to sanitize weights if the model has a sanitize method"""
317
+ if hasattr(model_obj, "sanitize"):
318
+ if config is not None:
319
+ model_obj = model_obj(config)
320
+ weights = model_obj.sanitize(weights)
321
+ return weights
322
+
323
+
324
+ def update_module_configs(model_config, model_class, config, modules):
325
+ """Updates configuration for model modules like text and vision modules.
326
+
327
+ Args:
328
+ model_config: The model configuration object that will be updated
329
+ model_class: The model class containing component config classes
330
+ config: Dictionary containing configuration parameters
331
+ modules: List of module names to update configs for (e.g. ["text", "vision"])
332
+
333
+ Returns:
334
+ The updated model_config object
335
+ """
336
+ for config_name in modules:
337
+ config_attr = f"{config_name}_config"
338
+ if hasattr(model_config, config_attr):
339
+ config_class = getattr(model_class, f"{config_name.title()}Config")
340
+ setattr(
341
+ model_config, config_attr, config_class.from_dict(config[config_attr])
342
+ )
343
+ return model_config
344
+
345
+
346
+ def load(
347
+ path_or_hf_repo: str,
348
+ adapter_path: Optional[str] = None,
349
+ lazy: bool = False,
350
+ revision: Optional[str] = None,
351
+ **kwargs,
352
+ ) -> Tuple[nn.Module, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]:
353
+ """
354
+ Load the model and tokenizer from a given path or a huggingface repository.
355
+
356
+ Args:
357
+ path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
358
+ tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
359
+ Defaults to an empty dictionary.
360
+ adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
361
+ to the model. Default: ``None``.
362
+ lazy (bool): If False eval the model parameters to make sure they are
363
+ loaded in memory before returning, otherwise they will be loaded
364
+ when needed. Default: ``False``
365
+ revision (str, optional): A revision id which can be a branch name,
366
+ a tag, or a commit hash. Default: ``None``.
367
+ Returns:
368
+ Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
369
+
370
+ Raises:
371
+ FileNotFoundError: If config file or safetensors are not found.
372
+ ValueError: If model class or args class are not found.
373
+ """
374
+ force_download = kwargs.get("force_download", False)
375
+ model_path = get_model_path(
376
+ path_or_hf_repo, force_download=force_download, revision=revision
377
+ )
378
+ model = load_model(model_path, lazy, **kwargs)
379
+ if adapter_path is not None:
380
+ model = apply_lora_layers(model, adapter_path)
381
+ model.eval()
382
+
383
+ image_processor = load_image_processor(model_path, **kwargs)
384
+
385
+ # Get the eos_token_id from the model config
386
+ eos_token_id = getattr(model.config, "eos_token_id", None)
387
+
388
+ processor = load_processor(model_path, True, eos_token_ids=eos_token_id, **kwargs)
389
+
390
+ if image_processor is not None:
391
+ processor.image_processor = image_processor
392
+
393
+ return model, processor
394
+
395
+
396
+ def load_config(model_path: Union[str, Path], **kwargs) -> dict:
397
+ """Load model configuration from a path or Hugging Face repo.
398
+
399
+ Args:
400
+ model_path: Local path or Hugging Face repo ID to load config from
401
+ **kwargs: Additional keyword arguments to pass to the config loader
402
+
403
+ Returns:
404
+ dict: Model configuration
405
+
406
+ Raises:
407
+ FileNotFoundError: If config.json is not found at the path
408
+ """
409
+ if isinstance(model_path, str):
410
+ model_path = get_model_path(model_path)
411
+
412
+ try:
413
+ with open(model_path / "config.json", encoding="utf-8") as f:
414
+ config = json.load(f)
415
+
416
+ generation_config_file = model_path / "generation_config.json"
417
+ if generation_config_file.exists():
418
+ generation_config = {}
419
+ try:
420
+ with open(generation_config_file, "r") as f:
421
+ generation_config = json.load(f)
422
+ except json.JSONDecodeError:
423
+ pass
424
+
425
+ if eos_token_id := generation_config.get("eos_token_id", False):
426
+ config["eos_token_id"] = eos_token_id
427
+
428
+ return config
429
+
430
+ except FileNotFoundError as exc:
431
+ raise FileNotFoundError(f"Config not found at {model_path}") from exc
432
+
433
+
434
+ def load_image_processor(model_path: Union[str, Path], **kwargs) -> BaseImageProcessor:
435
+ if isinstance(model_path, str):
436
+ model_path = get_model_path(model_path)
437
+
438
+ if not kwargs:
439
+ config = load_config(model_path, trust_remote_code=True)
440
+ else:
441
+ config = load_config(model_path, **kwargs)
442
+
443
+ model_class, _ = get_model_and_args(config)
444
+ image_processor = None
445
+
446
+ if hasattr(model_class, "ImageProcessor"):
447
+ init_signature = inspect.signature(model_class.ImageProcessor.__init__)
448
+
449
+ if "config" in init_signature.parameters:
450
+ image_processor = model_class.ImageProcessor(config=config)
451
+ else:
452
+ image_processor = model_class.ImageProcessor()
453
+
454
+ return image_processor
455
+
456
+
457
+ def load_processor(
458
+ model_path, add_detokenizer=True, eos_token_ids=None, **kwargs
459
+ ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
460
+
461
+ processor = AutoProcessor.from_pretrained(model_path, use_fast=True, **kwargs)
462
+ if add_detokenizer:
463
+ detokenizer_class = load_tokenizer(model_path, return_tokenizer=False)
464
+
465
+ # Get the tokenizer object
466
+ tokenizer_obj = (
467
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
468
+ )
469
+
470
+ # Instantiate the detokenizer
471
+ processor.detokenizer = detokenizer_class(tokenizer_obj)
472
+
473
+ # Determine the EOS token IDs, prioritizing the function argument
474
+ final_eos_token_ids = (
475
+ eos_token_ids if eos_token_ids is not None else tokenizer_obj.eos_token_ids
476
+ )
477
+
478
+ # Create and assign the StoppingCriteria
479
+ criteria = StoppingCriteria(final_eos_token_ids, tokenizer_obj)
480
+ if hasattr(processor, "tokenizer"):
481
+ processor.tokenizer.stopping_criteria = criteria
482
+ else:
483
+ processor.stopping_criteria = criteria
484
+
485
+ return processor
486
+
487
+
488
+ def fetch_from_hub(
489
+ model_path: Path, lazy: bool = False, **kwargs
490
+ ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
491
+ model = load_model(model_path, lazy, **kwargs)
492
+ config = load_config(model_path, **kwargs)
493
+ processor = load_processor(
494
+ model_path,
495
+ add_detokenizer=False,
496
+ eos_token_ids=config.get("eos_token_id", None),
497
+ **kwargs,
498
+ )
499
+ return model, config, processor
500
+
501
+
502
+ def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
503
+ """
504
+ Splits the weights into smaller shards.
505
+
506
+ Args:
507
+ weights (dict): Model weights.
508
+ max_file_size_gb (int): Maximum size of each shard in gigabytes.
509
+
510
+ Returns:
511
+ list: List of weight shards.
512
+ """
513
+ max_file_size_bytes = max_file_size_gb << 30
514
+ shards = []
515
+ shard, shard_size = {}, 0
516
+ for k, v in weights.items():
517
+ if shard_size + v.nbytes > max_file_size_bytes:
518
+ shards.append(shard)
519
+ shard, shard_size = {}, 0
520
+ shard[k] = v
521
+ shard_size += v.nbytes
522
+ shards.append(shard)
523
+ return shards
524
+
525
+
526
+ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
527
+ """
528
+ Uploads the model to Hugging Face hub.
529
+
530
+ Args:
531
+ path (str): Local path to the model.
532
+ upload_repo (str): Name of the HF repo to upload to.
533
+ hf_path (str): Path to the original Hugging Face model.
534
+ """
535
+ import os
536
+
537
+ from huggingface_hub import HfApi, ModelCard, logging
538
+
539
+ from . import __version__
540
+
541
+ card = ModelCard.load(hf_path)
542
+ card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
543
+ card.text = dedent(
544
+ f"""
545
+ # {upload_repo}
546
+ This model was converted to MLX format from [`{hf_path}`]() using mlx-vlm version **{__version__}**.
547
+ Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
548
+ ## Use with mlx
549
+
550
+ ```bash
551
+ pip install -U mlx-vlm
552
+ ```
553
+
554
+ ```bash
555
+ python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temperature 0.0 --prompt "Describe this image." --image <path_to_image>
556
+ ```
557
+ """
558
+ )
559
+ card.save(os.path.join(path, "README.md"))
560
+
561
+ logging.set_verbosity_info()
562
+
563
+ api = HfApi()
564
+ api.create_repo(repo_id=upload_repo, exist_ok=True)
565
+ api.upload_folder(
566
+ folder_path=path,
567
+ repo_id=upload_repo,
568
+ repo_type="model",
569
+ )
570
+ print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
571
+
572
+
573
+ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
574
+ """
575
+ Apply repetition penalty to specific logits based on the given context.
576
+
577
+ Paper: https://arxiv.org/abs/1909.05858
578
+
579
+ Args:
580
+ logits (mx.array): The logits produced by the language model.
581
+ generated_tokens (any): A list of N previous tokens.
582
+ penalty (float): The repetition penalty factor to be applied.
583
+
584
+ Returns:
585
+ logits (mx.array): Logits with repetition penalty applied to generated tokens.
586
+ """
587
+ if len(generated_tokens) > 0:
588
+ indices = mx.array([token for token in generated_tokens])
589
+ selected_logits = logits[:, indices]
590
+ selected_logits = mx.where(
591
+ selected_logits < 0, selected_logits * penalty, selected_logits / penalty
592
+ )
593
+ logits[:, indices] = selected_logits
594
+ return logits
595
+
596
+
597
+ def save_weights(
598
+ save_path: Union[str, Path],
599
+ model: nn.Module,
600
+ *,
601
+ donate_weights: bool = False,
602
+ ) -> None:
603
+ """Save model weights into specified directory."""
604
+ if isinstance(save_path, str):
605
+ save_path = Path(save_path)
606
+
607
+ weights = dict(tree_flatten(model.parameters()))
608
+
609
+ save_path.mkdir(parents=True, exist_ok=True)
610
+
611
+ shards = make_shards(weights)
612
+ shards_count = len(shards)
613
+ shard_file_format = (
614
+ "model-{:05d}-of-{:05d}.safetensors"
615
+ if shards_count > 1
616
+ else "model.safetensors"
617
+ )
618
+
619
+ total_size = sum(v.nbytes for v in weights.values())
620
+ index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
621
+
622
+ # Write the weights and make sure no references are kept other than the
623
+ # necessary ones
624
+ if donate_weights:
625
+ model.update(tree_map(lambda _: mx.array([]), model.parameters()))
626
+
627
+ weights.clear()
628
+ del weights
629
+
630
+ for i in range(len(shards)):
631
+ shard = shards[i]
632
+ shards[i] = None
633
+ shard_name = shard_file_format.format(i + 1, shards_count)
634
+ shard_path = save_path / shard_name
635
+
636
+ mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
637
+
638
+ for weight_name in shard.keys():
639
+ index_data["weight_map"][weight_name] = shard_name
640
+ del shard
641
+
642
+ index_data["weight_map"] = {
643
+ k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
644
+ }
645
+
646
+ with open(save_path / "model.safetensors.index.json", "w") as f:
647
+ json.dump(
648
+ index_data,
649
+ f,
650
+ indent=4,
651
+ )
652
+
653
+
654
+ def save_config(
655
+ config: dict,
656
+ config_path: Union[str, Path],
657
+ ) -> None:
658
+ """Save the model configuration to the ``config_path``.
659
+
660
+ The final configuration will be sorted before saving for better readability.
661
+
662
+ Args:
663
+ config (dict): The model configuration.
664
+ config_path (Union[str, Path]): Model configuration file path.
665
+ """
666
+ # Clean unused keys
667
+ config.pop("_name_or_path", None)
668
+ config.pop("torch_dtype", None)
669
+
670
+ # sort the config for better readability
671
+ config = dict(sorted(config.items()))
672
+
673
+ # write the updated config to the config_path (if provided)
674
+ with open(config_path, "w") as fid:
675
+ json.dump(config, fid, indent=4)
676
+
677
+
678
+ def load_image(image_source: Union[str, Path, BytesIO], timeout: int = 10):
679
+ """
680
+ Helper function to load an image from either a URL or file.
681
+ """
682
+ if (
683
+ isinstance(image_source, BytesIO)
684
+ or (isinstance(image_source, str) and image_source.startswith("data:image/"))
685
+ or Path(image_source).is_file()
686
+ ):
687
+ # for base64 encoded images
688
+ try:
689
+ if image_source.startswith("data:image/"):
690
+ import base64
691
+
692
+ if "," not in image_source:
693
+ raise ValueError(
694
+ "Invalid data URI format - missing comma separator"
695
+ )
696
+
697
+ _, data = image_source.split(",", 1)
698
+ image_source = BytesIO(base64.b64decode(data))
699
+
700
+ image = Image.open(image_source)
701
+ except IOError as e:
702
+ raise ValueError(
703
+ f"Failed to load image from {image_source} with error: {e}"
704
+ ) from e
705
+ elif image_source.startswith(("http://", "https://")):
706
+ try:
707
+ response = requests.get(image_source, stream=True, timeout=timeout)
708
+ response.raise_for_status()
709
+ image = Image.open(response.raw)
710
+ except Exception as e:
711
+ raise ValueError(
712
+ f"Failed to load image from URL: {image_source} with error {e}"
713
+ ) from e
714
+ else:
715
+ raise ValueError(
716
+ f"The image {image_source} must be a valid URL or existing file."
717
+ )
718
+
719
+ image = ImageOps.exif_transpose(image)
720
+ image = image.convert("RGB")
721
+ return image
722
+
723
+
724
+ def resize_image(img, max_size):
725
+
726
+ ratio = min(max_size[0] / img.width, max_size[1] / img.height)
727
+ new_size = (int(img.width * ratio), int(img.height * ratio))
728
+ return img.resize(new_size)
729
+
730
+
731
+ def process_image(img, resize_shape, image_processor):
732
+ if isinstance(img, str):
733
+ img = load_image(img)
734
+ if resize_shape is not None and not isinstance(image_processor, BaseImageProcessor):
735
+ img = resize_image(img, resize_shape)
736
+ return img
737
+
738
+
739
+ def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
740
+ """Resample audio using linear interpolation."""
741
+ if orig_sr == target_sr:
742
+ return audio
743
+
744
+ # Calculate the resampling ratio
745
+ ratio = target_sr / orig_sr
746
+
747
+ # Handle different audio shapes
748
+ if audio.ndim == 1:
749
+ # Mono audio - simple case
750
+ new_length = int(len(audio) * ratio)
751
+ old_indices = np.arange(len(audio))
752
+ new_indices = np.linspace(0, len(audio) - 1, new_length)
753
+ resampled = np.interp(new_indices, old_indices, audio)
754
+
755
+ elif audio.ndim == 2:
756
+ # Multi-channel audio - transpose to (samples, channels) if needed
757
+ if audio.shape[0] < audio.shape[1]:
758
+ audio = audio.T
759
+
760
+ # Resample each channel
761
+ n_samples, n_channels = audio.shape
762
+ new_length = int(n_samples * ratio)
763
+ old_indices = np.arange(n_samples)
764
+ new_indices = np.linspace(0, n_samples - 1, new_length)
765
+
766
+ resampled = np.zeros((new_length, n_channels))
767
+ for i in range(n_channels):
768
+ resampled[:, i] = np.interp(new_indices, old_indices, audio[:, i])
769
+ else:
770
+ raise ValueError(f"Audio array has unsupported shape: {audio.shape}")
771
+
772
+ return resampled
773
+
774
+
775
+ def load_audio(
776
+ file: str,
777
+ sr: int,
778
+ timeout: int = 10,
779
+ ):
780
+ """
781
+ Helper function to load audio from either a URL or file.
782
+ """
783
+ if file.startswith(("http://", "https://")):
784
+ try:
785
+ response = requests.get(file, stream=True, timeout=timeout)
786
+ response.raise_for_status()
787
+ audio, sample_rate = sf.read(BytesIO(response.content), always_2d=True)
788
+ except Exception as e:
789
+ raise ValueError(
790
+ f"Failed to load audio from URL: {file} with error {e}"
791
+ ) from e
792
+ else:
793
+ audio, sample_rate = sf.read(file, always_2d=True)
794
+
795
+ if sample_rate != sr:
796
+ audio = resample_audio(audio, sample_rate, sr)
797
+ return np.array(audio).mean(axis=1)
798
+
799
+
800
+ def normalize_audio_features(features: mx.array) -> mx.array:
801
+ """Normalize mel spectrogram features for lossy audio formats (e.g., MP3)."""
802
+ return (features - mx.mean(features)) / (mx.std(features) + 1e-6)
803
+
804
+
805
+ def process_inputs(
806
+ processor,
807
+ prompts,
808
+ images=None,
809
+ audio=None,
810
+ add_special_tokens=False,
811
+ padding=True,
812
+ padding_side="left",
813
+ return_tensors="mlx",
814
+ **kwargs,
815
+ ):
816
+ # Get the process method from the processor
817
+ process_method = getattr(processor, "process", processor)
818
+ parameters = inspect.signature(process_method).parameters
819
+
820
+ # Prepare arguments
821
+ args = {
822
+ "text": prompts,
823
+ "images": images,
824
+ "padding": padding,
825
+ "return_tensors": return_tensors,
826
+ }
827
+ if "padding_side" in parameters:
828
+ args["padding_side"] = padding_side
829
+
830
+ # Add special tokens if supported
831
+ if "add_special_tokens" in parameters:
832
+ args["add_special_tokens"] = add_special_tokens
833
+
834
+ for param in parameters.keys():
835
+ if param in kwargs.keys():
836
+ args[param] = kwargs.get(param, None)
837
+
838
+ # Add audio if provided and supported
839
+ if audio is not None and len(audio) > 0:
840
+ if "audio" in parameters:
841
+ args["audio"] = audio
842
+ else:
843
+ raise ValueError(f"Processor {processor} does not support audio parameter")
844
+
845
+ return process_method(**args)
846
+
847
+
848
+ def process_inputs_with_fallback(
849
+ processor,
850
+ prompts,
851
+ images,
852
+ audio,
853
+ add_special_tokens=False,
854
+ return_tensors="mlx",
855
+ **kwargs,
856
+ ):
857
+ # First attempt with specified return_tensors
858
+ try:
859
+ return process_inputs(
860
+ processor,
861
+ prompts=prompts,
862
+ images=images,
863
+ audio=audio,
864
+ add_special_tokens=add_special_tokens,
865
+ return_tensors=return_tensors,
866
+ **kwargs,
867
+ )
868
+ except Exception as e:
869
+ # Fallback to PyTorch tensors if MLX fails
870
+ if return_tensors != "pt":
871
+ try:
872
+ return process_inputs(
873
+ processor,
874
+ prompts=prompts,
875
+ images=images,
876
+ audio=audio,
877
+ add_special_tokens=add_special_tokens,
878
+ return_tensors="pt",
879
+ **kwargs,
880
+ )
881
+ except Exception as fallback_error:
882
+ raise ValueError(
883
+ f"Failed to process inputs with error: {fallback_error}"
884
+ ) from fallback_error
885
+
886
+ raise ValueError(f"Failed to process inputs with error: {e}")
887
+
888
+
889
+ def prepare_inputs(
890
+ processor,
891
+ images=None,
892
+ audio=None,
893
+ prompts=None,
894
+ image_token_index=None,
895
+ resize_shape=None,
896
+ add_special_tokens=False,
897
+ padding=True,
898
+ padding_side="left",
899
+ pad_to_uniform_size=False,
900
+ **kwargs,
901
+ ):
902
+
903
+ if not images and not audio:
904
+ tokenizer = (
905
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
906
+ )
907
+ # Ensure pad_token exists when padding text-only inputs
908
+ if padding and tokenizer.pad_token is None:
909
+ tokenizer.pad_token = tokenizer.eos_token
910
+ inputs = tokenizer(
911
+ prompts,
912
+ add_special_tokens=add_special_tokens,
913
+ padding=padding,
914
+ padding_side=padding_side,
915
+ )
916
+ input_ids = mx.array([inputs.input_ids])
917
+ mask = mx.array([inputs.attention_mask])
918
+ return {
919
+ "input_ids": input_ids,
920
+ "attention_mask": mask,
921
+ }
922
+
923
+ # Process images
924
+ if images is not None:
925
+ if not isinstance(images, list):
926
+ images = [images]
927
+
928
+ image_processor = (
929
+ processor.image_processor if hasattr(processor, "image_processor") else None
930
+ )
931
+ images = [process_image(img, resize_shape, image_processor) for img in images]
932
+
933
+ # For batching, we need uniform image sizes. Instead of padding to the
934
+ # largest image (which adds white borders that hurt accuracy), we resize
935
+ # all images to the model's expected input size.
936
+ if len(images) > 1 and pad_to_uniform_size:
937
+ # Get target size from image processor if available
938
+ target_size = None
939
+ if image_processor is not None and hasattr(image_processor, "size"):
940
+ size = image_processor.size
941
+ if isinstance(size, tuple):
942
+ target_size = size
943
+ elif isinstance(size, dict):
944
+ target_size = (size.get("height", 384), size.get("width", 384))
945
+ elif isinstance(size, int):
946
+ target_size = (size, size)
947
+
948
+ if target_size is not None:
949
+ # Resize all images to the target size
950
+ resized_images = []
951
+ for img in images:
952
+ if img.size != (
953
+ target_size[1],
954
+ target_size[0],
955
+ ): # PIL uses (width, height)
956
+ img = img.resize(
957
+ (target_size[1], target_size[0]), Image.Resampling.BICUBIC
958
+ )
959
+ resized_images.append(img)
960
+ images = resized_images
961
+ else:
962
+ # Fallback: pad to largest size (original behavior)
963
+ max_width = max(img.width for img in images)
964
+ max_height = max(img.height for img in images)
965
+
966
+ padded_images = []
967
+ for img in images:
968
+ if img.width != max_width or img.height != max_height:
969
+ padded_img = Image.new(
970
+ "RGB", (max_width, max_height), (255, 255, 255)
971
+ )
972
+ x_offset = (max_width - img.width) // 2
973
+ y_offset = (max_height - img.height) // 2
974
+ padded_img.paste(img, (x_offset, y_offset))
975
+ padded_images.append(padded_img)
976
+ else:
977
+ padded_images.append(img)
978
+ images = padded_images
979
+
980
+ # Process audio
981
+ audio_inputs = None
982
+ audio_feature_lengths = None
983
+ is_qwen3_omni_moe = False
984
+ processor_class_name = (
985
+ processor.__class__.__name__ if hasattr(processor, "__class__") else ""
986
+ )
987
+ if (
988
+ "qwen3" in processor_class_name.lower()
989
+ and "omni" in processor_class_name.lower()
990
+ ):
991
+ is_qwen3_omni_moe = True
992
+
993
+ is_lossy_audio = False
994
+ if audio is not None and len(audio) > 0:
995
+ if not isinstance(audio, list):
996
+ audio = [audio]
997
+
998
+ # Check if any audio file is a lossy format (MP3, AAC, OGG, etc.)
999
+ lossy_extensions = {".mp3", ".m4a"}
1000
+ is_lossy_audio = any(
1001
+ str(f).lower().endswith(tuple(lossy_extensions)) for f in audio
1002
+ )
1003
+
1004
+ if len(audio) > 1:
1005
+ print(
1006
+ "\033[33mWarning\033[0m: Single prompt with multiple audio files is not supported yet. Using the first audio file.\n"
1007
+ )
1008
+ audio = audio[:1]
1009
+
1010
+ if is_qwen3_omni_moe:
1011
+ audio_arrays = [
1012
+ load_audio(audio_file, sr=processor.feature_extractor.sampling_rate)
1013
+ for audio_file in audio
1014
+ ]
1015
+ audio_arrays = [
1016
+ audio_array.astype(np.float32) for audio_array in audio_arrays
1017
+ ]
1018
+
1019
+ feature_extractor = getattr(processor, "feature_extractor", None)
1020
+ if feature_extractor is None:
1021
+ raise ValueError("Processor missing feature_extractor for audio prep.")
1022
+
1023
+ audio_inputs = feature_extractor(
1024
+ audio_arrays,
1025
+ sampling_rate=feature_extractor.sampling_rate,
1026
+ padding=True,
1027
+ return_attention_mask=True,
1028
+ )
1029
+
1030
+ audio_feature_lengths = np.sum(
1031
+ audio_inputs["attention_mask"], axis=-1, dtype=np.int32
1032
+ )
1033
+ else:
1034
+ feature_extractor = getattr(processor, "feature_extractor", None)
1035
+ if feature_extractor is not None:
1036
+ audio = [
1037
+ load_audio(audio_file, sr=feature_extractor.sampling_rate)
1038
+ for audio_file in audio
1039
+ ]
1040
+ else:
1041
+ audio = [
1042
+ load_audio(audio_file, sr=processor.feature_extractor.sampling_rate)
1043
+ for audio_file in audio
1044
+ ]
1045
+
1046
+ model_inputs = {}
1047
+
1048
+ if hasattr(processor, "image_processor") and isinstance(
1049
+ processor.image_processor, BaseImageProcessor
1050
+ ):
1051
+ if not isinstance(prompts, list):
1052
+ prompts = [prompts]
1053
+
1054
+ if processor.pad_token is None:
1055
+ processor.pad_token = processor.eos_token
1056
+
1057
+ # Moondream expects image patch tokens immediately after BOS. Its
1058
+ # prompting is string-based, so we ignore literal "<image>" placement
1059
+ # and always insert the image token block after the first token.
1060
+ if processor.__class__.__name__ == "MoondreamProcessor":
1061
+ # Clean up prompts: strip <image> markers and add generation suffix
1062
+ cleaned_prompts = []
1063
+ for prompt in prompts:
1064
+ clean = prompt.replace("<image>", "").strip()
1065
+ # Add the generation prompt suffix moondream expects
1066
+ if not clean.endswith("Answer:") and not clean.endswith("Answer: "):
1067
+ clean = clean + "\n\nAnswer:"
1068
+ cleaned_prompts.append(clean)
1069
+
1070
+ token_ids_per_prompt = [
1071
+ processor(prompt, add_special_tokens=True).input_ids for prompt in cleaned_prompts
1072
+ ]
1073
+ text_chunks = []
1074
+ for ids in token_ids_per_prompt:
1075
+ if not ids:
1076
+ ids = [processor.bos_token_id]
1077
+ if ids[0] != processor.bos_token_id:
1078
+ ids = [processor.bos_token_id] + ids
1079
+ # Represent as [bos], [rest]
1080
+ text_chunks.append([[ids[0]], ids[1:]])
1081
+ else:
1082
+ text_chunks = [
1083
+ [processor(chunk).input_ids for chunk in prompt.split("<image>")]
1084
+ for prompt in prompts
1085
+ ]
1086
+
1087
+ # Normalize chunks to a 2-part [before, after] representation.
1088
+ # - If prompt has no "<image>", we treat it as [full_prompt, ""]
1089
+ # - If prompt has multiple "<image>", we only insert one image token and
1090
+ # concatenate the remaining text parts into the "after" section.
1091
+ normalized_chunks = []
1092
+ for chunks in text_chunks:
1093
+ if len(chunks) == 1:
1094
+ before = chunks[0]
1095
+ after = []
1096
+ elif len(chunks) >= 2:
1097
+ before = chunks[0]
1098
+ after = []
1099
+ for part in chunks[1:]:
1100
+ after += part
1101
+ else:
1102
+ before = []
1103
+ after = []
1104
+ normalized_chunks.append([before, after])
1105
+ text_chunks = normalized_chunks
1106
+
1107
+ # Find the maximum length for padding.
1108
+ # Note: for MoondreamProcessor we expand a single "<image>" marker into
1109
+ # 729 patch tokens.
1110
+ if processor.__class__.__name__ == "MoondreamProcessor":
1111
+ max_length = max(
1112
+ sum(len(chunk) for chunk in chunks) + 729 for chunks in text_chunks
1113
+ )
1114
+ else:
1115
+ max_length = max(
1116
+ sum(len(chunk) for chunk in chunks) + 1 for chunks in text_chunks
1117
+ )
1118
+
1119
+ # Pad and create input_ids
1120
+ input_ids = []
1121
+ for chunks in text_chunks:
1122
+ # Moondream2 uses a block of patch tokens (729) rather than a single
1123
+ # placeholder token. Keep this model-specific to avoid impacting
1124
+ # other multimodal models.
1125
+ if processor.__class__.__name__ == "MoondreamProcessor":
1126
+ image_tokens = [image_token_index] * 729
1127
+ else:
1128
+ image_tokens = [image_token_index]
1129
+
1130
+ ids = chunks[0] + image_tokens + chunks[1]
1131
+ padding = [processor.pad_token_id] * (max_length - len(ids))
1132
+ input_ids.append(mx.array(ids + padding))
1133
+
1134
+ model_inputs["input_ids"] = mx.array(input_ids)
1135
+
1136
+ # Handle Moondream's multi-crop preprocessing which returns
1137
+ # (crops_list, crop_counts, tilings) instead of just pixel_values
1138
+ if processor.__class__.__name__ == "MoondreamProcessor":
1139
+ crops_list, crop_counts, tilings = processor.image_processor.preprocess(
1140
+ images=images
1141
+ )
1142
+ # Concatenate all crops for batch processing
1143
+ all_crops = np.concatenate(crops_list, axis=0)
1144
+ model_inputs["pixel_values"] = mx.array(all_crops)
1145
+ model_inputs["crop_counts"] = crop_counts
1146
+ model_inputs["tilings"] = tilings
1147
+ else:
1148
+ pixel_values = processor.image_processor.preprocess(images=images)
1149
+ model_inputs["pixel_values"] = mx.array(np.stack(pixel_values))
1150
+
1151
+ model_inputs["attention_mask"] = mx.array(
1152
+ [(ids != processor.pad_token_id) for ids in input_ids]
1153
+ ).astype(mx.int32)
1154
+
1155
+ else:
1156
+ if hasattr(processor, "tokenizer") and processor.tokenizer.pad_token is None:
1157
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
1158
+
1159
+ inputs = process_inputs_with_fallback(
1160
+ processor,
1161
+ images=images,
1162
+ audio=audio,
1163
+ prompts=prompts,
1164
+ add_special_tokens=add_special_tokens,
1165
+ **kwargs,
1166
+ )
1167
+
1168
+ if "images" in inputs:
1169
+ inputs["pixel_values"] = inputs["images"]
1170
+ inputs.pop("images")
1171
+
1172
+ model_inputs["attention_mask"] = (
1173
+ mx.array(inputs["attention_mask"]) if "attention_mask" in inputs else None
1174
+ )
1175
+
1176
+ # Convert inputs to model_inputs with mx.array if present
1177
+ for key, value in inputs.items():
1178
+ if key not in model_inputs:
1179
+ if isinstance(value, (str, list, mx.array)):
1180
+ model_inputs[key] = value
1181
+ else:
1182
+ model_inputs[key] = mx.array(value)
1183
+
1184
+ if audio_inputs is not None:
1185
+ model_inputs["input_features"] = mx.array(audio_inputs["input_features"])
1186
+ model_inputs["feature_attention_mask"] = mx.array(
1187
+ audio_inputs["attention_mask"]
1188
+ ).astype(mx.int32)
1189
+ model_inputs["audio_feature_lengths"] = mx.array(
1190
+ audio_feature_lengths, dtype=mx.int32
1191
+ )
1192
+
1193
+ if is_lossy_audio and "input_features" in model_inputs:
1194
+ f = model_inputs["input_features"]
1195
+ if isinstance(f, list):
1196
+ model_inputs["input_features"] = [
1197
+ normalize_audio_features(mx.array(x)) for x in f
1198
+ ]
1199
+ else:
1200
+ model_inputs["input_features"] = normalize_audio_features(f)
1201
+
1202
+ return model_inputs
1203
+
1204
+
1205
+ def group_images_by_shape(
1206
+ images: List[Image.Image],
1207
+ disable_grouping: bool = False,
1208
+ ) -> Tuple[Dict[Tuple[int, int], List[Image.Image]], Dict[Tuple[int, int], List[int]]]:
1209
+ """
1210
+ Group images by their dimensions for efficient batch processing.
1211
+
1212
+ Images with the same dimensions can be stacked and processed together,
1213
+ which is much faster than processing individually (especially on GPU).
1214
+
1215
+ Args:
1216
+ images: List of PIL images to group
1217
+ disable_grouping: If True, each image gets its own group (useful for debugging)
1218
+
1219
+ Returns:
1220
+ grouped_images: Dict mapping shape -> list of images with that shape
1221
+ grouped_indices: Dict mapping shape -> list of original indices
1222
+
1223
+ Example:
1224
+ >>> images = [img_400x300, img_800x600, img_400x300_2]
1225
+ >>> grouped, indices = group_images_by_shape(images)
1226
+ >>> grouped
1227
+ {(300, 400): [img_400x300, img_400x300_2], (600, 800): [img_800x600]}
1228
+ >>> indices
1229
+ {(300, 400): [0, 2], (600, 800): [1]}
1230
+ """
1231
+ if disable_grouping:
1232
+ # Each image in its own group
1233
+ grouped_images = {}
1234
+ grouped_indices = {}
1235
+ for i, img in enumerate(images):
1236
+ shape = (img.height, img.width)
1237
+ # Make each shape unique by adding index
1238
+ unique_shape = (img.height, img.width, i)
1239
+ grouped_images[unique_shape] = [img]
1240
+ grouped_indices[unique_shape] = [i]
1241
+ return grouped_images, grouped_indices
1242
+
1243
+ grouped_images: Dict[Tuple[int, int], List[Image.Image]] = {}
1244
+ grouped_indices: Dict[Tuple[int, int], List[int]] = {}
1245
+
1246
+ for i, img in enumerate(images):
1247
+ shape = (img.height, img.width)
1248
+ if shape not in grouped_images:
1249
+ grouped_images[shape] = []
1250
+ grouped_indices[shape] = []
1251
+ grouped_images[shape].append(img)
1252
+ grouped_indices[shape].append(i)
1253
+
1254
+ return grouped_images, grouped_indices
1255
+
1256
+
1257
+ class StoppingCriteria:
1258
+ def __init__(self, eos_token_ids: List[int], tokenizer=None):
1259
+
1260
+ if isinstance(eos_token_ids, int):
1261
+ self.eos_token_ids = [eos_token_ids]
1262
+ else:
1263
+ self.eos_token_ids = eos_token_ids
1264
+
1265
+ self.tokenizer = tokenizer
1266
+
1267
+ def add_eos_token_ids(self, new_eos_token_ids: Union[int, List[int]] = None):
1268
+ """
1269
+ Add new token IDs to the list of EOS token IDs.
1270
+
1271
+ Args:
1272
+ new_eos_token_ids: Integer, string, or list of integers/strings representing token IDs to add.
1273
+ If strings are provided, they will be converted to integers if possible.
1274
+ """
1275
+ if new_eos_token_ids is None:
1276
+ return
1277
+
1278
+ if self.tokenizer is None:
1279
+ raise ValueError("Processor is not provided")
1280
+
1281
+ if new_eos_token_ids is not None:
1282
+ if isinstance(new_eos_token_ids, str):
1283
+ new_eos_token_ids = [new_eos_token_ids]
1284
+ new_eos_token_ids = [
1285
+ self.tokenizer.encode(" " + token, add_special_tokens=False)[-1]
1286
+ for token in new_eos_token_ids
1287
+ ]
1288
+ self.eos_token_ids.extend(new_eos_token_ids)
1289
+
1290
+ def reset(self, eos_token_ids: List[int] = None):
1291
+ eos_token_ids = (
1292
+ eos_token_ids if eos_token_ids is not None else self.tokenizer.eos_token_ids
1293
+ )
1294
+
1295
+ if isinstance(eos_token_ids, int):
1296
+ eos_token_ids = [eos_token_ids]
1297
+
1298
+ if self.eos_token_ids != eos_token_ids:
1299
+ self.eos_token_ids = eos_token_ids
1300
+
1301
+ def __call__(self, input_ids: mx.array) -> bool:
1302
+ return input_ids in self.eos_token_ids
1303
+
1304
+
1305
+ def print_array_report(t: mx.array, label: Optional[str]) -> dict:
1306
+ """
1307
+ Return a dictionary report of an MLX array similar to PyTorch's tensor representation.
1308
+ Args:
1309
+ arr: MLX array to analyze
1310
+ Returns:
1311
+ Dictionary containing shape, dtype, value representation, and statistics
1312
+ """
1313
+
1314
+ # Get basic statistics
1315
+ mean_val = mx.mean(t)
1316
+ std_val = mx.std(t)
1317
+ min_val = mx.min(t)
1318
+ max_val = mx.max(t)
1319
+
1320
+ report = {
1321
+ "shape": f"{tuple(t.shape)}",
1322
+ "dtype": str(t.dtype),
1323
+ "value": repr(t),
1324
+ "mean": f"array({mean_val}, dtype={t.dtype})",
1325
+ "std": f"array({std_val}, dtype={t.dtype})",
1326
+ "min": f"array({min_val}, dtype={t.dtype})",
1327
+ "max": f"array({max_val}, dtype={t.dtype})",
1328
+ "label": label if label else "array",
1329
+ }
1330
+
1331
+ # Print each field, handling 'value' specially
1332
+ print("{")
1333
+ for key, value in report.items():
1334
+ if key == "value":
1335
+ print(f" '{key}': {value},") # No quotes around value
1336
+ else:
1337
+ print(f" '{key}': {repr(value)},")
1338
+ print("}")
1339
+ return report