fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,338 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from PIL import Image
7
+ from transformers.image_processing_utils import BatchFeature
8
+ from transformers.image_utils import to_numpy_array
9
+
10
+ from ..base import BaseImageProcessor, InputEmbeddingsFeatures, expand2square
11
+ from .config import ModelConfig
12
+ from .language import LanguageModel
13
+ from .vision import VisionModel
14
+
15
+
16
+ class ImageProcessor(BaseImageProcessor):
17
+ model_input_names = ["pixel_values"]
18
+
19
+ def __init__(
20
+ self,
21
+ config,
22
+ image_size: int = 384,
23
+ min_size: int = 14,
24
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
25
+ 0.5,
26
+ 0.5,
27
+ 0.5,
28
+ ),
29
+ image_std: Union[Tuple[float, float, float], List[float]] = (
30
+ 0.5,
31
+ 0.5,
32
+ 0.5,
33
+ ),
34
+ rescale_factor: float = 1.0 / 255.0,
35
+ do_normalize: bool = True,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+ if "high_res_cfg" in config["vision_config"]["params"]:
40
+ self.image_size = config["vision_config"]["params"]["high_res_cfg"][
41
+ "image_size"
42
+ ]
43
+ self.image_mean = config["vision_config"]["params"]["high_res_cfg"][
44
+ "pixel_mean"
45
+ ]
46
+ self.image_std = config["vision_config"]["params"]["high_res_cfg"][
47
+ "pixel_std"
48
+ ]
49
+ self.do_normalize = False
50
+ else:
51
+ self.image_size = image_size
52
+ self.image_mean = image_mean
53
+ self.image_std = image_std
54
+ self.do_normalize = do_normalize
55
+
56
+ self.rescale_factor = rescale_factor
57
+ self.min_size = min_size
58
+
59
+ if image_mean is None:
60
+ self.background_color = (127, 127, 127)
61
+ else:
62
+ self.background_color = tuple([int(x * 255) for x in self.image_mean])
63
+
64
+ def resize(self, pil_img: Image) -> np.ndarray:
65
+ """
66
+
67
+ Args:
68
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
69
+
70
+ Returns:
71
+ x (np.ndarray): [3, self.image_size, self.image_size]
72
+ """
73
+
74
+ width, height = pil_img.size
75
+ max_size = max(width, height)
76
+
77
+ size = [
78
+ max(int(height / max_size * self.image_size), self.min_size),
79
+ max(int(width / max_size * self.image_size), self.min_size),
80
+ ]
81
+
82
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
83
+ print(f"orig size = {pil_img.size}, new size = {size}")
84
+ raise ValueError("Invalid size!")
85
+
86
+ pil_img = pil_img.resize(size=tuple(size[::-1]), resample=Image.BICUBIC)
87
+
88
+ pil_img = expand2square(pil_img, self.background_color)
89
+ x = to_numpy_array(pil_img)
90
+
91
+ # [H, W, 3] -> [3, H, W]
92
+ x = np.transpose(x, (2, 0, 1))
93
+
94
+ return x
95
+
96
+ def preprocess(self, images, **kwargs) -> BatchFeature:
97
+ # resize and pad to [self.image_size, self.image_size]
98
+ # then convert from [H, W, 3] to [3, H, W]
99
+ images: List[np.ndarray] = [self.resize(image) for image in images]
100
+
101
+ # resacle from [0, 255] -> [0, 1]
102
+ images = [
103
+ self.rescale(
104
+ image=image,
105
+ scale=self.rescale_factor,
106
+ input_data_format="channels_first",
107
+ )
108
+ for image in images
109
+ ]
110
+
111
+ # normalize
112
+ if self.do_normalize:
113
+ images = [
114
+ self.normalize(
115
+ image=image,
116
+ mean=self.image_mean,
117
+ std=self.image_std,
118
+ input_data_format="channels_first",
119
+ )
120
+ for image in images
121
+ ]
122
+
123
+ return images
124
+
125
+
126
+ class MlpProjector(nn.Module):
127
+ def __init__(self, config: ModelConfig):
128
+ super().__init__()
129
+
130
+ if config.projector_config.params["projector_type"] == "mlp_gelu":
131
+ self.layers = [
132
+ nn.Linear(
133
+ config.vision_config.hidden_size,
134
+ config.text_config.hidden_size,
135
+ bias=True,
136
+ )
137
+ ]
138
+ mlp_depth = config.projector_config.params["depth"]
139
+ for _ in range(1, mlp_depth):
140
+ self.layers.append(nn.GELU())
141
+ self.layers.append(
142
+ nn.Linear(
143
+ config.text_config.hidden_size,
144
+ config.text_config.hidden_size,
145
+ bias=True,
146
+ )
147
+ )
148
+ elif (
149
+ config.projector_config.params["projector_type"]
150
+ == "low_high_hybrid_split_mlp_gelu"
151
+ ):
152
+ mlp_depth = config.projector_config.params["depth"]
153
+ self.high_up_proj = nn.Linear(
154
+ config.vision_config.hidden_size, config.text_config.hidden_size // 2
155
+ )
156
+ self.low_up_proj = nn.Linear(
157
+ config.vision_config.hidden_size, config.text_config.hidden_size // 2
158
+ )
159
+
160
+ self.layers = []
161
+ for _ in range(1, mlp_depth):
162
+ self.layers.append(nn.GELU())
163
+ self.layers.append(
164
+ nn.Linear(
165
+ config.text_config.hidden_size, config.text_config.hidden_size
166
+ )
167
+ )
168
+
169
+ else:
170
+ projector_type = config.projector_config.params["projector_type"]
171
+ raise ValueError(f"Unknown projector type: {projector_type}")
172
+
173
+ def __call__(self, x: Union[mx.array, Tuple]) -> mx.array:
174
+
175
+ if isinstance(x, tuple):
176
+ high_x, low_x = x
177
+
178
+ high_x = self.high_up_proj(high_x)
179
+ low_x = self.low_up_proj(low_x)
180
+
181
+ B, D = high_x.shape[0], high_x.shape[-1]
182
+ high_x = high_x.reshape(B, -1, D)
183
+
184
+ x = mx.concatenate([high_x, low_x], axis=-1)
185
+
186
+ for layer in self.layers:
187
+ x = layer(x)
188
+
189
+ return x
190
+
191
+
192
+ class Model(nn.Module):
193
+ def __init__(self, config: ModelConfig):
194
+ super().__init__()
195
+ self.config = config
196
+ self.vision_model = VisionModel(config.vision_config)
197
+ self.language_model = LanguageModel(config.text_config)
198
+ self.aligner = MlpProjector(config)
199
+ self.vision_feature_layer = config.select_layer
200
+ self.vision_feature_select_strategy = config.vision_feature_select_strategy
201
+
202
+ def add_image_token(
203
+ self,
204
+ image_indices: list,
205
+ input_ids: np.ndarray,
206
+ image_token_index: int,
207
+ num_image_tokens: int,
208
+ add_special_token: bool = False,
209
+ ):
210
+ """
211
+ Inserts image tokens into an array of input IDs at specified indices.
212
+
213
+ Args:
214
+ image_indices (List[int]): Indices where image tokens should be inserted.
215
+ input_ids (np.ndarray): Original array of input IDs, expected to be two-dimensional.
216
+ image_token_index (int): The ID used to represent an image token.
217
+ num_image_tokens (int): Number of image tokens to insert at each index.
218
+ add_special_token (bool): If True, adjusts the indices to include a special token.
219
+
220
+ Returns:
221
+ Tuple of (np.ndarray, np.ndarray):
222
+ - Updated array of input IDs with image tokens inserted.
223
+ - Array indicating the number of image tokens added at each position.
224
+ """
225
+ input_slices = []
226
+
227
+ start = 0
228
+ flat_input_ids = input_ids.flatten()
229
+
230
+ for index in image_indices:
231
+ end = (index[0] + 1) if add_special_token else index[0]
232
+
233
+ input_slices.append(flat_input_ids[start:end])
234
+ input_slices.append(
235
+ np.full((num_image_tokens,), image_token_index, dtype=np.int64)
236
+ )
237
+ start = index[0] + 1 # Move start past the current image insertion point
238
+
239
+ input_slices.append(flat_input_ids[start:])
240
+
241
+ input_ids = np.concatenate(input_slices, axis=0)
242
+ num_image_tokens_array = np.array(
243
+ [num_image_tokens] * len(image_indices), dtype=np.int64
244
+ )
245
+ input_ids = input_ids.reshape(1, -1)
246
+
247
+ return input_ids, num_image_tokens_array
248
+
249
+ def get_input_embeddings(
250
+ self,
251
+ input_ids: Optional[mx.array] = None,
252
+ pixel_values: Optional[mx.array] = None,
253
+ **kwargs,
254
+ ):
255
+ if pixel_values is None:
256
+ return InputEmbeddingsFeatures(
257
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
258
+ )
259
+
260
+ image_token_index = self.config.image_token_index
261
+ num_image_tokens = self.config.num_image_tokens
262
+
263
+ image_token_mask = np.array(input_ids[0] == image_token_index).astype(bool)
264
+ image_indices = np.nonzero(image_token_mask)
265
+
266
+ input_ids, num_image_tokens = self.add_image_token(
267
+ image_indices=image_indices,
268
+ input_ids=np.array(input_ids),
269
+ image_token_index=image_token_index,
270
+ num_image_tokens=num_image_tokens,
271
+ )
272
+
273
+ input_ids = mx.array(input_ids)
274
+
275
+ # Get the input embeddings from the language model
276
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
277
+
278
+ # Get the ouptut hidden states from the vision model
279
+ if self.config.vision_config.cls == "HybridVisionTower":
280
+ hidden_states = self.vision_model(
281
+ pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
282
+ )
283
+ else:
284
+ hidden_states, _, _ = self.vision_model(
285
+ pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
286
+ )
287
+
288
+ # Pass image features through the multi-modal projector
289
+ image_features = self.aligner(hidden_states)
290
+
291
+ # Insert special image tokens in the input_ids
292
+ final_inputs_embeds = self._merge_input_ids_with_image_features(
293
+ image_features, inputs_embeds, input_ids
294
+ )
295
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
296
+
297
+ def _merge_input_ids_with_image_features(
298
+ self, image_features, inputs_embeds, input_ids
299
+ ):
300
+ image_token_index = self.config.image_token_index
301
+
302
+ # Positions of <image> tokens in input_ids, assuming batch size is 1
303
+ image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
304
+ text_segments = []
305
+ start_idx = 0
306
+
307
+ for position in image_positions:
308
+ text_segments.append(inputs_embeds[:, start_idx:position])
309
+ start_idx = position + 1
310
+
311
+ image_embeddings = mx.split(image_features, image_features.shape[0])
312
+ final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
313
+ final_embeddings += [inputs_embeds[:, start_idx:]]
314
+
315
+ # Create a final embedding of shape
316
+ # (1, num_image_patches*num_images + sequence_len, embed_dim)
317
+ return mx.concatenate(final_embeddings, axis=1)
318
+
319
+ @property
320
+ def layers(self):
321
+ return self.language_model.model.layers
322
+
323
+ def __call__(
324
+ self,
325
+ input_ids: mx.array,
326
+ pixel_values: mx.array,
327
+ mask: mx.array,
328
+ cache=None,
329
+ **kwargs,
330
+ ):
331
+
332
+ input_embeddings_features = self.get_input_embeddings(input_ids, pixel_values)
333
+ logits = self.language_model(
334
+ input_ids,
335
+ cache=cache,
336
+ inputs_embeds=input_embeddings_features.inputs_embeds,
337
+ )
338
+ return logits