fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,166 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import InputEmbeddingsFeatures
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel
10
+ from .vision import VisionModel
11
+
12
+
13
+ def masked_scatter(
14
+ final_embedding: mx.array,
15
+ image_mask_expanded: mx.array,
16
+ scaled_image_features: mx.array,
17
+ ):
18
+ # Reshape the tensors to 1D
19
+ final_embedding_shape = final_embedding.shape
20
+ scaled_image_features_flattened = mx.flatten(scaled_image_features)
21
+ final_embedding_flattened = mx.flatten(final_embedding)
22
+ image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
23
+
24
+ # Scatter the scaled image features into the special image token positions
25
+ image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
26
+ final_embedding_flattened[image_positions] = scaled_image_features_flattened
27
+
28
+ # Reshape back to the original shape
29
+ final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
30
+
31
+ return final_embedding
32
+
33
+
34
+ class Model(nn.Module):
35
+ def __init__(self, config: ModelConfig):
36
+ super().__init__()
37
+ self.config = config
38
+ self.vision_tower = VisionModel(config.vision_config)
39
+ self.language_model = LanguageModel(config.text_config, config)
40
+
41
+ def get_input_embeddings(
42
+ self,
43
+ input_ids: Optional[mx.array] = None,
44
+ pixel_values: Optional[mx.array] = None,
45
+ **kwargs,
46
+ ):
47
+ image_grid_thw = kwargs.get("image_grid_thw", None)
48
+ video_grid_thw = kwargs.get("video_grid_thw", None)
49
+ mask = kwargs.get("mask", None)
50
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
51
+
52
+ if pixel_values is None:
53
+ # Reset position state for text-only generation
54
+ self.language_model._position_ids = None
55
+ self.language_model._rope_deltas = None
56
+ return InputEmbeddingsFeatures(
57
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
58
+ )
59
+
60
+ dtype = self.vision_tower.patch_embed.proj.weight.dtype
61
+ pixel_values = pixel_values.astype(dtype)
62
+
63
+ # Get the input embeddings from the language model
64
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
65
+
66
+ # Get the ouptut hidden states from the vision model
67
+ hidden_states, deepstack_image_embeds = self.vision_tower(
68
+ pixel_values, grid_thw
69
+ )
70
+
71
+ visual_pos_masks = None
72
+ deepstack_visual_embeds = None
73
+
74
+ # Insert special image tokens in the input_ids
75
+ inputs_embeds, image_mask = self.merge_input_ids_with_image_features(
76
+ hidden_states,
77
+ inputs_embeds,
78
+ input_ids,
79
+ self.config.image_token_index,
80
+ self.config.video_token_index,
81
+ )
82
+
83
+ image_mask = image_mask[..., 0]
84
+ visual_pos_masks = image_mask
85
+ deepstack_visual_embeds = mx.eval(deepstack_image_embeds)
86
+
87
+ # Pre-calculate position_ids for chunked prefill
88
+ if image_grid_thw is not None or video_grid_thw is not None:
89
+ position_ids, rope_deltas = self.language_model.get_rope_index(
90
+ input_ids, image_grid_thw, video_grid_thw, mask
91
+ )
92
+ self.language_model._position_ids = position_ids
93
+ self.language_model._rope_deltas = rope_deltas
94
+
95
+ return InputEmbeddingsFeatures(
96
+ inputs_embeds=inputs_embeds,
97
+ visual_pos_masks=visual_pos_masks,
98
+ deepstack_visual_embeds=deepstack_visual_embeds,
99
+ )
100
+
101
+ @staticmethod
102
+ def merge_input_ids_with_image_features(
103
+ image_features, inputs_embeds, input_ids, image_token_index, video_token_index
104
+ ):
105
+ special_image_mask = input_ids == image_token_index
106
+ special_video_mask = input_ids == video_token_index
107
+ special_image_mask = special_image_mask | special_video_mask
108
+ n_image_tokens = special_image_mask.sum()
109
+ special_image_mask = special_image_mask[..., None]
110
+ special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape)
111
+
112
+ n_image_features = image_features.shape[0]
113
+ n_image_mask_elements = special_image_mask.sum()
114
+ if n_image_mask_elements != image_features.size:
115
+ raise ValueError(
116
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
117
+ )
118
+
119
+ inputs_embeds = masked_scatter(
120
+ inputs_embeds, special_image_mask, image_features
121
+ )
122
+
123
+ return inputs_embeds, special_image_mask
124
+
125
+ @property
126
+ def layers(self):
127
+ return self.language_model.model.layers
128
+
129
+ def __call__(
130
+ self,
131
+ input_ids: mx.array,
132
+ pixel_values: Optional[mx.array] = None,
133
+ mask: Optional[mx.array] = None,
134
+ cache=None,
135
+ **kwargs,
136
+ ):
137
+
138
+ input_embeddings_features = self.get_input_embeddings(
139
+ input_ids, pixel_values, **kwargs
140
+ )
141
+
142
+ kwargs.update(
143
+ {
144
+ "pixel_values": pixel_values,
145
+ **input_embeddings_features.to_dict(),
146
+ }
147
+ )
148
+
149
+ logits = self.language_model(input_ids, mask=mask, cache=cache, **kwargs)
150
+ return logits
151
+
152
+ def sanitize(self, weights):
153
+ sanitized_weights = {}
154
+ for key, value in weights.items():
155
+ if "model" in key:
156
+ if "model.language_model" in key:
157
+ key = key.replace("model.language_model", "language_model.model")
158
+
159
+ elif "model.visual" in key:
160
+ key = key.replace("model.visual", "vision_tower")
161
+ elif "lm_head" in key:
162
+ key = key.replace("lm_head", "language_model.lm_head")
163
+
164
+ sanitized_weights[key] = value
165
+
166
+ return sanitized_weights
@@ -0,0 +1,441 @@
1
+ from itertools import accumulate
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ # Check if the shape has 4 or 5 dimensions
13
+ if len(shape) not in [4, 5]:
14
+ return False
15
+
16
+ B, out_channels, kH, KW, t = shape
17
+
18
+ if t == 3:
19
+ return True
20
+
21
+ # Check if out_channels is the largest, and kH and KW are the same
22
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
23
+ return True
24
+ else:
25
+ return False
26
+
27
+
28
+ def rotate_half(x):
29
+ """Rotates half the hidden dims of the input."""
30
+ x1 = x[..., : x.shape[-1] // 2]
31
+ x2 = x[..., x.shape[-1] // 2 :]
32
+ return mx.concatenate([-x2, x1], axis=-1)
33
+
34
+
35
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
36
+ orig_dtype = tensor.dtype
37
+
38
+ cos = mx.cos(freqs)
39
+ sin = mx.sin(freqs)
40
+
41
+ cos = mx.expand_dims(cos, axis=1)
42
+ cos = mx.tile(cos, (1, 1, 2))
43
+ cos = mx.expand_dims(cos, axis=0)
44
+
45
+ sin = mx.expand_dims(sin, axis=1)
46
+ sin = mx.tile(sin, (1, 1, 2))
47
+ sin = mx.expand_dims(sin, axis=0)
48
+
49
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
50
+ return output.astype(orig_dtype)
51
+
52
+
53
+ class VisionRotaryEmbedding(nn.Module):
54
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.theta = theta
58
+
59
+ def __call__(self, seqlen: int) -> mx.array:
60
+ inv_freq = 1.0 / (
61
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
62
+ )
63
+ seq = mx.arange(seqlen, dtype=inv_freq.dtype)
64
+ freqs = mx.outer(seq, inv_freq)
65
+ return freqs
66
+
67
+
68
+ class PatchEmbed(nn.Module):
69
+ def __init__(
70
+ self,
71
+ patch_size: int = 14,
72
+ temporal_patch_size: int = 2,
73
+ in_channels: int = 3,
74
+ hidden_size: int = 1152,
75
+ ) -> None:
76
+ super().__init__()
77
+ self.patch_size = patch_size
78
+ self.temporal_patch_size = temporal_patch_size
79
+ self.in_channels = in_channels
80
+ self.hidden_size = hidden_size
81
+
82
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
83
+ self.proj = nn.Conv3d(
84
+ in_channels,
85
+ hidden_size,
86
+ kernel_size=kernel_size,
87
+ stride=kernel_size,
88
+ bias=True,
89
+ )
90
+
91
+ def __call__(self, hidden_states: mx.array) -> mx.array:
92
+ hidden_states = hidden_states.reshape(
93
+ -1,
94
+ self.in_channels,
95
+ self.temporal_patch_size,
96
+ self.patch_size,
97
+ self.patch_size,
98
+ ).moveaxis(1, 4)
99
+
100
+ hidden_states = self.proj(hidden_states)
101
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
102
+ return hidden_states
103
+
104
+
105
+ class PatchMerger(nn.Module):
106
+ def __init__(self, config: VisionConfig, use_postshuffle_norm=False) -> None:
107
+ super().__init__()
108
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
109
+ self.use_postshuffle_norm = use_postshuffle_norm
110
+ self.norm = nn.LayerNorm(
111
+ self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6
112
+ )
113
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
114
+ self.act_fn = nn.GELU()
115
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
116
+
117
+ def __call__(self, x: mx.array) -> mx.array:
118
+ x = self.norm(
119
+ x.reshape(-1, self.hidden_size) if self.use_postshuffle_norm else x
120
+ ).reshape(-1, self.hidden_size)
121
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
122
+ return x
123
+
124
+
125
+ class Attention(nn.Module):
126
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
127
+ super().__init__()
128
+ self.num_heads = num_heads
129
+ self.head_dim = head_dim = dim // num_heads
130
+ self.scale = head_dim**-0.5
131
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
132
+ self.proj = nn.Linear(dim, dim)
133
+
134
+ def __call__(
135
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
136
+ ) -> mx.array:
137
+ seq_length = x.shape[0]
138
+ qkv = (
139
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
140
+ )
141
+ q, k, v = mx.split(qkv, 3)
142
+
143
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
144
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
145
+
146
+ q = q.transpose(0, 2, 1, 3)
147
+ k = k.transpose(0, 2, 1, 3)
148
+ v = v.transpose(0, 2, 1, 3)
149
+
150
+ splits = [
151
+ mx.split(tensor, cu_seqlens[1:-1].tolist(), axis=2) for tensor in (q, k, v)
152
+ ]
153
+
154
+ attn_outputs = []
155
+ for q, k, v in zip(*splits):
156
+ output = mx.fast.scaled_dot_product_attention(
157
+ q, k, v, scale=self.scale, mask=None
158
+ )
159
+ attn_outputs.append(output)
160
+
161
+ output = mx.concatenate(attn_outputs, axis=2)
162
+ output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1)
163
+ return self.proj(output)
164
+
165
+
166
+ class MLP(nn.Module):
167
+ def __init__(self, dim, hidden_dim):
168
+ super().__init__()
169
+ self.linear_fc1 = nn.Linear(dim, hidden_dim, bias=True)
170
+ self.linear_fc2 = nn.Linear(hidden_dim, dim, bias=True)
171
+ self.act_fn = nn.GELU(approx="tanh")
172
+
173
+ def __call__(self, x: mx.array) -> mx.array:
174
+ return self.linear_fc2(self.act_fn(self.linear_fc1(x)))
175
+
176
+
177
+ class Qwen3VLMoEVisionBlock(nn.Module):
178
+ def __init__(self, config: VisionConfig) -> None:
179
+ super().__init__()
180
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
181
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
182
+
183
+ self.attn = Attention(dim=config.hidden_size, num_heads=config.num_heads)
184
+ self.mlp = MLP(dim=config.hidden_size, hidden_dim=config.intermediate_size)
185
+
186
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
187
+ hidden_states = hidden_states + self.attn(
188
+ self.norm1(hidden_states),
189
+ cu_seqlens=cu_seqlens,
190
+ rotary_pos_emb=rotary_pos_emb,
191
+ )
192
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
193
+ return hidden_states
194
+
195
+
196
+ class VisionModel(nn.Module):
197
+ def __init__(self, config: VisionConfig) -> None:
198
+ super().__init__()
199
+ self.config = config
200
+ self.model_type = config.model_type
201
+
202
+ if self.model_type != "qwen3_vl":
203
+ raise ValueError(f"Unsupported model type: {self.model_type}")
204
+
205
+ self.spatial_merge_size = config.spatial_merge_size
206
+
207
+ self.patch_embed = PatchEmbed(
208
+ patch_size=config.patch_size,
209
+ temporal_patch_size=config.temporal_patch_size,
210
+ in_channels=config.in_channels,
211
+ hidden_size=config.hidden_size,
212
+ )
213
+
214
+ head_dim = config.hidden_size // config.num_heads
215
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
216
+
217
+ self.pos_embed = nn.Embedding(
218
+ config.num_position_embeddings, config.hidden_size
219
+ )
220
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
221
+
222
+ self.blocks = [Qwen3VLMoEVisionBlock(config) for _ in range(config.depth)]
223
+ self.merger = PatchMerger(config=config, use_postshuffle_norm=False)
224
+
225
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
226
+ self.deepstack_merger_list = [
227
+ PatchMerger(
228
+ config=config,
229
+ use_postshuffle_norm=True,
230
+ )
231
+ for _ in range(len(config.deepstack_visual_indexes))
232
+ ]
233
+
234
+ def rot_pos_emb(self, grid_thw: mx.array) -> mx.array:
235
+ merge_size = self.spatial_merge_size
236
+
237
+ # Get max grid size for frequency table
238
+ max_hw = int(mx.max(grid_thw[:, 1:]).item())
239
+ freq_table = self.rotary_pos_emb(max_hw) # Shape: (max_hw, dim // 2)
240
+
241
+ pos_ids = []
242
+
243
+ for num_frames, height, width in grid_thw.tolist():
244
+ num_frames, height, width = int(num_frames), int(height), int(width)
245
+ merged_h, merged_w = height // merge_size, width // merge_size
246
+
247
+ # Create block indices
248
+ block_rows = mx.arange(merged_h)
249
+ block_cols = mx.arange(merged_w)
250
+
251
+ # Create intra-block indices
252
+ intra_row = mx.arange(merge_size)
253
+ intra_col = mx.arange(merge_size)
254
+
255
+ # Compute full-resolution positions
256
+ row_idx = (
257
+ block_rows[:, None, None, None] * merge_size
258
+ + intra_row[None, None, :, None]
259
+ )
260
+ col_idx = (
261
+ block_cols[None, :, None, None] * merge_size
262
+ + intra_col[None, None, None, :]
263
+ )
264
+
265
+ # Broadcast and flatten
266
+ row_idx = mx.broadcast_to(
267
+ row_idx, (merged_h, merged_w, merge_size, merge_size)
268
+ ).reshape(-1)
269
+ col_idx = mx.broadcast_to(
270
+ col_idx, (merged_h, merged_w, merge_size, merge_size)
271
+ ).reshape(-1)
272
+
273
+ # Stack into coordinate pairs
274
+ coords = mx.stack([row_idx, col_idx], axis=-1)
275
+
276
+ # Repeat for temporal dimension
277
+ if num_frames > 1:
278
+ coords = mx.tile(coords, (num_frames, 1))
279
+
280
+ pos_ids.append(coords)
281
+
282
+ # Concatenate all position IDs - shape: (total_tokens, 2)
283
+ pos_ids = mx.concatenate(pos_ids, axis=0)
284
+
285
+ # Lookup embeddings: freq_table[h_pos] and freq_table[w_pos]
286
+ # pos_ids[:, 0] = height positions, pos_ids[:, 1] = width positions
287
+ h_embeddings = freq_table[pos_ids[:, 0]] # (total_tokens, dim // 2)
288
+ w_embeddings = freq_table[pos_ids[:, 1]] # (total_tokens, dim // 2)
289
+
290
+ # Concatenate height and width embeddings
291
+ embeddings = mx.concatenate([h_embeddings, w_embeddings], axis=-1)
292
+
293
+ return embeddings
294
+
295
+ def fast_pos_embed_interpolate(self, grid_thw):
296
+ grid_thw_list = grid_thw.tolist()
297
+ idx_list = [[] for _ in range(4)]
298
+ weight_list = [[] for _ in range(4)]
299
+
300
+ for t, h, w in grid_thw_list:
301
+ h = int(h)
302
+ w = int(w)
303
+ t = int(t)
304
+
305
+ h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
306
+ w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
307
+
308
+ h_idxs_floor = h_idxs.astype(mx.int32)
309
+ w_idxs_floor = w_idxs.astype(mx.int32)
310
+ h_idxs_ceil = mx.minimum(h_idxs_floor + 1, self.num_grid_per_side - 1)
311
+ w_idxs_ceil = mx.minimum(w_idxs_floor + 1, self.num_grid_per_side - 1)
312
+
313
+ dh = h_idxs - h_idxs_floor.astype(mx.float32)
314
+ dw = w_idxs - w_idxs_floor.astype(mx.float32)
315
+
316
+ base_h = h_idxs_floor * self.num_grid_per_side
317
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
318
+
319
+ indices = [
320
+ (base_h[:, None] + w_idxs_floor[None, :]).flatten(),
321
+ (base_h[:, None] + w_idxs_ceil[None, :]).flatten(),
322
+ (base_h_ceil[:, None] + w_idxs_floor[None, :]).flatten(),
323
+ (base_h_ceil[:, None] + w_idxs_ceil[None, :]).flatten(),
324
+ ]
325
+
326
+ weights = [
327
+ ((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(),
328
+ ((1 - dh)[:, None] * dw[None, :]).flatten(),
329
+ (dh[:, None] * (1 - dw)[None, :]).flatten(),
330
+ (dh[:, None] * dw[None, :]).flatten(),
331
+ ]
332
+
333
+ for i in range(4):
334
+ idx_list[i].extend(indices[i].tolist())
335
+ weight_list[i].extend(weights[i].tolist())
336
+
337
+ idx_tensor = mx.array(idx_list, dtype=mx.int32)
338
+ weight_tensor = mx.array(weight_list, dtype=self.pos_embed.weight.dtype)
339
+
340
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
341
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
342
+
343
+ split_sizes = [int(h * w) for t, h, w in grid_thw_list]
344
+ if len(split_sizes) > 1:
345
+ split_indices = list(accumulate(split_sizes[:-1]))
346
+ patch_pos_embeds_split = mx.split(patch_pos_embeds, split_indices, axis=0)
347
+ else:
348
+ patch_pos_embeds_split = [patch_pos_embeds]
349
+
350
+ patch_pos_embeds_permute = []
351
+ merge_size = self.config.spatial_merge_size
352
+
353
+ for pos_embed, (t, h, w) in zip(patch_pos_embeds_split, grid_thw_list):
354
+ t, h, w = int(t), int(h), int(w)
355
+ feature_dim = pos_embed.shape[-1]
356
+ pos_embed = mx.tile(pos_embed, (t, 1))
357
+ pos_embed = pos_embed.reshape(t, h, w, feature_dim)
358
+ pos_embed = (
359
+ pos_embed.reshape(
360
+ t,
361
+ h // merge_size,
362
+ merge_size,
363
+ w // merge_size,
364
+ merge_size,
365
+ feature_dim,
366
+ )
367
+ .transpose(0, 1, 3, 2, 4, 5)
368
+ .reshape(-1, feature_dim)
369
+ )
370
+ patch_pos_embeds_permute.append(pos_embed)
371
+
372
+ patch_pos_embeds = mx.concatenate(patch_pos_embeds_permute)
373
+ return patch_pos_embeds
374
+
375
+ def __call__(
376
+ self,
377
+ hidden_states: mx.array,
378
+ grid_thw: mx.array,
379
+ **kwargs,
380
+ ) -> mx.array:
381
+
382
+ hidden_states = self.patch_embed(hidden_states)
383
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
384
+ hidden_states = hidden_states + pos_embeds
385
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
386
+
387
+ seq_len = hidden_states.shape[0]
388
+ hidden_states = hidden_states.reshape(seq_len, -1)
389
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
390
+
391
+ # Assuming grid_thw has shape (batch_size, 3)
392
+ batch_size = grid_thw.shape[0]
393
+
394
+ # Calculate cu_seqlens for each item in the batch
395
+ cu_seqlens = []
396
+ for i in range(batch_size):
397
+ seq_len = grid_thw[i, 1] * grid_thw[i, 2]
398
+ cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
399
+
400
+ # Concatenate the cu_seqlens for all items in the batch
401
+ cu_seqlens = mx.concatenate(cu_seqlens)
402
+
403
+ cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
404
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
405
+
406
+ deepstack_feature_lists = []
407
+ for layer_num, blk in enumerate(self.blocks):
408
+ hidden_states = blk(
409
+ hidden_states,
410
+ cu_seqlens=cu_seqlens,
411
+ rotary_pos_emb=rotary_pos_emb,
412
+ )
413
+ if layer_num in self.deepstack_visual_indexes:
414
+ deepstack_feature = self.deepstack_merger_list[
415
+ self.deepstack_visual_indexes.index(layer_num)
416
+ ](hidden_states)
417
+ deepstack_feature_lists.append(deepstack_feature)
418
+
419
+ hidden_states = self.merger(hidden_states)
420
+
421
+ return hidden_states, deepstack_feature_lists
422
+
423
+ def sanitize(self, weights):
424
+ sanitized_weights = {}
425
+ for k, v in weights.items():
426
+ if "position_ids" in k:
427
+ # Remove unused position_ids
428
+ continue
429
+ elif "patch_embed.proj.weight" in k:
430
+ # PyTorch conv2d weight tensors have shape:
431
+ # [out_channels, in_channels, kH, KW]
432
+ # MLX conv2d expects the weight be of shape:
433
+ # [out_channels, kH, KW, in_channels]
434
+ if check_array_shape(v):
435
+ sanitized_weights[k] = v
436
+ else:
437
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
438
+ else:
439
+ sanitized_weights[k] = v
440
+
441
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .qwen3_vl_moe import LanguageModel, Model, VisionModel