fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,81 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str
10
+ vocab_size: int
11
+ hidden_size: int
12
+ intermediate_size: int
13
+ max_position_embeddings: int
14
+ moe_intermediate_size: int
15
+ norm_topk_prob: bool
16
+ num_attention_heads: int
17
+ n_group: int
18
+ head_dim: int
19
+ topk_group: int
20
+ n_shared_experts: int
21
+ n_routed_experts: int
22
+ routed_scaling_factor: float
23
+ num_experts_per_tok: int
24
+ first_k_dense_replace: int
25
+ num_hidden_layers: int
26
+ num_key_value_heads: int
27
+ rms_norm_eps: float
28
+ use_qk_norm: bool
29
+ attention_bias: bool
30
+ partial_rotary_factor: float
31
+ rope_theta: float = None
32
+ rope_parameters: Dict = None
33
+ rope_scaling: Dict = field(
34
+ default_factory=lambda: {"type": "default", "mrope_section": [64, 32, 32]}
35
+ )
36
+ tie_word_embeddings: bool = None
37
+ scoring_func: str = "sigmoid"
38
+ topk_method: str = "noaux_tc"
39
+
40
+
41
+ @dataclass
42
+ class VisionConfig(BaseModelConfig):
43
+ model_type: str
44
+ depth: int
45
+ hidden_size: int
46
+ intermediate_size: int
47
+ num_heads: int
48
+ patch_size: int
49
+ window_size: int = 112
50
+ image_size: int = 336
51
+ in_channels: int = 3
52
+ rms_norm_eps: float = 1e-05
53
+ attention_bias: bool = False
54
+ attention_dropout: float = 0.0
55
+ hidden_act: str = "silu"
56
+ initializer_range: float = 0.02
57
+ out_hidden_size: int = 4096
58
+ spatial_merge_size: int = 2
59
+ temporal_patch_size: int = 2
60
+
61
+
62
+ @dataclass
63
+ class ModelConfig(BaseModelConfig):
64
+ text_config: TextConfig
65
+ vision_config: VisionConfig
66
+ model_type: str
67
+ vocab_size: int = 257152
68
+ ignore_index: int = -100
69
+ image_token_index: int = 151363
70
+ image_token_id: int = 151363
71
+ video_token_index: int = 151364
72
+ video_token_id: int = 151364
73
+ vision_start_token_id: int = 151339
74
+ vision_end_token_id: int = 151340
75
+ hidden_size: int = 2048
76
+ pad_token_id: int = 0
77
+ eos_token_id: Optional[List[int]] = None
78
+
79
+ def __post_init__(self):
80
+ if self.eos_token_id is None:
81
+ self.eos_token_id = [151329, 151336, 151338]
@@ -0,0 +1,176 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .processing import Glm46VMoEProcessor
10
+ from .vision import VisionModel
11
+
12
+ # Register the processor with the name expected by the model config
13
+ try:
14
+ from transformers import AutoProcessor
15
+
16
+ # Register for both possible processor class names
17
+ AutoProcessor.register("Glm46VMoEProcessor", Glm46VMoEProcessor)
18
+ except Exception as e:
19
+ print(f"Error registering glm4v_moe processor: {e}")
20
+
21
+
22
+ class Model(nn.Module):
23
+ def __init__(self, config: ModelConfig):
24
+ super().__init__()
25
+ self.config = config
26
+ self.vision_tower = VisionModel(config.vision_config)
27
+ self.language_model = LanguageModel(config.text_config, config)
28
+
29
+ def get_input_embeddings(
30
+ self,
31
+ input_ids: Optional[mx.array] = None,
32
+ pixel_values: Optional[mx.array] = None,
33
+ **kwargs,
34
+ ):
35
+
36
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
37
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
38
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
39
+
40
+ if pixel_values is None:
41
+ return InputEmbeddingsFeatures(
42
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
43
+ )
44
+
45
+ dtype = self.vision_tower.patch_embed.proj.weight.dtype
46
+ pixel_values = pixel_values.astype(dtype)
47
+
48
+ # Get the input embeddings from the language model
49
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
50
+
51
+ # Get the ouptut hidden states from the vision model
52
+ hidden_states = self.vision_tower(
53
+ pixel_values, grid_thw, output_hidden_states=False
54
+ )
55
+
56
+ # Insert special image tokens in the input_ids
57
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
58
+ self.config.image_token_id,
59
+ self.config.video_token_id,
60
+ hidden_states,
61
+ inputs_embeds,
62
+ input_ids,
63
+ )
64
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
65
+
66
+ @staticmethod
67
+ def merge_input_ids_with_image_features(
68
+ image_token_id,
69
+ video_token_id,
70
+ image_features,
71
+ inputs_embeds,
72
+ input_ids,
73
+ ):
74
+ """Merge image features into input embeddings at image token positions.
75
+
76
+ Args:
77
+ image_token_id: The token ID for image placeholders
78
+ video_token_id: The token ID for video placeholders (fallback)
79
+ image_features: Vision features from the vision tower [num_features, hidden_dim]
80
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
81
+ input_ids: Input token IDs [batch_size, seq_len]
82
+ grid_thw: Grid dimensions for each image (optional, not used in simple case)
83
+
84
+ Returns:
85
+ Updated input embeddings with image features inserted
86
+ """
87
+ # Find positions of image tokens
88
+ image_positions = input_ids == image_token_id
89
+ if mx.sum(image_positions) == 0:
90
+ image_positions = input_ids == video_token_id
91
+
92
+ # Get dimensions
93
+ batch_size, seq_len = input_ids.shape
94
+
95
+ # Process each batch item
96
+ batch_outputs = []
97
+ feature_start_idx = 0
98
+
99
+ for batch_idx in range(batch_size):
100
+ # Get mask for this batch
101
+ image_mask = image_positions[batch_idx]
102
+ num_positions = mx.sum(image_mask).item()
103
+
104
+ if num_positions > 0:
105
+ # Extract features for this batch
106
+ batch_features = image_features[
107
+ feature_start_idx : feature_start_idx + num_positions
108
+ ]
109
+
110
+ # Validate we have the right number of features
111
+ if batch_features.shape[0] != num_positions:
112
+ raise ValueError(
113
+ f"Number of image token positions ({num_positions}) does not match "
114
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
115
+ )
116
+
117
+ # Create indices for gathering
118
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
119
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
120
+
121
+ # Gather features
122
+ gathered_features = batch_features[feature_indices]
123
+
124
+ # Combine with original embeddings
125
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
126
+ batch_output = mx.where(
127
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
128
+ )
129
+
130
+ feature_start_idx += num_positions
131
+ else:
132
+ # No image tokens in this batch item
133
+ batch_output = inputs_embeds[batch_idx]
134
+
135
+ batch_outputs.append(batch_output)
136
+
137
+ # Stack all batch outputs
138
+ return mx.stack(batch_outputs, axis=0)
139
+
140
+ @property
141
+ def layers(self):
142
+ return self.language_model.model.layers
143
+
144
+ def __call__(
145
+ self,
146
+ input_ids: mx.array,
147
+ pixel_values: Optional[mx.array] = None,
148
+ mask: Optional[mx.array] = None,
149
+ cache=None,
150
+ **kwargs,
151
+ ):
152
+ input_embeddings_features = self.get_input_embeddings(
153
+ input_ids, pixel_values, **kwargs
154
+ )
155
+ logits = self.language_model(
156
+ input_ids,
157
+ input_embeddings_features.inputs_embeds,
158
+ mask=mask,
159
+ cache=cache,
160
+ **kwargs,
161
+ )
162
+
163
+ return logits
164
+
165
+ def sanitize(self, weights):
166
+ def transform_key(key):
167
+ if "visual" in key:
168
+ if "vision_tower" not in key:
169
+ key = key.replace("model.", "").replace("visual", "vision_tower")
170
+ if "model.language_model" in key:
171
+ key = key.replace("model.language_model", "language_model.model")
172
+ if "lm_head" in key and not key.startswith("language_model"):
173
+ key = key.replace("lm_head", "language_model.lm_head")
174
+ return key
175
+
176
+ return {transform_key(k): v for k, v in weights.items()}