fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,229 @@
1
+ """Processor for GLM-4V-MoE model.
2
+
3
+ Handles image/video token expansion based on grid dimensions and merge size.
4
+ Based on the HuggingFace transformers GLM-4.6V processor implementation.
5
+ """
6
+
7
+ from typing import List, Union
8
+
9
+ import numpy as np
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.processing_utils import ProcessorMixin
12
+
13
+
14
+ class Glm46VMoEProcessor(ProcessorMixin):
15
+ """
16
+ Processor for GLM-4V-MoE that wraps an image processor and tokenizer.
17
+
18
+ Handles:
19
+ - Image preprocessing via image_processor
20
+ - Token replacement for image/video placeholders based on grid dimensions
21
+ """
22
+
23
+ attributes = ["image_processor", "tokenizer"]
24
+ valid_kwargs = ["chat_template"]
25
+ image_processor_class = "AutoImageProcessor"
26
+ tokenizer_class = "AutoTokenizer"
27
+
28
+ def __init__(
29
+ self,
30
+ image_processor=None,
31
+ tokenizer=None,
32
+ chat_template=None,
33
+ **kwargs,
34
+ ):
35
+ self.tokenizer = tokenizer
36
+ self.image_processor = image_processor
37
+
38
+ # Get image/video tokens from tokenizer or use defaults
39
+ self.image_token = "<|image|>"
40
+ self.video_token = "<|video|>"
41
+
42
+ if tokenizer is not None:
43
+ self.image_token = getattr(tokenizer, "image_token", "<|image|>")
44
+ self.video_token = getattr(tokenizer, "video_token", "<|video|>")
45
+
46
+ # Get token IDs
47
+ self.image_token_id = getattr(tokenizer, "image_token_id", None)
48
+ if self.image_token_id is None:
49
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
50
+
51
+ self.video_token_id = getattr(tokenizer, "video_token_id", None)
52
+ if self.video_token_id is None:
53
+ self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
54
+ else:
55
+ self.image_token_id = None
56
+ self.video_token_id = None
57
+
58
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
59
+
60
+ def __call__(
61
+ self,
62
+ images=None,
63
+ text: Union[str, List[str]] = None,
64
+ videos=None,
65
+ **kwargs,
66
+ ) -> BatchFeature:
67
+ """
68
+ Process images/videos and text for the model.
69
+
70
+ Args:
71
+ images: Single image or list of images (PIL.Image, np.ndarray, etc.)
72
+ text: Single text or list of texts
73
+ videos: Video inputs (optional)
74
+ **kwargs: Additional arguments passed to image_processor and tokenizer
75
+
76
+ Returns:
77
+ BatchFeature with:
78
+ - input_ids: Token IDs with image/video placeholders expanded
79
+ - attention_mask: Attention mask
80
+ - pixel_values: Processed image/video patches
81
+ - image_grid_thw: Grid dimensions for each image
82
+ - video_grid_thw: Grid dimensions for each video (if videos provided)
83
+ """
84
+ image_inputs = {}
85
+ video_inputs = {}
86
+ image_grid_thw = None
87
+ video_grid_thw = None
88
+
89
+ # Pop tokenizer-specific kwargs that shouldn't go to image processor
90
+ padding = kwargs.pop("padding", False)
91
+ return_token_type_ids = kwargs.pop("return_token_type_ids", False)
92
+ return_tensors = kwargs.pop("return_tensors", None)
93
+
94
+ # Process images
95
+ if images is not None and self.image_processor is not None:
96
+ image_inputs = self.image_processor(images=images)
97
+ image_grid_thw = image_inputs.get("image_grid_thw")
98
+
99
+ # Process videos
100
+ if videos is not None:
101
+ if hasattr(self, "video_processor") and self.video_processor is not None:
102
+ video_inputs = self.video_processor(videos=videos, **kwargs)
103
+ video_grid_thw = video_inputs.get("video_grid_thw")
104
+
105
+ # Handle text input
106
+ if text is None:
107
+ text = [""]
108
+ elif not isinstance(text, list):
109
+ text = [text]
110
+
111
+ # Make a copy to avoid modifying original
112
+ text = [t for t in text]
113
+
114
+ # Get merge_size from image_processor
115
+ merge_size = getattr(self.image_processor, "merge_size", 2)
116
+ if hasattr(self.image_processor, "spatial_merge_size"):
117
+ merge_size = self.image_processor.spatial_merge_size
118
+ merge_length = merge_size**2
119
+
120
+ # Expand image tokens based on grid dimensions
121
+ if image_grid_thw is not None:
122
+ index = 0
123
+ for i in range(len(text)):
124
+ while self.image_token in text[i]:
125
+ # Calculate number of image tokens: prod(grid_thw) / merge_size^2
126
+ grid = image_grid_thw[index]
127
+ if hasattr(grid, "tolist"):
128
+ grid = grid.tolist()
129
+ num_image_tokens = int(np.prod(grid) // merge_length)
130
+
131
+ # Replace single image token with correct number of placeholder tokens
132
+ text[i] = text[i].replace(
133
+ self.image_token,
134
+ "<|placeholder|>" * num_image_tokens,
135
+ 1,
136
+ )
137
+ index += 1
138
+ # Replace placeholders back to image tokens
139
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
140
+
141
+ # Expand video tokens based on grid dimensions
142
+ if video_grid_thw is not None:
143
+ video_index = 0
144
+ for i in range(len(text)):
145
+ while self.video_token in text[i]:
146
+ grid = video_grid_thw[video_index]
147
+ if hasattr(grid, "tolist"):
148
+ grid = grid.tolist()
149
+
150
+ num_frames = grid[0]
151
+ # Calculate tokens per frame
152
+ num_tokens_per_frame = int(
153
+ np.prod(grid) // merge_length // num_frames
154
+ )
155
+
156
+ # Build video structure with frame tokens
157
+ video_structure = ""
158
+ for frame_idx in range(num_frames):
159
+ # Add image tokens for this frame
160
+ frame_structure = self.image_token * num_tokens_per_frame
161
+ video_structure += frame_structure
162
+
163
+ text[i] = text[i].replace(self.video_token, video_structure, 1)
164
+ video_index += 1
165
+
166
+ # Tokenize text
167
+ text_inputs = self.tokenizer(
168
+ text,
169
+ padding=padding,
170
+ return_token_type_ids=return_token_type_ids,
171
+ **kwargs,
172
+ )
173
+
174
+ return BatchFeature(
175
+ data={**text_inputs, **image_inputs, **video_inputs},
176
+ tensor_type=return_tensors,
177
+ )
178
+
179
+ def batch_decode(self, *args, **kwargs):
180
+ """Decode token IDs to text."""
181
+ return self.tokenizer.batch_decode(*args, **kwargs)
182
+
183
+ def decode(self, *args, **kwargs):
184
+ """Decode token IDs to text."""
185
+ return self.tokenizer.decode(*args, **kwargs)
186
+
187
+ def apply_chat_template(self, *args, **kwargs):
188
+ """Apply chat template using the tokenizer."""
189
+ return self.tokenizer.apply_chat_template(*args, **kwargs)
190
+
191
+ @property
192
+ def model_input_names(self):
193
+ """Return combined input names from tokenizer and image processor."""
194
+ tokenizer_input_names = (
195
+ self.tokenizer.model_input_names if self.tokenizer else []
196
+ )
197
+ image_processor_input_names = (
198
+ self.image_processor.model_input_names
199
+ if hasattr(self.image_processor, "model_input_names")
200
+ else []
201
+ )
202
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
203
+
204
+ @classmethod
205
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
206
+ """Load processor from pretrained model path."""
207
+ from transformers import AutoTokenizer, Glm4vImageProcessor
208
+
209
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained(
212
+ pretrained_model_name_or_path,
213
+ trust_remote_code=trust_remote_code,
214
+ **kwargs,
215
+ )
216
+
217
+ image_processor = Glm4vImageProcessor.from_pretrained(
218
+ pretrained_model_name_or_path,
219
+ trust_remote_code=trust_remote_code,
220
+ **kwargs,
221
+ )
222
+
223
+ return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
224
+
225
+
226
+ __all__ = ["Glm46VMoEProcessor"]
227
+
228
+ # Alias for backwards compatibility
229
+ Glm4VMoEProcessor = Glm46VMoEProcessor
@@ -0,0 +1,405 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..kernels import grid_sample
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 or 5 dimensions
14
+ if len(shape) == 4:
15
+ out_channels, kH, KW, _ = shape
16
+ # Check if out_channels is the largest, and kH and KW are the same
17
+ return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
18
+ elif len(shape) == 5:
19
+ B, out_channels, kH, KW, t = shape
20
+ # Special case for temporal dimension
21
+ if t == 3:
22
+ return True
23
+ # Check if out_channels is the largest, and kH and KW are the same
24
+ return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
25
+ else:
26
+ return False
27
+
28
+
29
+ def rotate_half(x):
30
+ """Rotates half the hidden dims of the input."""
31
+ x1 = x[..., : x.shape[-1] // 2]
32
+ x2 = x[..., x.shape[-1] // 2 :]
33
+ return mx.concatenate([-x2, x1], axis=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
37
+ orig_dtype = tensor.dtype
38
+
39
+ cos = mx.cos(freqs)
40
+ sin = mx.sin(freqs)
41
+
42
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
43
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
44
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
45
+
46
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
47
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
48
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
49
+
50
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
51
+ return output.astype(orig_dtype)
52
+
53
+
54
+ class Glm4vMoeVisionRotaryEmbedding(nn.Module):
55
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.theta = theta
59
+
60
+ def __call__(self, seqlen: int) -> mx.array:
61
+ inv_freq = 1.0 / (
62
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
63
+ )
64
+ seq = mx.arange(seqlen.item(), dtype=inv_freq.dtype)
65
+ freqs = mx.outer(seq, inv_freq)
66
+ return freqs
67
+
68
+
69
+ class Glm4vVisionEmbeddings(nn.Module):
70
+ def __init__(self, config: VisionConfig):
71
+ super().__init__()
72
+ self.config = config
73
+ self.embed_dim = config.hidden_size
74
+ self.image_size = config.image_size
75
+ self.patch_size = config.patch_size
76
+
77
+ self.num_patches = (self.image_size // self.patch_size) ** 2
78
+ self.num_positions = self.num_patches
79
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
80
+
81
+ def __call__(self, embeddings, lengths, image_shapes, h_coords, w_coords):
82
+
83
+ # Get position embedding parameters
84
+ pos_embed_weight = self.position_embedding.weight
85
+ hidden_size = pos_embed_weight.shape[1]
86
+ total_seq = h_coords.shape[0]
87
+
88
+ # Handle empty sequence case
89
+ if total_seq == 0:
90
+ adapted_pos_embed = mx.empty(0, hidden_size, dtype=pos_embed_weight.dtype)
91
+ else:
92
+ # Convert inputs to tensors if needed
93
+ if isinstance(lengths, list):
94
+ lengths = mx.array(lengths, dtype=mx.int32)
95
+ if not isinstance(image_shapes, mx.array):
96
+ image_shapes = mx.array(image_shapes, dtype=mx.int32)
97
+
98
+ # Prepare 2D position embedding
99
+ orig_size_sq = pos_embed_weight.shape[0]
100
+ orig_size = int(orig_size_sq**0.5)
101
+ pos_embed_2d = (
102
+ pos_embed_weight.reshape(orig_size, orig_size, hidden_size)
103
+ .transpose(2, 0, 1)[None, ...]
104
+ .astype(mx.float32)
105
+ )
106
+
107
+ # Calculate target dimensions for each patch
108
+ target_h = mx.concatenate(
109
+ [mx.repeat(image_shapes[i, 1], lengths[i]) for i in range(len(lengths))]
110
+ ).astype(mx.float32)
111
+ target_w = mx.concatenate(
112
+ [mx.repeat(image_shapes[i, 2], lengths[i]) for i in range(len(lengths))]
113
+ ).astype(mx.float32)
114
+
115
+ # Normalize coordinates to [-1, 1] range for grid_sample
116
+ h_coords = h_coords.astype(mx.float32)
117
+ w_coords = w_coords.astype(mx.float32)
118
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
119
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
120
+
121
+ # Create sampling grid
122
+ grid = mx.stack((norm_w, norm_h), axis=-1)[None, :, None, ...]
123
+
124
+ # Perform bicubic interpolation
125
+ interpolated_embed_fp32 = grid_sample(
126
+ pos_embed_2d.transpose(0, 2, 3, 1),
127
+ grid,
128
+ )
129
+
130
+ # Reshape and convert back to original dtype
131
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(1)
132
+ adapted_pos_embed = adapted_pos_embed_fp32.astype(pos_embed_weight.dtype)
133
+
134
+ # Add adapted position encoding to embeddings
135
+ embeddings = embeddings + adapted_pos_embed
136
+ return embeddings
137
+
138
+
139
+ class Glm4vMoeVisionPatchEmbed(nn.Module):
140
+ def __init__(self, config: VisionConfig) -> None:
141
+ super().__init__()
142
+ self.config = config
143
+ self.patch_size = config.patch_size
144
+ self.temporal_patch_size = config.temporal_patch_size
145
+ self.in_channels = config.in_channels
146
+ self.embed_dim = config.hidden_size
147
+
148
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
149
+ self.proj = nn.Conv3d(
150
+ self.in_channels,
151
+ self.embed_dim,
152
+ kernel_size=kernel_size,
153
+ stride=kernel_size,
154
+ )
155
+
156
+ def __call__(self, hidden_states: mx.array) -> mx.array:
157
+ hidden_states = hidden_states.reshape(
158
+ -1,
159
+ self.in_channels,
160
+ self.temporal_patch_size,
161
+ self.patch_size,
162
+ self.patch_size,
163
+ ).moveaxis(1, 4)
164
+
165
+ hidden_states = self.proj(hidden_states)
166
+ hidden_states = hidden_states.reshape(-1, self.embed_dim)
167
+ return hidden_states
168
+
169
+
170
+ class Glm4vMoeVisionPatchMerger(nn.Module):
171
+ def __init__(self, dim: int, context_dim: int, bias: bool = False) -> None:
172
+ super().__init__()
173
+
174
+ self.proj = nn.Linear(dim, dim, bias=bias)
175
+ self.post_projection_norm = nn.LayerNorm(dim)
176
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
177
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
178
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
179
+
180
+ def __call__(self, x: mx.array) -> mx.array:
181
+ x = self.proj(x)
182
+ x = nn.gelu(self.post_projection_norm(x))
183
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
184
+
185
+
186
+ class Glm4vMoeVisionAttention(nn.Module):
187
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
188
+ super().__init__()
189
+ self.num_heads = num_heads
190
+ self.head_dim = head_dim = dim // num_heads
191
+ self.scale = head_dim**-0.5
192
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
193
+ self.proj = nn.Linear(dim, dim, bias=False)
194
+
195
+ def __call__(
196
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
197
+ ) -> mx.array:
198
+ seq_length = x.shape[0]
199
+ qkv = (
200
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
201
+ )
202
+ q, k, v = mx.split(qkv, 3)
203
+
204
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
205
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
206
+
207
+ attention_mask = mx.full(
208
+ (1, seq_length, seq_length), mx.finfo(q.dtype).min, dtype=q.dtype
209
+ )
210
+
211
+ for i in range(1, len(cu_seqlens)):
212
+ start = int(cu_seqlens[i - 1])
213
+ end = int(cu_seqlens[i])
214
+ attention_mask[..., start:end, start:end] = 0
215
+
216
+ q = q.transpose(0, 2, 1, 3)
217
+ k = k.transpose(0, 2, 1, 3)
218
+ v = v.transpose(0, 2, 1, 3)
219
+
220
+ output = mx.fast.scaled_dot_product_attention(
221
+ q, k, v, scale=self.scale, mask=attention_mask
222
+ )
223
+ output = output.transpose(0, 2, 1, 3)
224
+ output = output.reshape(seq_length, -1)
225
+ return self.proj(output)
226
+
227
+
228
+ class Glm4vMoeVisionMLP(nn.Module):
229
+ def __init__(self, dim, hidden_dim):
230
+ super().__init__()
231
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
232
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
233
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
234
+
235
+ def __call__(self, x: mx.array) -> mx.array:
236
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
237
+
238
+
239
+ class Glm4vMoeVisionBlock(nn.Module):
240
+ def __init__(self, config: VisionConfig) -> None:
241
+ super().__init__()
242
+ self.norm1 = nn.RMSNorm(config.hidden_size, eps=1e-6)
243
+ self.norm2 = nn.RMSNorm(config.hidden_size, eps=1e-6)
244
+
245
+ self.attn = Glm4vMoeVisionAttention(
246
+ dim=config.hidden_size, num_heads=config.num_heads
247
+ )
248
+ self.mlp = Glm4vMoeVisionMLP(
249
+ dim=config.hidden_size, hidden_dim=config.out_hidden_size
250
+ )
251
+
252
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
253
+ hidden_states = hidden_states + self.attn(
254
+ self.norm1(hidden_states),
255
+ cu_seqlens=cu_seqlens,
256
+ rotary_pos_emb=rotary_pos_emb,
257
+ )
258
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
259
+ return hidden_states
260
+
261
+
262
+ class VisionModel(nn.Module):
263
+
264
+ def __init__(self, config: VisionConfig) -> None:
265
+ super().__init__()
266
+ self.config = config
267
+ self.model_type = config.model_type
268
+ if self.model_type not in ["glm4v_moe", "glm4v_moe_vision"]:
269
+ raise ValueError(f"Unsupported model type: {self.model_type}")
270
+ self.spatial_merge_size = config.spatial_merge_size
271
+
272
+ self.embeddings = Glm4vVisionEmbeddings(config)
273
+ self.patch_embed = Glm4vMoeVisionPatchEmbed(
274
+ config=config,
275
+ )
276
+
277
+ self.window_size = config.window_size
278
+ self.patch_size = config.patch_size
279
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
280
+
281
+ head_dim = config.hidden_size // config.num_heads
282
+ self.rotary_pos_emb = Glm4vMoeVisionRotaryEmbedding(head_dim // 2)
283
+
284
+ self.blocks = [Glm4vMoeVisionBlock(config) for _ in range(config.depth)]
285
+ self.merger = Glm4vMoeVisionPatchMerger(
286
+ dim=config.out_hidden_size, context_dim=config.intermediate_size
287
+ )
288
+
289
+ self.post_conv_layernorm = nn.RMSNorm(
290
+ config.hidden_size, eps=config.rms_norm_eps
291
+ )
292
+ self.downsample = nn.Conv2d(
293
+ in_channels=config.hidden_size,
294
+ out_channels=config.out_hidden_size,
295
+ kernel_size=config.spatial_merge_size,
296
+ stride=config.spatial_merge_size,
297
+ )
298
+ self.post_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
+
300
+ def rot_pos_emb(self, grid_thw):
301
+ pos_ids = []
302
+
303
+ for t, h, w in grid_thw.tolist():
304
+ hpos_ids = mx.expand_dims(mx.arange(h), 1)
305
+ hpos_ids = mx.repeat(hpos_ids, w, axis=1)
306
+ hpos_ids = hpos_ids.reshape(
307
+ h // self.spatial_merge_size,
308
+ self.spatial_merge_size,
309
+ w // self.spatial_merge_size,
310
+ self.spatial_merge_size,
311
+ )
312
+ hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
313
+ hpos_ids = hpos_ids.flatten()
314
+
315
+ wpos_ids = mx.expand_dims(mx.arange(w), 0)
316
+ wpos_ids = mx.repeat(wpos_ids, h, axis=0)
317
+ wpos_ids = wpos_ids.reshape(
318
+ h // self.spatial_merge_size,
319
+ self.spatial_merge_size,
320
+ w // self.spatial_merge_size,
321
+ self.spatial_merge_size,
322
+ )
323
+ wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
324
+ wpos_ids = wpos_ids.flatten()
325
+
326
+ stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
327
+ pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
328
+
329
+ pos_ids = mx.concatenate(pos_ids, axis=0)
330
+ max_grid_size = mx.max(grid_thw[:, 1:])
331
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
332
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids]
333
+
334
+ return rotary_pos_emb.reshape(pos_ids.shape[0], -1), pos_ids
335
+
336
+ def __call__(
337
+ self,
338
+ hidden_states: mx.array,
339
+ grid_thw: mx.array,
340
+ output_hidden_states: Optional[bool] = None,
341
+ ) -> mx.array:
342
+
343
+ hidden_states = self.patch_embed(hidden_states)
344
+ hidden_states = self.post_conv_layernorm(hidden_states)
345
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
346
+
347
+ seq_lens = grid_thw[:, 1] * grid_thw[:, 2]
348
+ repeats = grid_thw[:, 0]
349
+ repeated_values = []
350
+ for i, (seq_len, repeat_count) in enumerate(
351
+ zip(seq_lens.tolist(), repeats.tolist())
352
+ ):
353
+ repeated_values.extend([seq_len] * repeat_count)
354
+
355
+ cu_seqlens = mx.array(repeated_values).cumsum(axis=0)
356
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), constant_values=0)
357
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
358
+ hidden_states = self.embeddings(
359
+ hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
360
+ )
361
+
362
+ for blk in self.blocks:
363
+ hidden_states = blk(
364
+ hidden_states,
365
+ cu_seqlens=cu_seqlens,
366
+ rotary_pos_emb=rotary_pos_emb,
367
+ )
368
+
369
+ hidden_states = self.post_layernorm(hidden_states)
370
+
371
+ hidden_states = hidden_states.reshape(
372
+ -1,
373
+ self.spatial_merge_size,
374
+ self.spatial_merge_size,
375
+ hidden_states.shape[-1],
376
+ )
377
+ hidden_states = self.downsample(hidden_states).reshape(
378
+ -1, self.config.out_hidden_size
379
+ )
380
+
381
+ hidden_states = self.merger(hidden_states)
382
+ return hidden_states
383
+
384
+ def sanitize(self, weights):
385
+ sanitized_weights = {}
386
+ for k, v in weights.items():
387
+ if "position_ids" in k:
388
+ # Remove unused position_ids
389
+ continue
390
+ elif "patch_embed.proj.weight" in k or "downsample.weight" in k:
391
+ # PyTorch conv2d weight tensors have shape:
392
+ # [out_channels, in_channels, kH, KW]
393
+ # MLX conv2d expects the weight be of shape:
394
+ # [out_channels, kH, KW, in_channels]
395
+ if check_array_shape(v):
396
+ sanitized_weights[k] = v
397
+ else:
398
+ if v.ndim == 5:
399
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
400
+ if v.ndim == 4:
401
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
402
+ else:
403
+ sanitized_weights[k] = v
404
+
405
+ return sanitized_weights
@@ -0,0 +1,3 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .glm_ocr import LanguageModel, Model, VisionModel
3
+ from .processing import GlmOcrProcessor