fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,655 @@
1
+ """
2
+ From https://github.com/deepseek-ai/DeepSeek-VL2
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Literal, Optional, Tuple
8
+
9
+ import mlx.core as mx
10
+ import numpy as np
11
+ from PIL import Image, ImageOps
12
+ from transformers import LlamaTokenizerFast
13
+ from transformers.processing_utils import ProcessorMixin
14
+
15
+
16
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
17
+ best_ratio_diff = float("inf")
18
+ best_ratio = (1, 1)
19
+ area = width * height
20
+ for ratio in target_ratios:
21
+ target_aspect_ratio = ratio[0] / ratio[1]
22
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
23
+ if ratio_diff < best_ratio_diff:
24
+ best_ratio_diff = ratio_diff
25
+ best_ratio = ratio
26
+ elif ratio_diff == best_ratio_diff:
27
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
28
+ best_ratio = ratio
29
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
30
+ return best_ratio
31
+
32
+
33
+ def dynamic_preprocess(
34
+ image, min_num=2, max_num=9, image_size=640, use_thumbnail=False
35
+ ):
36
+ orig_width, orig_height = image.size
37
+ aspect_ratio = orig_width / orig_height
38
+
39
+ # calculate the existing image aspect ratio
40
+ target_ratios = set(
41
+ (i, j)
42
+ for n in range(min_num, max_num + 1)
43
+ for i in range(1, n + 1)
44
+ for j in range(1, n + 1)
45
+ if i * j <= max_num and i * j >= min_num
46
+ )
47
+ # print(target_ratios)
48
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
49
+
50
+ # find the closest aspect ratio to the target
51
+ target_aspect_ratio = find_closest_aspect_ratio(
52
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
53
+ )
54
+
55
+ # print(target_aspect_ratio)
56
+ # calculate the target width and height
57
+ target_width = image_size * target_aspect_ratio[0]
58
+ target_height = image_size * target_aspect_ratio[1]
59
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
60
+
61
+ # resize the image
62
+ resized_img = image.resize((target_width, target_height))
63
+ processed_images = []
64
+ for i in range(blocks):
65
+ box = (
66
+ (i % (target_width // image_size)) * image_size,
67
+ (i // (target_width // image_size)) * image_size,
68
+ ((i % (target_width // image_size)) + 1) * image_size,
69
+ ((i // (target_width // image_size)) + 1) * image_size,
70
+ )
71
+ # split the image
72
+ split_img = resized_img.crop(box)
73
+ processed_images.append(split_img)
74
+ assert len(processed_images) == blocks
75
+ if use_thumbnail and len(processed_images) != 1:
76
+ thumbnail_img = image.resize((image_size, image_size))
77
+ processed_images.append(thumbnail_img)
78
+ return processed_images, target_aspect_ratio
79
+
80
+
81
+ class DictOutput(object):
82
+ def keys(self):
83
+ return self.__dict__.keys()
84
+
85
+ def __getitem__(self, item):
86
+ if isinstance(item, int):
87
+ return list(self.__dict__.values())[item]
88
+ if item not in self.__dict__:
89
+ raise KeyError(item)
90
+ return self.__dict__[item]
91
+
92
+ def __setitem__(self, key, value):
93
+ self.__dict__[key] = value
94
+
95
+
96
+ @dataclass
97
+ class VLChatProcessorOutput(DictOutput):
98
+ sft_format: str
99
+ input_ids: mx.array
100
+ target_ids: mx.array
101
+ images: mx.array
102
+ images_seq_mask: mx.array
103
+ images_spatial_crop: mx.array
104
+ num_image_tokens: List[int]
105
+
106
+ def __len__(self):
107
+ return len(self.input_ids)
108
+
109
+
110
+ @dataclass
111
+ class BatchCollateOutput(DictOutput):
112
+ sft_format: List[str]
113
+ input_ids: mx.array
114
+ labels: mx.array
115
+ images: mx.array
116
+ attention_mask: mx.array
117
+ images_seq_mask: mx.array
118
+ images_spatial_crop: mx.array
119
+ seq_lens: List[int]
120
+
121
+
122
+ class ImageTransform:
123
+ def __init__(
124
+ self,
125
+ mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
126
+ std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
127
+ normalize: bool = True,
128
+ ):
129
+ self.mean = mean
130
+ self.std = std
131
+ self.normalize = normalize
132
+
133
+ def __call__(self, pil_img: Image.Image):
134
+ # Convert PIL image to numpy array and normalize
135
+
136
+ img = mx.array(np.array(pil_img)) / 255.0
137
+
138
+ # Transpose from HWC to CHW format
139
+ img = mx.transpose(img, [2, 0, 1])
140
+
141
+ if self.normalize:
142
+ mean = mx.array(self.mean).reshape(-1, 1, 1)
143
+ std = mx.array(self.std).reshape(-1, 1, 1)
144
+ img = (img - mean) / std
145
+
146
+ return img
147
+
148
+
149
+ class DeepseekOCRProcessor(ProcessorMixin):
150
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
151
+ attributes = ["tokenizer"]
152
+
153
+ def __init__(
154
+ self,
155
+ tokenizer: LlamaTokenizerFast,
156
+ candidate_resolutions: Tuple[Tuple[int, int]],
157
+ patch_size: int,
158
+ downsample_ratio: int,
159
+ image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
160
+ image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
161
+ normalize: bool = True,
162
+ image_token: str = "<image>",
163
+ pad_token: str = "<|▁pad▁|>",
164
+ add_special_token: bool = False,
165
+ sft_format: str = "deepseek",
166
+ mask_prompt: bool = True,
167
+ ignore_id: int = -100,
168
+ **kwargs,
169
+ ):
170
+ self.candidate_resolutions = candidate_resolutions
171
+ self.image_size = candidate_resolutions[0][0]
172
+ self.patch_size = patch_size
173
+ self.image_mean = image_mean
174
+ self.image_std = image_std
175
+ self.normalize = normalize
176
+ self.downsample_ratio = downsample_ratio
177
+
178
+ self.image_transform = ImageTransform(
179
+ mean=image_mean, std=image_std, normalize=normalize
180
+ )
181
+ self.tokenizer = tokenizer
182
+ self.tokenizer.padding_side = "left"
183
+
184
+ # Add special tokens
185
+ if tokenizer.pad_token is None:
186
+ self.tokenizer.add_special_tokens({"pad_token": pad_token})
187
+ print(
188
+ f"Add pad token = ['{pad_token}'] to the tokenizer\n"
189
+ f"{pad_token}:{tokenizer.encode(pad_token, add_special_tokens=False)[0]}"
190
+ )
191
+
192
+ image_token_id = self.tokenizer.vocab.get(image_token)
193
+ if image_token_id is None:
194
+ special_tokens = [image_token]
195
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
196
+ self.tokenizer.add_special_tokens(special_tokens_dict)
197
+ self.image_token_id = self.tokenizer.vocab.get(image_token)
198
+ print(
199
+ f"Add image token = ['{image_token}'] to the tokenizer\n"
200
+ f"{image_token}:{tokenizer.encode(image_token, add_special_tokens=False)[0]}"
201
+ )
202
+
203
+ # Add grounding-related tokens
204
+ special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
205
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
206
+ self.tokenizer.add_special_tokens(special_tokens_dict)
207
+ print("Added grounding-related tokens")
208
+
209
+ # Add chat tokens
210
+ special_tokens = ["<|User|>", "<|Assistant|>"]
211
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
212
+ self.tokenizer.add_special_tokens(special_tokens_dict)
213
+ print("Added chat tokens")
214
+
215
+ self.image_token = image_token
216
+ self.pad_token = pad_token
217
+ self.add_special_token = add_special_token
218
+ self.sft_format = sft_format
219
+ self.mask_prompt = mask_prompt
220
+ self.ignore_id = ignore_id
221
+
222
+ super().__init__(tokenizer, **kwargs)
223
+
224
+ # Add chat template
225
+ self.chat_template = kwargs.pop("chat_template", self.default_chat_template)
226
+
227
+ @property
228
+ def default_chat_template(self):
229
+ return (
230
+ "{% for message in messages %}"
231
+ "{% if message['role'] == 'user' %}"
232
+ "{% elif message['role'] == 'assistant' %}{% endif %}"
233
+ "{{message['content']}} "
234
+ "{% endfor %}"
235
+ "{% if add_generation_prompt %}{% endif %}"
236
+ )
237
+
238
+ @property
239
+ def bos_id(self):
240
+ return self.tokenizer.bos_token_id
241
+
242
+ @property
243
+ def eos_id(self):
244
+ return self.tokenizer.eos_token_id
245
+
246
+ @property
247
+ def pad_id(self):
248
+ return self.tokenizer.pad_token_id
249
+
250
+ def encode(self, text: str, bos: bool = True, eos: bool = False):
251
+ t = self.tokenizer.encode(text, add_special_tokens=False)
252
+
253
+ if bos:
254
+ t = [self.bos_id] + t
255
+ if eos:
256
+ t = t + [self.eos_id]
257
+
258
+ return t
259
+
260
+ def decode(self, t: List[int], **kwargs) -> str:
261
+ return self.tokenizer.decode(t, **kwargs)
262
+
263
+ def process_one(
264
+ self,
265
+ prompt: str = None,
266
+ images: List[Image.Image] = None,
267
+ inference_mode: bool = True,
268
+ base_size: int = 1024,
269
+ image_size: int = 640,
270
+ cropping: bool = True,
271
+ ):
272
+
273
+ sft_format = prompt
274
+ (
275
+ tokenized_str,
276
+ images_list,
277
+ images_seq_mask,
278
+ images_spatial_crop,
279
+ num_image_tokens,
280
+ ) = self.tokenize_with_images(
281
+ sft_format,
282
+ images,
283
+ base_size=base_size,
284
+ image_size=image_size,
285
+ cropping=cropping,
286
+ )
287
+
288
+ masked_tokenized_str = []
289
+ for token_index in tokenized_str:
290
+ if token_index != self.image_token_id:
291
+ masked_tokenized_str.append(token_index)
292
+ else:
293
+ masked_tokenized_str.append(self.ignore_id)
294
+
295
+ input_ids = mx.array(tokenized_str)
296
+ target_ids = mx.array(masked_tokenized_str)
297
+ images_seq_mask = mx.array(images_seq_mask)
298
+
299
+ # Set ignored indices
300
+ target_ids = mx.where(
301
+ (input_ids < 0) | (input_ids == self.image_token_id),
302
+ self.ignore_id,
303
+ target_ids,
304
+ )
305
+ input_ids = mx.where(input_ids < 0, self.pad_id, input_ids)
306
+
307
+ if inference_mode:
308
+ input_ids = input_ids[:-1]
309
+ target_ids = target_ids[:-1]
310
+ images_seq_mask = images_seq_mask[:-1]
311
+
312
+ return {
313
+ "input_ids": input_ids[None, :],
314
+ "attention_mask": input_ids != self.pad_id,
315
+ "labels": target_ids,
316
+ "images": images_list,
317
+ "images_seq_mask": images_seq_mask[None, ...],
318
+ "images_spatial_crop": images_spatial_crop,
319
+ "num_image_tokens": num_image_tokens,
320
+ }
321
+
322
+ def pad_sequence(self, sequences, padding_value):
323
+ # Get max length of sequences
324
+ max_len = max(len(seq) for seq in sequences)
325
+
326
+ # Pad each sequence to max length
327
+ padded_seqs = []
328
+ for seq in sequences:
329
+ pad_length = max_len - len(seq)
330
+ if pad_length > 0:
331
+ padding = mx.full((pad_length,), padding_value)
332
+ padded_seq = mx.concatenate([seq, padding])
333
+ else:
334
+ padded_seq = seq
335
+ padded_seqs.append(padded_seq)
336
+
337
+ return mx.stack(padded_seqs)
338
+
339
+ def tokenize_with_images(
340
+ self,
341
+ conversation: str,
342
+ images: List[Image.Image],
343
+ base_size: int = 1024,
344
+ image_size: int = 640,
345
+ cropping: bool = True,
346
+ ):
347
+
348
+ patch_size = 16
349
+ downsample_ratio = 4
350
+ valid_img_tokens = 0
351
+ ratio = 1
352
+
353
+ image_draw = images[0].copy()
354
+
355
+ w, h = image_draw.size
356
+
357
+ ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
358
+
359
+ """Tokenize text with <image> tags."""
360
+ assert conversation.count(self.image_token) == len(
361
+ images
362
+ ), f"The number of image tokens in the prompt does not match the number of images: {conversation.count(self.image_token)} != {len(images)}"
363
+ text_splits = conversation.split(self.image_token)
364
+
365
+ images_list, images_crop_list, images_seq_mask = [], [], []
366
+ tokenized_str = []
367
+ images_spatial_crop = []
368
+ for text_sep, image in zip(text_splits, images):
369
+
370
+ tokenized_sep = self.encode(text_sep, bos=False, eos=False)
371
+ tokenized_str += tokenized_sep
372
+ images_seq_mask += [False] * len(tokenized_sep)
373
+
374
+ if cropping:
375
+
376
+ if image.size[0] <= 640 and image.size[1] <= 640:
377
+ crop_ratio = [1, 1]
378
+
379
+ else:
380
+ if cropping:
381
+ # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
382
+ images_crop_raw, crop_ratio = dynamic_preprocess(image)
383
+
384
+ else:
385
+ # best_width, best_height = self.image_size, self.image_size
386
+ crop_ratio = [1, 1]
387
+
388
+ """process the global view"""
389
+ # image = image.resize((base_size, base_size))
390
+ global_view = ImageOps.pad(
391
+ image,
392
+ (base_size, base_size),
393
+ color=tuple(int(x * 255) for x in self.image_transform.mean),
394
+ )
395
+
396
+ if base_size == 1024:
397
+ valid_img_tokens += int(256 * ratio)
398
+ elif base_size == 1280:
399
+ valid_img_tokens += int(400 * ratio)
400
+ elif base_size == 640:
401
+ valid_img_tokens += int(100 * ratio)
402
+
403
+ images_list.append(
404
+ self.image_transform(global_view).astype(mx.bfloat16)
405
+ )
406
+
407
+ width_crop_num, height_crop_num = crop_ratio
408
+
409
+ images_spatial_crop.append([width_crop_num, height_crop_num])
410
+
411
+ if width_crop_num > 1 or height_crop_num > 1:
412
+ """process the local views"""
413
+
414
+ for i in range(len(images_crop_raw)):
415
+ images_crop_list.append(
416
+ self.image_transform(images_crop_raw[i]).astype(mx.bfloat16)
417
+ )
418
+
419
+ if image_size == 640:
420
+ valid_img_tokens += len(images_crop_list) * 100
421
+
422
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
423
+ num_queries_base = math.ceil(
424
+ (base_size // patch_size) / downsample_ratio
425
+ )
426
+
427
+ """add image tokens"""
428
+
429
+ tokenized_image = (
430
+ [self.image_token_id] * num_queries_base + [self.image_token_id]
431
+ ) * num_queries_base
432
+ tokenized_image += [self.image_token_id]
433
+ if width_crop_num > 1 or height_crop_num > 1:
434
+ tokenized_image += (
435
+ [self.image_token_id] * (num_queries * width_crop_num)
436
+ + [self.image_token_id]
437
+ ) * (num_queries * height_crop_num)
438
+ tokenized_str += tokenized_image
439
+ images_seq_mask += [True] * len(tokenized_image)
440
+
441
+ else:
442
+
443
+ """process the global view"""
444
+ if image_size <= 640:
445
+ print("directly resize")
446
+ image = image.resize((image_size, image_size))
447
+
448
+ global_view = ImageOps.pad(
449
+ image,
450
+ (image_size, image_size),
451
+ color=tuple(int(x * 255) for x in self.image_transform.mean),
452
+ )
453
+ images_list.append(
454
+ self.image_transform(global_view).astype(mx.bfloat16)
455
+ )
456
+
457
+ if base_size == 1024:
458
+ valid_img_tokens += int(256 * ratio)
459
+ elif base_size == 1280:
460
+ valid_img_tokens += int(400 * ratio)
461
+ elif base_size == 640:
462
+ valid_img_tokens += int(100 * 1)
463
+ elif base_size == 512:
464
+ valid_img_tokens += int(64 * 1)
465
+
466
+ width_crop_num, height_crop_num = 1, 1
467
+
468
+ images_spatial_crop.append([width_crop_num, height_crop_num])
469
+
470
+ """add image tokens"""
471
+ num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
472
+
473
+ tokenized_image = (
474
+ [self.image_token_id] * num_queries + [self.image_token_id]
475
+ ) * num_queries
476
+ tokenized_image += [self.image_token_id]
477
+
478
+ tokenized_str += tokenized_image
479
+ images_seq_mask += [True] * len(tokenized_image)
480
+
481
+ tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
482
+ tokenized_str += tokenized_sep
483
+ images_seq_mask += [False] * len(tokenized_sep)
484
+
485
+ """add the bos tokens"""
486
+ bos_id = 0
487
+ tokenized_str = [bos_id] + tokenized_str
488
+ images_seq_mask = [False] + images_seq_mask
489
+
490
+ images_seq_mask = mx.array(images_seq_mask)
491
+
492
+ if len(images_list) == 0:
493
+ images_ori = mx.zeros((1, 3, image_size, image_size))
494
+ images_spatial_crop = mx.zeros((1, 2))
495
+ images_crop = mx.zeros((1, 3, base_size, base_size))
496
+
497
+ else:
498
+ images_ori = mx.stack(images_list, axis=0)
499
+ images_spatial_crop = mx.array(images_spatial_crop)
500
+ if images_crop_list:
501
+ images_crop = mx.stack(images_crop_list, axis=0)
502
+ else:
503
+ images_crop = mx.zeros((1, 3, base_size, base_size))
504
+
505
+ assert len(tokenized_str) == len(
506
+ images_seq_mask
507
+ ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
508
+
509
+ return (
510
+ tokenized_str,
511
+ [images_crop, images_ori],
512
+ images_seq_mask,
513
+ images_spatial_crop,
514
+ valid_img_tokens,
515
+ )
516
+
517
+ def __call__(
518
+ self,
519
+ *,
520
+ text: str = None,
521
+ images: List[Image.Image] = None,
522
+ inference_mode: bool = True,
523
+ image_size: int = 640,
524
+ base_size: int = 1024,
525
+ cropping: bool = True,
526
+ padding: bool = True,
527
+ return_tensors: Literal["np", "mx", "pt"] = "mx",
528
+ **kwargs,
529
+ ):
530
+ """
531
+
532
+ Args:
533
+ text (str or List[str]): the formatted prompt(s);
534
+ images (List[ImageType]): the list of images (one per prompt for batched inputs);
535
+ inference_mode (bool): if True, then remove the last eos token;
536
+
537
+ Returns:
538
+ outputs (BaseProcessorOutput): the output of the processor,
539
+ - input_ids (mx.array): [batch_size, N + image tokens]
540
+ - images (mx.array): [n_images, 3, H, W]
541
+ - image_id (int): the id of the image token
542
+ - num_image_tokens (List[int]): the number of image tokens
543
+ """
544
+
545
+ # Handle batched inputs (list of prompts with list of images)
546
+ if isinstance(text, list):
547
+ if images is None:
548
+ images = [None] * len(text)
549
+
550
+ batch_results = []
551
+ for i, prompt in enumerate(text):
552
+ # Each prompt has one image
553
+ img = [images[i]] if images[i] is not None else None
554
+ result = self.process_one(
555
+ prompt=prompt,
556
+ images=img,
557
+ inference_mode=inference_mode,
558
+ image_size=image_size,
559
+ base_size=base_size,
560
+ cropping=cropping,
561
+ )
562
+ batch_results.append(result)
563
+
564
+ # Collate batch results
565
+ return self._collate_batch(batch_results, padding=padding)
566
+
567
+ # Single input case
568
+ prepare = self.process_one(
569
+ prompt=text,
570
+ images=images,
571
+ inference_mode=inference_mode,
572
+ image_size=image_size,
573
+ base_size=base_size,
574
+ cropping=cropping,
575
+ )
576
+
577
+ return prepare
578
+
579
+ def _collate_batch(self, batch_results: List[Dict], padding: bool = True) -> Dict:
580
+ """Collate multiple processed results into a batch."""
581
+ if not batch_results:
582
+ return {}
583
+
584
+ batch_size = len(batch_results)
585
+
586
+ # Get max sequence length for padding
587
+ max_seq_len = max(r["input_ids"].shape[1] for r in batch_results)
588
+
589
+ # Pad and stack input_ids
590
+ padded_input_ids = []
591
+ padded_images_seq_mask = []
592
+ for r in batch_results:
593
+ seq_len = r["input_ids"].shape[1]
594
+ pad_len = max_seq_len - seq_len
595
+
596
+ if pad_len > 0:
597
+ # Pad input_ids on the left
598
+ input_ids = mx.concatenate(
599
+ [
600
+ mx.full((1, pad_len), self.pad_id, dtype=r["input_ids"].dtype),
601
+ r["input_ids"],
602
+ ],
603
+ axis=1,
604
+ )
605
+ # Pad images_seq_mask on the left with False
606
+ seq_mask = mx.concatenate(
607
+ [mx.zeros((1, pad_len), dtype=mx.bool_), r["images_seq_mask"]],
608
+ axis=1,
609
+ )
610
+ else:
611
+ input_ids = r["input_ids"]
612
+ seq_mask = r["images_seq_mask"]
613
+
614
+ padded_input_ids.append(input_ids)
615
+ padded_images_seq_mask.append(seq_mask)
616
+
617
+ # Stack into batch
618
+ input_ids = mx.concatenate(padded_input_ids, axis=0)
619
+ images_seq_mask = mx.concatenate(padded_images_seq_mask, axis=0)
620
+
621
+ # Combine images: [patches, global_images]
622
+ all_patches = []
623
+ all_global_images = []
624
+ all_spatial_crops = []
625
+
626
+ for r in batch_results:
627
+ patches, global_img = r["images"]
628
+ # Only add non-zero patches
629
+ if mx.sum(patches).item() != 0:
630
+ all_patches.append(patches)
631
+ all_global_images.append(global_img)
632
+ all_spatial_crops.append(r["images_spatial_crop"])
633
+
634
+ # Stack patches and global images
635
+ if all_patches:
636
+ combined_patches = mx.concatenate(all_patches, axis=0)
637
+ else:
638
+ combined_patches = mx.zeros((1, 3, 1024, 1024))
639
+
640
+ combined_global_images = mx.concatenate(all_global_images, axis=0)
641
+ combined_spatial_crops = mx.concatenate(all_spatial_crops, axis=0)
642
+
643
+ return {
644
+ "input_ids": input_ids,
645
+ "attention_mask": input_ids != self.pad_id,
646
+ "images": [combined_patches, combined_global_images],
647
+ "images_seq_mask": images_seq_mask,
648
+ "images_spatial_crop": combined_spatial_crops,
649
+ }
650
+
651
+
652
+ # Install a composable AutoProcessor patch for DeepSeek-OCR (v1)
653
+ from ..base import install_auto_processor_patch
654
+
655
+ install_auto_processor_patch("deepseekocr", DeepseekOCRProcessor)