fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,223 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import InputEmbeddingsFeatures
8
+ from . import processing_lfm2_vl # noqa: F401
9
+ from .config import ModelConfig
10
+ from .language import LanguageModel
11
+ from .vision import VisionModel
12
+
13
+
14
+ class Lfm2VlMultiModalProjector(nn.Module):
15
+ def __init__(self, config: ModelConfig):
16
+ super().__init__()
17
+ in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
18
+ if config.projector_use_layernorm:
19
+ self.layer_norm = nn.LayerNorm(in_channels)
20
+ else:
21
+ self.layer_norm = nn.Identity()
22
+ self.linear_1 = nn.Linear(
23
+ in_channels,
24
+ config.projector_hidden_size,
25
+ bias=config.projector_bias,
26
+ )
27
+
28
+ self.linear_2 = nn.Linear(
29
+ config.projector_hidden_size,
30
+ config.text_config.hidden_size,
31
+ bias=config.projector_bias,
32
+ )
33
+
34
+ def __call__(self, x):
35
+ x = self.linear_1(self.layer_norm(x))
36
+ x = self.linear_2(nn.gelu(x))
37
+ return x
38
+
39
+
40
+ class PixelUnshuffleBlock(nn.Module):
41
+ def __init__(self, factor: int):
42
+ super().__init__()
43
+ self.factor = factor
44
+
45
+ def __call__(self, x):
46
+ n, w, h, c = x.shape
47
+ if w % self.factor != 0:
48
+ x = mx.concatenate(
49
+ [
50
+ x,
51
+ mx.zeros((n, self.factor - (w % self.factor), h, c), dtype=x.dtype),
52
+ ],
53
+ axis=1,
54
+ )
55
+ n, w, h, c = x.shape
56
+
57
+ if h % self.factor != 0:
58
+ x = mx.concatenate(
59
+ [
60
+ x,
61
+ mx.zeros((n, w, self.factor - (h % self.factor), c), dtype=x.dtype),
62
+ ],
63
+ axis=2,
64
+ )
65
+ n, w, h, c = x.shape
66
+ x = x.reshape(n, w, int(h / self.factor), int(c * self.factor))
67
+ x = x.transpose(0, 2, 1, 3)
68
+ x = x.reshape(
69
+ n, int(h / self.factor), int(w / self.factor), int(c * self.factor**2)
70
+ )
71
+ x = x.transpose(0, 2, 1, 3)
72
+ return x
73
+
74
+
75
+ def masked_scatter(
76
+ final_embedding: mx.array,
77
+ image_mask_expanded: mx.array,
78
+ scaled_image_features: mx.array,
79
+ ):
80
+ # Reshape the tensors to 1D
81
+ final_embedding_shape = final_embedding.shape
82
+ scaled_image_features_flattened = mx.flatten(scaled_image_features)
83
+ final_embedding_flattened = mx.flatten(final_embedding)
84
+ image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
85
+
86
+ # Scatter the scaled image features into the special image token positions
87
+ image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
88
+ final_embedding_flattened[image_positions] = scaled_image_features_flattened
89
+
90
+ # Reshape back to the original shape
91
+ final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
92
+
93
+ return final_embedding
94
+
95
+
96
+ class Model(nn.Module):
97
+ def __init__(self, config: ModelConfig):
98
+ super().__init__()
99
+ self.model_type = config.model_type
100
+ self.config = config
101
+ self.vision_tower = VisionModel(config.vision_config)
102
+
103
+ if config.vision_feature_layer != -1:
104
+ self.vision_tower.encoder.layers = self.vision_tower.encoder.layers[
105
+ : config.vision_feature_layer + 1
106
+ ]
107
+ if config.downsample_factor > 1:
108
+ self.pixel_unshuffle = PixelUnshuffleBlock(config.downsample_factor)
109
+ else:
110
+ self.pixel_unshuffle = nn.Identity()
111
+
112
+ self.multi_modal_projector = Lfm2VlMultiModalProjector(config)
113
+ self.language_model = LanguageModel(config.text_config)
114
+
115
+ def get_input_embeddings(
116
+ self,
117
+ input_ids: Optional[mx.array] = None,
118
+ pixel_values: Optional[mx.array] = None,
119
+ **kwargs,
120
+ ):
121
+ spatial_shapes = kwargs.get("spatial_shapes", None)
122
+ pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
123
+
124
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
125
+
126
+ if pixel_values is None:
127
+ return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
128
+
129
+ # Get the ouptut hidden states from the vision model
130
+ *_, hidden_states = self.vision_tower(
131
+ pixel_values, output_hidden_states=True, spatial_shapes=spatial_shapes
132
+ )
133
+
134
+ img_feature_lengths = pixel_attention_mask.sum(axis=1).tolist()
135
+ image_features = []
136
+
137
+ for img_idx in range(hidden_states.shape[0]):
138
+ feature = hidden_states[img_idx]
139
+
140
+ feature = feature[: img_feature_lengths[img_idx], :][None, ...]
141
+
142
+ feature_org_h, feature_org_w = spatial_shapes[img_idx]
143
+ feature = feature.reshape(1, feature_org_h, feature_org_w, -1)
144
+ feature = self.pixel_unshuffle(feature)
145
+
146
+ img_embedding = self.multi_modal_projector(feature)
147
+
148
+ img_embedding = img_embedding.reshape(-1, img_embedding.shape[-1])
149
+ image_features.append(img_embedding)
150
+
151
+ image_features = mx.concatenate(image_features, axis=0)
152
+
153
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
154
+ image_features, inputs_embeds, input_ids, self.config.image_token_index
155
+ )
156
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
157
+
158
+ @staticmethod
159
+ def merge_input_ids_with_image_features(
160
+ image_features, inputs_embeds, input_ids, image_token_index
161
+ ):
162
+ special_image_mask = input_ids == image_token_index
163
+ n_image_tokens = special_image_mask.sum()
164
+ special_image_mask = special_image_mask[..., None]
165
+ special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape)
166
+
167
+ n_image_features = image_features.shape[0]
168
+ n_image_mask_elements = special_image_mask.sum()
169
+ if n_image_mask_elements != image_features.size:
170
+ raise ValueError(
171
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
172
+ )
173
+
174
+ inputs_embeds = masked_scatter(
175
+ inputs_embeds, special_image_mask, image_features
176
+ )
177
+
178
+ return inputs_embeds
179
+
180
+ @property
181
+ def layers(self):
182
+ return self.language_model.model.layers
183
+
184
+ def __call__(
185
+ self,
186
+ input_ids: mx.array,
187
+ pixel_values: mx.array,
188
+ mask: mx.array,
189
+ cache=None,
190
+ **kwargs,
191
+ ):
192
+ spatial_shapes = kwargs.get("spatial_shapes", None)
193
+ pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
194
+ input_embeddings_features = self.get_input_embeddings(
195
+ input_ids, pixel_values, spatial_shapes, pixel_attention_mask
196
+ )
197
+
198
+ logits = self.language_model(
199
+ input_ids, mask=None, cache=cache, inputs_embeds=input_embeddings_features
200
+ )
201
+ return logits
202
+
203
+ def sanitize(self, weights):
204
+ def transform_key(key):
205
+ if "vision_tower" in key:
206
+ key = (
207
+ key.replace("model.", "")
208
+ .replace("vision_encoder", "encoder")
209
+ .replace("vision_embeddings", "embeddings")
210
+ .replace("vision_post_layernorm", "post_layernorm")
211
+ )
212
+
213
+ if "language_model" in key:
214
+ key = key.replace("model.language_model", "language_model.model")
215
+
216
+ if "multi_modal_projector" in key:
217
+ key = key.replace(
218
+ "model.multi_modal_projector", "multi_modal_projector"
219
+ )
220
+
221
+ return key
222
+
223
+ return {transform_key(k): v for k, v in weights.items()}
@@ -0,0 +1,320 @@
1
+ """
2
+ Compatibility patch for Lfm2VlProcessor.
3
+
4
+ The Lfm2VlProcessorKwargs has a default `return_row_col_info: True` in images_kwargs,
5
+ but this parameter is only supported by the FAST image processor (Lfm2VlImageProcessorFast).
6
+ When using the slow image processor (Siglip2ImageProcessor), this causes a validation error.
7
+
8
+ This patch:
9
+ 1. Removes the unsupported `return_row_col_info` parameter from the defaults
10
+ 2. Enables `do_resize: True` to ensure images are properly resized for patch processing
11
+ 3. Patches the `__call__` method to handle the slow image processor case, computing
12
+ `image_rows`, `image_cols`, `image_sizes` when missing and providing sensible
13
+ defaults for tile-related parameters
14
+ 4. Patches the `__init__` to add missing attributes to the slow image processor
15
+ 5. Forces the use of the slow image processor to avoid PyTorch tensor requirements
16
+ """
17
+
18
+ import math
19
+
20
+ import numpy as np
21
+ from transformers.models.lfm2_vl.processing_lfm2_vl import (
22
+ Lfm2VlProcessor,
23
+ Lfm2VlProcessorKwargs,
24
+ )
25
+
26
+ # Try to import the slow image processor to force its use
27
+ try:
28
+ from transformers.models.siglip2.image_processing_siglip2 import (
29
+ Siglip2ImageProcessor,
30
+ )
31
+
32
+ _SLOW_PROCESSOR_AVAILABLE = True
33
+ except ImportError:
34
+ _SLOW_PROCESSOR_AVAILABLE = False
35
+
36
+ # Remove return_row_col_info from the defaults since the slow image processor
37
+ # (Siglip2ImageProcessor) doesn't support it - only the fast version does.
38
+ # Also enable do_resize to ensure images are properly resized to be divisible by patch_size.
39
+ if hasattr(Lfm2VlProcessorKwargs, "_defaults"):
40
+ if "images_kwargs" in Lfm2VlProcessorKwargs._defaults:
41
+ Lfm2VlProcessorKwargs._defaults["images_kwargs"].pop(
42
+ "return_row_col_info", None
43
+ )
44
+ # Enable resizing for the slow image processor (model config has do_resize: False
45
+ # which is intended for the fast processor that handles resizing differently)
46
+ Lfm2VlProcessorKwargs._defaults["images_kwargs"]["do_resize"] = True
47
+
48
+
49
+ # Store the original __init__ method
50
+ _original_init = Lfm2VlProcessor.__init__
51
+
52
+
53
+ def _patched_init(self, image_processor, tokenizer, chat_template=None, **kwargs):
54
+ """Patched __init__ that adds missing attributes to the slow image processor."""
55
+ # Check if we got the fast image processor and need to replace it with the slow one
56
+ # The fast processor requires PyTorch tensors which we don't have
57
+ processor_class_name = type(image_processor).__name__
58
+ if "Fast" in processor_class_name and _SLOW_PROCESSOR_AVAILABLE:
59
+ # Replace with slow processor using the same config
60
+ if hasattr(image_processor, "to_dict"):
61
+ # Use the config dict to create the slow processor
62
+ slow_processor = Siglip2ImageProcessor(**image_processor.to_dict())
63
+ else:
64
+ # Fallback to copying attributes
65
+ slow_processor = Siglip2ImageProcessor(
66
+ **{
67
+ k: v
68
+ for k, v in image_processor.__dict__.items()
69
+ if not k.startswith("_") and k not in ["name_or_path"]
70
+ }
71
+ )
72
+ image_processor = slow_processor
73
+
74
+ # Call original __init__
75
+ _original_init(
76
+ self, image_processor, tokenizer, chat_template=chat_template, **kwargs
77
+ )
78
+
79
+ # Add missing attributes for the slow image processor (Siglip2ImageProcessor)
80
+ # These are needed by expand_text_with_placeholders and _get_image_num_tokens
81
+ if not hasattr(self.image_processor, "tile_size"):
82
+ self.image_processor.tile_size = 512
83
+ if not hasattr(self.image_processor, "max_image_tokens"):
84
+ self.image_processor.max_image_tokens = 256
85
+ if not hasattr(self.image_processor, "min_image_tokens"):
86
+ self.image_processor.min_image_tokens = 64
87
+ if not hasattr(self.image_processor, "downsample_factor"):
88
+ self.image_processor.downsample_factor = 2
89
+ if not hasattr(self.image_processor, "encoder_patch_size"):
90
+ self.image_processor.encoder_patch_size = 16
91
+ if not hasattr(self.image_processor, "do_image_splitting"):
92
+ self.image_processor.do_image_splitting = (
93
+ False # Disable tiling for slow processor
94
+ )
95
+ if not hasattr(self.image_processor, "use_thumbnail"):
96
+ self.image_processor.use_thumbnail = False
97
+
98
+
99
+ # Apply the __init__ patch
100
+ Lfm2VlProcessor.__init__ = _patched_init
101
+
102
+
103
+ def _compute_image_grid_info(pixel_values, patch_size: int = 16):
104
+ """
105
+ Compute image_rows, image_cols, and image_sizes from pixel_values.
106
+
107
+ When using the slow image processor (Siglip2ImageProcessor), these values
108
+ are not returned. This function computes them from the pixel_values tensor.
109
+
110
+ Args:
111
+ pixel_values: Array of shape (batch, num_patches, patch_dim)
112
+ patch_size: The patch size used for image processing
113
+
114
+ Returns:
115
+ image_rows: List of rows per image
116
+ image_cols: List of cols per image
117
+ image_sizes: List of total patches per image
118
+ """
119
+ # pixel_values shape: (batch, num_patches, patch_dim)
120
+ # For Siglip2, each image is processed independently and has its own num_patches
121
+ if hasattr(pixel_values, "shape"):
122
+ batch_size = pixel_values.shape[0]
123
+ num_patches = pixel_values.shape[1]
124
+
125
+ # Estimate rows/cols from num_patches (assuming roughly square)
126
+ # The actual image was resized to fit max_num_patches while maintaining aspect ratio
127
+ side_length = int(math.sqrt(num_patches))
128
+
129
+ # Return as nested lists (one list per batch, one value per image in batch)
130
+ image_rows = [[side_length] for _ in range(batch_size)]
131
+ image_cols = [[side_length] for _ in range(batch_size)]
132
+ image_sizes = [[num_patches] for _ in range(batch_size)]
133
+
134
+ return image_rows, image_cols, image_sizes
135
+
136
+ return [[1]], [[1]], [[1]]
137
+
138
+
139
+ # Store the original __call__ method
140
+ _original_call = Lfm2VlProcessor.__call__
141
+
142
+
143
+ def _ensure_slow_processor(processor_instance):
144
+ """
145
+ Ensure we're using the slow image processor, not the fast one.
146
+ The fast processor only supports PyTorch tensors which we can't use without PyTorch.
147
+ """
148
+ image_processor = processor_instance.image_processor
149
+ processor_class_name = type(image_processor).__name__
150
+
151
+ if "Fast" in processor_class_name and _SLOW_PROCESSOR_AVAILABLE:
152
+ # Need to replace with slow processor
153
+ # Get the config from the fast processor
154
+ config = (
155
+ image_processor.to_dict() if hasattr(image_processor, "to_dict") else {}
156
+ )
157
+ # Remove keys that might cause issues
158
+ config.pop("image_processor_type", None)
159
+ config.pop("auto_map", None)
160
+ config.pop("_processor_class", None)
161
+
162
+ # Create slow processor with the same config
163
+ slow_processor = Siglip2ImageProcessor(**config)
164
+ processor_instance.image_processor = slow_processor
165
+
166
+ # Re-add missing attributes
167
+ if not hasattr(processor_instance.image_processor, "tile_size"):
168
+ processor_instance.image_processor.tile_size = 512
169
+ if not hasattr(processor_instance.image_processor, "downsample_factor"):
170
+ processor_instance.image_processor.downsample_factor = 2
171
+ if not hasattr(processor_instance.image_processor, "do_image_splitting"):
172
+ processor_instance.image_processor.do_image_splitting = False
173
+ if not hasattr(processor_instance.image_processor, "use_thumbnail"):
174
+ processor_instance.image_processor.use_thumbnail = False
175
+
176
+ return processor_instance.image_processor
177
+
178
+
179
+ def _patched_call(self, images=None, text=None, **kwargs):
180
+ """
181
+ Patched __call__ that handles the slow image processor case.
182
+
183
+ The slow Siglip2ImageProcessor doesn't return image_rows, image_cols, image_sizes
184
+ which are required by expand_text_with_placeholders. This patch intercepts the call
185
+ and computes these values when they're missing.
186
+ """
187
+ from transformers.feature_extraction_utils import BatchFeature
188
+ from transformers.image_utils import make_nested_list_of_images
189
+
190
+ # Ensure we're using the slow processor (fast requires PyTorch tensors)
191
+ if images is not None:
192
+ _ensure_slow_processor(self)
193
+
194
+ if images is None and text is not None:
195
+ # Text-only case
196
+ output_kwargs = self._merge_kwargs(
197
+ Lfm2VlProcessorKwargs,
198
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
199
+ **kwargs,
200
+ )
201
+ output_kwargs["text_kwargs"].pop("use_image_special_tokens", None)
202
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
203
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
204
+ return BatchFeature(text_inputs, tensor_type=return_tensors)
205
+
206
+ if text is None and images is None:
207
+ raise ValueError("You must provide one of `text` or `images`.")
208
+
209
+ if images is not None and text is None:
210
+ raise ValueError(
211
+ "You must provide `text` when `images` is provided. Minimal text consists of a single image token."
212
+ )
213
+
214
+ # Merge kwargs to get the final settings
215
+ output_kwargs = self._merge_kwargs(
216
+ Lfm2VlProcessorKwargs,
217
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
218
+ **kwargs,
219
+ )
220
+
221
+ if isinstance(text, str):
222
+ text = [text]
223
+ elif text is not None and not isinstance(text, list):
224
+ raise TypeError(
225
+ "Invalid input text. Please provide a string, or a list of strings"
226
+ )
227
+
228
+ n_images_in_text = [sample.count(self.image_token) for sample in text]
229
+
230
+ inputs = {}
231
+ use_image_special_tokens = output_kwargs["text_kwargs"].pop(
232
+ "use_image_special_tokens", True
233
+ )
234
+
235
+ # Process images
236
+ images = self.image_processor.fetch_images(images)
237
+ batched_images = make_nested_list_of_images(images)
238
+
239
+ # Override return_tensors for image processing to avoid PyTorch dependency
240
+ images_kwargs = output_kwargs["images_kwargs"].copy()
241
+ images_kwargs["return_tensors"] = "np" # Use numpy instead of pt
242
+
243
+ vision_inputs = self.image_processor(batched_images, **images_kwargs)
244
+
245
+ n_images_in_images = [len(sublist) for sublist in batched_images]
246
+ if n_images_in_images != n_images_in_text:
247
+ raise ValueError(
248
+ f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
249
+ )
250
+
251
+ # Check if image_rows/cols/sizes are present (fast processor case)
252
+ if "image_rows" in vision_inputs:
253
+ image_rows = vision_inputs.pop("image_rows")
254
+ image_cols = vision_inputs.pop("image_cols")
255
+ image_sizes = vision_inputs.pop("image_sizes")
256
+ else:
257
+ # Slow processor case - compute from spatial_shapes or pixel_attention_mask
258
+ # The spatial_shapes gives the actual (height, width) in patches for each image
259
+ spatial_shapes = vision_inputs.get("spatial_shapes")
260
+ if spatial_shapes is not None:
261
+ # spatial_shapes is array of shape (batch, 2) with [height, width] in patches
262
+ image_rows = [[int(ss[0])] for ss in spatial_shapes]
263
+ image_cols = [[int(ss[1])] for ss in spatial_shapes]
264
+ image_sizes = [[int(ss[0] * ss[1])] for ss in spatial_shapes]
265
+ else:
266
+ # Fallback to computing from pixel_values
267
+ pixel_values = vision_inputs.get("pixel_values")
268
+ patch_size = getattr(self.image_processor, "patch_size", 16)
269
+ image_rows, image_cols, image_sizes = _compute_image_grid_info(
270
+ pixel_values, patch_size
271
+ )
272
+
273
+ # For slow processor, use simplified text expansion
274
+ # (no tiling support, just add image tokens)
275
+ # Account for downsample_factor: the vision tower reduces patches by factor^2
276
+ downsample_factor = getattr(self.image_processor, "downsample_factor", 2)
277
+
278
+ expanded_text = []
279
+ for sample_text, sample_images, rows, cols, sizes in zip(
280
+ text, batched_images, image_rows, image_cols, image_sizes
281
+ ):
282
+ split_sample = sample_text.split(self.image_token)
283
+ result = ""
284
+ for i, _ in enumerate(sample_images):
285
+ result += split_sample[i]
286
+ if use_image_special_tokens:
287
+ result += self.image_start_token
288
+ # Add image tokens based on the number of patches AFTER downsampling
289
+ # The vision tower downsamples by factor^2, so divide by that
290
+ num_patches = sizes[i] if i < len(sizes) else sizes[0]
291
+ num_image_tokens = num_patches // (downsample_factor**2)
292
+ result += self.image_token * num_image_tokens
293
+ if use_image_special_tokens:
294
+ result += self.image_end_token
295
+ # Add any remaining text after the last image
296
+ if len(split_sample) > len(sample_images):
297
+ result += split_sample[-1]
298
+ expanded_text.append(result)
299
+
300
+ inputs.update(vision_inputs)
301
+
302
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
303
+
304
+ text_inputs = self.tokenizer(expanded_text, **output_kwargs["text_kwargs"])
305
+ inputs.update(text_inputs)
306
+
307
+ # Convert lists to numpy arrays for proper handling by mlx_vlm
308
+ # The tokenizer returns lists but mlx_vlm expects numpy arrays
309
+ if isinstance(inputs.get("input_ids"), list):
310
+ inputs["input_ids"] = np.array(inputs["input_ids"])
311
+ if isinstance(inputs.get("attention_mask"), list):
312
+ inputs["attention_mask"] = np.array(inputs["attention_mask"])
313
+
314
+ return BatchFeature(
315
+ inputs, tensor_type=None
316
+ ) # Don't convert, let mlx_vlm handle it
317
+
318
+
319
+ # Apply the patch
320
+ Lfm2VlProcessor.__call__ = _patched_call