fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,207 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures, install_auto_processor_patch
7
+ from .config import ModelConfig
8
+ from .language import LanguageModel
9
+ from .processing_paddleocr_vl import PaddleOCRVLProcessor
10
+ from .vision import VisionModel
11
+
12
+ install_auto_processor_patch("paddleocr_vl", PaddleOCRVLProcessor)
13
+
14
+
15
+ class Model(nn.Module):
16
+ def __init__(self, config: ModelConfig):
17
+ super().__init__()
18
+ self.config = config
19
+ self.visual = VisionModel(config.vision_config)
20
+ self.language_model = LanguageModel(config.text_config, config)
21
+
22
+ def get_input_embeddings(
23
+ self,
24
+ input_ids: Optional[mx.array] = None,
25
+ pixel_values: Optional[mx.array] = None,
26
+ **kwargs,
27
+ ):
28
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
29
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
30
+ mask = kwargs.pop("mask", None)
31
+ grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
32
+
33
+ if pixel_values is None:
34
+ # Reset position state for text-only generation
35
+ self.language_model._position_ids = None
36
+ self.language_model._rope_deltas = None
37
+ return InputEmbeddingsFeatures(
38
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
39
+ )
40
+
41
+ dtype = self.visual.embeddings.patch_embedding.weight.dtype
42
+ pixel_values = mx.array(pixel_values, dtype=dtype)
43
+
44
+ # Get the input embeddings from the language model
45
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
46
+
47
+ # Get the ouptut hidden states from the vision model
48
+ hidden_states = self.visual(pixel_values, grid_thw, output_hidden_states=False)
49
+
50
+ # Insert special image tokens in the input_ids
51
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
52
+ self.config.image_token_id,
53
+ hidden_states,
54
+ inputs_embeds,
55
+ input_ids,
56
+ )
57
+
58
+ # Pre-calculate position_ids for chunked prefill
59
+ if image_grid_thw is not None or video_grid_thw is not None:
60
+ position_ids, rope_deltas = self.language_model.get_rope_index(
61
+ input_ids, image_grid_thw, video_grid_thw, mask
62
+ )
63
+ self.language_model._position_ids = position_ids
64
+ self.language_model._rope_deltas = rope_deltas
65
+
66
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
67
+
68
+ @staticmethod
69
+ def merge_input_ids_with_image_features(
70
+ image_token_id,
71
+ image_features,
72
+ inputs_embeds,
73
+ input_ids,
74
+ ):
75
+ """Merge image features into input embeddings at image token positions.
76
+
77
+ Args:
78
+ image_features: Vision features from the vision tower [num_features, hidden_dim]
79
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
80
+ input_ids: Input token IDs [batch_size, seq_len]
81
+
82
+ Returns:
83
+ Updated input embeddings with image features inserted
84
+ """
85
+
86
+ # Positions of <image> tokens in input_ids
87
+ image_positions = input_ids == image_token_id
88
+
89
+ # Get dimensions
90
+ batch_size, seq_len = input_ids.shape
91
+
92
+ # Process each batch item
93
+ batch_outputs = []
94
+ feature_start_idx = 0
95
+
96
+ for batch_idx in range(batch_size):
97
+ # Get mask for this batch
98
+ image_mask = image_positions[batch_idx]
99
+ num_positions = mx.sum(image_mask).item()
100
+
101
+ if num_positions > 0:
102
+ # Extract features for this batch
103
+ batch_features = image_features[
104
+ feature_start_idx : feature_start_idx + num_positions
105
+ ]
106
+
107
+ # Validate we have the right number of features
108
+ if batch_features.shape[0] != num_positions:
109
+ raise ValueError(
110
+ f"Number of image token positions ({num_positions}) does not match "
111
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
112
+ )
113
+
114
+ # Create indices for gathering
115
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
116
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
117
+
118
+ # Gather features
119
+ gathered_features = batch_features[feature_indices]
120
+
121
+ # Combine with original embeddings
122
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
123
+ batch_output = mx.where(
124
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
125
+ )
126
+
127
+ feature_start_idx += num_positions
128
+ else:
129
+ # No image tokens in this batch item
130
+ batch_output = inputs_embeds[batch_idx]
131
+
132
+ batch_outputs.append(batch_output)
133
+
134
+ # Stack all batch outputs
135
+ return mx.stack(batch_outputs, axis=0)
136
+
137
+ @property
138
+ def layers(self):
139
+ return self.language_model.model.layers
140
+
141
+ def __call__(
142
+ self,
143
+ input_ids: mx.array,
144
+ pixel_values: Optional[mx.array] = None,
145
+ mask: Optional[mx.array] = None,
146
+ cache=None,
147
+ **kwargs,
148
+ ):
149
+
150
+ input_embeddings_features = self.get_input_embeddings(
151
+ input_ids, pixel_values, **kwargs
152
+ )
153
+ kwargs = {
154
+ "pixel_values": pixel_values,
155
+ **kwargs,
156
+ }
157
+ logits = self.language_model(
158
+ input_ids,
159
+ input_embeddings_features.inputs_embeds,
160
+ mask=mask,
161
+ cache=cache,
162
+ **kwargs,
163
+ )
164
+ return logits
165
+
166
+ def sanitize(self, weights):
167
+ _keys_to_ignore_on_load_unexpected = [
168
+ "packing_position_embedding",
169
+ "vision_model.head",
170
+ ]
171
+
172
+ def transform_key(key):
173
+ if "visual.vision_model" in key:
174
+ if "embeddings" in key or "post_layernorm" in key:
175
+ key = key.replace("visual.vision_model", "visual")
176
+ elif "encoder" in key:
177
+ key = key.replace("visual.vision_model.encoder", "visual")
178
+ elif "mlp_AR" in key:
179
+ key = key.replace("mlp_AR", "visual.projector")
180
+ elif "model" in key:
181
+ key = key.replace("model", "language_model.model")
182
+ elif "lm_head" in key:
183
+ key = key.replace("lm_head", "language_model.lm_head")
184
+
185
+ return key
186
+
187
+ new_weights = {}
188
+ for k, v in weights.items():
189
+ if (
190
+ "packing_position_embedding" in k
191
+ or "vision_model.head" in k
192
+ or ("visual" in k and "k_proj" in k)
193
+ or ("visual" in k and "v_proj" in k)
194
+ ):
195
+ continue
196
+ elif "visual" in k and "q_proj" in k:
197
+ new_key = transform_key(k)
198
+ k_proj = weights.get(k.replace("q_proj", "k_proj"), None)
199
+ v_proj = weights.get(k.replace("q_proj", "v_proj"), None)
200
+ if k_proj is not None and v_proj is not None:
201
+ merged_tensor = mx.concatenate([v, k_proj, v_proj], axis=0)
202
+ merged_key = new_key.replace("q_proj", "qkv")
203
+ new_weights[merged_key] = merged_tensor
204
+ else:
205
+ new_weights[transform_key(k)] = v
206
+
207
+ return new_weights
@@ -0,0 +1,425 @@
1
+ import json
2
+ import math
3
+ from pathlib import Path
4
+ from typing import List, Optional, Union
5
+
6
+ import numpy as np
7
+ from transformers import AutoTokenizer
8
+ from transformers.feature_extraction_utils import BatchFeature
9
+ from transformers.image_processing_utils import BaseImageProcessor
10
+ from transformers.image_transforms import convert_to_rgb
11
+ from transformers.image_utils import (
12
+ ImageInput,
13
+ PILImageResampling,
14
+ make_flat_list_of_images,
15
+ to_numpy_array,
16
+ valid_images,
17
+ )
18
+ from transformers.processing_utils import ProcessorMixin
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def smart_resize(
25
+ height: int,
26
+ width: int,
27
+ factor: int,
28
+ min_pixels: int,
29
+ max_pixels: int,
30
+ ):
31
+ if height < factor:
32
+ width = round((width * factor) / height)
33
+ height = factor
34
+
35
+ if width < factor:
36
+ height = round((height * factor) / width)
37
+ width = factor
38
+
39
+ if max(height, width) / min(height, width) > 200:
40
+ raise ValueError(
41
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
42
+ )
43
+ h_bar = round(height / factor) * factor
44
+ w_bar = round(width / factor) * factor
45
+ if h_bar * w_bar > max_pixels:
46
+ beta = math.sqrt((height * width) / max_pixels)
47
+ h_bar = math.floor(height / beta / factor) * factor
48
+ w_bar = math.floor(width / beta / factor) * factor
49
+ elif h_bar * w_bar < min_pixels:
50
+ beta = math.sqrt(min_pixels / (height * width))
51
+ h_bar = math.ceil(height * beta / factor) * factor
52
+ w_bar = math.ceil(width * beta / factor) * factor
53
+ return h_bar, w_bar
54
+
55
+
56
+ class ImageProcessor(BaseImageProcessor):
57
+ """
58
+ MLX-native image processor for PaddleOCRVL that doesn't require torch.
59
+ """
60
+
61
+ model_input_names = ["pixel_values"]
62
+
63
+ def __init__(
64
+ self,
65
+ do_resize: bool = True,
66
+ size: dict[str, int] | None = None,
67
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
68
+ do_rescale: bool = True,
69
+ rescale_factor: int | float = 1 / 255,
70
+ do_normalize: bool = True,
71
+ image_mean: float | list[float] | None = None,
72
+ image_std: float | list[float] | None = None,
73
+ do_convert_rgb: bool = True,
74
+ min_pixels: int = 147384,
75
+ max_pixels: int = 2822400,
76
+ patch_size: int = 14,
77
+ temporal_patch_size: int = 1,
78
+ merge_size: int = 2,
79
+ **kwargs,
80
+ ) -> None:
81
+ super().__init__(**kwargs)
82
+ if size is not None:
83
+ if "shortest_edge" not in size or "longest_edge" not in size:
84
+ raise ValueError(
85
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
86
+ )
87
+ else:
88
+ size = {"shortest_edge": 147384, "longest_edge": 2822400}
89
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
90
+ if min_pixels is not None:
91
+ size["shortest_edge"] = min_pixels
92
+ if max_pixels is not None:
93
+ size["longest_edge"] = max_pixels
94
+ self.min_pixels = size["shortest_edge"]
95
+ self.max_pixels = size["longest_edge"]
96
+ self.size = size
97
+ self.do_resize = do_resize
98
+ self.resample = resample
99
+ self.do_rescale = do_rescale
100
+ self.rescale_factor = rescale_factor
101
+ self.do_normalize = do_normalize
102
+ self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
103
+ self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
104
+ self.patch_size = patch_size
105
+ self.temporal_patch_size = temporal_patch_size
106
+ self.merge_size = merge_size
107
+ self.do_convert_rgb = do_convert_rgb
108
+
109
+ def preprocess(
110
+ self,
111
+ images: ImageInput,
112
+ do_resize: Optional[bool] = None,
113
+ size: Optional[dict[str, int]] = None,
114
+ min_pixels: Optional[int] = None,
115
+ max_pixels: Optional[int] = None,
116
+ resample: Optional[PILImageResampling] = None,
117
+ do_rescale: Optional[bool] = None,
118
+ rescale_factor: Optional[float] = None,
119
+ do_normalize: Optional[bool] = None,
120
+ image_mean: Optional[Union[float, list[float]]] = None,
121
+ image_std: Optional[Union[float, list[float]]] = None,
122
+ patch_size: Optional[int] = None,
123
+ temporal_patch_size: Optional[int] = None,
124
+ merge_size: Optional[int] = None,
125
+ do_convert_rgb: Optional[bool] = None,
126
+ return_tensors: Optional[str] = None,
127
+ **kwargs,
128
+ ) -> BatchFeature:
129
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
130
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
131
+
132
+ if size is not None:
133
+ if "shortest_edge" not in size or "longest_edge" not in size:
134
+ raise ValueError(
135
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
136
+ )
137
+ elif min_pixels is not None and max_pixels is not None:
138
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
139
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
140
+
141
+ do_resize = do_resize if do_resize is not None else self.do_resize
142
+ resample = resample if resample is not None else self.resample
143
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
144
+ rescale_factor = (
145
+ rescale_factor if rescale_factor is not None else self.rescale_factor
146
+ )
147
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
148
+ image_mean = image_mean if image_mean is not None else self.image_mean
149
+ image_std = image_std if image_std is not None else self.image_std
150
+ patch_size = patch_size if patch_size is not None else self.patch_size
151
+ temporal_patch_size = (
152
+ temporal_patch_size
153
+ if temporal_patch_size is not None
154
+ else self.temporal_patch_size
155
+ )
156
+ merge_size = merge_size if merge_size is not None else self.merge_size
157
+ do_convert_rgb = (
158
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
159
+ )
160
+
161
+ if images is not None:
162
+ images = make_flat_list_of_images(images)
163
+
164
+ if images is not None and not valid_images(images):
165
+ raise ValueError(
166
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
167
+ "torch.Tensor, tf.Tensor or jax.ndarray."
168
+ )
169
+
170
+ if do_convert_rgb:
171
+ images = [convert_to_rgb(image) for image in images]
172
+
173
+ data = {}
174
+ pixel_values, vision_grid_thws = [], []
175
+ if images is not None:
176
+ processed_images = []
177
+ for image in images:
178
+ width, height = image.size
179
+ resized_height, resized_width = smart_resize(
180
+ height,
181
+ width,
182
+ factor=patch_size * merge_size,
183
+ min_pixels=min_pixels,
184
+ max_pixels=max_pixels,
185
+ )
186
+ image = image.resize((resized_width, resized_height), resample)
187
+ img_array = to_numpy_array(image)
188
+
189
+ if do_rescale:
190
+ img_array = img_array / 255.0
191
+
192
+ if do_normalize:
193
+ mean = np.array(self.image_mean).reshape(1, 1, 3)
194
+ std = np.array(self.image_std).reshape(1, 1, 3)
195
+ img_array = (img_array - mean) / std
196
+
197
+ processed_images.append(img_array)
198
+
199
+ patches = np.array(processed_images)
200
+
201
+ if patches.shape[1] > 3:
202
+ patches = patches.transpose(0, 3, 1, 2)
203
+ if patches.shape[0] == 1:
204
+ patches = np.tile(patches, (temporal_patch_size, 1, 1, 1))
205
+
206
+ channel = patches.shape[1]
207
+ grid_t = patches.shape[0] // temporal_patch_size
208
+ grid_h, grid_w = (
209
+ resized_height // patch_size,
210
+ resized_width // patch_size,
211
+ )
212
+ patches = patches.reshape(
213
+ grid_t,
214
+ temporal_patch_size,
215
+ channel,
216
+ grid_h,
217
+ patch_size,
218
+ grid_w,
219
+ patch_size,
220
+ )
221
+ patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
222
+ if temporal_patch_size != 1:
223
+ raise ValueError(
224
+ f"temporal_patch_size must be 1!, but got {temporal_patch_size}!"
225
+ )
226
+ flatten_patches = patches.reshape(
227
+ grid_t * grid_h * grid_w, channel, patch_size, patch_size
228
+ )
229
+ image_grid_thw = (grid_t, grid_h, grid_w)
230
+ pixel_values.extend(flatten_patches)
231
+ vision_grid_thws.append(image_grid_thw)
232
+
233
+ pixel_values = np.array([pixel_values])
234
+ vision_grid_thws = np.array(vision_grid_thws)
235
+ data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
236
+
237
+ return BatchFeature(data, tensor_type=return_tensors)
238
+
239
+
240
+ class PaddleOCRVLProcessor(ProcessorMixin):
241
+ attributes = ["image_processor", "tokenizer"]
242
+ valid_kwargs = ["chat_template"]
243
+ image_processor_class = "AutoImageProcessor"
244
+ tokenizer_class = "AutoTokenizer"
245
+
246
+ def __init__(
247
+ self,
248
+ image_processor=None,
249
+ tokenizer=None,
250
+ chat_template=None,
251
+ **kwargs,
252
+ ):
253
+
254
+ if image_processor is None:
255
+ image_processor = ImageProcessor(**kwargs)
256
+
257
+ self.tokenizer = tokenizer
258
+ self.image_token = (
259
+ "<|IMAGE_PLACEHOLDER|>"
260
+ if not hasattr(tokenizer, "image_token")
261
+ else tokenizer.image_token
262
+ )
263
+ self.image_processor = image_processor
264
+
265
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
266
+
267
+ def __call__(
268
+ self,
269
+ images=None,
270
+ text: Union[str, List[str]] = None,
271
+ **kwargs,
272
+ ) -> BatchFeature:
273
+ """Process images and text for the model.
274
+
275
+ Args:
276
+ images: Single image or list of images
277
+ text: Single text or list of texts
278
+ videos: Video inputs (not currently supported)
279
+ **kwargs: Additional arguments passed to tokenizer
280
+
281
+ Returns:
282
+ BatchFeature with:
283
+ - input_ids: Token IDs with image placeholders replaced
284
+ - attention_mask: Attention mask
285
+ - pixel_values: Processed image patches
286
+ - image_grid_thw: Grid dimensions for each image
287
+ - position_ids: 4D position IDs for xdrope
288
+ """
289
+ image_inputs = {}
290
+
291
+ if images is not None:
292
+ image_inputs = self.image_processor(images=images)
293
+ image_grid_thw = image_inputs["image_grid_thw"]
294
+
295
+ if text is None:
296
+ text = [""]
297
+ elif not isinstance(text, list):
298
+ text = [text]
299
+
300
+ text = [t for t in text] # Copy to avoid modifying original
301
+
302
+ if images is not None:
303
+ index = 0
304
+ for i in range(len(text)):
305
+ while self.image_token in text[i]:
306
+ text[i] = text[i].replace(
307
+ self.image_token,
308
+ "<|placeholder|>"
309
+ * (
310
+ image_grid_thw[index].prod()
311
+ // self.image_processor.merge_size
312
+ // self.image_processor.merge_size
313
+ ),
314
+ 1,
315
+ )
316
+ index += 1
317
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
318
+
319
+ # Pop return_tensors to handle it ourselves at the end
320
+ return_tensors = kwargs.pop("return_tensors", None)
321
+
322
+ # Tokenize text
323
+ text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
324
+
325
+ # Get input_ids and convert to numpy array for processing
326
+ input_ids = text_inputs["input_ids"]
327
+ if hasattr(input_ids, "tolist"):
328
+ # Handle mlx arrays or torch tensors
329
+ input_ids = np.array(input_ids.tolist())
330
+ elif isinstance(input_ids, list):
331
+ input_ids = np.array(input_ids)
332
+
333
+ return BatchFeature(
334
+ data={**text_inputs, **image_inputs},
335
+ tensor_type=return_tensors,
336
+ )
337
+
338
+ def batch_decode(self, *args, **kwargs):
339
+ """Decode token IDs to text."""
340
+ return self.tokenizer.batch_decode(*args, **kwargs)
341
+
342
+ def decode(self, *args, **kwargs):
343
+ """Decode token IDs to text."""
344
+ return self.tokenizer.decode(*args, **kwargs)
345
+
346
+ def apply_chat_template(self, *args, **kwargs):
347
+ """Apply chat template using the tokenizer."""
348
+ return self.tokenizer.apply_chat_template(*args, **kwargs)
349
+
350
+ @property
351
+ def model_input_names(self):
352
+ """Return combined input names from tokenizer and image processor."""
353
+ tokenizer_input_names = (
354
+ self.tokenizer.model_input_names if self.tokenizer else []
355
+ )
356
+ image_processor_input_names = self.image_processor.model_input_names
357
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
358
+
359
+ @classmethod
360
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
361
+ """Load processor from pretrained model path."""
362
+ import warnings
363
+
364
+ from huggingface_hub import hf_hub_download
365
+
366
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
367
+
368
+ model_path = Path(pretrained_model_name_or_path)
369
+ is_local = model_path.exists() and model_path.is_dir()
370
+
371
+ # Suppress warning about mrope_section in rope_parameters
372
+ with warnings.catch_warnings():
373
+ warnings.filterwarnings(
374
+ "ignore", message="Unrecognized keys in `rope_parameters`"
375
+ )
376
+ tokenizer = AutoTokenizer.from_pretrained(
377
+ str(model_path) if is_local else pretrained_model_name_or_path,
378
+ trust_remote_code=trust_remote_code,
379
+ local_files_only=is_local,
380
+ **kwargs,
381
+ )
382
+
383
+ # Load image processor config from preprocessor_config.json
384
+ image_processor_config = {}
385
+ try:
386
+ if is_local:
387
+ config_path = model_path / "preprocessor_config.json"
388
+ else:
389
+ config_path = Path(
390
+ hf_hub_download(
391
+ pretrained_model_name_or_path, "preprocessor_config.json"
392
+ )
393
+ )
394
+ if config_path.exists():
395
+ with open(config_path, "r", encoding="utf-8") as f:
396
+ preprocessor_config = json.load(f)
397
+ # Extract relevant image processor parameters
398
+ relevant_keys = [
399
+ "min_pixels",
400
+ "max_pixels",
401
+ "patch_size",
402
+ "temporal_patch_size",
403
+ "merge_size",
404
+ "image_mean",
405
+ "image_std",
406
+ "do_resize",
407
+ "do_rescale",
408
+ "do_normalize",
409
+ "do_convert_rgb",
410
+ ]
411
+ for key in relevant_keys:
412
+ if key in preprocessor_config:
413
+ image_processor_config[key] = preprocessor_config[key]
414
+
415
+ except Exception:
416
+ pass
417
+
418
+ image_processor = ImageProcessor(**image_processor_config)
419
+ return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
420
+
421
+
422
+ __all__ = [
423
+ "PaddleOCRVLProcessor",
424
+ "ImageProcessor",
425
+ ]