fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,560 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MLX-based KimiVL Processor.
|
|
3
|
+
|
|
4
|
+
This module provides an MLX-native processor for KimiVL models that:
|
|
5
|
+
1. Uses a pre-converted fast tokenizer (no tiktoken dependency)
|
|
6
|
+
2. Provides an MLX-based image processor (no torch/torchvision dependency)
|
|
7
|
+
3. Patches missing functions for transformers 5.0 compatibility
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import math
|
|
12
|
+
import warnings
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import List, Optional, Tuple, Union
|
|
15
|
+
|
|
16
|
+
import mlx.core as mx
|
|
17
|
+
import transformers.processing_utils as processing_utils
|
|
18
|
+
from PIL import Image
|
|
19
|
+
from transformers import AutoTokenizer
|
|
20
|
+
from transformers.feature_extraction_utils import BatchFeature
|
|
21
|
+
from transformers.image_processing_utils import BaseImageProcessor
|
|
22
|
+
from transformers.image_utils import ImageInput, make_list_of_images, valid_images
|
|
23
|
+
from transformers.processing_utils import ProcessorMixin
|
|
24
|
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
25
|
+
from transformers.utils import TensorType
|
|
26
|
+
|
|
27
|
+
from .config import ModelConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _validate_images_text_input_order(images, text):
|
|
31
|
+
"""
|
|
32
|
+
Validate and potentially swap the order of images and text arguments.
|
|
33
|
+
|
|
34
|
+
This function checks if the arguments are in the correct order (images first, text second)
|
|
35
|
+
for backward compatibility. If text is passed as the first argument and images as the second,
|
|
36
|
+
it swaps them and issues a deprecation warning.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
images: The images argument (should be image-like objects or None)
|
|
40
|
+
text: The text argument (should be strings or None)
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tuple of (images, text) in the correct order
|
|
44
|
+
"""
|
|
45
|
+
# Check if arguments are swapped (text passed as images, images passed as text)
|
|
46
|
+
if images is not None and text is not None:
|
|
47
|
+
# If 'images' looks like text and 'text' looks like images, swap them
|
|
48
|
+
images_is_text = isinstance(images, str) or (
|
|
49
|
+
isinstance(images, (list, tuple))
|
|
50
|
+
and len(images) > 0
|
|
51
|
+
and isinstance(images[0], str)
|
|
52
|
+
)
|
|
53
|
+
text_is_image = not isinstance(text, str) and not (
|
|
54
|
+
isinstance(text, (list, tuple))
|
|
55
|
+
and len(text) > 0
|
|
56
|
+
and isinstance(text[0], str)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if images_is_text and text_is_image:
|
|
60
|
+
warnings.warn(
|
|
61
|
+
"You passed text as the first argument and images as the second. "
|
|
62
|
+
"This is deprecated and will be removed in a future version. "
|
|
63
|
+
"Please pass images first and text second.",
|
|
64
|
+
FutureWarning,
|
|
65
|
+
)
|
|
66
|
+
return text, images
|
|
67
|
+
|
|
68
|
+
return images, text
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# Add the function to transformers.processing_utils if it doesn't exist
|
|
72
|
+
if not hasattr(processing_utils, "_validate_images_text_input_order"):
|
|
73
|
+
processing_utils._validate_images_text_input_order = (
|
|
74
|
+
_validate_images_text_input_order
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Also add Unpack if it doesn't exist (for older Python versions)
|
|
78
|
+
if not hasattr(processing_utils, "Unpack"):
|
|
79
|
+
try:
|
|
80
|
+
from typing import Unpack
|
|
81
|
+
|
|
82
|
+
processing_utils.Unpack = Unpack
|
|
83
|
+
except ImportError:
|
|
84
|
+
from typing_extensions import Unpack
|
|
85
|
+
|
|
86
|
+
processing_utils.Unpack = Unpack
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# CLIP-style normalization constants
|
|
90
|
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
|
91
|
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class KimiVLImageProcessor(BaseImageProcessor):
|
|
95
|
+
|
|
96
|
+
model_input_names = ["pixel_values", "image_grid_hws"]
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
patch_size: int = 14,
|
|
101
|
+
pad_input: bool = False,
|
|
102
|
+
image_mean: Tuple[float, float, float] = OPENAI_DATASET_MEAN,
|
|
103
|
+
image_std: Tuple[float, float, float] = OPENAI_DATASET_STD,
|
|
104
|
+
in_token_limit: int = 4096,
|
|
105
|
+
merge_kernel_size: List[int] = None,
|
|
106
|
+
**kwargs,
|
|
107
|
+
):
|
|
108
|
+
super().__init__(**kwargs)
|
|
109
|
+
self.in_token_limit = in_token_limit
|
|
110
|
+
self.patch_size = patch_size
|
|
111
|
+
self.pad_input = pad_input
|
|
112
|
+
self.image_mean = image_mean
|
|
113
|
+
self.image_std = image_std
|
|
114
|
+
self.merge_kernel_size = (
|
|
115
|
+
merge_kernel_size if merge_kernel_size is not None else [2, 2]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def rescale(
|
|
119
|
+
self, image: Image.Image, merge_kernel_size: List[int] = None
|
|
120
|
+
) -> Image.Image:
|
|
121
|
+
"""Rescale image to fit within token limits and pad/crop to patch boundaries."""
|
|
122
|
+
if merge_kernel_size is None:
|
|
123
|
+
merge_kernel_size = self.merge_kernel_size
|
|
124
|
+
|
|
125
|
+
w, h = image.size
|
|
126
|
+
patch_size = self.patch_size
|
|
127
|
+
|
|
128
|
+
# Rescale if exceeds token limit
|
|
129
|
+
if (w // patch_size) * (h // patch_size) > self.in_token_limit:
|
|
130
|
+
scale = math.sqrt(
|
|
131
|
+
self.in_token_limit / ((w // patch_size) * (h // patch_size))
|
|
132
|
+
)
|
|
133
|
+
new_w, new_h = int(w * scale), int(h * scale)
|
|
134
|
+
image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
|
|
135
|
+
|
|
136
|
+
if self.pad_input:
|
|
137
|
+
new_w, new_h = image.size
|
|
138
|
+
pad_size_h = merge_kernel_size[0] * patch_size
|
|
139
|
+
pad_size_w = merge_kernel_size[1] * patch_size
|
|
140
|
+
|
|
141
|
+
pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h
|
|
142
|
+
pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w
|
|
143
|
+
|
|
144
|
+
if pad_h > 0 or pad_w > 0:
|
|
145
|
+
# Pad image (bottom and right padding)
|
|
146
|
+
new_image = Image.new(
|
|
147
|
+
image.mode, (new_w + pad_w, new_h + pad_h), (0, 0, 0)
|
|
148
|
+
)
|
|
149
|
+
new_image.paste(image, (0, 0))
|
|
150
|
+
image = new_image
|
|
151
|
+
else:
|
|
152
|
+
new_w, new_h = image.size
|
|
153
|
+
# Ensure dimensions are divisible by merge_kernel_size * patch_size
|
|
154
|
+
# so that the grid dimensions are divisible by merge_kernel_size
|
|
155
|
+
crop_size_w = merge_kernel_size[1] * patch_size
|
|
156
|
+
crop_size_h = merge_kernel_size[0] * patch_size
|
|
157
|
+
new_w = new_w - new_w % crop_size_w
|
|
158
|
+
new_h = new_h - new_h % crop_size_h
|
|
159
|
+
# Center crop
|
|
160
|
+
left = (image.size[0] - new_w) // 2
|
|
161
|
+
top = (image.size[1] - new_h) // 2
|
|
162
|
+
image = image.crop((left, top, left + new_w, top + new_h))
|
|
163
|
+
|
|
164
|
+
w, h = image.size
|
|
165
|
+
if w // patch_size >= 512 or h // patch_size >= 512:
|
|
166
|
+
raise ValueError("Exceed pos emb")
|
|
167
|
+
|
|
168
|
+
return image
|
|
169
|
+
|
|
170
|
+
def to_mlx(self, image: Image.Image) -> mx.array:
|
|
171
|
+
"""Convert PIL image to MLX array in CHW format, normalized to [0, 1]."""
|
|
172
|
+
image = image.convert("RGB")
|
|
173
|
+
w, h = image.size
|
|
174
|
+
# Convert PIL image to MLX array directly via bytes
|
|
175
|
+
arr = mx.array(list(image.getdata()), dtype=mx.float32).reshape(h, w, 3) / 255.0
|
|
176
|
+
# Convert from HWC to CHW format
|
|
177
|
+
arr = arr.transpose(2, 0, 1)
|
|
178
|
+
return arr
|
|
179
|
+
|
|
180
|
+
def normalize(self, image: mx.array) -> mx.array:
|
|
181
|
+
"""Normalize image with CLIP-style mean and std."""
|
|
182
|
+
mean = mx.array(self.image_mean, dtype=mx.float32).reshape(3, 1, 1)
|
|
183
|
+
std = mx.array(self.image_std, dtype=mx.float32).reshape(3, 1, 1)
|
|
184
|
+
return (image - mean) / std
|
|
185
|
+
|
|
186
|
+
def patchify(self, image: mx.array) -> Tuple[mx.array, Tuple[int, int]]:
|
|
187
|
+
"""Convert image to patches."""
|
|
188
|
+
patch_size = self.patch_size
|
|
189
|
+
C, H, W = image.shape
|
|
190
|
+
|
|
191
|
+
# Reshape to (C, H//p, p, W//p, p) then to (num_patches, C, p, p)
|
|
192
|
+
patches = image.reshape(
|
|
193
|
+
C, H // patch_size, patch_size, W // patch_size, patch_size
|
|
194
|
+
)
|
|
195
|
+
# Permute to (H//p, W//p, C, p, p)
|
|
196
|
+
patches = patches.transpose(1, 3, 0, 2, 4)
|
|
197
|
+
# Flatten to (num_patches, C, p, p)
|
|
198
|
+
patches = patches.reshape(-1, C, patch_size, patch_size)
|
|
199
|
+
|
|
200
|
+
grid_hw = (H // patch_size, W // patch_size)
|
|
201
|
+
return patches, grid_hw
|
|
202
|
+
|
|
203
|
+
def _preprocess(self, image: ImageInput) -> Tuple[mx.array, Tuple[int, int]]:
|
|
204
|
+
"""
|
|
205
|
+
Preprocess image and patchify it.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
image: Image to preprocess.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
patches: mx.array
|
|
212
|
+
grid_hw: Tuple[int, int]
|
|
213
|
+
"""
|
|
214
|
+
image = self.rescale(image, self.merge_kernel_size)
|
|
215
|
+
image = self.to_mlx(image)
|
|
216
|
+
image = self.normalize(image)
|
|
217
|
+
patches, grid_hw = self.patchify(image)
|
|
218
|
+
return patches, grid_hw
|
|
219
|
+
|
|
220
|
+
def preprocess(
|
|
221
|
+
self,
|
|
222
|
+
images: ImageInput,
|
|
223
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
224
|
+
**kwargs,
|
|
225
|
+
) -> BatchFeature:
|
|
226
|
+
"""Process images and return BatchFeature."""
|
|
227
|
+
images = make_list_of_images(images)
|
|
228
|
+
|
|
229
|
+
if not valid_images(images):
|
|
230
|
+
raise ValueError(
|
|
231
|
+
"Invalid image type. Must be of type PIL.Image.Image or mx.array."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
pixel_values_list = []
|
|
235
|
+
image_grid_hws = []
|
|
236
|
+
|
|
237
|
+
for image in images:
|
|
238
|
+
# Convert MLX arrays to PIL Images if needed
|
|
239
|
+
if isinstance(image, mx.array):
|
|
240
|
+
# Ensure we're working with the array values
|
|
241
|
+
arr = image
|
|
242
|
+
if arr.ndim == 3 and arr.shape[0] in [1, 3, 4]:
|
|
243
|
+
# CHW format, convert to HWC
|
|
244
|
+
arr = arr.transpose(1, 2, 0)
|
|
245
|
+
# Convert to uint8 for PIL
|
|
246
|
+
if arr.dtype in [mx.float32, mx.float16, mx.bfloat16]:
|
|
247
|
+
arr = (arr * 255).astype(mx.uint8)
|
|
248
|
+
# Convert to PIL via list (MLX -> list -> PIL)
|
|
249
|
+
h, w, _ = arr.shape
|
|
250
|
+
flat_data = arr.reshape(-1).tolist()
|
|
251
|
+
image = Image.frombytes("RGB", (w, h), bytes(flat_data))
|
|
252
|
+
|
|
253
|
+
patches, image_grid_hw = self._preprocess(image)
|
|
254
|
+
pixel_values_list.append(patches)
|
|
255
|
+
image_grid_hws.append(image_grid_hw)
|
|
256
|
+
|
|
257
|
+
pixel_values = mx.concatenate(pixel_values_list, axis=0)
|
|
258
|
+
image_grid_hws = mx.array(image_grid_hws)
|
|
259
|
+
|
|
260
|
+
# Return MLX arrays directly
|
|
261
|
+
data = {
|
|
262
|
+
"pixel_values": pixel_values,
|
|
263
|
+
"image_grid_hws": image_grid_hws,
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
|
267
|
+
|
|
268
|
+
def __call__(
|
|
269
|
+
self,
|
|
270
|
+
images: ImageInput,
|
|
271
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
272
|
+
**kwargs,
|
|
273
|
+
) -> BatchFeature:
|
|
274
|
+
"""Make the image processor callable."""
|
|
275
|
+
return self.preprocess(images, return_tensors=return_tensors, **kwargs)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class KimiVLProcessor(ProcessorMixin):
|
|
279
|
+
"""
|
|
280
|
+
MLX-based processor for KimiVL that doesn't require torch/torchvision.
|
|
281
|
+
|
|
282
|
+
Constructs a KimiVL processor which wraps a KimiVL image processor and a tokenizer
|
|
283
|
+
into a single processor.
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
attributes = ["image_processor", "tokenizer"]
|
|
287
|
+
valid_kwargs = ["chat_template"]
|
|
288
|
+
image_processor_class = "KimiVLImageProcessor"
|
|
289
|
+
tokenizer_class = "AutoTokenizer"
|
|
290
|
+
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
image_processor=None,
|
|
294
|
+
tokenizer=None,
|
|
295
|
+
chat_template=None,
|
|
296
|
+
**kwargs,
|
|
297
|
+
):
|
|
298
|
+
self.image_token = "<|media_pad|>"
|
|
299
|
+
if image_processor is None:
|
|
300
|
+
image_processor = KimiVLImageProcessor()
|
|
301
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
302
|
+
|
|
303
|
+
def __call__(
|
|
304
|
+
self,
|
|
305
|
+
images: ImageInput = None,
|
|
306
|
+
text: Union[
|
|
307
|
+
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
|
308
|
+
] = None,
|
|
309
|
+
**kwargs,
|
|
310
|
+
) -> BatchFeature:
|
|
311
|
+
"""
|
|
312
|
+
Main method to prepare for the model one or several sequences(s) and image(s).
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
images: The image or batch of images to be prepared.
|
|
316
|
+
text: The sequence or batch of sequences to be encoded.
|
|
317
|
+
return_tensors: If set, will return tensors of a particular framework.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
BatchFeature with input_ids, attention_mask, and pixel_values.
|
|
321
|
+
"""
|
|
322
|
+
if images is None and text is None:
|
|
323
|
+
raise ValueError("You have to specify at least one of `images` or `text`.")
|
|
324
|
+
|
|
325
|
+
# Check if images and text inputs are reversed for BC
|
|
326
|
+
images, text = _validate_images_text_input_order(images, text)
|
|
327
|
+
|
|
328
|
+
# Extract return_tensors from kwargs (unused, we always return MLX arrays)
|
|
329
|
+
kwargs.pop("return_tensors", None)
|
|
330
|
+
|
|
331
|
+
# Process images
|
|
332
|
+
if images is not None:
|
|
333
|
+
image_inputs = self.image_processor(images)
|
|
334
|
+
image_grid_hws = image_inputs["image_grid_hws"]
|
|
335
|
+
else:
|
|
336
|
+
image_inputs = {}
|
|
337
|
+
image_grid_hws = None
|
|
338
|
+
|
|
339
|
+
# Process text
|
|
340
|
+
if isinstance(text, str):
|
|
341
|
+
text = [text]
|
|
342
|
+
elif text is not None and not isinstance(text, list):
|
|
343
|
+
raise ValueError(
|
|
344
|
+
"Invalid input text. Please provide a string, or a list of strings"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Replace image tokens with the correct number of placeholder tokens
|
|
348
|
+
if image_grid_hws is not None and text is not None:
|
|
349
|
+
merge_length = (
|
|
350
|
+
self.image_processor.merge_kernel_size[0]
|
|
351
|
+
* self.image_processor.merge_kernel_size[1]
|
|
352
|
+
)
|
|
353
|
+
index = 0
|
|
354
|
+
for i in range(len(text)):
|
|
355
|
+
while self.image_token in text[i]:
|
|
356
|
+
# Use mx.prod for MLX arrays
|
|
357
|
+
grid_hw = image_grid_hws[index]
|
|
358
|
+
num_placeholders = int(mx.prod(grid_hw).item()) // merge_length
|
|
359
|
+
text[i] = text[i].replace(
|
|
360
|
+
self.image_token,
|
|
361
|
+
"<|placeholder|>" * num_placeholders,
|
|
362
|
+
1,
|
|
363
|
+
)
|
|
364
|
+
index += 1
|
|
365
|
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
366
|
+
|
|
367
|
+
# Tokenize text
|
|
368
|
+
# Note: The TikToken tokenizer doesn't work properly with transformers' standard
|
|
369
|
+
# __call__ method due to issues with the pad function. We use encode() directly.
|
|
370
|
+
if text is not None:
|
|
371
|
+
# Encode each text and build the result manually
|
|
372
|
+
all_input_ids = []
|
|
373
|
+
for t in text:
|
|
374
|
+
ids = self.tokenizer.encode(t)
|
|
375
|
+
all_input_ids.append(ids)
|
|
376
|
+
|
|
377
|
+
# Pad sequences to the same length if needed
|
|
378
|
+
max_len = max(len(ids) for ids in all_input_ids)
|
|
379
|
+
pad_token_id = self.tokenizer.pad_token_id or 0
|
|
380
|
+
|
|
381
|
+
padded_input_ids = []
|
|
382
|
+
attention_masks = []
|
|
383
|
+
for ids in all_input_ids:
|
|
384
|
+
padding_length = max_len - len(ids)
|
|
385
|
+
padded_ids = ids + [pad_token_id] * padding_length
|
|
386
|
+
mask = [1] * len(ids) + [0] * padding_length
|
|
387
|
+
padded_input_ids.append(padded_ids)
|
|
388
|
+
attention_masks.append(mask)
|
|
389
|
+
|
|
390
|
+
# Convert to MLX arrays
|
|
391
|
+
text_inputs = {
|
|
392
|
+
"input_ids": mx.array(padded_input_ids),
|
|
393
|
+
"attention_mask": mx.array(attention_masks),
|
|
394
|
+
}
|
|
395
|
+
else:
|
|
396
|
+
text_inputs = {}
|
|
397
|
+
|
|
398
|
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
399
|
+
|
|
400
|
+
def batch_decode(self, *args, **kwargs):
|
|
401
|
+
"""Forward to tokenizer's batch_decode."""
|
|
402
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
403
|
+
|
|
404
|
+
def decode(self, *args, **kwargs):
|
|
405
|
+
"""Forward to tokenizer's decode."""
|
|
406
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
407
|
+
|
|
408
|
+
def apply_chat_template(
|
|
409
|
+
self,
|
|
410
|
+
conversation,
|
|
411
|
+
chat_template=None,
|
|
412
|
+
add_generation_prompt=False,
|
|
413
|
+
tokenize=False,
|
|
414
|
+
**kwargs,
|
|
415
|
+
):
|
|
416
|
+
"""Apply chat template to the conversation."""
|
|
417
|
+
# Use provided template, processor's template, or tokenizer's template
|
|
418
|
+
if chat_template is None:
|
|
419
|
+
chat_template = self.chat_template
|
|
420
|
+
if chat_template is None:
|
|
421
|
+
chat_template = getattr(self.tokenizer, "chat_template", None)
|
|
422
|
+
if chat_template is None:
|
|
423
|
+
raise ValueError(
|
|
424
|
+
"No chat template found. Please provide a chat_template argument "
|
|
425
|
+
"or ensure the tokenizer has a chat_template attribute."
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Use jinja2 to render the template
|
|
429
|
+
try:
|
|
430
|
+
from jinja2 import Template
|
|
431
|
+
except ImportError:
|
|
432
|
+
raise ImportError("jinja2 is required for apply_chat_template")
|
|
433
|
+
|
|
434
|
+
template = Template(chat_template)
|
|
435
|
+
rendered = template.render(
|
|
436
|
+
messages=conversation,
|
|
437
|
+
add_generation_prompt=add_generation_prompt,
|
|
438
|
+
**kwargs,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
if tokenize:
|
|
442
|
+
return self.tokenizer.encode(rendered)
|
|
443
|
+
return rendered
|
|
444
|
+
|
|
445
|
+
@property
|
|
446
|
+
def model_input_names(self):
|
|
447
|
+
"""Get the model input names from tokenizer and image processor."""
|
|
448
|
+
tokenizer_input_names = self.tokenizer.model_input_names
|
|
449
|
+
image_processor_input_names = self.image_processor.model_input_names
|
|
450
|
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
451
|
+
|
|
452
|
+
@classmethod
|
|
453
|
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
454
|
+
"""Load the processor from a pretrained model path."""
|
|
455
|
+
from huggingface_hub import hf_hub_download
|
|
456
|
+
|
|
457
|
+
kwargs.pop("trust_remote_code", None)
|
|
458
|
+
|
|
459
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
460
|
+
is_local = model_path.exists() and model_path.is_dir()
|
|
461
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
462
|
+
str(model_path) if is_local else pretrained_model_name_or_path,
|
|
463
|
+
trust_remote_code=True,
|
|
464
|
+
local_files_only=is_local,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# Load image processor config and create our processor
|
|
468
|
+
image_processor_config = {}
|
|
469
|
+
try:
|
|
470
|
+
if is_local:
|
|
471
|
+
config_path = model_path / "config.json"
|
|
472
|
+
else:
|
|
473
|
+
config_path = Path(
|
|
474
|
+
hf_hub_download(pretrained_model_name_or_path, "config.json")
|
|
475
|
+
)
|
|
476
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
477
|
+
config_dict = json.load(f)
|
|
478
|
+
config = ModelConfig.from_dict(config_dict)
|
|
479
|
+
if hasattr(config, "vision_config"):
|
|
480
|
+
vision_config = config.vision_config
|
|
481
|
+
if hasattr(vision_config, "patch_size"):
|
|
482
|
+
image_processor_config["patch_size"] = vision_config.patch_size
|
|
483
|
+
if hasattr(vision_config, "in_token_limit"):
|
|
484
|
+
image_processor_config["in_token_limit"] = (
|
|
485
|
+
vision_config.in_token_limit
|
|
486
|
+
)
|
|
487
|
+
if hasattr(vision_config, "merge_kernel_size"):
|
|
488
|
+
image_processor_config["merge_kernel_size"] = (
|
|
489
|
+
vision_config.merge_kernel_size
|
|
490
|
+
)
|
|
491
|
+
except Exception:
|
|
492
|
+
pass
|
|
493
|
+
|
|
494
|
+
image_processor = KimiVLImageProcessor(**image_processor_config)
|
|
495
|
+
|
|
496
|
+
# Load chat template from jinja file if not already set on tokenizer
|
|
497
|
+
chat_template = getattr(tokenizer, "chat_template", None)
|
|
498
|
+
if chat_template is None:
|
|
499
|
+
try:
|
|
500
|
+
if is_local:
|
|
501
|
+
jinja_path = model_path / "chat_template.jinja"
|
|
502
|
+
else:
|
|
503
|
+
jinja_path = Path(
|
|
504
|
+
hf_hub_download(
|
|
505
|
+
pretrained_model_name_or_path, "chat_template.jinja"
|
|
506
|
+
)
|
|
507
|
+
)
|
|
508
|
+
if jinja_path.exists():
|
|
509
|
+
chat_template = jinja_path.read_text(encoding="utf-8")
|
|
510
|
+
# Set chat_template on tokenizer so apply_chat_template works
|
|
511
|
+
tokenizer.chat_template = chat_template
|
|
512
|
+
except Exception:
|
|
513
|
+
pass
|
|
514
|
+
|
|
515
|
+
return cls(
|
|
516
|
+
image_processor=image_processor,
|
|
517
|
+
tokenizer=tokenizer,
|
|
518
|
+
chat_template=chat_template,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
from transformers import AutoProcessor
|
|
523
|
+
|
|
524
|
+
_original_auto_processor_from_pretrained = AutoProcessor.from_pretrained
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@classmethod
|
|
528
|
+
def _patched_auto_processor_from_pretrained(
|
|
529
|
+
cls, pretrained_model_name_or_path, **kwargs
|
|
530
|
+
):
|
|
531
|
+
"""Patched from_pretrained that returns KimiVLProcessor for kimi_vl models."""
|
|
532
|
+
from huggingface_hub import hf_hub_download
|
|
533
|
+
|
|
534
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
535
|
+
is_local = model_path.exists() and model_path.is_dir()
|
|
536
|
+
|
|
537
|
+
# Check if this is a kimi_vl model
|
|
538
|
+
is_kimi_vl = False
|
|
539
|
+
try:
|
|
540
|
+
if is_local:
|
|
541
|
+
config_path = model_path / "config.json"
|
|
542
|
+
else:
|
|
543
|
+
config_path = Path(
|
|
544
|
+
hf_hub_download(pretrained_model_name_or_path, "config.json")
|
|
545
|
+
)
|
|
546
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
547
|
+
config = json.load(f)
|
|
548
|
+
is_kimi_vl = config.get("model_type", "").lower() == "kimi_vl"
|
|
549
|
+
except Exception:
|
|
550
|
+
pass
|
|
551
|
+
|
|
552
|
+
if is_kimi_vl:
|
|
553
|
+
return KimiVLProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
554
|
+
|
|
555
|
+
return _original_auto_processor_from_pretrained.__func__(
|
|
556
|
+
cls, pretrained_model_name_or_path, **kwargs
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
AutoProcessor.from_pretrained = _patched_auto_processor_from_pretrained
|