fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,522 @@
1
+ import re
2
+ from functools import partial, reduce
3
+ from typing import List, Optional, Tuple
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ import numpy as np
8
+ from PIL import Image
9
+ from transformers.image_transforms import (
10
+ convert_to_rgb,
11
+ resize,
12
+ to_channel_dimension_format,
13
+ )
14
+ from transformers.image_utils import PILImageResampling, to_numpy_array
15
+
16
+ from ..base import BaseImageProcessor, InputEmbeddingsFeatures
17
+ from .config import ModelConfig, VisionConfig
18
+ from .image_crops import adaptive_avg_pool2d, overlap_crop_image, reconstruct_from_crops
19
+ from .language import LanguageModel
20
+ from .vision import VisionModel
21
+
22
+
23
+ class ImageProcessor(BaseImageProcessor):
24
+ """Moondream image processor with multi-crop support."""
25
+
26
+ def __init__(self, max_crops: int = 12, overlap_margin: int = 4):
27
+ super().__init__(
28
+ image_mean=(0.5, 0.5, 0.5),
29
+ image_std=(0.5, 0.5, 0.5),
30
+ size=(378, 378),
31
+ resample=PILImageResampling.BICUBIC,
32
+ rescale_factor=1 / 255,
33
+ )
34
+ self.max_crops = max_crops
35
+ self.overlap_margin = overlap_margin
36
+
37
+ def preprocess(
38
+ self, images
39
+ ) -> Tuple[List[np.ndarray], List[int], List[Tuple[int, int]]]:
40
+ """
41
+ Preprocess images with multi-crop support.
42
+
43
+ Args:
44
+ images: Single PIL Image or list of PIL Images
45
+
46
+ Returns:
47
+ crops_list: List of [n_crops, C, H, W] arrays per image
48
+ crop_counts: Number of crops per image
49
+ tilings: (h_tiles, w_tiles) per image
50
+ """
51
+ if isinstance(images, Image.Image):
52
+ images = [images]
53
+ else:
54
+ assert isinstance(images, list)
55
+
56
+ crops_list = []
57
+ crop_counts = []
58
+ tilings = []
59
+
60
+ for image in images:
61
+ # Convert to RGB numpy array
62
+ image = convert_to_rgb(image)
63
+ image_np = to_numpy_array(image)
64
+
65
+ # Get multi-crop decomposition
66
+ crops, tiling = overlap_crop_image(
67
+ image_np,
68
+ max_crops=self.max_crops,
69
+ overlap_margin=self.overlap_margin,
70
+ base_size=self.size,
71
+ patch_size=14,
72
+ )
73
+ # crops is [n_crops, H, W, C] in range [0, 255]
74
+
75
+ # Normalize each crop: (pixel/255 - 0.5) / 0.5 = [-1, 1]
76
+ crops = crops.astype(np.float32) * self.rescale_factor # [0, 1]
77
+ crops = (crops - 0.5) / 0.5 # [-1, 1]
78
+
79
+ # Convert to channel-first format: [n_crops, H, W, C] -> [n_crops, C, H, W]
80
+ crops = np.transpose(crops, (0, 3, 1, 2))
81
+
82
+ crops_list.append(crops)
83
+ crop_counts.append(crops.shape[0])
84
+ tilings.append(tiling)
85
+
86
+ return crops_list, crop_counts, tilings
87
+
88
+
89
+ class VisionProjection(nn.Module):
90
+ """
91
+ 2-layer MLP projector from vision to language space.
92
+
93
+ Projects concatenated [global, reconstructed] features (2304D) to language
94
+ model dimension (2048D). The input is the concatenation of:
95
+ - Global features: [B, 729, 1152] from full image
96
+ - Reconstructed features: [B, 729, 1152] pooled from local crops
97
+
98
+ Reference: moondream2/vision.py:77-89
99
+ """
100
+
101
+ def __init__(self, config: ModelConfig):
102
+ super().__init__()
103
+ # Input is concatenation of global and reconstructed: 1152 * 2 = 2304
104
+ vision_dim = config.vision_config.hidden_size * 2 # 2304
105
+ inner_dim = config.proj_inner_dim # 8192
106
+ output_dim = config.text_config.hidden_size # 2048
107
+
108
+ self.fc1 = nn.Linear(vision_dim, inner_dim, bias=True)
109
+ self.fc2 = nn.Linear(inner_dim, output_dim, bias=True)
110
+ self.activation = nn.GELU(approx="precise")
111
+
112
+ def __call__(
113
+ self, global_features: mx.array, reconstructed_features: mx.array
114
+ ) -> mx.array:
115
+ """
116
+ Project concatenated vision features to language model dimension.
117
+
118
+ Args:
119
+ global_features: [B, 729, 1152] features from global crop
120
+ reconstructed_features: [B, 729, 1152] features reconstructed from local crops
121
+
122
+ Returns:
123
+ [B, 729, 2048] projected features
124
+ """
125
+ # Concatenate along feature dimension: [B, 729, 2304]
126
+ x = mx.concatenate([global_features, reconstructed_features], axis=-1)
127
+ x = self.activation(self.fc1(x))
128
+ x = self.fc2(x)
129
+ return x
130
+
131
+
132
+ class Model(nn.Module):
133
+ """Moondream 2 model for visual question answering."""
134
+
135
+ def __init__(self, config: ModelConfig):
136
+ super().__init__()
137
+ self.model_type = config.model_type
138
+ self.config = config
139
+
140
+ self.vision_encoder = VisionModel(config.vision_config)
141
+ self.vision_projection = VisionProjection(config)
142
+ self.language_model = LanguageModel(config.text_config)
143
+
144
+ def get_input_embeddings(
145
+ self,
146
+ input_ids: Optional[mx.array] = None,
147
+ pixel_values: Optional[mx.array] = None,
148
+ crop_counts: Optional[List[int]] = None,
149
+ tilings: Optional[List[Tuple[int, int]]] = None,
150
+ **kwargs,
151
+ ):
152
+ """
153
+ Get input embeddings with multi-crop image features.
154
+
155
+ Full pipeline:
156
+ 1. Encode ALL crops through vision_encoder: [total_crops, 729, 1152]
157
+ 2. For each image:
158
+ a. global_features = features[0] # [729, 1152]
159
+ b. local_features = features[1:].reshape(n_local, 27, 27, 1152)
160
+ c. reconstructed = reconstruct_from_crops(local_features, tiling)
161
+ d. reconstructed = adaptive_avg_pool2d(reconstructed, (27, 27))
162
+ e. reconstructed = reconstructed.reshape(729, 1152)
163
+ f. projected = vision_projection(global, reconstructed) # [729, 2048]
164
+ 3. Insert projected features into embeddings
165
+
166
+ Args:
167
+ input_ids: Token IDs [B, seq_len]
168
+ pixel_values: Concatenated crops [total_crops, C, H, W]
169
+ crop_counts: Number of crops per image (list of ints)
170
+ tilings: (h_tiles, w_tiles) per image (list of tuples)
171
+ """
172
+ # #region agent log
173
+ import json
174
+ log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
175
+ def log_embed(location, message, data, hypothesis_id):
176
+ try:
177
+ with open(log_file, "a") as f:
178
+ f.write(json.dumps({"sessionId": "debug-session", "runId": "inference", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
179
+ except: pass
180
+
181
+ log_embed("moondream2.py:get_input_embeddings_entry", "Entry to get_input_embeddings", {
182
+ "input_ids_shape": str(input_ids.shape) if input_ids is not None else None,
183
+ "pixel_values_shape": str(pixel_values.shape) if pixel_values is not None else None,
184
+ "crop_counts": crop_counts,
185
+ "tilings": tilings,
186
+ "input_ids_sample": input_ids[0, :20].tolist() if input_ids is not None else None
187
+ }, "H7,H9")
188
+ # #endregion
189
+
190
+ if pixel_values is None:
191
+ return InputEmbeddingsFeatures(
192
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
193
+ )
194
+
195
+ # Get text embeddings
196
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
197
+
198
+ # #region agent log
199
+ log_embed("moondream2.py:text_embeddings", "Text embeddings from embed_tokens", {
200
+ "shape": str(inputs_embeds.shape),
201
+ "dtype": str(inputs_embeds.dtype),
202
+ "mean": float(mx.mean(inputs_embeds)),
203
+ "std": float(mx.std(inputs_embeds)),
204
+ "min": float(mx.min(inputs_embeds)),
205
+ "max": float(mx.max(inputs_embeds))
206
+ }, "H9")
207
+ # #endregion
208
+
209
+ # Encode ALL crops through vision encoder at once
210
+ # pixel_values is [total_crops, C, H, W]
211
+ all_features = self.vision_encoder(pixel_values) # [total_crops, 729, 1152]
212
+
213
+ # #region agent log
214
+ log_embed("moondream2.py:vision_features", "Vision encoder output", {
215
+ "shape": str(all_features.shape),
216
+ "dtype": str(all_features.dtype),
217
+ "mean": float(mx.mean(all_features)),
218
+ "std": float(mx.std(all_features)),
219
+ "min": float(mx.min(all_features)),
220
+ "max": float(mx.max(all_features))
221
+ }, "H6,H8")
222
+ # #endregion
223
+
224
+ # Process each image's crops
225
+ batch_size = len(crop_counts) if crop_counts is not None else 1
226
+ projected_features_list = []
227
+
228
+ crop_offset = 0
229
+ for b in range(batch_size):
230
+ n_crops = crop_counts[b] if crop_counts is not None else all_features.shape[0]
231
+ tiling = tilings[b] if tilings is not None else (1, 1)
232
+
233
+ # Extract features for this image
234
+ img_features = all_features[crop_offset : crop_offset + n_crops] # [n_crops, 729, 1152]
235
+ crop_offset += n_crops
236
+
237
+ # Global features from first crop
238
+ global_features = img_features[0] # [729, 1152]
239
+
240
+ # Local crop features
241
+ n_local = n_crops - 1
242
+ if n_local > 0:
243
+ local_features = img_features[1:] # [n_local, 729, 1152]
244
+
245
+ # Reshape to spatial grid: [n_local, 729, 1152] -> [n_local, 27, 27, 1152]
246
+ local_features = local_features.reshape(n_local, 27, 27, -1)
247
+
248
+ # Reconstruct unified feature map from local crops
249
+ reconstructed = reconstruct_from_crops(
250
+ local_features, tiling, overlap_margin=self.config.vision_config.overlap_margin
251
+ ) # [H, W, 1152]
252
+
253
+ # Pool back to 27x27 to match global features
254
+ reconstructed = adaptive_avg_pool2d(reconstructed, (27, 27)) # [27, 27, 1152]
255
+
256
+ # Flatten to [729, 1152]
257
+ reconstructed_flat = reconstructed.reshape(729, -1)
258
+ else:
259
+ # No local crops, duplicate global for reconstruction
260
+ reconstructed_flat = global_features
261
+
262
+ # Add batch dimension and project
263
+ global_batch = global_features[None, :, :] # [1, 729, 1152]
264
+ reconstructed_batch = reconstructed_flat[None, :, :] # [1, 729, 1152]
265
+
266
+ # #region agent log
267
+ log_embed("moondream2.py:before_projection", "Features before projection", {
268
+ "global_shape": str(global_batch.shape),
269
+ "global_mean": float(mx.mean(global_batch)),
270
+ "global_std": float(mx.std(global_batch)),
271
+ "reconstructed_shape": str(reconstructed_batch.shape),
272
+ "reconstructed_mean": float(mx.mean(reconstructed_batch)),
273
+ "reconstructed_std": float(mx.std(reconstructed_batch)),
274
+ "n_local_crops": n_local
275
+ }, "H6,H8")
276
+ # #endregion
277
+
278
+ # Project concatenated features: [1, 729, 2304] -> [1, 729, 2048]
279
+ projected = self.vision_projection(global_batch, reconstructed_batch)
280
+
281
+ # #region agent log
282
+ log_embed("moondream2.py:after_projection", "Projected vision features", {
283
+ "shape": str(projected.shape),
284
+ "dtype": str(projected.dtype),
285
+ "mean": float(mx.mean(projected)),
286
+ "std": float(mx.std(projected)),
287
+ "min": float(mx.min(projected)),
288
+ "max": float(mx.max(projected))
289
+ }, "H6")
290
+ # #endregion
291
+
292
+ projected_features_list.append(projected)
293
+
294
+ # Concatenate all projected features
295
+ image_features = mx.concatenate(projected_features_list, axis=0) # [B, 729, 2048]
296
+
297
+ # #region agent log
298
+ log_embed("moondream2.py:concatenated_image_features", "Concatenated image features", {
299
+ "shape": str(image_features.shape),
300
+ "dtype": str(image_features.dtype),
301
+ "mean": float(mx.mean(image_features)),
302
+ "std": float(mx.std(image_features)),
303
+ "min": float(mx.min(image_features)),
304
+ "max": float(mx.max(image_features))
305
+ }, "H6")
306
+ # #endregion
307
+
308
+ # Replace 729-token image placeholder in input_ids with vision features
309
+ # prepare_inputs() creates input_ids as: [BOS, <img_token>*729, <text_tokens>]
310
+ patch_count = image_features.shape[1] # expected 729
311
+ if inputs_embeds.shape[1] >= 1 + patch_count:
312
+ # #region agent log
313
+ log_embed("moondream2.py:before_replacement", "Before replacing image tokens", {
314
+ "inputs_embeds_shape": str(inputs_embeds.shape),
315
+ "patch_count": patch_count,
316
+ "replacement_range": f"[1:{1 + patch_count}]",
317
+ "image_features_dtype": str(image_features.dtype),
318
+ "inputs_embeds_dtype": str(inputs_embeds.dtype),
319
+ "embeds_before_mean": float(mx.mean(inputs_embeds[:, 1 : 1 + patch_count, :])),
320
+ "embeds_before_std": float(mx.std(inputs_embeds[:, 1 : 1 + patch_count, :]))
321
+ }, "H7,H9")
322
+ # #endregion
323
+
324
+ # Replace positions [1 : 1+patch_count] (right after BOS)
325
+ inputs_embeds[:, 1 : 1 + patch_count, :] = image_features.astype(
326
+ inputs_embeds.dtype
327
+ )
328
+
329
+ # #region agent log
330
+ log_embed("moondream2.py:after_replacement", "After replacing image tokens", {
331
+ "embeds_after_mean": float(mx.mean(inputs_embeds[:, 1 : 1 + patch_count, :])),
332
+ "embeds_after_std": float(mx.std(inputs_embeds[:, 1 : 1 + patch_count, :])),
333
+ "text_tokens_mean": float(mx.mean(inputs_embeds[:, 1 + patch_count:, :])) if inputs_embeds.shape[1] > 1 + patch_count else None,
334
+ "text_tokens_std": float(mx.std(inputs_embeds[:, 1 + patch_count:, :])) if inputs_embeds.shape[1] > 1 + patch_count else None
335
+ }, "H7,H9")
336
+ # #endregion
337
+
338
+ final_embeddings = inputs_embeds
339
+ else:
340
+ # Fallback: original behavior (prepend image embeddings)
341
+ batch_size = inputs_embeds.shape[0]
342
+ final_embeddings = []
343
+ for b in range(batch_size):
344
+ bos_embed = inputs_embeds[b : b + 1, :1, :]
345
+ text_embed = inputs_embeds[b : b + 1, 1:, :]
346
+ img_embed = image_features[b : b + 1]
347
+ combined = mx.concatenate([bos_embed, img_embed, text_embed], axis=1)
348
+ final_embeddings.append(combined)
349
+ final_embeddings = mx.concatenate(final_embeddings, axis=0)
350
+
351
+ # #region agent log
352
+ log_embed("moondream2.py:final_embeddings", "Final embeddings output", {
353
+ "shape": str(final_embeddings.shape),
354
+ "dtype": str(final_embeddings.dtype),
355
+ "mean": float(mx.mean(final_embeddings)),
356
+ "std": float(mx.std(final_embeddings)),
357
+ "min": float(mx.min(final_embeddings)),
358
+ "max": float(mx.max(final_embeddings)),
359
+ "has_nan": bool(mx.any(mx.isnan(final_embeddings))),
360
+ "has_inf": bool(mx.any(mx.isinf(final_embeddings)))
361
+ }, "H9")
362
+ # #endregion
363
+
364
+ return InputEmbeddingsFeatures(inputs_embeds=final_embeddings)
365
+
366
+ @property
367
+ def layers(self):
368
+ return self.language_model.model.layers
369
+
370
+ def __call__(
371
+ self,
372
+ input_ids: mx.array,
373
+ pixel_values: mx.array,
374
+ mask: Optional[mx.array] = None,
375
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
376
+ crop_counts: Optional[List[int]] = None,
377
+ tilings: Optional[List[Tuple[int, int]]] = None,
378
+ **kwargs,
379
+ ):
380
+ input_embeddings_features = self.get_input_embeddings(
381
+ input_ids, pixel_values, crop_counts=crop_counts, tilings=tilings, **kwargs
382
+ )
383
+
384
+ logits = self.language_model(
385
+ inputs=input_ids,
386
+ cache=cache,
387
+ inputs_embeds=input_embeddings_features.inputs_embeds,
388
+ mask=mask,
389
+ )
390
+ return logits
391
+
392
+ def sanitize(self, weights):
393
+ """
394
+ Map HuggingFace weights to MLX model structure.
395
+
396
+ HF Weight Structure (from moondream2/vision.py + text.py):
397
+ - model.vision.patch_emb.* -> vision_encoder.patch_emb.*
398
+ - model.vision.pos_emb -> vision_encoder.position_embedding
399
+ - model.vision.blocks.{i}.ln1.* -> vision_encoder.encoder.layers.{i}.ln1.*
400
+ - model.vision.blocks.{i}.attn.* -> vision_encoder.encoder.layers.{i}.attn.*
401
+ - model.vision.blocks.{i}.ln2.* -> vision_encoder.encoder.layers.{i}.ln2.*
402
+ - model.vision.blocks.{i}.mlp.* -> vision_encoder.encoder.layers.{i}.mlp.*
403
+ - model.vision.post_ln.* -> vision_encoder.post_layernorm.*
404
+ - model.vision.proj_mlp.* -> vision_projection.*
405
+ - model.text.wte -> language_model.model.embed_tokens.weight
406
+ - model.text.blocks.{i}.ln.* -> language_model.model.layers.{i}.input_layernorm.*
407
+ - model.text.blocks.{i}.attn.qkv.* -> language_model.model.layers.{i}.self_attn.qkv_proj.*
408
+ - model.text.blocks.{i}.attn.proj.* -> language_model.model.layers.{i}.self_attn.o_proj.*
409
+ - model.text.blocks.{i}.mlp.* -> language_model.model.layers.{i}.mlp.*
410
+ - model.text.post_ln.* -> language_model.model.norm.*
411
+ - model.text.lm_head.* -> language_model.lm_head.*
412
+ - model.region.* -> (skip, not needed for VQA)
413
+ """
414
+ # #region agent log
415
+ import json
416
+ log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
417
+ def log(location, message, data, hypothesis_id):
418
+ try:
419
+ with open(log_file, "a") as f:
420
+ f.write(json.dumps({"sessionId": "debug-session", "runId": "sanitize", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
421
+ except: pass
422
+ # #endregion
423
+
424
+ new_weights = {}
425
+ n_skipped_region = 0
426
+ n_changed = 0
427
+ n_unchanged = 0
428
+
429
+ # #region agent log
430
+ original_keys = sorted(weights.keys())
431
+ log("moondream2.py:sanitize_entry", "Sanitize entry - input weights", {"n_weights": len(weights), "sample_keys": original_keys[:10], "all_keys": original_keys}, "H1")
432
+ # #endregion
433
+
434
+ for k, v in weights.items():
435
+ # Skip region model weights (not needed for VQA)
436
+ if k.startswith("model.region."):
437
+ n_skipped_region += 1
438
+ continue
439
+
440
+ new_key = k
441
+
442
+ # Vision encoder: patch embedding
443
+ if k.startswith("model.vision.patch_emb."):
444
+ new_key = k.replace(
445
+ "model.vision.patch_emb.",
446
+ "vision_encoder.patch_emb.",
447
+ )
448
+
449
+ # Vision encoder: positional embedding
450
+ elif k == "model.vision.pos_emb":
451
+ new_key = "vision_encoder.position_embedding"
452
+
453
+ # Vision encoder: blocks
454
+ elif k.startswith("model.vision.blocks."):
455
+ # Extract block number and rest
456
+ match = re.match(r"model\.vision\.blocks\.(\d+)\.(.+)", k)
457
+ if match:
458
+ block_num = match.group(1)
459
+ suffix = match.group(2)
460
+ # Keep the structure: ln1, attn, ln2, mlp
461
+ new_key = f"vision_encoder.encoder.layers.{block_num}.{suffix}"
462
+
463
+ # Vision encoder: post layer norm
464
+ elif k.startswith("model.vision.post_ln."):
465
+ new_key = k.replace(
466
+ "model.vision.post_ln.",
467
+ "vision_encoder.post_layernorm.",
468
+ )
469
+
470
+ # Vision projection MLP
471
+ elif k.startswith("model.vision.proj_mlp."):
472
+ new_key = k.replace(
473
+ "model.vision.proj_mlp.",
474
+ "vision_projection.",
475
+ )
476
+
477
+ # Text model: embedding
478
+ elif k == "model.text.wte":
479
+ new_key = "language_model.model.embed_tokens.weight"
480
+
481
+ # Text model: transformer blocks
482
+ elif k.startswith("model.text.blocks."):
483
+ # Extract block number and rest
484
+ match = re.match(r"model\.text\.blocks\.(\d+)\.(.+)", k)
485
+ if match:
486
+ block_num = match.group(1)
487
+ suffix = match.group(2)
488
+
489
+ # Map the suffixes
490
+ if suffix.startswith("ln."):
491
+ new_suffix = suffix.replace("ln.", "input_layernorm.")
492
+ elif suffix.startswith("attn.qkv."):
493
+ new_suffix = suffix.replace("attn.qkv.", "self_attn.qkv_proj.")
494
+ elif suffix.startswith("attn.proj."):
495
+ new_suffix = suffix.replace("attn.proj.", "self_attn.o_proj.")
496
+ elif suffix.startswith("mlp."):
497
+ new_suffix = suffix
498
+ else:
499
+ new_suffix = suffix
500
+
501
+ new_key = f"language_model.model.layers.{block_num}.{new_suffix}"
502
+
503
+ # Text model: final layer norm
504
+ elif k.startswith("model.text.post_ln."):
505
+ new_key = k.replace("model.text.post_ln.", "language_model.model.norm.")
506
+
507
+ # Text model: lm head
508
+ elif k.startswith("model.text.lm_head."):
509
+ new_key = k.replace("model.text.lm_head.", "language_model.lm_head.")
510
+
511
+ if new_key == k:
512
+ n_unchanged += 1
513
+ else:
514
+ n_changed += 1
515
+ new_weights[new_key] = v
516
+
517
+ # #region agent log
518
+ sanitized_keys = sorted(new_weights.keys())
519
+ log("moondream2.py:sanitize_exit", "Sanitize exit - output weights", {"n_weights": len(new_weights), "n_changed": n_changed, "n_unchanged": n_unchanged, "n_skipped": n_skipped_region, "sample_keys": sanitized_keys[:10], "all_keys": sanitized_keys}, "H1")
520
+ # #endregion
521
+
522
+ return new_weights
@@ -0,0 +1,144 @@
1
+ """
2
+ Moondream2 processor for mlx-vlm.
3
+ """
4
+
5
+ from typing import List, Optional, Union
6
+
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+ from transformers.processing_utils import ProcessorMixin
10
+
11
+
12
+ class MoondreamProcessor(ProcessorMixin):
13
+ """
14
+ Processor for Moondream2 model.
15
+
16
+ Wraps the tokenizer and provides compatibility with mlx-vlm's generation flow.
17
+ Image processing is handled separately by the model's ImageProcessor class.
18
+ """
19
+
20
+ tokenizer_class = "AutoTokenizer"
21
+ attributes = ["tokenizer"]
22
+
23
+ def __init__(self, tokenizer, chat_template: Optional[str] = None, **kwargs):
24
+ self.tokenizer = tokenizer
25
+
26
+ # Set up chat template for moondream
27
+ if chat_template is None:
28
+ # Moondream uses a simple format: <image>\n\nQuestion: {question}\n\nAnswer:
29
+ chat_template = (
30
+ "{% for message in messages %}"
31
+ "{% if message['role'] == 'user' %}"
32
+ "{{ message['content'] }}\n\n"
33
+ "{% elif message['role'] == 'assistant' %}"
34
+ "{{ message['content'] }}"
35
+ "{% endif %}"
36
+ "{% endfor %}"
37
+ "{% if add_generation_prompt %}Answer: {% endif %}"
38
+ )
39
+
40
+ self.chat_template = chat_template
41
+ tokenizer.chat_template = chat_template
42
+ super().__init__(tokenizer, **kwargs)
43
+
44
+ @classmethod
45
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
46
+ """Load processor from pretrained model path."""
47
+ # Convert Path to string if needed
48
+ if hasattr(pretrained_model_name_or_path, "__fspath__"):
49
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
50
+
51
+ # Pop kwargs that are not valid for AutoTokenizer
52
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
53
+
54
+ # Moondream2 uses a custom tokenizer (starmie-v1), not the GPT-2
55
+ # tokenizer files shipped in the model repo.
56
+ tokenizer = AutoTokenizer.from_pretrained(
57
+ "moondream/starmie-v1",
58
+ trust_remote_code=trust_remote_code,
59
+ )
60
+
61
+ # starmie-v1 doesn't define special token roles; set them to
62
+ # <|endoftext|> (ID 0) which moondream uses as BOS/EOS/PAD.
63
+ tokenizer.eos_token = "<|endoftext|>"
64
+ tokenizer.bos_token = "<|endoftext|>"
65
+ tokenizer.pad_token = "<|endoftext|>"
66
+
67
+ return cls(tokenizer=tokenizer)
68
+
69
+ def __call__(
70
+ self,
71
+ text: Optional[Union[str, List[str]]] = None,
72
+ images: Optional[Union[Image.Image, List[Image.Image]]] = None,
73
+ **kwargs,
74
+ ):
75
+ """
76
+ Process text and images for the model.
77
+
78
+ Note: Image processing is handled by the model's ImageProcessor,
79
+ this processor mainly handles tokenization.
80
+ """
81
+ if text is None:
82
+ raise ValueError("Text input is required")
83
+
84
+ # Tokenize text
85
+ encoding = self.tokenizer(text, **kwargs)
86
+
87
+ return encoding
88
+
89
+ def batch_decode(self, *args, **kwargs):
90
+ """Decode token ids to text."""
91
+ return self.tokenizer.batch_decode(*args, **kwargs)
92
+
93
+ def decode(self, *args, **kwargs):
94
+ """Decode token ids to text."""
95
+ return self.tokenizer.decode(*args, **kwargs)
96
+
97
+ def apply_chat_template(self, messages, add_generation_prompt=True, **kwargs):
98
+ """Apply chat template to messages."""
99
+ return self.tokenizer.apply_chat_template(
100
+ messages,
101
+ chat_template=self.chat_template,
102
+ add_generation_prompt=add_generation_prompt,
103
+ tokenize=kwargs.get("tokenize", False),
104
+ **{k: v for k, v in kwargs.items() if k != "tokenize"}
105
+ )
106
+
107
+ # Token properties delegated to tokenizer
108
+ @property
109
+ def pad_token(self):
110
+ return self.tokenizer.pad_token
111
+
112
+ @pad_token.setter
113
+ def pad_token(self, value):
114
+ self.tokenizer.pad_token = value
115
+
116
+ @property
117
+ def pad_token_id(self):
118
+ return self.tokenizer.pad_token_id
119
+
120
+ @pad_token_id.setter
121
+ def pad_token_id(self, value):
122
+ self.tokenizer.pad_token_id = value
123
+
124
+ @property
125
+ def eos_token(self):
126
+ return self.tokenizer.eos_token
127
+
128
+ @property
129
+ def eos_token_id(self):
130
+ return self.tokenizer.eos_token_id
131
+
132
+ @property
133
+ def bos_token(self):
134
+ return self.tokenizer.bos_token
135
+
136
+ @property
137
+ def bos_token_id(self):
138
+ return self.tokenizer.bos_token_id
139
+
140
+
141
+ # Install the AutoProcessor patch for moondream1 model type
142
+ from ..base import install_auto_processor_patch
143
+
144
+ install_auto_processor_patch("moondream1", MoondreamProcessor)