fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,208 @@
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ from transformers.feature_extraction_utils import BatchFeature
5
+ from transformers.processing_utils import ProcessorMixin
6
+
7
+ from ..base import install_auto_processor_patch
8
+
9
+
10
+ class GlmOcrProcessor(ProcessorMixin):
11
+ """
12
+ Processor for GLM-OCR that wraps an image processor and tokenizer.
13
+
14
+ Handles:
15
+ - Image preprocessing via image_processor
16
+ - Token replacement for image/video placeholders based on grid dimensions
17
+ """
18
+
19
+ attributes = ["image_processor", "tokenizer"]
20
+ valid_kwargs = ["chat_template"]
21
+ image_processor_class = "AutoImageProcessor"
22
+ tokenizer_class = "AutoTokenizer"
23
+
24
+ def __init__(
25
+ self,
26
+ image_processor=None,
27
+ tokenizer=None,
28
+ chat_template=None,
29
+ **kwargs,
30
+ ):
31
+ self.tokenizer = tokenizer
32
+ self.image_processor = image_processor
33
+
34
+ self.image_token = "<|image|>"
35
+ self.video_token = "<|video|>"
36
+
37
+ if tokenizer is not None:
38
+ self.image_token = getattr(tokenizer, "image_token", "<|image|>")
39
+ self.video_token = getattr(tokenizer, "video_token", "<|video|>")
40
+
41
+ self.image_token_id = getattr(tokenizer, "image_token_id", None)
42
+ if self.image_token_id is None:
43
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
44
+
45
+ self.video_token_id = getattr(tokenizer, "video_token_id", None)
46
+ if self.video_token_id is None:
47
+ self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
48
+ else:
49
+ self.image_token_id = None
50
+ self.video_token_id = None
51
+
52
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
53
+
54
+ def __call__(
55
+ self,
56
+ images=None,
57
+ text: Union[str, List[str]] = None,
58
+ videos=None,
59
+ **kwargs,
60
+ ) -> BatchFeature:
61
+ """
62
+ Process images/videos and text for the model.
63
+
64
+ Args:
65
+ images: Single image or list of images (PIL.Image, np.ndarray, etc.)
66
+ text: Single text or list of texts
67
+ videos: Video inputs (optional)
68
+ **kwargs: Additional arguments passed to image_processor and tokenizer
69
+
70
+ Returns:
71
+ BatchFeature with:
72
+ - input_ids: Token IDs with image/video placeholders expanded
73
+ - attention_mask: Attention mask
74
+ - pixel_values: Processed image/video patches
75
+ - image_grid_thw: Grid dimensions for each image
76
+ - video_grid_thw: Grid dimensions for each video (if videos provided)
77
+ """
78
+ image_inputs = {}
79
+ video_inputs = {}
80
+ image_grid_thw = None
81
+ video_grid_thw = None
82
+
83
+ padding = kwargs.pop("padding", False)
84
+ return_token_type_ids = kwargs.pop("return_token_type_ids", False)
85
+ return_tensors = kwargs.pop("return_tensors", None)
86
+
87
+ if images is not None and self.image_processor is not None:
88
+ image_inputs = self.image_processor(images=images)
89
+ image_grid_thw = image_inputs.get("image_grid_thw")
90
+
91
+ if videos is not None:
92
+ if hasattr(self, "video_processor") and self.video_processor is not None:
93
+ video_inputs = self.video_processor(videos=videos, **kwargs)
94
+ video_grid_thw = video_inputs.get("video_grid_thw")
95
+
96
+ if text is None:
97
+ text = [""]
98
+ elif not isinstance(text, list):
99
+ text = [text]
100
+
101
+ text = [t for t in text]
102
+
103
+ merge_size = getattr(self.image_processor, "merge_size", 2)
104
+ if hasattr(self.image_processor, "spatial_merge_size"):
105
+ merge_size = self.image_processor.spatial_merge_size
106
+ merge_length = merge_size**2
107
+
108
+ if image_grid_thw is not None:
109
+ index = 0
110
+ for i in range(len(text)):
111
+ while self.image_token in text[i]:
112
+ grid = image_grid_thw[index]
113
+ if hasattr(grid, "tolist"):
114
+ grid = grid.tolist()
115
+ num_image_tokens = int(np.prod(grid) // merge_length)
116
+
117
+ text[i] = text[i].replace(
118
+ self.image_token,
119
+ "<|placeholder|>" * num_image_tokens,
120
+ 1,
121
+ )
122
+ index += 1
123
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
124
+
125
+ if video_grid_thw is not None:
126
+ video_index = 0
127
+ for i in range(len(text)):
128
+ while self.video_token in text[i]:
129
+ grid = video_grid_thw[video_index]
130
+ if hasattr(grid, "tolist"):
131
+ grid = grid.tolist()
132
+
133
+ num_frames = grid[0]
134
+ num_tokens_per_frame = int(
135
+ np.prod(grid) // merge_length // num_frames
136
+ )
137
+
138
+ video_structure = ""
139
+ for frame_idx in range(num_frames):
140
+ frame_structure = self.image_token * num_tokens_per_frame
141
+ video_structure += frame_structure
142
+
143
+ text[i] = text[i].replace(self.video_token, video_structure, 1)
144
+ video_index += 1
145
+
146
+ text_inputs = self.tokenizer(
147
+ text,
148
+ padding=padding,
149
+ return_token_type_ids=return_token_type_ids,
150
+ **kwargs,
151
+ )
152
+
153
+ return BatchFeature(
154
+ data={**text_inputs, **image_inputs, **video_inputs},
155
+ tensor_type=return_tensors,
156
+ )
157
+
158
+ def batch_decode(self, *args, **kwargs):
159
+ """Decode token IDs to text."""
160
+ return self.tokenizer.batch_decode(*args, **kwargs)
161
+
162
+ def decode(self, *args, **kwargs):
163
+ """Decode token IDs to text."""
164
+ return self.tokenizer.decode(*args, **kwargs)
165
+
166
+ def apply_chat_template(self, *args, **kwargs):
167
+ """Apply chat template using the tokenizer."""
168
+ return self.tokenizer.apply_chat_template(*args, **kwargs)
169
+
170
+ @property
171
+ def model_input_names(self):
172
+ """Return combined input names from tokenizer and image processor."""
173
+ tokenizer_input_names = (
174
+ self.tokenizer.model_input_names if self.tokenizer else []
175
+ )
176
+ image_processor_input_names = (
177
+ self.image_processor.model_input_names
178
+ if hasattr(self.image_processor, "model_input_names")
179
+ else []
180
+ )
181
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
182
+
183
+ @classmethod
184
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
185
+ """Load processor from pretrained model path."""
186
+ from transformers import AutoImageProcessor, AutoTokenizer
187
+
188
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
189
+
190
+ tokenizer = AutoTokenizer.from_pretrained(
191
+ pretrained_model_name_or_path,
192
+ trust_remote_code=trust_remote_code,
193
+ **kwargs,
194
+ )
195
+
196
+ image_processor = AutoImageProcessor.from_pretrained(
197
+ pretrained_model_name_or_path,
198
+ trust_remote_code=trust_remote_code,
199
+ **kwargs,
200
+ )
201
+
202
+ return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
203
+
204
+
205
+ __all__ = ["GlmOcrProcessor"]
206
+
207
+ # Register the processor with AutoProcessor for the glm_ocr model type
208
+ install_auto_processor_patch("glm_ocr", GlmOcrProcessor)
@@ -0,0 +1,342 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ if len(shape) == 4:
13
+ out_channels, kH, KW, _ = shape
14
+ return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
15
+ elif len(shape) == 5:
16
+ out_channels, kT, kH, KW, _ = shape
17
+ return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
18
+ else:
19
+ return False
20
+
21
+
22
+ def rotate_half(x):
23
+ """Rotates half the hidden dims of the input."""
24
+ x1 = x[..., : x.shape[-1] // 2]
25
+ x2 = x[..., x.shape[-1] // 2 :]
26
+ return mx.concatenate([-x2, x1], axis=-1)
27
+
28
+
29
+ def apply_rotary_pos_emb_vision(
30
+ q: mx.array, k: mx.array, cos: mx.array, sin: mx.array
31
+ ) -> tuple:
32
+ orig_q_dtype = q.dtype
33
+ orig_k_dtype = k.dtype
34
+ q, k = q.astype(mx.float32), k.astype(mx.float32)
35
+ cos = mx.expand_dims(cos, axis=-2).astype(mx.float32)
36
+ sin = mx.expand_dims(sin, axis=-2).astype(mx.float32)
37
+ q_embed = (q * cos) + (rotate_half(q) * sin)
38
+ k_embed = (k * cos) + (rotate_half(k) * sin)
39
+ q_embed = q_embed.astype(orig_q_dtype)
40
+ k_embed = k_embed.astype(orig_k_dtype)
41
+ return q_embed, k_embed
42
+
43
+
44
+ class GlmOcrVisionRotaryEmbedding(nn.Module):
45
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
46
+ super().__init__()
47
+ self.dim = dim
48
+ self.theta = theta
49
+
50
+ def __call__(self, seqlen: int) -> mx.array:
51
+ inv_freq = 1.0 / (
52
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
53
+ )
54
+ seq = mx.arange(seqlen, dtype=inv_freq.dtype)
55
+ freqs = mx.outer(seq, inv_freq)
56
+ return freqs
57
+
58
+
59
+ class GlmOcrVisionPatchEmbed(nn.Module):
60
+ def __init__(self, config: VisionConfig) -> None:
61
+ super().__init__()
62
+ self.config = config
63
+ self.patch_size = config.patch_size
64
+ self.temporal_patch_size = config.temporal_patch_size
65
+ self.in_channels = config.in_channels
66
+ self.embed_dim = config.hidden_size
67
+
68
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
69
+ self.proj = nn.Conv3d(
70
+ self.in_channels,
71
+ self.embed_dim,
72
+ kernel_size=kernel_size,
73
+ stride=kernel_size,
74
+ bias=True,
75
+ )
76
+
77
+ def __call__(self, hidden_states: mx.array) -> mx.array:
78
+ hidden_states = hidden_states.reshape(
79
+ -1,
80
+ self.in_channels,
81
+ self.temporal_patch_size,
82
+ self.patch_size,
83
+ self.patch_size,
84
+ ).moveaxis(1, 4)
85
+
86
+ hidden_states = self.proj(hidden_states)
87
+ hidden_states = hidden_states.reshape(-1, self.embed_dim)
88
+ return hidden_states
89
+
90
+
91
+ class GlmOcrVisionPatchMerger(nn.Module):
92
+ def __init__(
93
+ self, dim: int, context_dim: int, hidden_act: str, bias: bool = False
94
+ ) -> None:
95
+ super().__init__()
96
+ self.proj = nn.Linear(dim, dim, bias=bias)
97
+ self.post_projection_norm = nn.LayerNorm(dim)
98
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
99
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
100
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
101
+
102
+ def __call__(self, x: mx.array) -> mx.array:
103
+ x = self.proj(x)
104
+ x = nn.gelu(self.post_projection_norm(x))
105
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
106
+
107
+
108
+ class GlmOcrVisionAttention(nn.Module):
109
+ def __init__(self, config: VisionConfig) -> None:
110
+ super().__init__()
111
+ self.config = config
112
+ self.dim = config.hidden_size
113
+ self.num_heads = config.num_heads
114
+ self.head_dim = self.dim // self.num_heads
115
+ self.scale = self.head_dim**-0.5
116
+
117
+ self.qkv = nn.Linear(
118
+ config.hidden_size, config.hidden_size * 3, bias=config.attention_bias
119
+ )
120
+ self.proj = nn.Linear(
121
+ config.hidden_size, config.hidden_size, bias=config.attention_bias
122
+ )
123
+
124
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
125
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
126
+
127
+ def __call__(
128
+ self,
129
+ hidden_states: mx.array,
130
+ cu_seqlens: mx.array,
131
+ position_embeddings: tuple,
132
+ ) -> mx.array:
133
+ seq_length = hidden_states.shape[0]
134
+
135
+ qkv = self.qkv(hidden_states)
136
+ qkv = qkv.reshape(seq_length, 3, self.num_heads, -1)
137
+ qkv = qkv.transpose(1, 0, 2, 3)
138
+ q, k, v = mx.split(qkv, 3, axis=0)
139
+ q = q.squeeze(0)
140
+ k = k.squeeze(0)
141
+ v = v.squeeze(0)
142
+
143
+ q = self.q_norm(q)
144
+ k = self.k_norm(k)
145
+
146
+ cos, sin = position_embeddings
147
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
148
+
149
+ q = q.transpose(1, 0, 2)[None, ...]
150
+ k = k.transpose(1, 0, 2)[None, ...]
151
+ v = v.transpose(1, 0, 2)[None, ...]
152
+
153
+ lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
154
+ split_indices = []
155
+ cumsum = 0
156
+ for i, length in enumerate(lengths[:-1]):
157
+ cumsum += length
158
+ split_indices.append(cumsum)
159
+
160
+ q_splits = mx.split(q, split_indices, axis=2)
161
+ k_splits = mx.split(k, split_indices, axis=2)
162
+ v_splits = mx.split(v, split_indices, axis=2)
163
+
164
+ attn_outputs = []
165
+ for q_chunk, k_chunk, v_chunk in zip(q_splits, k_splits, v_splits):
166
+ output = mx.fast.scaled_dot_product_attention(
167
+ q_chunk, k_chunk, v_chunk, scale=self.scale, mask=None
168
+ )
169
+ attn_outputs.append(output)
170
+
171
+ attn_output = mx.concatenate(attn_outputs, axis=2)
172
+ # Transpose from (batch, heads, seq, head_dim) to (batch, seq, heads, head_dim)
173
+ # then reshape to (seq, hidden_size)
174
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(seq_length, -1)
175
+ attn_output = self.proj(attn_output)
176
+ return attn_output
177
+
178
+
179
+ class GlmOcrVisionMLP(nn.Module):
180
+ def __init__(self, config: VisionConfig) -> None:
181
+ super().__init__()
182
+ self.gate_proj = nn.Linear(
183
+ config.hidden_size, config.intermediate_size, bias=config.attention_bias
184
+ )
185
+ self.up_proj = nn.Linear(
186
+ config.hidden_size, config.intermediate_size, bias=config.attention_bias
187
+ )
188
+ self.down_proj = nn.Linear(
189
+ config.intermediate_size, config.hidden_size, bias=config.attention_bias
190
+ )
191
+
192
+ def __call__(self, x: mx.array) -> mx.array:
193
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
194
+
195
+
196
+ class GlmOcrVisionBlock(nn.Module):
197
+ def __init__(self, config: VisionConfig) -> None:
198
+ super().__init__()
199
+ self.norm1 = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
200
+ self.norm2 = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
201
+ self.attn = GlmOcrVisionAttention(config)
202
+ self.mlp = GlmOcrVisionMLP(config)
203
+
204
+ def __call__(
205
+ self, hidden_states: mx.array, cu_seqlens: mx.array, position_embeddings: tuple
206
+ ) -> mx.array:
207
+ hidden_states = hidden_states + self.attn(
208
+ self.norm1(hidden_states),
209
+ cu_seqlens=cu_seqlens,
210
+ position_embeddings=position_embeddings,
211
+ )
212
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
213
+ return hidden_states
214
+
215
+
216
+ class VisionModel(nn.Module):
217
+ def __init__(self, config: VisionConfig) -> None:
218
+ super().__init__()
219
+ self.config = config
220
+ self.model_type = config.model_type
221
+ self.spatial_merge_size = config.spatial_merge_size
222
+ self.patch_size = config.patch_size
223
+
224
+ self.patch_embed = GlmOcrVisionPatchEmbed(config)
225
+
226
+ head_dim = config.hidden_size // config.num_heads
227
+ self.rotary_pos_emb = GlmOcrVisionRotaryEmbedding(head_dim // 2)
228
+
229
+ self.blocks = [GlmOcrVisionBlock(config) for _ in range(config.depth)]
230
+
231
+ self.merger = GlmOcrVisionPatchMerger(
232
+ dim=config.out_hidden_size,
233
+ context_dim=config.out_hidden_size * config.in_channels,
234
+ hidden_act=config.hidden_act,
235
+ )
236
+
237
+ self.downsample = nn.Conv2d(
238
+ in_channels=config.hidden_size,
239
+ out_channels=config.out_hidden_size,
240
+ kernel_size=config.spatial_merge_size,
241
+ stride=config.spatial_merge_size,
242
+ bias=True,
243
+ )
244
+
245
+ self.post_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+
247
+ def rot_pos_emb(self, grid_thw: mx.array):
248
+ pos_ids = []
249
+
250
+ for t, h, w in grid_thw.tolist():
251
+ hpos_ids = mx.expand_dims(mx.arange(h), 1)
252
+ hpos_ids = mx.repeat(hpos_ids, w, axis=1)
253
+ hpos_ids = hpos_ids.reshape(
254
+ h // self.spatial_merge_size,
255
+ self.spatial_merge_size,
256
+ w // self.spatial_merge_size,
257
+ self.spatial_merge_size,
258
+ )
259
+ hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
260
+ hpos_ids = hpos_ids.flatten()
261
+
262
+ wpos_ids = mx.expand_dims(mx.arange(w), 0)
263
+ wpos_ids = mx.repeat(wpos_ids, h, axis=0)
264
+ wpos_ids = wpos_ids.reshape(
265
+ h // self.spatial_merge_size,
266
+ self.spatial_merge_size,
267
+ w // self.spatial_merge_size,
268
+ self.spatial_merge_size,
269
+ )
270
+ wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
271
+ wpos_ids = wpos_ids.flatten()
272
+
273
+ stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
274
+ pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
275
+
276
+ pos_ids = mx.concatenate(pos_ids, axis=0)
277
+ max_grid_size = mx.max(grid_thw[:, 1:])
278
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size.item())
279
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape(pos_ids.shape[0], -1)
280
+
281
+ emb = mx.concatenate((rotary_pos_emb, rotary_pos_emb), axis=-1)
282
+ return (mx.cos(emb), mx.sin(emb)), pos_ids
283
+
284
+ def __call__(
285
+ self,
286
+ hidden_states: mx.array,
287
+ grid_thw: mx.array,
288
+ output_hidden_states: Optional[bool] = None,
289
+ ) -> mx.array:
290
+ hidden_states = self.patch_embed(hidden_states)
291
+ position_embeddings, _ = self.rot_pos_emb(grid_thw)
292
+
293
+ seq_lens = grid_thw[:, 1] * grid_thw[:, 2]
294
+ repeats = grid_thw[:, 0]
295
+ repeated_values = []
296
+ for seq_len, repeat_count in zip(seq_lens.tolist(), repeats.tolist()):
297
+ repeated_values.extend([seq_len] * repeat_count)
298
+
299
+ cu_seqlens = mx.array(repeated_values).cumsum(axis=0)
300
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), constant_values=0)
301
+
302
+ for blk in self.blocks:
303
+ hidden_states = blk(
304
+ hidden_states,
305
+ cu_seqlens=cu_seqlens,
306
+ position_embeddings=position_embeddings,
307
+ )
308
+
309
+ hidden_states = self.post_layernorm(hidden_states)
310
+
311
+ hidden_states = hidden_states.reshape(
312
+ -1,
313
+ self.spatial_merge_size,
314
+ self.spatial_merge_size,
315
+ hidden_states.shape[-1],
316
+ )
317
+ hidden_states = self.downsample(hidden_states).reshape(
318
+ -1, self.config.out_hidden_size
319
+ )
320
+
321
+ merged_hidden_states = self.merger(hidden_states)
322
+ return merged_hidden_states
323
+
324
+ def sanitize(self, weights):
325
+ sanitized_weights = {}
326
+ for k, v in weights.items():
327
+ if "position_ids" in k:
328
+ continue
329
+ elif "patch_embed.proj.weight" in k or "downsample.weight" in k:
330
+ if check_array_shape(v):
331
+ sanitized_weights[k] = v
332
+ else:
333
+ if v.ndim == 5:
334
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
335
+ elif v.ndim == 4:
336
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
337
+ else:
338
+ sanitized_weights[k] = v
339
+ else:
340
+ sanitized_weights[k] = v
341
+
342
+ return sanitized_weights
@@ -0,0 +1,7 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .hunyuan_vl import LanguageModel, Model, VisionModel
3
+ from .processing_hunyuan_vl import (
4
+ HunYuanVLImageProcessor,
5
+ HunYuanVLProcessor,
6
+ ImageProcessor,
7
+ )
@@ -0,0 +1,136 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VisionConfig(BaseModelConfig):
10
+ model_type: str = "hunyuan_vl"
11
+ hidden_size: int = 1152
12
+ out_hidden_size: int = 1024
13
+ num_hidden_layers: int = 27
14
+ num_attention_heads: int = 16
15
+ intermediate_size: int = 4304
16
+ patch_size: int = 16
17
+ num_channels: int = 3
18
+ spatial_merge_size: int = 2
19
+ attention_dropout: float = 0.0
20
+ hidden_dropout: float = 0.0
21
+ rms_norm_eps: float = 1e-5
22
+ interpolate_mode: str = "bilinear"
23
+ cat_extra_token: int = 1
24
+ img_max_token_num: int = 4096
25
+ max_vit_seq_len: int = 16384
26
+ add_patchemb_bias: bool = True
27
+ max_image_size: int = 2048
28
+ hidden_act: str = "gelu"
29
+
30
+
31
+ @dataclass
32
+ class TextConfig(BaseModelConfig):
33
+ model_type: str = "hunyuan_vl"
34
+ vocab_size: int = 120818
35
+ org_vocab_size: int = 120818
36
+ hidden_size: int = 1024
37
+ num_hidden_layers: int = 24
38
+ num_attention_heads: int = 16
39
+ num_key_value_heads: Optional[int] = 8
40
+ head_dim: Optional[int] = 128
41
+ attention_head_dim: Optional[int] = 128
42
+ intermediate_size: int = 3584
43
+ hidden_act: str = "silu"
44
+ attention_bias: bool = False
45
+ mlp_bias: bool = False
46
+ attention_dropout: float = 0.0
47
+ use_qk_norm: bool = True
48
+ rope_theta: float = 10000.0
49
+ rope_scaling: Optional[Dict[str, Union[float, int, bool, List[int]]]] = field(
50
+ default_factory=lambda: {
51
+ "alpha": 1000.0,
52
+ "beta_fast": 32,
53
+ "beta_slow": 1,
54
+ "factor": 1.0,
55
+ "mscale": 1.0,
56
+ "mscale_all_dim": 1.0,
57
+ "type": "xdrope",
58
+ "xdrope_section": [16, 16, 16, 16],
59
+ }
60
+ )
61
+ max_position_embeddings: int = 32768
62
+ rms_norm_eps: float = 1e-5
63
+ norm_type: str = "rms"
64
+ tie_word_embeddings: bool = True
65
+ use_cache: bool = True
66
+ initializer_range: float = 0.02
67
+ routed_scaling_factor: float = 1.0
68
+ dtype: str = "bfloat16"
69
+ bos_token_id: int = 120000
70
+ eos_token_id: int = 120020
71
+ eod_token_id: int = 120020
72
+ pad_token_id: int = -1
73
+ pad_id: int = 120002
74
+ sep_token_id: int = 0
75
+ text_start_id: int = 7
76
+ text_end_id: int = 8
77
+ num_experts: int = 1
78
+ pretraining_tp: int = 1
79
+ use_cla: bool = False
80
+
81
+ def __post_init__(self):
82
+ if self.num_key_value_heads is None:
83
+ self.num_key_value_heads = self.num_attention_heads
84
+
85
+ if self.head_dim is None:
86
+ self.head_dim = self.hidden_size // self.num_attention_heads
87
+
88
+ if self.attention_head_dim is None:
89
+ self.attention_head_dim = self.head_dim
90
+
91
+
92
+ @dataclass
93
+ class ModelConfig(BaseModelConfig):
94
+ text_config: TextConfig = field(default_factory=TextConfig)
95
+ vision_config: VisionConfig = field(default_factory=VisionConfig)
96
+ model_type: str = "hunyuan_vl"
97
+ image_start_token_id: int = 120118
98
+ image_end_token_id: int = 120119
99
+ image_token_id: int = 120120
100
+ image_newline_token_id: int = 120121
101
+ bos_token_id: int = 120000
102
+ eos_token_id: int = 120020
103
+ pad_token_id: int = -1
104
+ pad_id: int = 120002
105
+ sep_token_id: int = 0
106
+ text_start_id: int = 7
107
+ text_end_id: int = 8
108
+ vocab_size: int = 120818
109
+ org_vocab_size: int = 120818
110
+ routed_scaling_factor: float = 1.0
111
+ norm_type: str = "rms"
112
+ dtype: str = "bfloat16"
113
+ use_cache: bool = True
114
+ tie_word_embeddings: bool = True
115
+
116
+ @classmethod
117
+ def from_dict(cls, params):
118
+ text_params = params.get("text_config", {})
119
+ vision_params = params.get("vision_config", {})
120
+
121
+ for key, value in params.items():
122
+ if key in TextConfig.__dataclass_fields__ and key not in text_params:
123
+ text_params[key] = value
124
+ if key in VisionConfig.__dataclass_fields__ and key not in vision_params:
125
+ vision_params[key] = value
126
+
127
+ return cls(
128
+ text_config=TextConfig.from_dict(text_params),
129
+ vision_config=VisionConfig.from_dict(vision_params),
130
+ **{
131
+ k: v
132
+ for k, v in params.items()
133
+ if k in inspect.signature(cls).parameters
134
+ and k not in ["text_config", "vision_config"]
135
+ },
136
+ )