fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,539 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import ModelConfig, TextConfig
13
+
14
+
15
+ class Qwen2RotaryEmbedding:
16
+ def __init__(
17
+ self, dim, max_position_embeddings=2048, base=10000, rope_scaling=None
18
+ ):
19
+ self.dim = dim
20
+ self.max_position_embeddings = max_position_embeddings
21
+ self.base = base
22
+
23
+ inv_freq = 1.0 / (
24
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
25
+ )
26
+ self.inv_freq = inv_freq
27
+ self.attention_scaling = 1.0 # type: default
28
+
29
+ self.mrope_section = rope_scaling.get("mrope_section", [24, 20, 20])
30
+
31
+ def apply_mrope(self, freqs, mrope_section):
32
+ """Apply MRoPE to 3D rotary embeddings.
33
+ Reorganizes frequency layout to chunked [TTT...HHH...WWW]
34
+ args:
35
+ freqs: (3, bs, seq_len, head_dim // 2)
36
+ mrope_section: (3,)
37
+ returns:
38
+ x_t: (bs, seq_len, head_dim // 2)
39
+ """
40
+ freqs_t = freqs[0] # just overwrite the first dimension T
41
+ offset = mrope_section[0]
42
+ for dim, length in enumerate(mrope_section[1:], start=1): # H, W
43
+ idx = slice(offset, offset + length)
44
+ freqs_t[..., idx] = freqs[dim, ..., idx]
45
+ offset += length
46
+ return freqs_t
47
+
48
+ def __call__(self, x, position_ids):
49
+ # In contrast to other models, Qwen3VL has different position ids for the grids
50
+ # So we expand the inv_freq to shape (3, ...)
51
+ if position_ids.ndim == 2:
52
+ position_ids = mx.broadcast_to(
53
+ position_ids[None, ...],
54
+ (3, position_ids.shape[0], position_ids.shape[1]),
55
+ )
56
+
57
+ inv_freq_expanded = mx.broadcast_to(
58
+ self.inv_freq[None, None, :, None].astype(mx.float32),
59
+ (3, position_ids.shape[1], self.inv_freq.shape[0], 1),
60
+ )
61
+ position_ids_expanded = position_ids[:, :, None, :].astype(
62
+ mx.float32
63
+ ) # shape (3, bs, 1, positions)
64
+
65
+ freqs = inv_freq_expanded @ position_ids_expanded
66
+ freqs = mx.swapaxes(freqs, 2, 3)
67
+ freqs = self.apply_mrope(freqs, self.mrope_section)
68
+ emb = mx.concatenate([freqs, freqs], axis=-1)
69
+ cos = mx.cos(emb) * self.attention_scaling
70
+ sin = mx.sin(emb) * self.attention_scaling
71
+
72
+ return cos.astype(x.dtype), sin.astype(x.dtype)
73
+
74
+
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return mx.concatenate([-x2, x1], axis=-1)
80
+
81
+
82
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unqueeze_dim=1):
83
+ """
84
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
85
+ Args:
86
+ q (mx.array): The query tensor.
87
+ k (mx.array): The key tensor.
88
+ cos (mx.array): The cosine part of the rotary embedding.
89
+ sin (mx.array): The sine part of the rotary embedding.
90
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
91
+ Returns:
92
+ tuple(mx.array): The rotated query and key tensors.
93
+ """
94
+
95
+ cos = mx.expand_dims(cos, axis=unqueeze_dim)
96
+ sin = mx.expand_dims(sin, axis=unqueeze_dim)
97
+
98
+ # Apply rotary embedding
99
+ q_embed = (q * cos) + (rotate_half(q) * sin)
100
+ k_embed = (k * cos) + (rotate_half(k) * sin)
101
+
102
+ return q_embed, k_embed
103
+
104
+
105
+ class Attention(nn.Module):
106
+ def __init__(self, args: TextConfig):
107
+ super().__init__()
108
+
109
+ dim = args.hidden_size
110
+ self.n_heads = n_heads = args.num_attention_heads
111
+ assert args.num_key_value_heads is not None
112
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
113
+
114
+ self.head_dim = head_dim = args.hidden_size // n_heads
115
+ self.scale = head_dim**-0.5
116
+
117
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
118
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
119
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
120
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
121
+
122
+ self.rope_scaling = args.rope_scaling
123
+
124
+ self.rotary_emb = Qwen2RotaryEmbedding(
125
+ head_dim,
126
+ max_position_embeddings=args.max_position_embeddings,
127
+ base=args.rope_theta,
128
+ rope_scaling=args.rope_scaling,
129
+ )
130
+
131
+ def __call__(
132
+ self,
133
+ x: mx.array,
134
+ mask: Optional[mx.array] = None,
135
+ cache: Optional[KVCache] = None,
136
+ position_ids: Optional[mx.array] = None,
137
+ ) -> mx.array:
138
+ B, L, D = x.shape
139
+
140
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
141
+
142
+ # Prepare the queries, keys and values for the attention computation
143
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
144
+ 0, 2, 1, 3
145
+ )
146
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
147
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
148
+ 0, 2, 1, 3
149
+ )
150
+
151
+ kv_seq_len = keys.shape[-2]
152
+
153
+ if position_ids is None:
154
+ kv_seq_len += cache.offset + 1
155
+ position_ids = mx.arange(cache.offset, cache.offset + L)
156
+ position_ids = mx.expand_dims(position_ids, axis=0)
157
+ position_ids = mx.tile(position_ids, (3, 1, 1))
158
+ else:
159
+ kv_seq_len += cache.offset + 1 if cache is not None else 0
160
+
161
+ cos, sin = self.rotary_emb(values, position_ids)
162
+
163
+ if mask is not None and isinstance(mask, mx.array):
164
+ mask = mask[..., : keys.shape[-2]]
165
+ queries, keys = apply_multimodal_rotary_pos_emb(
166
+ queries, keys, cos, sin, unqueeze_dim=1
167
+ )
168
+
169
+ if cache is not None:
170
+ keys, values = cache.update_and_fetch(keys, values)
171
+
172
+ output = scaled_dot_product_attention(
173
+ queries, keys, values, cache, scale=self.scale, mask=mask
174
+ )
175
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
176
+ return self.o_proj(output)
177
+
178
+
179
+ class MLP(nn.Module):
180
+ def __init__(self, dim, hidden_dim):
181
+ super().__init__()
182
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
183
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
184
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
185
+
186
+ def __call__(self, x) -> mx.array:
187
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
188
+
189
+
190
+ class Qwen2VLDecoderLayer(nn.Module):
191
+ def __init__(self, args: TextConfig):
192
+ super().__init__()
193
+ self.num_attention_heads = args.num_attention_heads
194
+ self.hidden_size = args.hidden_size
195
+ self.self_attn = Attention(args)
196
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
197
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
198
+ self.post_attention_layernorm = nn.RMSNorm(
199
+ args.hidden_size, eps=args.rms_norm_eps
200
+ )
201
+ self.args = args
202
+
203
+ def __call__(
204
+ self,
205
+ x: mx.array,
206
+ mask: Optional[mx.array] = None,
207
+ cache: Optional[KVCache] = None,
208
+ position_ids: Optional[mx.array] = None,
209
+ ) -> mx.array:
210
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
211
+ h = x + r
212
+ r = self.mlp(self.post_attention_layernorm(h))
213
+ out = h + r
214
+ return out
215
+
216
+
217
+ class Qwen2Model(nn.Module):
218
+ def __init__(self, args: TextConfig):
219
+ super().__init__()
220
+ self.args = args
221
+ self.vocab_size = args.vocab_size
222
+ self.num_hidden_layers = args.num_hidden_layers
223
+ assert self.vocab_size > 0
224
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
225
+ self.layers = [
226
+ Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
227
+ ]
228
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
229
+
230
+ def __call__(
231
+ self,
232
+ inputs: mx.array,
233
+ inputs_embeds: Optional[mx.array] = None,
234
+ mask: Optional[mx.array] = None,
235
+ cache=None,
236
+ position_ids: Optional[mx.array] = None,
237
+ ):
238
+ if inputs_embeds is None:
239
+ h = self.embed_tokens(inputs)
240
+ else:
241
+ h = inputs_embeds
242
+
243
+ if cache is None:
244
+ cache = [None] * len(self.layers)
245
+
246
+ if mask is None:
247
+ mask = create_attention_mask(h, cache[0])
248
+
249
+ for layer, c in zip(self.layers, cache):
250
+ h = layer(h, mask, c, position_ids)
251
+
252
+ return self.norm(h)
253
+
254
+
255
+ class LanguageModel(nn.Module):
256
+ def __init__(self, args: TextConfig, config: ModelConfig):
257
+ super().__init__()
258
+ self.args = args
259
+ self.config = config
260
+ self.model_type = args.model_type
261
+ self.model = Qwen2Model(args)
262
+ self._rope_deltas = None
263
+ self._position_ids = None
264
+
265
+ if not args.tie_word_embeddings:
266
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
267
+
268
+ def get_rope_index(
269
+ self,
270
+ input_ids: mx.array,
271
+ image_grid_thw: Optional[mx.array] = None,
272
+ video_grid_thw: Optional[mx.array] = None,
273
+ attention_mask: Optional[mx.array] = None,
274
+ ):
275
+ # Calculate RoPE index for image/video tokens
276
+ batch_size, seq_length = input_ids.shape
277
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
278
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
279
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
280
+ image_token_id = self.config.image_token_id
281
+ video_token_id = self.config.video_token_id
282
+ vision_start_token_id = self.config.vision_start_token_id
283
+ mrope_position_deltas = []
284
+ if input_ids is not None and (
285
+ image_grid_thw is not None or video_grid_thw is not None
286
+ ):
287
+ total_input_ids = input_ids
288
+ if attention_mask is None:
289
+ attention_mask = mx.ones_like(input_ids)
290
+ position_ids = mx.ones(
291
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
292
+ )
293
+ image_index, video_index = 0, 0
294
+ for i, input_ids in enumerate(total_input_ids):
295
+ input_ids = mx.where(
296
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
297
+ )
298
+ image_nums, video_nums = 0, 0
299
+ vision_start_indices = mx.sum(
300
+ mx.where(
301
+ input_ids == vision_start_token_id,
302
+ mx.arange(input_ids.shape[0]),
303
+ mx.zeros_like(input_ids),
304
+ )
305
+ )
306
+ vision_tokens = input_ids[vision_start_indices + 1]
307
+ image_nums = (vision_tokens == image_token_id).sum().item()
308
+ video_nums = (vision_tokens == video_token_id).sum().item()
309
+ input_tokens = input_ids.tolist()
310
+ llm_pos_ids_list: list = []
311
+ st = 0
312
+ remain_images, remain_videos = image_nums, video_nums
313
+ for _ in range(image_nums + video_nums):
314
+ if image_token_id in input_tokens and remain_images > 0:
315
+ ed_image = input_tokens.index(image_token_id, st)
316
+ else:
317
+ ed_image = len(input_tokens) + 1
318
+ if video_token_id in input_tokens and remain_videos > 0:
319
+ ed_video = input_tokens.index(video_token_id, st)
320
+ else:
321
+ ed_video = len(input_tokens) + 1
322
+ if ed_image < ed_video:
323
+ t, h, w = (
324
+ image_grid_thw[image_index][0],
325
+ image_grid_thw[image_index][1],
326
+ image_grid_thw[image_index][2],
327
+ )
328
+ image_index += 1
329
+ remain_images -= 1
330
+ ed = ed_image
331
+ else:
332
+ t, h, w = (
333
+ video_grid_thw[video_index][0],
334
+ video_grid_thw[video_index][1],
335
+ video_grid_thw[video_index][2],
336
+ )
337
+ video_index += 1
338
+ remain_videos -= 1
339
+ ed = ed_video
340
+ llm_grid_t, llm_grid_h, llm_grid_w = (
341
+ t.item(),
342
+ h.item() // spatial_merge_size,
343
+ w.item() // spatial_merge_size,
344
+ )
345
+ text_len = ed - st
346
+ st_idx = (
347
+ llm_pos_ids_list[-1].max() + 1
348
+ if len(llm_pos_ids_list) > 0
349
+ else 0
350
+ )
351
+ index = mx.arange(text_len).reshape(1, text_len)
352
+ index = mx.broadcast_to(index, (3, text_len))
353
+ index = index + st_idx
354
+ llm_pos_ids_list.append(index)
355
+ t_index = mx.arange(llm_grid_t).reshape(
356
+ llm_grid_t, 1
357
+ ) # Equivalent to .view(-1, 1)
358
+ t_index = mx.broadcast_to(
359
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
360
+ ) # Equivalent to expand()
361
+ t_index = t_index.flatten() # Flattens to 1D
362
+
363
+ h_index = mx.arange(llm_grid_h).reshape(
364
+ 1, llm_grid_h, 1
365
+ ) # Equivalent to .view(1, -1)
366
+ h_index = mx.broadcast_to(
367
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
368
+ ) # Equivalent to expand()
369
+ h_index = h_index.flatten() # Flattens to 1D
370
+
371
+ w_index = mx.arange(llm_grid_w).reshape(
372
+ 1, 1, llm_grid_w
373
+ ) # Equivalent to .view(1, -1)
374
+ w_index = mx.broadcast_to(
375
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
376
+ ) # Equivalent to expand()
377
+ w_index = w_index.flatten() # Flattens to 1D
378
+
379
+ llm_pos_ids_list.append(
380
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
381
+ )
382
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
383
+ if st < len(input_tokens):
384
+ st_idx = (
385
+ llm_pos_ids_list[-1].max() + 1
386
+ if len(llm_pos_ids_list) > 0
387
+ else 0
388
+ )
389
+ text_len = len(input_tokens) - st
390
+
391
+ t_index = mx.arange(text_len).reshape(
392
+ 1, text_len
393
+ ) # Equivalent to .view(-1, 1)
394
+ t_index = mx.broadcast_to(
395
+ t_index, (3, text_len)
396
+ ) # Equivalent to expand(3, -1)
397
+
398
+ llm_pos_ids_list.append(t_index + st_idx)
399
+
400
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
401
+ mask = mx.array(attention_mask[i] == 1)
402
+ expanded_mask = mx.expand_dims(mask, axis=0)
403
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
404
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
405
+ new_positions = mx.where(
406
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
407
+ )
408
+ updated_position_ids = mx.concatenate(
409
+ [
410
+ position_ids[:, :i, :],
411
+ new_positions,
412
+ position_ids[:, i + 1 :, :],
413
+ ],
414
+ axis=1,
415
+ )
416
+ position_ids = updated_position_ids
417
+ mrope_position_deltas.append(
418
+ llm_positions.max() + 1 - len(total_input_ids[i])
419
+ )
420
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
421
+ return position_ids, mrope_position_deltas
422
+ else:
423
+ if attention_mask is not None:
424
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
425
+ position_ids = mx.where(
426
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
427
+ )
428
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
429
+ position_ids = mx.tile(position_ids, (3, 1, 1))
430
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
431
+ -1, keepdims=True
432
+ )[0]
433
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
434
+ else:
435
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
436
+ position_ids = mx.broadcast_to(
437
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
438
+ )
439
+ mrope_position_deltas = mx.zeros(
440
+ [input_ids.shape[0], 1],
441
+ dtype=input_ids.dtype,
442
+ )
443
+ return position_ids, mrope_position_deltas
444
+
445
+ def __call__(
446
+ self,
447
+ inputs: mx.array,
448
+ inputs_embeds: Optional[mx.array] = None,
449
+ mask: Optional[mx.array] = None,
450
+ cache=None,
451
+ **kwargs,
452
+ ):
453
+ position_ids = kwargs.pop("position_ids", None)
454
+ pixel_values = kwargs.pop("pixel_values", None)
455
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
456
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
457
+ # reset rope_deltas and position_ids when processing a new image/video
458
+ if pixel_values is not None:
459
+ self._rope_deltas = None
460
+ self._position_ids = None
461
+
462
+ cache_offset = 0
463
+ if cache and cache[0] is not None:
464
+ offset = cache[0].offset
465
+ if isinstance(offset, int):
466
+ cache_offset = offset
467
+ elif isinstance(offset, mx.array):
468
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
469
+ else:
470
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
471
+
472
+ # Check if mask shape matches input shape (for chunked prefill compatibility)
473
+ rope_mask = mask
474
+ if mask is not None and mask.shape[-1] != inputs.shape[-1]:
475
+ rope_mask = None
476
+
477
+ if position_ids is None and (rope_mask is None or rope_mask.ndim == 2):
478
+ # Calculate RoPE index once per generation in the pre-fill stage only
479
+ if (
480
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
481
+ or self._rope_deltas is None
482
+ or cache is None
483
+ ):
484
+ # Check if we have stored position_ids from chunked prefill
485
+ if self._position_ids is not None:
486
+ seq_length = inputs.shape[1]
487
+ position_ids = self._position_ids[
488
+ :, :, cache_offset : cache_offset + seq_length
489
+ ]
490
+ else:
491
+ position_ids, rope_deltas = self.get_rope_index(
492
+ inputs, image_grid_thw, video_grid_thw, rope_mask
493
+ )
494
+ self._rope_deltas = rope_deltas
495
+ self._position_ids = position_ids
496
+ else:
497
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
498
+ batch_size, seq_length = inputs.shape
499
+ delta = mx.array(
500
+ cache_offset + self._rope_deltas if cache is not None else 0
501
+ )
502
+ position_ids = mx.arange(seq_length).reshape(1, -1)
503
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
504
+
505
+ if cache_offset is not None:
506
+ if delta.ndim == 0:
507
+ delta = mx.expand_dims(delta, axis=0)
508
+
509
+ if delta.shape[0] < batch_size:
510
+ delta = mx.tile(delta, (batch_size, 1))
511
+ else:
512
+ # Slice delta to match batch
513
+ delta = delta[:batch_size]
514
+
515
+ position_ids = mx.add(position_ids, delta)[None, ...]
516
+ position_ids = mx.broadcast_to(
517
+ position_ids, (3, batch_size, seq_length)
518
+ )
519
+
520
+ out = self.model(
521
+ inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
522
+ )
523
+ if self.args.tie_word_embeddings:
524
+ out = self.model.embed_tokens.as_linear(out)
525
+ else:
526
+ out = self.lm_head(out)
527
+ return LanguageModelOutput(logits=out)
528
+
529
+ @property
530
+ def layers(self):
531
+ return self.model.layers
532
+
533
+ @property
534
+ def head_dim(self):
535
+ return self.args.hidden_size // self.args.num_attention_heads
536
+
537
+ @property
538
+ def n_kv_heads(self):
539
+ return self.args.num_key_value_heads