fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,624 @@
1
+ """
2
+ From https://github.com/deepseek-ai/DeepSeek-VL2
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Literal, Optional, Tuple
7
+
8
+ import mlx.core as mx
9
+ import numpy as np
10
+ from PIL import Image, ImageOps
11
+ from transformers import LlamaTokenizerFast
12
+ from transformers.processing_utils import ProcessorMixin
13
+
14
+
15
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
16
+ best_ratio_diff = float("inf")
17
+ best_ratio = (1, 1)
18
+ area = width * height
19
+ for ratio in target_ratios:
20
+ target_aspect_ratio = ratio[0] / ratio[1]
21
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
22
+ if ratio_diff < best_ratio_diff:
23
+ best_ratio_diff = ratio_diff
24
+ best_ratio = ratio
25
+ elif ratio_diff == best_ratio_diff:
26
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
27
+ best_ratio = ratio
28
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
29
+ return best_ratio
30
+
31
+
32
+ def dynamic_preprocess(
33
+ image, min_num=2, max_num=9, image_size=640, use_thumbnail=False
34
+ ):
35
+ orig_width, orig_height = image.size
36
+ aspect_ratio = orig_width / orig_height
37
+
38
+ # calculate the existing image aspect ratio
39
+ target_ratios = set(
40
+ (i, j)
41
+ for n in range(min_num, max_num + 1)
42
+ for i in range(1, n + 1)
43
+ for j in range(1, n + 1)
44
+ if i * j <= max_num and i * j >= min_num
45
+ )
46
+ # print(target_ratios)
47
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
48
+
49
+ # find the closest aspect ratio to the target
50
+ target_aspect_ratio = find_closest_aspect_ratio(
51
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
52
+ )
53
+
54
+ # print(target_aspect_ratio)
55
+ # calculate the target width and height
56
+ target_width = image_size * target_aspect_ratio[0]
57
+ target_height = image_size * target_aspect_ratio[1]
58
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
59
+
60
+ # resize the image
61
+ resized_img = image.resize((target_width, target_height))
62
+ processed_images = []
63
+ for i in range(blocks):
64
+ box = (
65
+ (i % (target_width // image_size)) * image_size,
66
+ (i // (target_width // image_size)) * image_size,
67
+ ((i % (target_width // image_size)) + 1) * image_size,
68
+ ((i // (target_width // image_size)) + 1) * image_size,
69
+ )
70
+ # split the image
71
+ split_img = resized_img.crop(box)
72
+ processed_images.append(split_img)
73
+ assert len(processed_images) == blocks
74
+ if use_thumbnail and len(processed_images) != 1:
75
+ thumbnail_img = image.resize((image_size, image_size))
76
+ processed_images.append(thumbnail_img)
77
+ return processed_images, target_aspect_ratio
78
+
79
+
80
+ class DictOutput(object):
81
+ def keys(self):
82
+ return self.__dict__.keys()
83
+
84
+ def __getitem__(self, item):
85
+ if isinstance(item, int):
86
+ return list(self.__dict__.values())[item]
87
+ if item not in self.__dict__:
88
+ raise KeyError(item)
89
+ return self.__dict__[item]
90
+
91
+ def __setitem__(self, key, value):
92
+ self.__dict__[key] = value
93
+
94
+
95
+ @dataclass
96
+ class VLChatProcessorOutput(DictOutput):
97
+ sft_format: str
98
+ input_ids: mx.array
99
+ target_ids: mx.array
100
+ images: mx.array
101
+ images_seq_mask: mx.array
102
+ images_spatial_crop: mx.array
103
+ num_image_tokens: List[int]
104
+
105
+ def __len__(self):
106
+ return len(self.input_ids)
107
+
108
+
109
+ @dataclass
110
+ class BatchCollateOutput(DictOutput):
111
+ sft_format: List[str]
112
+ input_ids: mx.array
113
+ labels: mx.array
114
+ images: mx.array
115
+ attention_mask: mx.array
116
+ images_seq_mask: mx.array
117
+ images_spatial_crop: mx.array
118
+ seq_lens: List[int]
119
+
120
+
121
+ class ImageTransform:
122
+ def __init__(
123
+ self,
124
+ mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
125
+ std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
126
+ normalize: bool = True,
127
+ ):
128
+ self.mean = mean
129
+ self.std = std
130
+ self.normalize = normalize
131
+
132
+ def __call__(self, pil_img: Image.Image):
133
+ # Convert PIL image to numpy array and normalize
134
+
135
+ img = mx.array(np.array(pil_img)) / 255.0
136
+
137
+ # Transpose from HWC to CHW format
138
+ img = mx.transpose(img, [2, 0, 1])
139
+
140
+ if self.normalize:
141
+ mean = mx.array(self.mean).reshape(-1, 1, 1)
142
+ std = mx.array(self.std).reshape(-1, 1, 1)
143
+ img = (img - mean) / std
144
+
145
+ return img
146
+
147
+
148
+ class DeepseekOCR2Processor(ProcessorMixin):
149
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
150
+ attributes = ["tokenizer"]
151
+
152
+ def __init__(
153
+ self,
154
+ tokenizer: LlamaTokenizerFast,
155
+ candidate_resolutions: Tuple[Tuple[int, int]],
156
+ patch_size: int,
157
+ downsample_ratio: int,
158
+ image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
159
+ image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
160
+ normalize: bool = True,
161
+ image_token: str = "<image>",
162
+ pad_token: str = "<|▁pad▁|>",
163
+ add_special_token: bool = False,
164
+ sft_format: str = "deepseek",
165
+ mask_prompt: bool = True,
166
+ ignore_id: int = -100,
167
+ **kwargs,
168
+ ):
169
+ self.candidate_resolutions = candidate_resolutions
170
+ self.image_size = candidate_resolutions[0][0]
171
+ self.patch_size = patch_size
172
+ self.image_mean = image_mean
173
+ self.image_std = image_std
174
+ self.normalize = normalize
175
+ self.downsample_ratio = downsample_ratio
176
+
177
+ self.image_transform = ImageTransform(
178
+ mean=image_mean, std=image_std, normalize=normalize
179
+ )
180
+ self.tokenizer = tokenizer
181
+ self.tokenizer.padding_side = "left"
182
+
183
+ # Add special tokens
184
+ if tokenizer.pad_token is None:
185
+ self.tokenizer.add_special_tokens({"pad_token": pad_token})
186
+ print(
187
+ f"Add pad token = ['{pad_token}'] to the tokenizer\n"
188
+ f"{pad_token}:{tokenizer.encode(pad_token, add_special_tokens=False)[0]}"
189
+ )
190
+
191
+ image_token_id = self.tokenizer.vocab.get(image_token)
192
+ if image_token_id is None:
193
+ special_tokens = [image_token]
194
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
195
+ self.tokenizer.add_special_tokens(special_tokens_dict)
196
+ self.image_token_id = self.tokenizer.vocab.get(image_token)
197
+ print(
198
+ f"Add image token = ['{image_token}'] to the tokenizer\n"
199
+ f"{image_token}:{tokenizer.encode(image_token, add_special_tokens=False)[0]}"
200
+ )
201
+
202
+ # Add grounding-related tokens
203
+ special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
204
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
205
+ self.tokenizer.add_special_tokens(special_tokens_dict)
206
+ print("Added grounding-related tokens")
207
+
208
+ # Add chat tokens
209
+ special_tokens = ["<|User|>", "<|Assistant|>"]
210
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
211
+ self.tokenizer.add_special_tokens(special_tokens_dict)
212
+ print("Added chat tokens")
213
+
214
+ self.image_token = image_token
215
+ self.pad_token = pad_token
216
+ self.add_special_token = add_special_token
217
+ self.sft_format = sft_format
218
+ self.mask_prompt = mask_prompt
219
+ self.ignore_id = ignore_id
220
+
221
+ super().__init__(tokenizer, **kwargs)
222
+
223
+ # Add chat template
224
+ self.chat_template = kwargs.pop("chat_template", self.default_chat_template)
225
+
226
+ @property
227
+ def default_chat_template(self):
228
+ return (
229
+ "{% for message in messages %}"
230
+ "{% if message['role'] == 'user' %}"
231
+ "{% elif message['role'] == 'assistant' %}{% endif %}"
232
+ "{{message['content']}} "
233
+ "{% endfor %}"
234
+ "{% if add_generation_prompt %}{% endif %}"
235
+ )
236
+
237
+ @property
238
+ def bos_id(self):
239
+ return self.tokenizer.bos_token_id
240
+
241
+ @property
242
+ def eos_id(self):
243
+ return self.tokenizer.eos_token_id
244
+
245
+ @property
246
+ def pad_id(self):
247
+ return self.tokenizer.pad_token_id
248
+
249
+ def encode(self, text: str, bos: bool = True, eos: bool = False):
250
+ t = self.tokenizer.encode(text, add_special_tokens=False)
251
+
252
+ if bos:
253
+ t = [self.bos_id] + t
254
+ if eos:
255
+ t = t + [self.eos_id]
256
+
257
+ return t
258
+
259
+ def decode(self, t: List[int], **kwargs) -> str:
260
+ return self.tokenizer.decode(t, **kwargs)
261
+
262
+ def process_one(
263
+ self,
264
+ prompt: str = None,
265
+ images: List[Image.Image] = None,
266
+ inference_mode: bool = True,
267
+ base_size: int = 1024,
268
+ image_size: int = 768,
269
+ cropping: bool = True,
270
+ min_patches: int = 1,
271
+ max_patches: int = 6,
272
+ ):
273
+
274
+ sft_format = prompt
275
+ (
276
+ tokenized_str,
277
+ images_list,
278
+ images_seq_mask,
279
+ images_spatial_crop,
280
+ num_image_tokens,
281
+ ) = self.tokenize_with_images(
282
+ sft_format,
283
+ images,
284
+ base_size=base_size,
285
+ image_size=image_size,
286
+ cropping=cropping,
287
+ min_patches=min_patches,
288
+ max_patches=max_patches,
289
+ )
290
+
291
+ masked_tokenized_str = []
292
+ for token_index in tokenized_str:
293
+ if token_index != self.image_token_id:
294
+ masked_tokenized_str.append(token_index)
295
+ else:
296
+ masked_tokenized_str.append(self.ignore_id)
297
+
298
+ input_ids = mx.array(tokenized_str)
299
+ target_ids = mx.array(masked_tokenized_str)
300
+ images_seq_mask = mx.array(images_seq_mask)
301
+
302
+ # Set ignored indices
303
+ target_ids = mx.where(
304
+ (input_ids < 0) | (input_ids == self.image_token_id),
305
+ self.ignore_id,
306
+ target_ids,
307
+ )
308
+ input_ids = mx.where(input_ids < 0, self.pad_id, input_ids)
309
+
310
+ if inference_mode:
311
+ input_ids = input_ids[:-1]
312
+ target_ids = target_ids[:-1]
313
+ images_seq_mask = images_seq_mask[:-1]
314
+
315
+ return {
316
+ "input_ids": input_ids[None, :],
317
+ "attention_mask": input_ids != self.pad_id,
318
+ "labels": target_ids,
319
+ "images": images_list,
320
+ "images_seq_mask": images_seq_mask[None, ...],
321
+ "images_spatial_crop": images_spatial_crop,
322
+ "num_image_tokens": num_image_tokens,
323
+ }
324
+
325
+ def pad_sequence(self, sequences, padding_value):
326
+ # Get max length of sequences
327
+ max_len = max(len(seq) for seq in sequences)
328
+
329
+ # Pad each sequence to max length
330
+ padded_seqs = []
331
+ for seq in sequences:
332
+ pad_length = max_len - len(seq)
333
+ if pad_length > 0:
334
+ padding = mx.full((pad_length,), padding_value)
335
+ padded_seq = mx.concatenate([seq, padding])
336
+ else:
337
+ padded_seq = seq
338
+ padded_seqs.append(padded_seq)
339
+
340
+ return mx.stack(padded_seqs)
341
+
342
+ def tokenize_with_images(
343
+ self,
344
+ conversation: str,
345
+ images: List[Image.Image],
346
+ base_size: int = 1024,
347
+ image_size: int = 768,
348
+ cropping: bool = True,
349
+ min_patches: int = 1,
350
+ max_patches: int = 6,
351
+ ):
352
+ """Tokenize text with <image> tags.
353
+
354
+ For DeepSeek-OCR-2 with Qwen2 encoder:
355
+ - Global view (1024x1024): 256 tokens from Qwen2 encoder
356
+ - Local patches (768x768): 144 tokens each from Qwen2 encoder
357
+ - Plus 1 view_separator token
358
+
359
+ Dynamic resolution:
360
+ - Total tokens = (num_patches * 144) + 256 + 1
361
+ - Default: 0-6 patches at 768x768 + 1 global at 1024x1024
362
+ """
363
+ # Token counts for Qwen2 encoder
364
+ TOKENS_PER_PATCH = 144 # 12x12 SAM features for 768x768
365
+ TOKENS_PER_GLOBAL = 256 # 16x16 SAM features for 1024x1024
366
+ TOKENS_VIEW_SEP = 1
367
+
368
+ assert conversation.count(self.image_token) == len(
369
+ images
370
+ ), f"The number of image tokens in the prompt does not match the number of images: {conversation.count(self.image_token)} != {len(images)}"
371
+
372
+ text_splits = conversation.split(self.image_token)
373
+
374
+ all_patches_list = []
375
+ all_global_list = []
376
+ images_seq_mask = []
377
+ tokenized_str = []
378
+ images_spatial_crop = []
379
+ num_image_tokens_list = []
380
+
381
+ for text_sep, image in zip(text_splits, images):
382
+ # Tokenize the text before this image
383
+ tokenized_sep = self.encode(text_sep, bos=False, eos=False)
384
+ tokenized_str += tokenized_sep
385
+ images_seq_mask += [False] * len(tokenized_sep)
386
+
387
+ # Process global view: pad to base_size x base_size (1024x1024)
388
+ global_view = ImageOps.pad(
389
+ image,
390
+ (base_size, base_size),
391
+ color=tuple(int(x * 255) for x in self.image_transform.mean),
392
+ )
393
+ global_tensor = self.image_transform(global_view).astype(mx.bfloat16)
394
+ all_global_list.append(global_tensor)
395
+
396
+ # Process local patches using dynamic resolution
397
+ if cropping and min_patches > 0:
398
+ # Use dynamic_preprocess to split image into patches
399
+ patches, (rows, cols) = dynamic_preprocess(
400
+ image,
401
+ min_num=min_patches,
402
+ max_num=max_patches,
403
+ image_size=image_size, # 768x768 patches
404
+ use_thumbnail=False,
405
+ )
406
+ num_patches = len(patches)
407
+
408
+ # Transform each patch
409
+ patch_tensors = []
410
+ for patch in patches:
411
+ patch_tensor = self.image_transform(patch).astype(mx.bfloat16)
412
+ patch_tensors.append(patch_tensor)
413
+
414
+ if patch_tensors:
415
+ patches_stacked = mx.stack(patch_tensors, axis=0)
416
+ all_patches_list.append(patches_stacked)
417
+
418
+ images_spatial_crop.append([rows, cols])
419
+ else:
420
+ # No patches, only global view
421
+ num_patches = 0
422
+ images_spatial_crop.append([0, 0])
423
+
424
+ # Calculate number of image tokens for this image
425
+ # Order: [local_patches, global_view, view_separator]
426
+ num_image_tokens = (
427
+ (num_patches * TOKENS_PER_PATCH) + TOKENS_PER_GLOBAL + TOKENS_VIEW_SEP
428
+ )
429
+ num_image_tokens_list.append(num_image_tokens)
430
+
431
+ # Add image tokens to sequence
432
+ tokenized_image = [self.image_token_id] * num_image_tokens
433
+ tokenized_str += tokenized_image
434
+ images_seq_mask += [True] * len(tokenized_image)
435
+
436
+ # Tokenize the text after the last image
437
+ tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
438
+ tokenized_str += tokenized_sep
439
+ images_seq_mask += [False] * len(tokenized_sep)
440
+
441
+ # Add the bos token
442
+ bos_id = 0
443
+ tokenized_str = [bos_id] + tokenized_str
444
+ images_seq_mask = [False] + images_seq_mask
445
+
446
+ images_seq_mask = mx.array(images_seq_mask)
447
+
448
+ # Stack global images
449
+ if len(all_global_list) == 0:
450
+ images_ori = mx.zeros((1, 3, base_size, base_size))
451
+ images_spatial_crop = mx.zeros((1, 2))
452
+ else:
453
+ images_ori = mx.stack(all_global_list, axis=0)
454
+ images_spatial_crop = mx.array(images_spatial_crop)
455
+
456
+ # Stack patches (or zeros if no patches)
457
+ if all_patches_list:
458
+ # Concatenate all patches from all images
459
+ images_patches = mx.concatenate(all_patches_list, axis=0)
460
+ else:
461
+ images_patches = mx.zeros((1, 3, image_size, image_size))
462
+
463
+ assert len(tokenized_str) == len(
464
+ images_seq_mask
465
+ ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to images_seq_mask's length {len(images_seq_mask)}"
466
+
467
+ return (
468
+ tokenized_str,
469
+ [images_patches, images_ori],
470
+ images_seq_mask,
471
+ images_spatial_crop,
472
+ num_image_tokens_list[0] if num_image_tokens_list else 257,
473
+ )
474
+
475
+ def __call__(
476
+ self,
477
+ *,
478
+ text: str = None,
479
+ images: List[Image.Image] = None,
480
+ inference_mode: bool = True,
481
+ image_size: int = 768,
482
+ base_size: int = 1024,
483
+ cropping: bool = True,
484
+ min_patches: int = 1,
485
+ max_patches: int = 6,
486
+ padding: bool = True,
487
+ return_tensors: Literal["np", "mx", "pt"] = "mx",
488
+ **kwargs,
489
+ ):
490
+ """Process text and images for DeepSeek-OCR-2.
491
+
492
+ Args:
493
+ text (str or List[str]): the formatted prompt(s)
494
+ images (List[ImageType]): the list of images (one per prompt for batched inputs)
495
+ inference_mode (bool): if True, remove the last eos token
496
+ image_size (int): size of local patches (default 768)
497
+ base_size (int): size of global view (default 1024)
498
+ cropping (bool): whether to use dynamic resolution with local patches
499
+ min_patches (int): minimum number of patches (default 1)
500
+ max_patches (int): maximum number of patches (default 6)
501
+
502
+ Returns:
503
+ outputs (dict): the output of the processor,
504
+ - input_ids (mx.array): [batch_size, N + image tokens]
505
+ - images (List[mx.array]): [patches, global_images]
506
+ - images_seq_mask (mx.array): mask for image token positions
507
+ - images_spatial_crop (mx.array): patch grid dimensions
508
+ """
509
+
510
+ # Handle batched inputs (list of prompts with list of images)
511
+ if isinstance(text, list):
512
+ if images is None:
513
+ images = [None] * len(text)
514
+
515
+ batch_results = []
516
+ for i, prompt in enumerate(text):
517
+ # Each prompt has one image
518
+ img = [images[i]] if images[i] is not None else None
519
+ result = self.process_one(
520
+ prompt=prompt,
521
+ images=img,
522
+ inference_mode=inference_mode,
523
+ image_size=image_size,
524
+ base_size=base_size,
525
+ cropping=cropping,
526
+ min_patches=min_patches,
527
+ max_patches=max_patches,
528
+ )
529
+ batch_results.append(result)
530
+
531
+ # Collate batch results
532
+ return self._collate_batch(batch_results, padding=padding)
533
+
534
+ # Single input case
535
+ prepare = self.process_one(
536
+ prompt=text,
537
+ images=images,
538
+ inference_mode=inference_mode,
539
+ image_size=image_size,
540
+ base_size=base_size,
541
+ cropping=cropping,
542
+ min_patches=min_patches,
543
+ max_patches=max_patches,
544
+ )
545
+
546
+ return prepare
547
+
548
+ def _collate_batch(self, batch_results: List[Dict], padding: bool = True) -> Dict:
549
+ """Collate multiple processed results into a batch."""
550
+ if not batch_results:
551
+ return {}
552
+
553
+ batch_size = len(batch_results)
554
+
555
+ # Get max sequence length for padding
556
+ max_seq_len = max(r["input_ids"].shape[1] for r in batch_results)
557
+
558
+ # Pad and stack input_ids
559
+ padded_input_ids = []
560
+ padded_images_seq_mask = []
561
+ for r in batch_results:
562
+ seq_len = r["input_ids"].shape[1]
563
+ pad_len = max_seq_len - seq_len
564
+
565
+ if pad_len > 0:
566
+ # Pad input_ids on the left
567
+ input_ids = mx.concatenate(
568
+ [
569
+ mx.full((1, pad_len), self.pad_id, dtype=r["input_ids"].dtype),
570
+ r["input_ids"],
571
+ ],
572
+ axis=1,
573
+ )
574
+ # Pad images_seq_mask on the left with False
575
+ seq_mask = mx.concatenate(
576
+ [mx.zeros((1, pad_len), dtype=mx.bool_), r["images_seq_mask"]],
577
+ axis=1,
578
+ )
579
+ else:
580
+ input_ids = r["input_ids"]
581
+ seq_mask = r["images_seq_mask"]
582
+
583
+ padded_input_ids.append(input_ids)
584
+ padded_images_seq_mask.append(seq_mask)
585
+
586
+ # Stack into batch
587
+ input_ids = mx.concatenate(padded_input_ids, axis=0)
588
+ images_seq_mask = mx.concatenate(padded_images_seq_mask, axis=0)
589
+
590
+ # Combine images: [patches, global_images]
591
+ all_patches = []
592
+ all_global_images = []
593
+ all_spatial_crops = []
594
+
595
+ for r in batch_results:
596
+ patches, global_img = r["images"]
597
+ # Only add non-zero patches
598
+ if mx.sum(patches).item() != 0:
599
+ all_patches.append(patches)
600
+ all_global_images.append(global_img)
601
+ all_spatial_crops.append(r["images_spatial_crop"])
602
+
603
+ # Stack patches and global images
604
+ if all_patches:
605
+ combined_patches = mx.concatenate(all_patches, axis=0)
606
+ else:
607
+ combined_patches = mx.zeros((1, 3, 1024, 1024))
608
+
609
+ combined_global_images = mx.concatenate(all_global_images, axis=0)
610
+ combined_spatial_crops = mx.concatenate(all_spatial_crops, axis=0)
611
+
612
+ return {
613
+ "input_ids": input_ids,
614
+ "attention_mask": input_ids != self.pad_id,
615
+ "images": [combined_patches, combined_global_images],
616
+ "images_seq_mask": images_seq_mask,
617
+ "images_spatial_crop": combined_spatial_crops,
618
+ }
619
+
620
+
621
+ # Install a composable AutoProcessor patch for DeepSeek-OCR-2
622
+ from ..base import install_auto_processor_patch
623
+
624
+ install_auto_processor_patch("deepseekocr_2", DeepseekOCR2Processor)