fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,763 @@
1
+ """
2
+ MLX-based Molmo Processor.
3
+
4
+ This module provides an MLX-native processor for Molmo models that doesn't
5
+ require torch, torchvision, or tensorflow.
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import List, Optional, Tuple, Union
11
+
12
+ import mlx.core as mx
13
+ import numpy as np
14
+ from PIL import Image
15
+ from transformers import AutoTokenizer
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.image_processing_utils import BaseImageProcessor
18
+ from transformers.image_utils import ImageInput, make_list_of_images
19
+ from transformers.processing_utils import ProcessorMixin
20
+
21
+ # CLIP normalization constants
22
+ OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
23
+ OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
24
+
25
+
26
+ def pad_to_bounding_box(
27
+ image: np.ndarray,
28
+ offset_height: int,
29
+ offset_width: int,
30
+ target_height: int,
31
+ target_width: int,
32
+ value: int = 0,
33
+ ) -> np.ndarray:
34
+ """Pad image to target bounding box."""
35
+ height, width = image.shape[:2]
36
+ after_padding_width = target_width - offset_width - width
37
+ after_padding_height = target_height - offset_height - height
38
+ if image.ndim == 3:
39
+ padding = [
40
+ [offset_height, after_padding_height],
41
+ [offset_width, after_padding_width],
42
+ [0, 0],
43
+ ]
44
+ else:
45
+ padding = [
46
+ [offset_height, after_padding_height],
47
+ [offset_width, after_padding_width],
48
+ ]
49
+ return np.pad(image, padding, constant_values=value)
50
+
51
+
52
+ def normalize_image(
53
+ image: np.ndarray, offset: Tuple[float, ...], scale: Tuple[float, ...]
54
+ ) -> np.ndarray:
55
+ """Normalize image with mean and std."""
56
+ image = image.astype(np.float32)
57
+ image -= np.array(offset, dtype=np.float32)[None, None, :]
58
+ image /= np.array(scale, dtype=np.float32)[None, None, :]
59
+ return image
60
+
61
+
62
+ def resize_and_pad(
63
+ image: np.ndarray,
64
+ desired_output_size: Tuple[int, int],
65
+ pad_value: float = 0,
66
+ normalize: bool = True,
67
+ image_mean: Tuple[float, ...] = OPENAI_CLIP_MEAN,
68
+ image_std: Tuple[float, ...] = OPENAI_CLIP_STD,
69
+ ) -> Tuple[np.ndarray, np.ndarray]:
70
+ """Resize and pad image using PIL (no torch/tensorflow)."""
71
+ desired_height, desired_width = desired_output_size
72
+ height, width = image.shape[:2]
73
+
74
+ # Calculate scale
75
+ image_scale_y = np.float32(desired_height) / np.float32(height)
76
+ image_scale_x = np.float32(desired_width) / np.float32(width)
77
+ image_scale = min(image_scale_x, image_scale_y)
78
+ scaled_height = int(np.float32(height) * image_scale)
79
+ scaled_width = int(np.float32(width) * image_scale)
80
+
81
+ # Use PIL for resizing (bilinear interpolation)
82
+ pil_image = Image.fromarray(
83
+ (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
84
+ )
85
+ pil_image = pil_image.resize(
86
+ (scaled_width, scaled_height), Image.Resampling.BILINEAR
87
+ )
88
+ image = np.array(pil_image).astype(np.float32) / 255.0
89
+ image = np.clip(image, 0.0, 1.0)
90
+
91
+ # Pad to desired size
92
+ top_pad = (desired_height - scaled_height) // 2
93
+ left_pad = (desired_width - scaled_width) // 2
94
+ padding = [
95
+ [top_pad, desired_height - scaled_height - top_pad],
96
+ [left_pad, desired_width - scaled_width - left_pad],
97
+ [0, 0],
98
+ ]
99
+ image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
100
+ image = np.pad(image, padding, constant_values=pad_value)
101
+
102
+ if normalize:
103
+ image = normalize_image(image, offset=image_mean, scale=image_std)
104
+
105
+ return image, image_mask
106
+
107
+
108
+ def select_tiling(
109
+ h: int, w: int, patch_size: int, max_num_patches: int
110
+ ) -> Tuple[int, int]:
111
+ """Select best tiling for image."""
112
+ tilings = []
113
+ for i in range(1, max_num_patches + 1):
114
+ for j in range(1, max_num_patches + 1):
115
+ if i * j <= max_num_patches:
116
+ tilings.append((i, j))
117
+ tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
118
+ candidate_tilings = np.array(tilings, dtype=np.int32)
119
+ candidate_resolutions = candidate_tilings * patch_size
120
+
121
+ original_size = np.array([h, w], dtype=np.float32)
122
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size
123
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True)
124
+
125
+ if np.all(required_scale < 1):
126
+ ix = np.argmax(required_scale)
127
+ else:
128
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
129
+ ix = np.argmin(required_scale)
130
+
131
+ return tuple(candidate_tilings[ix])
132
+
133
+
134
+ def rearrange_patches(
135
+ patches: np.ndarray, dh: int, dw: int, h: int, w: int
136
+ ) -> np.ndarray:
137
+ """Rearrange patches: 'p (h dh) (w dw) c -> p (h w) (dh dw c)'"""
138
+ p, H, W, c = patches.shape
139
+ patches = patches.reshape(p, h, dh, w, dw, c)
140
+ patches = patches.transpose(0, 1, 3, 2, 4, 5)
141
+ patches = patches.reshape(p, h * w, dh * dw * c)
142
+ return patches
143
+
144
+
145
+ def rearrange_mask(mask: np.ndarray, dh: int, dw: int, h: int, w: int) -> np.ndarray:
146
+ """Rearrange mask: 'p (h dh) (w dw) -> p (h w) (dh dw)'"""
147
+ p, H, W = mask.shape
148
+ mask = mask.reshape(p, h, dh, w, dw)
149
+ mask = mask.transpose(0, 1, 3, 2, 4)
150
+ mask = mask.reshape(p, h * w, dh * dw)
151
+ return mask
152
+
153
+
154
+ def rearrange_global(image: np.ndarray, dh: int, dw: int, h: int, w: int) -> np.ndarray:
155
+ """Rearrange global image: '(h dh) (w dw) c -> (h w) (dh dw c)'"""
156
+ H, W, c = image.shape
157
+ image = image.reshape(h, dh, w, dw, c)
158
+ image = image.transpose(0, 2, 1, 3, 4)
159
+ image = image.reshape(h * w, dh * dw * c)
160
+ return image
161
+
162
+
163
+ class MolmoImageProcessor(BaseImageProcessor):
164
+ """MLX-based image processor for Molmo."""
165
+
166
+ model_input_names = ["images", "image_input_idx", "image_masks"]
167
+
168
+ def __init__(
169
+ self,
170
+ max_crops: int = 12,
171
+ overlap_margins: List[int] = None,
172
+ base_image_input_size: List[int] = None,
173
+ image_token_length_w: int = 12,
174
+ image_token_length_h: int = 12,
175
+ image_patch_size: int = 14,
176
+ image_padding_mask: bool = True,
177
+ do_normalize: bool = True,
178
+ image_mean: Optional[List[float]] = None,
179
+ image_std: Optional[List[float]] = None,
180
+ **kwargs,
181
+ ):
182
+ super().__init__(**kwargs)
183
+ self.max_crops = max_crops
184
+ self.overlap_margins = overlap_margins or [4, 4]
185
+ self.base_image_input_size = base_image_input_size or [336, 336]
186
+ self.image_token_length_w = image_token_length_w
187
+ self.image_token_length_h = image_token_length_h
188
+ self.image_patch_size = image_patch_size
189
+ self.image_padding_mask = image_padding_mask
190
+ self.do_normalize = do_normalize
191
+ self.image_mean = tuple(image_mean) if image_mean else OPENAI_CLIP_MEAN
192
+ self.image_std = tuple(image_std) if image_std else OPENAI_CLIP_STD
193
+
194
+ def image_to_patches_and_tokens(
195
+ self,
196
+ image: np.ndarray,
197
+ image_patch_token_id: int,
198
+ image_col_token_id: int,
199
+ image_start_token_id: int,
200
+ image_end_token_id: int,
201
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
202
+ """Convert image to patches and tokens."""
203
+ base_image_input_size = self.base_image_input_size
204
+ base_image_input_d = self.image_patch_size
205
+ tokens_per_image = self.image_token_length_w * self.image_token_length_h
206
+ image_base_patch_w = base_image_input_size[1] // base_image_input_d
207
+ image_base_patch_h = base_image_input_size[0] // base_image_input_d
208
+
209
+ original_image_h, original_image_w = image.shape[:2]
210
+ crop_size = base_image_input_size[0]
211
+
212
+ left_margin, right_margin = self.overlap_margins
213
+ total_margin_pixels = base_image_input_d * (right_margin + left_margin)
214
+ crop_patches = base_image_input_size[0] // base_image_input_d
215
+ crop_window_patches = crop_patches - (right_margin + left_margin)
216
+ crop_window_size = crop_window_patches * base_image_input_d
217
+
218
+ tiling = select_tiling(
219
+ original_image_h - total_margin_pixels,
220
+ original_image_w - total_margin_pixels,
221
+ crop_window_size,
222
+ self.max_crops,
223
+ )
224
+
225
+ src, img_mask = resize_and_pad(
226
+ image,
227
+ [
228
+ tiling[0] * crop_window_size + total_margin_pixels,
229
+ tiling[1] * crop_window_size + total_margin_pixels,
230
+ ],
231
+ image_mean=self.image_mean,
232
+ image_std=self.image_std,
233
+ )
234
+
235
+ patches_arr = []
236
+ mask_arr = []
237
+ patch_ordering_arr = []
238
+
239
+ on = 0
240
+ for i in range(tiling[0]):
241
+ y0 = i * crop_window_size
242
+ crop_y0 = 0 if i == 0 else left_margin // 2
243
+
244
+ crop_h = image_base_patch_h - (right_margin + left_margin)
245
+ if i == 0:
246
+ crop_h += left_margin
247
+ if i == (tiling[0] - 1):
248
+ crop_h += right_margin
249
+
250
+ for j in range(tiling[1]):
251
+ x0 = j * crop_window_size
252
+ crop_x0 = 0 if j == 0 else left_margin // 2
253
+
254
+ crop_w = image_base_patch_w - (right_margin + left_margin)
255
+ if j == 0:
256
+ crop_w += left_margin
257
+ if j == (tiling[1] - 1):
258
+ crop_w += right_margin
259
+
260
+ pooled_w = (crop_w + 1) // 2
261
+ pooled_h = (crop_h + 1) // 2
262
+
263
+ ordering = np.reshape(
264
+ np.arange(on, on + pooled_h * pooled_w, dtype=np.int32),
265
+ (pooled_h, pooled_w, 1),
266
+ )
267
+ patch_ordering_arr.append(
268
+ pad_to_bounding_box(
269
+ ordering,
270
+ crop_y0,
271
+ crop_x0,
272
+ self.image_token_length_h,
273
+ self.image_token_length_w,
274
+ value=-1,
275
+ )[:, :, 0]
276
+ )
277
+ patches_arr.append(src[y0 : y0 + crop_size, x0 : x0 + crop_size])
278
+ mask_arr.append(img_mask[y0 : y0 + crop_size, x0 : x0 + crop_size])
279
+
280
+ on += pooled_h * pooled_w
281
+
282
+ patches = np.stack(patches_arr)
283
+ patch_ordering = np.stack(patch_ordering_arr)
284
+ img_mask = np.stack(mask_arr)
285
+
286
+ # Rearrange patches
287
+ patches = rearrange_patches(
288
+ patches,
289
+ base_image_input_d,
290
+ base_image_input_d,
291
+ image_base_patch_h,
292
+ image_base_patch_w,
293
+ )
294
+ img_mask = rearrange_mask(
295
+ img_mask,
296
+ base_image_input_d,
297
+ base_image_input_d,
298
+ image_base_patch_h,
299
+ image_base_patch_w,
300
+ )
301
+
302
+ img_mask = img_mask.astype(np.float32).mean(axis=-1)
303
+ patch_ordering = np.reshape(patch_ordering, [-1])
304
+ valid = patch_ordering >= 0
305
+
306
+ # Transpose order
307
+ patch_ordering_rh = np.reshape(
308
+ patch_ordering,
309
+ [
310
+ tiling[0],
311
+ tiling[1],
312
+ self.image_token_length_h,
313
+ self.image_token_length_w,
314
+ ],
315
+ )
316
+ patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
317
+ patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
318
+
319
+ patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
320
+
321
+ # Build output tokens
322
+ h = tiling[0] * crop_window_patches + (right_margin + left_margin)
323
+ w = tiling[1] * crop_window_patches + (right_margin + left_margin)
324
+ per_row = np.full(((w + 1) // 2,), image_patch_token_id)
325
+ per_row = np.concatenate([per_row, [image_col_token_id]], 0)
326
+
327
+ joint = np.tile(per_row, [(h + 1) // 2])
328
+ joint = [[image_start_token_id], joint, [image_end_token_id]]
329
+
330
+ # Global image
331
+ resized, _ = resize_and_pad(
332
+ image,
333
+ base_image_input_size,
334
+ image_mean=self.image_mean,
335
+ image_std=self.image_std,
336
+ )
337
+ resized = rearrange_global(
338
+ resized,
339
+ base_image_input_d,
340
+ base_image_input_d,
341
+ image_base_patch_h,
342
+ image_base_patch_w,
343
+ )
344
+ patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
345
+
346
+ patch_ordering = np.where(
347
+ patch_ordering >= 0, patch_ordering + tokens_per_image, -1
348
+ )
349
+ patch_ordering = np.concatenate(
350
+ [np.arange(0, tokens_per_image), patch_ordering], 0
351
+ )
352
+
353
+ per_row = np.full((self.image_token_length_w,), image_patch_token_id)
354
+ per_row = np.concatenate([per_row, [image_col_token_id]], 0)
355
+ extra_tokens = np.tile(per_row, [self.image_token_length_h])
356
+ joint = [
357
+ [image_start_token_id],
358
+ extra_tokens,
359
+ [image_end_token_id],
360
+ ] + joint
361
+
362
+ joint = np.concatenate(joint, 0)
363
+ img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
364
+
365
+ return patches, joint, patch_ordering, img_mask
366
+
367
+ def build_image_input_idx(
368
+ self,
369
+ image_tokens: np.ndarray,
370
+ patch_order: np.ndarray,
371
+ image_patch_token_id: int,
372
+ ) -> np.ndarray:
373
+ """Build image input indices."""
374
+ tokens_per_image = self.image_token_length_w * self.image_token_length_h
375
+
376
+ image_input_idx = image_tokens == image_patch_token_id
377
+ image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
378
+
379
+ if patch_order is not None:
380
+ n_tokens = image_input_idx.shape[0]
381
+ patch_order = np.reshape(patch_order, [-1])
382
+
383
+ valid = patch_order >= 0
384
+ n_valid_patches = valid.sum()
385
+
386
+ sorted_patch_ixs = np.zeros([n_tokens], np.int32)
387
+ sorted_patch_ixs[patch_order[valid]] = np.arange(
388
+ n_valid_patches, dtype=np.int32
389
+ )
390
+
391
+ sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
392
+ sorted_patch_ixs_ex[valid] = sorted_patch_ixs
393
+
394
+ valid_int = (sorted_patch_ixs_ex >= 0).astype(np.int32)
395
+ image_input_idx = image_input_idx[sorted_patch_ixs_ex * valid_int]
396
+ image_input_idx = image_input_idx * valid_int - 100 * (1 - valid_int)
397
+ image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
398
+
399
+ return image_input_idx
400
+
401
+ def preprocess(
402
+ self,
403
+ image: np.ndarray,
404
+ image_patch_token_id: int,
405
+ image_col_token_id: int,
406
+ image_start_token_id: int,
407
+ image_end_token_id: int,
408
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
409
+ """Preprocess a single image."""
410
+ crops, image_tokens, patch_ordering, img_mask = (
411
+ self.image_to_patches_and_tokens(
412
+ image,
413
+ image_patch_token_id,
414
+ image_col_token_id,
415
+ image_start_token_id,
416
+ image_end_token_id,
417
+ )
418
+ )
419
+ patch_idx = self.build_image_input_idx(
420
+ image_tokens, patch_ordering, image_patch_token_id
421
+ )
422
+ return crops, image_tokens, patch_idx, img_mask
423
+
424
+
425
+ class MolmoProcessor(ProcessorMixin):
426
+ """MLX-based processor for Molmo."""
427
+
428
+ attributes = ["image_processor", "tokenizer"]
429
+ image_processor_class = "MolmoImageProcessor"
430
+ tokenizer_class = "AutoTokenizer"
431
+
432
+ def __init__(
433
+ self,
434
+ image_processor=None,
435
+ tokenizer=None,
436
+ chat_template=None,
437
+ **kwargs,
438
+ ):
439
+ if image_processor is None:
440
+ image_processor = MolmoImageProcessor()
441
+ # Molmo uses these specific token names
442
+ self.image_patch_token = "<im_patch>"
443
+ self.image_col_token = "<im_col>"
444
+ self.image_start_token = "<im_start>"
445
+ self.image_end_token = "<im_end>"
446
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
447
+
448
+ def __call__(
449
+ self,
450
+ images: ImageInput = None,
451
+ text: Union[str, List[str]] = None,
452
+ **kwargs,
453
+ ) -> BatchFeature:
454
+ """Process images and text for Molmo."""
455
+ if images is None and text is None:
456
+ raise ValueError("You must provide either images or text.")
457
+
458
+ # Get token IDs
459
+ image_patch_token_id = self.tokenizer.convert_tokens_to_ids(
460
+ self.image_patch_token
461
+ )
462
+ image_col_token_id = self.tokenizer.convert_tokens_to_ids(self.image_col_token)
463
+ image_start_token_id = self.tokenizer.convert_tokens_to_ids(
464
+ self.image_start_token
465
+ )
466
+ image_end_token_id = self.tokenizer.convert_tokens_to_ids(self.image_end_token)
467
+
468
+ # Validate token IDs
469
+ if image_patch_token_id is None:
470
+ raise ValueError(
471
+ f"Token '{self.image_patch_token}' not found in tokenizer vocabulary"
472
+ )
473
+ if image_col_token_id is None:
474
+ raise ValueError(
475
+ f"Token '{self.image_col_token}' not found in tokenizer vocabulary"
476
+ )
477
+ if image_start_token_id is None:
478
+ raise ValueError(
479
+ f"Token '{self.image_start_token}' not found in tokenizer vocabulary"
480
+ )
481
+ if image_end_token_id is None:
482
+ raise ValueError(
483
+ f"Token '{self.image_end_token}' not found in tokenizer vocabulary"
484
+ )
485
+
486
+ # Process images
487
+ if images is not None:
488
+ images = make_list_of_images(images)
489
+ # Convert PIL images to numpy arrays
490
+ np_images = []
491
+ for img in images:
492
+ if isinstance(img, Image.Image):
493
+ img = img.convert("RGB")
494
+ np_images.append(np.array(img).astype(np.float32) / 255.0)
495
+ elif isinstance(img, np.ndarray):
496
+ if img.max() > 1.0:
497
+ img = img.astype(np.float32) / 255.0
498
+ np_images.append(img)
499
+ else:
500
+ np_images.append(np.array(img).astype(np.float32) / 255.0)
501
+ images = np_images
502
+
503
+ # Tokenize text
504
+ if text is not None:
505
+ if isinstance(text, str):
506
+ text = [text]
507
+ tokens_list = [self.tokenizer.encode(t) for t in text]
508
+ else:
509
+ tokens_list = [[]]
510
+
511
+ # Process each image with text
512
+ if images is not None and len(images) > 0:
513
+ all_crops = []
514
+ all_image_idx = []
515
+ all_masks = []
516
+ all_input_ids = []
517
+
518
+ for i, (img, tokens) in enumerate(
519
+ zip(
520
+ images,
521
+ (
522
+ tokens_list
523
+ if len(tokens_list) == len(images)
524
+ else [tokens_list[0]] * len(images)
525
+ ),
526
+ )
527
+ ):
528
+ crops, image_tokens, patch_idx, img_mask = (
529
+ self.image_processor.preprocess(
530
+ img,
531
+ image_patch_token_id,
532
+ image_col_token_id,
533
+ image_start_token_id,
534
+ image_end_token_id,
535
+ )
536
+ )
537
+
538
+ # Combine image tokens with text tokens
539
+ combined_tokens = np.concatenate([image_tokens, np.array(tokens)])
540
+
541
+ # Adjust patch_idx for the position in combined tokens
542
+ all_crops.append(crops)
543
+ all_image_idx.append(patch_idx)
544
+ all_masks.append(img_mask)
545
+ all_input_ids.append(combined_tokens)
546
+
547
+ # Stack results
548
+ pixel_values = mx.array(
549
+ np.concatenate(all_crops, axis=0).astype(np.float32)
550
+ )
551
+ image_input_idx = mx.array(
552
+ np.concatenate(all_image_idx, axis=0).astype(np.int32)
553
+ )
554
+ image_masks = mx.array(np.concatenate(all_masks, axis=0).astype(np.float32))
555
+
556
+ # Pad input_ids to same length
557
+ max_len = max(len(ids) for ids in all_input_ids)
558
+ pad_token_id = self.tokenizer.pad_token_id
559
+ if pad_token_id is None:
560
+ pad_token_id = self.tokenizer.eos_token_id or 0
561
+ padded_ids = []
562
+ for ids in all_input_ids:
563
+ pad_len = max_len - len(ids)
564
+ if pad_len > 0:
565
+ ids = np.pad(ids, (0, pad_len), constant_values=pad_token_id)
566
+ padded_ids.append(ids.astype(np.int32))
567
+
568
+ input_ids = mx.array(np.stack(padded_ids).astype(np.int32))
569
+
570
+ return BatchFeature(
571
+ data={
572
+ "input_ids": input_ids,
573
+ "pixel_values": pixel_values,
574
+ "image_input_idx": image_input_idx,
575
+ "image_masks": image_masks,
576
+ }
577
+ )
578
+ else:
579
+ # Text only
580
+ max_len = max(len(t) for t in tokens_list)
581
+ pad_token_id = self.tokenizer.pad_token_id
582
+ if pad_token_id is None:
583
+ pad_token_id = self.tokenizer.eos_token_id or 0
584
+ padded = []
585
+ for t in tokens_list:
586
+ pad_len = max_len - len(t)
587
+ if pad_len > 0:
588
+ t = t + [pad_token_id] * pad_len
589
+ padded.append(t)
590
+
591
+ return BatchFeature(data={"input_ids": mx.array(padded, dtype=mx.int32)})
592
+
593
+ def batch_decode(self, *args, **kwargs):
594
+ """Forward to tokenizer's batch_decode."""
595
+ return self.tokenizer.batch_decode(*args, **kwargs)
596
+
597
+ def decode(self, *args, **kwargs):
598
+ """Forward to tokenizer's decode."""
599
+ return self.tokenizer.decode(*args, **kwargs)
600
+
601
+ def apply_chat_template(
602
+ self,
603
+ conversation,
604
+ chat_template=None,
605
+ add_generation_prompt=False,
606
+ tokenize=False,
607
+ **kwargs,
608
+ ):
609
+ """Apply chat template."""
610
+ if chat_template is None:
611
+ chat_template = self.chat_template
612
+ if chat_template is None:
613
+ chat_template = getattr(self.tokenizer, "chat_template", None)
614
+ if chat_template is None:
615
+ # Default Molmo chat template
616
+ chat_template = (
617
+ "{% for message in messages %}"
618
+ "{% if message['role'] == 'user' %}"
619
+ "User: {{ message['content'] }}\n"
620
+ "{% elif message['role'] == 'assistant' %}"
621
+ "Assistant: {{ message['content'] }}\n"
622
+ "{% endif %}"
623
+ "{% endfor %}"
624
+ "{% if add_generation_prompt %}Assistant: {% endif %}"
625
+ )
626
+
627
+ from jinja2 import Environment
628
+
629
+ # Use Environment with loopcontrols extension to support {% continue %} and {% break %}
630
+ env = Environment(extensions=["jinja2.ext.loopcontrols"])
631
+ template = env.from_string(chat_template)
632
+ rendered = template.render(
633
+ messages=conversation,
634
+ add_generation_prompt=add_generation_prompt,
635
+ **kwargs,
636
+ )
637
+
638
+ if tokenize:
639
+ return self.tokenizer.encode(rendered)
640
+ return rendered
641
+
642
+ @property
643
+ def model_input_names(self):
644
+ """Get model input names."""
645
+ return ["input_ids", "pixel_values", "image_input_idx", "image_masks"]
646
+
647
+ @classmethod
648
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
649
+ """Load processor from pretrained model."""
650
+ from huggingface_hub import hf_hub_download
651
+
652
+ kwargs.pop("trust_remote_code", None)
653
+
654
+ model_path = Path(pretrained_model_name_or_path)
655
+ is_local = model_path.exists() and model_path.is_dir()
656
+
657
+ # Load tokenizer
658
+ tokenizer = AutoTokenizer.from_pretrained(
659
+ str(model_path) if is_local else pretrained_model_name_or_path,
660
+ trust_remote_code=True,
661
+ local_files_only=is_local,
662
+ )
663
+
664
+ # Load image processor config
665
+ image_processor_config = {}
666
+ try:
667
+ if is_local:
668
+ config_path = model_path / "preprocessor_config.json"
669
+ else:
670
+ config_path = Path(
671
+ hf_hub_download(
672
+ pretrained_model_name_or_path, "preprocessor_config.json"
673
+ )
674
+ )
675
+ if config_path.exists():
676
+ with open(config_path, "r") as f:
677
+ config = json.load(f)
678
+ for key in [
679
+ "max_crops",
680
+ "overlap_margins",
681
+ "base_image_input_size",
682
+ "image_token_length_w",
683
+ "image_token_length_h",
684
+ "image_patch_size",
685
+ "image_padding_mask",
686
+ "do_normalize",
687
+ "image_mean",
688
+ "image_std",
689
+ ]:
690
+ if key in config:
691
+ image_processor_config[key] = config[key]
692
+ except Exception:
693
+ pass
694
+
695
+ image_processor = MolmoImageProcessor(**image_processor_config)
696
+
697
+ # Load chat template
698
+ chat_template = getattr(tokenizer, "chat_template", None)
699
+ if chat_template is None:
700
+ try:
701
+ if is_local:
702
+ jinja_path = model_path / "chat_template.jinja"
703
+ else:
704
+ jinja_path = Path(
705
+ hf_hub_download(
706
+ pretrained_model_name_or_path, "chat_template.jinja"
707
+ )
708
+ )
709
+ if jinja_path.exists():
710
+ chat_template = jinja_path.read_text(encoding="utf-8")
711
+ tokenizer.chat_template = chat_template
712
+ except Exception:
713
+ pass
714
+
715
+ return cls(
716
+ image_processor=image_processor,
717
+ tokenizer=tokenizer,
718
+ chat_template=chat_template,
719
+ )
720
+
721
+
722
+ # Patch AutoProcessor for Molmo models
723
+ from transformers import AutoProcessor
724
+
725
+ _original_auto_processor_from_pretrained = AutoProcessor.from_pretrained
726
+
727
+
728
+ @classmethod
729
+ def _patched_auto_processor_from_pretrained(
730
+ cls, pretrained_model_name_or_path, **kwargs
731
+ ):
732
+ """Patched from_pretrained that returns MolmoProcessor for molmo models."""
733
+ from huggingface_hub import hf_hub_download
734
+
735
+ model_path = Path(pretrained_model_name_or_path)
736
+ is_local = model_path.exists() and model_path.is_dir()
737
+
738
+ # Check if this is a molmo model
739
+ is_molmo = False
740
+ try:
741
+ if is_local:
742
+ config_path = model_path / "config.json"
743
+ else:
744
+ config_path = Path(
745
+ hf_hub_download(pretrained_model_name_or_path, "config.json")
746
+ )
747
+ with open(config_path, "r") as f:
748
+ config = json.load(f)
749
+ model_type = config.get("model_type", "").lower()
750
+
751
+ is_molmo = model_type == "molmo"
752
+ except Exception:
753
+ pass
754
+
755
+ if is_molmo:
756
+ return MolmoProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
757
+
758
+ return _original_auto_processor_from_pretrained.__func__(
759
+ cls, pretrained_model_name_or_path, **kwargs
760
+ )
761
+
762
+
763
+ AutoProcessor.from_pretrained = _patched_auto_processor_from_pretrained