fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,686 @@
1
+ """Image processor and Processor for ERNIE 4.5 VL MoE."""
2
+
3
+ import math
4
+ import os
5
+ from shutil import copyfile
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
+ import mlx.core as mx
9
+ import numpy as np
10
+ import sentencepiece as spm
11
+ from PIL import Image
12
+ from transformers import AutoImageProcessor, AutoProcessor
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.image_processing_utils import (
15
+ BaseImageProcessor as HFBaseImageProcessor,
16
+ )
17
+ from transformers.image_transforms import (
18
+ normalize,
19
+ rescale,
20
+ resize,
21
+ to_channel_dimension_format,
22
+ )
23
+ from transformers.image_utils import (
24
+ ChannelDimension,
25
+ ImageInput,
26
+ PILImageResampling,
27
+ is_valid_image,
28
+ to_numpy_array,
29
+ )
30
+ from transformers.processing_utils import ProcessorMixin
31
+ from transformers.tokenization_utils import PreTrainedTokenizer
32
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
33
+
34
+
35
+ class Ernie4_5_VLTokenizer(PreTrainedTokenizer):
36
+ """Tokenizer for ERNIE 4.5 VL model using SentencePiece."""
37
+
38
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
39
+ model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
40
+ padding_side = "right"
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_file,
45
+ bos_token="<s>",
46
+ cls_token="<|begin_of_sentence|>",
47
+ eos_token="</s>",
48
+ mask_token="<mask:1>",
49
+ pad_token="<unk>",
50
+ sep_token="<|end_of_sentence|>",
51
+ unk_token="<unk>",
52
+ additional_special_tokens=None,
53
+ chat_template=None,
54
+ **kwargs,
55
+ ):
56
+ self.vocab_file = vocab_file
57
+ self.sp_model = spm.SentencePieceProcessor()
58
+ self.sp_model.Load(vocab_file)
59
+
60
+ if additional_special_tokens is None:
61
+ additional_special_tokens = ["<mask:1>", "<mask:7>"]
62
+
63
+ # Load chat_template from tokenizer_config.json if not provided
64
+ if chat_template is None:
65
+ import json
66
+
67
+ config_file = os.path.join(
68
+ os.path.dirname(vocab_file), "tokenizer_config.json"
69
+ )
70
+ if os.path.exists(config_file):
71
+ with open(config_file, "r") as f:
72
+ config = json.load(f)
73
+ chat_template = config.get("chat_template")
74
+
75
+ super().__init__(
76
+ bos_token=bos_token,
77
+ cls_token=cls_token,
78
+ eos_token=eos_token,
79
+ mask_token=mask_token,
80
+ pad_token=pad_token,
81
+ sep_token=sep_token,
82
+ unk_token=unk_token,
83
+ additional_special_tokens=additional_special_tokens,
84
+ chat_template=chat_template,
85
+ **kwargs,
86
+ )
87
+
88
+ @property
89
+ def vocab_size(self):
90
+ return self.sp_model.vocab_size()
91
+
92
+ @property
93
+ def space_token_id(self):
94
+ return self.sp_model.piece_to_id("<mask:1>")
95
+
96
+ @property
97
+ def gend_token_id(self):
98
+ return self.sp_model.piece_to_id("<mask:7>")
99
+
100
+ def get_vocab(self):
101
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
102
+ vocab.update(self.added_tokens_encoder)
103
+ return vocab
104
+
105
+ def _tokenize(self, text):
106
+ return self.sp_model.encode_as_pieces(text)
107
+
108
+ def _convert_token_to_id(self, token):
109
+ return self.sp_model.piece_to_id(token)
110
+
111
+ def _convert_id_to_token(self, id):
112
+ return self.sp_model.id_to_piece(id)
113
+
114
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
115
+ if token_ids_1 is None:
116
+ return token_ids_0
117
+ return token_ids_0 + token_ids_1
118
+
119
+ def convert_tokens_to_string(self, tokens):
120
+ current_sub_tokens = []
121
+ out_string = ""
122
+ for token in tokens:
123
+ if token in self.all_special_tokens:
124
+ out_string += self.sp_model.decode(current_sub_tokens) + token
125
+ current_sub_tokens = []
126
+ else:
127
+ current_sub_tokens.append(token)
128
+ out_string += self.sp_model.decode(current_sub_tokens)
129
+ return out_string
130
+
131
+ def save_vocabulary(
132
+ self, save_directory, filename_prefix: Optional[str] = None
133
+ ) -> Tuple[str]:
134
+ if not os.path.isdir(save_directory):
135
+ return None
136
+ out_vocab_file = os.path.join(
137
+ save_directory,
138
+ (filename_prefix + "-" if filename_prefix else "")
139
+ + self.vocab_files_names["vocab_file"],
140
+ )
141
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
142
+ out_vocab_file
143
+ ) and os.path.isfile(self.vocab_file):
144
+ copyfile(self.vocab_file, out_vocab_file)
145
+ elif not os.path.isfile(self.vocab_file):
146
+ with open(out_vocab_file, "wb") as fi:
147
+ content_spiece_model = self.sp_model.serialized_model_proto()
148
+ fi.write(content_spiece_model)
149
+ return (out_vocab_file,)
150
+
151
+ def _decode(self, *args, **kwargs):
152
+ kwargs.pop("clean_up_tokenization_spaces", None)
153
+ kwargs.pop("spaces_between_special_tokens", None)
154
+ return super()._decode(
155
+ *args,
156
+ **kwargs,
157
+ clean_up_tokenization_spaces=False,
158
+ spaces_between_special_tokens=False,
159
+ )
160
+
161
+
162
+ def _validate_images_text_input_order(images, text):
163
+ if isinstance(images, str) and text is None:
164
+ return None, images
165
+ if images is not None and text is not None:
166
+ if isinstance(images, str) and not isinstance(text, str):
167
+ return text, images
168
+ return images, text
169
+
170
+
171
+ def round_by_factor(number: int, factor: int) -> int:
172
+ return round(number / factor) * factor
173
+
174
+
175
+ def ceil_by_factor(number: int, factor: int) -> int:
176
+ return math.ceil(number / factor) * factor
177
+
178
+
179
+ def floor_by_factor(number: int, factor: int) -> int:
180
+ return math.floor(number / factor) * factor
181
+
182
+
183
+ def smart_resize(
184
+ height: int,
185
+ width: int,
186
+ factor: int = 28,
187
+ min_pixels: int = 56 * 56,
188
+ max_pixels: int = 28 * 28 * 1280,
189
+ ) -> Tuple[int, int]:
190
+ MAX_RATIO = 200
191
+ if height / width > MAX_RATIO:
192
+ width = height // MAX_RATIO
193
+ elif width / height > MAX_RATIO:
194
+ height = width // MAX_RATIO
195
+
196
+ h_bar = max(factor, round_by_factor(height, factor))
197
+ w_bar = max(factor, round_by_factor(width, factor))
198
+
199
+ if h_bar * w_bar > max_pixels:
200
+ beta = math.sqrt((height * width) / max_pixels)
201
+ h_bar = floor_by_factor(int(height / beta), factor)
202
+ w_bar = floor_by_factor(int(width / beta), factor)
203
+ elif h_bar * w_bar < min_pixels:
204
+ beta = math.sqrt(min_pixels / (height * width))
205
+ h_bar = ceil_by_factor(int(height * beta), factor)
206
+ w_bar = ceil_by_factor(int(width * beta), factor)
207
+
208
+ h_bar = max(factor, h_bar)
209
+ w_bar = max(factor, w_bar)
210
+
211
+ return h_bar, w_bar
212
+
213
+
214
+ class ImageProcessor(HFBaseImageProcessor):
215
+ """Image processor for ERNIE 4.5 VL MoE model."""
216
+
217
+ model_input_names = ["pixel_values", "image_grid_thw"]
218
+
219
+ def __init__(
220
+ self,
221
+ image_mean: Tuple[float, ...] = (0.48145466, 0.4578275, 0.40821073),
222
+ image_std: Tuple[float, ...] = (0.26862954, 0.26130258, 0.27577711),
223
+ size: Tuple[int, int] = (224, 224),
224
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
225
+ rescale_factor: float = 1 / 255,
226
+ data_format: ChannelDimension = ChannelDimension.FIRST,
227
+ patch_size: int = 14,
228
+ merge_size: int = 2,
229
+ temporal_patch_size: int = 2,
230
+ min_pixels: int = 56 * 56,
231
+ max_pixels: int = 28 * 28 * 1280,
232
+ config=None,
233
+ **kwargs,
234
+ ):
235
+ if config is not None:
236
+ if isinstance(config, dict):
237
+ vision_config = config.get("vision_config", {})
238
+ image_mean = config.get("image_mean", image_mean)
239
+ image_std = config.get("image_std", image_std)
240
+ min_pixels = config.get("min_pixels", min_pixels)
241
+ max_pixels = config.get("max_pixels", max_pixels)
242
+ patch_size = vision_config.get(
243
+ "patch_size", config.get("patch_size", patch_size)
244
+ )
245
+ merge_size = vision_config.get(
246
+ "spatial_merge_size", config.get("spatial_merge_size", merge_size)
247
+ )
248
+ temporal_patch_size = vision_config.get(
249
+ "temporal_patch_size",
250
+ config.get("temporal_patch_size", temporal_patch_size),
251
+ )
252
+ else:
253
+ patch_size = getattr(config, "patch_size", patch_size)
254
+ merge_size = getattr(
255
+ config,
256
+ "spatial_merge_size",
257
+ getattr(config, "merge_size", merge_size),
258
+ )
259
+ temporal_patch_size = getattr(
260
+ config, "temporal_patch_size", temporal_patch_size
261
+ )
262
+
263
+ HFBaseImageProcessor.__init__(self, **kwargs)
264
+
265
+ self.image_mean = image_mean
266
+ self.image_std = image_std
267
+ self.size = size
268
+ self.resample = resample
269
+ self.rescale_factor = rescale_factor
270
+ self.data_format = data_format
271
+ self.patch_size = patch_size
272
+ self.merge_size = merge_size
273
+ self.temporal_patch_size = temporal_patch_size
274
+ self.min_pixels = min_pixels
275
+ self.max_pixels = max_pixels
276
+ self.factor = patch_size * merge_size
277
+
278
+ def get_smart_resize(
279
+ self,
280
+ height: int,
281
+ width: int,
282
+ min_pixels: Optional[int] = None,
283
+ max_pixels: Optional[int] = None,
284
+ ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
285
+ actual_min_pixels = min_pixels if min_pixels is not None else self.min_pixels
286
+ actual_max_pixels = max_pixels if max_pixels is not None else self.max_pixels
287
+
288
+ resized_height, resized_width = smart_resize(
289
+ height,
290
+ width,
291
+ factor=self.factor,
292
+ min_pixels=actual_min_pixels,
293
+ max_pixels=actual_max_pixels,
294
+ )
295
+
296
+ grid_h = resized_height // self.patch_size
297
+ grid_w = resized_width // self.patch_size
298
+
299
+ return (resized_height, resized_width), (grid_h, grid_w)
300
+
301
+ def _extract_patches(
302
+ self,
303
+ image: np.ndarray,
304
+ grid_h: int,
305
+ grid_w: int,
306
+ ) -> np.ndarray:
307
+ C, H, W = image.shape
308
+
309
+ patches = image.reshape(
310
+ C,
311
+ grid_h // self.merge_size,
312
+ self.merge_size,
313
+ self.patch_size,
314
+ grid_w // self.merge_size,
315
+ self.merge_size,
316
+ self.patch_size,
317
+ )
318
+
319
+ patches = patches.transpose(1, 4, 2, 5, 0, 3, 6)
320
+
321
+ num_patches = (
322
+ (grid_h // self.merge_size)
323
+ * (grid_w // self.merge_size)
324
+ * (self.merge_size**2)
325
+ )
326
+ patches = patches.reshape(num_patches, C * self.patch_size * self.patch_size)
327
+
328
+ return patches
329
+
330
+ def preprocess(
331
+ self,
332
+ images: Union[Image.Image, List[Image.Image]],
333
+ return_grid_thw: bool = True,
334
+ ) -> Union[np.ndarray, Dict]:
335
+ if isinstance(images, Image.Image):
336
+ images = [images]
337
+
338
+ all_patches = []
339
+ all_grid_thw = []
340
+
341
+ for image in images:
342
+ if image.mode != "RGB":
343
+ image = image.convert("RGB")
344
+
345
+ (resized_h, resized_w), (grid_h, grid_w) = self.get_smart_resize(
346
+ image.height, image.width
347
+ )
348
+
349
+ img_array = to_numpy_array(image)
350
+ img_array = resize(
351
+ img_array,
352
+ size=(resized_h, resized_w),
353
+ resample=self.resample,
354
+ data_format=ChannelDimension.LAST,
355
+ input_data_format=ChannelDimension.LAST,
356
+ )
357
+
358
+ img_array = rescale(
359
+ img_array,
360
+ scale=self.rescale_factor,
361
+ data_format=ChannelDimension.LAST,
362
+ input_data_format=ChannelDimension.LAST,
363
+ )
364
+
365
+ img_array = normalize(
366
+ img_array,
367
+ mean=self.image_mean,
368
+ std=self.image_std,
369
+ data_format=ChannelDimension.LAST,
370
+ input_data_format=ChannelDimension.LAST,
371
+ )
372
+
373
+ img_array = to_channel_dimension_format(
374
+ img_array,
375
+ channel_dim=ChannelDimension.FIRST,
376
+ input_channel_dim=ChannelDimension.LAST,
377
+ )
378
+
379
+ patches = self._extract_patches(img_array, grid_h, grid_w)
380
+ all_patches.append(patches)
381
+ all_grid_thw.append([1, grid_h, grid_w])
382
+
383
+ pixel_values = np.concatenate(all_patches, axis=0)
384
+
385
+ if return_grid_thw:
386
+ return {
387
+ "pixel_values": pixel_values,
388
+ "image_grid_thw": np.array(all_grid_thw, dtype=np.int64),
389
+ }
390
+
391
+ return pixel_values
392
+
393
+ def preprocess_video(
394
+ self,
395
+ frames: List[Image.Image],
396
+ return_grid_thw: bool = True,
397
+ ) -> Union[np.ndarray, Dict]:
398
+ if not frames:
399
+ raise ValueError("frames list cannot be empty")
400
+
401
+ first_frame = frames[0]
402
+ if first_frame.mode != "RGB":
403
+ first_frame = first_frame.convert("RGB")
404
+
405
+ (resized_h, resized_w), (grid_h, grid_w) = self.get_smart_resize(
406
+ first_frame.height, first_frame.width
407
+ )
408
+
409
+ all_patches = []
410
+
411
+ for frame in frames:
412
+ if frame.mode != "RGB":
413
+ frame = frame.convert("RGB")
414
+
415
+ img_array = to_numpy_array(frame)
416
+ img_array = resize(
417
+ img_array,
418
+ size=(resized_h, resized_w),
419
+ resample=self.resample,
420
+ data_format=ChannelDimension.LAST,
421
+ input_data_format=ChannelDimension.LAST,
422
+ )
423
+
424
+ img_array = rescale(
425
+ img_array,
426
+ scale=self.rescale_factor,
427
+ data_format=ChannelDimension.LAST,
428
+ input_data_format=ChannelDimension.LAST,
429
+ )
430
+
431
+ img_array = normalize(
432
+ img_array,
433
+ mean=self.image_mean,
434
+ std=self.image_std,
435
+ data_format=ChannelDimension.LAST,
436
+ input_data_format=ChannelDimension.LAST,
437
+ )
438
+
439
+ img_array = to_channel_dimension_format(
440
+ img_array,
441
+ channel_dim=ChannelDimension.FIRST,
442
+ input_channel_dim=ChannelDimension.LAST,
443
+ )
444
+
445
+ patches = self._extract_patches(img_array, grid_h, grid_w)
446
+ all_patches.append(patches)
447
+
448
+ pixel_values = np.concatenate(all_patches, axis=0)
449
+ num_frames = len(frames)
450
+ grid_t = num_frames
451
+
452
+ if return_grid_thw:
453
+ return {
454
+ "pixel_values": pixel_values,
455
+ "video_grid_thw": np.array([[grid_t, grid_h, grid_w]], dtype=np.int64),
456
+ }
457
+
458
+ return pixel_values
459
+
460
+ def __call__(
461
+ self,
462
+ images: ImageInput,
463
+ **kwargs,
464
+ ) -> BatchFeature:
465
+ return self.preprocess(images, **kwargs)
466
+
467
+
468
+ class Ernie4_5_VLProcessor(ProcessorMixin):
469
+ """Processor for ERNIE 4.5 VL that wraps image processor and tokenizer."""
470
+
471
+ attributes = ["image_processor", "tokenizer"]
472
+ valid_kwargs = ["chat_template", "spatial_conv_size", "temporal_conv_size"]
473
+ image_processor_class = "ImageProcessor"
474
+ tokenizer_class = "Ernie4_5_VLTokenizer"
475
+
476
+ IMG_START = "<|IMAGE_START|>"
477
+ IMG_END = "<|IMAGE_END|>"
478
+ VID_START = "<|VIDEO_START|>"
479
+ VID_END = "<|VIDEO_END|>"
480
+ IMAGE_PLACEHOLDER = "<|IMAGE_PLACEHOLDER|>"
481
+
482
+ def __init__(
483
+ self,
484
+ image_processor=None,
485
+ tokenizer=None,
486
+ chat_template=None,
487
+ spatial_conv_size: int = 2,
488
+ temporal_conv_size: int = 2,
489
+ **kwargs,
490
+ ):
491
+ if image_processor is None:
492
+ image_processor = ImageProcessor()
493
+ self.spatial_conv_size = spatial_conv_size
494
+ self.temporal_conv_size = temporal_conv_size
495
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
496
+
497
+ @property
498
+ def pad_token(self):
499
+ return self.tokenizer.pad_token if self.tokenizer else None
500
+
501
+ @property
502
+ def pad_token_id(self):
503
+ return self.tokenizer.pad_token_id if self.tokenizer else None
504
+
505
+ @property
506
+ def eos_token(self):
507
+ return self.tokenizer.eos_token if self.tokenizer else None
508
+
509
+ @property
510
+ def eos_token_id(self):
511
+ return self.tokenizer.eos_token_id if self.tokenizer else None
512
+
513
+ @property
514
+ def bos_token(self):
515
+ return self.tokenizer.bos_token if self.tokenizer else None
516
+
517
+ @property
518
+ def bos_token_id(self):
519
+ return self.tokenizer.bos_token_id if self.tokenizer else None
520
+
521
+ def __call__(
522
+ self,
523
+ images: ImageInput = None,
524
+ text: Union[
525
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
526
+ ] = None,
527
+ **kwargs,
528
+ ) -> BatchFeature:
529
+ if images is None and text is None:
530
+ raise ValueError("You have to specify at least one of `images` or `text`.")
531
+
532
+ images, text = _validate_images_text_input_order(images, text)
533
+ kwargs.pop("return_tensors", None)
534
+
535
+ if images is not None:
536
+ if is_valid_image(images):
537
+ images = [images]
538
+
539
+ image_inputs = self.image_processor(images)
540
+ image_grid_thw = image_inputs["image_grid_thw"]
541
+ else:
542
+ image_inputs = {}
543
+ image_grid_thw = None
544
+
545
+ if isinstance(text, str):
546
+ text = [text]
547
+ elif text is not None and not isinstance(text, list):
548
+ raise ValueError(
549
+ "Invalid input text. Please provide a string, or a list of strings"
550
+ )
551
+
552
+ if image_grid_thw is not None and text is not None:
553
+ merge_length = self.spatial_conv_size * self.spatial_conv_size
554
+ index = 0
555
+ for i in range(len(text)):
556
+ # Handle <|image@placeholder|> format used in chat templates
557
+ placeholder = f"{self.IMG_START}<|image@placeholder|>{self.IMG_END}"
558
+ while placeholder in text[i]:
559
+ if index < len(image_grid_thw):
560
+ grid_thw = image_grid_thw[index]
561
+ # grid_thw is [t, h, w], compute number of tokens
562
+ num_patches = int(np.prod(grid_thw))
563
+ num_placeholders = num_patches // merge_length
564
+ replacement = (
565
+ f"{self.IMG_START}"
566
+ f"{self.IMAGE_PLACEHOLDER * num_placeholders}"
567
+ f"{self.IMG_END}"
568
+ )
569
+ text[i] = text[i].replace(placeholder, replacement, 1)
570
+ index += 1
571
+ else:
572
+ break
573
+
574
+ if text is not None:
575
+ all_input_ids = []
576
+ for t in text:
577
+ ids = self.tokenizer.encode(t)
578
+ all_input_ids.append(ids)
579
+
580
+ max_len = max(len(ids) for ids in all_input_ids)
581
+ pad_token_id = self.tokenizer.pad_token_id or 0
582
+
583
+ padded_input_ids = []
584
+ attention_masks = []
585
+ for ids in all_input_ids:
586
+ padding_length = max_len - len(ids)
587
+ padded_ids = ids + [pad_token_id] * padding_length
588
+ mask = [1] * len(ids) + [0] * padding_length
589
+ padded_input_ids.append(padded_ids)
590
+ attention_masks.append(mask)
591
+
592
+ if images is None:
593
+ if len(padded_input_ids) == 1:
594
+ text_inputs = {
595
+ "input_ids": padded_input_ids[0],
596
+ "attention_mask": attention_masks[0],
597
+ }
598
+ else:
599
+ text_inputs = {
600
+ "input_ids": padded_input_ids,
601
+ "attention_mask": attention_masks,
602
+ }
603
+ else:
604
+ text_inputs = {
605
+ "input_ids": mx.array(padded_input_ids),
606
+ "attention_mask": mx.array(attention_masks),
607
+ }
608
+ else:
609
+ text_inputs = {}
610
+
611
+ if image_inputs:
612
+ image_inputs = {
613
+ "pixel_values": mx.array(image_inputs["pixel_values"]),
614
+ "image_grid_thw": mx.array(image_inputs["image_grid_thw"]),
615
+ }
616
+
617
+ return BatchFeature(data={**text_inputs, **image_inputs})
618
+
619
+ def batch_decode(self, *args, **kwargs):
620
+ return self.tokenizer.batch_decode(*args, **kwargs)
621
+
622
+ def decode(self, *args, **kwargs):
623
+ return self.tokenizer.decode(*args, **kwargs)
624
+
625
+ def apply_chat_template(
626
+ self,
627
+ conversation,
628
+ chat_template=None,
629
+ add_generation_prompt=False,
630
+ tokenize=False,
631
+ **kwargs,
632
+ ):
633
+ if chat_template is None:
634
+ chat_template = self.chat_template
635
+ if chat_template is None:
636
+ chat_template = getattr(self.tokenizer, "chat_template", None)
637
+ if chat_template is None:
638
+ raise ValueError(
639
+ "No chat template found. Please provide a chat_template argument "
640
+ "or ensure the tokenizer has a chat_template attribute."
641
+ )
642
+
643
+ # Use jinja2 to render the template
644
+ try:
645
+ from jinja2 import Template
646
+ except ImportError:
647
+ raise ImportError("jinja2 is required for apply_chat_template")
648
+
649
+ template = Template(chat_template)
650
+ rendered = template.render(
651
+ messages=conversation,
652
+ add_generation_prompt=add_generation_prompt,
653
+ **kwargs,
654
+ )
655
+
656
+ if tokenize:
657
+ return self.tokenizer.encode(rendered)
658
+ return rendered
659
+
660
+ @staticmethod
661
+ def from_pretrained(pretrained_model_name_or_path, **kwargs):
662
+ from pathlib import Path
663
+
664
+ if not Path(pretrained_model_name_or_path).exists():
665
+ from huggingface_hub import snapshot_download
666
+
667
+ pretrained_model_name_or_path = snapshot_download(
668
+ pretrained_model_name_or_path,
669
+ allow_patterns=["*.json", "*.model", "*.txt"],
670
+ )
671
+
672
+ tokenizer = Ernie4_5_VLTokenizer.from_pretrained(pretrained_model_name_or_path)
673
+ image_processor = ImageProcessor()
674
+
675
+ return Ernie4_5_VLProcessor(
676
+ image_processor=image_processor, tokenizer=tokenizer
677
+ )
678
+
679
+
680
+ MODEL_TYPE = "ernie4_5_moe_vl"
681
+
682
+ try:
683
+ AutoImageProcessor.register(MODEL_TYPE, slow_image_processor_class=ImageProcessor)
684
+ AutoProcessor.register(MODEL_TYPE, Ernie4_5_VLProcessor)
685
+ except Exception as e:
686
+ raise Exception(f"Error registering {MODEL_TYPE} processor: {e}")