fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,208 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .vision import VisionModel
10
+
11
+
12
+ class LlavaMultiModalProjector(nn.Module):
13
+ def __init__(self, config: ModelConfig):
14
+ super().__init__()
15
+ self.linear_1 = nn.Linear(
16
+ config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
17
+ )
18
+ self.gelu = nn.GELU()
19
+ self.linear_2 = nn.Linear(
20
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=True
21
+ )
22
+
23
+ def __call__(self, x: mx.array) -> mx.array:
24
+ x = self.linear_1(x)
25
+ x = self.gelu(x)
26
+ x = self.linear_2(x)
27
+ return x
28
+
29
+
30
+ class Model(nn.Module):
31
+ def __init__(self, config: ModelConfig):
32
+ super().__init__()
33
+ self.config = config
34
+ self.vision_tower = VisionModel(config.vision_config)
35
+ self.language_model = LanguageModel(config.text_config)
36
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
37
+ self.vision_feature_layer = config.vision_feature_layer
38
+ self.vision_feature_select_strategy = config.vision_feature_select_strategy
39
+
40
+ def get_input_embeddings(
41
+ self,
42
+ input_ids: Optional[mx.array] = None,
43
+ pixel_values: Optional[mx.array] = None,
44
+ **kwargs,
45
+ ):
46
+ if pixel_values is None:
47
+ return InputEmbeddingsFeatures(
48
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
49
+ )
50
+
51
+ # Get the input embeddings from the language model
52
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
53
+
54
+ # Get the output hidden states from the vision model
55
+ if isinstance(pixel_values, list):
56
+ pixel_values = mx.concatenate(
57
+ [mx.array(pv)[None, ...] for pv in pixel_values], axis=0
58
+ )
59
+ if pixel_values.ndim == 3:
60
+ pixel_values = pixel_values[None, ...]
61
+
62
+ # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding
63
+ # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21
64
+ # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85
65
+ *_, hidden_states = self.vision_tower(
66
+ pixel_values.transpose(0, 2, 3, 1),
67
+ output_hidden_states=True,
68
+ )
69
+ # Select the hidden states from the desired layer
70
+ selected_image_feature = hidden_states[self.vision_feature_layer]
71
+
72
+ # Pass image features through the multi-modal projector
73
+ image_features = self.multi_modal_projector(selected_image_feature)
74
+
75
+ # Insert special image tokens in the input_ids
76
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
77
+ self.config.image_token_index, image_features, inputs_embeds, input_ids
78
+ )
79
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
80
+
81
+ @staticmethod
82
+ def merge_input_ids_with_image_features(
83
+ image_token_index, image_features, inputs_embeds, input_ids
84
+ ):
85
+ """Merge image features into input embeddings at image token positions.
86
+
87
+ Args:
88
+ image_token_index: Token ID for image placeholder
89
+ image_features: Vision features from the projector [1, num_features, hidden_dim]
90
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
91
+ input_ids: Input token IDs [batch_size, seq_len]
92
+
93
+ Returns:
94
+ Updated input embeddings with image features inserted
95
+ """
96
+ # Remove the extra batch dimension from image_features if present
97
+ if image_features.ndim == 3 and image_features.shape[0] == 1:
98
+ image_features = image_features.squeeze(0) # [num_features, hidden_dim]
99
+
100
+ # Positions of <image> tokens in input_ids
101
+ image_positions = input_ids == image_token_index
102
+
103
+ # Get dimensions
104
+ batch_size, seq_len = input_ids.shape
105
+
106
+ # Process each batch item
107
+ batch_outputs = []
108
+ feature_start_idx = 0
109
+
110
+ for batch_idx in range(batch_size):
111
+ # Get mask for this batch
112
+ image_mask = image_positions[batch_idx]
113
+ num_positions = mx.sum(image_mask).item()
114
+
115
+ if num_positions > 0:
116
+ # Extract features for this batch
117
+ batch_features = image_features[
118
+ feature_start_idx : feature_start_idx + num_positions
119
+ ]
120
+
121
+ # Validate we have the right number of features
122
+ if batch_features.shape[0] != num_positions:
123
+ raise ValueError(
124
+ f"Number of image token positions ({num_positions}) does not match "
125
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
126
+ )
127
+
128
+ # Create indices for gathering
129
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
130
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
131
+
132
+ # Gather features
133
+ gathered_features = batch_features[feature_indices]
134
+
135
+ # Combine with original embeddings
136
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
137
+ batch_output = mx.where(
138
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
139
+ )
140
+
141
+ feature_start_idx += num_positions
142
+ else:
143
+ # No image tokens in this batch item
144
+ batch_output = inputs_embeds[batch_idx]
145
+
146
+ batch_outputs.append(batch_output)
147
+
148
+ # Stack all batch outputs
149
+ return mx.stack(batch_outputs, axis=0)
150
+
151
+ @property
152
+ def layers(self):
153
+ return self.language_model.model.layers
154
+
155
+ def __call__(
156
+ self,
157
+ input_ids: mx.array,
158
+ pixel_values: mx.array,
159
+ mask: mx.array,
160
+ cache=None,
161
+ **kwargs,
162
+ ):
163
+ input_embeddings_features = self.get_input_embeddings(
164
+ input_ids, pixel_values, **kwargs
165
+ )
166
+ logits = self.language_model(
167
+ input_ids,
168
+ cache=cache,
169
+ inputs_embeds=input_embeddings_features.inputs_embeds,
170
+ )
171
+ return logits
172
+
173
+ def sanitize(self, weights):
174
+ def transform_key(key):
175
+ if "vision_tower" in key and "vision_model" not in key:
176
+ if "transformer" in key:
177
+ key = key.replace("vision_tower", "vision_tower.vision_model")
178
+ if "patch_conv" in key:
179
+ key = key.replace("vision_tower", "vision_tower.vision_model")
180
+ if "ln_pre" in key:
181
+ key = key.replace("vision_tower", "vision_tower.vision_model")
182
+
183
+ elif "vision_encoder" in key and "vision_tower" not in key:
184
+ if "transformer" in key:
185
+ key = key.replace(
186
+ "model.vision_encoder", "vision_tower.vision_model"
187
+ )
188
+ if "patch_conv" in key:
189
+ key = key.replace(
190
+ "model.vision_encoder", "vision_tower.vision_model"
191
+ )
192
+ if "ln_pre" in key:
193
+ key = key.replace(
194
+ "model.vision_encoder", "vision_tower.vision_model"
195
+ )
196
+
197
+ elif "model.language_model" in key and "language_model.model" not in key:
198
+ key = key.replace("model.language_model", "language_model.model")
199
+
200
+ elif "lm_head" in key and "language_model" not in key:
201
+ key = key.replace("lm_head", "language_model.lm_head")
202
+
203
+ elif "model.vision_projection" in key:
204
+ key = key.replace("model.vision_projection", "multi_modal_projector")
205
+
206
+ return key
207
+
208
+ return {transform_key(k): v for k, v in weights.items()}
@@ -0,0 +1,293 @@
1
+ from typing import List, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ # Check if the shape has 4 dimensions
13
+ if len(shape) != 4:
14
+ return False
15
+
16
+ out_channels, kH, KW, _ = shape
17
+
18
+ # Check if out_channels is the largest, and kH and KW are the same
19
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
20
+ return True
21
+ else:
22
+ return False
23
+
24
+
25
+ def position_ids_in_meshgrid(patch_embeds_list, max_width):
26
+ positions = []
27
+ for patch in patch_embeds_list:
28
+ height, width = patch.shape[0], patch.shape[1]
29
+ h_grid, v_grid = mx.meshgrid(mx.arange(height), mx.arange(width), indexing="ij")
30
+ h_grid = h_grid.reshape(-1, 1)
31
+ v_grid = v_grid.reshape(-1, 1)
32
+ ids = h_grid * max_width + v_grid
33
+ positions.append(ids.flatten())
34
+ return mx.concatenate(positions)
35
+
36
+
37
+ def generate_block_attention_mask(patch_embeds_list, tensor):
38
+ seq_len = tensor.shape[1]
39
+ d_min = -1e9 # Using a large negative value as MLX doesn't have finfo
40
+
41
+ causal_mask = mx.full((seq_len, seq_len), vals=d_min)
42
+
43
+ block_end_idx = mx.cumsum(mx.array(patch_embeds_list))
44
+ block_start_idx = mx.concatenate([mx.array([0]), mx.array(patch_embeds_list[:-1])])
45
+ block_start_idx = mx.cumsum(block_start_idx)
46
+
47
+ for start, end in zip(block_start_idx, block_end_idx):
48
+ start, end = int(start), int(end) # Convert to integers for indexing
49
+ causal_mask[start:end, start:end] = 0
50
+
51
+ causal_mask = mx.broadcast_to(
52
+ causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len)
53
+ )
54
+ return causal_mask.astype(tensor.dtype)
55
+
56
+
57
+ def rotate_half(x):
58
+ x1 = x[..., : x.shape[-1] // 2]
59
+ x2 = x[..., x.shape[-1] // 2 :]
60
+ return mx.concatenate((-x2, x1), axis=-1)
61
+
62
+
63
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
64
+ cos = mx.expand_dims(cos, axis=unsqueeze_dim)
65
+ sin = mx.expand_dims(sin, axis=unsqueeze_dim)
66
+ q_embed = (q * cos) + (rotate_half(q) * sin)
67
+ k_embed = (k * cos) + (rotate_half(k) * sin)
68
+ return q_embed, k_embed
69
+
70
+
71
+ class Attention(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dims: int,
75
+ num_heads: int,
76
+ query_input_dims: Optional[int] = None,
77
+ key_input_dims: Optional[int] = None,
78
+ value_input_dims: Optional[int] = None,
79
+ value_dims: Optional[int] = None,
80
+ value_output_dims: Optional[int] = None,
81
+ bias: bool = False,
82
+ ):
83
+ super().__init__()
84
+
85
+ if (dims % num_heads) != 0:
86
+ raise ValueError(
87
+ "The input feature dimensions should be divisible by the "
88
+ f"number of heads ({dims} % {num_heads}) != 0"
89
+ )
90
+
91
+ query_input_dims = query_input_dims or dims
92
+ key_input_dims = key_input_dims or dims
93
+ value_input_dims = value_input_dims or key_input_dims
94
+ value_dims = value_dims or dims
95
+ value_output_dims = value_output_dims or dims
96
+
97
+ self.embed_dim = dims
98
+ self.num_heads = num_heads
99
+ self.head_dim = self.embed_dim // self.num_heads
100
+
101
+ self.scale = self.head_dim**-0.5
102
+
103
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
104
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
105
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
106
+ self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
107
+
108
+ def __call__(self, queries, keys, values, position_embeddings, mask=None):
109
+ queries = self.q_proj(queries)
110
+ keys = self.k_proj(keys)
111
+ values = self.v_proj(values)
112
+
113
+ num_heads = self.num_heads
114
+ B, L, D = queries.shape
115
+ _, S, _ = keys.shape
116
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
117
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
118
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
119
+
120
+ cos, sin = position_embeddings
121
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0)
122
+
123
+ output = mx.fast.scaled_dot_product_attention(
124
+ queries, keys, values, scale=self.scale, mask=mask
125
+ )
126
+
127
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
128
+
129
+ return self.o_proj(output)
130
+
131
+
132
+ class MLP(nn.Module):
133
+ def __init__(self, config: VisionConfig):
134
+ super().__init__()
135
+ dim = config.hidden_size
136
+ hidden_dim = config.intermediate_size
137
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
138
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
139
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
140
+
141
+ def __call__(self, x) -> mx.array:
142
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
143
+
144
+
145
+ class EncoderLayer(nn.Module):
146
+ def __init__(self, config: VisionConfig):
147
+ super().__init__()
148
+ self.embed_dim = config.hidden_size
149
+ self.attention = Attention(
150
+ config.hidden_size, config.num_attention_heads, bias=True
151
+ )
152
+ self.attention_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
153
+ self.feed_forward = MLP(config)
154
+ self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
155
+
156
+ def __call__(
157
+ self,
158
+ x: mx.array,
159
+ position_embeddings: mx.array,
160
+ mask: Optional[mx.array] = None,
161
+ ) -> mx.array:
162
+ y = self.attention_norm(x)
163
+ y = self.attention(y, y, y, position_embeddings, mask)
164
+ x = x + y
165
+ y = self.ffn_norm(x)
166
+ y = self.feed_forward(y)
167
+ return x + y
168
+
169
+
170
+ class Encoder(nn.Module):
171
+ def __init__(self, config: VisionConfig):
172
+ super().__init__()
173
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
174
+
175
+
176
+ class PixtralRotaryEmbedding:
177
+ def __init__(self, config):
178
+ self.dim = config.head_dim
179
+ self.base = config.rope_theta
180
+ max_patches_per_side = config.image_size // config.patch_size
181
+ freqs = 1.0 / (
182
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
183
+ )
184
+
185
+ h = mx.arange(max_patches_per_side)
186
+ w = mx.arange(max_patches_per_side)
187
+
188
+ freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32)
189
+ freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32)
190
+ inv_freq = mx.concatenate(
191
+ [
192
+ mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)),
193
+ mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)),
194
+ ],
195
+ axis=-1,
196
+ ).reshape(-1, self.dim // 2)
197
+
198
+ self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1)
199
+
200
+ def __call__(self, x, position_ids):
201
+ freqs = self.inv_freq[position_ids]
202
+ emb = freqs
203
+ cos = mx.cos(emb)
204
+ sin = mx.sin(emb)
205
+ return cos.astype(x.dtype), sin.astype(x.dtype)
206
+
207
+
208
+ class PixtralVisionModel(nn.Module):
209
+ def __init__(self, config: VisionConfig):
210
+ super().__init__()
211
+ self.config = config
212
+ self.patch_conv = nn.Conv2d(
213
+ in_channels=config.num_channels,
214
+ out_channels=config.hidden_size,
215
+ kernel_size=config.patch_size,
216
+ stride=config.patch_size,
217
+ bias=False,
218
+ )
219
+ self.ln_pre = nn.RMSNorm(config.hidden_size)
220
+ self.transformer = Encoder(config)
221
+ self.patch_positional_embedding = PixtralRotaryEmbedding(config)
222
+
223
+ def __call__(
224
+ self,
225
+ x: List[mx.array],
226
+ output_hidden_states: Optional[bool] = None,
227
+ ) -> mx.array:
228
+
229
+ if x.dtype != self.patch_conv.weight.dtype:
230
+ x = x.astype(self.patch_conv.weight.dtype)
231
+
232
+ patch_embeds_list = self.patch_conv(x)
233
+ patch_embeds = patch_embeds_list.reshape(1, -1, patch_embeds_list.shape[-1])
234
+
235
+ patch_embeds = self.ln_pre(patch_embeds)
236
+
237
+ position_ids = position_ids_in_meshgrid(
238
+ patch_embeds_list,
239
+ max_width=self.config.image_size // self.config.patch_size,
240
+ )
241
+
242
+ position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
243
+
244
+ mask = generate_block_attention_mask(
245
+ [p.shape[1] * p.shape[0] for p in patch_embeds_list], patch_embeds
246
+ )
247
+
248
+ encoder_states = (patch_embeds,) if output_hidden_states else None
249
+
250
+ for l in self.transformer.layers:
251
+ patch_embeds = l(
252
+ patch_embeds, mask=mask, position_embeddings=position_embedding
253
+ )
254
+ if output_hidden_states:
255
+ encoder_states = encoder_states + (patch_embeds,)
256
+
257
+ return patch_embeds, encoder_states
258
+
259
+
260
+ class VisionModel(nn.Module):
261
+ def __init__(self, config: VisionConfig):
262
+ super().__init__()
263
+
264
+ self.model_type = config.model_type
265
+ if self.model_type not in ["clip_vision_model", "pixtral"]:
266
+ raise ValueError(f"Unsupported model type: {self.model_type}")
267
+
268
+ self.vision_model = PixtralVisionModel(config)
269
+
270
+ def __call__(
271
+ self, x: List[mx.array], output_hidden_states: Optional[bool] = None
272
+ ) -> mx.array:
273
+ return self.vision_model(x, output_hidden_states)
274
+
275
+ def sanitize(self, weights):
276
+ sanitized_weights = {}
277
+ for k, v in weights.items():
278
+ if "position_ids" in k:
279
+ # Remove unused position_ids
280
+ continue
281
+ elif "patch_conv.weight" in k:
282
+ # PyTorch conv2d weight tensors have shape:
283
+ # [out_channels, in_channels, kH, KW]
284
+ # MLX conv2d expects the weight be of shape:
285
+ # [out_channels, kH, KW, in_channels]
286
+ if check_array_shape(v):
287
+ sanitized_weights[k] = v
288
+ else:
289
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
290
+ else:
291
+ sanitized_weights[k] = v
292
+
293
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .qwen2_5_vl import LanguageModel, Model, VisionModel
@@ -0,0 +1,90 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VisionConfig(BaseModelConfig):
10
+ model_type: str = "qwen2_5_vl"
11
+ depth: int = 32
12
+ hidden_size: int = 1280
13
+ intermediate_size: int = 3420
14
+ out_hidden_size: int = 1536
15
+ num_heads: int = 16
16
+ image_size: int = 384
17
+ patch_size: int = 14
18
+ vocab_size: int = 32000
19
+ mlp_ratio: float = 4.0
20
+ in_channels: int = 3
21
+ layer_norm_eps: float = 1e-6
22
+ spatial_patch_size: int = 14
23
+ spatial_merge_size: int = 2
24
+ tokens_per_second: int = 2
25
+ temporal_patch_size: int = 2
26
+ window_size: int = 112
27
+ patch_size: int = 14
28
+ fullatt_block_indexes: list[int] = field(default_factory=lambda: [7, 15, 23, 31])
29
+
30
+
31
+ @dataclass
32
+ class TextConfig(BaseModelConfig):
33
+ model_type: str
34
+ hidden_size: int
35
+ num_hidden_layers: int
36
+ intermediate_size: int
37
+ num_attention_heads: int
38
+ rms_norm_eps: float
39
+ vocab_size: int
40
+ num_key_value_heads: Optional[int] = None
41
+ max_position_embeddings: Optional[int] = 128000
42
+ rope_theta: float = 1000000.0
43
+ rope_traditional: bool = False
44
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
45
+ tie_word_embeddings: bool = True
46
+
47
+ def __post_init__(self):
48
+ if self.num_key_value_heads is None:
49
+ self.num_key_value_heads = self.num_attention_heads
50
+
51
+ if self.rope_scaling:
52
+ required_keys = {"mrope_section", "type"}
53
+ if not all(key in self.rope_scaling for key in required_keys):
54
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
55
+
56
+ if not self.rope_scaling["type"] in ["mrope", "default"]:
57
+ raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
58
+
59
+
60
+ @dataclass
61
+ class ModelConfig(BaseModelConfig):
62
+ text_config: TextConfig
63
+ vision_config: VisionConfig
64
+ model_type: str
65
+ ignore_index: int = -100
66
+ image_token_id: int = 151655
67
+ video_token_id: int = 151656
68
+ vision_start_token_id: int = 151652
69
+ vision_end_token_id: int = 151653
70
+ vision_token_id: int = 151654
71
+ vision_feature_select_strategy: str = "default"
72
+ vision_feature_layer: int = -2
73
+ vocab_size: int = 32000
74
+ eos_token_id: Optional[List[int]] = None
75
+
76
+ @classmethod
77
+ def from_dict(cls, params):
78
+ # Copy text config parameters from root level
79
+ excluded_keys = {"vision_config"}
80
+ params["text_config"] = dict(
81
+ filter(lambda x: x[0] not in excluded_keys, params.items())
82
+ )
83
+
84
+ return cls(
85
+ **{
86
+ k: v
87
+ for k, v in params.items()
88
+ if k in inspect.signature(cls).parameters
89
+ }
90
+ )