fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,686 @@
|
|
|
1
|
+
"""Image processor and Processor for ERNIE 4.5 VL MoE."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
from shutil import copyfile
|
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import numpy as np
|
|
10
|
+
import sentencepiece as spm
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from transformers import AutoImageProcessor, AutoProcessor
|
|
13
|
+
from transformers.feature_extraction_utils import BatchFeature
|
|
14
|
+
from transformers.image_processing_utils import (
|
|
15
|
+
BaseImageProcessor as HFBaseImageProcessor,
|
|
16
|
+
)
|
|
17
|
+
from transformers.image_transforms import (
|
|
18
|
+
normalize,
|
|
19
|
+
rescale,
|
|
20
|
+
resize,
|
|
21
|
+
to_channel_dimension_format,
|
|
22
|
+
)
|
|
23
|
+
from transformers.image_utils import (
|
|
24
|
+
ChannelDimension,
|
|
25
|
+
ImageInput,
|
|
26
|
+
PILImageResampling,
|
|
27
|
+
is_valid_image,
|
|
28
|
+
to_numpy_array,
|
|
29
|
+
)
|
|
30
|
+
from transformers.processing_utils import ProcessorMixin
|
|
31
|
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
32
|
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Ernie4_5_VLTokenizer(PreTrainedTokenizer):
|
|
36
|
+
"""Tokenizer for ERNIE 4.5 VL model using SentencePiece."""
|
|
37
|
+
|
|
38
|
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
|
39
|
+
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
|
|
40
|
+
padding_side = "right"
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
vocab_file,
|
|
45
|
+
bos_token="<s>",
|
|
46
|
+
cls_token="<|begin_of_sentence|>",
|
|
47
|
+
eos_token="</s>",
|
|
48
|
+
mask_token="<mask:1>",
|
|
49
|
+
pad_token="<unk>",
|
|
50
|
+
sep_token="<|end_of_sentence|>",
|
|
51
|
+
unk_token="<unk>",
|
|
52
|
+
additional_special_tokens=None,
|
|
53
|
+
chat_template=None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
self.vocab_file = vocab_file
|
|
57
|
+
self.sp_model = spm.SentencePieceProcessor()
|
|
58
|
+
self.sp_model.Load(vocab_file)
|
|
59
|
+
|
|
60
|
+
if additional_special_tokens is None:
|
|
61
|
+
additional_special_tokens = ["<mask:1>", "<mask:7>"]
|
|
62
|
+
|
|
63
|
+
# Load chat_template from tokenizer_config.json if not provided
|
|
64
|
+
if chat_template is None:
|
|
65
|
+
import json
|
|
66
|
+
|
|
67
|
+
config_file = os.path.join(
|
|
68
|
+
os.path.dirname(vocab_file), "tokenizer_config.json"
|
|
69
|
+
)
|
|
70
|
+
if os.path.exists(config_file):
|
|
71
|
+
with open(config_file, "r") as f:
|
|
72
|
+
config = json.load(f)
|
|
73
|
+
chat_template = config.get("chat_template")
|
|
74
|
+
|
|
75
|
+
super().__init__(
|
|
76
|
+
bos_token=bos_token,
|
|
77
|
+
cls_token=cls_token,
|
|
78
|
+
eos_token=eos_token,
|
|
79
|
+
mask_token=mask_token,
|
|
80
|
+
pad_token=pad_token,
|
|
81
|
+
sep_token=sep_token,
|
|
82
|
+
unk_token=unk_token,
|
|
83
|
+
additional_special_tokens=additional_special_tokens,
|
|
84
|
+
chat_template=chat_template,
|
|
85
|
+
**kwargs,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def vocab_size(self):
|
|
90
|
+
return self.sp_model.vocab_size()
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def space_token_id(self):
|
|
94
|
+
return self.sp_model.piece_to_id("<mask:1>")
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def gend_token_id(self):
|
|
98
|
+
return self.sp_model.piece_to_id("<mask:7>")
|
|
99
|
+
|
|
100
|
+
def get_vocab(self):
|
|
101
|
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
|
102
|
+
vocab.update(self.added_tokens_encoder)
|
|
103
|
+
return vocab
|
|
104
|
+
|
|
105
|
+
def _tokenize(self, text):
|
|
106
|
+
return self.sp_model.encode_as_pieces(text)
|
|
107
|
+
|
|
108
|
+
def _convert_token_to_id(self, token):
|
|
109
|
+
return self.sp_model.piece_to_id(token)
|
|
110
|
+
|
|
111
|
+
def _convert_id_to_token(self, id):
|
|
112
|
+
return self.sp_model.id_to_piece(id)
|
|
113
|
+
|
|
114
|
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
|
115
|
+
if token_ids_1 is None:
|
|
116
|
+
return token_ids_0
|
|
117
|
+
return token_ids_0 + token_ids_1
|
|
118
|
+
|
|
119
|
+
def convert_tokens_to_string(self, tokens):
|
|
120
|
+
current_sub_tokens = []
|
|
121
|
+
out_string = ""
|
|
122
|
+
for token in tokens:
|
|
123
|
+
if token in self.all_special_tokens:
|
|
124
|
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
|
125
|
+
current_sub_tokens = []
|
|
126
|
+
else:
|
|
127
|
+
current_sub_tokens.append(token)
|
|
128
|
+
out_string += self.sp_model.decode(current_sub_tokens)
|
|
129
|
+
return out_string
|
|
130
|
+
|
|
131
|
+
def save_vocabulary(
|
|
132
|
+
self, save_directory, filename_prefix: Optional[str] = None
|
|
133
|
+
) -> Tuple[str]:
|
|
134
|
+
if not os.path.isdir(save_directory):
|
|
135
|
+
return None
|
|
136
|
+
out_vocab_file = os.path.join(
|
|
137
|
+
save_directory,
|
|
138
|
+
(filename_prefix + "-" if filename_prefix else "")
|
|
139
|
+
+ self.vocab_files_names["vocab_file"],
|
|
140
|
+
)
|
|
141
|
+
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
|
142
|
+
out_vocab_file
|
|
143
|
+
) and os.path.isfile(self.vocab_file):
|
|
144
|
+
copyfile(self.vocab_file, out_vocab_file)
|
|
145
|
+
elif not os.path.isfile(self.vocab_file):
|
|
146
|
+
with open(out_vocab_file, "wb") as fi:
|
|
147
|
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
|
148
|
+
fi.write(content_spiece_model)
|
|
149
|
+
return (out_vocab_file,)
|
|
150
|
+
|
|
151
|
+
def _decode(self, *args, **kwargs):
|
|
152
|
+
kwargs.pop("clean_up_tokenization_spaces", None)
|
|
153
|
+
kwargs.pop("spaces_between_special_tokens", None)
|
|
154
|
+
return super()._decode(
|
|
155
|
+
*args,
|
|
156
|
+
**kwargs,
|
|
157
|
+
clean_up_tokenization_spaces=False,
|
|
158
|
+
spaces_between_special_tokens=False,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _validate_images_text_input_order(images, text):
|
|
163
|
+
if isinstance(images, str) and text is None:
|
|
164
|
+
return None, images
|
|
165
|
+
if images is not None and text is not None:
|
|
166
|
+
if isinstance(images, str) and not isinstance(text, str):
|
|
167
|
+
return text, images
|
|
168
|
+
return images, text
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def round_by_factor(number: int, factor: int) -> int:
|
|
172
|
+
return round(number / factor) * factor
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def ceil_by_factor(number: int, factor: int) -> int:
|
|
176
|
+
return math.ceil(number / factor) * factor
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def floor_by_factor(number: int, factor: int) -> int:
|
|
180
|
+
return math.floor(number / factor) * factor
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def smart_resize(
|
|
184
|
+
height: int,
|
|
185
|
+
width: int,
|
|
186
|
+
factor: int = 28,
|
|
187
|
+
min_pixels: int = 56 * 56,
|
|
188
|
+
max_pixels: int = 28 * 28 * 1280,
|
|
189
|
+
) -> Tuple[int, int]:
|
|
190
|
+
MAX_RATIO = 200
|
|
191
|
+
if height / width > MAX_RATIO:
|
|
192
|
+
width = height // MAX_RATIO
|
|
193
|
+
elif width / height > MAX_RATIO:
|
|
194
|
+
height = width // MAX_RATIO
|
|
195
|
+
|
|
196
|
+
h_bar = max(factor, round_by_factor(height, factor))
|
|
197
|
+
w_bar = max(factor, round_by_factor(width, factor))
|
|
198
|
+
|
|
199
|
+
if h_bar * w_bar > max_pixels:
|
|
200
|
+
beta = math.sqrt((height * width) / max_pixels)
|
|
201
|
+
h_bar = floor_by_factor(int(height / beta), factor)
|
|
202
|
+
w_bar = floor_by_factor(int(width / beta), factor)
|
|
203
|
+
elif h_bar * w_bar < min_pixels:
|
|
204
|
+
beta = math.sqrt(min_pixels / (height * width))
|
|
205
|
+
h_bar = ceil_by_factor(int(height * beta), factor)
|
|
206
|
+
w_bar = ceil_by_factor(int(width * beta), factor)
|
|
207
|
+
|
|
208
|
+
h_bar = max(factor, h_bar)
|
|
209
|
+
w_bar = max(factor, w_bar)
|
|
210
|
+
|
|
211
|
+
return h_bar, w_bar
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class ImageProcessor(HFBaseImageProcessor):
|
|
215
|
+
"""Image processor for ERNIE 4.5 VL MoE model."""
|
|
216
|
+
|
|
217
|
+
model_input_names = ["pixel_values", "image_grid_thw"]
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
image_mean: Tuple[float, ...] = (0.48145466, 0.4578275, 0.40821073),
|
|
222
|
+
image_std: Tuple[float, ...] = (0.26862954, 0.26130258, 0.27577711),
|
|
223
|
+
size: Tuple[int, int] = (224, 224),
|
|
224
|
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
|
225
|
+
rescale_factor: float = 1 / 255,
|
|
226
|
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
|
227
|
+
patch_size: int = 14,
|
|
228
|
+
merge_size: int = 2,
|
|
229
|
+
temporal_patch_size: int = 2,
|
|
230
|
+
min_pixels: int = 56 * 56,
|
|
231
|
+
max_pixels: int = 28 * 28 * 1280,
|
|
232
|
+
config=None,
|
|
233
|
+
**kwargs,
|
|
234
|
+
):
|
|
235
|
+
if config is not None:
|
|
236
|
+
if isinstance(config, dict):
|
|
237
|
+
vision_config = config.get("vision_config", {})
|
|
238
|
+
image_mean = config.get("image_mean", image_mean)
|
|
239
|
+
image_std = config.get("image_std", image_std)
|
|
240
|
+
min_pixels = config.get("min_pixels", min_pixels)
|
|
241
|
+
max_pixels = config.get("max_pixels", max_pixels)
|
|
242
|
+
patch_size = vision_config.get(
|
|
243
|
+
"patch_size", config.get("patch_size", patch_size)
|
|
244
|
+
)
|
|
245
|
+
merge_size = vision_config.get(
|
|
246
|
+
"spatial_merge_size", config.get("spatial_merge_size", merge_size)
|
|
247
|
+
)
|
|
248
|
+
temporal_patch_size = vision_config.get(
|
|
249
|
+
"temporal_patch_size",
|
|
250
|
+
config.get("temporal_patch_size", temporal_patch_size),
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
patch_size = getattr(config, "patch_size", patch_size)
|
|
254
|
+
merge_size = getattr(
|
|
255
|
+
config,
|
|
256
|
+
"spatial_merge_size",
|
|
257
|
+
getattr(config, "merge_size", merge_size),
|
|
258
|
+
)
|
|
259
|
+
temporal_patch_size = getattr(
|
|
260
|
+
config, "temporal_patch_size", temporal_patch_size
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
HFBaseImageProcessor.__init__(self, **kwargs)
|
|
264
|
+
|
|
265
|
+
self.image_mean = image_mean
|
|
266
|
+
self.image_std = image_std
|
|
267
|
+
self.size = size
|
|
268
|
+
self.resample = resample
|
|
269
|
+
self.rescale_factor = rescale_factor
|
|
270
|
+
self.data_format = data_format
|
|
271
|
+
self.patch_size = patch_size
|
|
272
|
+
self.merge_size = merge_size
|
|
273
|
+
self.temporal_patch_size = temporal_patch_size
|
|
274
|
+
self.min_pixels = min_pixels
|
|
275
|
+
self.max_pixels = max_pixels
|
|
276
|
+
self.factor = patch_size * merge_size
|
|
277
|
+
|
|
278
|
+
def get_smart_resize(
|
|
279
|
+
self,
|
|
280
|
+
height: int,
|
|
281
|
+
width: int,
|
|
282
|
+
min_pixels: Optional[int] = None,
|
|
283
|
+
max_pixels: Optional[int] = None,
|
|
284
|
+
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
|
285
|
+
actual_min_pixels = min_pixels if min_pixels is not None else self.min_pixels
|
|
286
|
+
actual_max_pixels = max_pixels if max_pixels is not None else self.max_pixels
|
|
287
|
+
|
|
288
|
+
resized_height, resized_width = smart_resize(
|
|
289
|
+
height,
|
|
290
|
+
width,
|
|
291
|
+
factor=self.factor,
|
|
292
|
+
min_pixels=actual_min_pixels,
|
|
293
|
+
max_pixels=actual_max_pixels,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
grid_h = resized_height // self.patch_size
|
|
297
|
+
grid_w = resized_width // self.patch_size
|
|
298
|
+
|
|
299
|
+
return (resized_height, resized_width), (grid_h, grid_w)
|
|
300
|
+
|
|
301
|
+
def _extract_patches(
|
|
302
|
+
self,
|
|
303
|
+
image: np.ndarray,
|
|
304
|
+
grid_h: int,
|
|
305
|
+
grid_w: int,
|
|
306
|
+
) -> np.ndarray:
|
|
307
|
+
C, H, W = image.shape
|
|
308
|
+
|
|
309
|
+
patches = image.reshape(
|
|
310
|
+
C,
|
|
311
|
+
grid_h // self.merge_size,
|
|
312
|
+
self.merge_size,
|
|
313
|
+
self.patch_size,
|
|
314
|
+
grid_w // self.merge_size,
|
|
315
|
+
self.merge_size,
|
|
316
|
+
self.patch_size,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
patches = patches.transpose(1, 4, 2, 5, 0, 3, 6)
|
|
320
|
+
|
|
321
|
+
num_patches = (
|
|
322
|
+
(grid_h // self.merge_size)
|
|
323
|
+
* (grid_w // self.merge_size)
|
|
324
|
+
* (self.merge_size**2)
|
|
325
|
+
)
|
|
326
|
+
patches = patches.reshape(num_patches, C * self.patch_size * self.patch_size)
|
|
327
|
+
|
|
328
|
+
return patches
|
|
329
|
+
|
|
330
|
+
def preprocess(
|
|
331
|
+
self,
|
|
332
|
+
images: Union[Image.Image, List[Image.Image]],
|
|
333
|
+
return_grid_thw: bool = True,
|
|
334
|
+
) -> Union[np.ndarray, Dict]:
|
|
335
|
+
if isinstance(images, Image.Image):
|
|
336
|
+
images = [images]
|
|
337
|
+
|
|
338
|
+
all_patches = []
|
|
339
|
+
all_grid_thw = []
|
|
340
|
+
|
|
341
|
+
for image in images:
|
|
342
|
+
if image.mode != "RGB":
|
|
343
|
+
image = image.convert("RGB")
|
|
344
|
+
|
|
345
|
+
(resized_h, resized_w), (grid_h, grid_w) = self.get_smart_resize(
|
|
346
|
+
image.height, image.width
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
img_array = to_numpy_array(image)
|
|
350
|
+
img_array = resize(
|
|
351
|
+
img_array,
|
|
352
|
+
size=(resized_h, resized_w),
|
|
353
|
+
resample=self.resample,
|
|
354
|
+
data_format=ChannelDimension.LAST,
|
|
355
|
+
input_data_format=ChannelDimension.LAST,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
img_array = rescale(
|
|
359
|
+
img_array,
|
|
360
|
+
scale=self.rescale_factor,
|
|
361
|
+
data_format=ChannelDimension.LAST,
|
|
362
|
+
input_data_format=ChannelDimension.LAST,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
img_array = normalize(
|
|
366
|
+
img_array,
|
|
367
|
+
mean=self.image_mean,
|
|
368
|
+
std=self.image_std,
|
|
369
|
+
data_format=ChannelDimension.LAST,
|
|
370
|
+
input_data_format=ChannelDimension.LAST,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
img_array = to_channel_dimension_format(
|
|
374
|
+
img_array,
|
|
375
|
+
channel_dim=ChannelDimension.FIRST,
|
|
376
|
+
input_channel_dim=ChannelDimension.LAST,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
patches = self._extract_patches(img_array, grid_h, grid_w)
|
|
380
|
+
all_patches.append(patches)
|
|
381
|
+
all_grid_thw.append([1, grid_h, grid_w])
|
|
382
|
+
|
|
383
|
+
pixel_values = np.concatenate(all_patches, axis=0)
|
|
384
|
+
|
|
385
|
+
if return_grid_thw:
|
|
386
|
+
return {
|
|
387
|
+
"pixel_values": pixel_values,
|
|
388
|
+
"image_grid_thw": np.array(all_grid_thw, dtype=np.int64),
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
return pixel_values
|
|
392
|
+
|
|
393
|
+
def preprocess_video(
|
|
394
|
+
self,
|
|
395
|
+
frames: List[Image.Image],
|
|
396
|
+
return_grid_thw: bool = True,
|
|
397
|
+
) -> Union[np.ndarray, Dict]:
|
|
398
|
+
if not frames:
|
|
399
|
+
raise ValueError("frames list cannot be empty")
|
|
400
|
+
|
|
401
|
+
first_frame = frames[0]
|
|
402
|
+
if first_frame.mode != "RGB":
|
|
403
|
+
first_frame = first_frame.convert("RGB")
|
|
404
|
+
|
|
405
|
+
(resized_h, resized_w), (grid_h, grid_w) = self.get_smart_resize(
|
|
406
|
+
first_frame.height, first_frame.width
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
all_patches = []
|
|
410
|
+
|
|
411
|
+
for frame in frames:
|
|
412
|
+
if frame.mode != "RGB":
|
|
413
|
+
frame = frame.convert("RGB")
|
|
414
|
+
|
|
415
|
+
img_array = to_numpy_array(frame)
|
|
416
|
+
img_array = resize(
|
|
417
|
+
img_array,
|
|
418
|
+
size=(resized_h, resized_w),
|
|
419
|
+
resample=self.resample,
|
|
420
|
+
data_format=ChannelDimension.LAST,
|
|
421
|
+
input_data_format=ChannelDimension.LAST,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
img_array = rescale(
|
|
425
|
+
img_array,
|
|
426
|
+
scale=self.rescale_factor,
|
|
427
|
+
data_format=ChannelDimension.LAST,
|
|
428
|
+
input_data_format=ChannelDimension.LAST,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
img_array = normalize(
|
|
432
|
+
img_array,
|
|
433
|
+
mean=self.image_mean,
|
|
434
|
+
std=self.image_std,
|
|
435
|
+
data_format=ChannelDimension.LAST,
|
|
436
|
+
input_data_format=ChannelDimension.LAST,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
img_array = to_channel_dimension_format(
|
|
440
|
+
img_array,
|
|
441
|
+
channel_dim=ChannelDimension.FIRST,
|
|
442
|
+
input_channel_dim=ChannelDimension.LAST,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
patches = self._extract_patches(img_array, grid_h, grid_w)
|
|
446
|
+
all_patches.append(patches)
|
|
447
|
+
|
|
448
|
+
pixel_values = np.concatenate(all_patches, axis=0)
|
|
449
|
+
num_frames = len(frames)
|
|
450
|
+
grid_t = num_frames
|
|
451
|
+
|
|
452
|
+
if return_grid_thw:
|
|
453
|
+
return {
|
|
454
|
+
"pixel_values": pixel_values,
|
|
455
|
+
"video_grid_thw": np.array([[grid_t, grid_h, grid_w]], dtype=np.int64),
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
return pixel_values
|
|
459
|
+
|
|
460
|
+
def __call__(
|
|
461
|
+
self,
|
|
462
|
+
images: ImageInput,
|
|
463
|
+
**kwargs,
|
|
464
|
+
) -> BatchFeature:
|
|
465
|
+
return self.preprocess(images, **kwargs)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class Ernie4_5_VLProcessor(ProcessorMixin):
|
|
469
|
+
"""Processor for ERNIE 4.5 VL that wraps image processor and tokenizer."""
|
|
470
|
+
|
|
471
|
+
attributes = ["image_processor", "tokenizer"]
|
|
472
|
+
valid_kwargs = ["chat_template", "spatial_conv_size", "temporal_conv_size"]
|
|
473
|
+
image_processor_class = "ImageProcessor"
|
|
474
|
+
tokenizer_class = "Ernie4_5_VLTokenizer"
|
|
475
|
+
|
|
476
|
+
IMG_START = "<|IMAGE_START|>"
|
|
477
|
+
IMG_END = "<|IMAGE_END|>"
|
|
478
|
+
VID_START = "<|VIDEO_START|>"
|
|
479
|
+
VID_END = "<|VIDEO_END|>"
|
|
480
|
+
IMAGE_PLACEHOLDER = "<|IMAGE_PLACEHOLDER|>"
|
|
481
|
+
|
|
482
|
+
def __init__(
|
|
483
|
+
self,
|
|
484
|
+
image_processor=None,
|
|
485
|
+
tokenizer=None,
|
|
486
|
+
chat_template=None,
|
|
487
|
+
spatial_conv_size: int = 2,
|
|
488
|
+
temporal_conv_size: int = 2,
|
|
489
|
+
**kwargs,
|
|
490
|
+
):
|
|
491
|
+
if image_processor is None:
|
|
492
|
+
image_processor = ImageProcessor()
|
|
493
|
+
self.spatial_conv_size = spatial_conv_size
|
|
494
|
+
self.temporal_conv_size = temporal_conv_size
|
|
495
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
496
|
+
|
|
497
|
+
@property
|
|
498
|
+
def pad_token(self):
|
|
499
|
+
return self.tokenizer.pad_token if self.tokenizer else None
|
|
500
|
+
|
|
501
|
+
@property
|
|
502
|
+
def pad_token_id(self):
|
|
503
|
+
return self.tokenizer.pad_token_id if self.tokenizer else None
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def eos_token(self):
|
|
507
|
+
return self.tokenizer.eos_token if self.tokenizer else None
|
|
508
|
+
|
|
509
|
+
@property
|
|
510
|
+
def eos_token_id(self):
|
|
511
|
+
return self.tokenizer.eos_token_id if self.tokenizer else None
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def bos_token(self):
|
|
515
|
+
return self.tokenizer.bos_token if self.tokenizer else None
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def bos_token_id(self):
|
|
519
|
+
return self.tokenizer.bos_token_id if self.tokenizer else None
|
|
520
|
+
|
|
521
|
+
def __call__(
|
|
522
|
+
self,
|
|
523
|
+
images: ImageInput = None,
|
|
524
|
+
text: Union[
|
|
525
|
+
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
|
526
|
+
] = None,
|
|
527
|
+
**kwargs,
|
|
528
|
+
) -> BatchFeature:
|
|
529
|
+
if images is None and text is None:
|
|
530
|
+
raise ValueError("You have to specify at least one of `images` or `text`.")
|
|
531
|
+
|
|
532
|
+
images, text = _validate_images_text_input_order(images, text)
|
|
533
|
+
kwargs.pop("return_tensors", None)
|
|
534
|
+
|
|
535
|
+
if images is not None:
|
|
536
|
+
if is_valid_image(images):
|
|
537
|
+
images = [images]
|
|
538
|
+
|
|
539
|
+
image_inputs = self.image_processor(images)
|
|
540
|
+
image_grid_thw = image_inputs["image_grid_thw"]
|
|
541
|
+
else:
|
|
542
|
+
image_inputs = {}
|
|
543
|
+
image_grid_thw = None
|
|
544
|
+
|
|
545
|
+
if isinstance(text, str):
|
|
546
|
+
text = [text]
|
|
547
|
+
elif text is not None and not isinstance(text, list):
|
|
548
|
+
raise ValueError(
|
|
549
|
+
"Invalid input text. Please provide a string, or a list of strings"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if image_grid_thw is not None and text is not None:
|
|
553
|
+
merge_length = self.spatial_conv_size * self.spatial_conv_size
|
|
554
|
+
index = 0
|
|
555
|
+
for i in range(len(text)):
|
|
556
|
+
# Handle <|image@placeholder|> format used in chat templates
|
|
557
|
+
placeholder = f"{self.IMG_START}<|image@placeholder|>{self.IMG_END}"
|
|
558
|
+
while placeholder in text[i]:
|
|
559
|
+
if index < len(image_grid_thw):
|
|
560
|
+
grid_thw = image_grid_thw[index]
|
|
561
|
+
# grid_thw is [t, h, w], compute number of tokens
|
|
562
|
+
num_patches = int(np.prod(grid_thw))
|
|
563
|
+
num_placeholders = num_patches // merge_length
|
|
564
|
+
replacement = (
|
|
565
|
+
f"{self.IMG_START}"
|
|
566
|
+
f"{self.IMAGE_PLACEHOLDER * num_placeholders}"
|
|
567
|
+
f"{self.IMG_END}"
|
|
568
|
+
)
|
|
569
|
+
text[i] = text[i].replace(placeholder, replacement, 1)
|
|
570
|
+
index += 1
|
|
571
|
+
else:
|
|
572
|
+
break
|
|
573
|
+
|
|
574
|
+
if text is not None:
|
|
575
|
+
all_input_ids = []
|
|
576
|
+
for t in text:
|
|
577
|
+
ids = self.tokenizer.encode(t)
|
|
578
|
+
all_input_ids.append(ids)
|
|
579
|
+
|
|
580
|
+
max_len = max(len(ids) for ids in all_input_ids)
|
|
581
|
+
pad_token_id = self.tokenizer.pad_token_id or 0
|
|
582
|
+
|
|
583
|
+
padded_input_ids = []
|
|
584
|
+
attention_masks = []
|
|
585
|
+
for ids in all_input_ids:
|
|
586
|
+
padding_length = max_len - len(ids)
|
|
587
|
+
padded_ids = ids + [pad_token_id] * padding_length
|
|
588
|
+
mask = [1] * len(ids) + [0] * padding_length
|
|
589
|
+
padded_input_ids.append(padded_ids)
|
|
590
|
+
attention_masks.append(mask)
|
|
591
|
+
|
|
592
|
+
if images is None:
|
|
593
|
+
if len(padded_input_ids) == 1:
|
|
594
|
+
text_inputs = {
|
|
595
|
+
"input_ids": padded_input_ids[0],
|
|
596
|
+
"attention_mask": attention_masks[0],
|
|
597
|
+
}
|
|
598
|
+
else:
|
|
599
|
+
text_inputs = {
|
|
600
|
+
"input_ids": padded_input_ids,
|
|
601
|
+
"attention_mask": attention_masks,
|
|
602
|
+
}
|
|
603
|
+
else:
|
|
604
|
+
text_inputs = {
|
|
605
|
+
"input_ids": mx.array(padded_input_ids),
|
|
606
|
+
"attention_mask": mx.array(attention_masks),
|
|
607
|
+
}
|
|
608
|
+
else:
|
|
609
|
+
text_inputs = {}
|
|
610
|
+
|
|
611
|
+
if image_inputs:
|
|
612
|
+
image_inputs = {
|
|
613
|
+
"pixel_values": mx.array(image_inputs["pixel_values"]),
|
|
614
|
+
"image_grid_thw": mx.array(image_inputs["image_grid_thw"]),
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
618
|
+
|
|
619
|
+
def batch_decode(self, *args, **kwargs):
|
|
620
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
621
|
+
|
|
622
|
+
def decode(self, *args, **kwargs):
|
|
623
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
624
|
+
|
|
625
|
+
def apply_chat_template(
|
|
626
|
+
self,
|
|
627
|
+
conversation,
|
|
628
|
+
chat_template=None,
|
|
629
|
+
add_generation_prompt=False,
|
|
630
|
+
tokenize=False,
|
|
631
|
+
**kwargs,
|
|
632
|
+
):
|
|
633
|
+
if chat_template is None:
|
|
634
|
+
chat_template = self.chat_template
|
|
635
|
+
if chat_template is None:
|
|
636
|
+
chat_template = getattr(self.tokenizer, "chat_template", None)
|
|
637
|
+
if chat_template is None:
|
|
638
|
+
raise ValueError(
|
|
639
|
+
"No chat template found. Please provide a chat_template argument "
|
|
640
|
+
"or ensure the tokenizer has a chat_template attribute."
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Use jinja2 to render the template
|
|
644
|
+
try:
|
|
645
|
+
from jinja2 import Template
|
|
646
|
+
except ImportError:
|
|
647
|
+
raise ImportError("jinja2 is required for apply_chat_template")
|
|
648
|
+
|
|
649
|
+
template = Template(chat_template)
|
|
650
|
+
rendered = template.render(
|
|
651
|
+
messages=conversation,
|
|
652
|
+
add_generation_prompt=add_generation_prompt,
|
|
653
|
+
**kwargs,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
if tokenize:
|
|
657
|
+
return self.tokenizer.encode(rendered)
|
|
658
|
+
return rendered
|
|
659
|
+
|
|
660
|
+
@staticmethod
|
|
661
|
+
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
|
662
|
+
from pathlib import Path
|
|
663
|
+
|
|
664
|
+
if not Path(pretrained_model_name_or_path).exists():
|
|
665
|
+
from huggingface_hub import snapshot_download
|
|
666
|
+
|
|
667
|
+
pretrained_model_name_or_path = snapshot_download(
|
|
668
|
+
pretrained_model_name_or_path,
|
|
669
|
+
allow_patterns=["*.json", "*.model", "*.txt"],
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
tokenizer = Ernie4_5_VLTokenizer.from_pretrained(pretrained_model_name_or_path)
|
|
673
|
+
image_processor = ImageProcessor()
|
|
674
|
+
|
|
675
|
+
return Ernie4_5_VLProcessor(
|
|
676
|
+
image_processor=image_processor, tokenizer=tokenizer
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
MODEL_TYPE = "ernie4_5_moe_vl"
|
|
681
|
+
|
|
682
|
+
try:
|
|
683
|
+
AutoImageProcessor.register(MODEL_TYPE, slow_image_processor_class=ImageProcessor)
|
|
684
|
+
AutoProcessor.register(MODEL_TYPE, Ernie4_5_VLProcessor)
|
|
685
|
+
except Exception as e:
|
|
686
|
+
raise Exception(f"Error registering {MODEL_TYPE} processor: {e}")
|