fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,596 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from ..cache import KVCache
13
+ from .config import ModelConfig, TextConfig
14
+
15
+
16
+ class Qwen3VLRotaryEmbedding:
17
+ def __init__(
18
+ self, dim, max_position_embeddings=2048, base=10000, rope_scaling=None
19
+ ):
20
+ self.dim = dim
21
+ self.max_position_embeddings = max_position_embeddings
22
+ self.base = base
23
+
24
+ inv_freq = 1.0 / (
25
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
26
+ )
27
+ self.inv_freq = inv_freq
28
+
29
+ self.mrope_section = rope_scaling.get("mrope_section", [24, 20, 20])
30
+
31
+ def apply_interleaved_mrope(self, freqs, mrope_section):
32
+ """Apply interleaved MRoPE to 3D rotary embeddings.
33
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
34
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
35
+ args:
36
+ x: (3, bs, seq_len, head_dim // 2)
37
+ mrope_section: (3,)
38
+ returns:
39
+ x_t: (bs, seq_len, head_dim // 2)
40
+ """
41
+ freqs_t = freqs[0] # just overwrite the first dimension T
42
+ for dim, offset in enumerate((1, 2), start=1): # H, W
43
+ length = mrope_section[dim] * 3
44
+ idx = slice(offset, length, 3)
45
+ freqs_t[..., idx] = freqs[dim, ..., idx]
46
+ return freqs_t
47
+
48
+ def __call__(self, x, position_ids):
49
+
50
+ # In contrast to other models, Qwen3VL has different position ids for the grids
51
+ # So we expand the inv_freq to shape (3, ...)
52
+ if position_ids.ndim == 2:
53
+ position_ids = mx.broadcast_to(
54
+ position_ids[None, ...],
55
+ (3, position_ids.shape[0], position_ids.shape[1]),
56
+ )
57
+
58
+ inv_freq_expanded = mx.broadcast_to(
59
+ self.inv_freq[None, None, :, None].astype(mx.float32),
60
+ (3, position_ids.shape[1], self.inv_freq.shape[0], 1),
61
+ )
62
+ position_ids_expanded = position_ids[:, :, None, :].astype(
63
+ mx.float32
64
+ ) # shape (3, bs, 1, positions)
65
+
66
+ freqs = inv_freq_expanded @ position_ids_expanded
67
+ freqs = mx.swapaxes(freqs, 2, 3)
68
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
69
+ emb = mx.concatenate([freqs, freqs], axis=-1)
70
+ cos = mx.cos(emb)
71
+ sin = mx.sin(emb)
72
+
73
+ return cos.astype(x.dtype), sin.astype(x.dtype)
74
+
75
+
76
+ def rotate_half(x):
77
+ """Rotates half the hidden dims of the input."""
78
+ x1 = x[..., : x.shape[-1] // 2]
79
+ x2 = x[..., x.shape[-1] // 2 :]
80
+ return mx.concatenate([-x2, x1], axis=-1)
81
+
82
+
83
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unqueeze_dim=1):
84
+ """
85
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
86
+ Args:
87
+ q (mx.array): The query tensor.
88
+ k (mx.array): The key tensor.
89
+ cos (mx.array): The cosine part of the rotary embedding.
90
+ sin (mx.array): The sine part of the rotary embedding.
91
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
92
+ Returns:
93
+ tuple(mx.array): The rotated query and key tensors.
94
+ """
95
+
96
+ cos = mx.expand_dims(cos, axis=unqueeze_dim)
97
+ sin = mx.expand_dims(sin, axis=unqueeze_dim)
98
+
99
+ # Apply rotary embedding
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+
103
+ return q_embed, k_embed
104
+
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(self, args: TextConfig):
108
+ super().__init__()
109
+
110
+ dim = args.hidden_size
111
+ self.n_heads = n_heads = args.num_attention_heads
112
+ assert args.num_key_value_heads is not None
113
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
114
+
115
+ self.head_dim = head_dim = getattr(
116
+ args, "head_dim", args.hidden_size // args.num_attention_heads
117
+ )
118
+ self.scale = head_dim**-0.5
119
+
120
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
121
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
122
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
123
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
124
+
125
+ self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
126
+ self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
127
+
128
+ self.rope_scaling = args.rope_scaling
129
+
130
+ self.rotary_emb = Qwen3VLRotaryEmbedding(
131
+ head_dim,
132
+ max_position_embeddings=args.max_position_embeddings,
133
+ base=args.rope_theta,
134
+ rope_scaling=self.rope_scaling,
135
+ )
136
+
137
+ def __call__(
138
+ self,
139
+ x: mx.array,
140
+ mask: Optional[mx.array] = None,
141
+ cache: Optional[KVCache] = None,
142
+ position_ids: Optional[mx.array] = None,
143
+ ) -> mx.array:
144
+ B, L, D = x.shape
145
+
146
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
147
+
148
+ # Prepare the queries, keys and values for the attention computation
149
+ queries = self.q_norm(
150
+ queries.reshape(B, L, self.n_heads, self.head_dim)
151
+ ).transpose(0, 2, 1, 3)
152
+ keys = self.k_norm(
153
+ keys.reshape(B, L, self.n_kv_heads, self.head_dim)
154
+ ).transpose(0, 2, 1, 3)
155
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
156
+ 0, 2, 1, 3
157
+ )
158
+
159
+ kv_seq_len = keys.shape[-2]
160
+
161
+ if position_ids is None:
162
+ kv_seq_len += cache.offset + 1
163
+ position_ids = mx.arange(cache.offset, cache.offset + L)
164
+ position_ids = mx.expand_dims(position_ids, axis=0)
165
+ position_ids = mx.tile(position_ids, (3, 1, 1))
166
+ else:
167
+ kv_seq_len += cache.offset + 1 if cache is not None else 0
168
+
169
+ cos, sin = self.rotary_emb(values, position_ids)
170
+
171
+ if mask is not None and isinstance(mask, mx.array):
172
+ mask = mask[..., :kv_seq_len]
173
+
174
+ queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin)
175
+
176
+ if cache is not None:
177
+ keys, values = cache.update_and_fetch(keys, values)
178
+
179
+ output = scaled_dot_product_attention(
180
+ queries, keys, values, cache, scale=self.scale, mask=mask
181
+ )
182
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
183
+ return self.o_proj(output)
184
+
185
+
186
+ class MLP(nn.Module):
187
+ def __init__(self, dim, hidden_dim):
188
+ super().__init__()
189
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
190
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
191
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
192
+
193
+ def __call__(self, x) -> mx.array:
194
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
195
+
196
+
197
+ class Qwen3VLDecoderLayer(nn.Module):
198
+ def __init__(self, args: TextConfig, layer_idx: int):
199
+ super().__init__()
200
+ self.hidden_size = args.hidden_size
201
+ self.self_attn = Attention(args)
202
+
203
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
204
+ self.post_attention_layernorm = nn.RMSNorm(
205
+ args.hidden_size, eps=args.rms_norm_eps
206
+ )
207
+ self.args = args
208
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
209
+
210
+ def __call__(
211
+ self,
212
+ x: mx.array,
213
+ mask: Optional[mx.array] = None,
214
+ cache: Optional[KVCache] = None,
215
+ position_ids: Optional[mx.array] = None,
216
+ ) -> mx.array:
217
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
218
+ h = x + r
219
+ r = self.mlp(self.post_attention_layernorm(h))
220
+ out = h + r
221
+ return out
222
+
223
+
224
+ class Qwen3VLModel(nn.Module):
225
+ def __init__(self, args: TextConfig):
226
+ super().__init__()
227
+ self.args = args
228
+ self.vocab_size = args.vocab_size
229
+ self.num_hidden_layers = args.num_hidden_layers
230
+ assert self.vocab_size > 0
231
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
232
+ self.layers = [
233
+ Qwen3VLDecoderLayer(args=args, layer_idx=layer_idx)
234
+ for layer_idx in range(args.num_hidden_layers)
235
+ ]
236
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
237
+
238
+ def __call__(
239
+ self,
240
+ inputs: mx.array,
241
+ inputs_embeds: Optional[mx.array] = None,
242
+ mask: Optional[mx.array] = None,
243
+ cache=None,
244
+ position_ids: Optional[mx.array] = None,
245
+ # args for deepstack
246
+ visual_pos_masks: Optional[mx.array] = None,
247
+ deepstack_visual_embeds: Optional[mx.array] = None,
248
+ ):
249
+ if inputs_embeds is None:
250
+ h = self.embed_tokens(inputs)
251
+ else:
252
+ h = inputs_embeds
253
+
254
+ if cache is None:
255
+ cache = [None] * len(self.layers)
256
+
257
+ if mask is None:
258
+ mask = create_attention_mask(h, cache)
259
+ for layer_idx, (layer, c) in enumerate(zip(self.layers, cache)):
260
+ h = layer(h, mask, c, position_ids)
261
+ # Add deepstack visual embeds
262
+ # add visual features to the hidden states of first several layers
263
+ if deepstack_visual_embeds is not None and layer_idx in range(
264
+ len(deepstack_visual_embeds)
265
+ ):
266
+ h = self._deepstack_process(
267
+ h,
268
+ visual_pos_masks,
269
+ deepstack_visual_embeds[layer_idx],
270
+ )
271
+
272
+ return self.norm(h)
273
+
274
+ def _deepstack_process(
275
+ self,
276
+ hidden_states: mx.array,
277
+ visual_pos_masks: mx.array,
278
+ visual_embeds: mx.array,
279
+ ):
280
+ batch_size = hidden_states.shape[0]
281
+
282
+ updated_batches = []
283
+ for b in range(batch_size):
284
+ batch_mask = visual_pos_masks[b]
285
+ batch_hidden = hidden_states[b]
286
+
287
+ batch_indices = mx.array(np.where(batch_mask)[0], dtype=mx.uint32)
288
+
289
+ if len(batch_indices) == 0:
290
+ updated_batches.append(batch_hidden)
291
+ continue
292
+
293
+ batch_result = mx.array(batch_hidden) # avoid modifying in-place
294
+ batch_result = batch_result.at[batch_indices].add(visual_embeds)
295
+
296
+ updated_batches.append(batch_result)
297
+
298
+ return mx.stack(updated_batches, axis=0)
299
+
300
+
301
+ class LanguageModel(nn.Module):
302
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
303
+ super().__init__()
304
+ self.args = args
305
+ self.config = config
306
+ self.model_type = args.model_type
307
+ self.model = Qwen3VLModel(args)
308
+ self._rope_deltas = None
309
+ self._position_ids = None
310
+
311
+ if not args.tie_word_embeddings:
312
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
313
+
314
+ def get_rope_index(
315
+ self,
316
+ input_ids: mx.array,
317
+ image_grid_thw: Optional[mx.array] = None,
318
+ video_grid_thw: Optional[mx.array] = None,
319
+ attention_mask: Optional[mx.array] = None,
320
+ ):
321
+ # Calculate RoPE index for image/video tokens
322
+ batch_size, seq_length = input_ids.shape
323
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
324
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
325
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
326
+ image_token_id = self.config.image_token_id
327
+ video_token_id = self.config.video_token_id
328
+ vision_start_token_id = self.config.vision_start_token_id
329
+ mrope_position_deltas = []
330
+ if input_ids is not None and (
331
+ image_grid_thw is not None or video_grid_thw is not None
332
+ ):
333
+ total_input_ids = input_ids
334
+ if attention_mask is None:
335
+ attention_mask = mx.ones_like(input_ids)
336
+ position_ids = mx.ones(
337
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
338
+ )
339
+ image_index, video_index = 0, 0
340
+ for i, input_ids in enumerate(total_input_ids):
341
+ input_ids = mx.where(
342
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
343
+ )
344
+ image_nums, video_nums = 0, 0
345
+ vision_start_indices = mx.sum(
346
+ mx.where(
347
+ input_ids == vision_start_token_id,
348
+ mx.arange(input_ids.shape[0]),
349
+ mx.zeros_like(input_ids),
350
+ )
351
+ )
352
+ vision_tokens = input_ids[vision_start_indices + 1]
353
+ image_nums = (vision_tokens == image_token_id).sum().item()
354
+ video_nums = (vision_tokens == video_token_id).sum().item()
355
+ input_tokens = input_ids.tolist()
356
+ llm_pos_ids_list: list = []
357
+ st = 0
358
+ remain_images, remain_videos = image_nums, video_nums
359
+ for _ in range(image_nums + video_nums):
360
+ if image_token_id in input_tokens and remain_images > 0:
361
+ ed_image = input_tokens.index(image_token_id, st)
362
+ else:
363
+ ed_image = len(input_tokens) + 1
364
+ if video_token_id in input_tokens and remain_videos > 0:
365
+ ed_video = input_tokens.index(video_token_id, st)
366
+ else:
367
+ ed_video = len(input_tokens) + 1
368
+ if ed_image < ed_video:
369
+ t, h, w = (
370
+ image_grid_thw[image_index][0],
371
+ image_grid_thw[image_index][1],
372
+ image_grid_thw[image_index][2],
373
+ )
374
+ image_index += 1
375
+ remain_images -= 1
376
+ ed = ed_image
377
+ else:
378
+ t, h, w = (
379
+ video_grid_thw[video_index][0],
380
+ video_grid_thw[video_index][1],
381
+ video_grid_thw[video_index][2],
382
+ )
383
+ video_index += 1
384
+ remain_videos -= 1
385
+ ed = ed_video
386
+ llm_grid_t, llm_grid_h, llm_grid_w = (
387
+ t.item(),
388
+ h.item() // spatial_merge_size,
389
+ w.item() // spatial_merge_size,
390
+ )
391
+ text_len = ed - st
392
+ st_idx = (
393
+ llm_pos_ids_list[-1].max() + 1
394
+ if len(llm_pos_ids_list) > 0
395
+ else 0
396
+ )
397
+ index = mx.arange(text_len).reshape(1, text_len)
398
+ index = mx.broadcast_to(index, (3, text_len))
399
+ index = index + st_idx
400
+ llm_pos_ids_list.append(index)
401
+ t_index = mx.arange(llm_grid_t).reshape(
402
+ llm_grid_t, 1
403
+ ) # Equivalent to .view(-1, 1)
404
+ t_index = mx.broadcast_to(
405
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
406
+ ) # Equivalent to expand()
407
+ t_index = t_index.flatten() # Flattens to 1D
408
+
409
+ h_index = mx.arange(llm_grid_h).reshape(
410
+ 1, llm_grid_h, 1
411
+ ) # Equivalent to .view(1, -1)
412
+ h_index = mx.broadcast_to(
413
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
414
+ ) # Equivalent to expand()
415
+ h_index = h_index.flatten() # Flattens to 1D
416
+
417
+ w_index = mx.arange(llm_grid_w).reshape(
418
+ 1, 1, llm_grid_w
419
+ ) # Equivalent to .view(1, -1)
420
+ w_index = mx.broadcast_to(
421
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
422
+ ) # Equivalent to expand()
423
+ w_index = w_index.flatten() # Flattens to 1D
424
+
425
+ llm_pos_ids_list.append(
426
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
427
+ )
428
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
429
+ if st < len(input_tokens):
430
+ st_idx = (
431
+ llm_pos_ids_list[-1].max() + 1
432
+ if len(llm_pos_ids_list) > 0
433
+ else 0
434
+ )
435
+ text_len = len(input_tokens) - st
436
+
437
+ t_index = mx.arange(text_len).reshape(
438
+ 1, text_len
439
+ ) # Equivalent to .view(-1, 1)
440
+ t_index = mx.broadcast_to(
441
+ t_index, (3, text_len)
442
+ ) # Equivalent to expand(3, -1)
443
+
444
+ llm_pos_ids_list.append(t_index + st_idx)
445
+
446
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
447
+ mask = mx.array(attention_mask[i] == 1)
448
+ expanded_mask = mx.expand_dims(mask, axis=0)
449
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
450
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
451
+ new_positions = mx.where(
452
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
453
+ )
454
+ updated_position_ids = mx.concatenate(
455
+ [
456
+ position_ids[:, :i, :],
457
+ new_positions,
458
+ position_ids[:, i + 1 :, :],
459
+ ],
460
+ axis=1,
461
+ )
462
+ position_ids = updated_position_ids
463
+ mrope_position_deltas.append(
464
+ llm_positions.max() + 1 - len(total_input_ids[i])
465
+ )
466
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
467
+ return position_ids, mrope_position_deltas
468
+ else:
469
+ if attention_mask is not None:
470
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
471
+ position_ids = mx.where(
472
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
473
+ )
474
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
475
+ position_ids = mx.tile(position_ids, (3, 1, 1))
476
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
477
+ -1, keepdims=True
478
+ )[0]
479
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
480
+ else:
481
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
482
+ position_ids = mx.broadcast_to(
483
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
484
+ )
485
+ mrope_position_deltas = mx.zeros(
486
+ [input_ids.shape[0], 1],
487
+ dtype=input_ids.dtype,
488
+ )
489
+ return position_ids, mrope_position_deltas
490
+
491
+ def __call__(
492
+ self,
493
+ inputs: mx.array,
494
+ inputs_embeds: Optional[mx.array] = None,
495
+ mask: Optional[mx.array] = None,
496
+ cache=None,
497
+ # args for deepstack
498
+ visual_pos_masks: Optional[mx.array] = None,
499
+ deepstack_visual_embeds: Optional[mx.array] = None,
500
+ **kwargs,
501
+ ):
502
+ # Slicing visual_pos_masks when prefilling
503
+ n_to_process = kwargs.get("n_to_process", None)
504
+ if n_to_process is not None:
505
+ visual_pos_masks = visual_pos_masks[:, n_to_process:]
506
+
507
+ position_ids = kwargs.pop("position_ids", None)
508
+ pixel_values = kwargs.pop("pixel_values", None)
509
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
510
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
511
+ # reset rope_deltas when processing a new image/video
512
+ if pixel_values is not None:
513
+ self._rope_deltas = None
514
+
515
+ cache_offset = 0
516
+ if cache and cache[0] is not None:
517
+ offset = cache[0].offset
518
+ if isinstance(offset, int):
519
+ cache_offset = offset
520
+ elif isinstance(offset, mx.array):
521
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
522
+ else:
523
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
524
+
525
+ # Check if mask shape matches input shape (for chunked prefill compatibility)
526
+ rope_mask = mask
527
+ if mask is not None and mask.shape[-1] != inputs.shape[-1]:
528
+ rope_mask = None
529
+
530
+ if position_ids is None and (rope_mask is None or rope_mask.ndim == 2):
531
+ # Calculate RoPE index once per generation in the pre-fill stage only
532
+ if (
533
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
534
+ or self._rope_deltas is None
535
+ or cache is None
536
+ ):
537
+ if self._position_ids is not None:
538
+ seq_length = inputs.shape[1]
539
+ position_ids = self._position_ids[
540
+ :, :, cache_offset : cache_offset + seq_length
541
+ ]
542
+ else:
543
+ position_ids, rope_deltas = self.get_rope_index(
544
+ inputs, image_grid_thw, video_grid_thw, rope_mask
545
+ )
546
+ self._rope_deltas = rope_deltas
547
+ self._position_ids = position_ids
548
+ else:
549
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
550
+ batch_size, seq_length = inputs.shape
551
+ delta = mx.array(
552
+ cache_offset + self._rope_deltas if cache is not None else 0
553
+ )
554
+ position_ids = mx.arange(seq_length).reshape(1, -1)
555
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
556
+
557
+ if cache_offset is not None:
558
+ if delta.ndim == 0:
559
+ delta = mx.expand_dims(delta, axis=0)
560
+
561
+ if delta.shape[0] < batch_size:
562
+ delta = mx.tile(delta, (batch_size, 1))
563
+ else:
564
+ # Slice delta to match batch
565
+ delta = delta[:batch_size]
566
+
567
+ position_ids = mx.add(position_ids, delta)[None, ...]
568
+ position_ids = mx.broadcast_to(
569
+ position_ids, (3, batch_size, seq_length)
570
+ )
571
+
572
+ out = self.model(
573
+ inputs,
574
+ cache=cache,
575
+ inputs_embeds=inputs_embeds,
576
+ position_ids=position_ids,
577
+ visual_pos_masks=visual_pos_masks,
578
+ deepstack_visual_embeds=deepstack_visual_embeds,
579
+ )
580
+ if self.args.tie_word_embeddings:
581
+ out = self.model.embed_tokens.as_linear(out)
582
+ else:
583
+ out = self.lm_head(out)
584
+ return LanguageModelOutput(logits=out)
585
+
586
+ @property
587
+ def layers(self):
588
+ return self.model.layers
589
+
590
+ @property
591
+ def head_dim(self):
592
+ return self.args.hidden_size // self.args.num_attention_heads
593
+
594
+ @property
595
+ def n_kv_heads(self):
596
+ return self.args.num_key_value_heads