fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,773 @@
|
|
|
1
|
+
"""Image processor for Molmo2 - MLX-native implementation without torch dependency."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from transformers import AutoProcessor
|
|
9
|
+
from transformers.feature_extraction_utils import BatchFeature
|
|
10
|
+
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
|
11
|
+
from transformers.image_transforms import convert_to_rgb
|
|
12
|
+
from transformers.image_utils import (
|
|
13
|
+
ImageInput,
|
|
14
|
+
PILImageResampling,
|
|
15
|
+
make_flat_list_of_images,
|
|
16
|
+
to_numpy_array,
|
|
17
|
+
valid_images,
|
|
18
|
+
)
|
|
19
|
+
from transformers.processing_utils import ProcessorMixin
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Special tokens
|
|
24
|
+
IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
25
|
+
IMAGE_LOW_RES_TOKEN = "<im_low>"
|
|
26
|
+
IM_START_TOKEN = "<im_start>"
|
|
27
|
+
LOW_RES_IMAGE_START_TOKEN = "<low_res_im_start>"
|
|
28
|
+
FRAME_START_TOKEN = "<frame_start>"
|
|
29
|
+
IM_END_TOKEN = "<im_end>"
|
|
30
|
+
FRAME_END_TOKEN = "<frame_end>"
|
|
31
|
+
IM_COL_TOKEN = "<im_col>"
|
|
32
|
+
IMAGE_PROMPT = "<|image|>"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def normalize_image(
|
|
36
|
+
image: np.ndarray,
|
|
37
|
+
image_mean: List[float],
|
|
38
|
+
image_std: List[float],
|
|
39
|
+
) -> np.ndarray:
|
|
40
|
+
"""Normalize image with mean and std."""
|
|
41
|
+
image = image.astype(np.float32)
|
|
42
|
+
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
|
43
|
+
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
|
44
|
+
return image
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def resize_image_pil(
|
|
48
|
+
image: np.ndarray,
|
|
49
|
+
desired_output_size: List[int],
|
|
50
|
+
resample: PILImageResampling,
|
|
51
|
+
) -> np.ndarray:
|
|
52
|
+
"""Resize image using PIL instead of torch."""
|
|
53
|
+
# Convert numpy to PIL
|
|
54
|
+
if image.max() <= 1.0:
|
|
55
|
+
image_uint8 = (image * 255).astype(np.uint8)
|
|
56
|
+
else:
|
|
57
|
+
image_uint8 = image.astype(np.uint8)
|
|
58
|
+
|
|
59
|
+
pil_image = Image.fromarray(image_uint8)
|
|
60
|
+
|
|
61
|
+
# Map PILImageResampling to PIL resampling
|
|
62
|
+
resample_map = {
|
|
63
|
+
PILImageResampling.NEAREST: Image.Resampling.NEAREST,
|
|
64
|
+
PILImageResampling.BILINEAR: Image.Resampling.BILINEAR,
|
|
65
|
+
PILImageResampling.BICUBIC: Image.Resampling.BICUBIC,
|
|
66
|
+
PILImageResampling.LANCZOS: Image.Resampling.LANCZOS,
|
|
67
|
+
PILImageResampling.BOX: Image.Resampling.BOX,
|
|
68
|
+
PILImageResampling.HAMMING: Image.Resampling.HAMMING,
|
|
69
|
+
}
|
|
70
|
+
pil_resample = resample_map.get(resample, Image.Resampling.BILINEAR)
|
|
71
|
+
|
|
72
|
+
# Resize (PIL uses width, height order)
|
|
73
|
+
resized = pil_image.resize(
|
|
74
|
+
(desired_output_size[1], desired_output_size[0]), resample=pil_resample
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Convert back to numpy and normalize to [0, 1]
|
|
78
|
+
resized_np = np.array(resized, dtype=np.float32) / 255.0
|
|
79
|
+
return resized_np
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def select_tiling(
|
|
83
|
+
h: int, w: int, patch_size: int, max_num_crops: int
|
|
84
|
+
) -> Tuple[int, int]:
|
|
85
|
+
"""Divide an image of size [w, h] into up to max_num_crops of size patch_size."""
|
|
86
|
+
tilings = []
|
|
87
|
+
for i in range(1, max_num_crops + 1):
|
|
88
|
+
for j in range(1, max_num_crops + 1):
|
|
89
|
+
if i * j <= max_num_crops:
|
|
90
|
+
tilings.append((i, j))
|
|
91
|
+
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
|
92
|
+
candidate_tilings = np.array(tilings, dtype=np.int32)
|
|
93
|
+
candidate_resolutions = candidate_tilings * patch_size
|
|
94
|
+
|
|
95
|
+
original_size = np.array([h, w], dtype=np.float32)
|
|
96
|
+
|
|
97
|
+
with np.errstate(divide="ignore"):
|
|
98
|
+
required_scale = np.min(
|
|
99
|
+
candidate_resolutions.astype(np.float32) / original_size, axis=-1
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if np.all(required_scale < 1):
|
|
103
|
+
ix = np.argmax(required_scale)
|
|
104
|
+
else:
|
|
105
|
+
required_scale = np.where(required_scale < 1.0, 1e9, required_scale)
|
|
106
|
+
ix = np.argmin(required_scale)
|
|
107
|
+
|
|
108
|
+
return tuple(candidate_tilings[ix])
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def build_resized_image(
|
|
112
|
+
image: np.ndarray,
|
|
113
|
+
base_image_input_size: List[int],
|
|
114
|
+
resample: PILImageResampling,
|
|
115
|
+
image_mean: List[float],
|
|
116
|
+
image_std: List[float],
|
|
117
|
+
image_patch_size: int,
|
|
118
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
119
|
+
"""Build a single resized image crop."""
|
|
120
|
+
resized = resize_image_pil(image, base_image_input_size, resample)
|
|
121
|
+
resized = normalize_image(resized, image_mean, image_std)
|
|
122
|
+
if len(resized.shape) == 3:
|
|
123
|
+
resized = np.expand_dims(resized, 0)
|
|
124
|
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
125
|
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
|
126
|
+
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape(
|
|
127
|
+
[crop_patch_h, crop_patch_w]
|
|
128
|
+
)
|
|
129
|
+
return resized, resize_idx
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def build_overlapping_crops(
|
|
133
|
+
image: np.ndarray,
|
|
134
|
+
max_crops: int,
|
|
135
|
+
overlap_margins: List[int],
|
|
136
|
+
base_image_input_size: List[int],
|
|
137
|
+
resample: PILImageResampling,
|
|
138
|
+
image_mean: List[float],
|
|
139
|
+
image_std: List[float],
|
|
140
|
+
image_patch_size: int,
|
|
141
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
142
|
+
"""Decompose an image into overlapping crops."""
|
|
143
|
+
crop_size = base_image_input_size[0]
|
|
144
|
+
assert base_image_input_size[0] == base_image_input_size[1]
|
|
145
|
+
|
|
146
|
+
left_margin, right_margin = overlap_margins
|
|
147
|
+
total_margin_pixels = image_patch_size * (right_margin + left_margin)
|
|
148
|
+
crop_patches = base_image_input_size[0] // image_patch_size
|
|
149
|
+
crop_window_patches = crop_patches - (right_margin + left_margin)
|
|
150
|
+
crop_window_size = crop_window_patches * image_patch_size
|
|
151
|
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
152
|
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
|
153
|
+
original_image_h, original_image_w = image.shape[:2]
|
|
154
|
+
|
|
155
|
+
tiling = select_tiling(
|
|
156
|
+
original_image_h - total_margin_pixels,
|
|
157
|
+
original_image_w - total_margin_pixels,
|
|
158
|
+
crop_window_size,
|
|
159
|
+
max_crops,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
src = resize_image_pil(
|
|
163
|
+
image,
|
|
164
|
+
[
|
|
165
|
+
tiling[0] * crop_window_size + total_margin_pixels,
|
|
166
|
+
tiling[1] * crop_window_size + total_margin_pixels,
|
|
167
|
+
],
|
|
168
|
+
resample,
|
|
169
|
+
)
|
|
170
|
+
src = normalize_image(src, image_mean, image_std)
|
|
171
|
+
|
|
172
|
+
n_crops = tiling[0] * tiling[1]
|
|
173
|
+
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
|
174
|
+
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
|
175
|
+
|
|
176
|
+
on_crop = 0
|
|
177
|
+
for i in range(tiling[0]):
|
|
178
|
+
y0 = i * crop_window_size
|
|
179
|
+
for j in range(tiling[1]):
|
|
180
|
+
x0 = j * crop_window_size
|
|
181
|
+
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
|
|
182
|
+
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(
|
|
183
|
+
crop_patch_h, crop_patch_w
|
|
184
|
+
)
|
|
185
|
+
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
|
186
|
+
|
|
187
|
+
if i != 0:
|
|
188
|
+
patch_idx[:left_margin, :] = -1
|
|
189
|
+
if j != 0:
|
|
190
|
+
patch_idx[:, :left_margin] = -1
|
|
191
|
+
if i != tiling[0] - 1:
|
|
192
|
+
patch_idx[-right_margin:, :] = -1
|
|
193
|
+
if j != tiling[1] - 1:
|
|
194
|
+
patch_idx[:, -right_margin:] = -1
|
|
195
|
+
patch_idx_arr[on_crop] = patch_idx
|
|
196
|
+
on_crop += 1
|
|
197
|
+
|
|
198
|
+
# Transpose the patch_idx_arr to get the full index array
|
|
199
|
+
patch_idx_full = np.zeros(
|
|
200
|
+
[
|
|
201
|
+
tiling[0] * crop_window_patches + left_margin + right_margin,
|
|
202
|
+
tiling[1] * crop_window_patches + left_margin + right_margin,
|
|
203
|
+
],
|
|
204
|
+
dtype=np.int32,
|
|
205
|
+
)
|
|
206
|
+
for i in range(tiling[0]):
|
|
207
|
+
for j in range(tiling[1]):
|
|
208
|
+
crop_idx = i * tiling[1] + j
|
|
209
|
+
y_start = i * crop_window_patches
|
|
210
|
+
x_start = j * crop_window_patches
|
|
211
|
+
patch_idx_full[
|
|
212
|
+
y_start : y_start + crop_patch_h, x_start : x_start + crop_patch_w
|
|
213
|
+
] = np.where(
|
|
214
|
+
patch_idx_arr[crop_idx] >= 0,
|
|
215
|
+
patch_idx_arr[crop_idx],
|
|
216
|
+
patch_idx_full[
|
|
217
|
+
y_start : y_start + crop_patch_h, x_start : x_start + crop_patch_w
|
|
218
|
+
],
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return crop_arr, patch_idx_full, tiling
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def batch_pixels_to_patches(crops: np.ndarray, patch_size: int) -> np.ndarray:
|
|
225
|
+
"""Convert image crops to patches."""
|
|
226
|
+
n_crops, h, w, c = crops.shape
|
|
227
|
+
n_patches_h = h // patch_size
|
|
228
|
+
n_patches_w = w // patch_size
|
|
229
|
+
n_patches = n_patches_h * n_patches_w
|
|
230
|
+
patch_dim = patch_size * patch_size * c
|
|
231
|
+
|
|
232
|
+
# Reshape to patches
|
|
233
|
+
crops = crops.reshape(n_crops, n_patches_h, patch_size, n_patches_w, patch_size, c)
|
|
234
|
+
crops = crops.transpose(0, 1, 3, 2, 4, 5)
|
|
235
|
+
crops = crops.reshape(n_crops, n_patches, patch_dim)
|
|
236
|
+
return crops
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def arange_for_pooling(
|
|
240
|
+
idx_arr: np.ndarray,
|
|
241
|
+
pool_h: int,
|
|
242
|
+
pool_w: int,
|
|
243
|
+
) -> np.ndarray:
|
|
244
|
+
"""Build pooling indices using centered padding (matches HuggingFace implementation)."""
|
|
245
|
+
h, w = idx_arr.shape
|
|
246
|
+
# Calculate padding to make dimensions divisible by pool size (centered padding)
|
|
247
|
+
h_pad = pool_h * ((h + pool_h - 1) // pool_h) - h
|
|
248
|
+
w_pad = pool_w * ((w + pool_w - 1) // pool_w) - w
|
|
249
|
+
|
|
250
|
+
# Apply centered padding
|
|
251
|
+
idx_arr = np.pad(
|
|
252
|
+
idx_arr,
|
|
253
|
+
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
|
254
|
+
mode="constant",
|
|
255
|
+
constant_values=-1,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Rearrange into pooling windows: (h dh) (w dw) -> h w (dh dw)
|
|
259
|
+
padded_h, padded_w = idx_arr.shape
|
|
260
|
+
out_h = padded_h // pool_h
|
|
261
|
+
out_w = padded_w // pool_w
|
|
262
|
+
|
|
263
|
+
# Reshape to separate pooling dimensions
|
|
264
|
+
idx_arr = idx_arr.reshape(out_h, pool_h, out_w, pool_w)
|
|
265
|
+
# Transpose to get (out_h, out_w, pool_h, pool_w)
|
|
266
|
+
idx_arr = idx_arr.transpose(0, 2, 1, 3)
|
|
267
|
+
# Reshape to (out_h, out_w, pool_h * pool_w)
|
|
268
|
+
idx_arr = idx_arr.reshape(out_h, out_w, pool_h * pool_w)
|
|
269
|
+
|
|
270
|
+
return idx_arr
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def image_to_patches_and_grids(
|
|
274
|
+
image: np.ndarray,
|
|
275
|
+
max_crops: int,
|
|
276
|
+
overlap_margins: List[int],
|
|
277
|
+
base_image_input_size: List[int],
|
|
278
|
+
resample: PILImageResampling,
|
|
279
|
+
image_mean: List[float],
|
|
280
|
+
image_std: List[float],
|
|
281
|
+
image_patch_size: int,
|
|
282
|
+
image_pooling_w: int,
|
|
283
|
+
image_pooling_h: int,
|
|
284
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, int]]:
|
|
285
|
+
"""Convert image to patches with pooling information.
|
|
286
|
+
|
|
287
|
+
Returns crops and pooling indices in the order expected by the model:
|
|
288
|
+
- Crops: [low_res_crop, high_res_crops...] (low-res first)
|
|
289
|
+
- Pooling indices: [low_res_pooling, high_res_pooling] (low-res first)
|
|
290
|
+
- Image grid: [lo_h, lo_w, hi_h, hi_w] (low-res dimensions first)
|
|
291
|
+
"""
|
|
292
|
+
crop_patch_w = base_image_input_size[1] // image_patch_size
|
|
293
|
+
crop_patch_h = base_image_input_size[0] // image_patch_size
|
|
294
|
+
|
|
295
|
+
# Build overlapping crops for high-res
|
|
296
|
+
crop_arr, patch_idx_arr, tiling = build_overlapping_crops(
|
|
297
|
+
image,
|
|
298
|
+
max_crops,
|
|
299
|
+
overlap_margins,
|
|
300
|
+
base_image_input_size,
|
|
301
|
+
resample,
|
|
302
|
+
image_mean,
|
|
303
|
+
image_std,
|
|
304
|
+
image_patch_size,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Build pooling indices for high-res using centered padding (matches HF)
|
|
308
|
+
pooling_idx = arange_for_pooling(patch_idx_arr, image_pooling_h, image_pooling_w)
|
|
309
|
+
h, w = pooling_idx.shape[:2]
|
|
310
|
+
pooling_idx = pooling_idx.reshape(-1, image_pooling_h * image_pooling_w)
|
|
311
|
+
|
|
312
|
+
# Build resized image for low-res
|
|
313
|
+
resize_arr, resize_idx = build_resized_image(
|
|
314
|
+
image,
|
|
315
|
+
base_image_input_size,
|
|
316
|
+
resample,
|
|
317
|
+
image_mean,
|
|
318
|
+
image_std,
|
|
319
|
+
image_patch_size,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Combine crops: LOW-RES FIRST (matches HuggingFace)
|
|
323
|
+
all_crops = np.concatenate([resize_arr, crop_arr], axis=0)
|
|
324
|
+
|
|
325
|
+
# Build pooling indices for low-res
|
|
326
|
+
resize_pooling_idx = arange_for_pooling(
|
|
327
|
+
resize_idx, image_pooling_h, image_pooling_w
|
|
328
|
+
)
|
|
329
|
+
resized_h, resized_w = resize_pooling_idx.shape[:2]
|
|
330
|
+
resize_pooling_idx = resize_pooling_idx.reshape(
|
|
331
|
+
-1, image_pooling_h * image_pooling_w
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# LOW-RES crop is first, so offset HIGH-RES indices by the number of low-res patches
|
|
335
|
+
# (matches HuggingFace: "Global image goes first, so the order of patches in previous crops gets increased")
|
|
336
|
+
pooling_idx = np.where(
|
|
337
|
+
pooling_idx >= 0,
|
|
338
|
+
pooling_idx
|
|
339
|
+
+ crop_patch_h * crop_patch_w, # Offset by one crop (the low-res crop)
|
|
340
|
+
-1,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Concatenate pooling indices: LOW-RES FIRST (matches HuggingFace)
|
|
344
|
+
pooling_idx = np.concatenate([resize_pooling_idx, pooling_idx], axis=0)
|
|
345
|
+
|
|
346
|
+
# Image grid format: [resized_h, resized_w, h, w] = [lo_h, lo_w, hi_h, hi_w]
|
|
347
|
+
# (matches HuggingFace order)
|
|
348
|
+
image_grid = np.array([[resized_h, resized_w, h, w]], dtype=np.int32)
|
|
349
|
+
|
|
350
|
+
return (
|
|
351
|
+
image_grid,
|
|
352
|
+
batch_pixels_to_patches(all_crops, image_patch_size),
|
|
353
|
+
pooling_idx,
|
|
354
|
+
(h, w), # Return high-res pooled dims for token generation
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class Molmo2ImageProcessor(BaseImageProcessor):
|
|
359
|
+
"""
|
|
360
|
+
MLX-native image processor for Molmo2 that doesn't require torch.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
model_input_names = [
|
|
364
|
+
"pixel_values",
|
|
365
|
+
"image_token_pooling",
|
|
366
|
+
"image_grids",
|
|
367
|
+
"image_num_crops",
|
|
368
|
+
]
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
size: Optional[dict] = None,
|
|
373
|
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
|
374
|
+
image_mean: Optional[List[float]] = None,
|
|
375
|
+
image_std: Optional[List[float]] = None,
|
|
376
|
+
do_convert_rgb: bool = True,
|
|
377
|
+
max_crops: int = 8,
|
|
378
|
+
overlap_margins: List[int] = None,
|
|
379
|
+
patch_size: int = 14,
|
|
380
|
+
pooling_size: List[int] = None,
|
|
381
|
+
**kwargs,
|
|
382
|
+
) -> None:
|
|
383
|
+
super().__init__(**kwargs)
|
|
384
|
+
size = size if size is not None else {"height": 378, "width": 378}
|
|
385
|
+
size = get_size_dict(size, default_to_square=True)
|
|
386
|
+
self.size = size
|
|
387
|
+
|
|
388
|
+
self.resample = resample
|
|
389
|
+
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
|
|
390
|
+
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
|
|
391
|
+
self.do_convert_rgb = do_convert_rgb
|
|
392
|
+
|
|
393
|
+
self.max_crops = max_crops
|
|
394
|
+
self.overlap_margins = (
|
|
395
|
+
overlap_margins if overlap_margins is not None else [4, 4]
|
|
396
|
+
)
|
|
397
|
+
self.patch_size = patch_size
|
|
398
|
+
self.pooling_size = pooling_size if pooling_size is not None else [2, 2]
|
|
399
|
+
|
|
400
|
+
def preprocess(
|
|
401
|
+
self,
|
|
402
|
+
images: ImageInput,
|
|
403
|
+
size: Optional[dict] = None,
|
|
404
|
+
resample: Optional[PILImageResampling] = None,
|
|
405
|
+
image_mean: Optional[List[float]] = None,
|
|
406
|
+
image_std: Optional[List[float]] = None,
|
|
407
|
+
do_convert_rgb: Optional[bool] = None,
|
|
408
|
+
max_crops: Optional[int] = None,
|
|
409
|
+
overlap_margins: Optional[List[int]] = None,
|
|
410
|
+
patch_size: Optional[int] = None,
|
|
411
|
+
pooling_size: Optional[List[int]] = None,
|
|
412
|
+
return_tensors: Optional[str] = None,
|
|
413
|
+
**kwargs,
|
|
414
|
+
) -> BatchFeature:
|
|
415
|
+
"""Preprocess images for Molmo2."""
|
|
416
|
+
if size is not None:
|
|
417
|
+
if "height" not in size or "width" not in size:
|
|
418
|
+
raise ValueError("size must contain 'height' and 'width' keys.")
|
|
419
|
+
else:
|
|
420
|
+
size = {**self.size}
|
|
421
|
+
|
|
422
|
+
base_image_input_size = [size["height"], size["width"]]
|
|
423
|
+
|
|
424
|
+
resample = resample or self.resample
|
|
425
|
+
image_mean = image_mean or self.image_mean
|
|
426
|
+
image_std = image_std or self.image_std
|
|
427
|
+
do_convert_rgb = (
|
|
428
|
+
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
max_crops = max_crops or self.max_crops
|
|
432
|
+
overlap_margins = overlap_margins or self.overlap_margins
|
|
433
|
+
patch_size = patch_size or self.patch_size
|
|
434
|
+
pooling_size = pooling_size or self.pooling_size
|
|
435
|
+
|
|
436
|
+
image_pooling_h, image_pooling_w = pooling_size
|
|
437
|
+
|
|
438
|
+
if images is not None:
|
|
439
|
+
images = make_flat_list_of_images(images)
|
|
440
|
+
|
|
441
|
+
if images is not None and not valid_images(images):
|
|
442
|
+
raise ValueError(
|
|
443
|
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
|
444
|
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
if do_convert_rgb:
|
|
448
|
+
images = [convert_to_rgb(image) for image in images]
|
|
449
|
+
|
|
450
|
+
images = [to_numpy_array(image) for image in images]
|
|
451
|
+
|
|
452
|
+
data = {}
|
|
453
|
+
if images is not None:
|
|
454
|
+
batch_grids = []
|
|
455
|
+
batch_crops = []
|
|
456
|
+
batch_pooled_patches_idx = []
|
|
457
|
+
batch_num_crops = []
|
|
458
|
+
batch_hi_pooled_dims = []
|
|
459
|
+
|
|
460
|
+
for image in images:
|
|
461
|
+
image_grid, crops, pooled_idx, hi_pooled_dims = (
|
|
462
|
+
image_to_patches_and_grids(
|
|
463
|
+
image,
|
|
464
|
+
max_crops,
|
|
465
|
+
overlap_margins,
|
|
466
|
+
base_image_input_size,
|
|
467
|
+
resample,
|
|
468
|
+
image_mean,
|
|
469
|
+
image_std,
|
|
470
|
+
patch_size,
|
|
471
|
+
image_pooling_w,
|
|
472
|
+
image_pooling_h,
|
|
473
|
+
)
|
|
474
|
+
)
|
|
475
|
+
batch_grids.append(image_grid)
|
|
476
|
+
batch_crops.append(crops)
|
|
477
|
+
batch_pooled_patches_idx.append(pooled_idx)
|
|
478
|
+
batch_num_crops.append(crops.shape[0])
|
|
479
|
+
batch_hi_pooled_dims.append(hi_pooled_dims)
|
|
480
|
+
|
|
481
|
+
pixel_values = np.concatenate(batch_crops, 0)
|
|
482
|
+
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
|
483
|
+
image_grids = np.concatenate(batch_grids, 0)
|
|
484
|
+
image_num_crops = np.array(batch_num_crops)
|
|
485
|
+
|
|
486
|
+
data.update(
|
|
487
|
+
pixel_values=pixel_values,
|
|
488
|
+
image_token_pooling=image_token_pooling,
|
|
489
|
+
image_grids=image_grids,
|
|
490
|
+
image_num_crops=image_num_crops,
|
|
491
|
+
_hi_pooled_dims=batch_hi_pooled_dims, # Internal use for token generation
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
return BatchFeature(data, tensor_type=return_tensors)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
class Molmo2Processor(ProcessorMixin):
|
|
498
|
+
"""
|
|
499
|
+
Processor for Molmo2 that combines image processor and tokenizer.
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
attributes = ["image_processor", "tokenizer"]
|
|
503
|
+
image_processor_class = "Molmo2ImageProcessor"
|
|
504
|
+
tokenizer_class = "AutoTokenizer"
|
|
505
|
+
|
|
506
|
+
def __init__(
|
|
507
|
+
self,
|
|
508
|
+
image_processor=None,
|
|
509
|
+
tokenizer=None,
|
|
510
|
+
image_use_col_tokens: bool = True,
|
|
511
|
+
use_single_crop_col_tokens: Optional[bool] = None,
|
|
512
|
+
use_single_crop_start_token: bool = True,
|
|
513
|
+
**kwargs,
|
|
514
|
+
):
|
|
515
|
+
if image_processor is None:
|
|
516
|
+
image_processor = Molmo2ImageProcessor()
|
|
517
|
+
super().__init__(image_processor, tokenizer, **kwargs)
|
|
518
|
+
self.image_use_col_tokens = image_use_col_tokens
|
|
519
|
+
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
|
520
|
+
self.use_single_crop_start_token = use_single_crop_start_token
|
|
521
|
+
self.image_placeholder_token = IMAGE_PROMPT
|
|
522
|
+
|
|
523
|
+
def get_image_tokens(self, image_grid: np.ndarray) -> str:
|
|
524
|
+
"""Generate image token string from image grid.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
image_grid: Array of [resized_h, resized_w, height, width]
|
|
528
|
+
= [lo_pooled_h, lo_pooled_w, hi_pooled_h, hi_pooled_w]
|
|
529
|
+
(matches HuggingFace format)
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
String of image tokens to insert into prompt.
|
|
533
|
+
"""
|
|
534
|
+
# Unpack in HuggingFace order: [lo_h, lo_w, hi_h, hi_w]
|
|
535
|
+
resized_h, resized_w, height, width = image_grid
|
|
536
|
+
|
|
537
|
+
# Build high-res tokens first (will be appended after low-res)
|
|
538
|
+
per_row = [IMAGE_PATCH_TOKEN] * width
|
|
539
|
+
if self.image_use_col_tokens:
|
|
540
|
+
per_row = per_row + [IM_COL_TOKEN]
|
|
541
|
+
hi_res_tokens = [IM_START_TOKEN] + per_row * height + [IM_END_TOKEN]
|
|
542
|
+
|
|
543
|
+
# Build low-res tokens
|
|
544
|
+
per_row = [IMAGE_PATCH_TOKEN] * resized_w
|
|
545
|
+
use_single_crop_col_tokens = (
|
|
546
|
+
self.image_use_col_tokens
|
|
547
|
+
if self.use_single_crop_col_tokens is None
|
|
548
|
+
else self.use_single_crop_col_tokens
|
|
549
|
+
)
|
|
550
|
+
image_start_token = (
|
|
551
|
+
LOW_RES_IMAGE_START_TOKEN
|
|
552
|
+
if self.use_single_crop_start_token
|
|
553
|
+
else IM_START_TOKEN
|
|
554
|
+
)
|
|
555
|
+
if use_single_crop_col_tokens:
|
|
556
|
+
per_row = per_row + [IM_COL_TOKEN]
|
|
557
|
+
lo_res_tokens = [image_start_token] + per_row * resized_h + [IM_END_TOKEN]
|
|
558
|
+
|
|
559
|
+
# Low-res comes first, then high-res (matches HuggingFace)
|
|
560
|
+
all_tokens = lo_res_tokens + hi_res_tokens
|
|
561
|
+
return "".join(all_tokens)
|
|
562
|
+
|
|
563
|
+
def __call__(
|
|
564
|
+
self,
|
|
565
|
+
text=None,
|
|
566
|
+
images=None,
|
|
567
|
+
padding=False,
|
|
568
|
+
truncation=None,
|
|
569
|
+
max_length=None,
|
|
570
|
+
return_tensors=None,
|
|
571
|
+
**kwargs,
|
|
572
|
+
):
|
|
573
|
+
"""Process text and images for the model."""
|
|
574
|
+
encoding = {}
|
|
575
|
+
image_grids = None
|
|
576
|
+
|
|
577
|
+
if images is not None:
|
|
578
|
+
image_inputs = self.image_processor(images, return_tensors=None)
|
|
579
|
+
image_grids = image_inputs.get("image_grids")
|
|
580
|
+
# Remove internal keys before adding to encoding
|
|
581
|
+
hi_pooled_dims = image_inputs.pop("_hi_pooled_dims", None)
|
|
582
|
+
encoding.update(image_inputs)
|
|
583
|
+
|
|
584
|
+
if text is not None:
|
|
585
|
+
# Expand image placeholders with actual image tokens
|
|
586
|
+
if image_grids is not None:
|
|
587
|
+
if isinstance(text, str):
|
|
588
|
+
text = [text]
|
|
589
|
+
was_string = True
|
|
590
|
+
else:
|
|
591
|
+
text = list(text)
|
|
592
|
+
was_string = False
|
|
593
|
+
|
|
594
|
+
image_idx = 0
|
|
595
|
+
for i in range(len(text)):
|
|
596
|
+
num_images = text[i].count(self.image_placeholder_token)
|
|
597
|
+
for _ in range(num_images):
|
|
598
|
+
if image_idx < len(image_grids):
|
|
599
|
+
image_tokens = self.get_image_tokens(image_grids[image_idx])
|
|
600
|
+
text[i] = text[i].replace(
|
|
601
|
+
self.image_placeholder_token, image_tokens, 1
|
|
602
|
+
)
|
|
603
|
+
image_idx += 1
|
|
604
|
+
|
|
605
|
+
if was_string:
|
|
606
|
+
text = text[0]
|
|
607
|
+
|
|
608
|
+
text_inputs = self.tokenizer(
|
|
609
|
+
text,
|
|
610
|
+
padding=padding,
|
|
611
|
+
truncation=truncation,
|
|
612
|
+
max_length=max_length,
|
|
613
|
+
return_tensors=return_tensors,
|
|
614
|
+
**kwargs,
|
|
615
|
+
)
|
|
616
|
+
encoding.update(text_inputs)
|
|
617
|
+
|
|
618
|
+
# Convert to requested tensor type
|
|
619
|
+
if return_tensors is not None:
|
|
620
|
+
encoding = BatchFeature(encoding, tensor_type=return_tensors)
|
|
621
|
+
|
|
622
|
+
return encoding
|
|
623
|
+
|
|
624
|
+
def apply_chat_template(
|
|
625
|
+
self,
|
|
626
|
+
conversation,
|
|
627
|
+
chat_template=None,
|
|
628
|
+
add_generation_prompt=False,
|
|
629
|
+
tokenize=False,
|
|
630
|
+
**kwargs,
|
|
631
|
+
):
|
|
632
|
+
"""Apply chat template to conversation."""
|
|
633
|
+
if chat_template is None:
|
|
634
|
+
chat_template = getattr(self, "chat_template", None)
|
|
635
|
+
if chat_template is None:
|
|
636
|
+
chat_template = getattr(self.tokenizer, "chat_template", None)
|
|
637
|
+
if chat_template is None:
|
|
638
|
+
# Default Molmo2 chat template
|
|
639
|
+
chat_template = (
|
|
640
|
+
"{% for message in messages %}"
|
|
641
|
+
"{% if message['role'] == 'user' %}"
|
|
642
|
+
"User: {{ message['content'] }}\n"
|
|
643
|
+
"{% elif message['role'] == 'assistant' %}"
|
|
644
|
+
"Assistant: {{ message['content'] }}\n"
|
|
645
|
+
"{% endif %}"
|
|
646
|
+
"{% endfor %}"
|
|
647
|
+
"{% if add_generation_prompt %}Assistant: {% endif %}"
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
from jinja2 import Environment
|
|
651
|
+
|
|
652
|
+
# Use Environment with loopcontrols extension to support {% continue %} and {% break %}
|
|
653
|
+
env = Environment(extensions=["jinja2.ext.loopcontrols"])
|
|
654
|
+
template = env.from_string(chat_template)
|
|
655
|
+
rendered = template.render(
|
|
656
|
+
messages=conversation,
|
|
657
|
+
add_generation_prompt=add_generation_prompt,
|
|
658
|
+
**kwargs,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
if tokenize:
|
|
662
|
+
return self.tokenizer.encode(rendered)
|
|
663
|
+
return rendered
|
|
664
|
+
|
|
665
|
+
def batch_decode(self, *args, **kwargs):
|
|
666
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
667
|
+
|
|
668
|
+
def decode(self, *args, **kwargs):
|
|
669
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
670
|
+
|
|
671
|
+
@classmethod
|
|
672
|
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
673
|
+
"""Load processor from pretrained model."""
|
|
674
|
+
import json
|
|
675
|
+
from pathlib import Path
|
|
676
|
+
|
|
677
|
+
from huggingface_hub import hf_hub_download
|
|
678
|
+
from transformers import AutoTokenizer
|
|
679
|
+
|
|
680
|
+
kwargs.pop("trust_remote_code", None)
|
|
681
|
+
|
|
682
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
683
|
+
is_local = model_path.exists() and model_path.is_dir()
|
|
684
|
+
|
|
685
|
+
# Load tokenizer
|
|
686
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
687
|
+
str(model_path) if is_local else pretrained_model_name_or_path,
|
|
688
|
+
trust_remote_code=True,
|
|
689
|
+
local_files_only=is_local,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# Load image processor config
|
|
693
|
+
image_processor_config = {}
|
|
694
|
+
try:
|
|
695
|
+
if is_local:
|
|
696
|
+
config_path = model_path / "preprocessor_config.json"
|
|
697
|
+
else:
|
|
698
|
+
config_path = Path(
|
|
699
|
+
hf_hub_download(
|
|
700
|
+
pretrained_model_name_or_path, "preprocessor_config.json"
|
|
701
|
+
)
|
|
702
|
+
)
|
|
703
|
+
if config_path.exists():
|
|
704
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
705
|
+
config = json.load(f)
|
|
706
|
+
for key in [
|
|
707
|
+
"size",
|
|
708
|
+
"max_crops",
|
|
709
|
+
"overlap_margins",
|
|
710
|
+
"patch_size",
|
|
711
|
+
"pooling_size",
|
|
712
|
+
"image_mean",
|
|
713
|
+
"image_std",
|
|
714
|
+
"do_convert_rgb",
|
|
715
|
+
]:
|
|
716
|
+
if key in config:
|
|
717
|
+
image_processor_config[key] = config[key]
|
|
718
|
+
except Exception:
|
|
719
|
+
pass
|
|
720
|
+
|
|
721
|
+
image_processor = Molmo2ImageProcessor(**image_processor_config)
|
|
722
|
+
|
|
723
|
+
# Load chat template
|
|
724
|
+
chat_template = getattr(tokenizer, "chat_template", None)
|
|
725
|
+
|
|
726
|
+
return cls(
|
|
727
|
+
image_processor=image_processor,
|
|
728
|
+
tokenizer=tokenizer,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
# Patch AutoProcessor for Molmo2 models
|
|
733
|
+
import json
|
|
734
|
+
from pathlib import Path
|
|
735
|
+
|
|
736
|
+
_original_auto_processor_from_pretrained_molmo2 = AutoProcessor.from_pretrained
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
@classmethod
|
|
740
|
+
def _patched_auto_processor_from_pretrained_molmo2(
|
|
741
|
+
cls, pretrained_model_name_or_path, **kwargs
|
|
742
|
+
):
|
|
743
|
+
"""Patched from_pretrained that returns Molmo2Processor for molmo2 models."""
|
|
744
|
+
from huggingface_hub import hf_hub_download
|
|
745
|
+
|
|
746
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
747
|
+
is_local = model_path.exists() and model_path.is_dir()
|
|
748
|
+
|
|
749
|
+
# Check if this is a molmo2 model
|
|
750
|
+
is_molmo2 = False
|
|
751
|
+
try:
|
|
752
|
+
if is_local:
|
|
753
|
+
config_path = model_path / "config.json"
|
|
754
|
+
else:
|
|
755
|
+
config_path = Path(
|
|
756
|
+
hf_hub_download(pretrained_model_name_or_path, "config.json")
|
|
757
|
+
)
|
|
758
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
759
|
+
config = json.load(f)
|
|
760
|
+
model_type = config.get("model_type", "").lower()
|
|
761
|
+
is_molmo2 = model_type == "molmo2"
|
|
762
|
+
except Exception:
|
|
763
|
+
pass
|
|
764
|
+
|
|
765
|
+
if is_molmo2:
|
|
766
|
+
return Molmo2Processor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
767
|
+
|
|
768
|
+
return _original_auto_processor_from_pretrained_molmo2.__func__(
|
|
769
|
+
cls, pretrained_model_name_or_path, **kwargs
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
AutoProcessor.from_pretrained = _patched_auto_processor_from_pretrained_molmo2
|