fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures, install_auto_processor_patch
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .processing_paddleocr_vl import PaddleOCRVLProcessor
|
|
10
|
+
from .vision import VisionModel
|
|
11
|
+
|
|
12
|
+
install_auto_processor_patch("paddleocr_vl", PaddleOCRVLProcessor)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Model(nn.Module):
|
|
16
|
+
def __init__(self, config: ModelConfig):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.config = config
|
|
19
|
+
self.visual = VisionModel(config.vision_config)
|
|
20
|
+
self.language_model = LanguageModel(config.text_config, config)
|
|
21
|
+
|
|
22
|
+
def get_input_embeddings(
|
|
23
|
+
self,
|
|
24
|
+
input_ids: Optional[mx.array] = None,
|
|
25
|
+
pixel_values: Optional[mx.array] = None,
|
|
26
|
+
**kwargs,
|
|
27
|
+
):
|
|
28
|
+
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
29
|
+
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
|
30
|
+
mask = kwargs.pop("mask", None)
|
|
31
|
+
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
|
32
|
+
|
|
33
|
+
if pixel_values is None:
|
|
34
|
+
# Reset position state for text-only generation
|
|
35
|
+
self.language_model._position_ids = None
|
|
36
|
+
self.language_model._rope_deltas = None
|
|
37
|
+
return InputEmbeddingsFeatures(
|
|
38
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
dtype = self.visual.embeddings.patch_embedding.weight.dtype
|
|
42
|
+
pixel_values = mx.array(pixel_values, dtype=dtype)
|
|
43
|
+
|
|
44
|
+
# Get the input embeddings from the language model
|
|
45
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
46
|
+
|
|
47
|
+
# Get the ouptut hidden states from the vision model
|
|
48
|
+
hidden_states = self.visual(pixel_values, grid_thw, output_hidden_states=False)
|
|
49
|
+
|
|
50
|
+
# Insert special image tokens in the input_ids
|
|
51
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
52
|
+
self.config.image_token_id,
|
|
53
|
+
hidden_states,
|
|
54
|
+
inputs_embeds,
|
|
55
|
+
input_ids,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Pre-calculate position_ids for chunked prefill
|
|
59
|
+
if image_grid_thw is not None or video_grid_thw is not None:
|
|
60
|
+
position_ids, rope_deltas = self.language_model.get_rope_index(
|
|
61
|
+
input_ids, image_grid_thw, video_grid_thw, mask
|
|
62
|
+
)
|
|
63
|
+
self.language_model._position_ids = position_ids
|
|
64
|
+
self.language_model._rope_deltas = rope_deltas
|
|
65
|
+
|
|
66
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def merge_input_ids_with_image_features(
|
|
70
|
+
image_token_id,
|
|
71
|
+
image_features,
|
|
72
|
+
inputs_embeds,
|
|
73
|
+
input_ids,
|
|
74
|
+
):
|
|
75
|
+
"""Merge image features into input embeddings at image token positions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
image_features: Vision features from the vision tower [num_features, hidden_dim]
|
|
79
|
+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
80
|
+
input_ids: Input token IDs [batch_size, seq_len]
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Updated input embeddings with image features inserted
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# Positions of <image> tokens in input_ids
|
|
87
|
+
image_positions = input_ids == image_token_id
|
|
88
|
+
|
|
89
|
+
# Get dimensions
|
|
90
|
+
batch_size, seq_len = input_ids.shape
|
|
91
|
+
|
|
92
|
+
# Process each batch item
|
|
93
|
+
batch_outputs = []
|
|
94
|
+
feature_start_idx = 0
|
|
95
|
+
|
|
96
|
+
for batch_idx in range(batch_size):
|
|
97
|
+
# Get mask for this batch
|
|
98
|
+
image_mask = image_positions[batch_idx]
|
|
99
|
+
num_positions = mx.sum(image_mask).item()
|
|
100
|
+
|
|
101
|
+
if num_positions > 0:
|
|
102
|
+
# Extract features for this batch
|
|
103
|
+
batch_features = image_features[
|
|
104
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# Validate we have the right number of features
|
|
108
|
+
if batch_features.shape[0] != num_positions:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Number of image token positions ({num_positions}) does not match "
|
|
111
|
+
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Create indices for gathering
|
|
115
|
+
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
116
|
+
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
117
|
+
|
|
118
|
+
# Gather features
|
|
119
|
+
gathered_features = batch_features[feature_indices]
|
|
120
|
+
|
|
121
|
+
# Combine with original embeddings
|
|
122
|
+
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
123
|
+
batch_output = mx.where(
|
|
124
|
+
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
feature_start_idx += num_positions
|
|
128
|
+
else:
|
|
129
|
+
# No image tokens in this batch item
|
|
130
|
+
batch_output = inputs_embeds[batch_idx]
|
|
131
|
+
|
|
132
|
+
batch_outputs.append(batch_output)
|
|
133
|
+
|
|
134
|
+
# Stack all batch outputs
|
|
135
|
+
return mx.stack(batch_outputs, axis=0)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def layers(self):
|
|
139
|
+
return self.language_model.model.layers
|
|
140
|
+
|
|
141
|
+
def __call__(
|
|
142
|
+
self,
|
|
143
|
+
input_ids: mx.array,
|
|
144
|
+
pixel_values: Optional[mx.array] = None,
|
|
145
|
+
mask: Optional[mx.array] = None,
|
|
146
|
+
cache=None,
|
|
147
|
+
**kwargs,
|
|
148
|
+
):
|
|
149
|
+
|
|
150
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
151
|
+
input_ids, pixel_values, **kwargs
|
|
152
|
+
)
|
|
153
|
+
kwargs = {
|
|
154
|
+
"pixel_values": pixel_values,
|
|
155
|
+
**kwargs,
|
|
156
|
+
}
|
|
157
|
+
logits = self.language_model(
|
|
158
|
+
input_ids,
|
|
159
|
+
input_embeddings_features.inputs_embeds,
|
|
160
|
+
mask=mask,
|
|
161
|
+
cache=cache,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|
|
164
|
+
return logits
|
|
165
|
+
|
|
166
|
+
def sanitize(self, weights):
|
|
167
|
+
_keys_to_ignore_on_load_unexpected = [
|
|
168
|
+
"packing_position_embedding",
|
|
169
|
+
"vision_model.head",
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
def transform_key(key):
|
|
173
|
+
if "visual.vision_model" in key:
|
|
174
|
+
if "embeddings" in key or "post_layernorm" in key:
|
|
175
|
+
key = key.replace("visual.vision_model", "visual")
|
|
176
|
+
elif "encoder" in key:
|
|
177
|
+
key = key.replace("visual.vision_model.encoder", "visual")
|
|
178
|
+
elif "mlp_AR" in key:
|
|
179
|
+
key = key.replace("mlp_AR", "visual.projector")
|
|
180
|
+
elif "model" in key:
|
|
181
|
+
key = key.replace("model", "language_model.model")
|
|
182
|
+
elif "lm_head" in key:
|
|
183
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
184
|
+
|
|
185
|
+
return key
|
|
186
|
+
|
|
187
|
+
new_weights = {}
|
|
188
|
+
for k, v in weights.items():
|
|
189
|
+
if (
|
|
190
|
+
"packing_position_embedding" in k
|
|
191
|
+
or "vision_model.head" in k
|
|
192
|
+
or ("visual" in k and "k_proj" in k)
|
|
193
|
+
or ("visual" in k and "v_proj" in k)
|
|
194
|
+
):
|
|
195
|
+
continue
|
|
196
|
+
elif "visual" in k and "q_proj" in k:
|
|
197
|
+
new_key = transform_key(k)
|
|
198
|
+
k_proj = weights.get(k.replace("q_proj", "k_proj"), None)
|
|
199
|
+
v_proj = weights.get(k.replace("q_proj", "v_proj"), None)
|
|
200
|
+
if k_proj is not None and v_proj is not None:
|
|
201
|
+
merged_tensor = mx.concatenate([v, k_proj, v_proj], axis=0)
|
|
202
|
+
merged_key = new_key.replace("q_proj", "qkv")
|
|
203
|
+
new_weights[merged_key] = merged_tensor
|
|
204
|
+
else:
|
|
205
|
+
new_weights[transform_key(k)] = v
|
|
206
|
+
|
|
207
|
+
return new_weights
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from transformers import AutoTokenizer
|
|
8
|
+
from transformers.feature_extraction_utils import BatchFeature
|
|
9
|
+
from transformers.image_processing_utils import BaseImageProcessor
|
|
10
|
+
from transformers.image_transforms import convert_to_rgb
|
|
11
|
+
from transformers.image_utils import (
|
|
12
|
+
ImageInput,
|
|
13
|
+
PILImageResampling,
|
|
14
|
+
make_flat_list_of_images,
|
|
15
|
+
to_numpy_array,
|
|
16
|
+
valid_images,
|
|
17
|
+
)
|
|
18
|
+
from transformers.processing_utils import ProcessorMixin
|
|
19
|
+
from transformers.utils import logging
|
|
20
|
+
|
|
21
|
+
logger = logging.get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def smart_resize(
|
|
25
|
+
height: int,
|
|
26
|
+
width: int,
|
|
27
|
+
factor: int,
|
|
28
|
+
min_pixels: int,
|
|
29
|
+
max_pixels: int,
|
|
30
|
+
):
|
|
31
|
+
if height < factor:
|
|
32
|
+
width = round((width * factor) / height)
|
|
33
|
+
height = factor
|
|
34
|
+
|
|
35
|
+
if width < factor:
|
|
36
|
+
height = round((height * factor) / width)
|
|
37
|
+
width = factor
|
|
38
|
+
|
|
39
|
+
if max(height, width) / min(height, width) > 200:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
|
42
|
+
)
|
|
43
|
+
h_bar = round(height / factor) * factor
|
|
44
|
+
w_bar = round(width / factor) * factor
|
|
45
|
+
if h_bar * w_bar > max_pixels:
|
|
46
|
+
beta = math.sqrt((height * width) / max_pixels)
|
|
47
|
+
h_bar = math.floor(height / beta / factor) * factor
|
|
48
|
+
w_bar = math.floor(width / beta / factor) * factor
|
|
49
|
+
elif h_bar * w_bar < min_pixels:
|
|
50
|
+
beta = math.sqrt(min_pixels / (height * width))
|
|
51
|
+
h_bar = math.ceil(height * beta / factor) * factor
|
|
52
|
+
w_bar = math.ceil(width * beta / factor) * factor
|
|
53
|
+
return h_bar, w_bar
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ImageProcessor(BaseImageProcessor):
|
|
57
|
+
"""
|
|
58
|
+
MLX-native image processor for PaddleOCRVL that doesn't require torch.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
model_input_names = ["pixel_values"]
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
do_resize: bool = True,
|
|
66
|
+
size: dict[str, int] | None = None,
|
|
67
|
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
|
68
|
+
do_rescale: bool = True,
|
|
69
|
+
rescale_factor: int | float = 1 / 255,
|
|
70
|
+
do_normalize: bool = True,
|
|
71
|
+
image_mean: float | list[float] | None = None,
|
|
72
|
+
image_std: float | list[float] | None = None,
|
|
73
|
+
do_convert_rgb: bool = True,
|
|
74
|
+
min_pixels: int = 147384,
|
|
75
|
+
max_pixels: int = 2822400,
|
|
76
|
+
patch_size: int = 14,
|
|
77
|
+
temporal_patch_size: int = 1,
|
|
78
|
+
merge_size: int = 2,
|
|
79
|
+
**kwargs,
|
|
80
|
+
) -> None:
|
|
81
|
+
super().__init__(**kwargs)
|
|
82
|
+
if size is not None:
|
|
83
|
+
if "shortest_edge" not in size or "longest_edge" not in size:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"size must contain 'shortest_edge' and 'longest_edge' keys."
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
size = {"shortest_edge": 147384, "longest_edge": 2822400}
|
|
89
|
+
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
|
90
|
+
if min_pixels is not None:
|
|
91
|
+
size["shortest_edge"] = min_pixels
|
|
92
|
+
if max_pixels is not None:
|
|
93
|
+
size["longest_edge"] = max_pixels
|
|
94
|
+
self.min_pixels = size["shortest_edge"]
|
|
95
|
+
self.max_pixels = size["longest_edge"]
|
|
96
|
+
self.size = size
|
|
97
|
+
self.do_resize = do_resize
|
|
98
|
+
self.resample = resample
|
|
99
|
+
self.do_rescale = do_rescale
|
|
100
|
+
self.rescale_factor = rescale_factor
|
|
101
|
+
self.do_normalize = do_normalize
|
|
102
|
+
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
|
|
103
|
+
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
|
|
104
|
+
self.patch_size = patch_size
|
|
105
|
+
self.temporal_patch_size = temporal_patch_size
|
|
106
|
+
self.merge_size = merge_size
|
|
107
|
+
self.do_convert_rgb = do_convert_rgb
|
|
108
|
+
|
|
109
|
+
def preprocess(
|
|
110
|
+
self,
|
|
111
|
+
images: ImageInput,
|
|
112
|
+
do_resize: Optional[bool] = None,
|
|
113
|
+
size: Optional[dict[str, int]] = None,
|
|
114
|
+
min_pixels: Optional[int] = None,
|
|
115
|
+
max_pixels: Optional[int] = None,
|
|
116
|
+
resample: Optional[PILImageResampling] = None,
|
|
117
|
+
do_rescale: Optional[bool] = None,
|
|
118
|
+
rescale_factor: Optional[float] = None,
|
|
119
|
+
do_normalize: Optional[bool] = None,
|
|
120
|
+
image_mean: Optional[Union[float, list[float]]] = None,
|
|
121
|
+
image_std: Optional[Union[float, list[float]]] = None,
|
|
122
|
+
patch_size: Optional[int] = None,
|
|
123
|
+
temporal_patch_size: Optional[int] = None,
|
|
124
|
+
merge_size: Optional[int] = None,
|
|
125
|
+
do_convert_rgb: Optional[bool] = None,
|
|
126
|
+
return_tensors: Optional[str] = None,
|
|
127
|
+
**kwargs,
|
|
128
|
+
) -> BatchFeature:
|
|
129
|
+
min_pixels = min_pixels if min_pixels is not None else self.min_pixels
|
|
130
|
+
max_pixels = max_pixels if max_pixels is not None else self.max_pixels
|
|
131
|
+
|
|
132
|
+
if size is not None:
|
|
133
|
+
if "shortest_edge" not in size or "longest_edge" not in size:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"size must contain 'shortest_edge' and 'longest_edge' keys."
|
|
136
|
+
)
|
|
137
|
+
elif min_pixels is not None and max_pixels is not None:
|
|
138
|
+
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
|
139
|
+
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
|
|
140
|
+
|
|
141
|
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
|
142
|
+
resample = resample if resample is not None else self.resample
|
|
143
|
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
|
144
|
+
rescale_factor = (
|
|
145
|
+
rescale_factor if rescale_factor is not None else self.rescale_factor
|
|
146
|
+
)
|
|
147
|
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
|
148
|
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
|
149
|
+
image_std = image_std if image_std is not None else self.image_std
|
|
150
|
+
patch_size = patch_size if patch_size is not None else self.patch_size
|
|
151
|
+
temporal_patch_size = (
|
|
152
|
+
temporal_patch_size
|
|
153
|
+
if temporal_patch_size is not None
|
|
154
|
+
else self.temporal_patch_size
|
|
155
|
+
)
|
|
156
|
+
merge_size = merge_size if merge_size is not None else self.merge_size
|
|
157
|
+
do_convert_rgb = (
|
|
158
|
+
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if images is not None:
|
|
162
|
+
images = make_flat_list_of_images(images)
|
|
163
|
+
|
|
164
|
+
if images is not None and not valid_images(images):
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
|
167
|
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if do_convert_rgb:
|
|
171
|
+
images = [convert_to_rgb(image) for image in images]
|
|
172
|
+
|
|
173
|
+
data = {}
|
|
174
|
+
pixel_values, vision_grid_thws = [], []
|
|
175
|
+
if images is not None:
|
|
176
|
+
processed_images = []
|
|
177
|
+
for image in images:
|
|
178
|
+
width, height = image.size
|
|
179
|
+
resized_height, resized_width = smart_resize(
|
|
180
|
+
height,
|
|
181
|
+
width,
|
|
182
|
+
factor=patch_size * merge_size,
|
|
183
|
+
min_pixels=min_pixels,
|
|
184
|
+
max_pixels=max_pixels,
|
|
185
|
+
)
|
|
186
|
+
image = image.resize((resized_width, resized_height), resample)
|
|
187
|
+
img_array = to_numpy_array(image)
|
|
188
|
+
|
|
189
|
+
if do_rescale:
|
|
190
|
+
img_array = img_array / 255.0
|
|
191
|
+
|
|
192
|
+
if do_normalize:
|
|
193
|
+
mean = np.array(self.image_mean).reshape(1, 1, 3)
|
|
194
|
+
std = np.array(self.image_std).reshape(1, 1, 3)
|
|
195
|
+
img_array = (img_array - mean) / std
|
|
196
|
+
|
|
197
|
+
processed_images.append(img_array)
|
|
198
|
+
|
|
199
|
+
patches = np.array(processed_images)
|
|
200
|
+
|
|
201
|
+
if patches.shape[1] > 3:
|
|
202
|
+
patches = patches.transpose(0, 3, 1, 2)
|
|
203
|
+
if patches.shape[0] == 1:
|
|
204
|
+
patches = np.tile(patches, (temporal_patch_size, 1, 1, 1))
|
|
205
|
+
|
|
206
|
+
channel = patches.shape[1]
|
|
207
|
+
grid_t = patches.shape[0] // temporal_patch_size
|
|
208
|
+
grid_h, grid_w = (
|
|
209
|
+
resized_height // patch_size,
|
|
210
|
+
resized_width // patch_size,
|
|
211
|
+
)
|
|
212
|
+
patches = patches.reshape(
|
|
213
|
+
grid_t,
|
|
214
|
+
temporal_patch_size,
|
|
215
|
+
channel,
|
|
216
|
+
grid_h,
|
|
217
|
+
patch_size,
|
|
218
|
+
grid_w,
|
|
219
|
+
patch_size,
|
|
220
|
+
)
|
|
221
|
+
patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
|
|
222
|
+
if temporal_patch_size != 1:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"temporal_patch_size must be 1!, but got {temporal_patch_size}!"
|
|
225
|
+
)
|
|
226
|
+
flatten_patches = patches.reshape(
|
|
227
|
+
grid_t * grid_h * grid_w, channel, patch_size, patch_size
|
|
228
|
+
)
|
|
229
|
+
image_grid_thw = (grid_t, grid_h, grid_w)
|
|
230
|
+
pixel_values.extend(flatten_patches)
|
|
231
|
+
vision_grid_thws.append(image_grid_thw)
|
|
232
|
+
|
|
233
|
+
pixel_values = np.array([pixel_values])
|
|
234
|
+
vision_grid_thws = np.array(vision_grid_thws)
|
|
235
|
+
data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
|
|
236
|
+
|
|
237
|
+
return BatchFeature(data, tensor_type=return_tensors)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class PaddleOCRVLProcessor(ProcessorMixin):
|
|
241
|
+
attributes = ["image_processor", "tokenizer"]
|
|
242
|
+
valid_kwargs = ["chat_template"]
|
|
243
|
+
image_processor_class = "AutoImageProcessor"
|
|
244
|
+
tokenizer_class = "AutoTokenizer"
|
|
245
|
+
|
|
246
|
+
def __init__(
|
|
247
|
+
self,
|
|
248
|
+
image_processor=None,
|
|
249
|
+
tokenizer=None,
|
|
250
|
+
chat_template=None,
|
|
251
|
+
**kwargs,
|
|
252
|
+
):
|
|
253
|
+
|
|
254
|
+
if image_processor is None:
|
|
255
|
+
image_processor = ImageProcessor(**kwargs)
|
|
256
|
+
|
|
257
|
+
self.tokenizer = tokenizer
|
|
258
|
+
self.image_token = (
|
|
259
|
+
"<|IMAGE_PLACEHOLDER|>"
|
|
260
|
+
if not hasattr(tokenizer, "image_token")
|
|
261
|
+
else tokenizer.image_token
|
|
262
|
+
)
|
|
263
|
+
self.image_processor = image_processor
|
|
264
|
+
|
|
265
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
266
|
+
|
|
267
|
+
def __call__(
|
|
268
|
+
self,
|
|
269
|
+
images=None,
|
|
270
|
+
text: Union[str, List[str]] = None,
|
|
271
|
+
**kwargs,
|
|
272
|
+
) -> BatchFeature:
|
|
273
|
+
"""Process images and text for the model.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
images: Single image or list of images
|
|
277
|
+
text: Single text or list of texts
|
|
278
|
+
videos: Video inputs (not currently supported)
|
|
279
|
+
**kwargs: Additional arguments passed to tokenizer
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
BatchFeature with:
|
|
283
|
+
- input_ids: Token IDs with image placeholders replaced
|
|
284
|
+
- attention_mask: Attention mask
|
|
285
|
+
- pixel_values: Processed image patches
|
|
286
|
+
- image_grid_thw: Grid dimensions for each image
|
|
287
|
+
- position_ids: 4D position IDs for xdrope
|
|
288
|
+
"""
|
|
289
|
+
image_inputs = {}
|
|
290
|
+
|
|
291
|
+
if images is not None:
|
|
292
|
+
image_inputs = self.image_processor(images=images)
|
|
293
|
+
image_grid_thw = image_inputs["image_grid_thw"]
|
|
294
|
+
|
|
295
|
+
if text is None:
|
|
296
|
+
text = [""]
|
|
297
|
+
elif not isinstance(text, list):
|
|
298
|
+
text = [text]
|
|
299
|
+
|
|
300
|
+
text = [t for t in text] # Copy to avoid modifying original
|
|
301
|
+
|
|
302
|
+
if images is not None:
|
|
303
|
+
index = 0
|
|
304
|
+
for i in range(len(text)):
|
|
305
|
+
while self.image_token in text[i]:
|
|
306
|
+
text[i] = text[i].replace(
|
|
307
|
+
self.image_token,
|
|
308
|
+
"<|placeholder|>"
|
|
309
|
+
* (
|
|
310
|
+
image_grid_thw[index].prod()
|
|
311
|
+
// self.image_processor.merge_size
|
|
312
|
+
// self.image_processor.merge_size
|
|
313
|
+
),
|
|
314
|
+
1,
|
|
315
|
+
)
|
|
316
|
+
index += 1
|
|
317
|
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
318
|
+
|
|
319
|
+
# Pop return_tensors to handle it ourselves at the end
|
|
320
|
+
return_tensors = kwargs.pop("return_tensors", None)
|
|
321
|
+
|
|
322
|
+
# Tokenize text
|
|
323
|
+
text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
|
|
324
|
+
|
|
325
|
+
# Get input_ids and convert to numpy array for processing
|
|
326
|
+
input_ids = text_inputs["input_ids"]
|
|
327
|
+
if hasattr(input_ids, "tolist"):
|
|
328
|
+
# Handle mlx arrays or torch tensors
|
|
329
|
+
input_ids = np.array(input_ids.tolist())
|
|
330
|
+
elif isinstance(input_ids, list):
|
|
331
|
+
input_ids = np.array(input_ids)
|
|
332
|
+
|
|
333
|
+
return BatchFeature(
|
|
334
|
+
data={**text_inputs, **image_inputs},
|
|
335
|
+
tensor_type=return_tensors,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def batch_decode(self, *args, **kwargs):
|
|
339
|
+
"""Decode token IDs to text."""
|
|
340
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
341
|
+
|
|
342
|
+
def decode(self, *args, **kwargs):
|
|
343
|
+
"""Decode token IDs to text."""
|
|
344
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
345
|
+
|
|
346
|
+
def apply_chat_template(self, *args, **kwargs):
|
|
347
|
+
"""Apply chat template using the tokenizer."""
|
|
348
|
+
return self.tokenizer.apply_chat_template(*args, **kwargs)
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def model_input_names(self):
|
|
352
|
+
"""Return combined input names from tokenizer and image processor."""
|
|
353
|
+
tokenizer_input_names = (
|
|
354
|
+
self.tokenizer.model_input_names if self.tokenizer else []
|
|
355
|
+
)
|
|
356
|
+
image_processor_input_names = self.image_processor.model_input_names
|
|
357
|
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
358
|
+
|
|
359
|
+
@classmethod
|
|
360
|
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
361
|
+
"""Load processor from pretrained model path."""
|
|
362
|
+
import warnings
|
|
363
|
+
|
|
364
|
+
from huggingface_hub import hf_hub_download
|
|
365
|
+
|
|
366
|
+
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
|
367
|
+
|
|
368
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
369
|
+
is_local = model_path.exists() and model_path.is_dir()
|
|
370
|
+
|
|
371
|
+
# Suppress warning about mrope_section in rope_parameters
|
|
372
|
+
with warnings.catch_warnings():
|
|
373
|
+
warnings.filterwarnings(
|
|
374
|
+
"ignore", message="Unrecognized keys in `rope_parameters`"
|
|
375
|
+
)
|
|
376
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
377
|
+
str(model_path) if is_local else pretrained_model_name_or_path,
|
|
378
|
+
trust_remote_code=trust_remote_code,
|
|
379
|
+
local_files_only=is_local,
|
|
380
|
+
**kwargs,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Load image processor config from preprocessor_config.json
|
|
384
|
+
image_processor_config = {}
|
|
385
|
+
try:
|
|
386
|
+
if is_local:
|
|
387
|
+
config_path = model_path / "preprocessor_config.json"
|
|
388
|
+
else:
|
|
389
|
+
config_path = Path(
|
|
390
|
+
hf_hub_download(
|
|
391
|
+
pretrained_model_name_or_path, "preprocessor_config.json"
|
|
392
|
+
)
|
|
393
|
+
)
|
|
394
|
+
if config_path.exists():
|
|
395
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
396
|
+
preprocessor_config = json.load(f)
|
|
397
|
+
# Extract relevant image processor parameters
|
|
398
|
+
relevant_keys = [
|
|
399
|
+
"min_pixels",
|
|
400
|
+
"max_pixels",
|
|
401
|
+
"patch_size",
|
|
402
|
+
"temporal_patch_size",
|
|
403
|
+
"merge_size",
|
|
404
|
+
"image_mean",
|
|
405
|
+
"image_std",
|
|
406
|
+
"do_resize",
|
|
407
|
+
"do_rescale",
|
|
408
|
+
"do_normalize",
|
|
409
|
+
"do_convert_rgb",
|
|
410
|
+
]
|
|
411
|
+
for key in relevant_keys:
|
|
412
|
+
if key in preprocessor_config:
|
|
413
|
+
image_processor_config[key] = preprocessor_config[key]
|
|
414
|
+
|
|
415
|
+
except Exception:
|
|
416
|
+
pass
|
|
417
|
+
|
|
418
|
+
image_processor = ImageProcessor(**image_processor_config)
|
|
419
|
+
return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
__all__ = [
|
|
423
|
+
"PaddleOCRVLProcessor",
|
|
424
|
+
"ImageProcessor",
|
|
425
|
+
]
|