fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,3 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .glm4v import LanguageModel, Model, VisionModel
3
+ from .processing import Glm46VProcessor
@@ -0,0 +1,79 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str = "glm4v_text"
10
+ vocab_size: int = 151552
11
+ hidden_size: int = 4096
12
+ eos_token_id: List[int] = field(
13
+ default_factory=lambda: [151329, 151336, 151338, 151348]
14
+ )
15
+ intermediate_size: int = 13696
16
+ max_position_embeddings: int = 65536
17
+ num_attention_heads: int = 32
18
+ num_hidden_layers: int = 40
19
+ num_key_value_heads: int = 2
20
+ rms_norm_eps: float = 1e-05
21
+ rope_theta: float = 10000
22
+ attention_bias: bool = True
23
+ attention_dropout: float = 0.0
24
+ hidden_act: str = "silu"
25
+ initializer_range: float = 0.02
26
+ partial_rotary_factor: float = 0.5
27
+ rope_scaling: Dict = field(
28
+ default_factory=lambda: {"rope_type": "default", "mrope_section": [8, 12, 12]}
29
+ )
30
+ pad_token_id: int = 151329
31
+ use_cache: bool = True
32
+
33
+
34
+ @dataclass
35
+ class VisionConfig(BaseModelConfig):
36
+ model_type: str
37
+ depth: int
38
+ hidden_size: int
39
+ intermediate_size: int
40
+ num_heads: int
41
+ patch_size: int
42
+ window_size: int = 112
43
+ image_size: int = 336
44
+ in_channels: int = 3
45
+ rms_norm_eps: float = 1e-05
46
+ attention_bias: bool = False
47
+ attention_dropout: float = 0.0
48
+ hidden_act: str = "silu"
49
+ initializer_range: float = 0.02
50
+ out_hidden_size: int = 4096
51
+ spatial_merge_size: int = 2
52
+ temporal_patch_size: int = 2
53
+
54
+
55
+ @dataclass
56
+ class ModelConfig(BaseModelConfig):
57
+ text_config: TextConfig
58
+ vision_config: VisionConfig
59
+ model_type: str
60
+ vocab_size: int = 257152
61
+ ignore_index: int = -100
62
+ image_token_index: int = 151363
63
+ image_token_id: int = 151363
64
+ video_token_index: int = 151364
65
+ video_token_id: int = 151364
66
+ vision_start_token_id: int = 151339
67
+ vision_end_token_id: int = 151340
68
+ hidden_size: int = 2048
69
+ pad_token_id: int = 0
70
+ eos_token_id: Optional[List[int]] = None
71
+
72
+ def __post_init__(self):
73
+ if self.eos_token_id is None:
74
+ text_config = (
75
+ asdict(self.text_config)
76
+ if isinstance(self.text_config, TextConfig)
77
+ else self.text_config
78
+ )
79
+ self.eos_token_id = text_config["eos_token_id"]
@@ -0,0 +1,188 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .processing import Glm46VProcessor
10
+ from .vision import VisionModel
11
+
12
+ # Register the processor with the name expected by the model config
13
+ try:
14
+ from transformers import AutoProcessor
15
+
16
+ # The model's preprocessor_config.json specifies "processor_class": "Glm46VProcessor"
17
+ AutoProcessor.register("Glm46VProcessor", Glm46VProcessor)
18
+ except Exception as e:
19
+ print(f"Error registering glm4v processor: {e}")
20
+
21
+
22
+ class Model(nn.Module):
23
+ def __init__(self, config: ModelConfig):
24
+ super().__init__()
25
+ self.config = config
26
+ self.vision_tower = VisionModel(config.vision_config)
27
+ self.language_model = LanguageModel(config.text_config, config)
28
+
29
+ def get_input_embeddings(
30
+ self,
31
+ input_ids: Optional[mx.array] = None,
32
+ pixel_values: Optional[mx.array] = None,
33
+ **kwargs,
34
+ ):
35
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
36
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
37
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
38
+
39
+ if pixel_values is None:
40
+ return InputEmbeddingsFeatures(
41
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
42
+ )
43
+
44
+ dtype = self.vision_tower.patch_embed.proj.weight.dtype
45
+ pixel_values = pixel_values.astype(dtype)
46
+
47
+ # Get the input embeddings from the language model
48
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
49
+
50
+ # Get the ouptut hidden states from the vision model
51
+ hidden_states = self.vision_tower(
52
+ pixel_values, grid_thw, output_hidden_states=False
53
+ )
54
+
55
+ split_sizes = (
56
+ image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2
57
+ ).tolist()
58
+ hidden_states = mx.split(
59
+ hidden_states, [split_sizes[0], sum(split_sizes[:2])], axis=0
60
+ )
61
+
62
+ hidden_states = mx.concatenate(hidden_states, axis=0).astype(
63
+ hidden_states[0].dtype
64
+ )
65
+
66
+ # Insert special image tokens in the input_ids
67
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
68
+ self.config.image_token_id,
69
+ self.config.video_token_id,
70
+ hidden_states,
71
+ inputs_embeds,
72
+ input_ids,
73
+ )
74
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
75
+
76
+ @staticmethod
77
+ def merge_input_ids_with_image_features(
78
+ image_token_id,
79
+ video_token_id,
80
+ image_features,
81
+ inputs_embeds,
82
+ input_ids,
83
+ ):
84
+ """Merge image features into input embeddings at image token positions.
85
+
86
+ Args:
87
+ image_token_id: The token ID for image placeholders
88
+ video_token_id: The token ID for video placeholders (fallback)
89
+ image_features: Vision features from the vision tower [num_features, hidden_dim]
90
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
91
+ input_ids: Input token IDs [batch_size, seq_len]
92
+ grid_thw: Grid dimensions for each image (optional, not used in simple case)
93
+
94
+ Returns:
95
+ Updated input embeddings with image features inserted
96
+ """
97
+ # Find positions of image tokens
98
+ image_positions = input_ids == image_token_id
99
+ if mx.sum(image_positions) == 0:
100
+ image_positions = input_ids == video_token_id
101
+
102
+ # Get dimensions
103
+ batch_size, seq_len = input_ids.shape
104
+
105
+ # Process each batch item
106
+ batch_outputs = []
107
+ feature_start_idx = 0
108
+
109
+ for batch_idx in range(batch_size):
110
+ # Get mask for this batch
111
+ image_mask = image_positions[batch_idx]
112
+ num_positions = mx.sum(image_mask).item()
113
+
114
+ if num_positions > 0:
115
+ # Extract features for this batch
116
+ batch_features = image_features[
117
+ feature_start_idx : feature_start_idx + num_positions
118
+ ]
119
+
120
+ # Validate we have the right number of features
121
+ if batch_features.shape[0] != num_positions:
122
+ raise ValueError(
123
+ f"Number of image token positions ({num_positions}) does not match "
124
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
125
+ )
126
+
127
+ # Create indices for gathering
128
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
129
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
130
+
131
+ # Gather features
132
+ gathered_features = batch_features[feature_indices]
133
+
134
+ # Combine with original embeddings
135
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
136
+ batch_output = mx.where(
137
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
138
+ )
139
+
140
+ feature_start_idx += num_positions
141
+ else:
142
+ # No image tokens in this batch item
143
+ batch_output = inputs_embeds[batch_idx]
144
+
145
+ batch_outputs.append(batch_output)
146
+
147
+ # Stack all batch outputs
148
+ return mx.stack(batch_outputs, axis=0)
149
+
150
+ @property
151
+ def layers(self):
152
+ return self.language_model.model.layers
153
+
154
+ def __call__(
155
+ self,
156
+ input_ids: mx.array,
157
+ pixel_values: Optional[mx.array] = None,
158
+ mask: Optional[mx.array] = None,
159
+ cache=None,
160
+ **kwargs,
161
+ ):
162
+
163
+ input_embeddings_features = self.get_input_embeddings(
164
+ input_ids, pixel_values, **kwargs
165
+ )
166
+
167
+ logits = self.language_model(
168
+ input_ids,
169
+ input_embeddings_features.inputs_embeds,
170
+ mask=mask,
171
+ cache=cache,
172
+ **kwargs,
173
+ )
174
+
175
+ return logits
176
+
177
+ def sanitize(self, weights):
178
+ def transform_key(key):
179
+ if "visual" in key:
180
+ if "vision_tower" not in key:
181
+ key = key.replace("model.", "").replace("visual", "vision_tower")
182
+ if "model.language_model" in key:
183
+ key = key.replace("model.language_model", "language_model.model")
184
+ if "lm_head" in key and not key.startswith("language_model"):
185
+ key = key.replace("lm_head", "language_model.lm_head")
186
+ return key
187
+
188
+ return {transform_key(k): v for k, v in weights.items()}