fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,180 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .vision import VisionModel
10
+
11
+
12
+ class Model(nn.Module):
13
+ def __init__(self, config: ModelConfig):
14
+ super().__init__()
15
+ self.config = config
16
+ self.vision_tower = VisionModel(config.vision_config)
17
+ self.language_model = LanguageModel(config.text_config, config)
18
+
19
+ def get_input_embeddings(
20
+ self,
21
+ input_ids: Optional[mx.array] = None,
22
+ pixel_values: Optional[mx.array] = None,
23
+ **kwargs,
24
+ ):
25
+ image_grid_thw = kwargs.get("image_grid_thw", None)
26
+ video_grid_thw = kwargs.get("video_grid_thw", None)
27
+ mask = kwargs.get("mask", None)
28
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
29
+
30
+ if pixel_values is None:
31
+ # Reset position state for text-only generation
32
+ self.language_model._position_ids = None
33
+ self.language_model._rope_deltas = None
34
+ return InputEmbeddingsFeatures(
35
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
36
+ )
37
+
38
+ dtype = self.vision_tower.patch_embed.proj.weight.dtype
39
+ pixel_values = pixel_values.astype(dtype)
40
+
41
+ # Get the input embeddings from the language model
42
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
43
+
44
+ # Get the ouptut hidden states from the vision model
45
+ hidden_states = self.vision_tower(
46
+ pixel_values, grid_thw, output_hidden_states=False
47
+ )
48
+
49
+ # Insert special image tokens in the input_ids
50
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
51
+ self.config.image_token_id,
52
+ self.config.video_token_id,
53
+ hidden_states,
54
+ inputs_embeds,
55
+ input_ids,
56
+ )
57
+
58
+ # Pre-calculate position_ids for chunked prefill
59
+ if image_grid_thw is not None or video_grid_thw is not None:
60
+ position_ids, rope_deltas = self.language_model.get_rope_index(
61
+ input_ids, image_grid_thw, video_grid_thw, mask
62
+ )
63
+ self.language_model._position_ids = position_ids
64
+ self.language_model._rope_deltas = rope_deltas
65
+
66
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
67
+
68
+ @staticmethod
69
+ def merge_input_ids_with_image_features(
70
+ image_token_id,
71
+ video_token_id,
72
+ image_features,
73
+ inputs_embeds,
74
+ input_ids,
75
+ ):
76
+ """Merge image features into input embeddings at image token positions.
77
+
78
+ Args:
79
+ image_features: Vision features from the vision tower [num_features, hidden_dim]
80
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
81
+ input_ids: Input token IDs [batch_size, seq_len]
82
+
83
+ Returns:
84
+ Updated input embeddings with image features inserted
85
+ """
86
+
87
+ # Positions of <image> tokens in input_ids
88
+ image_positions = input_ids == image_token_id
89
+ if mx.sum(image_positions) == 0:
90
+ image_positions = input_ids == video_token_id
91
+
92
+ # Get dimensions
93
+ batch_size, seq_len = input_ids.shape
94
+
95
+ # Process each batch item
96
+ batch_outputs = []
97
+ feature_start_idx = 0
98
+
99
+ for batch_idx in range(batch_size):
100
+ # Get mask for this batch
101
+ image_mask = image_positions[batch_idx]
102
+ num_positions = mx.sum(image_mask).item()
103
+
104
+ if num_positions > 0:
105
+ # Extract features for this batch
106
+ batch_features = image_features[
107
+ feature_start_idx : feature_start_idx + num_positions
108
+ ]
109
+
110
+ # Validate we have the right number of features
111
+ if batch_features.shape[0] != num_positions:
112
+ raise ValueError(
113
+ f"Number of image token positions ({num_positions}) does not match "
114
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
115
+ )
116
+
117
+ # Create indices for gathering
118
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
119
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
120
+
121
+ # Gather features
122
+ gathered_features = batch_features[feature_indices]
123
+
124
+ # Combine with original embeddings
125
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
126
+ batch_output = mx.where(
127
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
128
+ )
129
+
130
+ feature_start_idx += num_positions
131
+ else:
132
+ # No image tokens in this batch item
133
+ batch_output = inputs_embeds[batch_idx]
134
+
135
+ batch_outputs.append(batch_output)
136
+
137
+ # Stack all batch outputs
138
+ return mx.stack(batch_outputs, axis=0)
139
+
140
+ @property
141
+ def layers(self):
142
+ return self.language_model.model.layers
143
+
144
+ def __call__(
145
+ self,
146
+ input_ids: mx.array,
147
+ pixel_values: Optional[mx.array] = None,
148
+ mask: Optional[mx.array] = None,
149
+ cache=None,
150
+ **kwargs,
151
+ ):
152
+
153
+ input_embeddings_features = self.get_input_embeddings(
154
+ input_ids, pixel_values, **kwargs
155
+ )
156
+ kwargs = {
157
+ "pixel_values": pixel_values,
158
+ **kwargs,
159
+ }
160
+ logits = self.language_model(
161
+ input_ids,
162
+ input_embeddings_features.inputs_embeds,
163
+ mask=mask,
164
+ cache=cache,
165
+ **kwargs,
166
+ )
167
+ return logits
168
+
169
+ def sanitize(self, weights):
170
+ def transform_key(key):
171
+ if "vision_tower" not in key:
172
+ key = key.replace("visual", "vision_tower")
173
+ if "language_model" not in key:
174
+ if "model" in key:
175
+ key = key.replace("model", "language_model.model")
176
+ elif "lm_head" in key:
177
+ key = key.replace("lm_head", "language_model.lm_head")
178
+ return key
179
+
180
+ return {transform_key(k): v for k, v in weights.items()}
@@ -0,0 +1,308 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ # Check if the shape has 4 dimensions
13
+ if len(shape) not in [4, 5]:
14
+ return False
15
+
16
+ B, out_channels, kH, KW, t = shape
17
+
18
+ if t == 3:
19
+ return True
20
+
21
+ # Check if out_channels is the largest, and kH and KW are the same
22
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
23
+ return True
24
+ else:
25
+ return False
26
+
27
+
28
+ def rotate_half(x):
29
+ """Rotates half the hidden dims of the input."""
30
+ x1 = x[..., : x.shape[-1] // 2]
31
+ x2 = x[..., x.shape[-1] // 2 :]
32
+ return mx.concatenate([-x2, x1], axis=-1)
33
+
34
+
35
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
36
+ orig_dtype = tensor.dtype
37
+
38
+ cos = mx.cos(freqs)
39
+ sin = mx.sin(freqs)
40
+
41
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
42
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
43
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
44
+
45
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
46
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
47
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
48
+
49
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
50
+ return output.astype(orig_dtype)
51
+
52
+
53
+ class VisionRotaryEmbedding(nn.Module):
54
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.theta = theta
58
+
59
+ def __call__(self, seqlen: int) -> mx.array:
60
+ inv_freq = 1.0 / (
61
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
62
+ )
63
+ seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
64
+ freqs = mx.outer(seq, inv_freq)
65
+ return freqs
66
+
67
+
68
+ class PatchEmbed(nn.Module):
69
+ def __init__(
70
+ self,
71
+ patch_size: int = 14,
72
+ temporal_patch_size: int = 2,
73
+ in_channels: int = 3,
74
+ embed_dim: int = 1152,
75
+ ) -> None:
76
+ super().__init__()
77
+ self.patch_size = patch_size
78
+ self.temporal_patch_size = temporal_patch_size
79
+ self.in_channels = in_channels
80
+ self.embed_dim = embed_dim
81
+
82
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
83
+ self.proj = nn.Conv3d(
84
+ in_channels,
85
+ embed_dim,
86
+ kernel_size=kernel_size,
87
+ stride=kernel_size,
88
+ bias=False,
89
+ )
90
+
91
+ def __call__(self, hidden_states: mx.array) -> mx.array:
92
+ hidden_states = hidden_states.reshape(
93
+ -1,
94
+ self.in_channels,
95
+ self.temporal_patch_size,
96
+ self.patch_size,
97
+ self.patch_size,
98
+ ).moveaxis(1, 4)
99
+
100
+ hidden_states = self.proj(hidden_states)
101
+ hidden_states = hidden_states.reshape(-1, self.embed_dim)
102
+ return hidden_states
103
+
104
+
105
+ class PatchMerger(nn.Module):
106
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
107
+ super().__init__()
108
+ self.hidden_size = context_dim * (spatial_merge_size**2)
109
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
110
+ self.mlp = [
111
+ nn.Linear(self.hidden_size, self.hidden_size),
112
+ nn.GELU(),
113
+ nn.Linear(self.hidden_size, dim),
114
+ ]
115
+
116
+ def __call__(self, x: mx.array) -> mx.array:
117
+ x = self.ln_q(x).reshape(-1, self.hidden_size)
118
+ for layer in self.mlp:
119
+ x = layer(x)
120
+ return x
121
+
122
+
123
+ class Attention(nn.Module):
124
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
125
+ super().__init__()
126
+ self.num_heads = num_heads
127
+ self.head_dim = head_dim = dim // num_heads
128
+ self.scale = head_dim**-0.5
129
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
130
+ self.proj = nn.Linear(dim, dim)
131
+
132
+ def __call__(
133
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
134
+ ) -> mx.array:
135
+ seq_length = x.shape[0]
136
+ qkv = (
137
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
138
+ )
139
+ q, k, v = mx.split(qkv, 3)
140
+
141
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
142
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
143
+ attention_mask = mx.zeros((seq_length, seq_length), dtype=mx.bool_)
144
+
145
+ for i in range(1, len(cu_seqlens)):
146
+ start = int(cu_seqlens[i - 1])
147
+ end = int(cu_seqlens[i])
148
+ attention_mask[start:end, start:end] = True
149
+
150
+ q = q.transpose(0, 2, 1, 3)
151
+ k = k.transpose(0, 2, 1, 3)
152
+ v = v.transpose(0, 2, 1, 3)
153
+
154
+ output = mx.fast.scaled_dot_product_attention(
155
+ q, k, v, scale=self.scale, mask=attention_mask
156
+ )
157
+ output = output.transpose(0, 2, 1, 3)
158
+ output = output.reshape(seq_length, -1)
159
+ return self.proj(output)
160
+
161
+
162
+ class MLP(nn.Module):
163
+ def __init__(self, dim, hidden_dim):
164
+ super().__init__()
165
+ self.activation_fn = nn.GELU(approx="fast")
166
+ self.fc1 = nn.Linear(dim, hidden_dim)
167
+ self.fc2 = nn.Linear(hidden_dim, dim)
168
+
169
+ def __call__(self, x: mx.array) -> mx.array:
170
+ x = self.activation_fn(self.fc1(x))
171
+ x = self.fc2(x)
172
+ return x
173
+
174
+
175
+ class Qwen2VLVisionBlock(nn.Module):
176
+ def __init__(self, config: VisionConfig) -> None:
177
+ super().__init__()
178
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
179
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
180
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
181
+
182
+ self.attn = Attention(dim=config.embed_dim, num_heads=config.num_heads)
183
+ self.mlp = MLP(dim=config.embed_dim, hidden_dim=mlp_hidden_dim)
184
+
185
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
186
+ hidden_states = hidden_states + self.attn(
187
+ self.norm1(hidden_states),
188
+ cu_seqlens=cu_seqlens,
189
+ rotary_pos_emb=rotary_pos_emb,
190
+ )
191
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
192
+ return hidden_states
193
+
194
+
195
+ class VisionModel(nn.Module):
196
+ def __init__(self, config: VisionConfig) -> None:
197
+ super().__init__()
198
+ self.config = config
199
+ self.model_type = config.model_type
200
+ if self.model_type != "qwen2_vl":
201
+ raise ValueError(f"Unsupported model type: {self.model_type}")
202
+ self.spatial_merge_size = config.spatial_merge_size
203
+
204
+ self.patch_embed = PatchEmbed(
205
+ patch_size=config.patch_size,
206
+ temporal_patch_size=config.temporal_patch_size,
207
+ in_channels=config.in_channels,
208
+ embed_dim=config.embed_dim,
209
+ )
210
+
211
+ head_dim = config.embed_dim // config.num_heads
212
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
213
+
214
+ self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
215
+ self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
216
+
217
+ def rot_pos_emb(self, grid_thw):
218
+ pos_ids = []
219
+
220
+ for t, h, w in grid_thw:
221
+ h, w = int(h), int(w) # Ensure h and w are integers
222
+ hpos_ids = mx.expand_dims(mx.arange(h), 1)
223
+ hpos_ids = mx.repeat(hpos_ids, w, axis=1)
224
+ hpos_ids = hpos_ids.reshape(
225
+ h // self.spatial_merge_size,
226
+ self.spatial_merge_size,
227
+ w // self.spatial_merge_size,
228
+ self.spatial_merge_size,
229
+ )
230
+ hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
231
+ hpos_ids = hpos_ids.flatten()
232
+
233
+ wpos_ids = mx.expand_dims(mx.arange(w), 0)
234
+ wpos_ids = mx.repeat(wpos_ids, h, axis=0)
235
+ wpos_ids = wpos_ids.reshape(
236
+ h // self.spatial_merge_size,
237
+ self.spatial_merge_size,
238
+ w // self.spatial_merge_size,
239
+ self.spatial_merge_size,
240
+ )
241
+ wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
242
+ wpos_ids = wpos_ids.flatten()
243
+
244
+ stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
245
+ pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
246
+
247
+ pos_ids = mx.concatenate(pos_ids, axis=0)
248
+ max_grid_size = mx.max(grid_thw[:, 1:])
249
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
250
+
251
+ rotary_pos_emb_full = rotary_pos_emb_full[pos_ids]
252
+
253
+ return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1)
254
+
255
+ def __call__(
256
+ self,
257
+ hidden_states: mx.array,
258
+ grid_thw: mx.array,
259
+ output_hidden_states: Optional[bool] = None,
260
+ ) -> mx.array:
261
+ hidden_states = self.patch_embed(hidden_states)
262
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
263
+
264
+ # Assuming grid_thw has shape (batch_size, 3)
265
+ batch_size = grid_thw.shape[0]
266
+
267
+ # Calculate cu_seqlens for each item in the batch
268
+ cu_seqlens = []
269
+ for i in range(batch_size):
270
+ seq_len = grid_thw[i, 1] * grid_thw[i, 2]
271
+ cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
272
+
273
+ # Concatenate the cu_seqlens for all items in the batch
274
+ cu_seqlens = mx.concatenate(cu_seqlens)
275
+
276
+ cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
277
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
278
+
279
+ encoder_states = (hidden_states,) if output_hidden_states else None
280
+
281
+ for blk in self.blocks:
282
+ hidden_states = blk(
283
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
284
+ )
285
+ if output_hidden_states:
286
+ encoder_states = encoder_states + (hidden_states,)
287
+
288
+ return self.merger(hidden_states)
289
+
290
+ def sanitize(self, weights):
291
+ sanitized_weights = {}
292
+ for k, v in weights.items():
293
+ if "position_ids" in k:
294
+ # Remove unused position_ids
295
+ continue
296
+ elif "patch_embed.proj.weight" in k:
297
+ # PyTorch conv2d weight tensors have shape:
298
+ # [out_channels, in_channels, kH, KW]
299
+ # MLX conv2d expects the weight be of shape:
300
+ # [out_channels, kH, KW, in_channels]
301
+ if check_array_shape(v):
302
+ sanitized_weights[k] = v
303
+ else:
304
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
305
+ else:
306
+ sanitized_weights[k] = v
307
+
308
+ return sanitized_weights
@@ -0,0 +1,29 @@
1
+ from .audio import AudioModel
2
+ from .config import (
3
+ AudioConfig,
4
+ Code2WavConfig,
5
+ CodePredictorConfig,
6
+ ModelConfig,
7
+ TalkerConfig,
8
+ TextConfig,
9
+ ThinkerConfig,
10
+ VisionConfig,
11
+ )
12
+ from .language import LanguageModel
13
+ from .qwen3_omni_moe import Model
14
+ from .vision import VisionModel
15
+
16
+ __all__ = [
17
+ "Model",
18
+ "ModelConfig",
19
+ "LanguageModel",
20
+ "VisionModel",
21
+ "AudioModel",
22
+ "TextConfig",
23
+ "VisionConfig",
24
+ "AudioConfig",
25
+ "ThinkerConfig",
26
+ "TalkerConfig",
27
+ "CodePredictorConfig",
28
+ "Code2WavConfig",
29
+ ]