fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,522 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from ..cache import KVCache
13
+ from .config import ModelConfig, TextConfig
14
+
15
+
16
+ class PaddleOCRRotaryEmbedding:
17
+ def __init__(self, dim, max_position_embeddings=8192, base=500000):
18
+ self.dim = dim
19
+ self.max_position_embeddings = max_position_embeddings
20
+ self.base = base
21
+
22
+ inv_freq = 1.0 / (
23
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
24
+ )
25
+ self.inv_freq = inv_freq
26
+ self.attention_scaling = 1.0
27
+
28
+ def __call__(self, x, position_ids):
29
+ inv_freq_expanded = mx.broadcast_to(
30
+ self.inv_freq[None, None, :, None].astype(mx.float32),
31
+ (3, position_ids.shape[1], self.inv_freq.shape[0], 1),
32
+ )
33
+ position_ids_expanded = position_ids[:, :, None, :].astype(
34
+ mx.float32
35
+ ) # shape (3, bs, 1, positions)
36
+
37
+ freqs = inv_freq_expanded @ position_ids_expanded
38
+ freqs = mx.swapaxes(freqs, 2, 3)
39
+ emb = mx.concatenate([freqs, freqs], axis=-1)
40
+ cos = mx.cos(emb) * self.attention_scaling
41
+ sin = mx.sin(emb) * self.attention_scaling
42
+
43
+ return cos.astype(x.dtype), sin.astype(x.dtype)
44
+
45
+
46
+ def rotate_half(x):
47
+ """Rotates half the hidden dims of the input."""
48
+ x1 = x[..., : x.shape[-1] // 2]
49
+ x2 = x[..., x.shape[-1] // 2 :]
50
+ return mx.concatenate([-x2, x1], axis=-1)
51
+
52
+
53
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section):
54
+ """
55
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
56
+ Args:
57
+ q (mx.array): The query tensor.
58
+ k (mx.array): The key tensor.
59
+ cos (mx.array): The cosine part of the rotary embedding.
60
+ sin (mx.array): The sine part of the rotary embedding.
61
+ mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
62
+ Returns:
63
+ tuple(mx.array): The rotated query and key tensors.
64
+ """
65
+ mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
66
+
67
+ cos = mx.concatenate(
68
+ [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
69
+ )[
70
+ :, None, :, :
71
+ ] # unsqueeze dim 1
72
+ sin = mx.concatenate(
73
+ [m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
74
+ )[:, None, :, :]
75
+
76
+ rotary_dim = cos.shape[-1]
77
+ q_rot = q[..., :rotary_dim]
78
+ q_pass = q[..., rotary_dim:]
79
+
80
+ k_rot = k[..., :rotary_dim]
81
+ k_pass = k[..., rotary_dim:]
82
+
83
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
84
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
85
+
86
+ q_embed = mx.concatenate([q_embed, q_pass], axis=-1)
87
+ k_embed = mx.concatenate([k_embed, k_pass], axis=-1)
88
+
89
+ return q_embed, k_embed
90
+
91
+
92
+ class Attention(nn.Module):
93
+ def __init__(self, args: TextConfig):
94
+ super().__init__()
95
+
96
+ dim = args.hidden_size
97
+ self.n_heads = n_heads = args.num_attention_heads
98
+ assert args.num_key_value_heads is not None
99
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
100
+
101
+ self.head_dim = head_dim = getattr(
102
+ args, "head_dim", args.hidden_size // n_heads
103
+ )
104
+ self.scale = head_dim**-0.5
105
+
106
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.use_bias)
107
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.use_bias)
108
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.use_bias)
109
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.use_bias)
110
+
111
+ self.rope_parameters = args.rope_parameters or args.rope_scaling
112
+
113
+ def __call__(
114
+ self,
115
+ x: mx.array,
116
+ mask: Optional[mx.array] = None,
117
+ cache: Optional[KVCache] = None,
118
+ position_embeddings: Optional[mx.array] = None,
119
+ ) -> mx.array:
120
+ B, L, D = x.shape
121
+
122
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
123
+
124
+ # Prepare the queries, keys and values for the attention computation
125
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
126
+ 0, 2, 1, 3
127
+ )
128
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
129
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
130
+ 0, 2, 1, 3
131
+ )
132
+
133
+ cos, sin = position_embeddings
134
+
135
+ queries, keys = apply_multimodal_rotary_pos_emb(
136
+ queries, keys, cos, sin, self.rope_parameters["mrope_section"]
137
+ )
138
+
139
+ if cache is not None:
140
+ keys, values = cache.update_and_fetch(keys, values)
141
+
142
+ if mask is not None and isinstance(mask, mx.array):
143
+ mask = mask[..., : keys.shape[-2]]
144
+
145
+ output = scaled_dot_product_attention(
146
+ queries, keys, values, cache, scale=self.scale, mask=mask
147
+ )
148
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
149
+ return self.o_proj(output)
150
+
151
+
152
+ class MLP(nn.Module):
153
+ def __init__(self, dim, hidden_dim):
154
+ super().__init__()
155
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
156
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
157
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
158
+
159
+ def __call__(self, x) -> mx.array:
160
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
161
+
162
+
163
+ class PaddleOCRDecoderLayer(nn.Module):
164
+ def __init__(self, args: TextConfig):
165
+ super().__init__()
166
+ self.num_attention_heads = args.num_attention_heads
167
+ self.hidden_size = args.hidden_size
168
+ self.self_attn = Attention(args)
169
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
170
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
171
+ self.post_attention_layernorm = nn.RMSNorm(
172
+ args.hidden_size, eps=args.rms_norm_eps
173
+ )
174
+ self.args = args
175
+
176
+ def __call__(
177
+ self,
178
+ x: mx.array,
179
+ mask: Optional[mx.array] = None,
180
+ cache: Optional[KVCache] = None,
181
+ position_embeddings: Optional[mx.array] = None,
182
+ ) -> mx.array:
183
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_embeddings)
184
+ h = x + r
185
+ r = self.mlp(self.post_attention_layernorm(h))
186
+ out = h + r
187
+ return out
188
+
189
+
190
+ class PaddleOCRModel(nn.Module):
191
+ def __init__(self, args: TextConfig):
192
+ super().__init__()
193
+ self.args = args
194
+ self.vocab_size = args.vocab_size
195
+ self.num_hidden_layers = args.num_hidden_layers
196
+ assert self.vocab_size > 0
197
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
198
+ self.layers = [
199
+ PaddleOCRDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
200
+ ]
201
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
202
+ self.rotary_emb = PaddleOCRRotaryEmbedding(
203
+ args.head_dim,
204
+ max_position_embeddings=args.max_position_embeddings,
205
+ base=args.rope_theta,
206
+ )
207
+
208
+ def __call__(
209
+ self,
210
+ inputs: mx.array,
211
+ inputs_embeds: Optional[mx.array] = None,
212
+ mask: Optional[mx.array] = None,
213
+ cache=None,
214
+ position_ids: Optional[mx.array] = None,
215
+ ):
216
+ if inputs_embeds is None:
217
+ h = self.embed_tokens(inputs)
218
+ else:
219
+ h = inputs_embeds
220
+
221
+ if position_ids is None:
222
+ position_ids = mx.arange(cache[0].offset, cache[0].offset + h.shape[-2])
223
+ position_ids = mx.expand_dims(position_ids, axis=0)
224
+ position_ids = mx.tile(position_ids, (3, 1, 1))
225
+
226
+ position_embeddings = self.rotary_emb(h, position_ids)
227
+
228
+ if cache is None:
229
+ cache = [None] * len(self.layers)
230
+
231
+ if mask is None:
232
+ mask = create_attention_mask(h, cache)
233
+
234
+ for layer, c in zip(self.layers, cache):
235
+ h = layer(h, mask, c, position_embeddings)
236
+
237
+ return self.norm(h)
238
+
239
+
240
+ class LanguageModel(nn.Module):
241
+ def __init__(self, args: TextConfig, config: ModelConfig):
242
+ super().__init__()
243
+ self.args = args
244
+ self.config = config
245
+ self.model_type = args.model_type
246
+ self.model = PaddleOCRModel(args)
247
+ self._rope_deltas = None
248
+ self._position_ids = None
249
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
250
+
251
+ def get_rope_index(
252
+ self,
253
+ input_ids: mx.array,
254
+ image_grid_thw: Optional[mx.array] = None,
255
+ video_grid_thw: Optional[mx.array] = None,
256
+ attention_mask: Optional[mx.array] = None,
257
+ ):
258
+ # Calculate RoPE index for image/video tokens
259
+ batch_size, seq_length = input_ids.shape
260
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
261
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
262
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
263
+ image_token_id = self.config.image_token_id
264
+ video_token_id = self.config.video_token_id
265
+ vision_start_token_id = self.config.vision_start_token_id
266
+ mrope_position_deltas = []
267
+ if input_ids is not None and (
268
+ image_grid_thw is not None or video_grid_thw is not None
269
+ ):
270
+ total_input_ids = input_ids
271
+ if attention_mask is None:
272
+ attention_mask = mx.ones_like(input_ids)
273
+ position_ids = mx.ones(
274
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
275
+ )
276
+ image_index, video_index = 0, 0
277
+ for i, input_ids in enumerate(total_input_ids):
278
+ input_ids = mx.where(
279
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
280
+ )
281
+ image_nums, video_nums = 0, 0
282
+ vision_start_indices = mx.sum(
283
+ mx.where(
284
+ input_ids == vision_start_token_id,
285
+ mx.arange(input_ids.shape[0]),
286
+ mx.zeros_like(input_ids),
287
+ )
288
+ )
289
+ vision_tokens = input_ids[vision_start_indices + 1]
290
+ image_nums = (vision_tokens == image_token_id).sum().item()
291
+ video_nums = (vision_tokens == video_token_id).sum().item()
292
+ input_tokens = input_ids.tolist()
293
+ llm_pos_ids_list: list = []
294
+ st = 0
295
+ remain_images, remain_videos = image_nums, video_nums
296
+ for _ in range(image_nums + video_nums):
297
+ if image_token_id in input_tokens and remain_images > 0:
298
+ ed_image = input_tokens.index(image_token_id, st)
299
+ else:
300
+ ed_image = len(input_tokens) + 1
301
+ if video_token_id in input_tokens and remain_videos > 0:
302
+ ed_video = input_tokens.index(video_token_id, st)
303
+ else:
304
+ ed_video = len(input_tokens) + 1
305
+ if ed_image < ed_video:
306
+ t, h, w = (
307
+ image_grid_thw[image_index][0],
308
+ image_grid_thw[image_index][1],
309
+ image_grid_thw[image_index][2],
310
+ )
311
+ image_index += 1
312
+ remain_images -= 1
313
+ ed = ed_image
314
+ else:
315
+ t, h, w = (
316
+ video_grid_thw[video_index][0],
317
+ video_grid_thw[video_index][1],
318
+ video_grid_thw[video_index][2],
319
+ )
320
+ video_index += 1
321
+ remain_videos -= 1
322
+ ed = ed_video
323
+ llm_grid_t, llm_grid_h, llm_grid_w = (
324
+ t.item(),
325
+ h.item() // spatial_merge_size,
326
+ w.item() // spatial_merge_size,
327
+ )
328
+ text_len = ed - st
329
+ st_idx = (
330
+ llm_pos_ids_list[-1].max() + 1
331
+ if len(llm_pos_ids_list) > 0
332
+ else 0
333
+ )
334
+ index = mx.arange(text_len).reshape(1, text_len)
335
+ index = mx.broadcast_to(index, (3, text_len))
336
+ index = index + st_idx
337
+ llm_pos_ids_list.append(index)
338
+ t_index = mx.arange(llm_grid_t).reshape(
339
+ llm_grid_t, 1
340
+ ) # Equivalent to .view(-1, 1)
341
+ t_index = mx.broadcast_to(
342
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
343
+ ) # Equivalent to expand()
344
+ t_index = t_index.flatten() # Flattens to 1D
345
+
346
+ h_index = mx.arange(llm_grid_h).reshape(
347
+ 1, llm_grid_h, 1
348
+ ) # Equivalent to .view(1, -1)
349
+ h_index = mx.broadcast_to(
350
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
351
+ ) # Equivalent to expand()
352
+ h_index = h_index.flatten() # Flattens to 1D
353
+
354
+ w_index = mx.arange(llm_grid_w).reshape(
355
+ 1, 1, llm_grid_w
356
+ ) # Equivalent to .view(1, -1)
357
+ w_index = mx.broadcast_to(
358
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
359
+ ) # Equivalent to expand()
360
+ w_index = w_index.flatten() # Flattens to 1D
361
+
362
+ llm_pos_ids_list.append(
363
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
364
+ )
365
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
366
+ if st < len(input_tokens):
367
+ st_idx = (
368
+ llm_pos_ids_list[-1].max() + 1
369
+ if len(llm_pos_ids_list) > 0
370
+ else 0
371
+ )
372
+ text_len = len(input_tokens) - st
373
+
374
+ t_index = mx.arange(text_len).reshape(
375
+ 1, text_len
376
+ ) # Equivalent to .view(-1, 1)
377
+ t_index = mx.broadcast_to(
378
+ t_index, (3, text_len)
379
+ ) # Equivalent to expand(3, -1)
380
+
381
+ llm_pos_ids_list.append(t_index + st_idx)
382
+
383
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
384
+ mask = mx.array(attention_mask[i] == 1)
385
+ expanded_mask = mx.expand_dims(mask, axis=0)
386
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
387
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
388
+ new_positions = mx.where(
389
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
390
+ )
391
+ updated_position_ids = mx.concatenate(
392
+ [
393
+ position_ids[:, :i, :],
394
+ new_positions,
395
+ position_ids[:, i + 1 :, :],
396
+ ],
397
+ axis=1,
398
+ )
399
+ position_ids = updated_position_ids
400
+ mrope_position_deltas.append(
401
+ llm_positions.max() + 1 - len(total_input_ids[i])
402
+ )
403
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
404
+ return position_ids, mrope_position_deltas
405
+ else:
406
+ if attention_mask is not None:
407
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
408
+ position_ids = mx.where(
409
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
410
+ )
411
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
412
+ position_ids = mx.tile(position_ids, (3, 1, 1))
413
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
414
+ -1, keepdims=True
415
+ )[0]
416
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
417
+ else:
418
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
419
+ position_ids = mx.broadcast_to(
420
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
421
+ )
422
+ mrope_position_deltas = mx.zeros(
423
+ [input_ids.shape[0], 1],
424
+ dtype=input_ids.dtype,
425
+ )
426
+ return position_ids, mrope_position_deltas
427
+
428
+ def __call__(
429
+ self,
430
+ inputs: mx.array,
431
+ inputs_embeds: Optional[mx.array] = None,
432
+ mask: Optional[mx.array] = None,
433
+ cache=None,
434
+ **kwargs,
435
+ ):
436
+
437
+ position_ids = kwargs.pop("position_ids", None)
438
+ pixel_values = kwargs.pop("pixel_values", None)
439
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
440
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
441
+ # reset rope_deltas and position_ids when processing a new image/video
442
+ if pixel_values is not None:
443
+ self._rope_deltas = None
444
+ self._position_ids = None
445
+
446
+ cache_offset = 0
447
+ if cache and cache[0] is not None:
448
+ offset = cache[0].offset
449
+ if isinstance(offset, int):
450
+ cache_offset = offset
451
+ elif isinstance(offset, mx.array):
452
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
453
+ else:
454
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
455
+
456
+ # Check if mask shape matches input shape
457
+ rope_mask = mask
458
+ if mask is not None and mask.shape[-1] != inputs.shape[-1]:
459
+ rope_mask = None
460
+
461
+ if position_ids is None and (rope_mask is None or rope_mask.ndim == 2):
462
+ # Calculate RoPE index once per generation in the pre-fill stage only
463
+ if (
464
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
465
+ or self._rope_deltas is None
466
+ or cache is None
467
+ ):
468
+ # Check if we have stored position_ids from chunked prefill
469
+ if self._position_ids is not None:
470
+ # Use stored position_ids, sliced for current chunk
471
+ seq_length = inputs.shape[1]
472
+ position_ids = self._position_ids[
473
+ :, :, cache_offset : cache_offset + seq_length
474
+ ]
475
+ else:
476
+ position_ids, rope_deltas = self.get_rope_index(
477
+ inputs, image_grid_thw, video_grid_thw, rope_mask
478
+ )
479
+ self._rope_deltas = rope_deltas
480
+ # Store full position_ids for chunked prefill
481
+ self._position_ids = position_ids
482
+ else:
483
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
484
+ batch_size, seq_length = inputs.shape
485
+ delta = mx.array(
486
+ cache_offset + self._rope_deltas if cache is not None else 0
487
+ )
488
+ position_ids = mx.arange(seq_length).reshape(1, -1)
489
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
490
+
491
+ if cache_offset is not None:
492
+ if delta.ndim == 0:
493
+ delta = mx.expand_dims(delta, axis=0)
494
+
495
+ if delta.shape[0] < batch_size:
496
+ delta = mx.tile(delta, (batch_size, 1))
497
+ else:
498
+ # Slice delta to match batch
499
+ delta = delta[:batch_size]
500
+
501
+ position_ids = mx.add(position_ids, delta)[None, ...]
502
+ position_ids = mx.broadcast_to(
503
+ position_ids, (3, batch_size, seq_length)
504
+ )
505
+
506
+ out = self.model(
507
+ inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
508
+ )
509
+ out = self.lm_head(out)
510
+ return LanguageModelOutput(logits=out)
511
+
512
+ @property
513
+ def layers(self):
514
+ return self.model.layers
515
+
516
+ @property
517
+ def head_dim(self):
518
+ return self.args.hidden_size // self.args.num_attention_heads
519
+
520
+ @property
521
+ def n_kv_heads(self):
522
+ return self.args.num_key_value_heads