fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,297 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from transformers import AutoProcessor
7
+
8
+ from ..base import InputEmbeddingsFeatures
9
+ from ..deepseekocr.language import LanguageModel
10
+ from ..deepseekocr.sam import SAMEncoder
11
+ from .config import ModelConfig, SAMViTConfig
12
+ from .processing_deepseekocr import DeepseekOCR2Processor
13
+ from .vision import VisionModel
14
+
15
+ AutoProcessor.register("deepseekocr_2", DeepseekOCR2Processor)
16
+
17
+
18
+ class MlpProjector(nn.Module):
19
+ def __init__(self, config: ModelConfig):
20
+ super().__init__()
21
+ self.config = config
22
+
23
+ if config.projector_config.projector_type == "linear":
24
+ self.layers = nn.Linear(
25
+ config.projector_config.input_dim, config.projector_config.n_embed
26
+ )
27
+ else:
28
+ raise ValueError(
29
+ f"Unknown projector type: {config.projector_config.projector_type}"
30
+ )
31
+
32
+ def __call__(self, x):
33
+ return self.layers(x)
34
+
35
+
36
+ class Model(nn.Module):
37
+ def __init__(self, config: ModelConfig):
38
+ super().__init__()
39
+ self.config = config
40
+ self.vision_model = VisionModel(config.vision_config)
41
+ sam_config = SAMViTConfig()
42
+ self.sam_model = SAMEncoder(
43
+ img_size=sam_config.image_size,
44
+ patch_size=sam_config.patch_size,
45
+ embed_dim=sam_config.width,
46
+ depth=sam_config.layers,
47
+ num_heads=sam_config.heads,
48
+ window_size=sam_config.window_size,
49
+ global_attn_indexes=sam_config.global_attn_indexes,
50
+ final_out_chans=896, # OCR-2 uses 896 output channels (vs 1024 in OCR)
51
+ )
52
+ self.language_model = LanguageModel(config.text_config)
53
+ self.projector = MlpProjector(config)
54
+
55
+ self.tile_tag = config.tile_tag
56
+ self.global_view_pos = config.global_view_pos
57
+
58
+ # view_separator is loaded from model weights (mapped from view_seperator)
59
+ # Initialize with zeros - will be overwritten when weights are loaded
60
+ if self.tile_tag == "2D":
61
+ # <|view_separator|> - marks end of image features
62
+ # Note: This must be defined as an mx.array for weight loading to work
63
+ self.view_separator = mx.zeros((config.projector_config.n_embed,))
64
+ else:
65
+ raise ValueError(
66
+ f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
67
+ )
68
+
69
+ def get_input_embeddings(
70
+ self,
71
+ input_ids: Optional[mx.array] = None,
72
+ pixel_values: Optional[mx.array] = None,
73
+ images_spatial_crop: Optional[mx.array] = None,
74
+ images_seq_mask: Optional[mx.array] = None,
75
+ **kwargs,
76
+ ):
77
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
78
+
79
+ if pixel_values is None:
80
+ return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
81
+
82
+ # pixel_values is a list: [patches, global_images]
83
+ if isinstance(pixel_values, list):
84
+ patches, global_images = pixel_values
85
+ else:
86
+ patches = None
87
+ global_images = pixel_values
88
+
89
+ # Check if we have valid pixel values
90
+ if mx.sum(global_images).item() == 0:
91
+ return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
92
+
93
+ # Process images through SAM -> Qwen2 -> Projector pipeline
94
+ batch_size = input_ids.shape[0]
95
+
96
+ for idx in range(batch_size):
97
+ all_features = []
98
+
99
+ # Check if we have valid patches (non-zero)
100
+ has_patches = patches is not None and mx.sum(patches).item() != 0
101
+
102
+ if has_patches:
103
+ # Get spatial crop info for this batch item
104
+ if (
105
+ images_spatial_crop is not None
106
+ and idx < images_spatial_crop.shape[0]
107
+ ):
108
+ rows, cols = int(images_spatial_crop[idx, 0].item()), int(
109
+ images_spatial_crop[idx, 1].item()
110
+ )
111
+ num_patches = rows * cols
112
+ else:
113
+ num_patches = patches.shape[0]
114
+
115
+ # Process each patch through SAM -> Qwen2 -> Projector
116
+ # patches shape: (num_patches, C, H, W) where H=W=768
117
+ for patch_idx in range(num_patches):
118
+ if patch_idx >= patches.shape[0]:
119
+ break
120
+
121
+ patch = patches[patch_idx : patch_idx + 1] # (1, C, H, W)
122
+
123
+ # Transpose to (B, H, W, C) for MLX conv2d
124
+ patch_hwc = patch.transpose(0, 2, 3, 1)
125
+
126
+ # SAM encoder: (1, 768, 768, 3) -> (1, 12, 12, 896)
127
+ sam_features = self.sam_model(patch_hwc)
128
+
129
+ # Qwen2 encoder: (1, 12, 12, 896) -> (1, 144, 896)
130
+ # Uses query_768 automatically based on 144 input tokens
131
+ vision_features = self.vision_model(patch_hwc, sam_features)
132
+
133
+ # Linear projector: (1, 144, 896) -> (1, 144, 1280)
134
+ vision_features = self.projector(vision_features)
135
+
136
+ # Remove batch dimension: (144, 1280)
137
+ all_features.append(vision_features[0])
138
+
139
+ # Process global view through SAM -> Qwen2 -> Projector
140
+ # global_images is (N, C, H, W) where H=W=1024
141
+ global_image = global_images[idx : idx + 1] # (1, C, H, W)
142
+
143
+ # Transpose to (B, H, W, C) for MLX conv2d
144
+ global_hwc = global_image.transpose(0, 2, 3, 1)
145
+
146
+ # SAM encoder: (1, 1024, 1024, 3) -> (1, 16, 16, 896)
147
+ sam_features = self.sam_model(global_hwc)
148
+
149
+ # Qwen2 encoder: (1, 16, 16, 896) -> (1, 256, 896)
150
+ # Uses query_1024 automatically based on 256 input tokens
151
+ global_features = self.vision_model(global_hwc, sam_features)
152
+
153
+ # Linear projector: (1, 256, 896) -> (1, 256, 1280)
154
+ global_features = self.projector(global_features)
155
+
156
+ # Remove batch dimension: (256, 1280)
157
+ all_features.append(global_features[0])
158
+
159
+ # Add view_separator
160
+ all_features.append(self.view_separator[None, :])
161
+
162
+ # Concatenate all features: [local_patches..., global, view_sep]
163
+ # Shape: (num_patches * 144 + 256 + 1, 1280)
164
+ vision_features = mx.concatenate(all_features, axis=0)
165
+
166
+ # Find positions where images should be placed
167
+ if images_seq_mask is not None:
168
+ image_indices = np.where(images_seq_mask[idx])[0].tolist()
169
+ # Assign image features to those positions
170
+ if len(image_indices) > 0:
171
+ num_positions = len(image_indices)
172
+ if vision_features.shape[0] >= num_positions:
173
+ input_embeds[idx, image_indices] = vision_features[
174
+ :num_positions
175
+ ]
176
+ else:
177
+ # If we have fewer features than expected, pad with the last features
178
+ input_embeds[idx, image_indices[: vision_features.shape[0]]] = (
179
+ vision_features
180
+ )
181
+
182
+ return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
183
+
184
+ @property
185
+ def layers(self):
186
+ return self.language_model.model.layers
187
+
188
+ def __call__(
189
+ self,
190
+ input_ids: mx.array,
191
+ pixel_values: Optional[mx.array] = None,
192
+ mask: Optional[mx.array] = None,
193
+ cache=None,
194
+ **kwargs,
195
+ ):
196
+ images_spatial_crop = kwargs.get("images_spatial_crop", None)
197
+ images_seq_mask = kwargs.get("images_seq_mask", None)
198
+
199
+ input_embeddings_features = self.get_input_embeddings(
200
+ input_ids, pixel_values, images_spatial_crop, images_seq_mask
201
+ )
202
+
203
+ logits = self.language_model(
204
+ input_ids,
205
+ cache=cache,
206
+ inputs_embeds=input_embeddings_features.inputs_embeds,
207
+ )
208
+ return logits
209
+
210
+ @staticmethod
211
+ def sanitize(weights):
212
+ def transform_key(key):
213
+ # Handle Qwen2 encoder weights from HuggingFace format
214
+ # HuggingFace: model.qwen2_model.model.model.layers.X...
215
+ # MLX: vision_model.qwen2_encoder.layers.X...
216
+ if "qwen2_model.model.model.layers" in key:
217
+ key = key.replace(
218
+ "model.qwen2_model.model.model.layers",
219
+ "vision_model.qwen2_encoder.layers",
220
+ )
221
+
222
+ # Handle Qwen2 encoder norm
223
+ if "qwen2_model.model.model.norm" in key:
224
+ key = key.replace(
225
+ "model.qwen2_model.model.model.norm",
226
+ "vision_model.qwen2_encoder.norm",
227
+ )
228
+
229
+ # Handle query weights (learnable queries for Qwen2 encoder)
230
+ # For 1024x1024 images, SAM outputs 16x16=256 features, so use query_1024
231
+ # query_1024: (256, 896) - used for 1024x1024 images
232
+ # query_768: (144, 896) - used for 768x768 images
233
+ if "model.qwen2_model.query_1024.weight" in key:
234
+ key = key.replace(
235
+ "model.qwen2_model.query_1024.weight",
236
+ "vision_model.qwen2_encoder.query_1024",
237
+ )
238
+ elif "model.qwen2_model.query_1024" in key:
239
+ key = key.replace(
240
+ "model.qwen2_model.query_1024",
241
+ "vision_model.qwen2_encoder.query_1024",
242
+ )
243
+ # Also handle query_768 for smaller images (not currently used but keep for future)
244
+ if "model.qwen2_model.query_768.weight" in key:
245
+ key = key.replace(
246
+ "model.qwen2_model.query_768.weight",
247
+ "vision_model.qwen2_encoder.query_768",
248
+ )
249
+ elif "model.qwen2_model.query_768" in key:
250
+ key = key.replace(
251
+ "model.qwen2_model.query_768",
252
+ "vision_model.qwen2_encoder.query_768",
253
+ )
254
+
255
+ # Language model layers
256
+ if (
257
+ "model.layers" in key
258
+ and "language_model" not in key
259
+ and "qwen2" not in key
260
+ ):
261
+ key = key.replace("model.layers", "language_model.model.layers")
262
+
263
+ if (
264
+ "model.embed_tokens" in key
265
+ and "language_model" not in key
266
+ and "qwen2" not in key
267
+ ):
268
+ key = key.replace(
269
+ "model.embed_tokens", "language_model.model.embed_tokens"
270
+ )
271
+
272
+ if (
273
+ "model.norm" in key
274
+ and "language_model" not in key
275
+ and "qwen2" not in key
276
+ ):
277
+ key = key.replace("model.norm", "language_model.model.norm")
278
+
279
+ if "model.vision_model" in key:
280
+ key = key.replace("model.vision_model", "vision_model")
281
+
282
+ if "model.sam_model" in key:
283
+ key = key.replace("model.sam_model", "sam_model")
284
+
285
+ if "model.projector" in key:
286
+ key = key.replace("model.projector", "projector")
287
+
288
+ # Note: HuggingFace has typo "view_seperator" (e instead of a)
289
+ if "model.view_seperator" in key:
290
+ key = key.replace("model.view_seperator", "view_separator")
291
+
292
+ if "lm_head.weight" in key and "language_model" not in key:
293
+ key = key.replace("lm_head.weight", "language_model.lm_head.weight")
294
+
295
+ return key
296
+
297
+ return {transform_key(k): v for k, v in weights.items()}