fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,140 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .vision import VisionModel
10
+
11
+
12
+ class PaliGemmaMultiModalProjector(nn.Module):
13
+ def __init__(self, config: ModelConfig):
14
+ super().__init__()
15
+ self.linear = nn.Linear(
16
+ config.vision_config.hidden_size,
17
+ config.vision_config.projection_dim,
18
+ bias=True,
19
+ )
20
+
21
+ def __call__(self, x: mx.array) -> mx.array:
22
+ output = self.linear(x)
23
+ return output
24
+
25
+
26
+ class Model(nn.Module):
27
+ def __init__(self, config: ModelConfig):
28
+ super().__init__()
29
+ self.model_type = config.model_type
30
+ self.config = config
31
+
32
+ self.vision_tower = VisionModel(config.vision_config)
33
+ self.language_model = LanguageModel(config.text_config)
34
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
35
+
36
+ def get_input_embeddings(
37
+ self,
38
+ input_ids: Optional[mx.array] = None,
39
+ pixel_values: Optional[mx.array] = None,
40
+ mask: Optional[mx.array] = None,
41
+ **kwargs,
42
+ ):
43
+ if pixel_values is None:
44
+ return InputEmbeddingsFeatures(
45
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
46
+ )
47
+
48
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
49
+
50
+ hidden_state, _, _ = self.vision_tower(
51
+ pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype),
52
+ output_hidden_states=True,
53
+ )
54
+
55
+ image_features = hidden_state[None, :].astype(pixel_values.dtype)
56
+ image_features = self.multi_modal_projector(image_features)
57
+
58
+ final_inputs_embeds, final_attention_mask_4d = (
59
+ self._prepare_inputs_for_multimodal(
60
+ image_features, inputs_embeds, input_ids, mask
61
+ )
62
+ )
63
+ return InputEmbeddingsFeatures(
64
+ inputs_embeds=final_inputs_embeds, attention_mask_4d=final_attention_mask_4d
65
+ )
66
+
67
+ def _prepare_inputs_for_multimodal(
68
+ self, image_features, inputs_embeds, input_ids, attention_mask
69
+ ):
70
+ _, _, embed_dim = image_features.shape
71
+
72
+ batch_size, sequence_length = input_ids.shape
73
+ scaled_image_features = image_features / (self.config.hidden_size**0.5)
74
+ final_embedding = mx.zeros((batch_size, sequence_length, embed_dim))
75
+
76
+ text_mask = (input_ids != self.config.image_token_index) & (
77
+ input_ids != self.config.pad_token_id
78
+ )
79
+ image_mask = input_ids == self.config.image_token_index
80
+ pad_mask = input_ids == self.config.pad_token_id
81
+
82
+ # expand masks to match embedding dimension
83
+ text_mask_expanded = mx.expand_dims(text_mask, -1)
84
+ text_mask_expanded = mx.repeat(text_mask_expanded, embed_dim, axis=-1)
85
+ pad_mask_expanded = mx.expand_dims(pad_mask, -1)
86
+ pad_mask_expanded = mx.repeat(pad_mask_expanded, embed_dim, axis=-1)
87
+
88
+ # insert padding and text token embeddings
89
+ final_embedding = mx.where(text_mask_expanded, inputs_embeds, final_embedding)
90
+ final_embedding = mx.where(
91
+ pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding
92
+ )
93
+ pad_size = final_embedding.shape[1] - scaled_image_features.shape[1]
94
+ scaled_image_features = mx.pad(
95
+ scaled_image_features, ((0, 0), (0, pad_size), (0, 0))
96
+ )
97
+ # insert image embeddings - the image mask is always less or equal to the sentence in length
98
+ image_mask_expanded = mx.expand_dims(image_mask, -1)
99
+ image_mask_expanded = mx.repeat(image_mask_expanded, embed_dim, axis=-1)
100
+ final_embedding = mx.where(
101
+ image_mask_expanded, scaled_image_features, final_embedding
102
+ )
103
+
104
+ final_embedding = mx.where(
105
+ pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding
106
+ )
107
+
108
+ attention_mask_expanded_1 = mx.expand_dims(attention_mask, 1)
109
+ attention_mask_expanded_2 = mx.expand_dims(attention_mask, 2)
110
+ final_attention_mask_4d = attention_mask_expanded_1 * attention_mask_expanded_2
111
+ final_attention_mask_4d = final_attention_mask_4d
112
+ final_attention_mask_4d = mx.expand_dims(final_attention_mask_4d, 1)
113
+ final_embedding = mx.array(final_embedding)
114
+ return final_embedding, final_attention_mask_4d
115
+
116
+ @property
117
+ def layers(self):
118
+ return self.language_model.model.layers
119
+
120
+ def __call__(
121
+ self,
122
+ input_ids: mx.array,
123
+ pixel_values: mx.array,
124
+ mask: Optional[mx.array] = None,
125
+ cache: Optional[mx.array] = None,
126
+ **kwargs,
127
+ ):
128
+ input_embeddings_features = self.get_input_embeddings(
129
+ input_ids, pixel_values, mask
130
+ )
131
+ input_embeddings = input_embeddings_features.inputs_embeds
132
+ final_attention_mask_4d = input_embeddings_features.attention_mask_4d
133
+
134
+ logits = self.language_model(
135
+ inputs=input_ids,
136
+ cache=cache,
137
+ inputs_embeds=input_embeddings,
138
+ mask=final_attention_mask_4d,
139
+ )
140
+ return logits
@@ -0,0 +1,218 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 dimensions
14
+ if len(shape) != 4:
15
+ return False
16
+
17
+ out_channels, kH, KW, _ = shape
18
+
19
+ # Check if out_channels is the largest, and kH and KW are the same
20
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ class Attention(nn.Module):
27
+ def __init__(
28
+ self,
29
+ dims: int,
30
+ num_heads: int,
31
+ query_input_dims: Optional[int] = None,
32
+ key_input_dims: Optional[int] = None,
33
+ value_input_dims: Optional[int] = None,
34
+ value_dims: Optional[int] = None,
35
+ value_output_dims: Optional[int] = None,
36
+ bias: bool = True,
37
+ ):
38
+ super().__init__()
39
+
40
+ if (dims % num_heads) != 0:
41
+ raise ValueError(
42
+ "The input feature dimensions should be divisible by the "
43
+ f"number of heads ({dims} % {num_heads}) != 0"
44
+ )
45
+
46
+ query_input_dims = query_input_dims or dims
47
+ key_input_dims = key_input_dims or dims
48
+ value_input_dims = value_input_dims or key_input_dims
49
+ value_dims = value_dims or dims
50
+ value_output_dims = value_output_dims or dims
51
+
52
+ self.num_heads = num_heads
53
+ head_dim = dims // num_heads
54
+ self.scale = head_dim**-0.5
55
+
56
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
57
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
58
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
59
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
60
+
61
+ def __call__(self, x, mask=None):
62
+ queries = self.q_proj(x)
63
+ keys = self.k_proj(x)
64
+ values = self.v_proj(x)
65
+
66
+ num_heads = self.num_heads
67
+ B, L, D = queries.shape
68
+ _, S, _ = keys.shape
69
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
70
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
71
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
72
+
73
+ output = mx.fast.scaled_dot_product_attention(
74
+ queries, keys, values, scale=self.scale, mask=mask
75
+ )
76
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
77
+ return self.out_proj(output)
78
+
79
+
80
+ class MLP(nn.Module):
81
+ def __init__(self, config: VisionConfig):
82
+ super().__init__()
83
+ self.activation_fn = nn.GELU(approx="precise")
84
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
85
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
86
+
87
+ def __call__(self, x: mx.array) -> mx.array:
88
+ x = self.fc1(x)
89
+ x = self.activation_fn(x)
90
+ x = self.fc2(x)
91
+ return x
92
+
93
+
94
+ class EncoderLayer(nn.Module):
95
+ def __init__(self, config: VisionConfig):
96
+ super().__init__()
97
+ self.embed_dim = config.hidden_size
98
+ self.self_attn = Attention(
99
+ config.hidden_size, config.num_attention_heads, bias=True
100
+ )
101
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
102
+ self.mlp = MLP(config)
103
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
104
+
105
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
106
+ r = self.self_attn(self.layer_norm1(x), mask)
107
+ h = x + r
108
+ r = self.mlp(self.layer_norm2(h))
109
+ return h + r
110
+
111
+
112
+ class Encoder(nn.Module):
113
+ def __init__(self, config: VisionConfig):
114
+ super().__init__()
115
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
116
+
117
+ def __call__(
118
+ self,
119
+ x: mx.array,
120
+ output_hidden_states: Optional[bool] = None,
121
+ mask: Optional[mx.array] = None,
122
+ ) -> mx.array:
123
+ encoder_states = (x,) if output_hidden_states else None
124
+ h = x
125
+ for l in self.layers:
126
+ x = l(x, mask=mask)
127
+ if output_hidden_states:
128
+ encoder_states = encoder_states + (x,)
129
+
130
+ h = x[0]
131
+
132
+ return (h, encoder_states)
133
+
134
+
135
+ class VisionEmbeddings(nn.Module):
136
+ def __init__(self, config: VisionConfig):
137
+ super().__init__()
138
+ self.config = config
139
+ self.embed_dim = config.hidden_size
140
+ self.image_size = config.image_size
141
+ self.patch_size = config.patch_size
142
+
143
+ self.patch_embedding = nn.Conv2d(
144
+ in_channels=config.num_channels,
145
+ out_channels=self.embed_dim,
146
+ kernel_size=self.patch_size,
147
+ stride=self.patch_size,
148
+ )
149
+
150
+ self.num_patches = (self.image_size // self.patch_size) ** 2
151
+ self.num_positions = self.num_patches
152
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
153
+
154
+ def __call__(self, x: mx.array) -> mx.array:
155
+ patch_embeddings = self.patch_embedding(x)
156
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
157
+ position_ids = mx.array(np.arange(self.num_positions)[None, :])
158
+ embeddings = patch_embeddings
159
+ embeddings += self.position_embedding(position_ids)
160
+ return embeddings
161
+
162
+
163
+ class SigLipVisionModel(nn.Module):
164
+ def __init__(self, config: VisionConfig):
165
+ super().__init__()
166
+ self.embeddings = VisionEmbeddings(config)
167
+ self.encoder = Encoder(config)
168
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
169
+
170
+ def __call__(
171
+ self,
172
+ x: mx.array,
173
+ output_hidden_states: Optional[bool] = None,
174
+ ) -> mx.array:
175
+ x = self.embeddings(x)
176
+
177
+ encoder_outputs = self.encoder(
178
+ x=x, output_hidden_states=output_hidden_states, mask=None
179
+ )
180
+
181
+ pooler_output = self.post_layernorm(encoder_outputs[0])
182
+
183
+ return pooler_output, x, encoder_outputs[-1]
184
+
185
+
186
+ class VisionModel(nn.Module):
187
+ def __init__(self, config: VisionConfig):
188
+ super().__init__()
189
+ self.model_type = config.model_type
190
+ if self.model_type != "siglip_vision_model":
191
+ raise ValueError(f"Unsupported model type: {self.model_type}")
192
+
193
+ self.vision_model = SigLipVisionModel(config)
194
+
195
+ def __call__(
196
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
197
+ ) -> mx.array:
198
+ return self.vision_model(x, output_hidden_states)
199
+
200
+ def sanitize(self, weights):
201
+ sanitized_weights = {}
202
+ for k, v in weights.items():
203
+ if "position_ids" in k:
204
+ # Remove unused position_ids
205
+ continue
206
+ elif "patch_embedding.weight" in k:
207
+ # PyTorch conv2d weight tensors have shape:
208
+ # [out_channels, in_channels, kH, KW]
209
+ # MLX conv2d expects the weight be of shape:
210
+ # [out_channels, kH, KW, in_channels]
211
+ if check_array_shape(v):
212
+ sanitized_weights[k] = v
213
+ else:
214
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
215
+ else:
216
+ sanitized_weights[k] = v
217
+
218
+ return sanitized_weights
@@ -0,0 +1,5 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .language import LanguageModel
3
+ from .phi3_v import Model
4
+ from .processing_phi3_v import Phi3VImageProcessor, Phi3VProcessor
5
+ from .vision import VisionModel
@@ -0,0 +1,55 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig(BaseModelConfig):
9
+ text_config: "TextConfig" = field(default_factory=lambda: TextConfig())
10
+ vision_config: "VisionConfig" = field(default_factory=lambda: VisionConfig())
11
+ model_type: str = "phi3_v"
12
+ vocab_size: int = 32064
13
+
14
+ num_hidden_layers: int = 32
15
+ intermediate_size: int = 8192
16
+ num_attention_heads: int = 32
17
+ rms_norm_eps: float = 1e-5
18
+
19
+ ignore_index: int = -100
20
+ image_token_index: int = 257152
21
+ hidden_size: int = 2048
22
+ pad_token_id: int = 0
23
+
24
+ num_key_value_heads: int = None
25
+ rope_theta: float = 10000
26
+ rope_traditional: bool = False
27
+ partial_rotary_factor: float = 1.0
28
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
29
+ max_position_embeddings: int = 131072
30
+ original_max_position_embeddings: int = 4096
31
+ eos_token_id: Optional[List[int]] = None
32
+
33
+
34
+ @dataclass
35
+ class TextConfig(BaseModelConfig):
36
+ max_position_embeddings: int = 4096
37
+
38
+
39
+ @dataclass
40
+ class VisionConfig(BaseModelConfig):
41
+ model_type: str = "phi3_v"
42
+ num_hidden_layers: int = 24
43
+ hidden_size: int = 1024
44
+ intermediate_size: int = 4096
45
+ num_attention_heads: int = 16
46
+ image_size: int = 336
47
+ patch_size: int = 14
48
+ projection_dim: int = 768
49
+ vocab_size: int = 32000
50
+ num_channels: int = 3
51
+ layer_norm_eps: float = 1e-5
52
+ image_dim_out: int = (1024,)
53
+ model_name: str = "openai/clip-vit-large-patch14-336"
54
+ name: str = "clip_vision_model"
55
+ num_img_tokens: int = 144
@@ -0,0 +1,2 @@
1
+ class LanguageModel:
2
+ pass
@@ -0,0 +1,239 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from mlx_lm.models.rope_utils import SuScaledRoPE
7
+
8
+ from ..base import InputEmbeddingsFeatures, LanguageModelOutput, create_attention_mask
9
+ from ..cache import KVCache
10
+
11
+ # Import processor to register it with AutoProcessor
12
+ from . import processing_phi3_v # noqa: F401
13
+ from .config import ModelConfig, TextConfig
14
+ from .vision import VisionModel
15
+
16
+
17
+ class Attention(nn.Module):
18
+ def __init__(self, config: ModelConfig):
19
+ super().__init__()
20
+
21
+ dim = config.hidden_size
22
+ self.n_heads = n_heads = config.num_attention_heads
23
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
24
+ self.num_hidden_layers = config.num_hidden_layers
25
+
26
+ self.head_dim = head_dim = config.hidden_size // n_heads
27
+ self.scale = head_dim**-0.5
28
+
29
+ op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
30
+ self.qkv_proj = nn.Linear(dim, op_size, bias=False)
31
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
32
+
33
+ rope_dim = int(head_dim * config.partial_rotary_factor)
34
+
35
+ # Check for Su-scaled RoPE by type or presence of short/long factors
36
+ rope_type = config.rope_scaling.get("type") if config.rope_scaling else None
37
+ has_su_factors = (
38
+ config.rope_scaling
39
+ and "short_factor" in config.rope_scaling
40
+ and "long_factor" in config.rope_scaling
41
+ )
42
+
43
+ if rope_type == "su" or has_su_factors:
44
+ self.rope = SuScaledRoPE(
45
+ rope_dim,
46
+ base=config.rope_theta,
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ original_max_position_embeddings=config.original_max_position_embeddings,
49
+ short_factor=config.rope_scaling["short_factor"],
50
+ long_factor=config.rope_scaling["long_factor"],
51
+ )
52
+ else:
53
+ rope_scale = 1.0
54
+ if config.rope_scaling and rope_type == "linear":
55
+ rope_scale = 1 / config.rope_scaling["factor"]
56
+ self.rope = nn.RoPE(
57
+ rope_dim,
58
+ traditional=config.rope_traditional,
59
+ base=config.rope_theta,
60
+ scale=rope_scale,
61
+ )
62
+
63
+ def __call__(
64
+ self,
65
+ x: mx.array,
66
+ mask: Optional[mx.array] = None,
67
+ cache: Optional[KVCache] = None,
68
+ ) -> mx.array:
69
+ B, L, _ = x.shape
70
+
71
+ qkv = self.qkv_proj(x)
72
+ query_pos = self.n_heads * self.head_dim
73
+ queries, keys, values = mx.split(
74
+ qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
75
+ )
76
+
77
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
78
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
79
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
80
+
81
+ if cache is not None:
82
+ queries = self.rope(queries, offset=cache.offset)
83
+ keys = self.rope(keys, offset=cache.offset)
84
+ keys, values = cache.update_and_fetch(keys, values)
85
+ else:
86
+ queries = self.rope(queries)
87
+ keys = self.rope(keys)
88
+
89
+ output = mx.fast.scaled_dot_product_attention(
90
+ queries, keys, values, scale=self.scale, mask=mask
91
+ )
92
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
93
+ return self.o_proj(output)
94
+
95
+
96
+ class MLP(nn.Module):
97
+ def __init__(self, dim, hidden_dim):
98
+ super().__init__()
99
+ self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
100
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
101
+
102
+ def __call__(self, x) -> mx.array:
103
+ x = self.gate_up_proj(x)
104
+ gate, x = mx.split(x, 2, axis=-1)
105
+ return self.down_proj(nn.silu(gate) * x)
106
+
107
+
108
+ class TransformerBlock(nn.Module):
109
+ def __init__(self, config: TextConfig):
110
+ super().__init__()
111
+ self.num_attention_heads = config.num_attention_heads
112
+ self.hidden_size = config.hidden_size
113
+ self.self_attn = Attention(config)
114
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
115
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
116
+ self.post_attention_layernorm = nn.RMSNorm(
117
+ config.hidden_size, eps=config.rms_norm_eps
118
+ )
119
+ self.config = config
120
+
121
+ def __call__(
122
+ self,
123
+ x: mx.array,
124
+ mask: Optional[mx.array] = None,
125
+ cache: Optional[KVCache] = None,
126
+ ) -> mx.array:
127
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
128
+ h = x + r
129
+ r = self.mlp(self.post_attention_layernorm(h))
130
+ out = h + r
131
+ return out
132
+
133
+
134
+ class Phi3V(nn.Module):
135
+ def __init__(self, config: ModelConfig):
136
+ super().__init__()
137
+ self.config = config
138
+ self.vocab_size = config.vocab_size
139
+ self.num_hidden_layers = config.num_hidden_layers
140
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
141
+ self.vision_embed_tokens = VisionModel(config)
142
+ self.layers = [
143
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
144
+ ]
145
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
146
+
147
+ def __call__(
148
+ self,
149
+ inputs: mx.array,
150
+ inputs_embeds: Optional[mx.array] = None,
151
+ mask: Optional[mx.array] = None,
152
+ cache=None,
153
+ ):
154
+ if inputs_embeds is None:
155
+ h = self.embed_tokens(inputs)
156
+ else:
157
+ h = inputs_embeds
158
+
159
+ if cache is None:
160
+ cache = [None] * len(self.layers)
161
+
162
+ if mask is None:
163
+ mask = create_attention_mask(h, cache[0])
164
+
165
+ for layer, c in zip(self.layers, cache):
166
+ h = layer(h, mask, c)
167
+
168
+ return self.norm(h)
169
+
170
+
171
+ class Model(nn.Module):
172
+ def __init__(self, config: ModelConfig):
173
+ super().__init__()
174
+ self.model_type = config.model_type
175
+ self.model = Phi3V(config)
176
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
177
+ self.config = config
178
+
179
+ def __call__(
180
+ self,
181
+ inputs: mx.array,
182
+ inputs_embeds: Optional[mx.array] = None,
183
+ pixel_values=None,
184
+ mask=None,
185
+ cache=None,
186
+ **kwargs,
187
+ ):
188
+ if inputs_embeds is None:
189
+ input_embeddings_features = self.get_input_embeddings(
190
+ inputs, pixel_values, **kwargs
191
+ )
192
+ inputs_embeds = input_embeddings_features.inputs_embeds
193
+
194
+ out = self.model(inputs, inputs_embeds, mask=mask, cache=cache)
195
+ logits = self.lm_head(out)
196
+
197
+ return LanguageModelOutput(logits=logits)
198
+
199
+ def get_input_embeddings(
200
+ self,
201
+ inputs: mx.array,
202
+ pixel_values: Optional[mx.array] = None,
203
+ **kwargs,
204
+ ):
205
+ image_sizes = kwargs.get("image_sizes", None) if kwargs else None
206
+
207
+ # Get text embeddings
208
+ inputs_embeds = self.model.embed_tokens(inputs)
209
+
210
+ # Find positions where inputs < 0 (image token positions)
211
+ inputs_list = inputs.tolist()
212
+ p = np.argwhere(np.array(inputs_list) < 0).tolist()
213
+
214
+ if pixel_values is not None:
215
+ inputs_embeds = self.model.vision_embed_tokens(
216
+ pixel_values, inputs_embeds, image_sizes, p
217
+ )
218
+
219
+ return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
220
+
221
+ @property
222
+ def layers(self):
223
+ return self.model.layers
224
+
225
+ @property
226
+ def head_dim(self):
227
+ return self.config.hidden_size // self.config.num_attention_heads
228
+
229
+ @property
230
+ def n_kv_heads(self):
231
+ return self.config.num_key_value_heads
232
+
233
+ @property
234
+ def language_model(self):
235
+ return self
236
+
237
+ @property
238
+ def vision_model(self):
239
+ return self.model.vision_embed_tokens