fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
"""
|
|
2
|
+
From https://github.com/deepseek-ai/DeepSeek-VL2
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Literal, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image, ImageOps
|
|
11
|
+
from transformers import LlamaTokenizerFast
|
|
12
|
+
from transformers.processing_utils import ProcessorMixin
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
16
|
+
best_ratio_diff = float("inf")
|
|
17
|
+
best_ratio = (1, 1)
|
|
18
|
+
area = width * height
|
|
19
|
+
for ratio in target_ratios:
|
|
20
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
|
21
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
22
|
+
if ratio_diff < best_ratio_diff:
|
|
23
|
+
best_ratio_diff = ratio_diff
|
|
24
|
+
best_ratio = ratio
|
|
25
|
+
elif ratio_diff == best_ratio_diff:
|
|
26
|
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
27
|
+
best_ratio = ratio
|
|
28
|
+
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
|
|
29
|
+
return best_ratio
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def dynamic_preprocess(
|
|
33
|
+
image, min_num=2, max_num=9, image_size=640, use_thumbnail=False
|
|
34
|
+
):
|
|
35
|
+
orig_width, orig_height = image.size
|
|
36
|
+
aspect_ratio = orig_width / orig_height
|
|
37
|
+
|
|
38
|
+
# calculate the existing image aspect ratio
|
|
39
|
+
target_ratios = set(
|
|
40
|
+
(i, j)
|
|
41
|
+
for n in range(min_num, max_num + 1)
|
|
42
|
+
for i in range(1, n + 1)
|
|
43
|
+
for j in range(1, n + 1)
|
|
44
|
+
if i * j <= max_num and i * j >= min_num
|
|
45
|
+
)
|
|
46
|
+
# print(target_ratios)
|
|
47
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
48
|
+
|
|
49
|
+
# find the closest aspect ratio to the target
|
|
50
|
+
target_aspect_ratio = find_closest_aspect_ratio(
|
|
51
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# print(target_aspect_ratio)
|
|
55
|
+
# calculate the target width and height
|
|
56
|
+
target_width = image_size * target_aspect_ratio[0]
|
|
57
|
+
target_height = image_size * target_aspect_ratio[1]
|
|
58
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
59
|
+
|
|
60
|
+
# resize the image
|
|
61
|
+
resized_img = image.resize((target_width, target_height))
|
|
62
|
+
processed_images = []
|
|
63
|
+
for i in range(blocks):
|
|
64
|
+
box = (
|
|
65
|
+
(i % (target_width // image_size)) * image_size,
|
|
66
|
+
(i // (target_width // image_size)) * image_size,
|
|
67
|
+
((i % (target_width // image_size)) + 1) * image_size,
|
|
68
|
+
((i // (target_width // image_size)) + 1) * image_size,
|
|
69
|
+
)
|
|
70
|
+
# split the image
|
|
71
|
+
split_img = resized_img.crop(box)
|
|
72
|
+
processed_images.append(split_img)
|
|
73
|
+
assert len(processed_images) == blocks
|
|
74
|
+
if use_thumbnail and len(processed_images) != 1:
|
|
75
|
+
thumbnail_img = image.resize((image_size, image_size))
|
|
76
|
+
processed_images.append(thumbnail_img)
|
|
77
|
+
return processed_images, target_aspect_ratio
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DictOutput(object):
|
|
81
|
+
def keys(self):
|
|
82
|
+
return self.__dict__.keys()
|
|
83
|
+
|
|
84
|
+
def __getitem__(self, item):
|
|
85
|
+
if isinstance(item, int):
|
|
86
|
+
return list(self.__dict__.values())[item]
|
|
87
|
+
if item not in self.__dict__:
|
|
88
|
+
raise KeyError(item)
|
|
89
|
+
return self.__dict__[item]
|
|
90
|
+
|
|
91
|
+
def __setitem__(self, key, value):
|
|
92
|
+
self.__dict__[key] = value
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class VLChatProcessorOutput(DictOutput):
|
|
97
|
+
sft_format: str
|
|
98
|
+
input_ids: mx.array
|
|
99
|
+
target_ids: mx.array
|
|
100
|
+
images: mx.array
|
|
101
|
+
images_seq_mask: mx.array
|
|
102
|
+
images_spatial_crop: mx.array
|
|
103
|
+
num_image_tokens: List[int]
|
|
104
|
+
|
|
105
|
+
def __len__(self):
|
|
106
|
+
return len(self.input_ids)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class BatchCollateOutput(DictOutput):
|
|
111
|
+
sft_format: List[str]
|
|
112
|
+
input_ids: mx.array
|
|
113
|
+
labels: mx.array
|
|
114
|
+
images: mx.array
|
|
115
|
+
attention_mask: mx.array
|
|
116
|
+
images_seq_mask: mx.array
|
|
117
|
+
images_spatial_crop: mx.array
|
|
118
|
+
seq_lens: List[int]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ImageTransform:
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
|
125
|
+
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
|
126
|
+
normalize: bool = True,
|
|
127
|
+
):
|
|
128
|
+
self.mean = mean
|
|
129
|
+
self.std = std
|
|
130
|
+
self.normalize = normalize
|
|
131
|
+
|
|
132
|
+
def __call__(self, pil_img: Image.Image):
|
|
133
|
+
# Convert PIL image to numpy array and normalize
|
|
134
|
+
|
|
135
|
+
img = mx.array(np.array(pil_img)) / 255.0
|
|
136
|
+
|
|
137
|
+
# Transpose from HWC to CHW format
|
|
138
|
+
img = mx.transpose(img, [2, 0, 1])
|
|
139
|
+
|
|
140
|
+
if self.normalize:
|
|
141
|
+
mean = mx.array(self.mean).reshape(-1, 1, 1)
|
|
142
|
+
std = mx.array(self.std).reshape(-1, 1, 1)
|
|
143
|
+
img = (img - mean) / std
|
|
144
|
+
|
|
145
|
+
return img
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class DeepseekOCR2Processor(ProcessorMixin):
|
|
149
|
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
|
150
|
+
attributes = ["tokenizer"]
|
|
151
|
+
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
tokenizer: LlamaTokenizerFast,
|
|
155
|
+
candidate_resolutions: Tuple[Tuple[int, int]],
|
|
156
|
+
patch_size: int,
|
|
157
|
+
downsample_ratio: int,
|
|
158
|
+
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
159
|
+
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
160
|
+
normalize: bool = True,
|
|
161
|
+
image_token: str = "<image>",
|
|
162
|
+
pad_token: str = "<|▁pad▁|>",
|
|
163
|
+
add_special_token: bool = False,
|
|
164
|
+
sft_format: str = "deepseek",
|
|
165
|
+
mask_prompt: bool = True,
|
|
166
|
+
ignore_id: int = -100,
|
|
167
|
+
**kwargs,
|
|
168
|
+
):
|
|
169
|
+
self.candidate_resolutions = candidate_resolutions
|
|
170
|
+
self.image_size = candidate_resolutions[0][0]
|
|
171
|
+
self.patch_size = patch_size
|
|
172
|
+
self.image_mean = image_mean
|
|
173
|
+
self.image_std = image_std
|
|
174
|
+
self.normalize = normalize
|
|
175
|
+
self.downsample_ratio = downsample_ratio
|
|
176
|
+
|
|
177
|
+
self.image_transform = ImageTransform(
|
|
178
|
+
mean=image_mean, std=image_std, normalize=normalize
|
|
179
|
+
)
|
|
180
|
+
self.tokenizer = tokenizer
|
|
181
|
+
self.tokenizer.padding_side = "left"
|
|
182
|
+
|
|
183
|
+
# Add special tokens
|
|
184
|
+
if tokenizer.pad_token is None:
|
|
185
|
+
self.tokenizer.add_special_tokens({"pad_token": pad_token})
|
|
186
|
+
print(
|
|
187
|
+
f"Add pad token = ['{pad_token}'] to the tokenizer\n"
|
|
188
|
+
f"{pad_token}:{tokenizer.encode(pad_token, add_special_tokens=False)[0]}"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
image_token_id = self.tokenizer.vocab.get(image_token)
|
|
192
|
+
if image_token_id is None:
|
|
193
|
+
special_tokens = [image_token]
|
|
194
|
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
|
195
|
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
196
|
+
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
|
197
|
+
print(
|
|
198
|
+
f"Add image token = ['{image_token}'] to the tokenizer\n"
|
|
199
|
+
f"{image_token}:{tokenizer.encode(image_token, add_special_tokens=False)[0]}"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Add grounding-related tokens
|
|
203
|
+
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
|
|
204
|
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
|
205
|
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
206
|
+
print("Added grounding-related tokens")
|
|
207
|
+
|
|
208
|
+
# Add chat tokens
|
|
209
|
+
special_tokens = ["<|User|>", "<|Assistant|>"]
|
|
210
|
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
|
211
|
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
212
|
+
print("Added chat tokens")
|
|
213
|
+
|
|
214
|
+
self.image_token = image_token
|
|
215
|
+
self.pad_token = pad_token
|
|
216
|
+
self.add_special_token = add_special_token
|
|
217
|
+
self.sft_format = sft_format
|
|
218
|
+
self.mask_prompt = mask_prompt
|
|
219
|
+
self.ignore_id = ignore_id
|
|
220
|
+
|
|
221
|
+
super().__init__(tokenizer, **kwargs)
|
|
222
|
+
|
|
223
|
+
# Add chat template
|
|
224
|
+
self.chat_template = kwargs.pop("chat_template", self.default_chat_template)
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def default_chat_template(self):
|
|
228
|
+
return (
|
|
229
|
+
"{% for message in messages %}"
|
|
230
|
+
"{% if message['role'] == 'user' %}"
|
|
231
|
+
"{% elif message['role'] == 'assistant' %}{% endif %}"
|
|
232
|
+
"{{message['content']}} "
|
|
233
|
+
"{% endfor %}"
|
|
234
|
+
"{% if add_generation_prompt %}{% endif %}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def bos_id(self):
|
|
239
|
+
return self.tokenizer.bos_token_id
|
|
240
|
+
|
|
241
|
+
@property
|
|
242
|
+
def eos_id(self):
|
|
243
|
+
return self.tokenizer.eos_token_id
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def pad_id(self):
|
|
247
|
+
return self.tokenizer.pad_token_id
|
|
248
|
+
|
|
249
|
+
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
|
250
|
+
t = self.tokenizer.encode(text, add_special_tokens=False)
|
|
251
|
+
|
|
252
|
+
if bos:
|
|
253
|
+
t = [self.bos_id] + t
|
|
254
|
+
if eos:
|
|
255
|
+
t = t + [self.eos_id]
|
|
256
|
+
|
|
257
|
+
return t
|
|
258
|
+
|
|
259
|
+
def decode(self, t: List[int], **kwargs) -> str:
|
|
260
|
+
return self.tokenizer.decode(t, **kwargs)
|
|
261
|
+
|
|
262
|
+
def process_one(
|
|
263
|
+
self,
|
|
264
|
+
prompt: str = None,
|
|
265
|
+
images: List[Image.Image] = None,
|
|
266
|
+
inference_mode: bool = True,
|
|
267
|
+
base_size: int = 1024,
|
|
268
|
+
image_size: int = 768,
|
|
269
|
+
cropping: bool = True,
|
|
270
|
+
min_patches: int = 1,
|
|
271
|
+
max_patches: int = 6,
|
|
272
|
+
):
|
|
273
|
+
|
|
274
|
+
sft_format = prompt
|
|
275
|
+
(
|
|
276
|
+
tokenized_str,
|
|
277
|
+
images_list,
|
|
278
|
+
images_seq_mask,
|
|
279
|
+
images_spatial_crop,
|
|
280
|
+
num_image_tokens,
|
|
281
|
+
) = self.tokenize_with_images(
|
|
282
|
+
sft_format,
|
|
283
|
+
images,
|
|
284
|
+
base_size=base_size,
|
|
285
|
+
image_size=image_size,
|
|
286
|
+
cropping=cropping,
|
|
287
|
+
min_patches=min_patches,
|
|
288
|
+
max_patches=max_patches,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
masked_tokenized_str = []
|
|
292
|
+
for token_index in tokenized_str:
|
|
293
|
+
if token_index != self.image_token_id:
|
|
294
|
+
masked_tokenized_str.append(token_index)
|
|
295
|
+
else:
|
|
296
|
+
masked_tokenized_str.append(self.ignore_id)
|
|
297
|
+
|
|
298
|
+
input_ids = mx.array(tokenized_str)
|
|
299
|
+
target_ids = mx.array(masked_tokenized_str)
|
|
300
|
+
images_seq_mask = mx.array(images_seq_mask)
|
|
301
|
+
|
|
302
|
+
# Set ignored indices
|
|
303
|
+
target_ids = mx.where(
|
|
304
|
+
(input_ids < 0) | (input_ids == self.image_token_id),
|
|
305
|
+
self.ignore_id,
|
|
306
|
+
target_ids,
|
|
307
|
+
)
|
|
308
|
+
input_ids = mx.where(input_ids < 0, self.pad_id, input_ids)
|
|
309
|
+
|
|
310
|
+
if inference_mode:
|
|
311
|
+
input_ids = input_ids[:-1]
|
|
312
|
+
target_ids = target_ids[:-1]
|
|
313
|
+
images_seq_mask = images_seq_mask[:-1]
|
|
314
|
+
|
|
315
|
+
return {
|
|
316
|
+
"input_ids": input_ids[None, :],
|
|
317
|
+
"attention_mask": input_ids != self.pad_id,
|
|
318
|
+
"labels": target_ids,
|
|
319
|
+
"images": images_list,
|
|
320
|
+
"images_seq_mask": images_seq_mask[None, ...],
|
|
321
|
+
"images_spatial_crop": images_spatial_crop,
|
|
322
|
+
"num_image_tokens": num_image_tokens,
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
def pad_sequence(self, sequences, padding_value):
|
|
326
|
+
# Get max length of sequences
|
|
327
|
+
max_len = max(len(seq) for seq in sequences)
|
|
328
|
+
|
|
329
|
+
# Pad each sequence to max length
|
|
330
|
+
padded_seqs = []
|
|
331
|
+
for seq in sequences:
|
|
332
|
+
pad_length = max_len - len(seq)
|
|
333
|
+
if pad_length > 0:
|
|
334
|
+
padding = mx.full((pad_length,), padding_value)
|
|
335
|
+
padded_seq = mx.concatenate([seq, padding])
|
|
336
|
+
else:
|
|
337
|
+
padded_seq = seq
|
|
338
|
+
padded_seqs.append(padded_seq)
|
|
339
|
+
|
|
340
|
+
return mx.stack(padded_seqs)
|
|
341
|
+
|
|
342
|
+
def tokenize_with_images(
|
|
343
|
+
self,
|
|
344
|
+
conversation: str,
|
|
345
|
+
images: List[Image.Image],
|
|
346
|
+
base_size: int = 1024,
|
|
347
|
+
image_size: int = 768,
|
|
348
|
+
cropping: bool = True,
|
|
349
|
+
min_patches: int = 1,
|
|
350
|
+
max_patches: int = 6,
|
|
351
|
+
):
|
|
352
|
+
"""Tokenize text with <image> tags.
|
|
353
|
+
|
|
354
|
+
For DeepSeek-OCR-2 with Qwen2 encoder:
|
|
355
|
+
- Global view (1024x1024): 256 tokens from Qwen2 encoder
|
|
356
|
+
- Local patches (768x768): 144 tokens each from Qwen2 encoder
|
|
357
|
+
- Plus 1 view_separator token
|
|
358
|
+
|
|
359
|
+
Dynamic resolution:
|
|
360
|
+
- Total tokens = (num_patches * 144) + 256 + 1
|
|
361
|
+
- Default: 0-6 patches at 768x768 + 1 global at 1024x1024
|
|
362
|
+
"""
|
|
363
|
+
# Token counts for Qwen2 encoder
|
|
364
|
+
TOKENS_PER_PATCH = 144 # 12x12 SAM features for 768x768
|
|
365
|
+
TOKENS_PER_GLOBAL = 256 # 16x16 SAM features for 1024x1024
|
|
366
|
+
TOKENS_VIEW_SEP = 1
|
|
367
|
+
|
|
368
|
+
assert conversation.count(self.image_token) == len(
|
|
369
|
+
images
|
|
370
|
+
), f"The number of image tokens in the prompt does not match the number of images: {conversation.count(self.image_token)} != {len(images)}"
|
|
371
|
+
|
|
372
|
+
text_splits = conversation.split(self.image_token)
|
|
373
|
+
|
|
374
|
+
all_patches_list = []
|
|
375
|
+
all_global_list = []
|
|
376
|
+
images_seq_mask = []
|
|
377
|
+
tokenized_str = []
|
|
378
|
+
images_spatial_crop = []
|
|
379
|
+
num_image_tokens_list = []
|
|
380
|
+
|
|
381
|
+
for text_sep, image in zip(text_splits, images):
|
|
382
|
+
# Tokenize the text before this image
|
|
383
|
+
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
|
384
|
+
tokenized_str += tokenized_sep
|
|
385
|
+
images_seq_mask += [False] * len(tokenized_sep)
|
|
386
|
+
|
|
387
|
+
# Process global view: pad to base_size x base_size (1024x1024)
|
|
388
|
+
global_view = ImageOps.pad(
|
|
389
|
+
image,
|
|
390
|
+
(base_size, base_size),
|
|
391
|
+
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
|
392
|
+
)
|
|
393
|
+
global_tensor = self.image_transform(global_view).astype(mx.bfloat16)
|
|
394
|
+
all_global_list.append(global_tensor)
|
|
395
|
+
|
|
396
|
+
# Process local patches using dynamic resolution
|
|
397
|
+
if cropping and min_patches > 0:
|
|
398
|
+
# Use dynamic_preprocess to split image into patches
|
|
399
|
+
patches, (rows, cols) = dynamic_preprocess(
|
|
400
|
+
image,
|
|
401
|
+
min_num=min_patches,
|
|
402
|
+
max_num=max_patches,
|
|
403
|
+
image_size=image_size, # 768x768 patches
|
|
404
|
+
use_thumbnail=False,
|
|
405
|
+
)
|
|
406
|
+
num_patches = len(patches)
|
|
407
|
+
|
|
408
|
+
# Transform each patch
|
|
409
|
+
patch_tensors = []
|
|
410
|
+
for patch in patches:
|
|
411
|
+
patch_tensor = self.image_transform(patch).astype(mx.bfloat16)
|
|
412
|
+
patch_tensors.append(patch_tensor)
|
|
413
|
+
|
|
414
|
+
if patch_tensors:
|
|
415
|
+
patches_stacked = mx.stack(patch_tensors, axis=0)
|
|
416
|
+
all_patches_list.append(patches_stacked)
|
|
417
|
+
|
|
418
|
+
images_spatial_crop.append([rows, cols])
|
|
419
|
+
else:
|
|
420
|
+
# No patches, only global view
|
|
421
|
+
num_patches = 0
|
|
422
|
+
images_spatial_crop.append([0, 0])
|
|
423
|
+
|
|
424
|
+
# Calculate number of image tokens for this image
|
|
425
|
+
# Order: [local_patches, global_view, view_separator]
|
|
426
|
+
num_image_tokens = (
|
|
427
|
+
(num_patches * TOKENS_PER_PATCH) + TOKENS_PER_GLOBAL + TOKENS_VIEW_SEP
|
|
428
|
+
)
|
|
429
|
+
num_image_tokens_list.append(num_image_tokens)
|
|
430
|
+
|
|
431
|
+
# Add image tokens to sequence
|
|
432
|
+
tokenized_image = [self.image_token_id] * num_image_tokens
|
|
433
|
+
tokenized_str += tokenized_image
|
|
434
|
+
images_seq_mask += [True] * len(tokenized_image)
|
|
435
|
+
|
|
436
|
+
# Tokenize the text after the last image
|
|
437
|
+
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
|
438
|
+
tokenized_str += tokenized_sep
|
|
439
|
+
images_seq_mask += [False] * len(tokenized_sep)
|
|
440
|
+
|
|
441
|
+
# Add the bos token
|
|
442
|
+
bos_id = 0
|
|
443
|
+
tokenized_str = [bos_id] + tokenized_str
|
|
444
|
+
images_seq_mask = [False] + images_seq_mask
|
|
445
|
+
|
|
446
|
+
images_seq_mask = mx.array(images_seq_mask)
|
|
447
|
+
|
|
448
|
+
# Stack global images
|
|
449
|
+
if len(all_global_list) == 0:
|
|
450
|
+
images_ori = mx.zeros((1, 3, base_size, base_size))
|
|
451
|
+
images_spatial_crop = mx.zeros((1, 2))
|
|
452
|
+
else:
|
|
453
|
+
images_ori = mx.stack(all_global_list, axis=0)
|
|
454
|
+
images_spatial_crop = mx.array(images_spatial_crop)
|
|
455
|
+
|
|
456
|
+
# Stack patches (or zeros if no patches)
|
|
457
|
+
if all_patches_list:
|
|
458
|
+
# Concatenate all patches from all images
|
|
459
|
+
images_patches = mx.concatenate(all_patches_list, axis=0)
|
|
460
|
+
else:
|
|
461
|
+
images_patches = mx.zeros((1, 3, image_size, image_size))
|
|
462
|
+
|
|
463
|
+
assert len(tokenized_str) == len(
|
|
464
|
+
images_seq_mask
|
|
465
|
+
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to images_seq_mask's length {len(images_seq_mask)}"
|
|
466
|
+
|
|
467
|
+
return (
|
|
468
|
+
tokenized_str,
|
|
469
|
+
[images_patches, images_ori],
|
|
470
|
+
images_seq_mask,
|
|
471
|
+
images_spatial_crop,
|
|
472
|
+
num_image_tokens_list[0] if num_image_tokens_list else 257,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
def __call__(
|
|
476
|
+
self,
|
|
477
|
+
*,
|
|
478
|
+
text: str = None,
|
|
479
|
+
images: List[Image.Image] = None,
|
|
480
|
+
inference_mode: bool = True,
|
|
481
|
+
image_size: int = 768,
|
|
482
|
+
base_size: int = 1024,
|
|
483
|
+
cropping: bool = True,
|
|
484
|
+
min_patches: int = 1,
|
|
485
|
+
max_patches: int = 6,
|
|
486
|
+
padding: bool = True,
|
|
487
|
+
return_tensors: Literal["np", "mx", "pt"] = "mx",
|
|
488
|
+
**kwargs,
|
|
489
|
+
):
|
|
490
|
+
"""Process text and images for DeepSeek-OCR-2.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
text (str or List[str]): the formatted prompt(s)
|
|
494
|
+
images (List[ImageType]): the list of images (one per prompt for batched inputs)
|
|
495
|
+
inference_mode (bool): if True, remove the last eos token
|
|
496
|
+
image_size (int): size of local patches (default 768)
|
|
497
|
+
base_size (int): size of global view (default 1024)
|
|
498
|
+
cropping (bool): whether to use dynamic resolution with local patches
|
|
499
|
+
min_patches (int): minimum number of patches (default 1)
|
|
500
|
+
max_patches (int): maximum number of patches (default 6)
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
outputs (dict): the output of the processor,
|
|
504
|
+
- input_ids (mx.array): [batch_size, N + image tokens]
|
|
505
|
+
- images (List[mx.array]): [patches, global_images]
|
|
506
|
+
- images_seq_mask (mx.array): mask for image token positions
|
|
507
|
+
- images_spatial_crop (mx.array): patch grid dimensions
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
# Handle batched inputs (list of prompts with list of images)
|
|
511
|
+
if isinstance(text, list):
|
|
512
|
+
if images is None:
|
|
513
|
+
images = [None] * len(text)
|
|
514
|
+
|
|
515
|
+
batch_results = []
|
|
516
|
+
for i, prompt in enumerate(text):
|
|
517
|
+
# Each prompt has one image
|
|
518
|
+
img = [images[i]] if images[i] is not None else None
|
|
519
|
+
result = self.process_one(
|
|
520
|
+
prompt=prompt,
|
|
521
|
+
images=img,
|
|
522
|
+
inference_mode=inference_mode,
|
|
523
|
+
image_size=image_size,
|
|
524
|
+
base_size=base_size,
|
|
525
|
+
cropping=cropping,
|
|
526
|
+
min_patches=min_patches,
|
|
527
|
+
max_patches=max_patches,
|
|
528
|
+
)
|
|
529
|
+
batch_results.append(result)
|
|
530
|
+
|
|
531
|
+
# Collate batch results
|
|
532
|
+
return self._collate_batch(batch_results, padding=padding)
|
|
533
|
+
|
|
534
|
+
# Single input case
|
|
535
|
+
prepare = self.process_one(
|
|
536
|
+
prompt=text,
|
|
537
|
+
images=images,
|
|
538
|
+
inference_mode=inference_mode,
|
|
539
|
+
image_size=image_size,
|
|
540
|
+
base_size=base_size,
|
|
541
|
+
cropping=cropping,
|
|
542
|
+
min_patches=min_patches,
|
|
543
|
+
max_patches=max_patches,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
return prepare
|
|
547
|
+
|
|
548
|
+
def _collate_batch(self, batch_results: List[Dict], padding: bool = True) -> Dict:
|
|
549
|
+
"""Collate multiple processed results into a batch."""
|
|
550
|
+
if not batch_results:
|
|
551
|
+
return {}
|
|
552
|
+
|
|
553
|
+
batch_size = len(batch_results)
|
|
554
|
+
|
|
555
|
+
# Get max sequence length for padding
|
|
556
|
+
max_seq_len = max(r["input_ids"].shape[1] for r in batch_results)
|
|
557
|
+
|
|
558
|
+
# Pad and stack input_ids
|
|
559
|
+
padded_input_ids = []
|
|
560
|
+
padded_images_seq_mask = []
|
|
561
|
+
for r in batch_results:
|
|
562
|
+
seq_len = r["input_ids"].shape[1]
|
|
563
|
+
pad_len = max_seq_len - seq_len
|
|
564
|
+
|
|
565
|
+
if pad_len > 0:
|
|
566
|
+
# Pad input_ids on the left
|
|
567
|
+
input_ids = mx.concatenate(
|
|
568
|
+
[
|
|
569
|
+
mx.full((1, pad_len), self.pad_id, dtype=r["input_ids"].dtype),
|
|
570
|
+
r["input_ids"],
|
|
571
|
+
],
|
|
572
|
+
axis=1,
|
|
573
|
+
)
|
|
574
|
+
# Pad images_seq_mask on the left with False
|
|
575
|
+
seq_mask = mx.concatenate(
|
|
576
|
+
[mx.zeros((1, pad_len), dtype=mx.bool_), r["images_seq_mask"]],
|
|
577
|
+
axis=1,
|
|
578
|
+
)
|
|
579
|
+
else:
|
|
580
|
+
input_ids = r["input_ids"]
|
|
581
|
+
seq_mask = r["images_seq_mask"]
|
|
582
|
+
|
|
583
|
+
padded_input_ids.append(input_ids)
|
|
584
|
+
padded_images_seq_mask.append(seq_mask)
|
|
585
|
+
|
|
586
|
+
# Stack into batch
|
|
587
|
+
input_ids = mx.concatenate(padded_input_ids, axis=0)
|
|
588
|
+
images_seq_mask = mx.concatenate(padded_images_seq_mask, axis=0)
|
|
589
|
+
|
|
590
|
+
# Combine images: [patches, global_images]
|
|
591
|
+
all_patches = []
|
|
592
|
+
all_global_images = []
|
|
593
|
+
all_spatial_crops = []
|
|
594
|
+
|
|
595
|
+
for r in batch_results:
|
|
596
|
+
patches, global_img = r["images"]
|
|
597
|
+
# Only add non-zero patches
|
|
598
|
+
if mx.sum(patches).item() != 0:
|
|
599
|
+
all_patches.append(patches)
|
|
600
|
+
all_global_images.append(global_img)
|
|
601
|
+
all_spatial_crops.append(r["images_spatial_crop"])
|
|
602
|
+
|
|
603
|
+
# Stack patches and global images
|
|
604
|
+
if all_patches:
|
|
605
|
+
combined_patches = mx.concatenate(all_patches, axis=0)
|
|
606
|
+
else:
|
|
607
|
+
combined_patches = mx.zeros((1, 3, 1024, 1024))
|
|
608
|
+
|
|
609
|
+
combined_global_images = mx.concatenate(all_global_images, axis=0)
|
|
610
|
+
combined_spatial_crops = mx.concatenate(all_spatial_crops, axis=0)
|
|
611
|
+
|
|
612
|
+
return {
|
|
613
|
+
"input_ids": input_ids,
|
|
614
|
+
"attention_mask": input_ids != self.pad_id,
|
|
615
|
+
"images": [combined_patches, combined_global_images],
|
|
616
|
+
"images_seq_mask": images_seq_mask,
|
|
617
|
+
"images_spatial_crop": combined_spatial_crops,
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
# Install a composable AutoProcessor patch for DeepSeek-OCR-2
|
|
622
|
+
from ..base import install_auto_processor_patch
|
|
623
|
+
|
|
624
|
+
install_auto_processor_patch("deepseekocr_2", DeepseekOCR2Processor)
|