fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,146 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .vision import Llama4MultiModalProjector, VisionModel
10
+
11
+
12
+ class Model(nn.Module):
13
+ def __init__(self, config: ModelConfig):
14
+ super().__init__()
15
+ self.config = config
16
+ self.vision_model = VisionModel(config.vision_config)
17
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
18
+ self.language_model = LanguageModel(config.text_config)
19
+ self.vocab_size = config.text_config.vocab_size
20
+
21
+ def set_input_embeddings(self, value):
22
+ self.language_model.set_input_embeddings(value)
23
+
24
+ def get_output_embeddings(self):
25
+ return self.language_model.get_output_embeddings()
26
+
27
+ def set_output_embeddings(self, new_embeddings):
28
+ self.language_model.set_output_embeddings(new_embeddings)
29
+
30
+ def set_decoder(self, decoder):
31
+ self.language_model.set_decoder(decoder)
32
+
33
+ def get_decoder(self):
34
+ return self.language_model.get_decoder()
35
+
36
+ def get_image_features(
37
+ self,
38
+ pixel_values: mx.array,
39
+ vision_feature_layer: Union[int, List[int]],
40
+ vision_feature_select_strategy: str,
41
+ **kwargs,
42
+ ):
43
+ if vision_feature_select_strategy not in ["default", "full"]:
44
+ raise ValueError(
45
+ f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
46
+ )
47
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
48
+ hidden_state = self.vision_model(
49
+ pixel_values, output_hidden_states=False, **kwargs
50
+ )
51
+ return hidden_state
52
+
53
+ def get_input_embeddings(
54
+ self,
55
+ input_ids: Optional[mx.array] = None,
56
+ pixel_values: Optional[mx.array] = None,
57
+ **kwargs,
58
+ ):
59
+ if pixel_values is None:
60
+ return InputEmbeddingsFeatures(
61
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
62
+ )
63
+
64
+ # Get the input embeddings from the language model
65
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
66
+
67
+ image_features = self.get_image_features(
68
+ pixel_values=pixel_values,
69
+ vision_feature_layer=kwargs.get("vision_feature_layer", -1),
70
+ vision_feature_select_strategy=kwargs.get(
71
+ "vision_feature_select_strategy", "default"
72
+ ),
73
+ )
74
+
75
+ vision_flat = image_features.reshape(-1, image_features.shape[-1])
76
+ projected_vision_flat = self.multi_modal_projector(vision_flat)
77
+
78
+ # Insert special image tokens in the input_ids
79
+ final_inputs_embeds = self._prepare_inputs_for_multimodal(
80
+ projected_vision_flat, inputs_embeds, input_ids
81
+ )
82
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
83
+
84
+ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
85
+ image_token_index = self.config.image_token_index
86
+
87
+ # Find positions of <image> tokens
88
+ image_mask = input_ids == image_token_index
89
+
90
+ batch_size, seq_len = input_ids.shape
91
+
92
+ # Process each batch item
93
+ batch_outputs = []
94
+ feature_start_idx = 0
95
+
96
+ for batch_idx in range(batch_size):
97
+ batch_mask = image_mask[batch_idx]
98
+ num_positions = mx.sum(batch_mask).item()
99
+
100
+ if num_positions > 0:
101
+ batch_features = image_features[
102
+ feature_start_idx : feature_start_idx + num_positions
103
+ ]
104
+
105
+ # Create indices for gathering
106
+ cumsum = mx.cumsum(batch_mask.astype(mx.int32))
107
+ feature_indices = mx.where(batch_mask, cumsum - 1, 0)
108
+
109
+ # Gather features
110
+ gathered_features = batch_features[feature_indices]
111
+
112
+ # Combine with original embeddings
113
+ batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1)
114
+ batch_output = mx.where(
115
+ batch_mask_expanded, gathered_features, inputs_embeds[batch_idx]
116
+ )
117
+
118
+ feature_start_idx += num_positions
119
+ else:
120
+ batch_output = inputs_embeds[batch_idx]
121
+
122
+ batch_outputs.append(batch_output)
123
+
124
+ return mx.stack(batch_outputs, axis=0)
125
+
126
+ @property
127
+ def layers(self):
128
+ return self.language_model.model.layers
129
+
130
+ def __call__(
131
+ self,
132
+ input_ids: mx.array,
133
+ pixel_values: mx.array,
134
+ cache=None,
135
+ **kwargs,
136
+ ):
137
+
138
+ input_embeddings_features = self.get_input_embeddings(
139
+ input_ids, pixel_values, **kwargs
140
+ )
141
+ logits = self.language_model(
142
+ inputs=input_ids,
143
+ inputs_embeds=input_embeddings_features.inputs_embeds,
144
+ cache=cache,
145
+ )
146
+ return logits
@@ -0,0 +1,526 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import pixel_shuffle
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 dimensions
14
+ if len(shape) != 4:
15
+ return False
16
+
17
+ out_channels, kH, KW, _ = shape
18
+
19
+ # Check if out_channels is the largest, and kH and KW are the same
20
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ class Llama4MultiModalProjector(nn.Module):
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ self.linear_1 = nn.Linear(
30
+ config.vision_config.vision_output_dim,
31
+ config.text_config.hidden_size,
32
+ bias=False,
33
+ )
34
+
35
+ def __call__(self, image_features):
36
+ hidden_states = self.linear_1(image_features)
37
+ return hidden_states
38
+
39
+
40
+ class Llama4VisionPixelShuffleMLP(nn.Module):
41
+ def __init__(self, config):
42
+ super().__init__()
43
+ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
44
+ self.inner_dim = int(
45
+ config.projector_input_dim // (self.pixel_shuffle_ratio**2)
46
+ )
47
+ self.output_dim = config.projector_output_dim
48
+ self.mlp = Llama4VisionMLP(config, bias=False, is_projector=True)
49
+
50
+ def __call__(self, encoded_patches: mx.array) -> mx.array:
51
+ encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
52
+ return self.mlp(encoded_patches)
53
+
54
+
55
+ # TODO there is a different RoPE for vision encoder, defined as below
56
+ def reshape_for_broadcast(freqs_ci: mx.array, query: mx.array):
57
+ ndim = query.ndim
58
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
59
+ return freqs_ci.reshape(*shape)
60
+
61
+
62
+ def view_as_complex(x):
63
+ """
64
+ Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
65
+
66
+ Args:
67
+ x: A real tensor with last dimension of size 2.
68
+
69
+ Returns:
70
+ A complex tensor with size one less than the input.
71
+ """
72
+ # Ensure the last dimension is size 2
73
+ assert x.shape[-1] == 2, f"Last dimension must be 2, got {x.shape[-1]}"
74
+
75
+ # Get real and imaginary parts
76
+ real, imag = x[..., 0], x[..., 1]
77
+
78
+ # Create complex tensor
79
+ return real + 1j * imag
80
+
81
+
82
+ def view_as_real(x):
83
+ """
84
+ Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
85
+
86
+ Args:
87
+ x: A complex tensor.
88
+
89
+ Returns:
90
+ A real tensor with an extra dimension of size 2.
91
+ """
92
+ # Get real and imaginary parts
93
+ real = mx.real(x)
94
+ imag = mx.imag(x)
95
+
96
+ # Combine into a tensor with last dimension 2
97
+ return mx.stack([real, imag], axis=-1)
98
+
99
+
100
+ def vision_apply_rotary_emb(
101
+ query: mx.array,
102
+ key: mx.array,
103
+ freqs_ci: mx.array,
104
+ ) -> Tuple[mx.array, mx.array]:
105
+
106
+ query_ = view_as_complex(query.astype(mx.float32).reshape(*query.shape[:-1], -1, 2))
107
+ key_ = view_as_complex(key.astype(mx.float32).reshape(*key.shape[:-1], -1, 2))
108
+ freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)
109
+ query_out = view_as_real(query_ * freqs_ci).flatten(3)
110
+ key_out = view_as_real(key_ * freqs_ci).flatten(3)
111
+ return query_out.astype(query.dtype), key_out.astype(key.dtype)
112
+
113
+
114
+ class Llama4VisionAttention(nn.Module):
115
+ def __init__(self, config: VisionConfig):
116
+ super().__init__()
117
+ self.config = config
118
+ self.embed_dim = config.hidden_size
119
+ self.num_heads = config.num_attention_heads
120
+ self.head_dim = config.hidden_size // config.num_attention_heads
121
+ self.num_key_value_groups = 1
122
+ self.scale = self.head_dim**-0.5
123
+
124
+ self.q_proj = nn.Linear(
125
+ self.embed_dim, self.num_heads * self.head_dim, bias=True
126
+ )
127
+ self.k_proj = nn.Linear(
128
+ self.embed_dim, self.num_heads * self.head_dim, bias=True
129
+ )
130
+ self.v_proj = nn.Linear(
131
+ self.embed_dim, self.num_heads * self.head_dim, bias=True
132
+ )
133
+ self.o_proj = nn.Linear(
134
+ self.num_heads * self.head_dim, self.embed_dim, bias=True
135
+ )
136
+
137
+ def __call__(
138
+ self,
139
+ hidden_states: mx.array,
140
+ freqs_ci: mx.array,
141
+ mask: Optional[mx.array] = None,
142
+ cache: Optional[mx.array] = None,
143
+ ):
144
+ B, L, D = hidden_states.shape
145
+
146
+ query_states = self.q_proj(hidden_states).reshape(B, L, self.num_heads, -1)
147
+ key_states = self.k_proj(hidden_states).reshape(B, L, self.num_heads, -1)
148
+ value_states = self.v_proj(hidden_states).reshape(B, L, self.num_heads, -1)
149
+
150
+ query_states, key_states = vision_apply_rotary_emb(
151
+ query_states, key_states, freqs_ci=freqs_ci
152
+ )
153
+
154
+ query_states = query_states.transpose(0, 2, 1, 3)
155
+ key_states = key_states.transpose(0, 2, 1, 3)
156
+ value_states = value_states.transpose(0, 2, 1, 3)
157
+
158
+ attn_output = mx.fast.scaled_dot_product_attention(
159
+ query_states, key_states, value_states, scale=self.scale
160
+ )
161
+
162
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
163
+ attn_output = self.o_proj(attn_output)
164
+ return attn_output
165
+
166
+
167
+ class Llama4VisionMLP(nn.Module):
168
+ def __init__(self, config, bias=True, is_projector=False):
169
+ super().__init__()
170
+ self.config = config
171
+ self.activation_fn = nn.GELU(approx="fast") # ACT2FN[config.hidden_act]
172
+ self.is_projector = is_projector
173
+ self.hidden_size = config.hidden_size
174
+ self.intermediate_size = config.intermediate_size
175
+
176
+ # Determine dimensions for first linear layer based on whether this is a projector
177
+ fc1_input_dim = self.intermediate_size if is_projector else self.hidden_size
178
+ fc1_output_dim = (
179
+ config.projector_input_dim if is_projector else self.intermediate_size
180
+ )
181
+
182
+ self.fc1 = nn.Linear(fc1_input_dim, fc1_output_dim, bias=bias)
183
+
184
+ # Determine dimensions for second linear layer
185
+ fc2_input_dim = (
186
+ config.projector_output_dim if is_projector else self.intermediate_size
187
+ )
188
+ fc2_output_dim = (
189
+ config.projector_output_dim if is_projector else self.hidden_size
190
+ )
191
+
192
+ self.fc2 = nn.Linear(fc2_input_dim, fc2_output_dim, bias=bias)
193
+
194
+ self.is_projector = is_projector
195
+
196
+ def __call__(self, hidden_states: mx.array) -> mx.array:
197
+ hidden_states = self.fc1(hidden_states)
198
+ hidden_states = self.activation_fn(hidden_states)
199
+
200
+ if self.is_projector:
201
+ return self.activation_fn(self.fc2(hidden_states))
202
+
203
+ return self.fc2(hidden_states)
204
+
205
+
206
+ class Llama4VisionEncoderLayer(nn.Module):
207
+ def __init__(self, config: VisionConfig):
208
+ super().__init__()
209
+ self.hidden_size = config.hidden_size
210
+
211
+ self.self_attn = Llama4VisionAttention(config)
212
+ self.mlp = Llama4VisionMLP(config)
213
+
214
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
215
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
216
+
217
+ def __call__(
218
+ self,
219
+ hidden_state: mx.array,
220
+ freqs_ci: mx.array,
221
+ mask: Optional[mx.array] = None,
222
+ ):
223
+ # Self Attention
224
+ residual = hidden_state
225
+
226
+ hidden_state = self.input_layernorm(hidden_state)
227
+
228
+ hidden_state = self.self_attn(
229
+ hidden_state,
230
+ freqs_ci=freqs_ci,
231
+ mask=mask,
232
+ )
233
+ hidden_state = residual + hidden_state
234
+
235
+ # Feed forward
236
+ residual = hidden_state
237
+ hidden_state = self.post_attention_layernorm(hidden_state)
238
+ hidden_state = self.mlp(hidden_state)
239
+ hidden_state = residual + hidden_state
240
+ return hidden_state
241
+
242
+
243
+ class Llama4VisionEncoder(nn.Module):
244
+ """
245
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
246
+ [`Llama4VisionEncoderLayer`].
247
+
248
+ Args:
249
+ config: VisionConfig
250
+ """
251
+
252
+ def __init__(self, config: VisionConfig):
253
+ super().__init__()
254
+ self.config = config
255
+ self.layers = [
256
+ Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
257
+ ]
258
+ self.config = config
259
+
260
+ def __call__(
261
+ self,
262
+ hidden_states: mx.array,
263
+ freqs_ci: mx.array, # TODO move this to an attribute instead of keeping it around
264
+ mask: Optional[mx.array] = None,
265
+ ):
266
+
267
+ for i, encoder_layer in enumerate(self.layers):
268
+ hidden_states = encoder_layer(
269
+ hidden_state=hidden_states,
270
+ mask=mask,
271
+ freqs_ci=freqs_ci,
272
+ )
273
+
274
+ return hidden_states
275
+
276
+
277
+ class Llama4UnfoldConvolution(nn.Module):
278
+ def __init__(self, config):
279
+ super().__init__()
280
+ kernel_size = config.patch_size
281
+ if isinstance(kernel_size, int):
282
+ kernel_size = (kernel_size, kernel_size)
283
+ self.kernel_size = kernel_size
284
+ self.stride = config.patch_size
285
+ self.linear = nn.Linear(
286
+ config.num_channels * kernel_size[0] * kernel_size[1],
287
+ config.hidden_size,
288
+ bias=False,
289
+ )
290
+
291
+ def _pair(self, x):
292
+ """Convert input to a pair of values."""
293
+ if isinstance(x, (list, tuple)):
294
+ return tuple(x)
295
+ return (x, x)
296
+
297
+ def unfold(self, input_tensor):
298
+ """
299
+ Extract sliding local blocks from a batched input tensor (MLX implementation).
300
+
301
+ This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
302
+
303
+ Args:
304
+ input_tensor: Input tensor of shape (B, C, H, W)
305
+
306
+ Returns:
307
+ Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
308
+ where L is the number of blocks
309
+ """
310
+ # Convert to pairs
311
+ kernel_size = self._pair(self.kernel_size)
312
+ stride = self._pair(self.stride)
313
+ padding = (0, 0) # No padding in the original code
314
+ dilation = (1, 1) # Default dilation
315
+
316
+ # Input shape
317
+ batch_size, channels, height, width = input_tensor.shape
318
+
319
+ # Calculate output dimensions
320
+ height_out = (
321
+ height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
322
+ ) // stride[0] + 1
323
+ width_out = (
324
+ width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
325
+ ) // stride[1] + 1
326
+
327
+ # Initialize output arrays
328
+ blocks = []
329
+
330
+ # Extract blocks
331
+ for i in range(0, height - kernel_size[0] * dilation[0] + 1, stride[0]):
332
+ for j in range(0, width - kernel_size[1] * dilation[1] + 1, stride[1]):
333
+ # Extract the block for all channels
334
+ block = []
335
+ for di in range(kernel_size[0]):
336
+ for dj in range(kernel_size[1]):
337
+ h_idx = i + di * dilation[0]
338
+ w_idx = j + dj * dilation[1]
339
+ # Get the block for all channels and add to our list
340
+ block.append(input_tensor[:, :, h_idx, w_idx])
341
+
342
+ # Stack the channel-blocks
343
+ block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
344
+ block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
345
+ blocks.append(block)
346
+
347
+ # Stack all blocks together
348
+ result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
349
+
350
+ # Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
351
+ result = mx.reshape(
352
+ result,
353
+ (
354
+ batch_size,
355
+ channels * kernel_size[0] * kernel_size[1],
356
+ height_out * width_out,
357
+ ),
358
+ )
359
+
360
+ return result
361
+
362
+ def __call__(self, hidden_states: mx.array) -> mx.array:
363
+ hidden_states = self.unfold(hidden_states)
364
+ hidden_states = hidden_states.swapaxes(1, 2)
365
+ hidden_states = self.linear(hidden_states)
366
+ return hidden_states
367
+
368
+
369
+ class Llama4VisionRotaryEmbedding:
370
+ def __init__(self, config):
371
+ super().__init__()
372
+ idx = config.image_size // config.patch_size
373
+ img_idx = mx.arange(idx**2, dtype=mx.int32).reshape(idx**2, 1)
374
+ img_idx = mx.concatenate([img_idx, img_idx[:1]], axis=0)
375
+ img_idx[-1, -1] = -2 # ID_CLS_TOKEN
376
+ frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
377
+ frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
378
+ freq_dim = config.hidden_size // config.num_attention_heads // 2
379
+ rope_freq = 1.0 / (
380
+ config.rope_theta
381
+ ** (
382
+ mx.arange(0, freq_dim, 2, dtype=mx.float32)[: (freq_dim // 2)]
383
+ / freq_dim
384
+ )
385
+ )
386
+
387
+ # Expand dimensions for frequencies_x and frequencies_y
388
+ freqs_x_expanded = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
389
+ freqs_y_expanded = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
390
+
391
+ def repeat_interleave(tensor, repeats, dim=-1):
392
+ # Get the shape
393
+ shape = list(tensor.shape)
394
+
395
+ # Reshape to add an extra dimension for repeating
396
+ tensor = mx.reshape(tensor, shape[:-1] + [shape[-1], 1])
397
+
398
+ # Repeat along the new dimension
399
+ tensor = mx.repeat(tensor, repeats, axis=-1)
400
+
401
+ # Reshape to flatten the last two dimensions
402
+ return mx.reshape(tensor, shape[:-1] + [shape[-1] * repeats])
403
+
404
+ # Apply interleaving
405
+ freqs_x = repeat_interleave(freqs_x_expanded, 2)
406
+ freqs_y = repeat_interleave(freqs_y_expanded, 2)
407
+ freqs = mx.concatenate([freqs_x, freqs_y], axis=-1).astype(mx.float32)[..., ::2]
408
+ # Replaced masked_fill with where
409
+ mask = img_idx.reshape(-1, 1, 1) < 0
410
+ freqs = mx.where(mask, mx.zeros_like(freqs), freqs)
411
+ freq_cis = mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
412
+ freq_cis = view_as_complex(freq_cis)
413
+ self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
414
+
415
+ def __call__(self, hidden_states):
416
+ return self.freqs_ci
417
+
418
+
419
+ class VisionModel(nn.Module):
420
+ def __init__(self, config: VisionConfig):
421
+ super().__init__()
422
+ self.image_size = config.image_size
423
+ self.patch_size = config.patch_size
424
+ self.hidden_size = config.hidden_size
425
+ self.num_channels = config.num_channels
426
+ self.model_type = config.model_type
427
+ if self.model_type not in ["llama4", "llama4_vision_model"]:
428
+ raise ValueError(f"Model type {self.model_type} not supported")
429
+
430
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
431
+ self.scale = config.hidden_size**-0.5
432
+
433
+ self.class_embedding = self.scale * mx.random.normal((self.hidden_size,))
434
+ self.positional_embedding_vlm = self.scale * mx.random.normal(
435
+ (self.num_patches, self.hidden_size)
436
+ )
437
+
438
+ self.patch_embedding = Llama4UnfoldConvolution(config)
439
+
440
+ self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
441
+
442
+ # layer norms
443
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size)
444
+ self.layernorm_post = nn.LayerNorm(self.hidden_size)
445
+
446
+ # encoders
447
+ self.model = Llama4VisionEncoder(config)
448
+ self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
449
+
450
+ def get_input_embeddings(self):
451
+ """
452
+ This function is used to fetch the first embedding layer to activate grads on inputs.
453
+ """
454
+ return self.patch_embedding
455
+
456
+ def __call__(
457
+ self,
458
+ pixel_values: mx.array,
459
+ output_attentions: Optional[bool] = None,
460
+ output_hidden_states: Optional[bool] = None,
461
+ capture_activations: Optional[bool] = True,
462
+ ):
463
+
464
+ batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
465
+ num_concurrent_media = 1
466
+ num_chunks = 1
467
+
468
+ hidden_state = self.patch_embedding(pixel_values)
469
+
470
+ _, num_patches, hidden_dim = hidden_state.shape
471
+
472
+ # Add cls token
473
+ hidden_state = hidden_state.reshape(
474
+ batch_size_times_num_tiles * num_concurrent_media * num_chunks,
475
+ num_patches,
476
+ hidden_dim,
477
+ )
478
+
479
+ class_embedding = mx.broadcast_to(
480
+ self.class_embedding, (hidden_state.shape[0], 1, hidden_state.shape[-1])
481
+ )
482
+ hidden_state = mx.concatenate([hidden_state, class_embedding], axis=1)
483
+ num_patches += 1
484
+
485
+ # Position embeddings
486
+ hidden_state = hidden_state.reshape(
487
+ batch_size_times_num_tiles * num_concurrent_media,
488
+ num_chunks,
489
+ num_patches,
490
+ hidden_dim,
491
+ )
492
+
493
+ positional_embedding = self.positional_embedding_vlm
494
+ hidden_state = hidden_state + positional_embedding
495
+
496
+ hidden_state = self.layernorm_pre(hidden_state)
497
+
498
+ hidden_state = hidden_state.reshape(batch_size_times_num_tiles, -1, hidden_dim)
499
+ freqs_ci = self.rotary_embedding(pixel_values)
500
+
501
+ hidden_state = self.model(
502
+ hidden_state,
503
+ mask=None,
504
+ freqs_ci=freqs_ci,
505
+ )
506
+
507
+ hidden_state = self.layernorm_post(hidden_state)
508
+
509
+ hidden_state = hidden_state[:, :-1, :]
510
+
511
+ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
512
+ final_hidden_state = self.vision_adapter(hidden_state)
513
+
514
+ # Return only the final state
515
+ return final_hidden_state
516
+
517
+ def sanitize(self, weights):
518
+ sanitized_weights = {}
519
+ for k, v in weights.items():
520
+ if "position_ids" in k:
521
+ # Remove unused position_ids
522
+ continue
523
+ else:
524
+ sanitized_weights[k] = v
525
+
526
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .llava import LanguageModel, Model, VisionModel