fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,541 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import ModelConfig, TextConfig
13
+
14
+
15
+ class Qwen2RotaryEmbedding:
16
+ def __init__(
17
+ self, dim, max_position_embeddings=2048, base=10000, rope_scaling=None
18
+ ):
19
+ self.dim = dim
20
+ self.max_position_embeddings = max_position_embeddings
21
+ self.base = base
22
+
23
+ inv_freq = 1.0 / (
24
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
25
+ )
26
+ self.inv_freq = inv_freq
27
+ self.attention_scaling = 1.0 # type: default
28
+
29
+ self.mrope_section = rope_scaling.get("mrope_section", [24, 20, 20])
30
+
31
+ def apply_mrope(self, freqs, mrope_section):
32
+ """Apply MRoPE to 3D rotary embeddings.
33
+ Reorganizes frequency layout to chunked [TTT...HHH...WWW]
34
+ args:
35
+ freqs: (3, bs, seq_len, head_dim // 2)
36
+ mrope_section: (3,)
37
+ returns:
38
+ x_t: (bs, seq_len, head_dim // 2)
39
+ """
40
+ freqs_t = freqs[0] # just overwrite the first dimension T
41
+ offset = mrope_section[0]
42
+ for dim, length in enumerate(mrope_section[1:], start=1): # H, W
43
+ idx = slice(offset, offset + length)
44
+ freqs_t[..., idx] = freqs[dim, ..., idx]
45
+ offset += length
46
+ return freqs_t
47
+
48
+ def __call__(self, x, position_ids):
49
+
50
+ # In contrast to other models, Qwen3VL has different position ids for the grids
51
+ # So we expand the inv_freq to shape (3, ...)
52
+ if position_ids.ndim == 2:
53
+ position_ids = mx.broadcast_to(
54
+ position_ids[None, ...],
55
+ (3, position_ids.shape[0], position_ids.shape[1]),
56
+ )
57
+
58
+ inv_freq_expanded = mx.broadcast_to(
59
+ self.inv_freq[None, None, :, None].astype(mx.float32),
60
+ (3, position_ids.shape[1], self.inv_freq.shape[0], 1),
61
+ )
62
+ position_ids_expanded = position_ids[:, :, None, :].astype(
63
+ mx.float32
64
+ ) # shape (3, bs, 1, positions)
65
+
66
+ freqs = inv_freq_expanded @ position_ids_expanded
67
+ freqs = mx.swapaxes(freqs, 2, 3)
68
+ freqs = self.apply_mrope(freqs, self.mrope_section)
69
+ emb = mx.concatenate([freqs, freqs], axis=-1)
70
+ cos = mx.cos(emb) * self.attention_scaling
71
+ sin = mx.sin(emb) * self.attention_scaling
72
+
73
+ return cos.astype(x.dtype), sin.astype(x.dtype)
74
+
75
+
76
+ def rotate_half(x):
77
+ """Rotates half the hidden dims of the input."""
78
+ x1 = x[..., : x.shape[-1] // 2]
79
+ x2 = x[..., x.shape[-1] // 2 :]
80
+ return mx.concatenate([-x2, x1], axis=-1)
81
+
82
+
83
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unqueeze_dim=1):
84
+ """
85
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
86
+ Args:
87
+ q (mx.array): The query tensor.
88
+ k (mx.array): The key tensor.
89
+ cos (mx.array): The cosine part of the rotary embedding.
90
+ sin (mx.array): The sine part of the rotary embedding.
91
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
92
+ Returns:
93
+ tuple(mx.array): The rotated query and key tensors.
94
+ """
95
+
96
+ cos = mx.expand_dims(cos, axis=unqueeze_dim)
97
+ sin = mx.expand_dims(sin, axis=unqueeze_dim)
98
+
99
+ # Apply rotary embedding
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+
103
+ return q_embed, k_embed
104
+
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(self, args: TextConfig):
108
+ super().__init__()
109
+
110
+ dim = args.hidden_size
111
+ self.n_heads = n_heads = args.num_attention_heads
112
+ assert args.num_key_value_heads is not None
113
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
114
+
115
+ self.head_dim = head_dim = args.hidden_size // n_heads
116
+ self.scale = head_dim**-0.5
117
+
118
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
119
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
120
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
121
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
122
+
123
+ self.rope_scaling = args.rope_scaling
124
+
125
+ self.rotary_emb = Qwen2RotaryEmbedding(
126
+ head_dim,
127
+ max_position_embeddings=args.max_position_embeddings,
128
+ base=args.rope_theta,
129
+ rope_scaling=args.rope_scaling,
130
+ )
131
+
132
+ def __call__(
133
+ self,
134
+ x: mx.array,
135
+ mask: Optional[mx.array] = None,
136
+ cache: Optional[KVCache] = None,
137
+ position_ids: Optional[mx.array] = None,
138
+ ) -> mx.array:
139
+ B, L, D = x.shape
140
+
141
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
142
+
143
+ # Prepare the queries, keys and values for the attention computation
144
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
145
+ 0, 2, 1, 3
146
+ )
147
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
148
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
149
+ 0, 2, 1, 3
150
+ )
151
+
152
+ kv_seq_len = keys.shape[-2]
153
+
154
+ if position_ids is None:
155
+ kv_seq_len += cache.offset + 1
156
+ position_ids = mx.arange(cache.offset, cache.offset + L)
157
+ position_ids = mx.expand_dims(position_ids, axis=0)
158
+ position_ids = mx.tile(position_ids, (3, 1, 1))
159
+ else:
160
+ kv_seq_len += cache.offset + 1 if cache is not None else 0
161
+
162
+ cos, sin = self.rotary_emb(values, position_ids)
163
+
164
+ if mask is not None and isinstance(mask, mx.array):
165
+ mask = mask[..., : keys.shape[-2]]
166
+ queries, keys = apply_multimodal_rotary_pos_emb(
167
+ queries, keys, cos, sin, unqueeze_dim=1
168
+ )
169
+
170
+ if cache is not None:
171
+ keys, values = cache.update_and_fetch(keys, values)
172
+
173
+ output = scaled_dot_product_attention(
174
+ queries, keys, values, cache, scale=self.scale, mask=mask
175
+ )
176
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
177
+ return self.o_proj(output)
178
+
179
+
180
+ class MLP(nn.Module):
181
+ def __init__(self, dim, hidden_dim):
182
+ super().__init__()
183
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
184
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
185
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
186
+
187
+ def __call__(self, x) -> mx.array:
188
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
189
+
190
+
191
+ class Qwen2VLDecoderLayer(nn.Module):
192
+ def __init__(self, args: TextConfig):
193
+ super().__init__()
194
+ self.num_attention_heads = args.num_attention_heads
195
+ self.hidden_size = args.hidden_size
196
+ self.self_attn = Attention(args)
197
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
198
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
199
+ self.post_attention_layernorm = nn.RMSNorm(
200
+ args.hidden_size, eps=args.rms_norm_eps
201
+ )
202
+ self.args = args
203
+
204
+ def __call__(
205
+ self,
206
+ x: mx.array,
207
+ mask: Optional[mx.array] = None,
208
+ cache: Optional[KVCache] = None,
209
+ position_ids: Optional[mx.array] = None,
210
+ ) -> mx.array:
211
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
212
+ h = x + r
213
+ r = self.mlp(self.post_attention_layernorm(h))
214
+ out = h + r
215
+ return out
216
+
217
+
218
+ class Qwen2Model(nn.Module):
219
+ def __init__(self, args: TextConfig):
220
+ super().__init__()
221
+ self.args = args
222
+ self.vocab_size = args.vocab_size
223
+ self.num_hidden_layers = args.num_hidden_layers
224
+ assert self.vocab_size > 0
225
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
226
+ self.layers = [
227
+ Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
228
+ ]
229
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
230
+
231
+ def __call__(
232
+ self,
233
+ inputs: mx.array,
234
+ inputs_embeds: Optional[mx.array] = None,
235
+ mask: Optional[mx.array] = None,
236
+ cache=None,
237
+ position_ids: Optional[mx.array] = None,
238
+ ):
239
+ if inputs_embeds is None:
240
+ h = self.embed_tokens(inputs)
241
+ else:
242
+ h = inputs_embeds
243
+
244
+ if cache is None:
245
+ cache = [None] * len(self.layers)
246
+
247
+ if mask is None:
248
+ mask = create_attention_mask(h, cache)
249
+
250
+ for layer, c in zip(self.layers, cache):
251
+ h = layer(h, mask, c, position_ids)
252
+
253
+ return self.norm(h)
254
+
255
+
256
+ class LanguageModel(nn.Module):
257
+ def __init__(self, args: TextConfig, config: ModelConfig):
258
+ super().__init__()
259
+ self.args = args
260
+ self.config = config
261
+ self.model_type = args.model_type
262
+ self.model = Qwen2Model(args)
263
+ self._rope_deltas = None
264
+ self._position_ids = None
265
+
266
+ if not args.tie_word_embeddings:
267
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
268
+
269
+ def get_rope_index(
270
+ self,
271
+ input_ids: mx.array,
272
+ image_grid_thw: Optional[mx.array] = None,
273
+ video_grid_thw: Optional[mx.array] = None,
274
+ attention_mask: Optional[mx.array] = None,
275
+ ):
276
+ # Calculate RoPE index for image/video tokens
277
+ batch_size, seq_length = input_ids.shape
278
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
279
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
280
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
281
+ image_token_id = self.config.image_token_id
282
+ video_token_id = self.config.video_token_id
283
+ vision_start_token_id = self.config.vision_start_token_id
284
+ mrope_position_deltas = []
285
+ if input_ids is not None and (
286
+ image_grid_thw is not None or video_grid_thw is not None
287
+ ):
288
+ total_input_ids = input_ids
289
+ if attention_mask is None:
290
+ attention_mask = mx.ones_like(input_ids)
291
+ position_ids = mx.ones(
292
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
293
+ )
294
+ image_index, video_index = 0, 0
295
+ for i, input_ids in enumerate(total_input_ids):
296
+ input_ids = mx.where(
297
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
298
+ )
299
+ image_nums, video_nums = 0, 0
300
+ vision_start_indices = mx.sum(
301
+ mx.where(
302
+ input_ids == vision_start_token_id,
303
+ mx.arange(input_ids.shape[0]),
304
+ mx.zeros_like(input_ids),
305
+ )
306
+ )
307
+ vision_tokens = input_ids[vision_start_indices + 1]
308
+ image_nums = (vision_tokens == image_token_id).sum().item()
309
+ video_nums = (vision_tokens == video_token_id).sum().item()
310
+ input_tokens = input_ids.tolist()
311
+ llm_pos_ids_list: list = []
312
+ st = 0
313
+ remain_images, remain_videos = image_nums, video_nums
314
+ for _ in range(image_nums + video_nums):
315
+ if image_token_id in input_tokens and remain_images > 0:
316
+ ed_image = input_tokens.index(image_token_id, st)
317
+ else:
318
+ ed_image = len(input_tokens) + 1
319
+ if video_token_id in input_tokens and remain_videos > 0:
320
+ ed_video = input_tokens.index(video_token_id, st)
321
+ else:
322
+ ed_video = len(input_tokens) + 1
323
+ if ed_image < ed_video:
324
+ t, h, w = (
325
+ image_grid_thw[image_index][0],
326
+ image_grid_thw[image_index][1],
327
+ image_grid_thw[image_index][2],
328
+ )
329
+ image_index += 1
330
+ remain_images -= 1
331
+ ed = ed_image
332
+ else:
333
+ t, h, w = (
334
+ video_grid_thw[video_index][0],
335
+ video_grid_thw[video_index][1],
336
+ video_grid_thw[video_index][2],
337
+ )
338
+ video_index += 1
339
+ remain_videos -= 1
340
+ ed = ed_video
341
+ llm_grid_t, llm_grid_h, llm_grid_w = (
342
+ t.item(),
343
+ h.item() // spatial_merge_size,
344
+ w.item() // spatial_merge_size,
345
+ )
346
+ text_len = ed - st
347
+ st_idx = (
348
+ llm_pos_ids_list[-1].max() + 1
349
+ if len(llm_pos_ids_list) > 0
350
+ else 0
351
+ )
352
+ index = mx.arange(text_len).reshape(1, text_len)
353
+ index = mx.broadcast_to(index, (3, text_len))
354
+ index = index + st_idx
355
+ llm_pos_ids_list.append(index)
356
+ t_index = mx.arange(llm_grid_t).reshape(
357
+ llm_grid_t, 1
358
+ ) # Equivalent to .view(-1, 1)
359
+ t_index = mx.broadcast_to(
360
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
361
+ ) # Equivalent to expand()
362
+ t_index = t_index.flatten() # Flattens to 1D
363
+
364
+ h_index = mx.arange(llm_grid_h).reshape(
365
+ 1, llm_grid_h, 1
366
+ ) # Equivalent to .view(1, -1)
367
+ h_index = mx.broadcast_to(
368
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
369
+ ) # Equivalent to expand()
370
+ h_index = h_index.flatten() # Flattens to 1D
371
+
372
+ w_index = mx.arange(llm_grid_w).reshape(
373
+ 1, 1, llm_grid_w
374
+ ) # Equivalent to .view(1, -1)
375
+ w_index = mx.broadcast_to(
376
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
377
+ ) # Equivalent to expand()
378
+ w_index = w_index.flatten() # Flattens to 1D
379
+
380
+ llm_pos_ids_list.append(
381
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
382
+ )
383
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
384
+ if st < len(input_tokens):
385
+ st_idx = (
386
+ llm_pos_ids_list[-1].max() + 1
387
+ if len(llm_pos_ids_list) > 0
388
+ else 0
389
+ )
390
+ text_len = len(input_tokens) - st
391
+
392
+ t_index = mx.arange(text_len).reshape(
393
+ 1, text_len
394
+ ) # Equivalent to .view(-1, 1)
395
+ t_index = mx.broadcast_to(
396
+ t_index, (3, text_len)
397
+ ) # Equivalent to expand(3, -1)
398
+
399
+ llm_pos_ids_list.append(t_index + st_idx)
400
+
401
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
402
+ mask = mx.array(attention_mask[i] == 1)
403
+ expanded_mask = mx.expand_dims(mask, axis=0)
404
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
405
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
406
+ new_positions = mx.where(
407
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
408
+ )
409
+ updated_position_ids = mx.concatenate(
410
+ [
411
+ position_ids[:, :i, :],
412
+ new_positions,
413
+ position_ids[:, i + 1 :, :],
414
+ ],
415
+ axis=1,
416
+ )
417
+ position_ids = updated_position_ids
418
+ mrope_position_deltas.append(
419
+ llm_positions.max() + 1 - len(total_input_ids[i])
420
+ )
421
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
422
+ return position_ids, mrope_position_deltas
423
+ else:
424
+ if attention_mask is not None:
425
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
426
+ position_ids = mx.where(
427
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
428
+ )
429
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
430
+ position_ids = mx.tile(position_ids, (3, 1, 1))
431
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
432
+ -1, keepdims=True
433
+ )[0]
434
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
435
+ else:
436
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
437
+ position_ids = mx.broadcast_to(
438
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
439
+ )
440
+ mrope_position_deltas = mx.zeros(
441
+ [input_ids.shape[0], 1],
442
+ dtype=input_ids.dtype,
443
+ )
444
+ return position_ids, mrope_position_deltas
445
+
446
+ def __call__(
447
+ self,
448
+ inputs: mx.array,
449
+ inputs_embeds: Optional[mx.array] = None,
450
+ mask: Optional[mx.array] = None,
451
+ cache=None,
452
+ **kwargs,
453
+ ):
454
+
455
+ position_ids = kwargs.pop("position_ids", None)
456
+ pixel_values = kwargs.pop("pixel_values", None)
457
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
458
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
459
+ # reset rope_deltas and position_ids when processing a new image/video
460
+ if pixel_values is not None:
461
+ self._rope_deltas = None
462
+ self._position_ids = None
463
+
464
+ cache_offset = 0
465
+ if cache and cache[0] is not None:
466
+ offset = cache[0].offset
467
+ if isinstance(offset, int):
468
+ cache_offset = offset
469
+ elif isinstance(offset, mx.array):
470
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
471
+ else:
472
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
473
+
474
+ # Check if mask shape matches input shape (for chunked prefill compatibility)
475
+ rope_mask = mask
476
+ if mask is not None and mask.shape[-1] != inputs.shape[-1]:
477
+ rope_mask = None
478
+
479
+ if position_ids is None and (rope_mask is None or rope_mask.ndim == 2):
480
+ # Calculate RoPE index once per generation in the pre-fill stage only
481
+ if (
482
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
483
+ or self._rope_deltas is None
484
+ or cache is None
485
+ ):
486
+ # Check if we have stored position_ids from chunked prefill
487
+ if self._position_ids is not None:
488
+ seq_length = inputs.shape[1]
489
+ position_ids = self._position_ids[
490
+ :, :, cache_offset : cache_offset + seq_length
491
+ ]
492
+ else:
493
+ position_ids, rope_deltas = self.get_rope_index(
494
+ inputs, image_grid_thw, video_grid_thw, rope_mask
495
+ )
496
+ self._rope_deltas = rope_deltas
497
+ self._position_ids = position_ids
498
+ else:
499
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
500
+ batch_size, seq_length = inputs.shape
501
+ delta = mx.array(
502
+ cache_offset + self._rope_deltas if cache is not None else 0
503
+ )
504
+ position_ids = mx.arange(seq_length).reshape(1, -1)
505
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
506
+
507
+ if cache_offset is not None:
508
+ if delta.ndim == 0:
509
+ delta = mx.expand_dims(delta, axis=0)
510
+
511
+ if delta.shape[0] < batch_size:
512
+ delta = mx.tile(delta, (batch_size, 1))
513
+ else:
514
+ # Slice delta to match batch
515
+ delta = delta[:batch_size]
516
+
517
+ position_ids = mx.add(position_ids, delta)[None, ...]
518
+ position_ids = mx.broadcast_to(
519
+ position_ids, (3, batch_size, seq_length)
520
+ )
521
+
522
+ out = self.model(
523
+ inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
524
+ )
525
+ if self.args.tie_word_embeddings:
526
+ out = self.model.embed_tokens.as_linear(out)
527
+ else:
528
+ out = self.lm_head(out)
529
+ return LanguageModelOutput(logits=out)
530
+
531
+ @property
532
+ def layers(self):
533
+ return self.model.layers
534
+
535
+ @property
536
+ def head_dim(self):
537
+ return self.args.hidden_size // self.args.num_attention_heads
538
+
539
+ @property
540
+ def n_kv_heads(self):
541
+ return self.args.num_key_value_heads