fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,585 @@
1
+ from typing import Any, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from .config import ModelConfig, TextConfig
13
+
14
+
15
+ def _compute_default_rope_parameters(
16
+ config: Optional[TextConfig] = None,
17
+ **rope_kwargs,
18
+ ) -> tuple[mx.array, float]:
19
+
20
+ if config is not None and len(rope_kwargs) > 0:
21
+ raise ValueError(
22
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
23
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
24
+ )
25
+ if len(rope_kwargs) > 0:
26
+ base = rope_kwargs["base"]
27
+ dim = rope_kwargs["dim"]
28
+ elif config is not None:
29
+ base = config.rope_theta
30
+ partial_rotary_factor = config.partial_rotary_factor
31
+ head_dim = config.head_dim
32
+ dim = int(head_dim * partial_rotary_factor)
33
+
34
+ attention_factor = 1.0
35
+
36
+ inv_freq = 1.0 / (
37
+ base ** (mx.arange(0, dim, 2, dtype=mx.int64).astype(mx.float32) / dim)
38
+ )
39
+ return inv_freq, attention_factor
40
+
41
+
42
+ class GlmOcrRotaryEmbedding(nn.Module):
43
+ def __init__(self, config: TextConfig):
44
+ super().__init__()
45
+
46
+ self.rope_type = config.rope_parameters.get("rope_type", "default")
47
+ self.max_seq_len_cached = config.max_position_embeddings
48
+ self.original_max_seq_len = config.max_position_embeddings
49
+
50
+ self.config = config
51
+ self.mrope_section = config.rope_parameters.get("mrope_section", [16, 24, 24])
52
+
53
+ self.rope_init_fn = _compute_default_rope_parameters
54
+
55
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config)
56
+ self._inv_freq = mx.array(inv_freq, dtype=mx.float32)
57
+ self._original_inv_freq = mx.array(inv_freq, dtype=mx.float32)
58
+
59
+ def apply_mrope(self, freqs, mrope_section):
60
+ """Apply M-RoPE by selecting different dimensions for T, H, W."""
61
+ split_indices = np.cumsum(mrope_section)[:-1].tolist()
62
+ chunks = mx.split(freqs, split_indices, axis=-1)
63
+ result = mx.concatenate(
64
+ [chunk[i % 3] for i, chunk in enumerate(chunks)], axis=-1
65
+ )
66
+ return result
67
+
68
+ def __call__(self, x, position_ids):
69
+ inv_freq_expanded = self._inv_freq[None, None, :, None].astype(mx.float32)
70
+ inv_freq_expanded = mx.broadcast_to(
71
+ inv_freq_expanded, (3, position_ids.shape[1], self._inv_freq.shape[0], 1)
72
+ )
73
+ position_ids_expanded = position_ids[:, :, None, :].astype(mx.float32)
74
+
75
+ freqs = (
76
+ inv_freq_expanded.astype(mx.float32)
77
+ @ position_ids_expanded.astype(mx.float32)
78
+ ).transpose(0, 1, 3, 2)
79
+
80
+ freqs = self.apply_mrope(freqs, self.mrope_section)
81
+
82
+ emb = mx.concatenate((freqs, freqs), axis=-1)
83
+ cos = mx.cos(emb) * self.attention_scaling
84
+ sin = mx.sin(emb) * self.attention_scaling
85
+
86
+ return cos.astype(x.dtype), sin.astype(x.dtype)
87
+
88
+
89
+ def rotate_half_llm(x):
90
+ """Rotates half the hidden dims of the input."""
91
+ x1 = x[..., 0::2]
92
+ x2 = x[..., 1::2]
93
+ return mx.flatten(mx.stack([-x2, x1], axis=-1), start_axis=-2, end_axis=-1)
94
+
95
+
96
+ def repeat_interleave(x, repeats, axis=-1):
97
+ """
98
+ Repeat elements of an array along an axis, interleaving the repeated values.
99
+ Like torch.repeat_interleave: [a,b,c] with repeats=2 -> [a,a,b,b,c,c]
100
+ """
101
+ shape = list(x.shape)
102
+ x = mx.expand_dims(x, axis=axis + 1 if axis >= 0 else axis)
103
+ tile_shape = [1] * len(x.shape)
104
+ tile_shape[axis + 1 if axis >= 0 else axis] = repeats
105
+ x = mx.tile(x, tile_shape)
106
+ new_shape = shape.copy()
107
+ new_shape[axis] = shape[axis] * repeats
108
+ return x.reshape(new_shape)
109
+
110
+
111
+ def apply_rotary_pos_emb(q, k, cos, sin):
112
+ """
113
+ Applies Rotary Position Embedding to the query and key tensors.
114
+ Matches PyTorch's GLM-OCR implementation exactly.
115
+
116
+ Args:
117
+ q: Query tensor of shape (batch, n_heads, seq_len, head_dim)
118
+ k: Key tensor of shape (batch, n_kv_heads, seq_len, head_dim)
119
+ cos: Cosine tensor of shape (batch, seq_len, head_dim)
120
+ sin: Sine tensor of shape (batch, seq_len, head_dim)
121
+ """
122
+ cos = cos[:, None, :, :]
123
+ sin = sin[:, None, :, :]
124
+
125
+ cos = repeat_interleave(cos[..., : cos.shape[-1] // 2], repeats=2, axis=-1)
126
+ sin = repeat_interleave(sin[..., : sin.shape[-1] // 2], repeats=2, axis=-1)
127
+
128
+ rotary_dim = cos.shape[-1]
129
+ q_rot = q[..., :rotary_dim]
130
+ q_pass = q[..., rotary_dim:]
131
+
132
+ k_rot = k[..., :rotary_dim]
133
+ k_pass = k[..., rotary_dim:]
134
+
135
+ q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
136
+ k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
137
+
138
+ q_embed = mx.concatenate([q_embed, q_pass], axis=-1)
139
+ k_embed = mx.concatenate([k_embed, k_pass], axis=-1)
140
+
141
+ return q_embed, k_embed
142
+
143
+
144
+ class GlmOcrAttention(nn.Module):
145
+ def __init__(self, args: TextConfig):
146
+ super().__init__()
147
+
148
+ dim = args.hidden_size
149
+ self.n_heads = n_heads = args.num_attention_heads
150
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
151
+ self.head_dim = args.head_dim
152
+ self.scale = self.head_dim**-0.5
153
+
154
+ self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=args.attention_bias)
155
+ self.k_proj = nn.Linear(
156
+ dim, n_kv_heads * self.head_dim, bias=args.attention_bias
157
+ )
158
+ self.v_proj = nn.Linear(
159
+ dim, n_kv_heads * self.head_dim, bias=args.attention_bias
160
+ )
161
+ self.o_proj = nn.Linear(n_heads * self.head_dim, dim, bias=False)
162
+
163
+ self.rope_parameters = args.rope_parameters
164
+
165
+ def __call__(
166
+ self,
167
+ x: mx.array,
168
+ mask: Optional[mx.array] = None,
169
+ cache: Optional[Any] = None,
170
+ position_embeddings: Optional[mx.array] = None,
171
+ ) -> mx.array:
172
+ B, L, _ = x.shape
173
+
174
+ queries = self.q_proj(x)
175
+ keys = self.k_proj(x)
176
+ values = self.v_proj(x)
177
+
178
+ queries = queries.reshape(B, L, self.n_heads, -1)
179
+ keys = keys.reshape(B, L, self.n_kv_heads, -1)
180
+
181
+ queries = queries.transpose(0, 2, 1, 3)
182
+ keys = keys.transpose(0, 2, 1, 3)
183
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
184
+
185
+ cos, sin = position_embeddings
186
+
187
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
188
+
189
+ if cache is not None:
190
+ keys, values = cache.update_and_fetch(keys, values)
191
+
192
+ if mask is not None and isinstance(mask, mx.array):
193
+ mask = mask[..., : keys.shape[-2]]
194
+
195
+ output = scaled_dot_product_attention(
196
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
197
+ )
198
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
199
+ return self.o_proj(output)
200
+
201
+
202
+ class GlmOcrMLP(nn.Module):
203
+ def __init__(
204
+ self, config: TextConfig, hidden_size: int = None, intermediate_size: int = None
205
+ ):
206
+ super().__init__()
207
+ self.config = config
208
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
209
+ self.intermediate_size = (
210
+ config.intermediate_size if intermediate_size is None else intermediate_size
211
+ )
212
+
213
+ self.gate_up_proj = nn.Linear(
214
+ self.hidden_size, self.intermediate_size * 2, bias=False
215
+ )
216
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
217
+
218
+ def __call__(self, x):
219
+ x = self.gate_up_proj(x)
220
+ gate, x = mx.split(x, 2, axis=-1)
221
+ return self.down_proj(nn.silu(gate) * x)
222
+
223
+
224
+ class GlmOcrDecoderLayer(nn.Module):
225
+ def __init__(self, config: TextConfig):
226
+ super().__init__()
227
+ self.self_attn = GlmOcrAttention(config)
228
+ self.mlp = GlmOcrMLP(config)
229
+
230
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
+ self.post_attention_layernorm = nn.RMSNorm(
232
+ config.hidden_size, eps=config.rms_norm_eps
233
+ )
234
+ self.post_self_attn_layernorm = nn.RMSNorm(
235
+ config.hidden_size, eps=config.rms_norm_eps
236
+ )
237
+ self.post_mlp_layernorm = nn.RMSNorm(
238
+ config.hidden_size, eps=config.rms_norm_eps
239
+ )
240
+
241
+ def __call__(
242
+ self,
243
+ x: mx.array,
244
+ mask: Optional[mx.array] = None,
245
+ cache: Optional[Any] = None,
246
+ position_embeddings: Optional[mx.array] = None,
247
+ ) -> mx.array:
248
+ r = x
249
+
250
+ x = self.self_attn(self.input_layernorm(x), mask, cache, position_embeddings)
251
+
252
+ x = self.post_self_attn_layernorm(x)
253
+ x = r + x
254
+
255
+ r = x
256
+ x = self.post_attention_layernorm(x)
257
+ x = self.mlp(x)
258
+ x = self.post_mlp_layernorm(x)
259
+ x = r + x
260
+ return x
261
+
262
+
263
+ class GlmOcrTextModel(nn.Module):
264
+ def __init__(self, config: TextConfig):
265
+ super().__init__()
266
+ self.vocab_size = config.vocab_size
267
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
268
+ self.layers = [
269
+ GlmOcrDecoderLayer(config) for _ in range(config.num_hidden_layers)
270
+ ]
271
+ self.start_idx = 0
272
+ self.end_idx = len(self.layers)
273
+ self.num_layers = self.end_idx
274
+
275
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276
+
277
+ self.rotary_emb = GlmOcrRotaryEmbedding(config)
278
+
279
+ def __call__(
280
+ self,
281
+ inputs: mx.array,
282
+ inputs_embeds: Optional[mx.array] = None,
283
+ cache: Optional[Any] = None,
284
+ mask: Optional[mx.array] = None,
285
+ position_ids: Optional[mx.array] = None,
286
+ ) -> mx.array:
287
+
288
+ if inputs_embeds is None:
289
+ h = self.embed_tokens(inputs)
290
+ else:
291
+ h = inputs_embeds.astype(self.norm.weight.dtype)
292
+
293
+ if position_ids is None:
294
+ position_ids = mx.arange(cache[0].offset, cache[0].offset + h.shape[-2])
295
+ position_ids = mx.expand_dims(position_ids, axis=0)
296
+ position_ids = mx.tile(position_ids, (3, 1, 1))
297
+
298
+ position_embeddings = self.rotary_emb(h, position_ids)
299
+
300
+ if mask is None:
301
+ mask = create_attention_mask(h, cache)
302
+
303
+ if cache is None:
304
+ cache = [None] * self.num_layers
305
+
306
+ for i in range(self.num_layers):
307
+ h = self.layers[self.start_idx + i](h, mask, cache[i], position_embeddings)
308
+
309
+ return self.norm(h)
310
+
311
+
312
+ class LanguageModel(nn.Module):
313
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
314
+ super().__init__()
315
+ self.args = args
316
+ self.config = config
317
+ self.model_type = args.model_type
318
+ self.model = GlmOcrTextModel(args)
319
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
320
+ self._rope_deltas = None
321
+ self._position_ids = None
322
+
323
+ def get_rope_index(
324
+ self,
325
+ input_ids: mx.array,
326
+ image_grid_thw: Optional[mx.array] = None,
327
+ video_grid_thw: Optional[mx.array] = None,
328
+ attention_mask: Optional[mx.array] = None,
329
+ ):
330
+ batch_size, seq_length = input_ids.shape
331
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
332
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
333
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
334
+ image_token_id = self.config.image_token_id
335
+ video_token_id = self.config.video_token_id
336
+ image_start_token_id = self.config.image_start_token_id
337
+ mrope_position_deltas = []
338
+ if input_ids is not None and (
339
+ image_grid_thw is not None or video_grid_thw is not None
340
+ ):
341
+ total_input_ids = input_ids
342
+ if (
343
+ attention_mask is None
344
+ or attention_mask.shape[-1] != input_ids.shape[-1]
345
+ ):
346
+ attention_mask = mx.ones_like(input_ids)
347
+ position_ids = mx.ones(
348
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
349
+ )
350
+ image_index, video_index = 0, 0
351
+ for i, input_ids in enumerate(total_input_ids):
352
+ input_ids = mx.where(
353
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
354
+ )
355
+ image_nums, video_nums = 0, 0
356
+ vision_start_indices = mx.sum(
357
+ mx.where(
358
+ input_ids == image_start_token_id,
359
+ mx.arange(input_ids.shape[0]),
360
+ mx.zeros_like(input_ids),
361
+ )
362
+ )
363
+ vision_tokens = input_ids[vision_start_indices + 1]
364
+ image_nums = (vision_tokens == image_token_id).sum().item()
365
+ video_nums = (vision_tokens == video_token_id).sum().item()
366
+ input_tokens = input_ids.tolist()
367
+ llm_pos_ids_list: list = []
368
+ st = 0
369
+ remain_images, remain_videos = image_nums, video_nums
370
+ for _ in range(image_nums + video_nums):
371
+ if image_token_id in input_tokens and remain_images > 0:
372
+ ed_image = input_tokens.index(image_token_id, st)
373
+ else:
374
+ ed_image = len(input_tokens) + 1
375
+ if video_token_id in input_tokens and remain_videos > 0:
376
+ ed_video = input_tokens.index(video_token_id, st)
377
+ else:
378
+ ed_video = len(input_tokens) + 1
379
+ if ed_image < ed_video:
380
+ t, h, w = (
381
+ image_grid_thw[image_index][0],
382
+ image_grid_thw[image_index][1],
383
+ image_grid_thw[image_index][2],
384
+ )
385
+ image_index += 1
386
+ remain_images -= 1
387
+ ed = ed_image
388
+ else:
389
+ t, h, w = (
390
+ video_grid_thw[video_index][0],
391
+ video_grid_thw[video_index][1],
392
+ video_grid_thw[video_index][2],
393
+ )
394
+ video_index += 1
395
+ remain_videos -= 1
396
+ ed = ed_video
397
+ llm_grid_t, llm_grid_h, llm_grid_w = (
398
+ t.item(),
399
+ h.item() // spatial_merge_size,
400
+ w.item() // spatial_merge_size,
401
+ )
402
+ text_len = ed - st
403
+ st_idx = (
404
+ llm_pos_ids_list[-1].max() + 1
405
+ if len(llm_pos_ids_list) > 0
406
+ else 0
407
+ )
408
+ index = mx.arange(text_len).reshape(1, text_len)
409
+ index = mx.broadcast_to(index, (3, text_len))
410
+ index = index + st_idx
411
+ llm_pos_ids_list.append(index)
412
+ t_index = mx.arange(llm_grid_t).reshape(llm_grid_t, 1)
413
+ t_index = mx.broadcast_to(
414
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
415
+ )
416
+ t_index = t_index.flatten()
417
+
418
+ h_index = mx.arange(llm_grid_h).reshape(1, llm_grid_h, 1)
419
+ h_index = mx.broadcast_to(
420
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
421
+ )
422
+ h_index = h_index.flatten()
423
+
424
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, llm_grid_w)
425
+ w_index = mx.broadcast_to(
426
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
427
+ )
428
+ w_index = w_index.flatten()
429
+
430
+ llm_pos_ids_list.append(
431
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
432
+ )
433
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
434
+ if st < len(input_tokens):
435
+ st_idx = (
436
+ llm_pos_ids_list[-1].max() + 1
437
+ if len(llm_pos_ids_list) > 0
438
+ else 0
439
+ )
440
+ text_len = len(input_tokens) - st
441
+
442
+ t_index = mx.arange(text_len).reshape(1, text_len)
443
+ t_index = mx.broadcast_to(t_index, (3, text_len))
444
+
445
+ llm_pos_ids_list.append(t_index + st_idx)
446
+
447
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
448
+ mask = mx.array(attention_mask[i] == 1)
449
+ expanded_mask = mx.expand_dims(mask, axis=0)
450
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
451
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
452
+ new_positions = mx.where(
453
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
454
+ )
455
+ updated_position_ids = mx.concatenate(
456
+ [
457
+ position_ids[:, :i, :],
458
+ new_positions,
459
+ position_ids[:, i + 1 :, :],
460
+ ],
461
+ axis=1,
462
+ )
463
+ position_ids = updated_position_ids
464
+ mrope_position_deltas.append(
465
+ llm_positions.max() + 1 - len(total_input_ids[i])
466
+ )
467
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
468
+ return position_ids, mrope_position_deltas
469
+ else:
470
+ if attention_mask is not None:
471
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
472
+ position_ids = mx.where(
473
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
474
+ )
475
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
476
+ position_ids = mx.tile(position_ids, (3, 1, 1))
477
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
478
+ -1, keepdims=True
479
+ )[0]
480
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
481
+ else:
482
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
483
+ position_ids = mx.broadcast_to(
484
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
485
+ )
486
+ mrope_position_deltas = mx.zeros(
487
+ [input_ids.shape[0], 1],
488
+ dtype=input_ids.dtype,
489
+ )
490
+ return position_ids, mrope_position_deltas
491
+
492
+ def __call__(
493
+ self,
494
+ inputs: mx.array,
495
+ inputs_embeds: Optional[mx.array] = None,
496
+ mask: Optional[mx.array] = None,
497
+ cache=None,
498
+ **kwargs,
499
+ ):
500
+
501
+ position_ids = kwargs.pop("position_ids", None)
502
+ pixel_values = kwargs.pop("pixel_values", None)
503
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
504
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
505
+ if pixel_values is not None:
506
+ self._rope_deltas = None
507
+
508
+ cache_offset = 0
509
+ if cache and cache[0] is not None:
510
+ offset = cache[0].offset
511
+ if isinstance(offset, int):
512
+ cache_offset = offset
513
+ elif isinstance(offset, mx.array):
514
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
515
+ else:
516
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
517
+
518
+ # Check if mask shape matches input shape (for chunked prefill compatibility)
519
+ rope_mask = mask
520
+ if mask is not None and mask.shape[-1] != inputs.shape[-1]:
521
+ rope_mask = None
522
+
523
+ if position_ids is None and (rope_mask is None or rope_mask.ndim == 2):
524
+ # Calculate RoPE index once per generation in the pre-fill stage only
525
+ if (
526
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
527
+ or self._rope_deltas is None
528
+ or cache is None
529
+ ):
530
+ # Use cached position_ids if available (pre-computed in get_input_embeddings)
531
+ if self._position_ids is not None:
532
+ seq_length = inputs.shape[1]
533
+ position_ids = self._position_ids[
534
+ :, :, cache_offset : cache_offset + seq_length
535
+ ]
536
+ else:
537
+ position_ids, rope_deltas = self.get_rope_index(
538
+ inputs, image_grid_thw, video_grid_thw, rope_mask
539
+ )
540
+ self._rope_deltas = rope_deltas
541
+ self._position_ids = position_ids
542
+ else:
543
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
544
+ batch_size, seq_length = inputs.shape
545
+ delta = mx.array(
546
+ cache_offset + self._rope_deltas if cache is not None else 0
547
+ )
548
+ position_ids = mx.arange(seq_length).reshape(1, -1)
549
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
550
+
551
+ if cache_offset is not None:
552
+ if delta.ndim == 0:
553
+ delta = mx.expand_dims(delta, axis=0)
554
+
555
+ if delta.shape[0] < batch_size:
556
+ delta = mx.tile(delta, (batch_size, 1))
557
+ else:
558
+ delta = delta[:batch_size]
559
+
560
+ position_ids = mx.add(position_ids, delta)[None, ...]
561
+ position_ids = mx.broadcast_to(
562
+ position_ids, (3, batch_size, seq_length)
563
+ )
564
+
565
+ out = self.model(
566
+ inputs,
567
+ cache=cache,
568
+ inputs_embeds=inputs_embeds,
569
+ position_ids=position_ids,
570
+ mask=mask,
571
+ )
572
+
573
+ out = self.lm_head(out)
574
+ return LanguageModelOutput(logits=out)
575
+
576
+ def sanitize(self, weights):
577
+ return weights
578
+
579
+ @property
580
+ def layers(self):
581
+ return self.model.layers
582
+
583
+ @property
584
+ def n_kv_heads(self):
585
+ return self.args.num_key_value_heads