fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,184 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .vision import VisionModel
10
+
11
+
12
+ class Model(nn.Module):
13
+ def __init__(self, config: ModelConfig):
14
+ super().__init__()
15
+ self.config = config
16
+ self.vision_tower = VisionModel(config.vision_config)
17
+ self.language_model = LanguageModel(config.text_config, config)
18
+
19
+ def get_input_embeddings(
20
+ self,
21
+ input_ids: Optional[mx.array] = None,
22
+ pixel_values: Optional[mx.array] = None,
23
+ **kwargs,
24
+ ):
25
+ image_grid_thw = kwargs.get("image_grid_thw", None)
26
+ video_grid_thw = kwargs.get("video_grid_thw", None)
27
+ mask = kwargs.get("mask", None)
28
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
29
+ if pixel_values is None:
30
+ # Reset position state for text-only generation
31
+ self.language_model._position_ids = None
32
+ self.language_model._rope_deltas = None
33
+ return InputEmbeddingsFeatures(
34
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
35
+ )
36
+
37
+ dtype = self.vision_tower.patch_embed.proj.weight.dtype
38
+ pixel_values = pixel_values.astype(dtype)
39
+
40
+ # Get the input embeddings from the language model
41
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
42
+
43
+ # Get the ouptut hidden states from the vision model
44
+ hidden_states = self.vision_tower(
45
+ pixel_values, grid_thw, output_hidden_states=False
46
+ )
47
+
48
+ # Insert special image tokens in the input_ids
49
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
50
+ self.config.image_token_id,
51
+ self.config.video_token_id,
52
+ hidden_states,
53
+ inputs_embeds,
54
+ input_ids,
55
+ )
56
+
57
+ # Pre-calculate position_ids for chunked prefill
58
+ if image_grid_thw is not None or video_grid_thw is not None:
59
+ position_ids, rope_deltas = self.language_model.get_rope_index(
60
+ input_ids, image_grid_thw, video_grid_thw, mask
61
+ )
62
+ self.language_model._position_ids = position_ids
63
+ self.language_model._rope_deltas = rope_deltas
64
+
65
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
66
+
67
+ @staticmethod
68
+ def merge_input_ids_with_image_features(
69
+ image_token_id,
70
+ video_token_id,
71
+ image_features,
72
+ inputs_embeds,
73
+ input_ids,
74
+ ):
75
+ """Merge image features into input embeddings at image token positions.
76
+
77
+ Args:
78
+ image_token_id: The token ID for image placeholders
79
+ video_token_id: The token ID for video placeholders (fallback)
80
+ image_features: Vision features from the vision tower [num_features, hidden_dim]
81
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
82
+ input_ids: Input token IDs [batch_size, seq_len]
83
+ grid_thw: Grid dimensions for each image (optional, not used in simple case)
84
+
85
+ Returns:
86
+ Updated input embeddings with image features inserted
87
+ """
88
+ # Find positions of image tokens
89
+ image_positions = input_ids == image_token_id
90
+ if mx.sum(image_positions) == 0:
91
+ image_positions = input_ids == video_token_id
92
+
93
+ # Get dimensions
94
+ batch_size, seq_len = input_ids.shape
95
+
96
+ # Process each batch item
97
+ batch_outputs = []
98
+ feature_start_idx = 0
99
+
100
+ for batch_idx in range(batch_size):
101
+ # Get mask for this batch
102
+ image_mask = image_positions[batch_idx]
103
+ num_positions = mx.sum(image_mask).item()
104
+
105
+ if num_positions > 0:
106
+ # Extract features for this batch
107
+ batch_features = image_features[
108
+ feature_start_idx : feature_start_idx + num_positions
109
+ ]
110
+
111
+ # Validate we have the right number of features
112
+ if batch_features.shape[0] != num_positions:
113
+ raise ValueError(
114
+ f"Number of image token positions ({num_positions}) does not match "
115
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
116
+ )
117
+
118
+ # Create indices for gathering
119
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
120
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
121
+
122
+ # Gather features
123
+ gathered_features = batch_features[feature_indices]
124
+
125
+ # Combine with original embeddings
126
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
127
+ batch_output = mx.where(
128
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
129
+ )
130
+
131
+ feature_start_idx += num_positions
132
+ else:
133
+ # No image tokens in this batch item
134
+ batch_output = inputs_embeds[batch_idx]
135
+
136
+ batch_outputs.append(batch_output)
137
+
138
+ # Stack all batch outputs
139
+ return mx.stack(batch_outputs, axis=0)
140
+
141
+ @property
142
+ def layers(self):
143
+ return self.language_model.model.layers
144
+
145
+ def __call__(
146
+ self,
147
+ input_ids: mx.array,
148
+ pixel_values: Optional[mx.array] = None,
149
+ mask: Optional[mx.array] = None,
150
+ cache=None,
151
+ **kwargs,
152
+ ):
153
+
154
+ input_embeddings_features = self.get_input_embeddings(
155
+ input_ids, pixel_values, **kwargs
156
+ )
157
+
158
+ kwargs = {
159
+ "pixel_values": pixel_values,
160
+ **kwargs,
161
+ }
162
+
163
+ logits = self.language_model(
164
+ input_ids,
165
+ input_embeddings_features.inputs_embeds,
166
+ mask=mask,
167
+ cache=cache,
168
+ **kwargs,
169
+ )
170
+
171
+ return logits
172
+
173
+ def sanitize(self, weights):
174
+ def transform_key(key):
175
+ if "vision_tower" not in key:
176
+ key = key.replace("visual", "vision_tower")
177
+ if "language_model" not in key:
178
+ if "model" in key:
179
+ key = key.replace("model", "language_model.model")
180
+ elif "lm_head" in key:
181
+ key = key.replace("lm_head", "language_model.lm_head")
182
+ return key
183
+
184
+ return {transform_key(k): v for k, v in weights.items()}
@@ -0,0 +1,414 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 dimensions
14
+ if len(shape) not in [4, 5]:
15
+ return False
16
+
17
+ B, out_channels, kH, KW, t = shape
18
+
19
+ if t == 3:
20
+ return True
21
+
22
+ # Check if out_channels is the largest, and kH and KW are the same
23
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
24
+ return True
25
+ else:
26
+ return False
27
+
28
+
29
+ def rotate_half(x):
30
+ """Rotates half the hidden dims of the input."""
31
+ x1 = x[..., : x.shape[-1] // 2]
32
+ x2 = x[..., x.shape[-1] // 2 :]
33
+ return mx.concatenate([-x2, x1], axis=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
37
+ orig_dtype = tensor.dtype
38
+
39
+ cos = mx.cos(freqs)
40
+ sin = mx.sin(freqs)
41
+
42
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
43
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
44
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
45
+
46
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
47
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
48
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
49
+
50
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
51
+ return output.astype(orig_dtype)
52
+
53
+
54
+ class VisionRotaryEmbedding(nn.Module):
55
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.theta = theta
59
+
60
+ def __call__(self, seqlen: int) -> mx.array:
61
+ inv_freq = 1.0 / (
62
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
63
+ )
64
+ seq = mx.arange(seqlen.item(), dtype=inv_freq.dtype)
65
+ freqs = mx.outer(seq, inv_freq)
66
+ return freqs
67
+
68
+
69
+ class PatchEmbed(nn.Module):
70
+ def __init__(
71
+ self,
72
+ patch_size: int = 14,
73
+ temporal_patch_size: int = 2,
74
+ in_channels: int = 3,
75
+ hidden_size: int = 1152,
76
+ ) -> None:
77
+ super().__init__()
78
+ self.patch_size = patch_size
79
+ self.temporal_patch_size = temporal_patch_size
80
+ self.in_channels = in_channels
81
+ self.hidden_size = hidden_size
82
+
83
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
84
+ self.proj = nn.Conv3d(
85
+ in_channels,
86
+ hidden_size,
87
+ kernel_size=kernel_size,
88
+ stride=kernel_size,
89
+ bias=False,
90
+ )
91
+
92
+ def __call__(self, hidden_states: mx.array) -> mx.array:
93
+ hidden_states = hidden_states.reshape(
94
+ -1,
95
+ self.in_channels,
96
+ self.temporal_patch_size,
97
+ self.patch_size,
98
+ self.patch_size,
99
+ ).moveaxis(1, 4)
100
+
101
+ hidden_states = self.proj(hidden_states)
102
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
103
+ return hidden_states
104
+
105
+
106
+ class PatchMerger(nn.Module):
107
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
108
+ super().__init__()
109
+ self.hidden_size = context_dim * (spatial_merge_size**2)
110
+ self.ln_q = nn.RMSNorm(context_dim, eps=1e-6)
111
+ self.mlp = [
112
+ nn.Linear(self.hidden_size, self.hidden_size),
113
+ nn.GELU(),
114
+ nn.Linear(self.hidden_size, dim),
115
+ ]
116
+
117
+ def __call__(self, x: mx.array) -> mx.array:
118
+ x = self.ln_q(x).reshape(-1, self.hidden_size)
119
+ for layer in self.mlp:
120
+ x = layer(x)
121
+ return x
122
+
123
+
124
+ class Attention(nn.Module):
125
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
126
+ super().__init__()
127
+ self.num_heads = num_heads
128
+ self.head_dim = head_dim = dim // num_heads
129
+ self.scale = head_dim**-0.5
130
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
131
+ self.proj = nn.Linear(dim, dim)
132
+
133
+ def __call__(
134
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
135
+ ) -> mx.array:
136
+ seq_length = x.shape[0]
137
+ qkv = (
138
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
139
+ )
140
+ q, k, v = mx.split(qkv, 3)
141
+
142
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
143
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
144
+ attention_mask = mx.full(
145
+ (1, seq_length, seq_length), mx.finfo(q.dtype).min, dtype=q.dtype
146
+ )
147
+
148
+ for i in range(1, len(cu_seqlens)):
149
+ start = int(cu_seqlens[i - 1])
150
+ end = int(cu_seqlens[i])
151
+ attention_mask[..., start:end, start:end] = 0
152
+
153
+ q = q.transpose(0, 2, 1, 3)
154
+ k = k.transpose(0, 2, 1, 3)
155
+ v = v.transpose(0, 2, 1, 3)
156
+
157
+ output = mx.fast.scaled_dot_product_attention(
158
+ q, k, v, scale=self.scale, mask=attention_mask
159
+ )
160
+ output = output.transpose(0, 2, 1, 3)
161
+ output = output.reshape(seq_length, -1)
162
+ return self.proj(output)
163
+
164
+
165
+ class MLP(nn.Module):
166
+ def __init__(self, dim, hidden_dim):
167
+ super().__init__()
168
+ self.gate_proj = nn.Linear(dim, hidden_dim)
169
+ self.up_proj = nn.Linear(dim, hidden_dim)
170
+ self.down_proj = nn.Linear(hidden_dim, dim)
171
+
172
+ def __call__(self, x: mx.array) -> mx.array:
173
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
174
+
175
+
176
+ class Qwen2VLVisionBlock(nn.Module):
177
+ def __init__(self, config: VisionConfig) -> None:
178
+ super().__init__()
179
+ self.norm1 = nn.RMSNorm(config.hidden_size, eps=1e-6)
180
+ self.norm2 = nn.RMSNorm(config.hidden_size, eps=1e-6)
181
+
182
+ self.attn = Attention(dim=config.hidden_size, num_heads=config.num_heads)
183
+ self.mlp = MLP(dim=config.hidden_size, hidden_dim=config.intermediate_size)
184
+
185
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
186
+ hidden_states = hidden_states + self.attn(
187
+ self.norm1(hidden_states),
188
+ cu_seqlens=cu_seqlens,
189
+ rotary_pos_emb=rotary_pos_emb,
190
+ )
191
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
192
+ return hidden_states
193
+
194
+
195
+ class VisionModel(nn.Module):
196
+
197
+ def __init__(self, config: VisionConfig) -> None:
198
+ super().__init__()
199
+ self.config = config
200
+ self.model_type = config.model_type
201
+ if self.model_type != "qwen2_5_vl":
202
+ raise ValueError(f"Unsupported model type: {self.model_type}")
203
+ self.spatial_merge_size = config.spatial_merge_size
204
+
205
+ self.patch_embed = PatchEmbed(
206
+ patch_size=config.patch_size,
207
+ temporal_patch_size=config.temporal_patch_size,
208
+ in_channels=config.in_channels,
209
+ hidden_size=config.hidden_size,
210
+ )
211
+
212
+ self.window_size = config.window_size
213
+ self.patch_size = config.patch_size
214
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
215
+ self.fullatt_block_indexes = config.fullatt_block_indexes
216
+ head_dim = config.hidden_size // config.num_heads
217
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
218
+
219
+ self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
220
+ self.merger = PatchMerger(
221
+ dim=config.out_hidden_size, context_dim=config.hidden_size
222
+ )
223
+
224
+ def rot_pos_emb(self, grid_thw):
225
+ pos_ids = []
226
+
227
+ for t, h, w in grid_thw.tolist():
228
+ hpos_ids = mx.expand_dims(mx.arange(h), 1)
229
+ hpos_ids = mx.repeat(hpos_ids, w, axis=1)
230
+ hpos_ids = hpos_ids.reshape(
231
+ h // self.spatial_merge_size,
232
+ self.spatial_merge_size,
233
+ w // self.spatial_merge_size,
234
+ self.spatial_merge_size,
235
+ )
236
+ hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
237
+ hpos_ids = hpos_ids.flatten()
238
+
239
+ wpos_ids = mx.expand_dims(mx.arange(w), 0)
240
+ wpos_ids = mx.repeat(wpos_ids, h, axis=0)
241
+ wpos_ids = wpos_ids.reshape(
242
+ h // self.spatial_merge_size,
243
+ self.spatial_merge_size,
244
+ w // self.spatial_merge_size,
245
+ self.spatial_merge_size,
246
+ )
247
+ wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
248
+ wpos_ids = wpos_ids.flatten()
249
+
250
+ stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
251
+ pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
252
+
253
+ pos_ids = mx.concatenate(pos_ids, axis=0)
254
+ max_grid_size = mx.max(grid_thw[:, 1:])
255
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
256
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids]
257
+
258
+ return rotary_pos_emb.reshape(pos_ids.shape[0], -1)
259
+
260
+ def get_window_index(self, grid_thw):
261
+ window_index = []
262
+ cu_window_seqlens = [0]
263
+ window_index_id = 0
264
+ vit_merger_window_size = (
265
+ self.window_size // self.spatial_merge_size // self.patch_size
266
+ )
267
+
268
+ for grid_t, grid_h, grid_w in grid_thw.tolist():
269
+ llm_grid_h = grid_h // self.spatial_merge_size
270
+ llm_grid_w = grid_w // self.spatial_merge_size
271
+
272
+ index = mx.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
273
+ grid_t, llm_grid_h, llm_grid_w
274
+ )
275
+
276
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
277
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
278
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
279
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
280
+
281
+ # Replace F.pad with np.pad
282
+ index_padded = mx.pad(
283
+ index,
284
+ ((0, 0), (0, pad_h), (0, pad_w)),
285
+ mode="constant",
286
+ constant_values=-100,
287
+ )
288
+
289
+ index_padded = index_padded.reshape(
290
+ grid_t,
291
+ num_windows_h,
292
+ vit_merger_window_size,
293
+ num_windows_w,
294
+ vit_merger_window_size,
295
+ )
296
+
297
+ # Replace permute with np.transpose
298
+ index_padded = mx.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
299
+ grid_t,
300
+ num_windows_h * num_windows_w,
301
+ vit_merger_window_size,
302
+ vit_merger_window_size,
303
+ )
304
+
305
+ # Replace torch operations with numpy
306
+ seqlens = mx.sum(index_padded != -100, axis=(2, 3)).reshape(-1)
307
+ index_padded = index_padded.reshape(-1)
308
+ index = np.where(index_padded != -100)[
309
+ 0
310
+ ].tolist() # [i for i, x in enumerate(index_padded) if x != -100]
311
+ index_new = index_padded[index]
312
+
313
+ window_index.append(index_new + window_index_id)
314
+ cu_seqlens_tmp = (
315
+ mx.cumsum(seqlens, axis=0) * self.spatial_merge_unit
316
+ + cu_window_seqlens[-1]
317
+ )
318
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
319
+ window_index_id += int(grid_t * llm_grid_h * llm_grid_w)
320
+
321
+ # Replace torch.cat with np.concatenate
322
+ window_index = mx.concatenate(window_index, axis=0)
323
+ cu_window_seqlens = mx.array(cu_window_seqlens)
324
+
325
+ return window_index, cu_window_seqlens
326
+
327
+ def __call__(
328
+ self,
329
+ hidden_states: mx.array,
330
+ grid_thw: mx.array,
331
+ output_hidden_states: Optional[bool] = None,
332
+ ) -> mx.array:
333
+
334
+ hidden_states = self.patch_embed(hidden_states)
335
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
336
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
337
+
338
+ # Get indices of first occurrence of each unique value
339
+ seen = set()
340
+ idx = []
341
+ for i, x in enumerate(cu_window_seqlens):
342
+ if x not in seen:
343
+ seen.add(x)
344
+ idx.append(i)
345
+
346
+ idx = mx.array(idx, dtype=mx.int32)
347
+ cu_window_seqlens = cu_window_seqlens[idx]
348
+
349
+ seq_len, _ = hidden_states.shape
350
+ hidden_states = hidden_states.reshape(
351
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
352
+ )
353
+ hidden_states = hidden_states[window_index, :, :]
354
+ hidden_states = hidden_states.reshape(seq_len, -1)
355
+ rotary_pos_emb = rotary_pos_emb.reshape(
356
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
357
+ )
358
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
359
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
360
+
361
+ # Assuming grid_thw has shape (batch_size, 3)
362
+ batch_size = grid_thw.shape[0]
363
+
364
+ # Calculate cu_seqlens for each item in the batch
365
+ cu_seqlens = []
366
+ for i in range(batch_size):
367
+ seq_len = grid_thw[i, 1] * grid_thw[i, 2]
368
+ cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
369
+
370
+ # Concatenate the cu_seqlens for all items in the batch
371
+ cu_seqlens = mx.concatenate(cu_seqlens)
372
+
373
+ cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
374
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
375
+
376
+ encoder_states = (hidden_states,) if output_hidden_states else None
377
+
378
+ for layer_num, blk in enumerate(self.blocks):
379
+ if layer_num in self.fullatt_block_indexes:
380
+ cu_seqlens_now = cu_seqlens
381
+ else:
382
+ cu_seqlens_now = cu_window_seqlens
383
+
384
+ hidden_states = blk(
385
+ hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb
386
+ )
387
+
388
+ if output_hidden_states:
389
+ encoder_states = encoder_states + (hidden_states,)
390
+
391
+ hidden_states = self.merger(hidden_states)
392
+ reverse_indices = mx.argsort(window_index, axis=0)
393
+ hidden_states = hidden_states[reverse_indices, :]
394
+ return hidden_states
395
+
396
+ def sanitize(self, weights):
397
+ sanitized_weights = {}
398
+ for k, v in weights.items():
399
+ if "position_ids" in k:
400
+ # Remove unused position_ids
401
+ continue
402
+ elif "patch_embed.proj.weight" in k:
403
+ # PyTorch conv2d weight tensors have shape:
404
+ # [out_channels, in_channels, kH, KW]
405
+ # MLX conv2d expects the weight be of shape:
406
+ # [out_channels, kH, KW, in_channels]
407
+ if check_array_shape(v):
408
+ sanitized_weights[k] = v
409
+ else:
410
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
411
+ else:
412
+ sanitized_weights[k] = v
413
+
414
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .qwen2_vl import LanguageModel, Model, VisionModel
@@ -0,0 +1,86 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VisionConfig(BaseModelConfig):
10
+ model_type: str = "qwen2_vl"
11
+ depth: int = 32
12
+ embed_dim: int = 1280
13
+ hidden_size: int = 1536
14
+ num_heads: int = 16
15
+ image_size: int = 384
16
+ patch_size: int = 14
17
+ vocab_size: int = 32000
18
+ mlp_ratio: float = 4.0
19
+ in_channels: int = 3
20
+ layer_norm_eps: float = 1e-6
21
+ spatial_patch_size: int = 14
22
+ spatial_merge_size: int = 2
23
+ temporal_patch_size: int = 2
24
+
25
+
26
+ @dataclass
27
+ class TextConfig(BaseModelConfig):
28
+ model_type: str
29
+ hidden_size: int
30
+ num_hidden_layers: int
31
+ intermediate_size: int
32
+ num_attention_heads: int
33
+ rms_norm_eps: float
34
+ vocab_size: int
35
+ num_key_value_heads: Optional[int] = 8
36
+ max_position_embeddings: Optional[int] = 40960
37
+ rope_theta: float = 1000000.0
38
+ rope_traditional: bool = False
39
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
40
+ tie_word_embeddings: bool = False
41
+ sliding_window: int = 32768
42
+ use_sliding_window: bool = False
43
+ use_cache: bool = True
44
+
45
+ def __post_init__(self):
46
+ if self.num_key_value_heads is None:
47
+ self.num_key_value_heads = self.num_attention_heads
48
+
49
+ if self.rope_scaling:
50
+ required_keys = {"mrope_section", "type"}
51
+ if not all(key in self.rope_scaling for key in required_keys):
52
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
53
+
54
+ if not self.rope_scaling["type"] in ["mrope", "default"]:
55
+ raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
56
+
57
+
58
+ @dataclass
59
+ class ModelConfig(BaseModelConfig):
60
+ text_config: TextConfig
61
+ vision_config: VisionConfig
62
+ model_type: str
63
+ ignore_index: int = -100
64
+ image_token_id: int = 151655
65
+ video_token_id: int = 151656
66
+ vision_start_token_id: int = 151652
67
+ vision_feature_select_strategy: str = "default"
68
+ vision_feature_layer: int = -2
69
+ vocab_size: int = 32000
70
+ eos_token_id: Optional[List[int]] = None
71
+
72
+ @classmethod
73
+ def from_dict(cls, params):
74
+ # Copy text config parameters from root level
75
+ excluded_keys = {"vision_config"}
76
+ params["text_config"] = dict(
77
+ filter(lambda x: x[0] not in excluded_keys, params.items())
78
+ )
79
+
80
+ return cls(
81
+ **{
82
+ k: v
83
+ for k, v in params.items()
84
+ if k in inspect.signature(cls).parameters
85
+ }
86
+ )