fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,560 @@
1
+ """
2
+ MLX-based KimiVL Processor.
3
+
4
+ This module provides an MLX-native processor for KimiVL models that:
5
+ 1. Uses a pre-converted fast tokenizer (no tiktoken dependency)
6
+ 2. Provides an MLX-based image processor (no torch/torchvision dependency)
7
+ 3. Patches missing functions for transformers 5.0 compatibility
8
+ """
9
+
10
+ import json
11
+ import math
12
+ import warnings
13
+ from pathlib import Path
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import mlx.core as mx
17
+ import transformers.processing_utils as processing_utils
18
+ from PIL import Image
19
+ from transformers import AutoTokenizer
20
+ from transformers.feature_extraction_utils import BatchFeature
21
+ from transformers.image_processing_utils import BaseImageProcessor
22
+ from transformers.image_utils import ImageInput, make_list_of_images, valid_images
23
+ from transformers.processing_utils import ProcessorMixin
24
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
25
+ from transformers.utils import TensorType
26
+
27
+ from .config import ModelConfig
28
+
29
+
30
+ def _validate_images_text_input_order(images, text):
31
+ """
32
+ Validate and potentially swap the order of images and text arguments.
33
+
34
+ This function checks if the arguments are in the correct order (images first, text second)
35
+ for backward compatibility. If text is passed as the first argument and images as the second,
36
+ it swaps them and issues a deprecation warning.
37
+
38
+ Args:
39
+ images: The images argument (should be image-like objects or None)
40
+ text: The text argument (should be strings or None)
41
+
42
+ Returns:
43
+ Tuple of (images, text) in the correct order
44
+ """
45
+ # Check if arguments are swapped (text passed as images, images passed as text)
46
+ if images is not None and text is not None:
47
+ # If 'images' looks like text and 'text' looks like images, swap them
48
+ images_is_text = isinstance(images, str) or (
49
+ isinstance(images, (list, tuple))
50
+ and len(images) > 0
51
+ and isinstance(images[0], str)
52
+ )
53
+ text_is_image = not isinstance(text, str) and not (
54
+ isinstance(text, (list, tuple))
55
+ and len(text) > 0
56
+ and isinstance(text[0], str)
57
+ )
58
+
59
+ if images_is_text and text_is_image:
60
+ warnings.warn(
61
+ "You passed text as the first argument and images as the second. "
62
+ "This is deprecated and will be removed in a future version. "
63
+ "Please pass images first and text second.",
64
+ FutureWarning,
65
+ )
66
+ return text, images
67
+
68
+ return images, text
69
+
70
+
71
+ # Add the function to transformers.processing_utils if it doesn't exist
72
+ if not hasattr(processing_utils, "_validate_images_text_input_order"):
73
+ processing_utils._validate_images_text_input_order = (
74
+ _validate_images_text_input_order
75
+ )
76
+
77
+ # Also add Unpack if it doesn't exist (for older Python versions)
78
+ if not hasattr(processing_utils, "Unpack"):
79
+ try:
80
+ from typing import Unpack
81
+
82
+ processing_utils.Unpack = Unpack
83
+ except ImportError:
84
+ from typing_extensions import Unpack
85
+
86
+ processing_utils.Unpack = Unpack
87
+
88
+
89
+ # CLIP-style normalization constants
90
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
91
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
92
+
93
+
94
+ class KimiVLImageProcessor(BaseImageProcessor):
95
+
96
+ model_input_names = ["pixel_values", "image_grid_hws"]
97
+
98
+ def __init__(
99
+ self,
100
+ patch_size: int = 14,
101
+ pad_input: bool = False,
102
+ image_mean: Tuple[float, float, float] = OPENAI_DATASET_MEAN,
103
+ image_std: Tuple[float, float, float] = OPENAI_DATASET_STD,
104
+ in_token_limit: int = 4096,
105
+ merge_kernel_size: List[int] = None,
106
+ **kwargs,
107
+ ):
108
+ super().__init__(**kwargs)
109
+ self.in_token_limit = in_token_limit
110
+ self.patch_size = patch_size
111
+ self.pad_input = pad_input
112
+ self.image_mean = image_mean
113
+ self.image_std = image_std
114
+ self.merge_kernel_size = (
115
+ merge_kernel_size if merge_kernel_size is not None else [2, 2]
116
+ )
117
+
118
+ def rescale(
119
+ self, image: Image.Image, merge_kernel_size: List[int] = None
120
+ ) -> Image.Image:
121
+ """Rescale image to fit within token limits and pad/crop to patch boundaries."""
122
+ if merge_kernel_size is None:
123
+ merge_kernel_size = self.merge_kernel_size
124
+
125
+ w, h = image.size
126
+ patch_size = self.patch_size
127
+
128
+ # Rescale if exceeds token limit
129
+ if (w // patch_size) * (h // patch_size) > self.in_token_limit:
130
+ scale = math.sqrt(
131
+ self.in_token_limit / ((w // patch_size) * (h // patch_size))
132
+ )
133
+ new_w, new_h = int(w * scale), int(h * scale)
134
+ image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
135
+
136
+ if self.pad_input:
137
+ new_w, new_h = image.size
138
+ pad_size_h = merge_kernel_size[0] * patch_size
139
+ pad_size_w = merge_kernel_size[1] * patch_size
140
+
141
+ pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h
142
+ pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w
143
+
144
+ if pad_h > 0 or pad_w > 0:
145
+ # Pad image (bottom and right padding)
146
+ new_image = Image.new(
147
+ image.mode, (new_w + pad_w, new_h + pad_h), (0, 0, 0)
148
+ )
149
+ new_image.paste(image, (0, 0))
150
+ image = new_image
151
+ else:
152
+ new_w, new_h = image.size
153
+ # Ensure dimensions are divisible by merge_kernel_size * patch_size
154
+ # so that the grid dimensions are divisible by merge_kernel_size
155
+ crop_size_w = merge_kernel_size[1] * patch_size
156
+ crop_size_h = merge_kernel_size[0] * patch_size
157
+ new_w = new_w - new_w % crop_size_w
158
+ new_h = new_h - new_h % crop_size_h
159
+ # Center crop
160
+ left = (image.size[0] - new_w) // 2
161
+ top = (image.size[1] - new_h) // 2
162
+ image = image.crop((left, top, left + new_w, top + new_h))
163
+
164
+ w, h = image.size
165
+ if w // patch_size >= 512 or h // patch_size >= 512:
166
+ raise ValueError("Exceed pos emb")
167
+
168
+ return image
169
+
170
+ def to_mlx(self, image: Image.Image) -> mx.array:
171
+ """Convert PIL image to MLX array in CHW format, normalized to [0, 1]."""
172
+ image = image.convert("RGB")
173
+ w, h = image.size
174
+ # Convert PIL image to MLX array directly via bytes
175
+ arr = mx.array(list(image.getdata()), dtype=mx.float32).reshape(h, w, 3) / 255.0
176
+ # Convert from HWC to CHW format
177
+ arr = arr.transpose(2, 0, 1)
178
+ return arr
179
+
180
+ def normalize(self, image: mx.array) -> mx.array:
181
+ """Normalize image with CLIP-style mean and std."""
182
+ mean = mx.array(self.image_mean, dtype=mx.float32).reshape(3, 1, 1)
183
+ std = mx.array(self.image_std, dtype=mx.float32).reshape(3, 1, 1)
184
+ return (image - mean) / std
185
+
186
+ def patchify(self, image: mx.array) -> Tuple[mx.array, Tuple[int, int]]:
187
+ """Convert image to patches."""
188
+ patch_size = self.patch_size
189
+ C, H, W = image.shape
190
+
191
+ # Reshape to (C, H//p, p, W//p, p) then to (num_patches, C, p, p)
192
+ patches = image.reshape(
193
+ C, H // patch_size, patch_size, W // patch_size, patch_size
194
+ )
195
+ # Permute to (H//p, W//p, C, p, p)
196
+ patches = patches.transpose(1, 3, 0, 2, 4)
197
+ # Flatten to (num_patches, C, p, p)
198
+ patches = patches.reshape(-1, C, patch_size, patch_size)
199
+
200
+ grid_hw = (H // patch_size, W // patch_size)
201
+ return patches, grid_hw
202
+
203
+ def _preprocess(self, image: ImageInput) -> Tuple[mx.array, Tuple[int, int]]:
204
+ """
205
+ Preprocess image and patchify it.
206
+
207
+ Args:
208
+ image: Image to preprocess.
209
+
210
+ Returns:
211
+ patches: mx.array
212
+ grid_hw: Tuple[int, int]
213
+ """
214
+ image = self.rescale(image, self.merge_kernel_size)
215
+ image = self.to_mlx(image)
216
+ image = self.normalize(image)
217
+ patches, grid_hw = self.patchify(image)
218
+ return patches, grid_hw
219
+
220
+ def preprocess(
221
+ self,
222
+ images: ImageInput,
223
+ return_tensors: Optional[Union[str, TensorType]] = None,
224
+ **kwargs,
225
+ ) -> BatchFeature:
226
+ """Process images and return BatchFeature."""
227
+ images = make_list_of_images(images)
228
+
229
+ if not valid_images(images):
230
+ raise ValueError(
231
+ "Invalid image type. Must be of type PIL.Image.Image or mx.array."
232
+ )
233
+
234
+ pixel_values_list = []
235
+ image_grid_hws = []
236
+
237
+ for image in images:
238
+ # Convert MLX arrays to PIL Images if needed
239
+ if isinstance(image, mx.array):
240
+ # Ensure we're working with the array values
241
+ arr = image
242
+ if arr.ndim == 3 and arr.shape[0] in [1, 3, 4]:
243
+ # CHW format, convert to HWC
244
+ arr = arr.transpose(1, 2, 0)
245
+ # Convert to uint8 for PIL
246
+ if arr.dtype in [mx.float32, mx.float16, mx.bfloat16]:
247
+ arr = (arr * 255).astype(mx.uint8)
248
+ # Convert to PIL via list (MLX -> list -> PIL)
249
+ h, w, _ = arr.shape
250
+ flat_data = arr.reshape(-1).tolist()
251
+ image = Image.frombytes("RGB", (w, h), bytes(flat_data))
252
+
253
+ patches, image_grid_hw = self._preprocess(image)
254
+ pixel_values_list.append(patches)
255
+ image_grid_hws.append(image_grid_hw)
256
+
257
+ pixel_values = mx.concatenate(pixel_values_list, axis=0)
258
+ image_grid_hws = mx.array(image_grid_hws)
259
+
260
+ # Return MLX arrays directly
261
+ data = {
262
+ "pixel_values": pixel_values,
263
+ "image_grid_hws": image_grid_hws,
264
+ }
265
+
266
+ return BatchFeature(data=data, tensor_type=return_tensors)
267
+
268
+ def __call__(
269
+ self,
270
+ images: ImageInput,
271
+ return_tensors: Optional[Union[str, TensorType]] = None,
272
+ **kwargs,
273
+ ) -> BatchFeature:
274
+ """Make the image processor callable."""
275
+ return self.preprocess(images, return_tensors=return_tensors, **kwargs)
276
+
277
+
278
+ class KimiVLProcessor(ProcessorMixin):
279
+ """
280
+ MLX-based processor for KimiVL that doesn't require torch/torchvision.
281
+
282
+ Constructs a KimiVL processor which wraps a KimiVL image processor and a tokenizer
283
+ into a single processor.
284
+ """
285
+
286
+ attributes = ["image_processor", "tokenizer"]
287
+ valid_kwargs = ["chat_template"]
288
+ image_processor_class = "KimiVLImageProcessor"
289
+ tokenizer_class = "AutoTokenizer"
290
+
291
+ def __init__(
292
+ self,
293
+ image_processor=None,
294
+ tokenizer=None,
295
+ chat_template=None,
296
+ **kwargs,
297
+ ):
298
+ self.image_token = "<|media_pad|>"
299
+ if image_processor is None:
300
+ image_processor = KimiVLImageProcessor()
301
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
302
+
303
+ def __call__(
304
+ self,
305
+ images: ImageInput = None,
306
+ text: Union[
307
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
308
+ ] = None,
309
+ **kwargs,
310
+ ) -> BatchFeature:
311
+ """
312
+ Main method to prepare for the model one or several sequences(s) and image(s).
313
+
314
+ Args:
315
+ images: The image or batch of images to be prepared.
316
+ text: The sequence or batch of sequences to be encoded.
317
+ return_tensors: If set, will return tensors of a particular framework.
318
+
319
+ Returns:
320
+ BatchFeature with input_ids, attention_mask, and pixel_values.
321
+ """
322
+ if images is None and text is None:
323
+ raise ValueError("You have to specify at least one of `images` or `text`.")
324
+
325
+ # Check if images and text inputs are reversed for BC
326
+ images, text = _validate_images_text_input_order(images, text)
327
+
328
+ # Extract return_tensors from kwargs (unused, we always return MLX arrays)
329
+ kwargs.pop("return_tensors", None)
330
+
331
+ # Process images
332
+ if images is not None:
333
+ image_inputs = self.image_processor(images)
334
+ image_grid_hws = image_inputs["image_grid_hws"]
335
+ else:
336
+ image_inputs = {}
337
+ image_grid_hws = None
338
+
339
+ # Process text
340
+ if isinstance(text, str):
341
+ text = [text]
342
+ elif text is not None and not isinstance(text, list):
343
+ raise ValueError(
344
+ "Invalid input text. Please provide a string, or a list of strings"
345
+ )
346
+
347
+ # Replace image tokens with the correct number of placeholder tokens
348
+ if image_grid_hws is not None and text is not None:
349
+ merge_length = (
350
+ self.image_processor.merge_kernel_size[0]
351
+ * self.image_processor.merge_kernel_size[1]
352
+ )
353
+ index = 0
354
+ for i in range(len(text)):
355
+ while self.image_token in text[i]:
356
+ # Use mx.prod for MLX arrays
357
+ grid_hw = image_grid_hws[index]
358
+ num_placeholders = int(mx.prod(grid_hw).item()) // merge_length
359
+ text[i] = text[i].replace(
360
+ self.image_token,
361
+ "<|placeholder|>" * num_placeholders,
362
+ 1,
363
+ )
364
+ index += 1
365
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
366
+
367
+ # Tokenize text
368
+ # Note: The TikToken tokenizer doesn't work properly with transformers' standard
369
+ # __call__ method due to issues with the pad function. We use encode() directly.
370
+ if text is not None:
371
+ # Encode each text and build the result manually
372
+ all_input_ids = []
373
+ for t in text:
374
+ ids = self.tokenizer.encode(t)
375
+ all_input_ids.append(ids)
376
+
377
+ # Pad sequences to the same length if needed
378
+ max_len = max(len(ids) for ids in all_input_ids)
379
+ pad_token_id = self.tokenizer.pad_token_id or 0
380
+
381
+ padded_input_ids = []
382
+ attention_masks = []
383
+ for ids in all_input_ids:
384
+ padding_length = max_len - len(ids)
385
+ padded_ids = ids + [pad_token_id] * padding_length
386
+ mask = [1] * len(ids) + [0] * padding_length
387
+ padded_input_ids.append(padded_ids)
388
+ attention_masks.append(mask)
389
+
390
+ # Convert to MLX arrays
391
+ text_inputs = {
392
+ "input_ids": mx.array(padded_input_ids),
393
+ "attention_mask": mx.array(attention_masks),
394
+ }
395
+ else:
396
+ text_inputs = {}
397
+
398
+ return BatchFeature(data={**text_inputs, **image_inputs})
399
+
400
+ def batch_decode(self, *args, **kwargs):
401
+ """Forward to tokenizer's batch_decode."""
402
+ return self.tokenizer.batch_decode(*args, **kwargs)
403
+
404
+ def decode(self, *args, **kwargs):
405
+ """Forward to tokenizer's decode."""
406
+ return self.tokenizer.decode(*args, **kwargs)
407
+
408
+ def apply_chat_template(
409
+ self,
410
+ conversation,
411
+ chat_template=None,
412
+ add_generation_prompt=False,
413
+ tokenize=False,
414
+ **kwargs,
415
+ ):
416
+ """Apply chat template to the conversation."""
417
+ # Use provided template, processor's template, or tokenizer's template
418
+ if chat_template is None:
419
+ chat_template = self.chat_template
420
+ if chat_template is None:
421
+ chat_template = getattr(self.tokenizer, "chat_template", None)
422
+ if chat_template is None:
423
+ raise ValueError(
424
+ "No chat template found. Please provide a chat_template argument "
425
+ "or ensure the tokenizer has a chat_template attribute."
426
+ )
427
+
428
+ # Use jinja2 to render the template
429
+ try:
430
+ from jinja2 import Template
431
+ except ImportError:
432
+ raise ImportError("jinja2 is required for apply_chat_template")
433
+
434
+ template = Template(chat_template)
435
+ rendered = template.render(
436
+ messages=conversation,
437
+ add_generation_prompt=add_generation_prompt,
438
+ **kwargs,
439
+ )
440
+
441
+ if tokenize:
442
+ return self.tokenizer.encode(rendered)
443
+ return rendered
444
+
445
+ @property
446
+ def model_input_names(self):
447
+ """Get the model input names from tokenizer and image processor."""
448
+ tokenizer_input_names = self.tokenizer.model_input_names
449
+ image_processor_input_names = self.image_processor.model_input_names
450
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
451
+
452
+ @classmethod
453
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
454
+ """Load the processor from a pretrained model path."""
455
+ from huggingface_hub import hf_hub_download
456
+
457
+ kwargs.pop("trust_remote_code", None)
458
+
459
+ model_path = Path(pretrained_model_name_or_path)
460
+ is_local = model_path.exists() and model_path.is_dir()
461
+ tokenizer = AutoTokenizer.from_pretrained(
462
+ str(model_path) if is_local else pretrained_model_name_or_path,
463
+ trust_remote_code=True,
464
+ local_files_only=is_local,
465
+ )
466
+
467
+ # Load image processor config and create our processor
468
+ image_processor_config = {}
469
+ try:
470
+ if is_local:
471
+ config_path = model_path / "config.json"
472
+ else:
473
+ config_path = Path(
474
+ hf_hub_download(pretrained_model_name_or_path, "config.json")
475
+ )
476
+ with open(config_path, "r", encoding="utf-8") as f:
477
+ config_dict = json.load(f)
478
+ config = ModelConfig.from_dict(config_dict)
479
+ if hasattr(config, "vision_config"):
480
+ vision_config = config.vision_config
481
+ if hasattr(vision_config, "patch_size"):
482
+ image_processor_config["patch_size"] = vision_config.patch_size
483
+ if hasattr(vision_config, "in_token_limit"):
484
+ image_processor_config["in_token_limit"] = (
485
+ vision_config.in_token_limit
486
+ )
487
+ if hasattr(vision_config, "merge_kernel_size"):
488
+ image_processor_config["merge_kernel_size"] = (
489
+ vision_config.merge_kernel_size
490
+ )
491
+ except Exception:
492
+ pass
493
+
494
+ image_processor = KimiVLImageProcessor(**image_processor_config)
495
+
496
+ # Load chat template from jinja file if not already set on tokenizer
497
+ chat_template = getattr(tokenizer, "chat_template", None)
498
+ if chat_template is None:
499
+ try:
500
+ if is_local:
501
+ jinja_path = model_path / "chat_template.jinja"
502
+ else:
503
+ jinja_path = Path(
504
+ hf_hub_download(
505
+ pretrained_model_name_or_path, "chat_template.jinja"
506
+ )
507
+ )
508
+ if jinja_path.exists():
509
+ chat_template = jinja_path.read_text(encoding="utf-8")
510
+ # Set chat_template on tokenizer so apply_chat_template works
511
+ tokenizer.chat_template = chat_template
512
+ except Exception:
513
+ pass
514
+
515
+ return cls(
516
+ image_processor=image_processor,
517
+ tokenizer=tokenizer,
518
+ chat_template=chat_template,
519
+ )
520
+
521
+
522
+ from transformers import AutoProcessor
523
+
524
+ _original_auto_processor_from_pretrained = AutoProcessor.from_pretrained
525
+
526
+
527
+ @classmethod
528
+ def _patched_auto_processor_from_pretrained(
529
+ cls, pretrained_model_name_or_path, **kwargs
530
+ ):
531
+ """Patched from_pretrained that returns KimiVLProcessor for kimi_vl models."""
532
+ from huggingface_hub import hf_hub_download
533
+
534
+ model_path = Path(pretrained_model_name_or_path)
535
+ is_local = model_path.exists() and model_path.is_dir()
536
+
537
+ # Check if this is a kimi_vl model
538
+ is_kimi_vl = False
539
+ try:
540
+ if is_local:
541
+ config_path = model_path / "config.json"
542
+ else:
543
+ config_path = Path(
544
+ hf_hub_download(pretrained_model_name_or_path, "config.json")
545
+ )
546
+ with open(config_path, "r", encoding="utf-8") as f:
547
+ config = json.load(f)
548
+ is_kimi_vl = config.get("model_type", "").lower() == "kimi_vl"
549
+ except Exception:
550
+ pass
551
+
552
+ if is_kimi_vl:
553
+ return KimiVLProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
554
+
555
+ return _original_auto_processor_from_pretrained.__func__(
556
+ cls, pretrained_model_name_or_path, **kwargs
557
+ )
558
+
559
+
560
+ AutoProcessor.from_pretrained = _patched_auto_processor_from_pretrained