fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,220 @@
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ from transformers.feature_extraction_utils import BatchFeature
5
+ from transformers.processing_utils import ProcessorMixin
6
+
7
+
8
+ class Glm46VProcessor(ProcessorMixin):
9
+ """
10
+ Processor for GLM-4V that wraps an image processor and tokenizer.
11
+
12
+ Handles:
13
+ - Image preprocessing via image_processor
14
+ - Token replacement for image/video placeholders based on grid dimensions
15
+ """
16
+
17
+ attributes = ["image_processor", "tokenizer"]
18
+ valid_kwargs = ["chat_template"]
19
+ image_processor_class = "AutoImageProcessor"
20
+ tokenizer_class = "AutoTokenizer"
21
+
22
+ def __init__(
23
+ self,
24
+ image_processor=None,
25
+ tokenizer=None,
26
+ chat_template=None,
27
+ **kwargs,
28
+ ):
29
+ self.tokenizer = tokenizer
30
+ self.image_processor = image_processor
31
+
32
+ # Get image/video tokens from tokenizer or use defaults
33
+ self.image_token = "<|image|>"
34
+ self.video_token = "<|video|>"
35
+
36
+ if tokenizer is not None:
37
+ self.image_token = getattr(tokenizer, "image_token", "<|image|>")
38
+ self.video_token = getattr(tokenizer, "video_token", "<|video|>")
39
+
40
+ # Get token IDs
41
+ self.image_token_id = getattr(tokenizer, "image_token_id", None)
42
+ if self.image_token_id is None:
43
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
44
+
45
+ self.video_token_id = getattr(tokenizer, "video_token_id", None)
46
+ if self.video_token_id is None:
47
+ self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
48
+ else:
49
+ self.image_token_id = None
50
+ self.video_token_id = None
51
+
52
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
53
+
54
+ def __call__(
55
+ self,
56
+ images=None,
57
+ text: Union[str, List[str]] = None,
58
+ videos=None,
59
+ **kwargs,
60
+ ) -> BatchFeature:
61
+ """
62
+ Process images/videos and text for the model.
63
+
64
+ Args:
65
+ images: Single image or list of images (PIL.Image, np.ndarray, etc.)
66
+ text: Single text or list of texts
67
+ videos: Video inputs (optional)
68
+ **kwargs: Additional arguments passed to image_processor and tokenizer
69
+
70
+ Returns:
71
+ BatchFeature with:
72
+ - input_ids: Token IDs with image/video placeholders expanded
73
+ - attention_mask: Attention mask
74
+ - pixel_values: Processed image/video patches
75
+ - image_grid_thw: Grid dimensions for each image
76
+ - video_grid_thw: Grid dimensions for each video (if videos provided)
77
+ """
78
+ image_inputs = {}
79
+ video_inputs = {}
80
+ image_grid_thw = None
81
+ video_grid_thw = None
82
+
83
+ # Pop tokenizer-specific kwargs that shouldn't go to image processor
84
+ padding = kwargs.pop("padding", False)
85
+ return_token_type_ids = kwargs.pop("return_token_type_ids", False)
86
+ return_tensors = kwargs.pop("return_tensors", None)
87
+
88
+ # Process images
89
+ if images is not None and self.image_processor is not None:
90
+ image_inputs = self.image_processor(images=images)
91
+ image_grid_thw = image_inputs.get("image_grid_thw")
92
+
93
+ # Process videos
94
+ if videos is not None:
95
+ if hasattr(self, "video_processor") and self.video_processor is not None:
96
+ video_inputs = self.video_processor(videos=videos, **kwargs)
97
+ video_grid_thw = video_inputs.get("video_grid_thw")
98
+
99
+ # Handle text input
100
+ if text is None:
101
+ text = [""]
102
+ elif not isinstance(text, list):
103
+ text = [text]
104
+
105
+ # Make a copy to avoid modifying original
106
+ text = [t for t in text]
107
+
108
+ # Get merge_size from image_processor
109
+ merge_size = getattr(self.image_processor, "merge_size", 2)
110
+ if hasattr(self.image_processor, "spatial_merge_size"):
111
+ merge_size = self.image_processor.spatial_merge_size
112
+ merge_length = merge_size**2
113
+
114
+ # Expand image tokens based on grid dimensions
115
+ if image_grid_thw is not None:
116
+ index = 0
117
+ for i in range(len(text)):
118
+ while self.image_token in text[i]:
119
+ # Calculate number of image tokens: prod(grid_thw) / merge_size^2
120
+ grid = image_grid_thw[index]
121
+ if hasattr(grid, "tolist"):
122
+ grid = grid.tolist()
123
+ num_image_tokens = int(np.prod(grid) // merge_length)
124
+
125
+ # Replace single image token with correct number of placeholder tokens
126
+ text[i] = text[i].replace(
127
+ self.image_token,
128
+ "<|placeholder|>" * num_image_tokens,
129
+ 1,
130
+ )
131
+ index += 1
132
+ # Replace placeholders back to image tokens
133
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
134
+
135
+ # Expand video tokens based on grid dimensions
136
+ if video_grid_thw is not None:
137
+ video_index = 0
138
+ for i in range(len(text)):
139
+ while self.video_token in text[i]:
140
+ grid = video_grid_thw[video_index]
141
+ if hasattr(grid, "tolist"):
142
+ grid = grid.tolist()
143
+
144
+ num_frames = grid[0]
145
+ # Calculate tokens per frame
146
+ num_tokens_per_frame = int(
147
+ np.prod(grid) // merge_length // num_frames
148
+ )
149
+
150
+ # Build video structure with frame tokens
151
+ video_structure = ""
152
+ for frame_idx in range(num_frames):
153
+ # Add image tokens for this frame
154
+ frame_structure = self.image_token * num_tokens_per_frame
155
+ video_structure += frame_structure
156
+
157
+ text[i] = text[i].replace(self.video_token, video_structure, 1)
158
+ video_index += 1
159
+
160
+ # Tokenize text
161
+ text_inputs = self.tokenizer(
162
+ text,
163
+ padding=padding,
164
+ return_token_type_ids=return_token_type_ids,
165
+ **kwargs,
166
+ )
167
+
168
+ return BatchFeature(
169
+ data={**text_inputs, **image_inputs, **video_inputs},
170
+ tensor_type=return_tensors,
171
+ )
172
+
173
+ def batch_decode(self, *args, **kwargs):
174
+ """Decode token IDs to text."""
175
+ return self.tokenizer.batch_decode(*args, **kwargs)
176
+
177
+ def decode(self, *args, **kwargs):
178
+ """Decode token IDs to text."""
179
+ return self.tokenizer.decode(*args, **kwargs)
180
+
181
+ def apply_chat_template(self, *args, **kwargs):
182
+ """Apply chat template using the tokenizer."""
183
+ return self.tokenizer.apply_chat_template(*args, **kwargs)
184
+
185
+ @property
186
+ def model_input_names(self):
187
+ """Return combined input names from tokenizer and image processor."""
188
+ tokenizer_input_names = (
189
+ self.tokenizer.model_input_names if self.tokenizer else []
190
+ )
191
+ image_processor_input_names = (
192
+ self.image_processor.model_input_names
193
+ if hasattr(self.image_processor, "model_input_names")
194
+ else []
195
+ )
196
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
197
+
198
+ @classmethod
199
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
200
+ """Load processor from pretrained model path."""
201
+ from transformers import AutoTokenizer, Glm4vImageProcessor
202
+
203
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
204
+
205
+ tokenizer = AutoTokenizer.from_pretrained(
206
+ pretrained_model_name_or_path,
207
+ trust_remote_code=trust_remote_code,
208
+ **kwargs,
209
+ )
210
+
211
+ image_processor = Glm4vImageProcessor.from_pretrained(
212
+ pretrained_model_name_or_path,
213
+ trust_remote_code=trust_remote_code,
214
+ **kwargs,
215
+ )
216
+
217
+ return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
218
+
219
+
220
+ __all__ = ["Glm46VProcessor"]
@@ -0,0 +1,406 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..kernels import grid_sample
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 or 5 dimensions
14
+ if len(shape) == 4:
15
+ out_channels, kH, KW, _ = shape
16
+ # Check if out_channels is the largest, and kH and KW are the same
17
+ return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
18
+ elif len(shape) == 5:
19
+ B, out_channels, kH, KW, t = shape
20
+ # Special case for temporal dimension
21
+ if t == 3:
22
+ return True
23
+ # Check if out_channels is the largest, and kH and KW are the same
24
+ return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
25
+ else:
26
+ return False
27
+
28
+
29
+ def rotate_half(x):
30
+ """Rotates half the hidden dims of the input."""
31
+ x1 = x[..., : x.shape[-1] // 2]
32
+ x2 = x[..., x.shape[-1] // 2 :]
33
+ return mx.concatenate([-x2, x1], axis=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
37
+ orig_dtype = tensor.dtype
38
+
39
+ cos = mx.cos(freqs)
40
+ sin = mx.sin(freqs)
41
+
42
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
43
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
44
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
45
+
46
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
47
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
48
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
49
+
50
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
51
+ return output.astype(orig_dtype)
52
+
53
+
54
+ class Glm4vVisionRotaryEmbedding(nn.Module):
55
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.theta = theta
59
+
60
+ def __call__(self, seqlen: int) -> mx.array:
61
+ inv_freq = 1.0 / (
62
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
63
+ )
64
+ seq = mx.arange(seqlen.item(), dtype=inv_freq.dtype)
65
+ freqs = mx.outer(seq, inv_freq)
66
+ return freqs
67
+
68
+
69
+ class Glm4vVisionEmbeddings(nn.Module):
70
+ def __init__(self, config: VisionConfig):
71
+ super().__init__()
72
+ self.config = config
73
+ self.embed_dim = config.hidden_size
74
+ self.image_size = config.image_size
75
+ self.patch_size = config.patch_size
76
+
77
+ self.num_patches = (self.image_size // self.patch_size) ** 2
78
+ self.num_positions = self.num_patches
79
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
80
+
81
+ def __call__(self, embeddings, lengths, image_shapes, h_coords, w_coords):
82
+
83
+ # Get position embedding parameters
84
+ pos_embed_weight = self.position_embedding.weight
85
+ hidden_size = pos_embed_weight.shape[1]
86
+ total_seq = h_coords.shape[0]
87
+
88
+ # Handle empty sequence case
89
+ if total_seq == 0:
90
+ adapted_pos_embed = mx.empty(0, hidden_size, dtype=pos_embed_weight.dtype)
91
+ else:
92
+ # Convert inputs to tensors if needed
93
+ if isinstance(lengths, list):
94
+ lengths = mx.array(lengths, dtype=mx.int32)
95
+ if not isinstance(image_shapes, mx.array):
96
+ image_shapes = mx.array(image_shapes, dtype=mx.int32)
97
+
98
+ # Prepare 2D position embedding
99
+ orig_size_sq = pos_embed_weight.shape[0]
100
+ orig_size = int(orig_size_sq**0.5)
101
+ pos_embed_2d = (
102
+ pos_embed_weight.reshape(orig_size, orig_size, hidden_size)
103
+ .transpose(2, 0, 1)[None, ...]
104
+ .astype(mx.float32)
105
+ )
106
+
107
+ # Calculate target dimensions for each patch
108
+ target_h = mx.concatenate(
109
+ [mx.repeat(image_shapes[i, 1], lengths[i]) for i in range(len(lengths))]
110
+ ).astype(mx.float32)
111
+ target_w = mx.concatenate(
112
+ [mx.repeat(image_shapes[i, 2], lengths[i]) for i in range(len(lengths))]
113
+ ).astype(mx.float32)
114
+
115
+ # Normalize coordinates to [-1, 1] range for grid_sample
116
+ h_coords = h_coords.astype(mx.float32)
117
+ w_coords = w_coords.astype(mx.float32)
118
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
119
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
120
+
121
+ # Create sampling grid
122
+ grid = mx.stack((norm_w, norm_h), axis=-1)[None, :, None, ...]
123
+
124
+ # Perform bicubic interpolation
125
+ interpolated_embed_fp32 = grid_sample(
126
+ pos_embed_2d.transpose(0, 2, 3, 1),
127
+ grid,
128
+ )
129
+
130
+ # Reshape and convert back to original dtype
131
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(1)
132
+ adapted_pos_embed = adapted_pos_embed_fp32.astype(pos_embed_weight.dtype)
133
+
134
+ # Add adapted position encoding to embeddings
135
+ embeddings = embeddings + adapted_pos_embed
136
+ return embeddings
137
+
138
+
139
+ class Glm4vVisionPatchEmbed(nn.Module):
140
+ def __init__(self, config: VisionConfig) -> None:
141
+ super().__init__()
142
+ self.config = config
143
+ self.patch_size = config.patch_size
144
+ self.temporal_patch_size = config.temporal_patch_size
145
+ self.in_channels = config.in_channels
146
+ self.embed_dim = config.hidden_size
147
+
148
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
149
+ self.proj = nn.Conv3d(
150
+ self.in_channels,
151
+ self.embed_dim,
152
+ kernel_size=kernel_size,
153
+ stride=kernel_size,
154
+ )
155
+
156
+ def __call__(self, hidden_states: mx.array) -> mx.array:
157
+ hidden_states = hidden_states.reshape(
158
+ -1,
159
+ self.in_channels,
160
+ self.temporal_patch_size,
161
+ self.patch_size,
162
+ self.patch_size,
163
+ ).moveaxis(1, 4)
164
+
165
+ hidden_states = self.proj(hidden_states)
166
+ hidden_states = hidden_states.reshape(-1, self.embed_dim)
167
+ return hidden_states
168
+
169
+
170
+ class Glm4vVisionPatchMerger(nn.Module):
171
+ def __init__(self, dim: int, context_dim: int, bias: bool = False) -> None:
172
+ super().__init__()
173
+
174
+ self.proj = nn.Linear(dim, dim, bias=bias)
175
+ self.post_projection_norm = nn.LayerNorm(dim)
176
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
177
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
178
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
179
+
180
+ def __call__(self, x: mx.array) -> mx.array:
181
+ x = self.proj(x)
182
+ x = nn.gelu(self.post_projection_norm(x))
183
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
184
+
185
+
186
+ class Glm4vVisionAttention(nn.Module):
187
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
188
+ super().__init__()
189
+ self.num_heads = num_heads
190
+ self.head_dim = head_dim = dim // num_heads
191
+ self.scale = head_dim**-0.5
192
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
193
+ self.proj = nn.Linear(dim, dim, bias=False)
194
+
195
+ def __call__(
196
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
197
+ ) -> mx.array:
198
+ seq_length = x.shape[0]
199
+ qkv = (
200
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
201
+ )
202
+ q, k, v = mx.split(qkv, 3)
203
+
204
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
205
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
206
+
207
+ q = q.transpose(0, 2, 1, 3)
208
+ k = k.transpose(0, 2, 1, 3)
209
+ v = v.transpose(0, 2, 1, 3)
210
+
211
+ lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
212
+ splits = [
213
+ mx.split(tensor, [lengths[0], sum(lengths[:2])], axis=2)
214
+ for tensor in (q, k, v)
215
+ ]
216
+
217
+ attn_outputs = []
218
+ for q, k, v in zip(*splits):
219
+ output = mx.fast.scaled_dot_product_attention(
220
+ q, k, v, scale=self.scale, mask=None
221
+ )
222
+ attn_outputs.append(output)
223
+
224
+ output = mx.concatenate(attn_outputs, axis=2)
225
+ output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1)
226
+ return self.proj(output)
227
+
228
+
229
+ class Glm4vVisionMLP(nn.Module):
230
+ def __init__(self, dim, hidden_dim):
231
+ super().__init__()
232
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
233
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
234
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
235
+
236
+ def __call__(self, x: mx.array) -> mx.array:
237
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
238
+
239
+
240
+ class Glm4vVisionBlock(nn.Module):
241
+ def __init__(self, config: VisionConfig) -> None:
242
+ super().__init__()
243
+ self.norm1 = nn.RMSNorm(config.hidden_size, eps=1e-6)
244
+ self.norm2 = nn.RMSNorm(config.hidden_size, eps=1e-6)
245
+
246
+ self.attn = Glm4vVisionAttention(
247
+ dim=config.hidden_size, num_heads=config.num_heads
248
+ )
249
+ self.mlp = Glm4vVisionMLP(
250
+ dim=config.hidden_size, hidden_dim=config.out_hidden_size
251
+ )
252
+
253
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
254
+ hidden_states = hidden_states + self.attn(
255
+ self.norm1(hidden_states),
256
+ cu_seqlens=cu_seqlens,
257
+ rotary_pos_emb=rotary_pos_emb,
258
+ )
259
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
260
+ return hidden_states
261
+
262
+
263
+ class VisionModel(nn.Module):
264
+
265
+ def __init__(self, config: VisionConfig) -> None:
266
+ super().__init__()
267
+ self.config = config
268
+ self.model_type = config.model_type
269
+ if self.model_type not in ["glm4v", "glm4v_vision"]:
270
+ raise ValueError(f"Unsupported model type: {self.model_type}")
271
+ self.spatial_merge_size = config.spatial_merge_size
272
+
273
+ self.embeddings = Glm4vVisionEmbeddings(config)
274
+ self.patch_embed = Glm4vVisionPatchEmbed(
275
+ config=config,
276
+ )
277
+
278
+ self.window_size = config.window_size
279
+ self.patch_size = config.patch_size
280
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
281
+
282
+ head_dim = config.hidden_size // config.num_heads
283
+ self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
284
+
285
+ self.blocks = [Glm4vVisionBlock(config) for _ in range(config.depth)]
286
+ self.merger = Glm4vVisionPatchMerger(
287
+ dim=config.out_hidden_size, context_dim=config.intermediate_size
288
+ )
289
+
290
+ self.post_conv_layernorm = nn.RMSNorm(
291
+ config.hidden_size, eps=config.rms_norm_eps
292
+ )
293
+ self.downsample = nn.Conv2d(
294
+ in_channels=config.hidden_size,
295
+ out_channels=config.out_hidden_size,
296
+ kernel_size=config.spatial_merge_size,
297
+ stride=config.spatial_merge_size,
298
+ )
299
+ self.post_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
300
+
301
+ def rot_pos_emb(self, grid_thw):
302
+ pos_ids = []
303
+
304
+ for t, h, w in grid_thw.tolist():
305
+ hpos_ids = mx.expand_dims(mx.arange(h), 1)
306
+ hpos_ids = mx.repeat(hpos_ids, w, axis=1)
307
+ hpos_ids = hpos_ids.reshape(
308
+ h // self.spatial_merge_size,
309
+ self.spatial_merge_size,
310
+ w // self.spatial_merge_size,
311
+ self.spatial_merge_size,
312
+ )
313
+ hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
314
+ hpos_ids = hpos_ids.flatten()
315
+
316
+ wpos_ids = mx.expand_dims(mx.arange(w), 0)
317
+ wpos_ids = mx.repeat(wpos_ids, h, axis=0)
318
+ wpos_ids = wpos_ids.reshape(
319
+ h // self.spatial_merge_size,
320
+ self.spatial_merge_size,
321
+ w // self.spatial_merge_size,
322
+ self.spatial_merge_size,
323
+ )
324
+ wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
325
+ wpos_ids = wpos_ids.flatten()
326
+
327
+ stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
328
+ pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
329
+
330
+ pos_ids = mx.concatenate(pos_ids, axis=0)
331
+ max_grid_size = mx.max(grid_thw[:, 1:])
332
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
333
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids]
334
+
335
+ return rotary_pos_emb.reshape(pos_ids.shape[0], -1), pos_ids
336
+
337
+ def __call__(
338
+ self,
339
+ hidden_states: mx.array,
340
+ grid_thw: mx.array,
341
+ output_hidden_states: Optional[bool] = None,
342
+ ) -> mx.array:
343
+
344
+ hidden_states = self.patch_embed(hidden_states)
345
+ hidden_states = self.post_conv_layernorm(hidden_states)
346
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
347
+
348
+ seq_lens = grid_thw[:, 1] * grid_thw[:, 2]
349
+ repeats = grid_thw[:, 0]
350
+ repeated_values = []
351
+ for i, (seq_len, repeat_count) in enumerate(
352
+ zip(seq_lens.tolist(), repeats.tolist())
353
+ ):
354
+ repeated_values.extend([seq_len] * repeat_count)
355
+
356
+ cu_seqlens = mx.array(repeated_values).cumsum(axis=0)
357
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), constant_values=0)
358
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
359
+ hidden_states = self.embeddings(
360
+ hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
361
+ )
362
+
363
+ for blk in self.blocks:
364
+ hidden_states = blk(
365
+ hidden_states,
366
+ cu_seqlens=cu_seqlens,
367
+ rotary_pos_emb=rotary_pos_emb,
368
+ )
369
+
370
+ hidden_states = self.post_layernorm(hidden_states)
371
+
372
+ hidden_states = hidden_states.reshape(
373
+ -1,
374
+ self.spatial_merge_size,
375
+ self.spatial_merge_size,
376
+ hidden_states.shape[-1],
377
+ )
378
+ hidden_states = self.downsample(hidden_states).reshape(
379
+ -1, self.config.out_hidden_size
380
+ )
381
+
382
+ hidden_states = self.merger(hidden_states)
383
+ return hidden_states
384
+
385
+ def sanitize(self, weights):
386
+ sanitized_weights = {}
387
+ for k, v in weights.items():
388
+ if "position_ids" in k:
389
+ # Remove unused position_ids
390
+ continue
391
+ elif "patch_embed.proj.weight" in k or "downsample.weight" in k:
392
+ # PyTorch conv2d weight tensors have shape:
393
+ # [out_channels, in_channels, kH, KW]
394
+ # MLX conv2d expects the weight be of shape:
395
+ # [out_channels, kH, KW, in_channels]
396
+ if check_array_shape(v):
397
+ sanitized_weights[k] = v
398
+ else:
399
+ if v.ndim == 5:
400
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
401
+ if v.ndim == 4:
402
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
403
+ else:
404
+ sanitized_weights[k] = v
405
+
406
+ return sanitized_weights
@@ -0,0 +1,3 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .glm4v_moe import LanguageModel, Model, VisionModel
3
+ from .processing import Glm46VMoEProcessor